|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845 |
- /**
- * \file imperative/python/src/tensor.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
-
- #include "megbrain/common.h"
- #include "megbrain/dtype.h"
- #include "megbrain/imperative/ops/autogen.h"
- #include "megbrain/imperative/ops/backward_graph.h"
- #include "megbrain/imperative/ops/utility.h"
- #include "megbrain/imperative/profiler.h"
- #include "megbrain/imperative/transformations/eval.h"
- #include "megbrain/imperative/transformations/lazy.h"
- #include "megbrain/imperative/transformations/scalar.h"
- #include "megbrain/imperative/transformations/symbol.h"
- #include "megbrain/imperative/transformations/trace.h"
- #include "megbrain/imperative/utils/map.h"
- #include "megbrain/imperative/utils/stats.h"
- #include "megbrain/opr/io.h"
- #include "megbrain/plugin/profiler.h"
-
- #include "./common.h"
- #include "./grad.h"
- #include "./graph_rt.h"
- #include "./helper.h"
- #include "./module_trace.h"
- #include "./numpy_dtypes.h"
- #include "./tensor.h"
- #include "./tensor_utils.h"
- #include "./transformation.h"
-
- #include <object.h>
- #include <pybind11/numpy.h>
- #include <pybind11/operators.h>
- #include <pybind11/pytypes.h>
- #include <pyerrors.h>
- #include <range/v3/all.hpp>
- #include <string>
-
- #include <unordered_map>
-
- #include "../../src/impl/mgb_cg_impl.h"
-
- namespace py = pybind11;
- namespace views = ranges::views;
-
- namespace mgb::imperative::python {
-
- bool is_scalar(PyObject* tensor) {
- if (py::isinstance<PySymbolVar>(py::handle(tensor))) {
- auto var = py::handle(tensor).cast<PySymbolVar*>();
- return var->is_scalar;
- }
- auto* tw = TensorWrapper::try_cast(tensor);
- if (tw) {
- return tw->m_tensor->is_scalar();
- }
- return PyArray_CheckAnyScalar(tensor);
- }
-
- bool is_bool_list(PyObject* arg) {
- if (!PyList_Check(arg)) {
- return false;
- }
- size_t sz = PyList_Size(arg);
- if (!sz) {
- return false;
- }
- for (size_t i = 0; i < sz; ++i) {
- PyObject* handle = PyList_GetItem(arg, i);
- if (!PyBool_Check(handle)) {
- return false;
- }
- }
- return true;
- }
-
- bool is_bool_dtype(PyObject* args) {
- if (!PyObject_HasAttrString(args, "dtype"))
- return false;
- PyObject* dobj = PyObject_GetAttrString(args, "dtype");
- PyArray_Descr* dtype;
- PyArray_DescrConverter(dobj, &dtype);
- bool ret = (dtype->kind == 'b');
- Py_XDECREF(dtype);
- Py_XDECREF(dobj);
- return ret;
- }
-
- py::object _Const(
- py::handle value, py::handle dtype, py::handle device, py::handle ref) {
- py::object val = py::reinterpret_borrow<py::object>(value);
- if (PyArray_Check(value.ptr())) {
- py::tuple strides =
- py::reinterpret_borrow<py::tuple>(getattr(value, "strides"));
- bool need_squeeze = false;
- for (size_t i = 0; i < strides.size(); ++i) {
- if (strides[i].cast<ptrdiff_t>() == 0) {
- need_squeeze = true;
- }
- }
- if (need_squeeze) {
- val = py::reinterpret_borrow<py::array>(value);
- val = val.attr("squeeze")();
- val = val.attr("reshape")(val.attr("shape"));
- }
- }
- if (py::isinstance<PySymbolVar>(ref)) {
- auto ref_var = ref.cast<PySymbolVar*>();
- auto* graph = ref_var->m_node->owner_graph();
- auto cn = device.cast<CompNode>();
- OperatorNodeConfig config(cn);
- auto hv = npy::np2tensor(
- val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
- auto typeobj = ref.get_type();
- return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
- }
- py::tuple tup = py::make_tuple(val, dtype, device, true, false, py::none());
- return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr);
- }
-
- py::tuple _make_shape_tuple(py::handle shape) {
- py::list orig;
- py::list ret(0);
- auto solve_one = [&](py::handle val) {
- if (TensorWrapper::try_cast(val.ptr()) || py::isinstance<PySymbolVar>(val)) {
- py::object np = getattr(val, "numpy")();
- PyArrayObject* arr = (PyArrayObject*)np.ptr();
- PyObject* maybe_list = PyArray_ToList(arr);
- if (PyList_Check(maybe_list)) {
- py::list may = py::reinterpret_steal<py::list>(maybe_list);
- for (size_t i = 0; i < may.size(); ++i) {
- ret.append(may[i]);
- }
- } else {
- mgb_assert(PyLong_Check(maybe_list));
- ret.append(PyLong_AsLong(maybe_list));
- Py_XDECREF(maybe_list);
- }
- } else if (PyArray_Check(val.ptr())) {
- ret.append(PyArray_PyIntAsInt(val.ptr()));
- } else {
- ret.append(PyLong_AsLong(val.ptr()));
- }
- };
- if (PyArray_Check(shape.ptr()) && !PyArray_CheckAnyScalar(shape.ptr())) {
- orig = py::reinterpret_steal<py::list>(
- PyArray_ToList((PyArrayObject*)shape.ptr()));
- for (size_t i = 0; i < orig.size(); ++i) {
- solve_one(orig[i]);
- }
- } else if (PyList_Check(shape.ptr())) {
- orig = py::reinterpret_borrow<py::list>(shape);
- for (size_t i = 0; i < orig.size(); ++i) {
- solve_one(orig[i]);
- }
- } else if (PyTuple_Check(shape.ptr())) {
- py::tuple tup = py::reinterpret_borrow<py::tuple>(shape);
- for (size_t i = 0; i < tup.size(); ++i) {
- solve_one(tup[i]);
- }
- } else {
- solve_one(shape);
- }
- return py::reinterpret_steal<py::tuple>(PyList_AsTuple(ret.ptr()));
- }
-
- py::object _get_index(py::object tensor, py::object src) {
- if (!TensorWrapper::try_cast(tensor.ptr()) &&
- !py::isinstance<PySymbolVar>(tensor)) {
- auto get_const = [&](mgb::DType dtype) -> py::object {
- return _Const(tensor, py::cast(dtype), src.attr("device"), src);
- };
- if (is_bool_list(tensor.ptr()) || is_bool_dtype(tensor.ptr())) {
- tensor = get_const(dtype::Bool());
- } else {
- tensor = get_const(dtype::Int32());
- }
- if (!is_bool_dtype(tensor.ptr())) {
- return tensor;
- }
- } else {
- if (!is_bool_dtype(tensor.ptr())) {
- return tensor;
- }
- }
- static std::shared_ptr<OpDef> op = CondTake::make();
- std::vector<PyObject*> p;
- p.resize(3);
- py::object Op = py::cast(op);
- p[0] = Op.ptr();
- p[1] = tensor.ptr();
- p[2] = tensor.ptr();
- py::tuple ret =
- py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
- return ret[1];
- }
-
- py::tuple _try_cond_take(py::handle tensor, py::handle index) {
- if (!hasattr(index, "dtype") || !hasattr(index, "shape")) {
- return py::tuple();
- }
- if (!is_bool_dtype(index.ptr()) ||
- _make_shape_tuple(getattr(index, "shape"))
- .not_equal(_make_shape_tuple(getattr(tensor, "shape")))) {
- return py::tuple();
- }
- py::object iobj;
- if (PyArray_Check(index.ptr())) {
- iobj =
- _Const(index, py::cast((mgb::DType)dtype::Bool()),
- getattr(tensor, "device"), tensor);
- } else {
- iobj = py::reinterpret_borrow<py::object>(index);
- }
- static std::shared_ptr<OpDef> op = CondTake::make();
- std::vector<PyObject*> p;
- p.resize(3);
- py::object Op = py::cast(op);
- p[0] = Op.ptr();
- p[1] = tensor.ptr();
- p[2] = iobj.ptr();
- py::tuple ret =
- py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
- return ret;
- }
-
- py::tuple _remove_ellipsis(py::object tensor, py::tuple tuple_val) {
- size_t tuple_size = tuple_val.size();
- size_t ndim_sum = 0, cur_sum = 0;
- int pos = -1;
- bool has_unknown_ndim_bool_index = false;
- for (size_t i = 0; i < tuple_size; ++i) {
- py::object handle = tuple_val[i];
- if (handle.ptr() == Py_Ellipsis) {
- pos = static_cast<int>(i);
- for (size_t j = 0; j < i; ++j) {
- py::object t = tuple_val[j];
- if (t.ptr() == Py_Ellipsis) {
- throw py::index_error("only one ellipsis is allowed.");
- }
- }
- } else {
- size_t ndim_incr = 1;
- if (hasattr(handle, "dtype") && is_bool_dtype(handle.ptr()) &&
- hasattr(handle, "ndim")) {
- py::object ndim = getattr(handle, "ndim");
- if (PyLong_Check(ndim.ptr())) {
- ndim_incr = PyLong_AsLong(ndim.ptr());
- } else {
- has_unknown_ndim_bool_index = true;
- }
- }
- cur_sum += ndim_incr;
- }
- }
- if (pos == -1) {
- return tuple_val;
- } else {
- if (has_unknown_ndim_bool_index) {
- throw py::index_error(
- "does not support bool index with unknown shape when using "
- "Ellipsis.");
- }
- try {
- ndim_sum = getattr(tensor, "ndim").cast<size_t>();
- } catch (py::error_already_set& err) {
- throw py::index_error(
- "does not support Ellipsis when tensor's ndim is unknown.");
- }
- py::tuple ret(ndim_sum - cur_sum + tuple_size - 1);
- size_t idx = 0;
- for (size_t i = 0; i < tuple_size; ++i) {
- if (i == pos) {
- for (size_t j = cur_sum; j < ndim_sum; ++j) {
- ret[idx++] = PySlice_New(NULL, NULL, NULL);
- }
- } else {
- ret[idx++] = tuple_val[i];
- }
- }
- return ret;
- }
- }
-
- py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
- py::tuple cur_shape = _make_shape_tuple(py::handle(getattr(tensor, "shape")));
- py::list new_tuple_val(0);
-
- size_t offset = 0;
- size_t tdim = 0;
- for (size_t i = 0; i < tuple_val.size(); ++i) {
- py::handle k = tuple_val[i];
- if (is_bool_dtype(k.ptr())) {
- size_t ndim = getattr(k, "ndim").cast<size_t>();
- if (ndim > 1) {
- py::tuple ishape = _make_shape_tuple(py::handle(getattr(k, "shape")));
- for (size_t j = 0; j < ndim; ++j) {
- if (cur_shape[tdim + j - offset].cast<size_t>() !=
- ishape[j].cast<size_t>()) {
- std::string msg =
- "boolean index did not match tensor along dimension " +
- std::to_string(tdim + j) + "; dimension is " +
- std::to_string(
- cur_shape[tdim + j - offset].cast<size_t>()) +
- " but corresponding boolean dimension is " +
- std::to_string(ishape[j].cast<size_t>());
- throw py::index_error(msg.c_str());
- }
- }
- py::object new_k = getattr(k, "reshape")(-1);
- py::object kshape = getattr(new_k, "shape");
- py::list new_shape(0);
- PyObject* sym = PyObject_CallObject(cpp_use_symbolic_shape, nullptr);
- bool is_sym = (sym == Py_True);
- Py_XDECREF(sym);
- if (is_sym) {
- py::object tshape = getattr(tensor, "shape");
- for (size_t j = 0; j < i; ++j) {
- new_shape.append(tshape[py::int_(j)]);
- }
- new_shape.append(kshape[py::int_(0)]);
- for (size_t j = tdim + ndim - offset; j < cur_shape.size(); ++j) {
- new_shape.append(cur_shape[j]);
- }
- py::tuple args = py::make_tuple(new_shape);
- PyObject* shape_tensor =
- PyObject_CallObject(cpp_astensor1d, args.ptr());
- py::object reshape_func = getattr(tensor, "reshape");
- Py_INCREF(shape_tensor);
- PyObject* Args = PyTuple_New(1);
- PyTuple_SetItem(Args, 0, shape_tensor);
- PyObject* new_tensor =
- PyObject_CallObject(reshape_func.ptr(), Args);
- Py_XDECREF(Args);
- tensor = py::reinterpret_steal<py::object>(new_tensor);
- cur_shape = _make_shape_tuple(py::handle(shape_tensor));
- Py_XDECREF(shape_tensor);
- } else {
- for (size_t j = 0; j < i; ++j) {
- new_shape.append(cur_shape[j]);
- }
- new_shape.append(py::reinterpret_borrow<py::tuple>(kshape)[0]);
- for (size_t j = tdim + ndim - offset; j < cur_shape.size(); ++j) {
- new_shape.append(cur_shape[j]);
- }
- cur_shape = new_shape;
- tensor = getattr(tensor, "reshape")(cur_shape);
- }
- offset++;
- tdim += ndim;
- }
- new_tuple_val.append(k);
- } else {
- new_tuple_val.append(k);
- tdim++;
- }
- }
- return py::make_tuple(tensor, py::reinterpret_borrow<py::tuple>(new_tuple_val));
- }
-
- py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
- py::object inp = py::reinterpret_borrow<py::object>(inp_hdl);
- py::tuple tuple_val;
- if (py::isinstance<py::tuple>(idx_hdl)) {
- tuple_val = py::reinterpret_borrow<py::tuple>(idx_hdl);
- } else {
- tuple_val = py::make_tuple(idx_hdl);
- }
-
- bool use_subtensor = true;
- bool need_remove_ellipsis = false;
- bool need_expand_bool_dim = false;
- size_t idx_ndim = 0;
- for (size_t i = 0; i < tuple_val.size(); ++i) {
- py::object k = tuple_val[i];
- if (k.ptr() == Py_None) {
- throw py::index_error("newaxis is not allowed here");
- } else if (k.ptr() == Py_Ellipsis) {
- need_remove_ellipsis = true;
- } else {
- if (is_bool_dtype(k.ptr()) && hasattr(k, "ndim")) {
- size_t ndim = getattr(k, "ndim").cast<size_t>();
- idx_ndim += ndim;
- if (ndim > 1) {
- need_expand_bool_dim = true;
- }
- } else {
- idx_ndim++;
- }
- }
- }
- try {
- size_t inp_ndim = getattr(inp, "ndim").cast<size_t>();
- if (idx_ndim > inp_ndim) {
- std::string msg = "too many indices for tensor: tensor is " +
- std::to_string(inp_ndim) + "-dimensional, but " +
- std::to_string(idx_ndim) + " were indexed";
- throw py::index_error(msg.c_str());
- }
- } catch (py::error_already_set& err) {
- ; // ignore
- }
- if (need_remove_ellipsis) {
- tuple_val = _remove_ellipsis(inp, tuple_val);
- }
-
- if (need_expand_bool_dim) {
- py::object shape = getattr(inp, "shape");
- if (shape.ptr() != Py_None) {
- py::tuple ret = _expand_bool_dim(inp, tuple_val);
- inp = ret[0];
- tuple_val = ret[1];
- }
- }
-
- py::list items;
- py::list tensors;
- int cur_axis = -1;
-
- for (size_t i = 0; i < tuple_val.size(); ++i) {
- py::object handle = tuple_val[i];
- cur_axis++;
- if (!is_scalar(handle.ptr()) && !PySlice_Check(handle.ptr())) {
- use_subtensor = false;
- }
- py::list item;
- item.append(cur_axis);
- auto push = [&](PyObject* v) {
- if (v == Py_None) {
- item.append(false);
- } else {
- item.append(true);
- tensors.append(_get_index(py::reinterpret_borrow<py::object>(v), inp));
- }
- };
-
- if (PySlice_Check(handle.ptr())) {
- PySliceObject* s = (PySliceObject*)handle.ptr();
- if (s->start == Py_None && s->stop == Py_None && s->step == Py_None) {
- continue;
- }
- push(s->start);
- push(s->stop);
- push(s->step);
- item.append(false);
- } else {
- for (size_t j = 0; j < 3; j++)
- item.append(false);
- push(handle.ptr());
- }
- items.append(item);
- }
-
- return py::make_tuple(inp, tensors, items, use_subtensor, need_expand_bool_dim);
- }
-
- py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
- py::tuple try_res = _try_cond_take(inp_hdl, idx_hdl);
- if (try_res.size() == 2) {
- return try_res[0];
- }
- py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
- py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
- py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
- py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
- std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
- for (size_t i = 0; i < py_items.size(); ++i) {
- py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
- cpp_items.push_back(
- {item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
- item[3].cast<bool>(), item[4].cast<bool>()});
- }
- static std::shared_ptr<OpDef> op;
- if (up[3].cast<bool>()) {
- op = Subtensor::make(cpp_items);
- } else {
- op = IndexingMultiAxisVec::make(cpp_items);
- }
- std::vector<PyObject*> p;
- p.resize(tensors.size() + 2);
- py::object Op = py::cast(op);
- p[0] = Op.ptr();
- p[1] = tensor.ptr();
- for (size_t i = 0; i < tensors.size(); ++i) {
- p[i + 2] = tensors[i].ptr();
- }
- py::tuple ret =
- py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
- return ret[0];
- }
-
- py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) {
- py::object org_shape = getattr(inp_hdl, "shape");
- py::object val = py::reinterpret_borrow<py::object>(val_hdl);
- if (!TensorWrapper::try_cast(val.ptr()) && !py::isinstance<PySymbolVar>(val)) {
- val =
- _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"),
- inp_hdl);
- }
-
- py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
- py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
- py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
- py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
- std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
- for (size_t i = 0; i < py_items.size(); ++i) {
- py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
- cpp_items.push_back(
- {item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
- item[3].cast<bool>(), item[4].cast<bool>()});
- }
- static std::shared_ptr<OpDef> op, set_op;
- if (up[3].cast<bool>()) {
- op = Subtensor::make(cpp_items);
- } else {
- op = IndexingMultiAxisVec::make(cpp_items);
- }
- std::vector<PyObject*> p;
- p.resize(tensors.size() + 2);
- py::object Op = py::cast(op);
- p[0] = Op.ptr();
- p[1] = tensor.ptr();
- for (size_t i = 0; i < tensors.size(); ++i) {
- p[i + 2] = tensors[i].ptr();
- }
- py::tuple ret =
- py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
- py::object tmp_result = ret[0];
-
- try {
- py::object value_tuple_shape = val.attr("_tuple_shape");
- py::object tmp_result_tuple_shape = tmp_result.attr("_tuple_shape");
- py::tuple value_shape = py::reinterpret_borrow<py::tuple>(value_tuple_shape);
- py::tuple tmp_result_shape =
- py::reinterpret_borrow<py::tuple>(tmp_result_tuple_shape);
- for (size_t i = 0; i < value_shape.size() && i < tmp_result_shape.size(); ++i) {
- size_t vs = value_shape[value_shape.size() - i - 1].cast<size_t>();
- size_t ts =
- tmp_result_shape[tmp_result_shape.size() - i - 1].cast<size_t>();
- if (vs != 1 && vs != ts) {
- std::string lhs = "", rhs = "";
- for (size_t j = 0; j < tmp_result_shape.size(); ++j) {
- lhs += std::to_string(tmp_result_shape[j].cast<size_t>());
- if (j)
- lhs += ",";
- }
- for (size_t j = 0; j < value_shape.size(); ++j) {
- rhs += std::to_string(value_shape[j].cast<size_t>());
- if (j)
- rhs += ",";
- }
- throw py::value_error(
- "cannot copy tensor with shape (" + rhs +
- ") to subtensor with shape (" + lhs + ")");
- }
- }
- } catch (py::error_already_set& err) {
- ;
- }
-
- py::object broadcast_func = getattr(val, "_broadcast");
- PyObject* Args = PyTuple_New(1);
- PyTuple_SetItem(Args, 0, getattr(tmp_result, "shape").release().ptr());
- PyObject* new_val = PyObject_CallObject(broadcast_func.ptr(), Args);
- Py_XDECREF(Args);
- val = py::reinterpret_steal<py::object>(new_val);
-
- if (up[3].cast<bool>()) {
- set_op = SetSubtensor::make(cpp_items);
- } else {
- set_op = IndexingSetMultiAxisVec::make(cpp_items);
- }
-
- std::vector<PyObject*> q;
- q.resize(tensors.size() + 3);
- py::object Set_Op = py::cast(set_op);
- q[0] = Set_Op.ptr();
- q[1] = tensor.ptr();
- q[2] = val.ptr();
- for (size_t i = 0; i < tensors.size(); ++i) {
- q[i + 3] = tensors[i].ptr();
- }
- py::tuple result =
- py::reinterpret_steal<py::object>(py_apply(NULL, q.data(), q.size()));
- py::object res = result[0];
-
- if (up[4].cast<bool>()) {
- py::object reshape_func = getattr(res, "reshape");
- PyObject* Args = PyTuple_New(1);
- PyTuple_SetItem(Args, 0, org_shape.release().ptr());
- PyObject* new_tensor = PyObject_CallObject(reshape_func.ptr(), Args);
- Py_XDECREF(Args);
- res = py::reinterpret_steal<py::object>(new_tensor);
- }
-
- return res;
- }
-
- bool is_tensor_or_symbolvar(py::handle arg) {
- return bool(TensorWrapper::try_cast(arg.ptr())) || py::isinstance<PySymbolVar>(arg);
- }
-
- bool is_py_sequence(py::handle arg) {
- if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr()) ||
- py::isinstance<PySymbolVar>(arg)) {
- return false;
- }
- return PySequence_Check(arg.ptr());
- }
-
- py::object _split_cpp(
- py::handle inp_hdl, py::handle nsplits_or_sections_hdl, py::handle axis_hdl) {
- py::object shape_obj = getattr(inp_hdl, "shape");
- py::object n_total = shape_obj[axis_hdl];
- int ndim = shape_obj.attr("__len__")().cast<int>();
- int axis = axis_hdl.cast<int>();
- if (axis >= ndim) {
- throw py::value_error("Invalid axis " + std::to_string(axis));
- }
- int n_sections;
- bool is_array;
- if (is_py_sequence(nsplits_or_sections_hdl)) {
- n_sections = PySequence_Length(nsplits_or_sections_hdl.ptr()) + 1;
- is_array = true;
- } else {
- n_sections = getattr(nsplits_or_sections_hdl, "__int__")().cast<int>();
- is_array = false;
- }
- py::list partitions;
- std::shared_ptr<OpDef> op;
- std::vector<PyObject*> p;
- if (is_array) {
- py::list div_points;
- py::list sections = py::reinterpret_borrow<py::object>(nsplits_or_sections_hdl);
- div_points.append(0);
- for (size_t i = 0; i < sections.size(); ++i) {
- div_points.append(sections[i]);
- }
- div_points.append(n_total);
- for (size_t i = 1; i < div_points.size(); ++i) {
- if (div_points[i - 1] > div_points[i]) {
- throw py::value_error(
- "Invalid nsplits_or_secions: " +
- repr(nsplits_or_sections_hdl).cast<std::string>());
- }
- py::object pos = div_points[i] - div_points[i - 1];
- if (is_tensor_or_symbolvar(pos)) {
- partitions.append(pos);
- } else {
- partitions.append(
- _Const(pos, py::cast((mgb::DType)dtype::Int32()),
- getattr(inp_hdl, "device"), inp_hdl));
- }
- }
- op = Split::make(axis, 0);
- p.resize(partitions.size() + 2);
- for (size_t i = 0; i < partitions.size(); ++i) {
- p[i + 2] = partitions[i].ptr();
- }
- } else {
- if (n_sections <= 0) {
- throw py::value_error("Number sections must be larger than 0");
- }
- if (py::int_(n_sections) > n_total) {
- throw py::value_error(
- "The size " + repr(n_total).cast<std::string>() + " at dim " +
- std::to_string(axis) + " cannot be split into " +
- std::to_string(n_sections) + " sections");
- }
- op = Split::make(axis, n_sections);
- p.resize(2);
- }
- py::object Op = py::cast(op);
- p[0] = Op.ptr();
- p[1] = inp_hdl.ptr();
- return py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
- }
-
- std::vector<int32_t> list2vector(py::handle li) {
- std::vector<int32_t> axis;
- if (is_py_sequence(li.ptr())) {
- py::list tmp_list = py::reinterpret_steal<py::list>(PySequence_List(li.ptr()));
- for (size_t i = 0; i < tmp_list.size(); ++i) {
- axis.push_back(tmp_list[i].attr("__int__")().cast<int32_t>());
- }
- } else {
- axis.push_back(getattr(li, "__int__")().cast<int32_t>());
- }
- return axis;
- }
-
- py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) {
- std::vector<int32_t> axis = list2vector(axis_hdl);
- bool unknown_ndim = true;
- size_t ndim = axis.size();
- if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) {
- auto&& shape = p->m_tensor->shape();
- if (shape) {
- unknown_ndim = false;
- ndim += shape->ndim;
- }
- } else {
- auto&& var = inp_hdl.cast<PySymbolVar*>();
- auto&& mgr = var->m_node->owner_graph()->static_infer_manager();
- auto&& shape = mgr.infer_shape_fallible(var->m_node);
- if (shape) {
- unknown_ndim = false;
- ndim += shape->ndim;
- }
- }
- for (size_t i = 0; i < axis.size(); ++i) {
- if (axis[i] < 0) {
- if (unknown_ndim) {
- throw py::index_error(
- "Does not support negative index when tensor's ndim is "
- "unknown");
- }
- axis[i] += static_cast<int32_t>(ndim);
- }
- }
- if (!axis.size()) {
- throw py::index_error("axis could not be empty");
- }
- std::sort(axis.begin(), axis.end());
- std::shared_ptr<OpDef> op = AddAxis::make(axis = axis);
- std::vector<PyObject*> p;
- p.resize(2);
- py::object Op = py::cast(op);
- p[0] = Op.ptr();
- p[1] = inp_hdl.ptr();
- py::tuple ret =
- py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
- return ret[0];
- }
-
- py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) {
- std::vector<int32_t> axis;
- size_t ndim;
- if (axis_hdl.ptr() != Py_None) {
- axis = list2vector(axis_hdl);
- }
- if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) {
- auto&& shape = p->m_tensor->shape();
- if (shape) {
- ndim = shape->ndim;
- if (axis_hdl.ptr() == Py_None) {
- for (size_t i = 0; i < shape->ndim; ++i) {
- if (shape->shape[i] == 1) {
- axis.push_back(i);
- }
- }
- }
- }
- } else {
- auto&& var = inp_hdl.cast<PySymbolVar*>();
- auto&& mgr = var->m_node->owner_graph()->static_infer_manager();
- auto&& shape = mgr.infer_shape_fallible(var->m_node);
- if (shape) {
- ndim = shape->ndim;
- if (axis_hdl.ptr() == Py_None) {
- for (size_t i = 0; i < shape->ndim; ++i) {
- if (shape->shape[i] == 1) {
- axis.push_back(i);
- }
- }
- }
- }
- }
- for (size_t i = 0; i < axis.size(); ++i) {
- if (axis[i] < 0) {
- axis[i] += static_cast<int32_t>(ndim);
- }
- }
- std::sort(axis.begin(), axis.end());
- for (size_t i = 0; i < axis.size(); ++i) {
- axis[i] -= static_cast<int32_t>(i);
- }
- std::shared_ptr<OpDef> op = RemoveAxis::make(axis = axis);
- std::vector<PyObject*> p;
- p.resize(2);
- py::object Op = py::cast(op);
- p[0] = Op.ptr();
- p[1] = inp_hdl.ptr();
- py::tuple ret =
- py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
- return ret[0];
- }
-
- PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
- try {
- return _make_shape_tuple(py::handle(args[0])).release().ptr();
- }
- PYEXT17_TRANSLATE_EXC_RET(nullptr)
- }
-
- PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
- try {
- return _getitem_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
- }
- PYEXT17_TRANSLATE_EXC_RET(nullptr)
- }
-
- PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
- try {
- return _setitem_cpp(
- py::handle(args[0]), py::handle(args[1]), py::handle(args[2]))
- .release()
- .ptr();
- }
- PYEXT17_TRANSLATE_EXC_RET(nullptr)
- }
-
- PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
- try {
- return _split_cpp(py::handle(args[0]), py::handle(args[1]), py::handle(args[2]))
- .release()
- .ptr();
- }
- PYEXT17_TRANSLATE_EXC_RET(nullptr)
- }
-
- PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
- try {
- return _expand_dims_cpp(py::handle(args[0]), py::handle(args[1]))
- .release()
- .ptr();
- }
- PYEXT17_TRANSLATE_EXC_RET(nullptr)
- }
-
- PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
- try {
- return _squeeze_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
- }
- PYEXT17_TRANSLATE_EXC_RET(nullptr)
- }
-
- } // namespace mgb::imperative::python
|