|
@@ -75,7 +75,16 @@ MEGDNN_OPR_INIT1(Argmin, "argmin") |
|
|
/* ================= ArgsortForward ================= */ |
|
|
/* ================= ArgsortForward ================= */ |
|
|
|
|
|
|
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ArgsortForward); |
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ArgsortForward); |
|
|
MEGDNN_OPR_CTOR_INIT1(ArgsortForward, "argsort") |
|
|
|
|
|
|
|
|
// MEGDNN_OPR_CTOR_INIT1(ArgsortForward, "argsort") |
|
|
|
|
|
|
|
|
|
|
|
ArgsortForward::ArgsortForward(VarNode *i0, const Param ¶m, const OperatorNodeConfig &config): |
|
|
|
|
|
Super(OperatorNodeBaseCtorParam{ i0->owner_graph(), config, "argsort", {i0}} ) { |
|
|
|
|
|
init_megdnn_opr(*this, param); |
|
|
|
|
|
add_input({i0}); |
|
|
|
|
|
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); // sorted value |
|
|
|
|
|
output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); // sorted index |
|
|
|
|
|
intl::MegDNNOprInitPostCtor<ArgsortForward>::apply(*this); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
std::array<SymbolVar, 2> ArgsortForward::make( |
|
|
std::array<SymbolVar, 2> ArgsortForward::make( |
|
|
SymbolVar in_tensor, const Param ¶m, |
|
|
SymbolVar in_tensor, const Param ¶m, |
|
@@ -87,6 +96,32 @@ std::array<SymbolVar, 2> ArgsortForward::make( |
|
|
return {node->output(0), node->output(1)}; |
|
|
return {node->output(0), node->output(1)}; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void ArgsortForward::scn_do_execute() { |
|
|
|
|
|
if (input(0)->dev_tensor().empty()) { |
|
|
|
|
|
mgb_assert(output(0)->dev_tensor().empty() && |
|
|
|
|
|
output(1)->dev_tensor().empty()); |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
|
|
|
mgb_assert(!output(0)->dev_tensor().empty() && |
|
|
|
|
|
!output(1)->dev_tensor().empty()); |
|
|
|
|
|
Super::scn_do_execute(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void ArgsortForward::get_output_var_shape( |
|
|
|
|
|
const TensorShapeArray &inp_shape, |
|
|
|
|
|
TensorShapeArray &out_shape) const { |
|
|
|
|
|
mgb_assert(inp_shape.size() == 1 && out_shape.size() == 2); |
|
|
|
|
|
out_shape[0] = inp_shape[0]; |
|
|
|
|
|
out_shape[1] = inp_shape[0]; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
ArgsortForward::NodeProp* ArgsortForward::do_make_node_prop() const { |
|
|
|
|
|
auto ret = Super::do_make_node_prop(); |
|
|
|
|
|
ret->add_dep_type_existing_var(input(0), |
|
|
|
|
|
NodeProp::DepType::VALUE_ALLOW_EMPTY); |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
#if MGB_ENABLE_GRAD |
|
|
#if MGB_ENABLE_GRAD |
|
|
MGB_IMPL_OPR_GRAD(ArgsortForward) { |
|
|
MGB_IMPL_OPR_GRAD(ArgsortForward) { |
|
|
mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); |
|
|
mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); |
|
|