|
|
@@ -158,7 +158,13 @@ auto apply_on_var_node( |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
OP_TRAIT_REG(Reduce, Reduce) |
|
|
|
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { |
|
|
|
auto* node = &node_->cast_final_safe<opr::Reduce>(); |
|
|
|
return Reduce::make(node->param()); |
|
|
|
} |
|
|
|
|
|
|
|
OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) |
|
|
|
.make_from_op_node(make_from_op_node) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
.fallback(); |
|
|
|
}} // reduce |
|
|
@@ -439,12 +445,13 @@ OP_TRAIT_REG(GaussianRNG, GaussianRNG) |
|
|
|
}} // gaussian_rng |
|
|
|
|
|
|
|
namespace { namespace roi_align { |
|
|
|
auto apply_on_var_node( |
|
|
|
VarNodeArray apply_on_var_node( |
|
|
|
const OpDef& def, |
|
|
|
const VarNodeArray& inputs) { |
|
|
|
auto&& op = static_cast<const ROIAlign&>(def); |
|
|
|
mgb_assert(inputs.size() == 2); |
|
|
|
return opr::ROIAlign::make(inputs[0], inputs[1], op.param()); |
|
|
|
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param()).node()->owner_opr(); |
|
|
|
return {opr->output(0), opr->output(1)}; |
|
|
|
} |
|
|
|
OP_TRAIT_REG(ROIAlign, ROIAlign) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
@@ -496,12 +503,13 @@ OP_TRAIT_REG(Eye, Eye) |
|
|
|
}} // eye |
|
|
|
|
|
|
|
namespace { namespace roi_pooling { |
|
|
|
auto apply_on_var_node( |
|
|
|
VarNodeArray apply_on_var_node( |
|
|
|
const OpDef& def, |
|
|
|
const VarNodeArray& inputs) { |
|
|
|
auto&& op = static_cast<const ROIPooling&>(def); |
|
|
|
mgb_assert(inputs.size() == 3); |
|
|
|
return opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()); |
|
|
|
auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()).node()->owner_opr(); |
|
|
|
return {opr->output(0), opr->output(1)}; |
|
|
|
} |
|
|
|
OP_TRAIT_REG(ROIPooling, ROIPooling) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
@@ -620,11 +628,11 @@ auto apply_on_var_node( |
|
|
|
const VarNodeArray& inputs) { |
|
|
|
auto&& op = static_cast<const SVD&>(def); |
|
|
|
mgb_assert(inputs.size() == 1); |
|
|
|
return opr::SVD::make(inputs[0], op.param()); |
|
|
|
return opr::SVD::make(inputs[0], op.param())[0].node()->owner_opr()->usable_output(); |
|
|
|
} |
|
|
|
OP_TRAIT_REG(SVD, SVD) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
.fallback(); |
|
|
|
}} // svd |
|
|
|
|
|
|
|
} // namespace mgb::imperative |
|
|
|
} // namespace mgb::imperative |