You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_utils.cpp 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845
  1. /**
  2. * \file imperative/python/src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/common.h"
  12. #include "megbrain/dtype.h"
  13. #include "megbrain/imperative/ops/autogen.h"
  14. #include "megbrain/imperative/ops/backward_graph.h"
  15. #include "megbrain/imperative/ops/utility.h"
  16. #include "megbrain/imperative/profiler.h"
  17. #include "megbrain/imperative/transformations/eval.h"
  18. #include "megbrain/imperative/transformations/lazy.h"
  19. #include "megbrain/imperative/transformations/scalar.h"
  20. #include "megbrain/imperative/transformations/symbol.h"
  21. #include "megbrain/imperative/transformations/trace.h"
  22. #include "megbrain/imperative/utils/map.h"
  23. #include "megbrain/imperative/utils/stats.h"
  24. #include "megbrain/opr/io.h"
  25. #include "megbrain/plugin/profiler.h"
  26. #include "./common.h"
  27. #include "./grad.h"
  28. #include "./graph_rt.h"
  29. #include "./helper.h"
  30. #include "./module_trace.h"
  31. #include "./numpy_dtypes.h"
  32. #include "./tensor.h"
  33. #include "./tensor_utils.h"
  34. #include "./transformation.h"
  35. #include <object.h>
  36. #include <pybind11/numpy.h>
  37. #include <pybind11/operators.h>
  38. #include <pybind11/pytypes.h>
  39. #include <pyerrors.h>
  40. #include <range/v3/all.hpp>
  41. #include <string>
  42. #include <unordered_map>
  43. #include "../../src/impl/mgb_cg_impl.h"
  44. namespace py = pybind11;
  45. namespace views = ranges::views;
  46. namespace mgb::imperative::python {
  47. bool is_scalar(PyObject* tensor) {
  48. if (py::isinstance<PySymbolVar>(py::handle(tensor))) {
  49. auto var = py::handle(tensor).cast<PySymbolVar*>();
  50. return var->is_scalar;
  51. }
  52. auto* tw = TensorWrapper::try_cast(tensor);
  53. if (tw) {
  54. return tw->m_tensor->is_scalar();
  55. }
  56. return PyArray_CheckAnyScalar(tensor);
  57. }
  58. bool is_bool_list(PyObject* arg) {
  59. if (!PyList_Check(arg)) {
  60. return false;
  61. }
  62. size_t sz = PyList_Size(arg);
  63. if (!sz) {
  64. return false;
  65. }
  66. for (size_t i = 0; i < sz; ++i) {
  67. PyObject* handle = PyList_GetItem(arg, i);
  68. if (!PyBool_Check(handle)) {
  69. return false;
  70. }
  71. }
  72. return true;
  73. }
  74. bool is_bool_dtype(PyObject* args) {
  75. if (!PyObject_HasAttrString(args, "dtype"))
  76. return false;
  77. PyObject* dobj = PyObject_GetAttrString(args, "dtype");
  78. PyArray_Descr* dtype;
  79. PyArray_DescrConverter(dobj, &dtype);
  80. bool ret = (dtype->kind == 'b');
  81. Py_XDECREF(dtype);
  82. Py_XDECREF(dobj);
  83. return ret;
  84. }
  85. py::object _Const(
  86. py::handle value, py::handle dtype, py::handle device, py::handle ref) {
  87. py::object val = py::reinterpret_borrow<py::object>(value);
  88. if (PyArray_Check(value.ptr())) {
  89. py::tuple strides =
  90. py::reinterpret_borrow<py::tuple>(getattr(value, "strides"));
  91. bool need_squeeze = false;
  92. for (size_t i = 0; i < strides.size(); ++i) {
  93. if (strides[i].cast<ptrdiff_t>() == 0) {
  94. need_squeeze = true;
  95. }
  96. }
  97. if (need_squeeze) {
  98. val = py::reinterpret_borrow<py::array>(value);
  99. val = val.attr("squeeze")();
  100. val = val.attr("reshape")(val.attr("shape"));
  101. }
  102. }
  103. if (py::isinstance<PySymbolVar>(ref)) {
  104. auto ref_var = ref.cast<PySymbolVar*>();
  105. auto* graph = ref_var->m_node->owner_graph();
  106. auto cn = device.cast<CompNode>();
  107. OperatorNodeConfig config(cn);
  108. auto hv = npy::np2tensor(
  109. val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
  110. auto typeobj = ref.get_type();
  111. return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
  112. }
  113. py::tuple tup = py::make_tuple(val, dtype, device, true, false, py::none());
  114. return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr);
  115. }
  116. py::tuple _make_shape_tuple(py::handle shape) {
  117. py::list orig;
  118. py::list ret(0);
  119. auto solve_one = [&](py::handle val) {
  120. if (TensorWrapper::try_cast(val.ptr()) || py::isinstance<PySymbolVar>(val)) {
  121. py::object np = getattr(val, "numpy")();
  122. PyArrayObject* arr = (PyArrayObject*)np.ptr();
  123. PyObject* maybe_list = PyArray_ToList(arr);
  124. if (PyList_Check(maybe_list)) {
  125. py::list may = py::reinterpret_steal<py::list>(maybe_list);
  126. for (size_t i = 0; i < may.size(); ++i) {
  127. ret.append(may[i]);
  128. }
  129. } else {
  130. mgb_assert(PyLong_Check(maybe_list));
  131. ret.append(PyLong_AsLong(maybe_list));
  132. Py_XDECREF(maybe_list);
  133. }
  134. } else if (PyArray_Check(val.ptr())) {
  135. ret.append(PyArray_PyIntAsInt(val.ptr()));
  136. } else {
  137. ret.append(PyLong_AsLong(val.ptr()));
  138. }
  139. };
  140. if (PyArray_Check(shape.ptr()) && !PyArray_CheckAnyScalar(shape.ptr())) {
  141. orig = py::reinterpret_steal<py::list>(
  142. PyArray_ToList((PyArrayObject*)shape.ptr()));
  143. for (size_t i = 0; i < orig.size(); ++i) {
  144. solve_one(orig[i]);
  145. }
  146. } else if (PyList_Check(shape.ptr())) {
  147. orig = py::reinterpret_borrow<py::list>(shape);
  148. for (size_t i = 0; i < orig.size(); ++i) {
  149. solve_one(orig[i]);
  150. }
  151. } else if (PyTuple_Check(shape.ptr())) {
  152. py::tuple tup = py::reinterpret_borrow<py::tuple>(shape);
  153. for (size_t i = 0; i < tup.size(); ++i) {
  154. solve_one(tup[i]);
  155. }
  156. } else {
  157. solve_one(shape);
  158. }
  159. return py::reinterpret_steal<py::tuple>(PyList_AsTuple(ret.ptr()));
  160. }
  161. py::object _get_index(py::object tensor, py::object src) {
  162. if (!TensorWrapper::try_cast(tensor.ptr()) &&
  163. !py::isinstance<PySymbolVar>(tensor)) {
  164. auto get_const = [&](mgb::DType dtype) -> py::object {
  165. return _Const(tensor, py::cast(dtype), src.attr("device"), src);
  166. };
  167. if (is_bool_list(tensor.ptr()) || is_bool_dtype(tensor.ptr())) {
  168. tensor = get_const(dtype::Bool());
  169. } else {
  170. tensor = get_const(dtype::Int32());
  171. }
  172. if (!is_bool_dtype(tensor.ptr())) {
  173. return tensor;
  174. }
  175. } else {
  176. if (!is_bool_dtype(tensor.ptr())) {
  177. return tensor;
  178. }
  179. }
  180. static std::shared_ptr<OpDef> op = CondTake::make();
  181. std::vector<PyObject*> p;
  182. p.resize(3);
  183. py::object Op = py::cast(op);
  184. p[0] = Op.ptr();
  185. p[1] = tensor.ptr();
  186. p[2] = tensor.ptr();
  187. py::tuple ret =
  188. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  189. return ret[1];
  190. }
  191. py::tuple _try_cond_take(py::handle tensor, py::handle index) {
  192. if (!hasattr(index, "dtype") || !hasattr(index, "shape")) {
  193. return py::tuple();
  194. }
  195. if (!is_bool_dtype(index.ptr()) ||
  196. _make_shape_tuple(getattr(index, "shape"))
  197. .not_equal(_make_shape_tuple(getattr(tensor, "shape")))) {
  198. return py::tuple();
  199. }
  200. py::object iobj;
  201. if (PyArray_Check(index.ptr())) {
  202. iobj =
  203. _Const(index, py::cast((mgb::DType)dtype::Bool()),
  204. getattr(tensor, "device"), tensor);
  205. } else {
  206. iobj = py::reinterpret_borrow<py::object>(index);
  207. }
  208. static std::shared_ptr<OpDef> op = CondTake::make();
  209. std::vector<PyObject*> p;
  210. p.resize(3);
  211. py::object Op = py::cast(op);
  212. p[0] = Op.ptr();
  213. p[1] = tensor.ptr();
  214. p[2] = iobj.ptr();
  215. py::tuple ret =
  216. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  217. return ret;
  218. }
  219. py::tuple _remove_ellipsis(py::object tensor, py::tuple tuple_val) {
  220. size_t tuple_size = tuple_val.size();
  221. size_t ndim_sum = 0, cur_sum = 0;
  222. int pos = -1;
  223. bool has_unknown_ndim_bool_index = false;
  224. for (size_t i = 0; i < tuple_size; ++i) {
  225. py::object handle = tuple_val[i];
  226. if (handle.ptr() == Py_Ellipsis) {
  227. pos = static_cast<int>(i);
  228. for (size_t j = 0; j < i; ++j) {
  229. py::object t = tuple_val[j];
  230. if (t.ptr() == Py_Ellipsis) {
  231. throw py::index_error("only one ellipsis is allowed.");
  232. }
  233. }
  234. } else {
  235. size_t ndim_incr = 1;
  236. if (hasattr(handle, "dtype") && is_bool_dtype(handle.ptr()) &&
  237. hasattr(handle, "ndim")) {
  238. py::object ndim = getattr(handle, "ndim");
  239. if (PyLong_Check(ndim.ptr())) {
  240. ndim_incr = PyLong_AsLong(ndim.ptr());
  241. } else {
  242. has_unknown_ndim_bool_index = true;
  243. }
  244. }
  245. cur_sum += ndim_incr;
  246. }
  247. }
  248. if (pos == -1) {
  249. return tuple_val;
  250. } else {
  251. if (has_unknown_ndim_bool_index) {
  252. throw py::index_error(
  253. "does not support bool index with unknown shape when using "
  254. "Ellipsis.");
  255. }
  256. try {
  257. ndim_sum = getattr(tensor, "ndim").cast<size_t>();
  258. } catch (py::error_already_set& err) {
  259. throw py::index_error(
  260. "does not support Ellipsis when tensor's ndim is unknown.");
  261. }
  262. py::tuple ret(ndim_sum - cur_sum + tuple_size - 1);
  263. size_t idx = 0;
  264. for (size_t i = 0; i < tuple_size; ++i) {
  265. if (i == pos) {
  266. for (size_t j = cur_sum; j < ndim_sum; ++j) {
  267. ret[idx++] = PySlice_New(NULL, NULL, NULL);
  268. }
  269. } else {
  270. ret[idx++] = tuple_val[i];
  271. }
  272. }
  273. return ret;
  274. }
  275. }
  276. py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
  277. py::tuple cur_shape = _make_shape_tuple(py::handle(getattr(tensor, "shape")));
  278. py::list new_tuple_val(0);
  279. size_t offset = 0;
  280. size_t tdim = 0;
  281. for (size_t i = 0; i < tuple_val.size(); ++i) {
  282. py::handle k = tuple_val[i];
  283. if (is_bool_dtype(k.ptr())) {
  284. size_t ndim = getattr(k, "ndim").cast<size_t>();
  285. if (ndim > 1) {
  286. py::tuple ishape = _make_shape_tuple(py::handle(getattr(k, "shape")));
  287. for (size_t j = 0; j < ndim; ++j) {
  288. if (cur_shape[tdim + j - offset].cast<size_t>() !=
  289. ishape[j].cast<size_t>()) {
  290. std::string msg =
  291. "boolean index did not match tensor along dimension " +
  292. std::to_string(tdim + j) + "; dimension is " +
  293. std::to_string(
  294. cur_shape[tdim + j - offset].cast<size_t>()) +
  295. " but corresponding boolean dimension is " +
  296. std::to_string(ishape[j].cast<size_t>());
  297. throw py::index_error(msg.c_str());
  298. }
  299. }
  300. py::object new_k = getattr(k, "reshape")(-1);
  301. py::object kshape = getattr(new_k, "shape");
  302. py::list new_shape(0);
  303. PyObject* sym = PyObject_CallObject(cpp_use_symbolic_shape, nullptr);
  304. bool is_sym = (sym == Py_True);
  305. Py_XDECREF(sym);
  306. if (is_sym) {
  307. py::object tshape = getattr(tensor, "shape");
  308. for (size_t j = 0; j < i; ++j) {
  309. new_shape.append(tshape[py::int_(j)]);
  310. }
  311. new_shape.append(kshape[py::int_(0)]);
  312. for (size_t j = tdim + ndim - offset; j < cur_shape.size(); ++j) {
  313. new_shape.append(cur_shape[j]);
  314. }
  315. py::tuple args = py::make_tuple(new_shape);
  316. PyObject* shape_tensor =
  317. PyObject_CallObject(cpp_astensor1d, args.ptr());
  318. py::object reshape_func = getattr(tensor, "reshape");
  319. Py_INCREF(shape_tensor);
  320. PyObject* Args = PyTuple_New(1);
  321. PyTuple_SetItem(Args, 0, shape_tensor);
  322. PyObject* new_tensor =
  323. PyObject_CallObject(reshape_func.ptr(), Args);
  324. Py_XDECREF(Args);
  325. tensor = py::reinterpret_steal<py::object>(new_tensor);
  326. cur_shape = _make_shape_tuple(py::handle(shape_tensor));
  327. Py_XDECREF(shape_tensor);
  328. } else {
  329. for (size_t j = 0; j < i; ++j) {
  330. new_shape.append(cur_shape[j]);
  331. }
  332. new_shape.append(py::reinterpret_borrow<py::tuple>(kshape)[0]);
  333. for (size_t j = tdim + ndim - offset; j < cur_shape.size(); ++j) {
  334. new_shape.append(cur_shape[j]);
  335. }
  336. cur_shape = new_shape;
  337. tensor = getattr(tensor, "reshape")(cur_shape);
  338. }
  339. offset++;
  340. tdim += ndim;
  341. }
  342. new_tuple_val.append(k);
  343. } else {
  344. new_tuple_val.append(k);
  345. tdim++;
  346. }
  347. }
  348. return py::make_tuple(tensor, py::reinterpret_borrow<py::tuple>(new_tuple_val));
  349. }
  350. py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
  351. py::object inp = py::reinterpret_borrow<py::object>(inp_hdl);
  352. py::tuple tuple_val;
  353. if (py::isinstance<py::tuple>(idx_hdl)) {
  354. tuple_val = py::reinterpret_borrow<py::tuple>(idx_hdl);
  355. } else {
  356. tuple_val = py::make_tuple(idx_hdl);
  357. }
  358. bool use_subtensor = true;
  359. bool need_remove_ellipsis = false;
  360. bool need_expand_bool_dim = false;
  361. size_t idx_ndim = 0;
  362. for (size_t i = 0; i < tuple_val.size(); ++i) {
  363. py::object k = tuple_val[i];
  364. if (k.ptr() == Py_None) {
  365. throw py::index_error("newaxis is not allowed here");
  366. } else if (k.ptr() == Py_Ellipsis) {
  367. need_remove_ellipsis = true;
  368. } else {
  369. if (is_bool_dtype(k.ptr()) && hasattr(k, "ndim")) {
  370. size_t ndim = getattr(k, "ndim").cast<size_t>();
  371. idx_ndim += ndim;
  372. if (ndim > 1) {
  373. need_expand_bool_dim = true;
  374. }
  375. } else {
  376. idx_ndim++;
  377. }
  378. }
  379. }
  380. try {
  381. size_t inp_ndim = getattr(inp, "ndim").cast<size_t>();
  382. if (idx_ndim > inp_ndim) {
  383. std::string msg = "too many indices for tensor: tensor is " +
  384. std::to_string(inp_ndim) + "-dimensional, but " +
  385. std::to_string(idx_ndim) + " were indexed";
  386. throw py::index_error(msg.c_str());
  387. }
  388. } catch (py::error_already_set& err) {
  389. ; // ignore
  390. }
  391. if (need_remove_ellipsis) {
  392. tuple_val = _remove_ellipsis(inp, tuple_val);
  393. }
  394. if (need_expand_bool_dim) {
  395. py::object shape = getattr(inp, "shape");
  396. if (shape.ptr() != Py_None) {
  397. py::tuple ret = _expand_bool_dim(inp, tuple_val);
  398. inp = ret[0];
  399. tuple_val = ret[1];
  400. }
  401. }
  402. py::list items;
  403. py::list tensors;
  404. int cur_axis = -1;
  405. for (size_t i = 0; i < tuple_val.size(); ++i) {
  406. py::object handle = tuple_val[i];
  407. cur_axis++;
  408. if (!is_scalar(handle.ptr()) && !PySlice_Check(handle.ptr())) {
  409. use_subtensor = false;
  410. }
  411. py::list item;
  412. item.append(cur_axis);
  413. auto push = [&](PyObject* v) {
  414. if (v == Py_None) {
  415. item.append(false);
  416. } else {
  417. item.append(true);
  418. tensors.append(_get_index(py::reinterpret_borrow<py::object>(v), inp));
  419. }
  420. };
  421. if (PySlice_Check(handle.ptr())) {
  422. PySliceObject* s = (PySliceObject*)handle.ptr();
  423. if (s->start == Py_None && s->stop == Py_None && s->step == Py_None) {
  424. continue;
  425. }
  426. push(s->start);
  427. push(s->stop);
  428. push(s->step);
  429. item.append(false);
  430. } else {
  431. for (size_t j = 0; j < 3; j++)
  432. item.append(false);
  433. push(handle.ptr());
  434. }
  435. items.append(item);
  436. }
  437. return py::make_tuple(inp, tensors, items, use_subtensor, need_expand_bool_dim);
  438. }
  439. py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
  440. py::tuple try_res = _try_cond_take(inp_hdl, idx_hdl);
  441. if (try_res.size() == 2) {
  442. return try_res[0];
  443. }
  444. py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
  445. py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
  446. py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
  447. py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
  448. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
  449. for (size_t i = 0; i < py_items.size(); ++i) {
  450. py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
  451. cpp_items.push_back(
  452. {item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
  453. item[3].cast<bool>(), item[4].cast<bool>()});
  454. }
  455. static std::shared_ptr<OpDef> op;
  456. if (up[3].cast<bool>()) {
  457. op = Subtensor::make(cpp_items);
  458. } else {
  459. op = IndexingMultiAxisVec::make(cpp_items);
  460. }
  461. std::vector<PyObject*> p;
  462. p.resize(tensors.size() + 2);
  463. py::object Op = py::cast(op);
  464. p[0] = Op.ptr();
  465. p[1] = tensor.ptr();
  466. for (size_t i = 0; i < tensors.size(); ++i) {
  467. p[i + 2] = tensors[i].ptr();
  468. }
  469. py::tuple ret =
  470. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  471. return ret[0];
  472. }
  473. py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) {
  474. py::object org_shape = getattr(inp_hdl, "shape");
  475. py::object val = py::reinterpret_borrow<py::object>(val_hdl);
  476. if (!TensorWrapper::try_cast(val.ptr()) && !py::isinstance<PySymbolVar>(val)) {
  477. val =
  478. _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"),
  479. inp_hdl);
  480. }
  481. py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
  482. py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
  483. py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
  484. py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
  485. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
  486. for (size_t i = 0; i < py_items.size(); ++i) {
  487. py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
  488. cpp_items.push_back(
  489. {item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
  490. item[3].cast<bool>(), item[4].cast<bool>()});
  491. }
  492. static std::shared_ptr<OpDef> op, set_op;
  493. if (up[3].cast<bool>()) {
  494. op = Subtensor::make(cpp_items);
  495. } else {
  496. op = IndexingMultiAxisVec::make(cpp_items);
  497. }
  498. std::vector<PyObject*> p;
  499. p.resize(tensors.size() + 2);
  500. py::object Op = py::cast(op);
  501. p[0] = Op.ptr();
  502. p[1] = tensor.ptr();
  503. for (size_t i = 0; i < tensors.size(); ++i) {
  504. p[i + 2] = tensors[i].ptr();
  505. }
  506. py::tuple ret =
  507. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  508. py::object tmp_result = ret[0];
  509. try {
  510. py::object value_tuple_shape = val.attr("_tuple_shape");
  511. py::object tmp_result_tuple_shape = tmp_result.attr("_tuple_shape");
  512. py::tuple value_shape = py::reinterpret_borrow<py::tuple>(value_tuple_shape);
  513. py::tuple tmp_result_shape =
  514. py::reinterpret_borrow<py::tuple>(tmp_result_tuple_shape);
  515. for (size_t i = 0; i < value_shape.size() && i < tmp_result_shape.size(); ++i) {
  516. size_t vs = value_shape[value_shape.size() - i - 1].cast<size_t>();
  517. size_t ts =
  518. tmp_result_shape[tmp_result_shape.size() - i - 1].cast<size_t>();
  519. if (vs != 1 && vs != ts) {
  520. std::string lhs = "", rhs = "";
  521. for (size_t j = 0; j < tmp_result_shape.size(); ++j) {
  522. lhs += std::to_string(tmp_result_shape[j].cast<size_t>());
  523. if (j)
  524. lhs += ",";
  525. }
  526. for (size_t j = 0; j < value_shape.size(); ++j) {
  527. rhs += std::to_string(value_shape[j].cast<size_t>());
  528. if (j)
  529. rhs += ",";
  530. }
  531. throw py::value_error(
  532. "cannot copy tensor with shape (" + rhs +
  533. ") to subtensor with shape (" + lhs + ")");
  534. }
  535. }
  536. } catch (py::error_already_set& err) {
  537. ;
  538. }
  539. py::object broadcast_func = getattr(val, "_broadcast");
  540. PyObject* Args = PyTuple_New(1);
  541. PyTuple_SetItem(Args, 0, getattr(tmp_result, "shape").release().ptr());
  542. PyObject* new_val = PyObject_CallObject(broadcast_func.ptr(), Args);
  543. Py_XDECREF(Args);
  544. val = py::reinterpret_steal<py::object>(new_val);
  545. if (up[3].cast<bool>()) {
  546. set_op = SetSubtensor::make(cpp_items);
  547. } else {
  548. set_op = IndexingSetMultiAxisVec::make(cpp_items);
  549. }
  550. std::vector<PyObject*> q;
  551. q.resize(tensors.size() + 3);
  552. py::object Set_Op = py::cast(set_op);
  553. q[0] = Set_Op.ptr();
  554. q[1] = tensor.ptr();
  555. q[2] = val.ptr();
  556. for (size_t i = 0; i < tensors.size(); ++i) {
  557. q[i + 3] = tensors[i].ptr();
  558. }
  559. py::tuple result =
  560. py::reinterpret_steal<py::object>(py_apply(NULL, q.data(), q.size()));
  561. py::object res = result[0];
  562. if (up[4].cast<bool>()) {
  563. py::object reshape_func = getattr(res, "reshape");
  564. PyObject* Args = PyTuple_New(1);
  565. PyTuple_SetItem(Args, 0, org_shape.release().ptr());
  566. PyObject* new_tensor = PyObject_CallObject(reshape_func.ptr(), Args);
  567. Py_XDECREF(Args);
  568. res = py::reinterpret_steal<py::object>(new_tensor);
  569. }
  570. return res;
  571. }
  572. bool is_tensor_or_symbolvar(py::handle arg) {
  573. return bool(TensorWrapper::try_cast(arg.ptr())) || py::isinstance<PySymbolVar>(arg);
  574. }
  575. bool is_py_sequence(py::handle arg) {
  576. if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr()) ||
  577. py::isinstance<PySymbolVar>(arg)) {
  578. return false;
  579. }
  580. return PySequence_Check(arg.ptr());
  581. }
  582. py::object _split_cpp(
  583. py::handle inp_hdl, py::handle nsplits_or_sections_hdl, py::handle axis_hdl) {
  584. py::object shape_obj = getattr(inp_hdl, "shape");
  585. py::object n_total = shape_obj[axis_hdl];
  586. int ndim = shape_obj.attr("__len__")().cast<int>();
  587. int axis = axis_hdl.cast<int>();
  588. if (axis >= ndim) {
  589. throw py::value_error("Invalid axis " + std::to_string(axis));
  590. }
  591. int n_sections;
  592. bool is_array;
  593. if (is_py_sequence(nsplits_or_sections_hdl)) {
  594. n_sections = PySequence_Length(nsplits_or_sections_hdl.ptr()) + 1;
  595. is_array = true;
  596. } else {
  597. n_sections = getattr(nsplits_or_sections_hdl, "__int__")().cast<int>();
  598. is_array = false;
  599. }
  600. py::list partitions;
  601. std::shared_ptr<OpDef> op;
  602. std::vector<PyObject*> p;
  603. if (is_array) {
  604. py::list div_points;
  605. py::list sections = py::reinterpret_borrow<py::object>(nsplits_or_sections_hdl);
  606. div_points.append(0);
  607. for (size_t i = 0; i < sections.size(); ++i) {
  608. div_points.append(sections[i]);
  609. }
  610. div_points.append(n_total);
  611. for (size_t i = 1; i < div_points.size(); ++i) {
  612. if (div_points[i - 1] > div_points[i]) {
  613. throw py::value_error(
  614. "Invalid nsplits_or_secions: " +
  615. repr(nsplits_or_sections_hdl).cast<std::string>());
  616. }
  617. py::object pos = div_points[i] - div_points[i - 1];
  618. if (is_tensor_or_symbolvar(pos)) {
  619. partitions.append(pos);
  620. } else {
  621. partitions.append(
  622. _Const(pos, py::cast((mgb::DType)dtype::Int32()),
  623. getattr(inp_hdl, "device"), inp_hdl));
  624. }
  625. }
  626. op = Split::make(axis, 0);
  627. p.resize(partitions.size() + 2);
  628. for (size_t i = 0; i < partitions.size(); ++i) {
  629. p[i + 2] = partitions[i].ptr();
  630. }
  631. } else {
  632. if (n_sections <= 0) {
  633. throw py::value_error("Number sections must be larger than 0");
  634. }
  635. if (py::int_(n_sections) > n_total) {
  636. throw py::value_error(
  637. "The size " + repr(n_total).cast<std::string>() + " at dim " +
  638. std::to_string(axis) + " cannot be split into " +
  639. std::to_string(n_sections) + " sections");
  640. }
  641. op = Split::make(axis, n_sections);
  642. p.resize(2);
  643. }
  644. py::object Op = py::cast(op);
  645. p[0] = Op.ptr();
  646. p[1] = inp_hdl.ptr();
  647. return py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  648. }
  649. std::vector<int32_t> list2vector(py::handle li) {
  650. std::vector<int32_t> axis;
  651. if (is_py_sequence(li.ptr())) {
  652. py::list tmp_list = py::reinterpret_steal<py::list>(PySequence_List(li.ptr()));
  653. for (size_t i = 0; i < tmp_list.size(); ++i) {
  654. axis.push_back(tmp_list[i].attr("__int__")().cast<int32_t>());
  655. }
  656. } else {
  657. axis.push_back(getattr(li, "__int__")().cast<int32_t>());
  658. }
  659. return axis;
  660. }
  661. py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) {
  662. std::vector<int32_t> axis = list2vector(axis_hdl);
  663. bool unknown_ndim = true;
  664. size_t ndim = axis.size();
  665. if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) {
  666. auto&& shape = p->m_tensor->shape();
  667. if (shape) {
  668. unknown_ndim = false;
  669. ndim += shape->ndim;
  670. }
  671. } else {
  672. auto&& var = inp_hdl.cast<PySymbolVar*>();
  673. auto&& mgr = var->m_node->owner_graph()->static_infer_manager();
  674. auto&& shape = mgr.infer_shape_fallible(var->m_node);
  675. if (shape) {
  676. unknown_ndim = false;
  677. ndim += shape->ndim;
  678. }
  679. }
  680. for (size_t i = 0; i < axis.size(); ++i) {
  681. if (axis[i] < 0) {
  682. if (unknown_ndim) {
  683. throw py::index_error(
  684. "Does not support negative index when tensor's ndim is "
  685. "unknown");
  686. }
  687. axis[i] += static_cast<int32_t>(ndim);
  688. }
  689. }
  690. if (!axis.size()) {
  691. throw py::index_error("axis could not be empty");
  692. }
  693. std::sort(axis.begin(), axis.end());
  694. std::shared_ptr<OpDef> op = AddAxis::make(axis = axis);
  695. std::vector<PyObject*> p;
  696. p.resize(2);
  697. py::object Op = py::cast(op);
  698. p[0] = Op.ptr();
  699. p[1] = inp_hdl.ptr();
  700. py::tuple ret =
  701. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  702. return ret[0];
  703. }
  704. py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) {
  705. std::vector<int32_t> axis;
  706. size_t ndim;
  707. if (axis_hdl.ptr() != Py_None) {
  708. axis = list2vector(axis_hdl);
  709. }
  710. if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) {
  711. auto&& shape = p->m_tensor->shape();
  712. if (shape) {
  713. ndim = shape->ndim;
  714. if (axis_hdl.ptr() == Py_None) {
  715. for (size_t i = 0; i < shape->ndim; ++i) {
  716. if (shape->shape[i] == 1) {
  717. axis.push_back(i);
  718. }
  719. }
  720. }
  721. }
  722. } else {
  723. auto&& var = inp_hdl.cast<PySymbolVar*>();
  724. auto&& mgr = var->m_node->owner_graph()->static_infer_manager();
  725. auto&& shape = mgr.infer_shape_fallible(var->m_node);
  726. if (shape) {
  727. ndim = shape->ndim;
  728. if (axis_hdl.ptr() == Py_None) {
  729. for (size_t i = 0; i < shape->ndim; ++i) {
  730. if (shape->shape[i] == 1) {
  731. axis.push_back(i);
  732. }
  733. }
  734. }
  735. }
  736. }
  737. for (size_t i = 0; i < axis.size(); ++i) {
  738. if (axis[i] < 0) {
  739. axis[i] += static_cast<int32_t>(ndim);
  740. }
  741. }
  742. std::sort(axis.begin(), axis.end());
  743. for (size_t i = 0; i < axis.size(); ++i) {
  744. axis[i] -= static_cast<int32_t>(i);
  745. }
  746. std::shared_ptr<OpDef> op = RemoveAxis::make(axis = axis);
  747. std::vector<PyObject*> p;
  748. p.resize(2);
  749. py::object Op = py::cast(op);
  750. p[0] = Op.ptr();
  751. p[1] = inp_hdl.ptr();
  752. py::tuple ret =
  753. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  754. return ret[0];
  755. }
  756. PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
  757. try {
  758. return _make_shape_tuple(py::handle(args[0])).release().ptr();
  759. }
  760. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  761. }
  762. PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  763. try {
  764. return _getitem_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
  765. }
  766. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  767. }
  768. PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  769. try {
  770. return _setitem_cpp(
  771. py::handle(args[0]), py::handle(args[1]), py::handle(args[2]))
  772. .release()
  773. .ptr();
  774. }
  775. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  776. }
  777. PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  778. try {
  779. return _split_cpp(py::handle(args[0]), py::handle(args[1]), py::handle(args[2]))
  780. .release()
  781. .ptr();
  782. }
  783. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  784. }
  785. PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  786. try {
  787. return _expand_dims_cpp(py::handle(args[0]), py::handle(args[1]))
  788. .release()
  789. .ptr();
  790. }
  791. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  792. }
  793. PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  794. try {
  795. return _squeeze_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
  796. }
  797. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  798. }
  799. } // namespace mgb::imperative::python