diff --git a/imperative/python/test/unit/functional/test_math.py b/imperative/python/test/unit/functional/test_math.py index 89ba78db..428e65ff 100644 --- a/imperative/python/test/unit/functional/test_math.py +++ b/imperative/python/test/unit/functional/test_math.py @@ -110,16 +110,42 @@ def test_sort(): data2_shape = (12, 2) data1 = np.random.random(data1_shape).astype(np.float32) data2 = np.random.random(data2_shape).astype(np.float32) - output0 = [np.sort(data1), np.argsort(data1).astype(np.int32)] - output1 = [np.sort(data2), np.argsort(data2).astype(np.int32)] + output1 = [np.sort(data1), np.argsort(data1).astype(np.int32)] + output2 = [np.sort(data2), np.argsort(data2).astype(np.int32)] cases = [ - {"input": data1, "output": output0}, - {"input": data2, "output": output1}, + {"input": data1, "output": output1}, + {"input": data2, "output": output2}, ] opr_test(cases, F.sort) +@pytest.mark.parametrize("is_symbolic", [None, False, True]) +def test_sort_empty(is_symbolic): + data_shapes = [ + (0,), + (10, 0), + ] + + def fn(x): + return F.sort(x) + + for shape in data_shapes: + if is_symbolic is not None: + fn_ = jit.trace(symbolic=is_symbolic)(fn) + else: + fn_ = fn + data = np.random.random(shape).astype(np.float32) + for _ in range(3): + outs = fn_(tensor(data)) + ref_outs = (np.sort(data), np.argsort(data)) + assert len(ref_outs) == len(outs) + for i in range(len(outs)): + np.testing.assert_equal(outs[i].numpy(), ref_outs[i]) + if is_symbolic is None: + break + + def test_normalize(): cases = [ diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index c73752d0..f4650515 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -75,7 +75,16 @@ MEGDNN_OPR_INIT1(Argmin, "argmin") /* ================= ArgsortForward ================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(ArgsortForward); -MEGDNN_OPR_CTOR_INIT1(ArgsortForward, "argsort") +// MEGDNN_OPR_CTOR_INIT1(ArgsortForward, "argsort") + +ArgsortForward::ArgsortForward(VarNode *i0, const Param ¶m, const OperatorNodeConfig &config): + Super(OperatorNodeBaseCtorParam{ i0->owner_graph(), config, "argsort", {i0}} ) { + init_megdnn_opr(*this, param); + add_input({i0}); + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); // sorted value + output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); // sorted index + intl::MegDNNOprInitPostCtor::apply(*this); +} std::array ArgsortForward::make( SymbolVar in_tensor, const Param ¶m, @@ -87,6 +96,32 @@ std::array ArgsortForward::make( return {node->output(0), node->output(1)}; } +void ArgsortForward::scn_do_execute() { + if (input(0)->dev_tensor().empty()) { + mgb_assert(output(0)->dev_tensor().empty() && + output(1)->dev_tensor().empty()); + return; + } + mgb_assert(!output(0)->dev_tensor().empty() && + !output(1)->dev_tensor().empty()); + Super::scn_do_execute(); +} + +void ArgsortForward::get_output_var_shape( + const TensorShapeArray &inp_shape, + TensorShapeArray &out_shape) const { + mgb_assert(inp_shape.size() == 1 && out_shape.size() == 2); + out_shape[0] = inp_shape[0]; + out_shape[1] = inp_shape[0]; +} + +ArgsortForward::NodeProp* ArgsortForward::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(0), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + return ret; +} + #if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ArgsortForward) { mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); diff --git a/src/opr/include/megbrain/opr/misc.h b/src/opr/include/megbrain/opr/misc.h index e3914684..51cade1d 100644 --- a/src/opr/include/megbrain/opr/misc.h +++ b/src/opr/include/megbrain/opr/misc.h @@ -55,6 +55,12 @@ MGB_DEFINE_OPR_CLASS(Argmin, */ MGB_DEFINE_OPR_CLASS(ArgsortForward, intl::MegDNNOprWrapperFwd) // { + protected: + NodeProp* do_make_node_prop() const override; + void scn_do_execute() override; + void get_output_var_shape( + const TensorShapeArray &inp_shape, + TensorShapeArray &out_shape) const override; public: ArgsortForward(VarNode *in_tensor, const Param ¶m,