GitOrigin-RevId: 190a330a8c
tags/v1.0.0-rc1
@@ -210,6 +210,7 @@ class trace: | |||||
info.external = True | info.external = True | ||||
info.device = x.device | info.device = x.device | ||||
info.dtype = x.dtype | info.dtype = x.dtype | ||||
info.shape = x.shape | |||||
if self._capture_as_const: | if self._capture_as_const: | ||||
info.bound_data = x | info.bound_data = x | ||||
@@ -338,7 +339,7 @@ class trace: | |||||
for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()): | for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()): | ||||
info = self._tinfo[h] | info = self._tinfo[h] | ||||
opnode = info.data_setter = G.InputNode( | opnode = info.data_setter = G.InputNode( | ||||
device=info.device, dtype=info.dtype, graph=graph | |||||
device=info.device, dtype=info.dtype, shape=info.shape, graph=graph | |||||
) | ) | ||||
need_reset_nodes.append(opnode) | need_reset_nodes.append(opnode) | ||||
info.varnode = opnode.outputs[0] | info.varnode = opnode.outputs[0] | ||||
@@ -355,7 +356,11 @@ class trace: | |||||
info.varnode = graph.make_const(info.bound_data._dev_tensor()) | info.varnode = graph.make_const(info.bound_data._dev_tensor()) | ||||
else: | else: | ||||
opnode = info.data_setter = G.InputNode( | opnode = info.data_setter = G.InputNode( | ||||
*links, device=info.device, dtype=info.dtype, graph=graph | |||||
*links, | |||||
device=info.device, | |||||
dtype=info.dtype, | |||||
shape=info.shape, | |||||
graph=graph, | |||||
) | ) | ||||
need_reset_nodes.append(opnode) | need_reset_nodes.append(opnode) | ||||
info.varnode, *links = opnode.outputs | info.varnode, *links = opnode.outputs | ||||
@@ -0,0 +1 @@ | |||||
../../../src/core/impl |
@@ -10,6 +10,7 @@ | |||||
*/ | */ | ||||
#include "megbrain/imperative/opr_utility.h" | #include "megbrain/imperative/opr_utility.h" | ||||
#include "./mgb_core_impl/graph/cg_impl.h" | |||||
// FIXME; setup_config_cn is copied from src/opr/impl/utility.cpp | // FIXME; setup_config_cn is copied from src/opr/impl/utility.cpp | ||||
namespace { | namespace { | ||||
@@ -64,14 +65,18 @@ SymbolVarArray InputCallback::make(cg::ComputingGraph& graph, | |||||
void InputCallback::init_output_static_infer_desc() { | void InputCallback::init_output_static_infer_desc() { | ||||
if (m_output_shape.ndim) { | if (m_output_shape.ndim) { | ||||
// Write this shape to static infer manager. The effect is | |||||
// that infer_shape_fallible() will return a non-empty shape | |||||
// while get_infer_type() remains NO_DESC. Most places check | |||||
// infer type before relying on inferred shape so things | |||||
// won't break. Memory optimizer however, deliberately omits | |||||
// infer type check so it will be able to use this shape for hint. | |||||
using namespace cg::static_infer; | using namespace cg::static_infer; | ||||
auto &&mgr = owner_graph()->static_infer_manager(); | |||||
auto infer_shape = [this](TensorShape &dest, const InpVal &) { | |||||
dest = m_output_shape; | |||||
return true; | |||||
}; | |||||
mgr.register_shape_infer(output(0), | |||||
{SourceType::CONSTANT, {}, infer_shape}); | |||||
auto* var = output(0); | |||||
var->shape(m_output_shape); | |||||
auto&& mgr = cg::ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl(); | |||||
auto* handle = mgr.get_tag_handler_for_shape(var); | |||||
handle->sync_from_var(); | |||||
} | } | ||||
} | } | ||||
@@ -86,9 +91,6 @@ cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const { | |||||
void InputCallback::scn_do_execute() { | void InputCallback::scn_do_execute() { | ||||
auto dev_tensor = m_callback(); | auto dev_tensor = m_callback(); | ||||
if (m_output_shape.ndim) { | |||||
mgb_assert(dev_tensor.shape().eq_shape(m_output_shape)); | |||||
} | |||||
output(0)->reset_dev_tensor_from_tensor(dev_tensor); | output(0)->reset_dev_tensor_from_tensor(dev_tensor); | ||||
} | } | ||||
@@ -99,7 +99,9 @@ MemoryOptimizerHelper::split_into_cn2oprseq(const OprNodeArray& oprseq, | |||||
auto&& infer_mgr = m_owner_graph->static_infer_manager(); | auto&& infer_mgr = m_owner_graph->static_infer_manager(); | ||||
for (auto j : i->output()) { | for (auto j : i->output()) { | ||||
if (!j->contain_flag(BAD_VAR_FLAG) && is_static_var_shape(j)) { | |||||
if (!j->contain_flag(BAD_VAR_FLAG)) { | |||||
// omit infer type check | |||||
// inferred shape will be used as-is | |||||
if (auto shape = infer_mgr.infer_shape_fallible(j)) { | if (auto shape = infer_mgr.infer_shape_fallible(j)) { | ||||
have_static_shape_out = true; | have_static_shape_out = true; | ||||
m_var_memsize[j] = j->dtype().size(shape->total_nr_elems()); | m_var_memsize[j] = j->dtype().size(shape->total_nr_elems()); | ||||