Browse Source

feat(opr): let Argsort support empty IO

GitOrigin-RevId: 05fcac6e47
release-1.6
Megvii Engine Team 3 years ago
parent
commit
1a1748daf1
3 changed files with 72 additions and 5 deletions
  1. +30
    -4
      imperative/python/test/unit/functional/test_math.py
  2. +36
    -1
      src/opr/impl/misc.cpp
  3. +6
    -0
      src/opr/include/megbrain/opr/misc.h

+ 30
- 4
imperative/python/test/unit/functional/test_math.py View File

@@ -110,16 +110,42 @@ def test_sort():
data2_shape = (12, 2) data2_shape = (12, 2)
data1 = np.random.random(data1_shape).astype(np.float32) data1 = np.random.random(data1_shape).astype(np.float32)
data2 = np.random.random(data2_shape).astype(np.float32) data2 = np.random.random(data2_shape).astype(np.float32)
output0 = [np.sort(data1), np.argsort(data1).astype(np.int32)]
output1 = [np.sort(data2), np.argsort(data2).astype(np.int32)]
output1 = [np.sort(data1), np.argsort(data1).astype(np.int32)]
output2 = [np.sort(data2), np.argsort(data2).astype(np.int32)]


cases = [ cases = [
{"input": data1, "output": output0},
{"input": data2, "output": output1},
{"input": data1, "output": output1},
{"input": data2, "output": output2},
] ]
opr_test(cases, F.sort) opr_test(cases, F.sort)




@pytest.mark.parametrize("is_symbolic", [None, False, True])
def test_sort_empty(is_symbolic):
data_shapes = [
(0,),
(10, 0),
]

def fn(x):
return F.sort(x)

for shape in data_shapes:
if is_symbolic is not None:
fn_ = jit.trace(symbolic=is_symbolic)(fn)
else:
fn_ = fn
data = np.random.random(shape).astype(np.float32)
for _ in range(3):
outs = fn_(tensor(data))
ref_outs = (np.sort(data), np.argsort(data))
assert len(ref_outs) == len(outs)
for i in range(len(outs)):
np.testing.assert_equal(outs[i].numpy(), ref_outs[i])
if is_symbolic is None:
break


def test_normalize(): def test_normalize():


cases = [ cases = [


+ 36
- 1
src/opr/impl/misc.cpp View File

@@ -75,7 +75,16 @@ MEGDNN_OPR_INIT1(Argmin, "argmin")
/* ================= ArgsortForward ================= */ /* ================= ArgsortForward ================= */


MGB_DYN_TYPE_OBJ_FINAL_IMPL(ArgsortForward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(ArgsortForward);
MEGDNN_OPR_CTOR_INIT1(ArgsortForward, "argsort")
// MEGDNN_OPR_CTOR_INIT1(ArgsortForward, "argsort")

ArgsortForward::ArgsortForward(VarNode *i0, const Param &param, const OperatorNodeConfig &config):
Super(OperatorNodeBaseCtorParam{ i0->owner_graph(), config, "argsort", {i0}} ) {
init_megdnn_opr(*this, param);
add_input({i0});
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); // sorted value
output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); // sorted index
intl::MegDNNOprInitPostCtor<ArgsortForward>::apply(*this);
}


std::array<SymbolVar, 2> ArgsortForward::make( std::array<SymbolVar, 2> ArgsortForward::make(
SymbolVar in_tensor, const Param &param, SymbolVar in_tensor, const Param &param,
@@ -87,6 +96,32 @@ std::array<SymbolVar, 2> ArgsortForward::make(
return {node->output(0), node->output(1)}; return {node->output(0), node->output(1)};
} }


void ArgsortForward::scn_do_execute() {
if (input(0)->dev_tensor().empty()) {
mgb_assert(output(0)->dev_tensor().empty() &&
output(1)->dev_tensor().empty());
return;
}
mgb_assert(!output(0)->dev_tensor().empty() &&
!output(1)->dev_tensor().empty());
Super::scn_do_execute();
}

void ArgsortForward::get_output_var_shape(
const TensorShapeArray &inp_shape,
TensorShapeArray &out_shape) const {
mgb_assert(inp_shape.size() == 1 && out_shape.size() == 2);
out_shape[0] = inp_shape[0];
out_shape[1] = inp_shape[0];
}

ArgsortForward::NodeProp* ArgsortForward::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}

#if MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ArgsortForward) { MGB_IMPL_OPR_GRAD(ArgsortForward) {
mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]);


+ 6
- 0
src/opr/include/megbrain/opr/misc.h View File

@@ -55,6 +55,12 @@ MGB_DEFINE_OPR_CLASS(Argmin,
*/ */
MGB_DEFINE_OPR_CLASS(ArgsortForward, MGB_DEFINE_OPR_CLASS(ArgsortForward,
intl::MegDNNOprWrapperFwd<megdnn::ArgsortForward>) // { intl::MegDNNOprWrapperFwd<megdnn::ArgsortForward>) // {
protected:
NodeProp* do_make_node_prop() const override;
void scn_do_execute() override;
void get_output_var_shape(
const TensorShapeArray &inp_shape,
TensorShapeArray &out_shape) const override;
public: public:
ArgsortForward(VarNode *in_tensor, ArgsortForward(VarNode *in_tensor,
const Param &param, const Param &param,


Loading…
Cancel
Save