@@ -12,7 +12,7 @@ from typing import Iterable, Optional, Sequence, Tuple, Union | |||||
import numpy as np | import numpy as np | ||||
from ..core._imperative_rt import CompNode | from ..core._imperative_rt import CompNode | ||||
from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion | |||||
from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion, split_cpp | |||||
from ..core._wrap import as_device | from ..core._wrap import as_device | ||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.ops.builtin import Copy, Identity | from ..core.ops.builtin import Copy, Identity | ||||
@@ -477,50 +477,8 @@ def split(inp, nsplits_or_sections, axis=0): | |||||
[(4, 20), (3, 20), (3, 20)] | [(4, 20), (3, 20), (3, 20)] | ||||
[(10, 6), (10, 11), (10, 3)] | [(10, 6), (10, 11), (10, 3)] | ||||
""" | """ | ||||
ndim = len(inp.shape) | |||||
if axis >= ndim: | |||||
raise ValueError("Invalid axis {}".format(axis)) | |||||
Ntotal = inp.shape[axis] | |||||
if isinstance(nsplits_or_sections, Sequence): | |||||
Nsections = len(nsplits_or_sections) + 1 | |||||
is_array = True | |||||
else: | |||||
Nsections = int(nsplits_or_sections) | |||||
is_array = False | |||||
if is_array: | |||||
partitions = [] | |||||
div_points = [0] + list(nsplits_or_sections) + [Ntotal] | |||||
for i in range(1, len(div_points)): | |||||
if div_points[i - 1] > div_points[i]: | |||||
raise ValueError( | |||||
"Invalid nsplits_or_secions: {}".format(nsplits_or_sections) | |||||
) | |||||
partitions.append(div_points[i] - div_points[i - 1]) | |||||
else: # scalar | |||||
if Nsections <= 0: | |||||
raise ValueError("Number sections must be larger than 0") | |||||
if Nsections > Ntotal: | |||||
raise ValueError( | |||||
"The size {} at dim {} cannot be split into {} sections".format( | |||||
Ntotal, axis, Nsections | |||||
) | |||||
) | |||||
partitions = [] | |||||
for i in range(Nsections): | |||||
section_size = (Ntotal + Nsections - i - 1) // Nsections | |||||
partitions.append(section_size) | |||||
partitions = [ | |||||
part | |||||
if isinstance(part, (SymbolVar, Tensor)) | |||||
else Const(part, dtype="int32", device=inp.device)(inp)[0] | |||||
for part in partitions | |||||
] | |||||
op = builtin.Split(axis=axis) | |||||
return apply(op, inp, *partitions) | |||||
return split_cpp(inp, nsplits_or_sections, axis) | |||||
def _get_idx(index, axis): | def _get_idx(index, axis): | ||||
@@ -633,6 +633,7 @@ WRAP_FUNC_PY35(get_device); | |||||
WRAP_FUNC_PY35(make_shape_tuple); | WRAP_FUNC_PY35(make_shape_tuple); | ||||
WRAP_FUNC_PY35(getitem_cpp); | WRAP_FUNC_PY35(getitem_cpp); | ||||
WRAP_FUNC_PY35(setitem_cpp); | WRAP_FUNC_PY35(setitem_cpp); | ||||
WRAP_FUNC_PY35(split_cpp); | |||||
#undef WRAP_FUNC_PY35 | #undef WRAP_FUNC_PY35 | ||||
#define MGE_PY_INTERFACE(NAME, FUNC) \ | #define MGE_PY_INTERFACE(NAME, FUNC) \ | ||||
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | ||||
@@ -765,6 +766,7 @@ void init_tensor(py::module m) { | |||||
MGE_PY_INTERFACE(make_shape_tuple, make_shape_tuple), | MGE_PY_INTERFACE(make_shape_tuple, make_shape_tuple), | ||||
MGE_PY_INTERFACE(getitem_cpp, getitem_cpp), | MGE_PY_INTERFACE(getitem_cpp, getitem_cpp), | ||||
MGE_PY_INTERFACE(setitem_cpp, setitem_cpp), | MGE_PY_INTERFACE(setitem_cpp, setitem_cpp), | ||||
MGE_PY_INTERFACE(split_cpp, split_cpp), | |||||
{nullptr, nullptr, 0, nullptr}}; | {nullptr, nullptr, 0, nullptr}}; | ||||
for (auto&& def : method_defs) { | for (auto&& def : method_defs) { | ||||
if (def.ml_meth != nullptr) { | if (def.ml_meth != nullptr) { | ||||
@@ -603,6 +603,86 @@ py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_h | |||||
return res; | return res; | ||||
} | } | ||||
bool is_tensor_or_symbolvar(py::handle arg) { | |||||
return bool(TensorWrapper::try_cast(arg.ptr())) || py::isinstance<PySymbolVar>(arg); | |||||
} | |||||
bool is_py_sequence(py::handle arg) { | |||||
if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr()) || | |||||
py::isinstance<PySymbolVar>(arg)) { | |||||
return false; | |||||
} | |||||
return PySequence_Check(arg.ptr()); | |||||
} | |||||
py::object _split_cpp( | |||||
py::handle inp_hdl, py::handle nsplits_or_sections_hdl, py::handle axis_hdl) { | |||||
py::object shape_obj = getattr(inp_hdl, "shape"); | |||||
py::object n_total = shape_obj[axis_hdl]; | |||||
int ndim = shape_obj.attr("__len__")().cast<int>(); | |||||
int axis = axis_hdl.cast<int>(); | |||||
if (axis >= ndim) { | |||||
throw py::value_error("Invalid axis " + std::to_string(axis)); | |||||
} | |||||
int n_sections; | |||||
bool is_array; | |||||
if (is_py_sequence(nsplits_or_sections_hdl)) { | |||||
n_sections = PySequence_Length(nsplits_or_sections_hdl.ptr()) + 1; | |||||
is_array = true; | |||||
} else { | |||||
n_sections = getattr(nsplits_or_sections_hdl, "__int__")().cast<int>(); | |||||
is_array = false; | |||||
} | |||||
py::list partitions; | |||||
std::shared_ptr<OpDef> op; | |||||
std::vector<PyObject*> p; | |||||
if (is_array) { | |||||
py::list div_points; | |||||
py::list sections = py::reinterpret_borrow<py::object>(nsplits_or_sections_hdl); | |||||
div_points.append(0); | |||||
for (size_t i = 0; i < sections.size(); ++i) { | |||||
div_points.append(sections[i]); | |||||
} | |||||
div_points.append(n_total); | |||||
for (size_t i = 1; i < div_points.size(); ++i) { | |||||
if (div_points[i - 1] > div_points[i]) { | |||||
throw py::value_error( | |||||
"Invalid nsplits_or_secions: " + | |||||
repr(nsplits_or_sections_hdl).cast<std::string>()); | |||||
} | |||||
py::object pos = div_points[i] - div_points[i - 1]; | |||||
if (is_tensor_or_symbolvar(pos)) { | |||||
partitions.append(pos); | |||||
} else { | |||||
partitions.append( | |||||
_Const(pos, py::cast((mgb::DType)dtype::Int32()), | |||||
getattr(inp_hdl, "device"), inp_hdl)); | |||||
} | |||||
} | |||||
op = Split::make(axis, 0); | |||||
p.resize(partitions.size() + 2); | |||||
for (size_t i = 0; i < partitions.size(); ++i) { | |||||
p[i + 2] = partitions[i].ptr(); | |||||
} | |||||
} else { | |||||
if (n_sections <= 0) { | |||||
throw py::value_error("Number sections must be larger than 0"); | |||||
} | |||||
if (py::int_(n_sections) > n_total) { | |||||
throw py::value_error( | |||||
"The size " + repr(n_total).cast<std::string>() + " at dim " + | |||||
std::to_string(axis) + " cannot be split into " + | |||||
std::to_string(n_sections) + " sections"); | |||||
} | |||||
op = Split::make(axis, n_sections); | |||||
p.resize(2); | |||||
} | |||||
py::object Op = py::cast(op); | |||||
p[0] = Op.ptr(); | |||||
p[1] = inp_hdl.ptr(); | |||||
return py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | |||||
} | |||||
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | ||||
try { | try { | ||||
return _make_shape_tuple(py::handle(args[0])).release().ptr(); | return _make_shape_tuple(py::handle(args[0])).release().ptr(); | ||||
@@ -627,4 +707,13 @@ PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | PYEXT17_TRANSLATE_EXC_RET(nullptr) | ||||
} | } | ||||
PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
try { | |||||
return _split_cpp(py::handle(args[0]), py::handle(args[1]), py::handle(args[2])) | |||||
.release() | |||||
.ptr(); | |||||
} | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python |
@@ -8,4 +8,6 @@ PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs); | ||||
PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python |
@@ -285,7 +285,7 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
opt.method == Options::Method::SPECIFY, | opt.method == Options::Method::SPECIFY, | ||||
"only Split with SPECIFY output shapes is supported"); | "only Split with SPECIFY output shapes is supported"); | ||||
mgb_assert(opt.partition.size() == opt.nr_part); | mgb_assert(opt.partition.size() == opt.nr_part); | ||||
return Split::make(axis); | |||||
return Split::make(axis, 0); | |||||
} | } | ||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
@@ -293,13 +293,18 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& sp = static_cast<const Split&>(def); | auto&& sp = static_cast<const Split&>(def); | ||||
OperatorNodeConfig config{sp.make_name()}; | OperatorNodeConfig config{sp.make_name()}; | ||||
opr::Split::Options opt; | opr::Split::Options opt; | ||||
opt.axis = sp.axis; | |||||
opt.method = Options::Method::SPECIFY; | |||||
mgb_assert(inputs.size() > 1); | |||||
opt.nr_part = inputs.size() - 1; | |||||
opt.partition.resize(opt.nr_part); | |||||
for (size_t i = 1; i < inputs.size(); ++i) | |||||
opt.partition[i - 1] = inputs[i]; | |||||
if (sp.nsections) { | |||||
opt = Options::make_average(sp.axis, sp.nsections); | |||||
opt.method = Options::Method::CALL_BACK; | |||||
} else { | |||||
opt.axis = sp.axis; | |||||
opt.method = Options::Method::SPECIFY; | |||||
mgb_assert(inputs.size() > 1); | |||||
opt.nr_part = inputs.size() - 1; | |||||
opt.partition.resize(opt.nr_part); | |||||
for (size_t i = 1; i < inputs.size(); ++i) | |||||
opt.partition[i - 1] = inputs[i]; | |||||
} | |||||
return opr::Split::make(inputs[0], opt, config); | return opr::Split::make(inputs[0], opt, config); | ||||
} | } | ||||
@@ -426,7 +426,8 @@ def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>; | |||||
def Split: MgbHashableOp<"Split", [EmptyParam]> { | def Split: MgbHashableOp<"Split", [EmptyParam]> { | ||||
let extraArguments = (ins | let extraArguments = (ins | ||||
MgbI32Attr:$axis | |||||
MgbI32Attr:$axis, | |||||
MgbI32Attr:$nsections | |||||
); | ); | ||||
} | } | ||||
@@ -422,7 +422,7 @@ public: | |||||
/*! | /*! | ||||
* \brief make split option by splitting into average parts | * \brief make split option by splitting into average parts | ||||
*/ | */ | ||||
static Options make_average(int axis, size_t nr_part); | |||||
MGE_WIN_DECLSPEC_FUC static Options make_average(int axis, size_t nr_part); | |||||
static Options make_partition(int axis, const SymbolVarArray& partition); | static Options make_partition(int axis, const SymbolVarArray& partition); | ||||
static Options make_partition( | static Options make_partition( | ||||