Browse Source

fix(mge/parampacksplit): fix param pack split mem forward

GitOrigin-RevId: 8c001b73ff
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
a66d4b8bb8
3 changed files with 18 additions and 4 deletions
  1. +1
    -0
      imperative/python/megengine/jit/tracing.py
  2. +14
    -3
      src/opr/impl/tensor_manip.cpp
  3. +3
    -1
      src/opr/include/megbrain/opr/tensor_manip.h

+ 1
- 0
imperative/python/megengine/jit/tracing.py View File

@@ -308,6 +308,7 @@ class trace:


def _apply_graph_options(self, graph): def _apply_graph_options(self, graph):


graph.options.seq_opt.enable_seq_comp_node_opt = False
# sublinear # sublinear
if self._sublinear_memory_config is not None: if self._sublinear_memory_config is not None:
graph.options.enable_sublinear_memory_opt = True graph.options.enable_sublinear_memory_opt = True


+ 14
- 3
src/opr/impl/tensor_manip.cpp View File

@@ -1496,6 +1496,14 @@ void ParamPackSplit::init_output_dtype() {
// already initialized in constructor // already initialized in constructor
} }


void ParamPackSplit::init_rt_force_dynamic_mem_alloc_imply_chain() {
for (size_t i = 0; i < output().size(); ++i) {
auto s = input(0), t = output(i);
s->add_rt_force_dynamic_mem_alloc_imply_chain(t);
t->add_rt_force_dynamic_mem_alloc_imply_chain(s);
}
}

void ParamPackSplit::mem_plan_fwd_in2out_readonly() { void ParamPackSplit::mem_plan_fwd_in2out_readonly() {
mgb_assert(m_offsets.size() == output().size() * 2); mgb_assert(m_offsets.size() == output().size() * 2);
for (size_t i = 0; i < output().size(); i++) { for (size_t i = 0; i < output().size(); i++) {
@@ -1516,16 +1524,19 @@ void ParamPackSplit::init_output_static_infer_desc() {
using namespace std::placeholders; using namespace std::placeholders;
auto&& mgr = owner_graph()->static_infer_manager(); auto&& mgr = owner_graph()->static_infer_manager();


DepVal shp_deps{{input(0), DepType::SHAPE}};

for (size_t i = 0; i < output().size(); i++) { for (size_t i = 0; i < output().size(); i++) {
auto ov = output(i); auto ov = output(i);
mgr.register_shape_infer( mgr.register_shape_infer(
ov, {SourceType::DEP, shp_deps,
ov, {SourceType::CONSTANT, {},
std::bind(&ParamPackSplit::infer_shape, this, i, _1, _2)}); std::bind(&ParamPackSplit::infer_shape, this, i, _1, _2)});
} }
} }


void ParamPackSplit::scn_do_execute() {
int inp_size = input(0)->shape().total_nr_elems();
mgb_assert(inp_size == m_offsets.back(), "input shape should match offsets");
}

#ifdef MGB_ENABLE_GRAD #ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ParamPackSplit) { MGB_IMPL_OPR_GRAD(ParamPackSplit) {
mgb_assert(out_grad.size() == opr.output().size()); mgb_assert(out_grad.size() == opr.output().size());


+ 3
- 1
src/opr/include/megbrain/opr/tensor_manip.h View File

@@ -591,7 +591,7 @@ MGB_DEFINE_OPR_CLASS(ParamPackSplit, cg::SingleCNOperatorNodeBase) // {
TensorShapeArray m_shapes; TensorShapeArray m_shapes;
std::vector<dt_int32> m_offsets; std::vector<dt_int32> m_offsets;


void scn_do_execute() override{};
void scn_do_execute() override;
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
bool infer_shape(size_t index, TensorShape &dest, bool infer_shape(size_t index, TensorShape &dest,
const cg::static_infer::InpVal &inp); const cg::static_infer::InpVal &inp);
@@ -615,6 +615,8 @@ public:
const TensorShapeArray& get_output_shapes() const { const TensorShapeArray& get_output_shapes() const {
return m_shapes; return m_shapes;
} }

void init_rt_force_dynamic_mem_alloc_imply_chain() override;
}; };


/*! /*!


Loading…
Cancel
Save