diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 20f047c7..689e0f0a 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -12,7 +12,13 @@ from typing import Iterable, Optional, Sequence, Tuple, Union import numpy as np from ..core._imperative_rt import CompNode -from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion, split_cpp +from ..core._imperative_rt.core2 import ( + SymbolVar, + apply, + dtype_promotion, + expand_dims_cpp, + split_cpp, +) from ..core._wrap import as_device from ..core.ops import builtin from ..core.ops.builtin import Copy, Identity @@ -959,27 +965,7 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: (1, 2) """ - def get_axes(): - try: - return [int(axis)] - except (TypeError, ValueError): - pass - return list(map(int, axis)) - - axis = get_axes() - try: - ndim = inp.ndim + len(axis) - axis = sorted(i + ndim if i < 0 else i for i in axis) - except ValueError: - if any([ind < 0 for ind in axis]): - raise IndexError( - "Does not support negative index when tensor's ndim is unknown" - ) - axis = sorted(axis) - assert axis, "axis could not be empty" - op = builtin.AddAxis(axis=axis) - (result,) = apply(op, inp) - return result + return expand_dims_cpp(inp, axis) def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Tensor: diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 6cc1afc2..6efcb0f7 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -634,6 +634,7 @@ WRAP_FUNC_PY35(make_shape_tuple); WRAP_FUNC_PY35(getitem_cpp); WRAP_FUNC_PY35(setitem_cpp); WRAP_FUNC_PY35(split_cpp); +WRAP_FUNC_PY35(expand_dims_cpp); #undef WRAP_FUNC_PY35 #define MGE_PY_INTERFACE(NAME, FUNC) \ { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } @@ -767,6 +768,7 @@ void init_tensor(py::module m) { MGE_PY_INTERFACE(getitem_cpp, getitem_cpp), MGE_PY_INTERFACE(setitem_cpp, setitem_cpp), MGE_PY_INTERFACE(split_cpp, split_cpp), + MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp), {nullptr, nullptr, 0, nullptr}}; for (auto&& def : method_defs) { if (def.ml_meth != nullptr) { diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index d6a89f3f..97798fc0 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -683,6 +683,59 @@ py::object _split_cpp( return py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); } +py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { + std::vector axis; + if (is_py_sequence(axis_hdl.ptr())) { + py::list tmp_list = + py::reinterpret_steal(PySequence_List(axis_hdl.ptr())); + for (size_t i = 0; i < tmp_list.size(); ++i) { + axis.push_back(tmp_list[i].attr("__int__")().cast()); + } + } else { + axis.push_back(getattr(axis_hdl, "__int__")().cast()); + } + bool unknown_ndim = true; + size_t ndim = axis.size(); + if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) { + auto&& shape = p->m_tensor->shape(); + if (shape) { + unknown_ndim = false; + ndim += shape->ndim; + } + } else { + auto&& var = inp_hdl.cast(); + auto&& mgr = var->m_node->owner_graph()->static_infer_manager(); + auto&& shape = mgr.infer_shape_fallible(var->m_node); + if (shape) { + unknown_ndim = false; + ndim += shape->ndim; + } + } + for (size_t i = 0; i < axis.size(); ++i) { + if (axis[i] < 0) { + if (unknown_ndim) { + throw py::index_error( + "Does not support negative index when tensor's ndim is " + "unknown"); + } + axis[i] += ndim; + } + } + if (!axis.size()) { + throw py::index_error("axis could not be empty"); + } + std::sort(axis.begin(), axis.end()); + std::shared_ptr op = AddAxis::make(axis = axis); + std::vector p; + p.resize(2); + py::object Op = py::cast(op); + p[0] = Op.ptr(); + p[1] = inp_hdl.ptr(); + py::tuple ret = + py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + return ret[0]; +} + PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { try { return _make_shape_tuple(py::handle(args[0])).release().ptr(); @@ -716,4 +769,13 @@ PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs) { PYEXT17_TRANSLATE_EXC_RET(nullptr) } +PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs) { + try { + return _expand_dims_cpp(py::handle(args[0]), py::handle(args[1])) + .release() + .ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + } // namespace mgb::imperative::python diff --git a/imperative/python/src/tensor_utils.h b/imperative/python/src/tensor_utils.h index cc35ec41..bf6bc046 100644 --- a/imperative/python/src/tensor_utils.h +++ b/imperative/python/src/tensor_utils.h @@ -10,4 +10,6 @@ PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs); PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs); +PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs); + } // namespace mgb::imperative::python \ No newline at end of file