You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781
  1. /**
  2. * \file imperative/python/src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/dtype.h"
  12. #include "megbrain/common.h"
  13. #include "megbrain/imperative/ops/utility.h"
  14. #include "./tensor.h"
  15. #include "./grad.h"
  16. #include "./trace.h"
  17. #include "./common.h"
  18. #include "./numpy_dtypes.h"
  19. #include "./graph_rt.h"
  20. #include "./helper.h"
  21. #include <pybind11/numpy.h>
  22. #include <pybind11/operators.h>
  23. #include <range/v3/all.hpp>
  24. #include <unordered_map>
  25. namespace py = pybind11;
  26. namespace views = ranges::views;
  27. namespace mgb::imperative::python {
  28. std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py;
  29. py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing,
  30. cpp_apply_compiled_mode, cpp_apply_const_compiled_mode;
  31. py::object cpp_apply_backward_varnode;
  32. void release_trace_apply_func(){
  33. cpp_apply_with_tracing.release();
  34. cpp_apply_const_with_tracing.release();
  35. cpp_apply_compiled_mode.release();
  36. cpp_apply_const_compiled_mode.release();
  37. cpp_apply_backward_varnode.release();
  38. }
  39. #define REGISTE_APPLY_FUNC(mode) \
  40. void set_##mode(py::object pyf) { \
  41. mode = pybind11::reinterpret_steal<py::object>(pyf); \
  42. }
  43. REGISTE_APPLY_FUNC(cpp_apply_with_tracing)
  44. REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing)
  45. REGISTE_APPLY_FUNC(cpp_apply_compiled_mode)
  46. REGISTE_APPLY_FUNC(cpp_apply_const_compiled_mode)
  47. REGISTE_APPLY_FUNC(cpp_apply_backward_varnode)
  48. #undef REGISTE_APPLY_FUNC
  49. bool is_tracing = false;
  50. bool is_symbolic = false;
  51. bool is_compiled = false;
  52. #define SET_UNSET_PROP(mode) \
  53. void set_##mode() { \
  54. is_##mode = true; \
  55. } \
  56. void unset_##mode() { \
  57. is_##mode = false; \
  58. } \
  59. SET_UNSET_PROP(tracing)
  60. SET_UNSET_PROP(symbolic)
  61. SET_UNSET_PROP(compiled)
  62. #undef SET_UNSET_PROP
  63. bool skip_tracing = false;
  64. Tensor::flags_t ApplyContext::global_disable = 0;
  65. apply_result_t apply(ApplyContext& ctx) {
  66. // emulating scalar should be put to specific op's apply, e.g.,
  67. // elementwise, reduce, typecvt. Currently it's still handled at python
  68. // side. It could be move to C++ side if it has an impact on performance
  69. auto flags = ctx.flags & ~ApplyContext::global_disable;
  70. if (flags & Tensor::Flags::SCALAR) {
  71. // TODO: emulate scalar
  72. }
  73. if (flags & Tensor::Flags::GRAD) {
  74. return apply_grad(ctx);
  75. }
  76. if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) {
  77. py::tuple pyin(ctx.nargs);
  78. for (size_t i = 0; i < ctx.nargs; ++i) {
  79. pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
  80. }
  81. auto f = py::getattr(op->obj, "_default_rule");
  82. auto pyout = py::reinterpret_steal<py::object>(PyObject_Call(f.ptr(), pyin.ptr(), nullptr));
  83. if (!pyout) throw py::error_already_set();
  84. if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) {
  85. return {tw->m_tensor};
  86. }
  87. apply_result_t ret;
  88. ret.reserve(py::len(pyout));
  89. for (auto&& i : pyout) {
  90. auto* tw = TensorWrapper::try_cast(i.ptr());
  91. mgb_assert(tw);
  92. ret.push_back(tw->m_tensor);
  93. }
  94. return ret;
  95. }
  96. if (flags & Tensor::Flags::TRACE) {
  97. return apply_trace(ctx);
  98. } else {
  99. SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs);
  100. for (size_t i = 0; i < ctx.nargs; ++i) {
  101. handles[i] = ctx.args[i]->m_handle.get();
  102. }
  103. auto output_handles = interpreter_for_py->apply_op(ctx.op, handles);
  104. apply_result_t outputs;
  105. outputs.reserve(output_handles.size());
  106. for (auto h : output_handles) {
  107. outputs.emplace_back(std::make_shared<Tensor>(h));
  108. }
  109. return outputs;
  110. }
  111. mgb_assert(0);
  112. }
  113. PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */) {
  114. try {
  115. // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  116. // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
  117. // return nullptr;
  118. // }
  119. if (nargs < 2) {
  120. PyErr_SetString(PyExc_TypeError,
  121. "py_apply expects one Op and at least one tensor "
  122. "as argument");
  123. return nullptr;
  124. }
  125. auto* op = args[0];
  126. PyTypeObject* pytype = args[1]->ob_type;
  127. ++args;
  128. --nargs;
  129. ApplyContext ctx;
  130. ctx.flags = 0;
  131. ctx.op = py::handle(op).cast<std::shared_ptr<OpDef>>();
  132. SmallVector<Tensor*, 64> tensors(nargs);
  133. ctx.args = &tensors[0];
  134. ctx.nargs = nargs;
  135. ctx.pytype = pytype;
  136. if (strstr(op->ob_type->tp_name, "BackwardGraph")) {
  137. ctx.backward = true;
  138. }
  139. for (size_t i = 0; i < nargs; ++i) {
  140. if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
  141. auto* t = tensors[i] = tw->m_tensor.get();
  142. ctx.flags |= t->m_flags;
  143. } else {
  144. PyErr_SetString(PyExc_TypeError, "expect Tensor");
  145. return nullptr;
  146. }
  147. }
  148. if (is_tracing) {
  149. ctx.flags |= Tensor::Flags::TRACE;
  150. }
  151. auto outputs = apply(ctx);
  152. size_t nout = outputs.size();
  153. auto ret = py::tuple(nout);
  154. for (size_t i = 0; i < nout; ++i) {
  155. ret[i] = TensorWrapper::make(pytype, std::move(outputs[i]));
  156. }
  157. return ret.release().ptr();
  158. } catch (std::exception& e) {
  159. PyErr_SetString(PyExc_RuntimeError, e.what());
  160. return nullptr;
  161. }
  162. }
  163. TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
  164. if (kwargs && PyDict_Size(kwargs)) {
  165. throw py::type_error("keyword argument not allowed");
  166. }
  167. auto nargs = PyTuple_Size(args);
  168. auto tup = py::reinterpret_borrow<py::tuple>(args);
  169. if (nargs == 0) {
  170. throw py::type_error("too few arguments");
  171. }
  172. if (auto* t = try_cast(tup[0].ptr())) {
  173. if (nargs > 1) {
  174. throw py::type_error("expect 1 argument");
  175. }
  176. m_tensor = t->m_tensor;
  177. } else {
  178. if (nargs == 1) {
  179. auto arg0 = PyTuple_GetItem(args, 0);
  180. // for lazy_eval_tensor
  181. if (strstr(arg0->ob_type->tp_name, "VarNode")) {
  182. if (PyObject_HasAttrString(arg0, "_node")) {
  183. arg0 = PyObject_GetAttrString(arg0, "_node");
  184. }
  185. m_tensor = std::make_shared<Tensor>(py::handle(arg0).cast<cg::VarNode *>());
  186. } else {
  187. // for DeviceTensorND
  188. if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) {
  189. auto dv = py::handle(arg0).cast<DeviceTensorND>();
  190. interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv);
  191. m_tensor = std::make_shared<Tensor>(handle);
  192. } else {
  193. throw py::type_error("single argument is not tensor, varnode or devicetensor");
  194. }
  195. }
  196. } else {
  197. py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
  198. if (nargs != 4 && nargs != 5) {
  199. throw py::type_error("expect 4 or 5 arguments");
  200. }
  201. auto data = tup[0].cast<py::array>();
  202. DType dtype = tup[1].cast<DType>();
  203. CompNode cn = tup[2].cast<CompNode>();
  204. bool is_const = tup[3].cast<bool>();
  205. bool no_cache = nargs == 5 ? tup[4].cast<bool>() : false;
  206. // const op
  207. if (is_const && is_tracing) {
  208. py::object pyf;
  209. if (is_compiled) {
  210. pyf = cpp_apply_const_compiled_mode;
  211. } else {
  212. pyf = cpp_apply_const_with_tracing;
  213. }
  214. auto ret = pyf(*tup);
  215. auto py_ret = py::reinterpret_borrow<py::list>(ret);
  216. if (auto* t = try_cast(py_ret[0].ptr())) {
  217. m_tensor = t->m_tensor;
  218. }
  219. return;
  220. }
  221. interpreter::Interpreter::Handle handle;
  222. constexpr auto size_threshhold = TensorShape::MAX_NDIM;
  223. if (data.size() > size_threshhold) {
  224. handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype), no_cache);
  225. } else {
  226. HostTensorND ret(cn);
  227. handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype), no_cache);
  228. }
  229. m_tensor = std::make_shared<Tensor>(handle);
  230. if (data.ndim() == 0) {
  231. m_tensor->m_flags |= Tensor::Flags::SCALAR;
  232. }
  233. }
  234. }
  235. }
  236. #define REGISTE_TENSORWRAPPER_FUNC(type, member) \
  237. PyObject* TensorWrapper::member() { \
  238. return py::cast(m_tensor->m_trace_info.member).release().ptr(); \
  239. } \
  240. void TensorWrapper::set_##member(PyObject* dest) { \
  241. auto py_dest = py::reinterpret_borrow<py::object>(dest); \
  242. type real_dest = py_dest.cast<type>(); \
  243. m_tensor->m_trace_info.member = real_dest; \
  244. }
  245. REGISTE_TENSORWRAPPER_FUNC(bool, data_read)
  246. REGISTE_TENSORWRAPPER_FUNC(bool, value_read)
  247. REGISTE_TENSORWRAPPER_FUNC(bool, shape_read)
  248. REGISTE_TENSORWRAPPER_FUNC(int64_t, mixin_handle)
  249. #undef REGISTE_TENSORWRAPPER_FUNC
  250. PyObject* TensorWrapper::handle() {
  251. return py::cast(m_tensor->m_handle).release().ptr();
  252. }
  253. void TensorWrapper::set_handle(PyObject* dest) {
  254. auto py_dest = py::reinterpret_borrow<py::object>(dest);
  255. SharedHandle real_dest = py_dest.cast<SharedHandle>();
  256. m_tensor->m_handle = std::move(real_dest);
  257. }
  258. PyObject* TensorWrapper::shape() {
  259. if (!skip_tracing) {
  260. set_shape_read(py::cast(true). release().ptr());
  261. }
  262. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  263. return PyTuple_New(0);
  264. }
  265. TensorShape shape;
  266. if (m_tensor->m_var) {
  267. shape = m_tensor->m_var->shape();
  268. } else {
  269. shape = m_tensor->shape();
  270. }
  271. if (!shape.ndim) {
  272. Py_RETURN_NONE;
  273. }
  274. py::tuple ret(shape.ndim);
  275. for (size_t i = 0; i < shape.ndim; ++i) {
  276. ret[i] = shape[i];
  277. }
  278. return ret.release().ptr();
  279. }
  280. PyObject* TensorWrapper::dtype() {
  281. if (m_tensor->m_var) {
  282. return py::cast(m_tensor->m_var->dtype()).release().ptr();
  283. }
  284. return py::cast(m_tensor->dtype()).release().ptr();
  285. }
  286. PyObject* TensorWrapper::device() {
  287. if (m_tensor->m_var) {
  288. return py::cast(m_tensor->m_var->comp_node()).release().ptr();
  289. }
  290. return py::cast(m_tensor->comp_node()).release().ptr();
  291. }
  292. PyObject* TensorWrapper::numpy() {
  293. if (!skip_tracing) {
  294. set_value_read(py::cast(true).release().ptr());
  295. }
  296. if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) {
  297. auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
  298. auto&& type = mgr.get_infer_type(m_tensor->m_var);
  299. using InferType = cg::static_infer::InferType;
  300. if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
  301. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  302. return nullptr;
  303. }
  304. auto* val = mgr.infer_value_fallible(m_tensor->m_var);
  305. if (!val) {
  306. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  307. return nullptr;
  308. }
  309. return py::cast(*val).attr("numpy")().release().ptr();
  310. }
  311. auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get());
  312. auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
  313. if (!arr) {
  314. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  315. return nullptr;
  316. }
  317. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  318. mgb_assert(PyArray_Check(arr.ptr()));
  319. return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
  320. }
  321. return arr.release().ptr();
  322. }
  323. PyObject* TensorWrapper::varnode() {
  324. if (m_tensor->m_var) {
  325. return py::cast(m_tensor->m_var).release().ptr();
  326. }
  327. return py::none().release().ptr();
  328. }
  329. void TensorWrapper::reset(PyObject* tensor) {
  330. TensorWrapper* t = TensorWrapper::try_cast(tensor);
  331. if (!t) {
  332. throw py::type_error("expect Tensor");
  333. }
  334. m_tensor = t->m_tensor;
  335. }
  336. void TensorWrapper::reset_varnode() {
  337. m_tensor->m_var = nullptr;
  338. }
  339. PyObject* TensorWrapper::detach() {
  340. PyObject* self = wrap_t::pycast(this);
  341. PyTypeObject* pytype = self->ob_type;
  342. std::shared_ptr<Tensor> new_tensor;
  343. if (m_tensor->m_handle.get()) {
  344. new_tensor = std::make_shared<Tensor>(m_tensor->m_handle);
  345. } else {
  346. new_tensor = std::make_shared<Tensor>(m_tensor->m_var);
  347. }
  348. new_tensor->m_trace_info = m_tensor->m_trace_info;
  349. auto ret = TensorWrapper::make(pytype, std::move(new_tensor));
  350. return ret.release().ptr();
  351. }
  352. PyObject* TensorWrapper::_dev_tensor(){
  353. if (!skip_tracing) {
  354. set_data_read(py::cast(true).release().ptr());
  355. }
  356. auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get());
  357. return py::cast(dev_tensor).release().ptr();
  358. }
  359. void TensorWrapper::_swap_out() {
  360. interpreter_for_py->swap_out(m_tensor->m_handle.get());
  361. }
  362. void TensorWrapper::_swap_in() {
  363. interpreter_for_py->swap_in(m_tensor->m_handle.get());
  364. }
  365. void TensorWrapper::_drop() {
  366. interpreter_for_py->drop(m_tensor->m_handle.get());
  367. }
  368. PyObject* TensorWrapper::isscalar() {
  369. if(m_tensor->m_flags & Tensor::Flags::SCALAR) {
  370. Py_RETURN_TRUE;
  371. } else {
  372. Py_RETURN_FALSE;
  373. }
  374. }
  375. void TensorWrapper::setscalar() {
  376. m_tensor->m_flags |= Tensor::Flags::SCALAR;
  377. }
  378. struct TensorWeakRef {
  379. std::weak_ptr<Tensor> wptr;
  380. TensorWeakRef(const TensorWrapper& tw) : wptr(tw.m_tensor) {}
  381. py::object operator()() {
  382. if (auto p = wptr.lock()) {
  383. return TensorWrapper::make(p);
  384. }
  385. return py::none();
  386. }
  387. };
  388. /* ============== convert inputs ============== */
  389. // map numpy.dtype.kind to priority
  390. inline uint8_t category_priority(char c) {
  391. switch (c) {
  392. case 'f': return 3; // floating-point
  393. case 'i': return 2; // signed integer
  394. case 'u': return 2; // unsigned integer
  395. case 'b': return 1; // boolean
  396. default: return 0;
  397. }
  398. }
  399. // Returns the maximum value of the priority of each type in the list `types`.
  400. uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
  401. if (types.size() == 0) {
  402. return 0;
  403. } else {
  404. uint8_t max_p = 0;
  405. for (auto&& desc: types) {
  406. max_p = std::max(max_p, category_priority(desc->kind));
  407. }
  408. return max_p;
  409. }
  410. }
  411. // Returns the data type with sufficient size to hold all types of
  412. // category `cat` in the list `types`.
  413. PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) {
  414. // Return value: New reference
  415. SmallVector<PyArray_Descr*> used_types;
  416. for (auto&& desc: types) {
  417. auto&& v = category_priority(desc->kind);
  418. if (v == cat) {
  419. used_types.emplace_back(desc);
  420. }
  421. }
  422. mgb_assert(used_types.size() > 0, "size of used_types is 0");
  423. PyArray_Descr* res = used_types[0];
  424. Py_INCREF(res);
  425. for (size_t i = 1; i < used_types.size(); ++i) {
  426. PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res);
  427. Py_DECREF(res);
  428. res = tmp;
  429. }
  430. return res;
  431. }
  432. PyArray_Descr* scalar2dtype(PyObject* arg) {
  433. // Return value: New reference
  434. if (PyBool_Check(arg)) {
  435. auto&& descr = PyArray_DescrFromType(NPY_BOOL);
  436. return descr;
  437. }
  438. if (PyLong_CheckExact(arg)) {
  439. auto&& descr = PyArray_DescrFromType(NPY_INT32);
  440. return descr;
  441. }
  442. if (PyFloat_CheckExact(arg)) {
  443. auto&& descr = PyArray_DescrFromType(NPY_FLOAT32);
  444. return descr;
  445. }
  446. return nullptr;
  447. }
  448. PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) {
  449. // Return value: New reference
  450. SmallVector<PyArray_Descr*> tensors;
  451. SmallVector<PyArray_Descr*> scalars;
  452. bool is_tuple = false;
  453. PyObject* tuple;
  454. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  455. if (PyList_Check(args[0])) {
  456. tuple = PyList_AsTuple(args[0]);
  457. } else {
  458. tuple = args[0];
  459. Py_INCREF(tuple);
  460. }
  461. nargs = PyTuple_Size(tuple);
  462. is_tuple = true;
  463. }
  464. for (size_t i = 0; i < nargs; ++i) {
  465. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i];
  466. if (handle == Py_None) continue;
  467. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  468. if (tw) {
  469. mgb::DType type = tw->m_tensor->dtype();
  470. auto&& descr = npy::dtype_mgb2np_descr(type);
  471. Py_INCREF(descr.get());
  472. tensors.emplace_back(descr.get());
  473. }else{
  474. if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) {
  475. auto&& descr = PyArray_DescrFromObject(handle, nullptr);
  476. tensors.emplace_back(descr);
  477. continue;
  478. }
  479. PyArray_Descr* descr = scalar2dtype(handle);
  480. if (descr) {
  481. scalars.emplace_back(descr);
  482. continue;
  483. }
  484. }
  485. }
  486. auto max_pri_scalars = max_priority(scalars);
  487. auto max_pri_tensors = max_priority(tensors);
  488. if (max_pri_scalars <= 0 && max_pri_tensors <= 0) {
  489. throw py::value_error("invalid input, no dtype avaliable");
  490. }
  491. PyArray_Descr* res;
  492. if (max_pri_scalars > max_pri_tensors) {
  493. res = promote_types(scalars, max_pri_scalars);
  494. }else{
  495. res = promote_types(tensors, max_pri_tensors);
  496. }
  497. for (auto *p: tensors) { Py_DECREF(p); }
  498. for (auto *p: scalars) { Py_DECREF(p); }
  499. Py_DECREF(tuple);
  500. return res;
  501. }
  502. CompNode _get_device(PyObject*const* args, size_t nargs) {
  503. bool is_tuple = false;
  504. PyObject* tuple;
  505. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  506. if (PyList_Check(args[0])) {
  507. tuple = PyList_AsTuple(args[0]);
  508. } else {
  509. tuple = args[0];
  510. Py_INCREF(tuple);
  511. }
  512. nargs = PyTuple_Size(tuple);
  513. is_tuple = true;
  514. }
  515. bool valid = false;
  516. CompNode cn;
  517. for (size_t i = 0; i < nargs; ++i) {
  518. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i];
  519. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  520. if (tw) {
  521. if (!valid) {
  522. cn = tw->m_tensor->comp_node();
  523. valid = true;
  524. } else {
  525. CompNode cn1 = tw->m_tensor->comp_node();
  526. if (cn1 != cn) {
  527. throw py::value_error(ssprintf("ambiguous device: %s vs %s",
  528. cn.to_string().c_str(), cn1.to_string().c_str()));
  529. }
  530. }
  531. }
  532. }
  533. if (!valid) {
  534. mgb_assert(0, "expect at least 1 device");
  535. }
  536. Py_DECREF(tuple);
  537. return cn;
  538. }
  539. // Returns the dtype that would result from performing an arithmetic
  540. // operation on the provided input tensors and scalars.
  541. PyObject* dtype_promotion(PyObject* self, PyObject*const* args, size_t nargs) {
  542. if (!nargs) {
  543. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  544. return nullptr;
  545. }
  546. try {
  547. PyArray_Descr* res = _dtype_promotion(args, nargs);
  548. return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr();
  549. } catch (std::exception& e) {
  550. PyErr_SetString(PyExc_RuntimeError, e.what());
  551. return nullptr;
  552. }
  553. }
  554. PyObject* get_device(PyObject* self, PyObject*const* args, size_t nargs) {
  555. if (!nargs) {
  556. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  557. return nullptr;
  558. }
  559. try {
  560. CompNode cn = _get_device(args, nargs);
  561. return py::cast(cn).release().ptr();
  562. } catch (std::exception& e) {
  563. PyErr_SetString(PyExc_RuntimeError, e.what());
  564. return nullptr;
  565. }
  566. }
  567. #ifdef METH_FASTCALL
  568. #define MGE_PY_INTERFACE(NAME, FUNC) \
  569. { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
  570. #else
  571. #define WRAP_FUNC_PY35(FUNC) \
  572. PyObject* py35_##FUNC(PyObject* self, PyObject* args) { \
  573. auto* arr = &PyTuple_GET_ITEM(args, 0); \
  574. auto size = PyTuple_GET_SIZE(args); \
  575. return FUNC(self, arr, size); \
  576. }
  577. WRAP_FUNC_PY35(py_apply);
  578. WRAP_FUNC_PY35(dtype_promotion);
  579. WRAP_FUNC_PY35(get_device);
  580. #undef WRAP_FUNC_PY35
  581. #define MGE_PY_INTERFACE(NAME, FUNC) \
  582. { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
  583. #endif
  584. py::object make_empty_tensorwrapper() {
  585. return TensorWrapper::make(std::move(std::make_shared<Tensor>()));
  586. }
  587. void init_tensor(py::module m) {
  588. interpreter_for_py = interpreter::Interpreter::inst().create_channel();
  589. auto* tensor_type = TensorWrapper::wrap_t::type()
  590. .def<&TensorWrapper::numpy>("numpy")
  591. .def_getset<&TensorWrapper::shape>("shape")
  592. .def_getset<&TensorWrapper::dtype>("dtype")
  593. .def_getset<&TensorWrapper::device>("device")
  594. .def<&TensorWrapper::reset>("_reset")
  595. .def<&TensorWrapper::isscalar>("isscalar")
  596. .def<&TensorWrapper::setscalar>("setscalar")
  597. .def<&TensorWrapper::detach>("detach")
  598. .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
  599. .def<&TensorWrapper::_swap_out>("_swap_out")
  600. .def<&TensorWrapper::_swap_in>("_swap_in")
  601. .def<&TensorWrapper::_drop>("_drop")
  602. .def<&TensorWrapper::reset_varnode>("_reset_varnode")
  603. .def_getset<&TensorWrapper::varnode>("_varnode")
  604. .def_getset<&TensorWrapper::data_read, &TensorWrapper::set_data_read>("data_read")
  605. .def_getset<&TensorWrapper::value_read, &TensorWrapper::set_value_read>("value_read")
  606. .def_getset<&TensorWrapper::shape_read, &TensorWrapper::set_shape_read>("shape_read")
  607. .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("mixin_handle")
  608. .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle")
  609. .finalize();
  610. if (!tensor_type) throw py::error_already_set();
  611. py::setattr(m, "Tensor", tensor_type);
  612. py::class_<TensorWeakRef>(m, "TensorWeakRef")
  613. .def(py::init<const TensorWrapper&>())
  614. .def("__call__", &TensorWeakRef::operator());
  615. static PyMethodDef method_defs[] = {
  616. MGE_PY_INTERFACE(apply, py_apply),
  617. MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
  618. MGE_PY_INTERFACE(get_device, get_device),
  619. {nullptr, nullptr, 0, nullptr}};
  620. for (auto&& def: method_defs) {
  621. if (def.ml_meth != nullptr) {
  622. auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
  623. if (!func) throw py::error_already_set();
  624. py::setattr(m, def.ml_name, func);
  625. }
  626. }
  627. m.def("_set_swap_flag",
  628. [](bool flag) { interpreter_for_py->set_swap_flag(flag); });
  629. m.def("_set_drop_flag",
  630. [](bool flag) { interpreter_for_py->set_drop_flag(flag); });
  631. m.def("config_async_level",
  632. [](int level) { interpreter_for_py->config_async_level(level); });
  633. m.def("get_async_level",
  634. []() { return interpreter_for_py->get_async_level(); });
  635. m.def("sync",
  636. []() {
  637. interpreter_for_py->sync();
  638. py_task_q.wait_all_task_finish();
  639. },
  640. py::call_guard<py::gil_scoped_release>());
  641. m.def("full_sync",
  642. []() {
  643. interpreter_for_py->sync();
  644. CompNode::sync_all();
  645. py_task_q.wait_all_task_finish();
  646. },
  647. py::call_guard<py::gil_scoped_release>());
  648. m.def("release_trace_apply_func", &release_trace_apply_func);
  649. py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
  650. .def<&GradKeyWrapper::attach>("attach")
  651. .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
  652. .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>("name")
  653. .finalize();
  654. if (!grad_key_type) throw py::error_already_set();
  655. py::setattr(m, "GradKey", grad_key_type);
  656. m.def("backward", &GradKeyWrapper::backward);
  657. m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing);
  658. m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing);
  659. m.def("set_cpp_apply_compiled_mode", &set_cpp_apply_compiled_mode);
  660. m.def("set_cpp_apply_const_compiled_mode", &set_cpp_apply_const_compiled_mode);
  661. m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode);
  662. m.attr("skip_tracing") = &skip_tracing;
  663. py::class_<SharedHandle>(m, "SharedHandle")
  664. .def(py::init<const SharedHandle&>());
  665. m.def("set_tracing", &set_tracing);
  666. m.def("unset_tracing", &unset_tracing);
  667. m.def("set_symbolic", &set_symbolic);
  668. m.def("unset_symbolic", &unset_symbolic);
  669. m.def("set_compiled", &set_compiled);
  670. m.def("unset_compiled", &unset_compiled);
  671. m.def("__make_empty_tensor", &make_empty_tensorwrapper);
  672. }
  673. #undef MGE_PY_INTERFACE
  674. } // namespace mgb::imperative::python

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台