Browse Source

perf(mge/functional): speed up RemoveAxis

GitOrigin-RevId: 9c5d27fe1d
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
3e5e08b0b4
5 changed files with 80 additions and 30 deletions
  1. +3
    -22
      imperative/python/megengine/core/tensor/array_method.py
  2. +3
    -2
      imperative/python/megengine/functional/tensor.py
  3. +2
    -0
      imperative/python/src/tensor.cpp
  4. +70
    -6
      imperative/python/src/tensor_utils.cpp
  5. +2
    -0
      imperative/python/src/tensor_utils.h

+ 3
- 22
imperative/python/megengine/core/tensor/array_method.py View File

@@ -17,6 +17,7 @@ from .. import _config
from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion
from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar
from .._imperative_rt.core2 import squeeze_cpp
from ..ops import builtin
from . import amp
from .indexing import getitem, setitem
@@ -448,26 +449,6 @@ def _logical_binary_elwise(mode, rev=False):
return f


def _remove_axis(inp: Tensor, axis) -> Tensor:
def get_axes():
if axis is None:
shp = inp.shape
return [i for i, s in enumerate(shp) if s == 1]
try:
return [int(axis)]
except (TypeError, ValueError):
pass
return list(map(int, axis))

axis = get_axes()
axis = _normalize_axis(inp.ndim, axis)
axis = [a - i for i, a in enumerate(axis)]

op = builtin.RemoveAxis(axis=axis)
(result,) = apply(op, inp)
return result


def _reduce(mode):
def f(self, axis=None, keepdims: bool = False):
data = self
@@ -480,7 +461,7 @@ def _reduce(mode):
op = builtin.Reduce(mode=mode, axis=ai)
(data,) = apply(op, data)
if not keepdims:
data = _remove_axis(data, ai)
data = squeeze_cpp(data, ai)
result = data
else:
# builtin.Reduce already accept negtive axis
@@ -488,7 +469,7 @@ def _reduce(mode):
(result,) = apply(op, data)

if not keepdims:
result = _remove_axis(result, axis)
result = squeeze_cpp(result, axis)
return result

return f


+ 3
- 2
imperative/python/megengine/functional/tensor.py View File

@@ -18,12 +18,13 @@ from ..core._imperative_rt.core2 import (
dtype_promotion,
expand_dims_cpp,
split_cpp,
squeeze_cpp,
)
from ..core._wrap import as_device
from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity
from ..core.ops.special import Const
from ..core.tensor.array_method import _broadcast, _remove_axis
from ..core.tensor.array_method import _broadcast
from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn
from ..device import get_default_device
from ..tensor import Tensor
@@ -996,7 +997,7 @@ def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Te

(1, 1, 2)
"""
return _remove_axis(inp, axis)
return squeeze_cpp(inp, axis)


def linspace(


+ 2
- 0
imperative/python/src/tensor.cpp View File

@@ -635,6 +635,7 @@ WRAP_FUNC_PY35(getitem_cpp);
WRAP_FUNC_PY35(setitem_cpp);
WRAP_FUNC_PY35(split_cpp);
WRAP_FUNC_PY35(expand_dims_cpp);
WRAP_FUNC_PY35(squeeze_cpp);
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
@@ -769,6 +770,7 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE(setitem_cpp, setitem_cpp),
MGE_PY_INTERFACE(split_cpp, split_cpp),
MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp),
MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp),
{nullptr, nullptr, 0, nullptr}};
for (auto&& def : method_defs) {
if (def.ml_meth != nullptr) {


+ 70
- 6
imperative/python/src/tensor_utils.cpp View File

@@ -683,17 +683,21 @@ py::object _split_cpp(
return py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
}

py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) {
std::vector<int32_t> list2vector(py::handle li) {
std::vector<int32_t> axis;
if (is_py_sequence(axis_hdl.ptr())) {
py::list tmp_list =
py::reinterpret_steal<py::list>(PySequence_List(axis_hdl.ptr()));
if (is_py_sequence(li.ptr())) {
py::list tmp_list = py::reinterpret_steal<py::list>(PySequence_List(li.ptr()));
for (size_t i = 0; i < tmp_list.size(); ++i) {
axis.push_back(tmp_list[i].attr("__int__")().cast<int32_t>());
}
} else {
axis.push_back(getattr(axis_hdl, "__int__")().cast<int>());
axis.push_back(getattr(li, "__int__")().cast<int32_t>());
}
return axis;
}

py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) {
std::vector<int32_t> axis = list2vector(axis_hdl);
bool unknown_ndim = true;
size_t ndim = axis.size();
if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) {
@@ -718,7 +722,7 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) {
"Does not support negative index when tensor's ndim is "
"unknown");
}
axis[i] += ndim;
axis[i] += static_cast<int32_t>(ndim);
}
}
if (!axis.size()) {
@@ -736,6 +740,59 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) {
return ret[0];
}

py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) {
std::vector<int32_t> axis;
size_t ndim;
if (axis_hdl.ptr() != Py_None) {
axis = list2vector(axis_hdl);
}
if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) {
auto&& shape = p->m_tensor->shape();
if (shape) {
ndim = shape->ndim;
if (axis_hdl.ptr() == Py_None) {
for (size_t i = 0; i < shape->ndim; ++i) {
if (shape->shape[i] == 1) {
axis.push_back(i);
}
}
}
}
} else {
auto&& var = inp_hdl.cast<PySymbolVar*>();
auto&& mgr = var->m_node->owner_graph()->static_infer_manager();
auto&& shape = mgr.infer_shape_fallible(var->m_node);
if (shape) {
ndim = shape->ndim;
if (axis_hdl.ptr() == Py_None) {
for (size_t i = 0; i < shape->ndim; ++i) {
if (shape->shape[i] == 1) {
axis.push_back(i);
}
}
}
}
}
for (size_t i = 0; i < axis.size(); ++i) {
if (axis[i] < 0) {
axis[i] += static_cast<int32_t>(ndim);
}
}
std::sort(axis.begin(), axis.end());
for (size_t i = 0; i < axis.size(); ++i) {
axis[i] -= static_cast<int32_t>(i);
}
std::shared_ptr<OpDef> op = RemoveAxis::make(axis = axis);
std::vector<PyObject*> p;
p.resize(2);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = inp_hdl.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
}

PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _make_shape_tuple(py::handle(args[0])).release().ptr();
@@ -778,4 +835,11 @@ PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}

PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _squeeze_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}

} // namespace mgb::imperative::python

+ 2
- 0
imperative/python/src/tensor_utils.h View File

@@ -12,4 +12,6 @@ PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs);

PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs);

PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs);

} // namespace mgb::imperative::python

Loading…
Cancel
Save