From 1fe8a2129939c7c79b6c4564aa4b1e004e20bd0f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 7 Sep 2020 16:54:39 +0800 Subject: [PATCH] fix(mge): fix sublinear memory in jit.trace GitOrigin-RevId: 190a330a8ce2016d24faed0246cc0c294c755946 --- imperative/python/megengine/jit/tracing.py | 9 +++++++-- imperative/src/impl/mgb_core_impl | 1 + imperative/src/impl/opr_utility.cpp | 22 ++++++++++++---------- src/core/impl/graph/memory_optimizer.cpp | 4 +++- 4 files changed, 23 insertions(+), 13 deletions(-) create mode 120000 imperative/src/impl/mgb_core_impl diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index e93b7087..6bee7fb6 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -210,6 +210,7 @@ class trace: info.external = True info.device = x.device info.dtype = x.dtype + info.shape = x.shape if self._capture_as_const: info.bound_data = x @@ -338,7 +339,7 @@ class trace: for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()): info = self._tinfo[h] opnode = info.data_setter = G.InputNode( - device=info.device, dtype=info.dtype, graph=graph + device=info.device, dtype=info.dtype, shape=info.shape, graph=graph ) need_reset_nodes.append(opnode) info.varnode = opnode.outputs[0] @@ -355,7 +356,11 @@ class trace: info.varnode = graph.make_const(info.bound_data._dev_tensor()) else: opnode = info.data_setter = G.InputNode( - *links, device=info.device, dtype=info.dtype, graph=graph + *links, + device=info.device, + dtype=info.dtype, + shape=info.shape, + graph=graph, ) need_reset_nodes.append(opnode) info.varnode, *links = opnode.outputs diff --git a/imperative/src/impl/mgb_core_impl b/imperative/src/impl/mgb_core_impl new file mode 120000 index 00000000..71d5fae4 --- /dev/null +++ b/imperative/src/impl/mgb_core_impl @@ -0,0 +1 @@ +../../../src/core/impl \ No newline at end of file diff --git a/imperative/src/impl/opr_utility.cpp b/imperative/src/impl/opr_utility.cpp index 5a983f0d..3206e50a 100644 --- a/imperative/src/impl/opr_utility.cpp +++ b/imperative/src/impl/opr_utility.cpp @@ -10,6 +10,7 @@ */ #include "megbrain/imperative/opr_utility.h" +#include "./mgb_core_impl/graph/cg_impl.h" // FIXME; setup_config_cn is copied from src/opr/impl/utility.cpp namespace { @@ -64,14 +65,18 @@ SymbolVarArray InputCallback::make(cg::ComputingGraph& graph, void InputCallback::init_output_static_infer_desc() { if (m_output_shape.ndim) { + // Write this shape to static infer manager. The effect is + // that infer_shape_fallible() will return a non-empty shape + // while get_infer_type() remains NO_DESC. Most places check + // infer type before relying on inferred shape so things + // won't break. Memory optimizer however, deliberately omits + // infer type check so it will be able to use this shape for hint. using namespace cg::static_infer; - auto &&mgr = owner_graph()->static_infer_manager(); - auto infer_shape = [this](TensorShape &dest, const InpVal &) { - dest = m_output_shape; - return true; - }; - mgr.register_shape_infer(output(0), - {SourceType::CONSTANT, {}, infer_shape}); + auto* var = output(0); + var->shape(m_output_shape); + auto&& mgr = cg::ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl(); + auto* handle = mgr.get_tag_handler_for_shape(var); + handle->sync_from_var(); } } @@ -86,9 +91,6 @@ cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const { void InputCallback::scn_do_execute() { auto dev_tensor = m_callback(); - if (m_output_shape.ndim) { - mgb_assert(dev_tensor.shape().eq_shape(m_output_shape)); - } output(0)->reset_dev_tensor_from_tensor(dev_tensor); } diff --git a/src/core/impl/graph/memory_optimizer.cpp b/src/core/impl/graph/memory_optimizer.cpp index 00c16ad5..fc542d7a 100644 --- a/src/core/impl/graph/memory_optimizer.cpp +++ b/src/core/impl/graph/memory_optimizer.cpp @@ -99,7 +99,9 @@ MemoryOptimizerHelper::split_into_cn2oprseq(const OprNodeArray& oprseq, auto&& infer_mgr = m_owner_graph->static_infer_manager(); for (auto j : i->output()) { - if (!j->contain_flag(BAD_VAR_FLAG) && is_static_var_shape(j)) { + if (!j->contain_flag(BAD_VAR_FLAG)) { + // omit infer type check + // inferred shape will be used as-is if (auto shape = infer_mgr.infer_shape_fallible(j)) { have_static_shape_out = true; m_var_memsize[j] = j->dtype().size(shape->total_nr_elems());