diff --git a/dnn/src/common/elemwise/opr_impl.cpp b/dnn/src/common/elemwise/opr_impl.cpp index c8d30c26..5ba020b7 100644 --- a/dnn/src/common/elemwise/opr_impl.cpp +++ b/dnn/src/common/elemwise/opr_impl.cpp @@ -174,11 +174,12 @@ void ElemwiseForward::deduce_shape(const TensorShapeArray& src, if (cur_idx >= 0 && dst_idx >= 0) { size_t v0 = dst.shape[dst_idx], v1 = cur.shape[cur_idx]; if (v0 != v1) { - if (v0 != 1 && v1 != 1) + if (v0 > 1 && v1 > 1) err(); } int final_idx = std::max(cur_idx, dst_idx); - dst.shape[final_idx] = std::max(v0, v1); + dst.shape[final_idx] = + (v0 != 0 && v1 != 0) ? std::max(v0, v1) : 0; } else { if (dst_idx < 0) { dst.shape[cur_idx] = cur.shape[cur_idx]; diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 249e6504..a1c8e4bb 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -132,6 +132,7 @@ Elemwise::Elemwise( Super{inputs.at(0)->owner_graph(), config, mode_trait.name, inputs} { init_megdnn_opr(*this, param); + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); if (mode_trait.commutable) { mgb_assert(inputs.size() == 2); add_input({inputs[0], inputs[1]}, AddInputSortType::CUR_ADDED); @@ -371,8 +372,14 @@ void Elemwise::scn_do_execute() { mgb_assert(megdnn_inp.capacity() >= inp.size(), "heap allocation in elemwise exec"); megdnn_inp.resize(inp.size()); - for (size_t i = 0; i < inp.size(); ++ i) + for (size_t i = 0; i < inp.size(); ++ i) { + if (inp[i]->dev_tensor().empty()) { + mgb_assert(output(0)->dev_tensor().empty()); + return; + } megdnn_inp[i] = (inp[i]->dev_tensor().as_megdnn()); + } + mgb_assert(!output(0)->dev_tensor().empty()); megdnn_opr()->param() = param(); call_megdnn_opr_exec( @@ -747,6 +754,15 @@ void Elemwise::record_execute_deps(ExecDependencyArray& deps) { record_megdnn_opr(deps); } +Elemwise::NodeProp* Elemwise::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + for (auto& inp : input()) { + ret->add_dep_type_existing_var(inp, + NodeProp::DepType::VALUE_ALLOW_EMPTY); + } + return ret; +} + /* =========================== TypeCvt =========================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(TypeCvt); diff --git a/src/opr/include/megbrain/opr/basic_arith.h b/src/opr/include/megbrain/opr/basic_arith.h index 45fa5fea..1dfda12d 100644 --- a/src/opr/include/megbrain/opr/basic_arith.h +++ b/src/opr/include/megbrain/opr/basic_arith.h @@ -163,6 +163,7 @@ MGB_DEFINE_OPR_CLASS(Elemwise, intl::ElemwiseBase) // { void record_execute_deps(ExecDependencyArray& deps) override; void add_input_layout_constraint() override; + NodeProp* do_make_node_prop() const override; }; namespace intl { diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index 3bbae075..231abc22 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -649,6 +649,17 @@ namespace { > TernaryTraitTypes; TYPED_TEST_CASE(TestOprBasicArithTernaryElemwise, TernaryTraitTypes); + ::testing::AssertionResult assert_shape_equal(const TensorShape& v0, + const TensorShape& v1) { + if (v0.eq_shape(v1)) + return ::testing::AssertionSuccess() + << v0.to_string() << " == " << v1.to_string(); + else + return ::testing::AssertionFailure() + << v0.to_string() << " != " << v1.to_string(); + } +#define ASSERT_SHAPE_EQ(v0, v1) ASSERT_TRUE(assert_shape_equal(v0, v1)) + } // anonymous namespace template @@ -950,4 +961,58 @@ TEST(TestLayoutUtil, CollectiveCollapse) { check(cc_res5, std_res5); } +TEST(TestOprBasicArithElemwise, EmptyInputOutputUnary) { + HostTensorGenerator<> gen; + auto graph = ComputingGraph::make(); + auto host_x = gen({3, 0, 1, 3}); + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + y = opr::Elemwise::make( + {x}, opr::Elemwise::Param(opr::Elemwise::Param::Mode::RELU)); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + + ASSERT_NO_THROW(func->execute().wait()); + ASSERT_TRUE(host_y.empty()); + ASSERT_TRUE(host_y.shape().is_empty()); + ASSERT_SHAPE_EQ(host_y.shape(), TensorShape({3, 0, 1, 3})); +} + +TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) { + HostTensorGenerator<> gen; + auto graph = ComputingGraph::make(); + auto host_x = gen({0, 8, 1, 7}), host_y = gen({0, 8, 1, 7}); + + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + y = opr::Host2DeviceCopy::make(*graph, host_y), + z = x + y; + HostTensorND host_z; + auto func = graph->compile({make_callback_copy(z, host_z)}); + + // Invalid broadcast + host_y->resize({0, 9, 1, 7}); + ASSERT_ANY_THROW(func->execute().wait()); + + // Broadcast to 0 + host_y->resize({1, 8, 0, 7}); + ASSERT_NO_THROW(func->execute().wait()); + ASSERT_TRUE(host_z.empty()); + ASSERT_TRUE(host_z.shape().is_empty()); + ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 0, 7})); + + // Broadcast to 0 (2) + host_y->resize({2, 8, 1, 7}); + ASSERT_NO_THROW(func->execute().wait()); + ASSERT_TRUE(host_z.empty()); + ASSERT_TRUE(host_z.shape().is_empty()); + ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); + + // Scalar broadcast + z = x + x.make_scalar(1.f); + func = graph->compile({make_callback_copy(z, host_z)}); + ASSERT_NO_THROW(func->execute().wait()); + ASSERT_TRUE(host_z.empty()); + ASSERT_TRUE(host_z.shape().is_empty()); + ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}