|
|
@@ -793,6 +793,57 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) { |
|
|
|
return ret[0]; |
|
|
|
} |
|
|
|
|
|
|
|
size_t fast_ndim(py::handle tensor) { |
|
|
|
if (auto p = TensorWrapper::try_cast(tensor.ptr())) { |
|
|
|
return p->m_tensor->shape()->ndim; |
|
|
|
} |
|
|
|
return getattr(tensor, "ndim").cast<size_t>(); |
|
|
|
} |
|
|
|
|
|
|
|
py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { |
|
|
|
py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr()); |
|
|
|
if (fast_ndim(inp_hdl) == 0) { |
|
|
|
if (args_tup.size() != 0) { |
|
|
|
throw py::index_error( |
|
|
|
"transpose for scalar does not accept additional args"); |
|
|
|
} |
|
|
|
return getattr(inp_hdl, "to")(getattr(inp_hdl, "device")); |
|
|
|
} |
|
|
|
std::vector<int32_t> pattern; |
|
|
|
if (!args_tup.size()) { |
|
|
|
size_t ndim = getattr(inp_hdl, "ndim").cast<size_t>(); |
|
|
|
for (size_t i = 0; i < ndim; ++i) { |
|
|
|
pattern.push_back(ndim - i - 1); |
|
|
|
} |
|
|
|
} else { |
|
|
|
py::list lis; |
|
|
|
if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) || |
|
|
|
is_tensor_or_symbolvar(args_tup[0].ptr()))) { |
|
|
|
lis = py::reinterpret_steal<py::list>(PySequence_List(args_tup[0].ptr())); |
|
|
|
} else { |
|
|
|
lis = py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr())); |
|
|
|
} |
|
|
|
for (size_t i = 0; i < lis.size(); ++i) { |
|
|
|
if (PyLong_Check(lis[i].ptr())) { |
|
|
|
pattern.push_back(lis[i].cast<int32_t>()); |
|
|
|
} else { |
|
|
|
if (lis[i].cast<std::string>() == "x") { |
|
|
|
pattern.push_back(-1); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
std::shared_ptr<OpDef> op = Dimshuffle::make(pattern); |
|
|
|
std::vector<PyObject*> p; |
|
|
|
p.resize(2); |
|
|
|
py::object Op = py::cast(op); |
|
|
|
p[0] = Op.ptr(); |
|
|
|
p[1] = inp_hdl.ptr(); |
|
|
|
py::tuple ret = |
|
|
|
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); |
|
|
|
return ret[0]; |
|
|
|
} |
|
|
|
|
|
|
|
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { |
|
|
|
try { |
|
|
|
return _make_shape_tuple(py::handle(args[0])).release().ptr(); |
|
|
@@ -842,4 +893,11 @@ PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs) { |
|
|
|
PYEXT17_TRANSLATE_EXC_RET(nullptr) |
|
|
|
} |
|
|
|
|
|
|
|
PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs) { |
|
|
|
try { |
|
|
|
return _transpose_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); |
|
|
|
} |
|
|
|
PYEXT17_TRANSLATE_EXC_RET(nullptr) |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace mgb::imperative::python |