|
@@ -384,4 +384,38 @@ OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) |
|
|
.fallback(); |
|
|
.fallback(); |
|
|
} // param_pack |
|
|
} // param_pack |
|
|
|
|
|
|
|
|
|
|
|
namespace split { |
|
|
|
|
|
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { |
|
|
|
|
|
using Options = opr::Split::Options; |
|
|
|
|
|
auto* node = &node_->cast_final_safe<opr::Split>(); |
|
|
|
|
|
auto&& opt = node->options(); |
|
|
|
|
|
int axis = opt.axis; |
|
|
|
|
|
mgb_assert(opt.method == Options::Method::SPECIFY, |
|
|
|
|
|
"only Split with SPECIFY output shapes is supported"); |
|
|
|
|
|
mgb_assert(opt.partition.size() == opt.nr_part); |
|
|
|
|
|
return Split::make(axis); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { |
|
|
|
|
|
using Options = opr::Split::Options; |
|
|
|
|
|
auto&& sp = static_cast<const Split&>(def); |
|
|
|
|
|
OperatorNodeConfig config{sp.make_name()}; |
|
|
|
|
|
opr::Split::Options opt; |
|
|
|
|
|
opt.axis = sp.axis; |
|
|
|
|
|
opt.method = Options::Method::SPECIFY; |
|
|
|
|
|
mgb_assert(inputs.size() > 1); |
|
|
|
|
|
opt.nr_part = inputs.size() - 1; |
|
|
|
|
|
opt.partition.resize(opt.nr_part); |
|
|
|
|
|
for (size_t i = 1; i < inputs.size(); ++ i) |
|
|
|
|
|
opt.partition[i - 1] = inputs[i]; |
|
|
|
|
|
return opr::Split::make(inputs[0], opt, config); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
OP_TRAIT_REG(Split, Split, opr::Split) |
|
|
|
|
|
.make_from_op_node(make_from_op_node) |
|
|
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
|
|
.fallback(); |
|
|
|
|
|
|
|
|
|
|
|
} // namespace split |
|
|
|
|
|
|
|
|
} // namespace mgb::imperative |
|
|
} // namespace mgb::imperative |