Browse Source

perf(functional/split): add python binding for Split opr

GitOrigin-RevId: 3359f5da26
release-1.6
Megvii Engine Team 3 years ago
parent
commit
7a9f2ed95e
3 changed files with 43 additions and 1 deletions
  1. +3
    -1
      imperative/python/megengine/functional/tensor.py
  2. +34
    -0
      imperative/src/impl/ops/tensor_manip.cpp
  3. +6
    -0
      src/core/include/megbrain/ir/ops.td

+ 3
- 1
imperative/python/megengine/functional/tensor.py View File

@@ -450,12 +450,14 @@ def split(inp, nsplits_or_sections, axis=0):
is_array = False

if is_array:
partitions = []
div_points = [0] + list(nsplits_or_sections) + [Ntotal]
for i in range(1, len(div_points)):
if div_points[i - 1] >= div_points[i]:
if div_points[i - 1] > div_points[i]:
raise ValueError(
"Invalid nsplits_or_secions: {}".format(nsplits_or_sections)
)
partitions.append(div_points[i] - div_points[i - 1])
else: # scalar
if Nsections <= 0:
raise ValueError("Number sections must be larger than 0")


+ 34
- 0
imperative/src/impl/ops/tensor_manip.cpp View File

@@ -384,4 +384,38 @@ OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat)
.fallback();
} // param_pack

namespace split {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
using Options = opr::Split::Options;
auto* node = &node_->cast_final_safe<opr::Split>();
auto&& opt = node->options();
int axis = opt.axis;
mgb_assert(opt.method == Options::Method::SPECIFY,
"only Split with SPECIFY output shapes is supported");
mgb_assert(opt.partition.size() == opt.nr_part);
return Split::make(axis);
}

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
using Options = opr::Split::Options;
auto&& sp = static_cast<const Split&>(def);
OperatorNodeConfig config{sp.make_name()};
opr::Split::Options opt;
opt.axis = sp.axis;
opt.method = Options::Method::SPECIFY;
mgb_assert(inputs.size() > 1);
opt.nr_part = inputs.size() - 1;
opt.partition.resize(opt.nr_part);
for (size_t i = 1; i < inputs.size(); ++ i)
opt.partition[i - 1] = inputs[i];
return opr::Split::make(inputs[0], opt, config);
}

OP_TRAIT_REG(Split, Split, opr::Split)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.fallback();

} // namespace split

} // namespace mgb::imperative

+ 6
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -383,4 +383,10 @@ def FastpathCopy: MgbHashableOp<"FastpathCopy">;

def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>;

def Split: MgbHashableOp<"Split", [EmptyParam]> {
let extraArguments = (ins
MgbI32Attr:$axis
);
}

#endif // MGB_OPS

Loading…
Cancel
Save