From 7a9f2ed95e4ef1b19afb038d012033054d34ab7e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 20 Aug 2021 19:16:51 +0800 Subject: [PATCH] perf(functional/split): add python binding for Split opr GitOrigin-RevId: 3359f5da26790d902573acdd75727c7e57da982b --- imperative/python/megengine/functional/tensor.py | 4 ++- imperative/src/impl/ops/tensor_manip.cpp | 34 ++++++++++++++++++++++++ src/core/include/megbrain/ir/ops.td | 6 +++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 2677c6b0..67aa9f4a 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -450,12 +450,14 @@ def split(inp, nsplits_or_sections, axis=0): is_array = False if is_array: + partitions = [] div_points = [0] + list(nsplits_or_sections) + [Ntotal] for i in range(1, len(div_points)): - if div_points[i - 1] >= div_points[i]: + if div_points[i - 1] > div_points[i]: raise ValueError( "Invalid nsplits_or_secions: {}".format(nsplits_or_sections) ) + partitions.append(div_points[i] - div_points[i - 1]) else: # scalar if Nsections <= 0: raise ValueError("Number sections must be larger than 0") diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index 10862e0b..b64169b9 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -384,4 +384,38 @@ OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) .fallback(); } // param_pack +namespace split { +std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { + using Options = opr::Split::Options; + auto* node = &node_->cast_final_safe(); + auto&& opt = node->options(); + int axis = opt.axis; + mgb_assert(opt.method == Options::Method::SPECIFY, + "only Split with SPECIFY output shapes is supported"); + mgb_assert(opt.partition.size() == opt.nr_part); + return Split::make(axis); +} + +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + using Options = opr::Split::Options; + auto&& sp = static_cast(def); + OperatorNodeConfig config{sp.make_name()}; + opr::Split::Options opt; + opt.axis = sp.axis; + opt.method = Options::Method::SPECIFY; + mgb_assert(inputs.size() > 1); + opt.nr_part = inputs.size() - 1; + opt.partition.resize(opt.nr_part); + for (size_t i = 1; i < inputs.size(); ++ i) + opt.partition[i - 1] = inputs[i]; + return opr::Split::make(inputs[0], opt, config); +} + +OP_TRAIT_REG(Split, Split, opr::Split) + .make_from_op_node(make_from_op_node) + .apply_on_var_node(apply_on_var_node) + .fallback(); + +} // namespace split + } // namespace mgb::imperative diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 07223e7f..6dd7ed7c 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -383,4 +383,10 @@ def FastpathCopy: MgbHashableOp<"FastpathCopy">; def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>; +def Split: MgbHashableOp<"Split", [EmptyParam]> { + let extraArguments = (ins + MgbI32Attr:$axis + ); +} + #endif // MGB_OPS