|
|
@@ -1496,6 +1496,14 @@ void ParamPackSplit::init_output_dtype() { |
|
|
|
// already initialized in constructor |
|
|
|
} |
|
|
|
|
|
|
|
void ParamPackSplit::init_rt_force_dynamic_mem_alloc_imply_chain() { |
|
|
|
for (size_t i = 0; i < output().size(); ++i) { |
|
|
|
auto s = input(0), t = output(i); |
|
|
|
s->add_rt_force_dynamic_mem_alloc_imply_chain(t); |
|
|
|
t->add_rt_force_dynamic_mem_alloc_imply_chain(s); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ParamPackSplit::mem_plan_fwd_in2out_readonly() { |
|
|
|
mgb_assert(m_offsets.size() == output().size() * 2); |
|
|
|
for (size_t i = 0; i < output().size(); i++) { |
|
|
@@ -1516,16 +1524,19 @@ void ParamPackSplit::init_output_static_infer_desc() { |
|
|
|
using namespace std::placeholders; |
|
|
|
auto&& mgr = owner_graph()->static_infer_manager(); |
|
|
|
|
|
|
|
DepVal shp_deps{{input(0), DepType::SHAPE}}; |
|
|
|
|
|
|
|
for (size_t i = 0; i < output().size(); i++) { |
|
|
|
auto ov = output(i); |
|
|
|
mgr.register_shape_infer( |
|
|
|
ov, {SourceType::DEP, shp_deps, |
|
|
|
ov, {SourceType::CONSTANT, {}, |
|
|
|
std::bind(&ParamPackSplit::infer_shape, this, i, _1, _2)}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ParamPackSplit::scn_do_execute() { |
|
|
|
int inp_size = input(0)->shape().total_nr_elems(); |
|
|
|
mgb_assert(inp_size == m_offsets.back(), "input shape should match offsets"); |
|
|
|
} |
|
|
|
|
|
|
|
#ifdef MGB_ENABLE_GRAD |
|
|
|
MGB_IMPL_OPR_GRAD(ParamPackSplit) { |
|
|
|
mgb_assert(out_grad.size() == opr.output().size()); |
|
|
|