From a66d4b8bb87276f176f397617b96a9fe47081a9c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 9 Sep 2020 22:08:33 +0800 Subject: [PATCH] fix(mge/parampacksplit): fix param pack split mem forward GitOrigin-RevId: 8c001b73ffbd086f0cfff7cac2a4c1037bfcecfb --- imperative/python/megengine/jit/tracing.py | 1 + src/opr/impl/tensor_manip.cpp | 17 ++++++++++++++--- src/opr/include/megbrain/opr/tensor_manip.h | 4 +++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 1b40822f..2c116f29 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -308,6 +308,7 @@ class trace: def _apply_graph_options(self, graph): + graph.options.seq_opt.enable_seq_comp_node_opt = False # sublinear if self._sublinear_memory_config is not None: graph.options.enable_sublinear_memory_opt = True diff --git a/src/opr/impl/tensor_manip.cpp b/src/opr/impl/tensor_manip.cpp index e15d909b..862e44ce 100644 --- a/src/opr/impl/tensor_manip.cpp +++ b/src/opr/impl/tensor_manip.cpp @@ -1496,6 +1496,14 @@ void ParamPackSplit::init_output_dtype() { // already initialized in constructor } +void ParamPackSplit::init_rt_force_dynamic_mem_alloc_imply_chain() { + for (size_t i = 0; i < output().size(); ++i) { + auto s = input(0), t = output(i); + s->add_rt_force_dynamic_mem_alloc_imply_chain(t); + t->add_rt_force_dynamic_mem_alloc_imply_chain(s); + } +} + void ParamPackSplit::mem_plan_fwd_in2out_readonly() { mgb_assert(m_offsets.size() == output().size() * 2); for (size_t i = 0; i < output().size(); i++) { @@ -1516,16 +1524,19 @@ void ParamPackSplit::init_output_static_infer_desc() { using namespace std::placeholders; auto&& mgr = owner_graph()->static_infer_manager(); - DepVal shp_deps{{input(0), DepType::SHAPE}}; - for (size_t i = 0; i < output().size(); i++) { auto ov = output(i); mgr.register_shape_infer( - ov, {SourceType::DEP, shp_deps, + ov, {SourceType::CONSTANT, {}, std::bind(&ParamPackSplit::infer_shape, this, i, _1, _2)}); } } +void ParamPackSplit::scn_do_execute() { + int inp_size = input(0)->shape().total_nr_elems(); + mgb_assert(inp_size == m_offsets.back(), "input shape should match offsets"); +} + #ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ParamPackSplit) { mgb_assert(out_grad.size() == opr.output().size()); diff --git a/src/opr/include/megbrain/opr/tensor_manip.h b/src/opr/include/megbrain/opr/tensor_manip.h index f02c9a38..345c7f6d 100644 --- a/src/opr/include/megbrain/opr/tensor_manip.h +++ b/src/opr/include/megbrain/opr/tensor_manip.h @@ -591,7 +591,7 @@ MGB_DEFINE_OPR_CLASS(ParamPackSplit, cg::SingleCNOperatorNodeBase) // { TensorShapeArray m_shapes; std::vector m_offsets; - void scn_do_execute() override{}; + void scn_do_execute() override; void init_output_static_infer_desc() override; bool infer_shape(size_t index, TensorShape &dest, const cg::static_infer::InpVal &inp); @@ -615,6 +615,8 @@ public: const TensorShapeArray& get_output_shapes() const { return m_shapes; } + + void init_rt_force_dynamic_mem_alloc_imply_chain() override; }; /*!