|
@@ -699,7 +699,9 @@ CondExecMark::CondExecMark(VarNode* ppv, const VarNodeArrayView& inputs, |
|
|
|
|
|
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
add_input({inputs[i]}); |
|
|
add_input({inputs[i]}); |
|
|
add_output(ssprintf("fwd%zu", i))->dtype(inputs[i]->dtype()); |
|
|
|
|
|
|
|
|
add_output(ssprintf("fwd%zu", i)) |
|
|
|
|
|
->dtype(inputs[i]->dtype()) |
|
|
|
|
|
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); |
|
|
} |
|
|
} |
|
|
add_input({ppv}); |
|
|
add_input({ppv}); |
|
|
add_equivalence_component<PODHash<Param>>(&m_param); |
|
|
add_equivalence_component<PODHash<Param>>(&m_param); |
|
@@ -789,6 +791,10 @@ void CondExecMark::add_input_layout_constraint() { |
|
|
CondExecMark::NodeProp* CondExecMark::do_make_node_prop() const { |
|
|
CondExecMark::NodeProp* CondExecMark::do_make_node_prop() const { |
|
|
auto ret = Super::do_make_node_prop(); |
|
|
auto ret = Super::do_make_node_prop(); |
|
|
ret->dep_map().at(input().back()) = NodeProp::DepType::DEV_COMP_ORDER; |
|
|
ret->dep_map().at(input().back()) = NodeProp::DepType::DEV_COMP_ORDER; |
|
|
|
|
|
for (size_t i = 0; i < input().size() - 1; ++ i) { |
|
|
|
|
|
ret->add_dep_type_existing_var(input(i), |
|
|
|
|
|
NodeProp::DepType::VALUE_ALLOW_EMPTY); |
|
|
|
|
|
} |
|
|
return ret; |
|
|
return ret; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@@ -859,7 +865,8 @@ CondExecMerge::CondExecMerge(const VarNodeArrayView& inputs, |
|
|
// 2. dynamic allocator would wait for all inputs to become ready (see |
|
|
// 2. dynamic allocator would wait for all inputs to become ready (see |
|
|
// VarNodeMemManager::DynamicAllocOprInfo::host_wait_input_ready), |
|
|
// VarNodeMemManager::DynamicAllocOprInfo::host_wait_input_ready), |
|
|
// which would cause infinite waiting for unselected inputs. |
|
|
// which would cause infinite waiting for unselected inputs. |
|
|
ovar->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC); |
|
|
|
|
|
|
|
|
ovar->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC) |
|
|
|
|
|
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
MGB_MARK_USED_VAR(mask2str); |
|
|
MGB_MARK_USED_VAR(mask2str); |
|
@@ -1056,7 +1063,9 @@ void CondExecMerge::init_output_static_infer_desc() { |
|
|
desc.infer_func = [this](DeviceTensorND& dest, const InpVal& inp) { |
|
|
desc.infer_func = [this](DeviceTensorND& dest, const InpVal& inp) { |
|
|
auto nr_branch = m_branch_masks.size(); |
|
|
auto nr_branch = m_branch_masks.size(); |
|
|
bool found = false, first = true; |
|
|
bool found = false, first = true; |
|
|
for (size_t i = 0; i < nr_branch; ++i) { |
|
|
|
|
|
|
|
|
auto&& shape = inp.val.at(nr_branch).shape(); |
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < nr_branch && !shape.is_empty(); ++i) { |
|
|
if (!inp.val[i].value().ptr<int>()[0]) |
|
|
if (!inp.val[i].value().ptr<int>()[0]) |
|
|
continue; |
|
|
continue; |
|
|
auto&& cur = inp.val.at(nr_branch + i).value(); |
|
|
auto&& cur = inp.val.at(nr_branch + i).value(); |
|
@@ -1083,7 +1092,6 @@ void CondExecMerge::init_output_static_infer_desc() { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
if (!found) { |
|
|
if (!found) { |
|
|
auto&& shape = inp.val.at(nr_branch).shape(); |
|
|
|
|
|
if (dest.storage().raw_storage().use_count() > 1) { |
|
|
if (dest.storage().raw_storage().use_count() > 1) { |
|
|
// likely to be assigned from some input in previous |
|
|
// likely to be assigned from some input in previous |
|
|
// runs; we create a new tensor to avoid modifying input |
|
|
// runs; we create a new tensor to avoid modifying input |
|
@@ -1115,6 +1123,7 @@ void CondExecMerge::scn_do_execute() { |
|
|
|
|
|
|
|
|
bool first = true; |
|
|
bool first = true; |
|
|
auto&& forwarded = m_mem_forwarded; |
|
|
auto&& forwarded = m_mem_forwarded; |
|
|
|
|
|
std::vector<bool> is_shape_empty(nr_out, false); |
|
|
for (size_t br = 0; br < m_branch_masks.size(); ++br) { |
|
|
for (size_t br = 0; br < m_branch_masks.size(); ++br) { |
|
|
if (!m_branch_masks[br]->enabled()) { |
|
|
if (!m_branch_masks[br]->enabled()) { |
|
|
continue; |
|
|
continue; |
|
@@ -1125,6 +1134,10 @@ void CondExecMerge::scn_do_execute() { |
|
|
for (size_t oidx = 0; oidx < nr_out; ++oidx) { |
|
|
for (size_t oidx = 0; oidx < nr_out; ++oidx) { |
|
|
bool succ = output(oidx)->reset_dev_tensor_from_other_var( |
|
|
bool succ = output(oidx)->reset_dev_tensor_from_other_var( |
|
|
inp(br, oidx)); |
|
|
inp(br, oidx)); |
|
|
|
|
|
if (inp(br, oidx)->shape().is_empty()) { |
|
|
|
|
|
is_shape_empty[oidx] = true; |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
if (!is_exact_one()) { |
|
|
if (!is_exact_one()) { |
|
|
if (forwarded.empty()) { |
|
|
if (forwarded.empty()) { |
|
|
forwarded.resize(nr_out); |
|
|
forwarded.resize(nr_out); |
|
@@ -1144,6 +1157,11 @@ void CondExecMerge::scn_do_execute() { |
|
|
auto ovar = output(oidx); |
|
|
auto ovar = output(oidx); |
|
|
auto&& src = inp(br, oidx)->dev_tensor().as_megdnn(); |
|
|
auto&& src = inp(br, oidx)->dev_tensor().as_megdnn(); |
|
|
auto&& dest = ovar->dev_tensor().as_megdnn(); |
|
|
auto&& dest = ovar->dev_tensor().as_megdnn(); |
|
|
|
|
|
mgb_assert(src.layout.eq_shape(dest.layout), |
|
|
|
|
|
"shape mismatch: %s vs %s in CondExecMerge", |
|
|
|
|
|
src.layout.to_string().c_str(), |
|
|
|
|
|
dest.layout.to_string().c_str()); |
|
|
|
|
|
if (is_shape_empty[oidx]) continue; |
|
|
if (forwarded[oidx]) { |
|
|
if (forwarded[oidx]) { |
|
|
ovar->shape_alloc(ovar->shape()); |
|
|
ovar->shape_alloc(ovar->shape()); |
|
|
auto&& own_dest = ovar->dev_tensor().as_megdnn(); |
|
|
auto&& own_dest = ovar->dev_tensor().as_megdnn(); |
|
@@ -1200,6 +1218,10 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const { |
|
|
// directly |
|
|
// directly |
|
|
ret->dep_map().at(input().back()) = NodeProp::DepType::DEV_COMP_ORDER; |
|
|
ret->dep_map().at(input().back()) = NodeProp::DepType::DEV_COMP_ORDER; |
|
|
} |
|
|
} |
|
|
|
|
|
for (size_t i = 0; i < m_param.nr_output * m_branch_masks.size(); ++ i) { |
|
|
|
|
|
ret->add_dep_type_existing_var(input(i), |
|
|
|
|
|
NodeProp::DepType::VALUE_ALLOW_EMPTY); |
|
|
|
|
|
} |
|
|
return ret; |
|
|
return ret; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|