... and add tests, should've done this the first time...
GitOrigin-RevId: 5d73fc4c7b
tags/v0.5.0
@@ -174,11 +174,12 @@ void ElemwiseForward::deduce_shape(const TensorShapeArray& src, | |||||
if (cur_idx >= 0 && dst_idx >= 0) { | if (cur_idx >= 0 && dst_idx >= 0) { | ||||
size_t v0 = dst.shape[dst_idx], v1 = cur.shape[cur_idx]; | size_t v0 = dst.shape[dst_idx], v1 = cur.shape[cur_idx]; | ||||
if (v0 != v1) { | if (v0 != v1) { | ||||
if (v0 != 1 && v1 != 1) | |||||
if (v0 > 1 && v1 > 1) | |||||
err(); | err(); | ||||
} | } | ||||
int final_idx = std::max(cur_idx, dst_idx); | int final_idx = std::max(cur_idx, dst_idx); | ||||
dst.shape[final_idx] = std::max(v0, v1); | |||||
dst.shape[final_idx] = | |||||
(v0 != 0 && v1 != 0) ? std::max(v0, v1) : 0; | |||||
} else { | } else { | ||||
if (dst_idx < 0) { | if (dst_idx < 0) { | ||||
dst.shape[cur_idx] = cur.shape[cur_idx]; | dst.shape[cur_idx] = cur.shape[cur_idx]; | ||||
@@ -132,6 +132,7 @@ Elemwise::Elemwise( | |||||
Super{inputs.at(0)->owner_graph(), config, mode_trait.name, inputs} | Super{inputs.at(0)->owner_graph(), config, mode_trait.name, inputs} | ||||
{ | { | ||||
init_megdnn_opr(*this, param); | init_megdnn_opr(*this, param); | ||||
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
if (mode_trait.commutable) { | if (mode_trait.commutable) { | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
add_input({inputs[0], inputs[1]}, AddInputSortType::CUR_ADDED); | add_input({inputs[0], inputs[1]}, AddInputSortType::CUR_ADDED); | ||||
@@ -371,8 +372,14 @@ void Elemwise::scn_do_execute() { | |||||
mgb_assert(megdnn_inp.capacity() >= inp.size(), | mgb_assert(megdnn_inp.capacity() >= inp.size(), | ||||
"heap allocation in elemwise exec"); | "heap allocation in elemwise exec"); | ||||
megdnn_inp.resize(inp.size()); | megdnn_inp.resize(inp.size()); | ||||
for (size_t i = 0; i < inp.size(); ++ i) | |||||
for (size_t i = 0; i < inp.size(); ++ i) { | |||||
if (inp[i]->dev_tensor().empty()) { | |||||
mgb_assert(output(0)->dev_tensor().empty()); | |||||
return; | |||||
} | |||||
megdnn_inp[i] = (inp[i]->dev_tensor().as_megdnn()); | megdnn_inp[i] = (inp[i]->dev_tensor().as_megdnn()); | ||||
} | |||||
mgb_assert(!output(0)->dev_tensor().empty()); | |||||
megdnn_opr()->param() = param(); | megdnn_opr()->param() = param(); | ||||
call_megdnn_opr_exec( | call_megdnn_opr_exec( | ||||
@@ -747,6 +754,15 @@ void Elemwise::record_execute_deps(ExecDependencyArray& deps) { | |||||
record_megdnn_opr(deps); | record_megdnn_opr(deps); | ||||
} | } | ||||
Elemwise::NodeProp* Elemwise::do_make_node_prop() const { | |||||
auto ret = Super::do_make_node_prop(); | |||||
for (auto& inp : input()) { | |||||
ret->add_dep_type_existing_var(inp, | |||||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
} | |||||
return ret; | |||||
} | |||||
/* =========================== TypeCvt =========================== */ | /* =========================== TypeCvt =========================== */ | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(TypeCvt); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(TypeCvt); | ||||
@@ -163,6 +163,7 @@ MGB_DEFINE_OPR_CLASS(Elemwise, intl::ElemwiseBase) // { | |||||
void record_execute_deps(ExecDependencyArray& deps) override; | void record_execute_deps(ExecDependencyArray& deps) override; | ||||
void add_input_layout_constraint() override; | void add_input_layout_constraint() override; | ||||
NodeProp* do_make_node_prop() const override; | |||||
}; | }; | ||||
namespace intl { | namespace intl { | ||||
@@ -649,6 +649,17 @@ namespace { | |||||
> TernaryTraitTypes; | > TernaryTraitTypes; | ||||
TYPED_TEST_CASE(TestOprBasicArithTernaryElemwise, TernaryTraitTypes); | TYPED_TEST_CASE(TestOprBasicArithTernaryElemwise, TernaryTraitTypes); | ||||
::testing::AssertionResult assert_shape_equal(const TensorShape& v0, | |||||
const TensorShape& v1) { | |||||
if (v0.eq_shape(v1)) | |||||
return ::testing::AssertionSuccess() | |||||
<< v0.to_string() << " == " << v1.to_string(); | |||||
else | |||||
return ::testing::AssertionFailure() | |||||
<< v0.to_string() << " != " << v1.to_string(); | |||||
} | |||||
#define ASSERT_SHAPE_EQ(v0, v1) ASSERT_TRUE(assert_shape_equal(v0, v1)) | |||||
} // anonymous namespace | } // anonymous namespace | ||||
template<typename Trait, typename dtype> | template<typename Trait, typename dtype> | ||||
@@ -950,4 +961,58 @@ TEST(TestLayoutUtil, CollectiveCollapse) { | |||||
check(cc_res5, std_res5); | check(cc_res5, std_res5); | ||||
} | } | ||||
TEST(TestOprBasicArithElemwise, EmptyInputOutputUnary) { | |||||
HostTensorGenerator<> gen; | |||||
auto graph = ComputingGraph::make(); | |||||
auto host_x = gen({3, 0, 1, 3}); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||||
y = opr::Elemwise::make( | |||||
{x}, opr::Elemwise::Param(opr::Elemwise::Param::Mode::RELU)); | |||||
HostTensorND host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y)}); | |||||
ASSERT_NO_THROW(func->execute().wait()); | |||||
ASSERT_TRUE(host_y.empty()); | |||||
ASSERT_TRUE(host_y.shape().is_empty()); | |||||
ASSERT_SHAPE_EQ(host_y.shape(), TensorShape({3, 0, 1, 3})); | |||||
} | |||||
TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) { | |||||
HostTensorGenerator<> gen; | |||||
auto graph = ComputingGraph::make(); | |||||
auto host_x = gen({0, 8, 1, 7}), host_y = gen({0, 8, 1, 7}); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||||
y = opr::Host2DeviceCopy::make(*graph, host_y), | |||||
z = x + y; | |||||
HostTensorND host_z; | |||||
auto func = graph->compile({make_callback_copy(z, host_z)}); | |||||
// Invalid broadcast | |||||
host_y->resize({0, 9, 1, 7}); | |||||
ASSERT_ANY_THROW(func->execute().wait()); | |||||
// Broadcast to 0 | |||||
host_y->resize({1, 8, 0, 7}); | |||||
ASSERT_NO_THROW(func->execute().wait()); | |||||
ASSERT_TRUE(host_z.empty()); | |||||
ASSERT_TRUE(host_z.shape().is_empty()); | |||||
ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 0, 7})); | |||||
// Broadcast to 0 (2) | |||||
host_y->resize({2, 8, 1, 7}); | |||||
ASSERT_NO_THROW(func->execute().wait()); | |||||
ASSERT_TRUE(host_z.empty()); | |||||
ASSERT_TRUE(host_z.shape().is_empty()); | |||||
ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); | |||||
// Scalar broadcast | |||||
z = x + x.make_scalar(1.f); | |||||
func = graph->compile({make_callback_copy(z, host_z)}); | |||||
ASSERT_NO_THROW(func->execute().wait()); | |||||
ASSERT_TRUE(host_z.empty()); | |||||
ASSERT_TRUE(host_z.shape().is_empty()); | |||||
ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |