@@ -12,7 +12,13 @@ from typing import Iterable, Optional, Sequence, Tuple, Union | |||||
import numpy as np | import numpy as np | ||||
from ..core._imperative_rt import CompNode | from ..core._imperative_rt import CompNode | ||||
from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion, split_cpp | |||||
from ..core._imperative_rt.core2 import ( | |||||
SymbolVar, | |||||
apply, | |||||
dtype_promotion, | |||||
expand_dims_cpp, | |||||
split_cpp, | |||||
) | |||||
from ..core._wrap import as_device | from ..core._wrap import as_device | ||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.ops.builtin import Copy, Identity | from ..core.ops.builtin import Copy, Identity | ||||
@@ -959,27 +965,7 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||||
(1, 2) | (1, 2) | ||||
""" | """ | ||||
def get_axes(): | |||||
try: | |||||
return [int(axis)] | |||||
except (TypeError, ValueError): | |||||
pass | |||||
return list(map(int, axis)) | |||||
axis = get_axes() | |||||
try: | |||||
ndim = inp.ndim + len(axis) | |||||
axis = sorted(i + ndim if i < 0 else i for i in axis) | |||||
except ValueError: | |||||
if any([ind < 0 for ind in axis]): | |||||
raise IndexError( | |||||
"Does not support negative index when tensor's ndim is unknown" | |||||
) | |||||
axis = sorted(axis) | |||||
assert axis, "axis could not be empty" | |||||
op = builtin.AddAxis(axis=axis) | |||||
(result,) = apply(op, inp) | |||||
return result | |||||
return expand_dims_cpp(inp, axis) | |||||
def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Tensor: | def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Tensor: | ||||
@@ -634,6 +634,7 @@ WRAP_FUNC_PY35(make_shape_tuple); | |||||
WRAP_FUNC_PY35(getitem_cpp); | WRAP_FUNC_PY35(getitem_cpp); | ||||
WRAP_FUNC_PY35(setitem_cpp); | WRAP_FUNC_PY35(setitem_cpp); | ||||
WRAP_FUNC_PY35(split_cpp); | WRAP_FUNC_PY35(split_cpp); | ||||
WRAP_FUNC_PY35(expand_dims_cpp); | |||||
#undef WRAP_FUNC_PY35 | #undef WRAP_FUNC_PY35 | ||||
#define MGE_PY_INTERFACE(NAME, FUNC) \ | #define MGE_PY_INTERFACE(NAME, FUNC) \ | ||||
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | ||||
@@ -767,6 +768,7 @@ void init_tensor(py::module m) { | |||||
MGE_PY_INTERFACE(getitem_cpp, getitem_cpp), | MGE_PY_INTERFACE(getitem_cpp, getitem_cpp), | ||||
MGE_PY_INTERFACE(setitem_cpp, setitem_cpp), | MGE_PY_INTERFACE(setitem_cpp, setitem_cpp), | ||||
MGE_PY_INTERFACE(split_cpp, split_cpp), | MGE_PY_INTERFACE(split_cpp, split_cpp), | ||||
MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp), | |||||
{nullptr, nullptr, 0, nullptr}}; | {nullptr, nullptr, 0, nullptr}}; | ||||
for (auto&& def : method_defs) { | for (auto&& def : method_defs) { | ||||
if (def.ml_meth != nullptr) { | if (def.ml_meth != nullptr) { | ||||
@@ -683,6 +683,59 @@ py::object _split_cpp( | |||||
return py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | return py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | ||||
} | } | ||||
py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { | |||||
std::vector<int32_t> axis; | |||||
if (is_py_sequence(axis_hdl.ptr())) { | |||||
py::list tmp_list = | |||||
py::reinterpret_steal<py::list>(PySequence_List(axis_hdl.ptr())); | |||||
for (size_t i = 0; i < tmp_list.size(); ++i) { | |||||
axis.push_back(tmp_list[i].attr("__int__")().cast<int32_t>()); | |||||
} | |||||
} else { | |||||
axis.push_back(getattr(axis_hdl, "__int__")().cast<int>()); | |||||
} | |||||
bool unknown_ndim = true; | |||||
size_t ndim = axis.size(); | |||||
if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) { | |||||
auto&& shape = p->m_tensor->shape(); | |||||
if (shape) { | |||||
unknown_ndim = false; | |||||
ndim += shape->ndim; | |||||
} | |||||
} else { | |||||
auto&& var = inp_hdl.cast<PySymbolVar*>(); | |||||
auto&& mgr = var->m_node->owner_graph()->static_infer_manager(); | |||||
auto&& shape = mgr.infer_shape_fallible(var->m_node); | |||||
if (shape) { | |||||
unknown_ndim = false; | |||||
ndim += shape->ndim; | |||||
} | |||||
} | |||||
for (size_t i = 0; i < axis.size(); ++i) { | |||||
if (axis[i] < 0) { | |||||
if (unknown_ndim) { | |||||
throw py::index_error( | |||||
"Does not support negative index when tensor's ndim is " | |||||
"unknown"); | |||||
} | |||||
axis[i] += ndim; | |||||
} | |||||
} | |||||
if (!axis.size()) { | |||||
throw py::index_error("axis could not be empty"); | |||||
} | |||||
std::sort(axis.begin(), axis.end()); | |||||
std::shared_ptr<OpDef> op = AddAxis::make(axis = axis); | |||||
std::vector<PyObject*> p; | |||||
p.resize(2); | |||||
py::object Op = py::cast(op); | |||||
p[0] = Op.ptr(); | |||||
p[1] = inp_hdl.ptr(); | |||||
py::tuple ret = | |||||
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | |||||
return ret[0]; | |||||
} | |||||
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | ||||
try { | try { | ||||
return _make_shape_tuple(py::handle(args[0])).release().ptr(); | return _make_shape_tuple(py::handle(args[0])).release().ptr(); | ||||
@@ -716,4 +769,13 @@ PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | PYEXT17_TRANSLATE_EXC_RET(nullptr) | ||||
} | } | ||||
PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
try { | |||||
return _expand_dims_cpp(py::handle(args[0]), py::handle(args[1])) | |||||
.release() | |||||
.ptr(); | |||||
} | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python |
@@ -10,4 +10,6 @@ PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs); | ||||
PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python |