From b74afde89e90046223665592218b05c3a4ee8ea2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 16 Jul 2021 18:50:35 +0800 Subject: [PATCH] feat(mgb/opr): let reduce support empty IO GitOrigin-RevId: 88b37123a8fa7f7dafbb1b0c506fb79f1e5a24c4 --- .../python/test/unit/functional/test_math.py | 27 +++++++++- imperative/src/impl/proxy_graph.cpp | 5 ++ src/core/test/graph/misc.cpp | 6 +-- src/opr/impl/basic_arith.cpp | 53 +++++++++++++++++++- src/opr/include/megbrain/opr/basic_arith.h | 1 + src/opr/test/basic_arith/reduction.cpp | 57 ++++++++++++++++++++++ 6 files changed, 143 insertions(+), 6 deletions(-) diff --git a/imperative/python/test/unit/functional/test_math.py b/imperative/python/test/unit/functional/test_math.py index d013dfbd..89ba78db 100644 --- a/imperative/python/test/unit/functional/test_math.py +++ b/imperative/python/test/unit/functional/test_math.py @@ -13,7 +13,7 @@ import pytest from utils import opr_test import megengine.functional as F -from megengine import tensor +from megengine import jit, tensor def common_test_reduce(opr, ref_opr): @@ -204,3 +204,28 @@ def test_topk(descending, sorted, inp1d, kth_only): if not sorted: values = np_sort(values) np.testing.assert_equal(values, np_sort(data)[..., :k]) + + +@pytest.mark.parametrize("is_trace", [True, False]) +def test_reduce_on_empty_tensor(is_trace): + dtypes = [np.float32, np.int32, np.bool] + inputs = [ + (np.random.random((0,)), None), + (np.random.random((3, 0, 2)), 1), + (np.random.random((10, 10, 0, 10)), 0), + ] + + def run_test(fn, ref_fn, input, dtype, axis=None, symbolic=False): + if is_trace: + fn = jit.trace(symbolic=symbolic)(fn) + for i in range(3): + out = fn(tensor(input, dtype=dtype), axis=axis).numpy() + out_ref = ref_fn(input.astype(dtype), axis=axis) + np.testing.assert_equal(out, out_ref) + + for dtype in dtypes: + for inp, axis in inputs: + run_test(F.sum, np.sum, inp, dtype, axis, True) + run_test(F.sum, np.sum, inp, dtype, axis, False) + run_test(F.prod, np.prod, inp, dtype, axis, True) + run_test(F.prod, np.prod, inp, dtype, axis, False) diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index 77f6f5ba..a54e6a46 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -84,6 +84,11 @@ public: auto&& dev_tensor = tensor.dev_tensor(); var->m_comp_node = dev_tensor.comp_node(); var->m_shape = dev_tensor.shape(); + if (dev_tensor.empty()) { + auto layout = dev_tensor.layout(); + layout.init_contiguous_stride(); + dev_tensor.reset(dev_tensor.storage(), layout); + } var->m_dev_tensor = dev_tensor; var->m_mem_plan.reset_from_owner_var().chunk() .mem_alloc_status.set_from_owner_var(); diff --git a/src/core/test/graph/misc.cpp b/src/core/test/graph/misc.cpp index cf1963a5..d7d724b1 100644 --- a/src/core/test/graph/misc.cpp +++ b/src/core/test/graph/misc.cpp @@ -1364,7 +1364,7 @@ TEST(TestGraph, EmptyShapeCheck) { using Param = opr::CondTake::Param; auto x = opr::Host2DeviceCopy::make(*graph, host_x), y = opr::CondTake::make(x, x, {Param::Mode::GT})[0], - z = opr::reduce_sum(y, y.make_scalar(1)); + z = opr::reduce_max(y, y.make_scalar(1)); HostTensorND host_z; auto func = graph->compile({make_callback_copy(z, host_z)}); func->execute(); @@ -1377,7 +1377,7 @@ TEST(TestGraph, EmptyShapeCheck) { func->execute(); } catch (const MegBrainError& exc) { std::string msg{exc.what()}; - ASSERT_TRUE(msg.find("empty output var") != + ASSERT_TRUE(msg.find("empty input is not allowed") != std::string::npos) << "bad message " << msg; throw; @@ -2413,8 +2413,6 @@ TEST(TestMemReuse, ResetEmptyDevTensor) { y = opr::Reduce::make(x, {opr::Reduce::Mode::MAX, 0}); HostTensorND host_y; auto func = g->compile({make_callback_copy(y, host_y)}); - auto &&recv = x.node()->owner_graph()->var_receiver_in_current_comp_seq(x.node()); - ASSERT_TRUE(!recv.is_empty_allowed()); if (inp_shp.is_empty()) { ASSERT_ANY_THROW(func->execute().wait()); } else { diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index ee63cafb..16bf25a9 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -1072,6 +1072,7 @@ class Reduce::KernScheduler { m_apply_side_effect; std::unique_ptr m_elemwise_trans_opr; std::unique_ptr m_typecvt_opr; + std::unique_ptr m_fill_opr; DeviceTensorND m_side_affect_wkspc; }; @@ -1338,6 +1339,47 @@ void Reduce::KernScheduler::execute( } mgb_assert(!m_kern_param.empty()); + + // empty input + if (input.shape_valid() && input.empty()) { + auto mode = m_kern_param[0].kparam.mode; + if (!m_fill_opr) { + m_fill_opr = intl::get_megdnn_handle(dest.comp_node())-> + create_operator(); + } + std::string err_msg; + switch (mode) { + case Reduce::Mode::SUM: + if (!dest.empty()) { + m_fill_opr->param() = 0; + m_fill_opr->exec(dest.as_megdnn(), {}); + } + break; + case Reduce::Mode::PRODUCT: + if (!dest.empty()) { + m_fill_opr->param() = 1; + m_fill_opr->exec(dest.as_megdnn(), {}); + } + break; + case Reduce::Mode::MEAN: + err_msg = "mean"; break; + case Reduce::Mode::MIN: + err_msg = "min"; break; + case Reduce::Mode::MAX: + err_msg = "max"; break; + case Reduce::Mode::SUM_SQR: + err_msg = "sum_sqr"; break; + default: + mgb_throw(MegBrainError, "bad reduce mode"); + } + if (!err_msg.empty()) { + mgb_throw( + MegBrainError, + "empty input is not allowed for reduce mode: %s", + err_msg.c_str()); + } + return; + } mgb_assert(input.layout().is_contiguous() && input.raw_ptr() == m_kern_param[0].input.raw_ptr && dest.raw_ptr() == m_kern_param.back().output.raw_ptr); @@ -1425,7 +1467,9 @@ Reduce::Reduce(VarNode *inp, VarNode *target_shape, const Param ¶m, mgb_throw(GraphError, "invalid param data_type: %d", int(param.data_type)); } - add_output(None)->dtype(out_dtype); + add_output(None) + ->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) + .dtype(out_dtype); cg::add_workspace_output(this); add_equivalence_component>(&m_param); @@ -1703,6 +1747,13 @@ void Reduce::perform( ksched.execute(opr.get(), *input_contig, dest); } +Reduce::NodeProp* Reduce::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(0), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + return ret; +} + void Reduce::create_megdnn_opr() { set_megdnn_opr(intl::get_megdnn_handle(comp_node())-> create_operator()); diff --git a/src/opr/include/megbrain/opr/basic_arith.h b/src/opr/include/megbrain/opr/basic_arith.h index 73a67757..69edb1b7 100644 --- a/src/opr/include/megbrain/opr/basic_arith.h +++ b/src/opr/include/megbrain/opr/basic_arith.h @@ -335,6 +335,7 @@ MGB_DEFINE_OPR_CLASS(Reduce, intl::DynamicOutputIfInputDynamic< void add_input_layout_constraint() override final; void scn_do_execute() override final; void init_output_static_infer_desc() override final; + NodeProp* do_make_node_prop() const override; void create_megdnn_opr() override; void record_execute_deps(ExecDependencyArray& deps) override; diff --git a/src/opr/test/basic_arith/reduction.cpp b/src/opr/test/basic_arith/reduction.cpp index 58dec096..c842e81f 100644 --- a/src/opr/test/basic_arith/reduction.cpp +++ b/src/opr/test/basic_arith/reduction.cpp @@ -900,4 +900,61 @@ TEST(TestBasicArithReduction, StaticInferValueDType) { run_test(F16, F16, ParamType::FLOAT_O16xC32); } +TEST(TestBasicArithReduction, EmptyInput) { + using Param = opr::Reduce::Param; + using Mode = opr::Reduce::Mode; + + auto check_allow_empty = [](const Param& param, const TensorShape& inpshp, double target_val) { + HostTensorGenerator<> gen; + auto graph = ComputingGraph::make(); + auto host_x = gen(inpshp); + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + y = opr::Reduce::make(x, param, {}); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute().wait(); + if (!host_y.shape().is_empty()) { + size_t size = host_y.layout().total_nr_elems(); + +#define cb(DType) \ + if (host_y.layout().dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + auto ptr = host_y.ptr(); \ + ctype target = static_cast(target_val); \ + for (size_t i = 0; i < size; ++i) { \ + ASSERT_TRUE(ptr[i] == target); \ + } \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb + + } else { + ASSERT_TRUE(host_y.empty()); + } + }; + + auto check_forbid_empty = [](const Param& param, const TensorShape& inpshp) { + HostTensorGenerator<> gen; + auto graph = ComputingGraph::make(); + auto host_x = gen(inpshp); + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + y = opr::Reduce::make(x, param, {}); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + ASSERT_ANY_THROW(func->execute().wait()); + }; + + check_allow_empty({Mode::SUM, 0, {}}, {0}, 0); + check_allow_empty({Mode::SUM, -1, {}}, {2, 0, 3}, 0); + check_allow_empty({Mode::SUM, 1, {}}, {2, 0, 3}, 0); + check_allow_empty({Mode::PRODUCT, 0, {}}, {0, 1, 2}, 1); + check_allow_empty({Mode::PRODUCT, 1, {}}, {0, 0, 0}, 1); + check_allow_empty({Mode::PRODUCT, 2, {}}, {0, 0, 0}, 1); + + check_forbid_empty({Mode::MAX, 0, {}}, {0}); + check_forbid_empty({Mode::MIN, -1, {}}, {0, 1, 2}); + check_forbid_empty({Mode::MEAN, 0, {}}, {0, 0}); + check_forbid_empty({Mode::SUM_SQR, 1, {}}, {2, 1, 0}); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}