From 0f736a0ab47455690db78d270811f1f01f158bbb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 4 Mar 2022 17:28:54 +0800 Subject: [PATCH] perf(mge/functional): speed up Dimshuffle GitOrigin-RevId: 8160c9522bc59d06aab1321b87408ecb410b4a81 --- .../python/megengine/core/tensor/array_method.py | 18 +------ imperative/python/megengine/functional/tensor.py | 2 +- imperative/python/src/tensor.cpp | 2 + imperative/python/src/tensor_utils.cpp | 58 ++++++++++++++++++++++ imperative/python/src/tensor_utils.h | 2 + 5 files changed, 65 insertions(+), 17 deletions(-) diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 51d6becd..da591cc3 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -17,7 +17,7 @@ from .. import _config from .._imperative_rt.common import CompNode from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar -from .._imperative_rt.core2 import squeeze_cpp +from .._imperative_rt.core2 import squeeze_cpp, transpose_cpp from ..ops import builtin from . import amp from .indexing import getitem, setitem @@ -331,12 +331,6 @@ def _matmul( return result -def _transpose(data, axes): - op = builtin.Dimshuffle(axes) - (result,) = apply(op, data) - return result - - def _broadcast(inp, shape): auto_infer = False if isinstance(shape, (list, tuple)): @@ -681,15 +675,7 @@ class ArrayMethodMixin(abc.ABC): def transpose(self, *args): r"""See :func:`~.transpose`.""" - if self.ndim == 0: - assert ( - len(args) == 0 - ), "transpose for scalar does not accept additional args" - ret = self.to(self.device) - return ret - if not args: - args = range(self.ndim)[::-1] - return _transpose(self, _expand_args(args)) + return transpose_cpp(self, args) def flatten(self): r"""See :func:`~.flatten`.""" diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 20ad39db..68ce956c 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -865,7 +865,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: [[1 0] [1 0]] """ - return inp.transpose(list(-1 if _ == "x" else _ for _ in pattern)) + return inp.transpose(pattern) def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index f271a689..58c9beb4 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -636,6 +636,7 @@ WRAP_FUNC_PY35(setitem_cpp); WRAP_FUNC_PY35(split_cpp); WRAP_FUNC_PY35(expand_dims_cpp); WRAP_FUNC_PY35(squeeze_cpp); +WRAP_FUNC_PY35(transpose_cpp); #undef WRAP_FUNC_PY35 #define MGE_PY_INTERFACE(NAME, FUNC) \ { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } @@ -771,6 +772,7 @@ void init_tensor(py::module m) { MGE_PY_INTERFACE(split_cpp, split_cpp), MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp), MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp), + MGE_PY_INTERFACE(transpose_cpp, transpose_cpp), {nullptr, nullptr, 0, nullptr}}; for (auto&& def : method_defs) { if (def.ml_meth != nullptr) { diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index b46549c8..16f4062e 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -793,6 +793,57 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) { return ret[0]; } +size_t fast_ndim(py::handle tensor) { + if (auto p = TensorWrapper::try_cast(tensor.ptr())) { + return p->m_tensor->shape()->ndim; + } + return getattr(tensor, "ndim").cast(); +} + +py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { + py::tuple args_tup = py::reinterpret_borrow(args.ptr()); + if (fast_ndim(inp_hdl) == 0) { + if (args_tup.size() != 0) { + throw py::index_error( + "transpose for scalar does not accept additional args"); + } + return getattr(inp_hdl, "to")(getattr(inp_hdl, "device")); + } + std::vector pattern; + if (!args_tup.size()) { + size_t ndim = getattr(inp_hdl, "ndim").cast(); + for (size_t i = 0; i < ndim; ++i) { + pattern.push_back(ndim - i - 1); + } + } else { + py::list lis; + if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) || + is_tensor_or_symbolvar(args_tup[0].ptr()))) { + lis = py::reinterpret_steal(PySequence_List(args_tup[0].ptr())); + } else { + lis = py::reinterpret_steal(PySequence_List(args_tup.ptr())); + } + for (size_t i = 0; i < lis.size(); ++i) { + if (PyLong_Check(lis[i].ptr())) { + pattern.push_back(lis[i].cast()); + } else { + if (lis[i].cast() == "x") { + pattern.push_back(-1); + } + } + } + } + std::shared_ptr op = Dimshuffle::make(pattern); + std::vector p; + p.resize(2); + py::object Op = py::cast(op); + p[0] = Op.ptr(); + p[1] = inp_hdl.ptr(); + py::tuple ret = + py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + return ret[0]; +} + PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { try { return _make_shape_tuple(py::handle(args[0])).release().ptr(); @@ -842,4 +893,11 @@ PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs) { PYEXT17_TRANSLATE_EXC_RET(nullptr) } +PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs) { + try { + return _transpose_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + } // namespace mgb::imperative::python diff --git a/imperative/python/src/tensor_utils.h b/imperative/python/src/tensor_utils.h index 55ea2c7d..422ee461 100644 --- a/imperative/python/src/tensor_utils.h +++ b/imperative/python/src/tensor_utils.h @@ -14,4 +14,6 @@ PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs); PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs); +PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs); + } // namespace mgb::imperative::python \ No newline at end of file