@@ -699,7 +699,9 @@ CondExecMark::CondExecMark(VarNode* ppv, const VarNodeArrayView& inputs, | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
add_input({inputs[i]}); | |||
add_output(ssprintf("fwd%zu", i))->dtype(inputs[i]->dtype()); | |||
add_output(ssprintf("fwd%zu", i)) | |||
->dtype(inputs[i]->dtype()) | |||
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||
} | |||
add_input({ppv}); | |||
add_equivalence_component<PODHash<Param>>(&m_param); | |||
@@ -789,6 +791,10 @@ void CondExecMark::add_input_layout_constraint() { | |||
CondExecMark::NodeProp* CondExecMark::do_make_node_prop() const { | |||
auto ret = Super::do_make_node_prop(); | |||
ret->dep_map().at(input().back()) = NodeProp::DepType::DEV_COMP_ORDER; | |||
for (size_t i = 0; i < input().size() - 1; ++ i) { | |||
ret->add_dep_type_existing_var(input(i), | |||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||
} | |||
return ret; | |||
} | |||
@@ -859,7 +865,8 @@ CondExecMerge::CondExecMerge(const VarNodeArrayView& inputs, | |||
// 2. dynamic allocator would wait for all inputs to become ready (see | |||
// VarNodeMemManager::DynamicAllocOprInfo::host_wait_input_ready), | |||
// which would cause infinite waiting for unselected inputs. | |||
ovar->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC); | |||
ovar->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC) | |||
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||
} | |||
MGB_MARK_USED_VAR(mask2str); | |||
@@ -1056,7 +1063,9 @@ void CondExecMerge::init_output_static_infer_desc() { | |||
desc.infer_func = [this](DeviceTensorND& dest, const InpVal& inp) { | |||
auto nr_branch = m_branch_masks.size(); | |||
bool found = false, first = true; | |||
for (size_t i = 0; i < nr_branch; ++i) { | |||
auto&& shape = inp.val.at(nr_branch).shape(); | |||
for (size_t i = 0; i < nr_branch && !shape.is_empty(); ++i) { | |||
if (!inp.val[i].value().ptr<int>()[0]) | |||
continue; | |||
auto&& cur = inp.val.at(nr_branch + i).value(); | |||
@@ -1083,7 +1092,6 @@ void CondExecMerge::init_output_static_infer_desc() { | |||
} | |||
} | |||
if (!found) { | |||
auto&& shape = inp.val.at(nr_branch).shape(); | |||
if (dest.storage().raw_storage().use_count() > 1) { | |||
// likely to be assigned from some input in previous | |||
// runs; we create a new tensor to avoid modifying input | |||
@@ -1115,6 +1123,7 @@ void CondExecMerge::scn_do_execute() { | |||
bool first = true; | |||
auto&& forwarded = m_mem_forwarded; | |||
std::vector<bool> is_shape_empty(nr_out, false); | |||
for (size_t br = 0; br < m_branch_masks.size(); ++br) { | |||
if (!m_branch_masks[br]->enabled()) { | |||
continue; | |||
@@ -1125,6 +1134,10 @@ void CondExecMerge::scn_do_execute() { | |||
for (size_t oidx = 0; oidx < nr_out; ++oidx) { | |||
bool succ = output(oidx)->reset_dev_tensor_from_other_var( | |||
inp(br, oidx)); | |||
if (inp(br, oidx)->shape().is_empty()) { | |||
is_shape_empty[oidx] = true; | |||
continue; | |||
} | |||
if (!is_exact_one()) { | |||
if (forwarded.empty()) { | |||
forwarded.resize(nr_out); | |||
@@ -1144,6 +1157,11 @@ void CondExecMerge::scn_do_execute() { | |||
auto ovar = output(oidx); | |||
auto&& src = inp(br, oidx)->dev_tensor().as_megdnn(); | |||
auto&& dest = ovar->dev_tensor().as_megdnn(); | |||
mgb_assert(src.layout.eq_shape(dest.layout), | |||
"shape mismatch: %s vs %s in CondExecMerge", | |||
src.layout.to_string().c_str(), | |||
dest.layout.to_string().c_str()); | |||
if (is_shape_empty[oidx]) continue; | |||
if (forwarded[oidx]) { | |||
ovar->shape_alloc(ovar->shape()); | |||
auto&& own_dest = ovar->dev_tensor().as_megdnn(); | |||
@@ -1200,6 +1218,10 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const { | |||
// directly | |||
ret->dep_map().at(input().back()) = NodeProp::DepType::DEV_COMP_ORDER; | |||
} | |||
for (size_t i = 0; i < m_param.nr_output * m_branch_masks.size(); ++ i) { | |||
ret->add_dep_type_existing_var(input(i), | |||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||
} | |||
return ret; | |||
} | |||
@@ -14,6 +14,7 @@ | |||
#include "megbrain/opr/basic_arith_wrapper.h" | |||
#include "megbrain/opr/cond.h" | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/opr/misc.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "megbrain/utils/timer.h" | |||
@@ -1285,6 +1286,80 @@ TEST(TestCondExec, MultiShape) { | |||
check(host_d2); | |||
} | |||
TEST(TestCondExec, EmptyShape) { | |||
HostTensorGenerator<> gen; | |||
auto host_pred = gen({1}); | |||
host_pred->ptr<float>()[0] = 0; | |||
static auto empty_in_empty_out = [](SymbolVar x) { | |||
return x; | |||
}; | |||
static auto empty_in_scalar_out = [](SymbolVar x) { | |||
return opr::Concat::make({x, x.make_scalar(1.f)}, 0); | |||
}; | |||
static auto scalar_in_empty_out = [](SymbolVar x) { | |||
return opr::CondTake::make(x, x, {})[0]; // whether eq 0 | |||
}; | |||
{ // EXACT_ONE | |||
auto graph = ComputingGraph::make(); | |||
auto pred = opr::Host2DeviceCopy::make(*graph, host_pred), | |||
empty = opr::ImmutableTensor::make(*graph, *gen({0})), | |||
scalar = pred.make_scalar(1.f), | |||
y0 = empty_in_empty_out(make_one_cond(pred + 1, empty)), | |||
y1 = empty_in_scalar_out(make_one_cond(pred, empty)), | |||
y2 = scalar_in_empty_out(make_one_cond(pred - 1, scalar)), | |||
z = merge_one_out({y0, y1, y2}, MergeMode::EXACT_ONE); | |||
HostTensorND host_z; | |||
auto func = graph->compile({make_callback_copy(z, host_z)}); | |||
func->execute(); | |||
ASSERT_TRUE(host_z.layout().is_empty()); | |||
host_pred->ptr<float>()[0] = 1; | |||
func->execute(); | |||
ASSERT_EQ(1.f, host_z.ptr<float>()[0]); | |||
host_pred->ptr<float>()[0] = 2; | |||
func->execute(); | |||
ASSERT_TRUE(host_z.layout().is_empty()); | |||
} | |||
{ // SUM | |||
auto graph = ComputingGraph::make(); | |||
host_pred->ptr<float>()[0] = 1; | |||
auto pred = opr::Host2DeviceCopy::make(*graph, host_pred), | |||
empty = opr::ImmutableTensor::make(*graph, *gen({0})), | |||
scalar = pred.make_scalar(1.f), | |||
y0 = empty_in_empty_out(make_one_cond(pred, empty)), | |||
y1 = scalar_in_empty_out(make_one_cond(pred, scalar)), | |||
z = merge_one_out({y0, y1}, MergeMode::SUM); | |||
HostTensorND host_z; | |||
auto func = graph->compile({make_callback_copy(z, host_z)}); | |||
func->execute(); | |||
ASSERT_TRUE(host_z.layout().is_empty()); | |||
} | |||
{ // TAKE GRAD | |||
auto graph = ComputingGraph::make(); | |||
host_pred->ptr<float>()[0] = 0; | |||
auto pred = opr::Host2DeviceCopy::make(*graph, host_pred), | |||
x = pred.make_scalar(1.2f), | |||
y0 = opr::CondTake::make(make_one_cond(pred + 1, x), pred, {})[0], | |||
y1 = make_one_cond(pred, x.make_scalar(3.4f)), | |||
z = merge_one_out({y0, y1}, MergeMode::EXACT_ONE), | |||
g = cg::grad(z, x); | |||
HostTensorND host_z, host_g; | |||
auto func = graph->compile({ | |||
make_callback_copy(z, host_z), make_callback_copy(g, host_g)}); | |||
func->execute(); | |||
ASSERT_EQ(1.2f, host_z.ptr<float>()[0]); | |||
ASSERT_EQ(1.f, host_g.ptr<float>()[0]); | |||
host_pred->ptr<float>()[0] = 1; | |||
func->execute(); | |||
ASSERT_EQ(3.4f, host_z.ptr<float>()[0]); | |||
ASSERT_EQ(0.f, host_g.ptr<float>()[0]); | |||
} | |||
} | |||
#endif // MGB_ENABLE_COND_EXEC | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |