|
|
@@ -33,6 +33,37 @@ void mixin::init_rt_force_dynamic_mem_alloc_imply_chain_for_dyn_pass_i2o( |
|
|
|
valid_out->add_rt_force_dynamic_mem_alloc_imply_chain(opr.input(0)); |
|
|
|
} |
|
|
|
|
|
|
|
/* ===================== FwdIn2OutWritableHelper ===================== */ |
|
|
|
void FwdIn2OutWritableHelper::mixin_mem_plan_fwd_in2out_writable( |
|
|
|
OperatorNodeBase& opr) { |
|
|
|
auto&& inp = opr.input(); |
|
|
|
auto isize = inp.size(); |
|
|
|
std::vector<bool> have_conflict(isize, false); |
|
|
|
for (size_t i = 0; i < isize; ++i) { |
|
|
|
for (size_t j = i + 1; j < isize; ++j) { |
|
|
|
auto type = cg::get_mem_plan_intersection_type(inp[i], inp[j]); |
|
|
|
using Type = cg::MemPlanIntersectionType; |
|
|
|
bool overlap = type == Type::OVERLAP; |
|
|
|
bool self_fwd = type == Type::IDENTICAL && |
|
|
|
(!inp[i]->layout().is_contiguous() || |
|
|
|
!inp[j]->layout().is_contiguous()); |
|
|
|
if (overlap || self_fwd) { |
|
|
|
have_conflict[i] = true; |
|
|
|
have_conflict[j] = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
auto o = opr.output(0); |
|
|
|
for (size_t idx = 0; idx < isize; ++ idx) { |
|
|
|
auto i = inp[idx]; |
|
|
|
// equal shape means no broadcast |
|
|
|
if (!have_conflict[idx] && o->shape().eq_shape(i->shape()) && |
|
|
|
o->dtype().enumv() == i->dtype().enumv() && |
|
|
|
i->layout().is_contiguous()) |
|
|
|
o->set_fwd_in2out_writable(i); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/* ===================== ReadonlyFwdHelper ===================== */ |
|
|
|
|
|
|
|
void ReadonlyFwdHelper::mixin_rofwd_init_mem_plan(OperatorNodeBase &opr) { |
|
|
|