Browse Source

fix(mge): fix sublinear memory in jit.trace

GitOrigin-RevId: 190a330a8c
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
1fe8a21299
4 changed files with 23 additions and 13 deletions
  1. +7
    -2
      imperative/python/megengine/jit/tracing.py
  2. +1
    -0
      imperative/src/impl/mgb_core_impl
  3. +12
    -10
      imperative/src/impl/opr_utility.cpp
  4. +3
    -1
      src/core/impl/graph/memory_optimizer.cpp

+ 7
- 2
imperative/python/megengine/jit/tracing.py View File

@@ -210,6 +210,7 @@ class trace:
info.external = True info.external = True
info.device = x.device info.device = x.device
info.dtype = x.dtype info.dtype = x.dtype
info.shape = x.shape
if self._capture_as_const: if self._capture_as_const:
info.bound_data = x info.bound_data = x


@@ -338,7 +339,7 @@ class trace:
for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()): for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
info = self._tinfo[h] info = self._tinfo[h]
opnode = info.data_setter = G.InputNode( opnode = info.data_setter = G.InputNode(
device=info.device, dtype=info.dtype, graph=graph
device=info.device, dtype=info.dtype, shape=info.shape, graph=graph
) )
need_reset_nodes.append(opnode) need_reset_nodes.append(opnode)
info.varnode = opnode.outputs[0] info.varnode = opnode.outputs[0]
@@ -355,7 +356,11 @@ class trace:
info.varnode = graph.make_const(info.bound_data._dev_tensor()) info.varnode = graph.make_const(info.bound_data._dev_tensor())
else: else:
opnode = info.data_setter = G.InputNode( opnode = info.data_setter = G.InputNode(
*links, device=info.device, dtype=info.dtype, graph=graph
*links,
device=info.device,
dtype=info.dtype,
shape=info.shape,
graph=graph,
) )
need_reset_nodes.append(opnode) need_reset_nodes.append(opnode)
info.varnode, *links = opnode.outputs info.varnode, *links = opnode.outputs


+ 1
- 0
imperative/src/impl/mgb_core_impl View File

@@ -0,0 +1 @@
../../../src/core/impl

+ 12
- 10
imperative/src/impl/opr_utility.cpp View File

@@ -10,6 +10,7 @@
*/ */


#include "megbrain/imperative/opr_utility.h" #include "megbrain/imperative/opr_utility.h"
#include "./mgb_core_impl/graph/cg_impl.h"


// FIXME; setup_config_cn is copied from src/opr/impl/utility.cpp // FIXME; setup_config_cn is copied from src/opr/impl/utility.cpp
namespace { namespace {
@@ -64,14 +65,18 @@ SymbolVarArray InputCallback::make(cg::ComputingGraph& graph,


void InputCallback::init_output_static_infer_desc() { void InputCallback::init_output_static_infer_desc() {
if (m_output_shape.ndim) { if (m_output_shape.ndim) {
// Write this shape to static infer manager. The effect is
// that infer_shape_fallible() will return a non-empty shape
// while get_infer_type() remains NO_DESC. Most places check
// infer type before relying on inferred shape so things
// won't break. Memory optimizer however, deliberately omits
// infer type check so it will be able to use this shape for hint.
using namespace cg::static_infer; using namespace cg::static_infer;
auto &&mgr = owner_graph()->static_infer_manager();
auto infer_shape = [this](TensorShape &dest, const InpVal &) {
dest = m_output_shape;
return true;
};
mgr.register_shape_infer(output(0),
{SourceType::CONSTANT, {}, infer_shape});
auto* var = output(0);
var->shape(m_output_shape);
auto&& mgr = cg::ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl();
auto* handle = mgr.get_tag_handler_for_shape(var);
handle->sync_from_var();
} }
} }


@@ -86,9 +91,6 @@ cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const {


void InputCallback::scn_do_execute() { void InputCallback::scn_do_execute() {
auto dev_tensor = m_callback(); auto dev_tensor = m_callback();
if (m_output_shape.ndim) {
mgb_assert(dev_tensor.shape().eq_shape(m_output_shape));
}
output(0)->reset_dev_tensor_from_tensor(dev_tensor); output(0)->reset_dev_tensor_from_tensor(dev_tensor);
} }




+ 3
- 1
src/core/impl/graph/memory_optimizer.cpp View File

@@ -99,7 +99,9 @@ MemoryOptimizerHelper::split_into_cn2oprseq(const OprNodeArray& oprseq,


auto&& infer_mgr = m_owner_graph->static_infer_manager(); auto&& infer_mgr = m_owner_graph->static_infer_manager();
for (auto j : i->output()) { for (auto j : i->output()) {
if (!j->contain_flag(BAD_VAR_FLAG) && is_static_var_shape(j)) {
if (!j->contain_flag(BAD_VAR_FLAG)) {
// omit infer type check
// inferred shape will be used as-is
if (auto shape = infer_mgr.infer_shape_fallible(j)) { if (auto shape = infer_mgr.infer_shape_fallible(j)) {
have_static_shape_out = true; have_static_shape_out = true;
m_var_memsize[j] = j->dtype().size(shape->total_nr_elems()); m_var_memsize[j] = j->dtype().size(shape->total_nr_elems());


Loading…
Cancel
Save