You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 55 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444
  1. /**
  2. * \file imperative/python/src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/common.h"
  12. #include "megbrain/dtype.h"
  13. #include "megbrain/imperative/cpp_cupti.h"
  14. #include "megbrain/imperative/ops/autogen.h"
  15. #include "megbrain/imperative/ops/backward_graph.h"
  16. #include "megbrain/imperative/ops/utility.h"
  17. #include "megbrain/imperative/profiler.h"
  18. #include "megbrain/imperative/transformations/dim_expansion.h"
  19. #include "megbrain/imperative/transformations/dtype_promote.h"
  20. #include "megbrain/imperative/transformations/eval.h"
  21. #include "megbrain/imperative/transformations/lazy.h"
  22. #include "megbrain/imperative/transformations/scalar.h"
  23. #include "megbrain/imperative/transformations/symbol.h"
  24. #include "megbrain/imperative/transformations/trace.h"
  25. #include "megbrain/imperative/utils/map.h"
  26. #include "megbrain/opr/io.h"
  27. #include "megbrain/plugin/profiler.h"
  28. #include "megbrain/utils/stats.h"
  29. #include "megdnn/algorithm_cache.h"
  30. #include "./common.h"
  31. #include "./grad.h"
  32. #include "./graph_rt.h"
  33. #include "./helper.h"
  34. #include "./module_trace.h"
  35. #include "./numpy_dtypes.h"
  36. #include "./tensor.h"
  37. #include "./tensor_utils.h"
  38. #include "./transformation.h"
  39. #include <object.h>
  40. #include <pybind11/numpy.h>
  41. #include <pybind11/operators.h>
  42. #include <pybind11/pytypes.h>
  43. #include <pyerrors.h>
  44. #include <iterator>
  45. #include <range/v3/all.hpp>
  46. #include <string>
  47. #include <unordered_map>
  48. #include "../../src/impl/mgb_cg_impl.h"
  49. namespace py = pybind11;
  50. namespace views = ranges::views;
  51. namespace mgb::imperative::python {
  52. namespace {
  53. WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map;
  54. struct SymbolVarContext {
  55. TransformationContext context;
  56. std::shared_ptr<SymbolTransformation> symbol_tsf;
  57. std::shared_ptr<ScalarTransformation> scalar_tsf;
  58. std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf;
  59. std::shared_ptr<DimExpansionTransformation> dim_expansion_tsf;
  60. SymbolVarContext(cg::ComputingGraph* graph) {
  61. symbol_tsf = std::make_shared<SymbolTransformation>(graph);
  62. scalar_tsf = std::make_shared<ScalarTransformation>();
  63. dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>();
  64. dim_expansion_tsf = std::make_shared<DimExpansionTransformation>();
  65. Transformation::swap_context(context);
  66. }
  67. void init() {
  68. symbol_tsf->register_at(Transformation::top());
  69. scalar_tsf->register_at(Transformation::top());
  70. dtype_promote_tsf->register_at(Transformation::top());
  71. dim_expansion_tsf->register_at(Transformation::top());
  72. }
  73. ValueRef symvar2val(py::handle py_symbol_var) {
  74. auto* symbol_var = py_symbol_var.cast<PySymbolVar*>();
  75. ValueRef value = symbol_tsf->value_type().make(symbol_var->m_node);
  76. if (symbol_var->is_scalar) {
  77. value = scalar_tsf->value_type().make(value);
  78. }
  79. return value;
  80. }
  81. py::object val2symvar(py::handle typeobj, ValueRef value) {
  82. bool is_scalar = false;
  83. if (auto* scalar_value = value.as(scalar_tsf->value_type())) {
  84. value = scalar_value->value();
  85. is_scalar = true;
  86. }
  87. auto* node = value.cast(symbol_tsf->value_type()).node();
  88. auto py_symbol_var =
  89. typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic));
  90. py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar;
  91. return py_symbol_var;
  92. }
  93. ~SymbolVarContext() { Transformation::swap_context(context); }
  94. };
  95. } // namespace
  96. interpreter::Interpreter::Channel* interpreter_for_py = nullptr;
  97. PyTypeObject* py_tensor_type = nullptr;
  98. pybind11::handle py_device_type = nullptr;
  99. PyObject* cpp_use_symbolic_shape;
  100. #define REGISTE_APPLY_FUNC(mode) \
  101. void set_##mode(py::object pyf) { mode = pyf.ptr(); }
  102. REGISTE_APPLY_FUNC(cpp_use_symbolic_shape)
  103. #undef REGISTE_APPLY_FUNC
  104. PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs);
  105. CompNode _get_device(PyObject* const* args, size_t nargs);
  106. PyObject* py_apply(
  107. PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) {
  108. try {
  109. // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  110. // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
  111. // return nullptr;
  112. // }
  113. if (nargs < 2) {
  114. PyErr_SetString(
  115. PyExc_TypeError,
  116. "py_apply expects one Op and at least one tensor "
  117. "as argument");
  118. return nullptr;
  119. }
  120. auto* py_op = args[0];
  121. ++args;
  122. --nargs;
  123. auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>();
  124. SmallVector<ValueRef, 8> tensors(nargs);
  125. SmallVector<bool, 8> is_symbol_var(nargs, false);
  126. ComputingGraph* cg = nullptr;
  127. for (size_t i = 0; i < nargs; ++i) {
  128. if ((!TensorWrapper::try_cast(args[i])) &&
  129. py::isinstance<PySymbolVar>(py::handle(args[i]))) {
  130. is_symbol_var[i] = true;
  131. ComputingGraph* cur_cg =
  132. py::handle(args[i]).cast<PySymbolVar*>()->m_node->owner_graph();
  133. if (cg == nullptr) {
  134. cg = cur_cg;
  135. } else {
  136. mgb_assert(cg == cur_cg);
  137. }
  138. }
  139. }
  140. mgb::CompNode target_cn;
  141. mgb::DType target_dtype;
  142. auto convert_pyinput_to_tensor = [&](size_t i) -> ValueRef {
  143. if (!target_dtype.valid()) {
  144. target_dtype = npy::dtype_np2mgb_descr(_dtype_promotion(args, nargs));
  145. target_cn = _get_device(args, nargs);
  146. }
  147. HostTensorND ht(target_cn);
  148. ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype);
  149. if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler
  150. // py_tuple is not allowed here because of tracing
  151. return imperative::apply(
  152. CreateTensor(CreateTensor::Const, target_cn, ht.layout()),
  153. HostStorage::make(ht.storage()))[0];
  154. } else { // scaler
  155. return imperative::apply(
  156. CreateTensor(CreateTensor::Const, target_cn, target_dtype, {}),
  157. HostStorage::make(ht.storage()))[0];
  158. }
  159. };
  160. if (cg != nullptr) {
  161. // swap to a special context to reuse scalar handle
  162. size_t symbol_var_idx = 8;
  163. SymbolVarContext context(cg);
  164. context.init();
  165. for (size_t i = 0; i < nargs; ++i) {
  166. if (is_symbol_var[i]) {
  167. symbol_var_idx = i;
  168. tensors[i] = context.symvar2val(args[i]);
  169. } else if (
  170. DTypePromoteCfg::convert_input_enabled &&
  171. op->same_type<Elemwise>()) {
  172. tensors[i] = convert_pyinput_to_tensor(i);
  173. } else {
  174. PyErr_SetString(
  175. PyExc_TypeError, "py_apply expects tensor as inputs");
  176. return nullptr;
  177. }
  178. }
  179. auto outputs = imperative::apply(*op, tensors);
  180. auto ret = pybind11::tuple(outputs.size());
  181. auto typeobj = py::handle(args[symbol_var_idx]).get_type();
  182. for (size_t i = 0; i < outputs.size(); ++i) {
  183. ret[i] = context.val2symvar(typeobj, outputs[i]);
  184. }
  185. return ret.release().ptr();
  186. }
  187. for (size_t i = 0; i < nargs; ++i) {
  188. if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
  189. tensors[i] = tw->m_tensor->data();
  190. } else if (
  191. DTypePromoteCfg::convert_input_enabled &&
  192. op->same_type<Elemwise>()) {
  193. tensors[i] = convert_pyinput_to_tensor(i);
  194. } else {
  195. PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs");
  196. return nullptr;
  197. }
  198. }
  199. auto outputs = [&] { return imperative::apply(*op, tensors); }();
  200. size_t nout = outputs.size();
  201. auto ret = py::tuple(nout);
  202. for (size_t i = 0; i < nout; ++i) {
  203. ret[i] = TensorWrapper::make(py_tensor_type, std::move(outputs[i]));
  204. }
  205. return ret.release().ptr();
  206. }
  207. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  208. }
  209. namespace {
  210. template <typename T>
  211. py::handle py_type() {
  212. if constexpr (std::is_same_v<T, py::int_>) {
  213. return (PyObject*)&PyLong_Type;
  214. } else if constexpr (std::is_same_v<T, py::float_>) {
  215. return (PyObject*)&PyFloat_Type;
  216. } else if constexpr (std::is_same_v<T, py::tuple>) {
  217. return (PyObject*)&PyTuple_Type;
  218. } else if constexpr (std::is_same_v<T, py::list>) {
  219. return (PyObject*)&PyList_Type;
  220. } else {
  221. static_assert(std::is_same_v<T, T>);
  222. }
  223. }
  224. template <typename T>
  225. auto scalar2storage(T val, CompNode cn, DType dtype) {
  226. using max_ctype_t = DTypeScalar::max_ctype;
  227. DTypeScalar scalar(dtype);
  228. scalar.set_retain_dtype(val);
  229. HostTensorStorage storage(cn);
  230. auto* raw_ptr = reinterpret_cast<dt_byte*>(new max_ctype_t());
  231. std::shared_ptr<dt_byte> raw_storage = {
  232. raw_ptr, [](dt_byte* ptr) { delete reinterpret_cast<max_ctype_t*>(ptr); }};
  233. storage.only_reset_raw_storage(cn, dtype.size(), raw_storage, 0);
  234. std::memcpy(storage.ptr(), scalar.storage(), dtype.size());
  235. return HostStorage::make(std::move(storage));
  236. }
  237. template <typename ctype>
  238. auto vec2storage(Span<DTypeScalar> vec, CompNode cn, DType dtype) {
  239. mgb_assert(vec.size() <= MEGDNN_MAX_NDIM);
  240. // TODO: use storage cache and modify ConstTensorCache to return (Host, Device)
  241. auto* raw_ptr = new ctype[MEGDNN_MAX_NDIM];
  242. for (size_t i = 0; i < vec.size(); ++i) {
  243. raw_ptr[i] = vec[i].get_cast<ctype>();
  244. }
  245. mgb_assert(sizeof(ctype) == dtype.size());
  246. std::shared_ptr<dt_byte> raw_storage = {
  247. reinterpret_cast<dt_byte*>(raw_ptr),
  248. [](dt_byte* ptr) { delete[] reinterpret_cast<ctype*>(ptr); }};
  249. HostTensorStorage storage(cn);
  250. storage.only_reset_raw_storage(cn, sizeof(ctype) * vec.size(), raw_storage, 0);
  251. return HostStorage::make(std::move(storage));
  252. }
  253. struct HostTensorArgs {
  254. ValueShape shape;
  255. DType dtype;
  256. HostStorage::ref_t storage;
  257. HostTensorND as_tensor_nd() const {
  258. HostTensorND ret(CompNode::default_cpu(), shape.as_tensor_shape(), dtype);
  259. ret.only_reset_raw_storage(*storage);
  260. return ret;
  261. }
  262. };
  263. template <typename seq_type, typename ctype>
  264. bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  265. auto size = obj.size();
  266. if (size > MEGDNN_MAX_NDIM) {
  267. return false;
  268. }
  269. ctype items[size];
  270. for (size_t i = 0; i < size; ++i) {
  271. py::handle item = obj[i];
  272. if (item.get_type().is(py_type<py::int_>())) {
  273. items[i] = (ctype)(dt_int32)item.template cast<py::int_>();
  274. } else if (item.get_type().is(py_type<py::float_>())) {
  275. items[i] = (ctype)(dt_float32)item.template cast<py::float_>();
  276. } else {
  277. return false;
  278. }
  279. }
  280. mgb_assert(sizeof(ctype) == dtype.size());
  281. auto* raw_ptr = new ctype[size];
  282. std::shared_ptr<dt_byte> raw_storage = {
  283. reinterpret_cast<dt_byte*>(raw_ptr),
  284. [](dt_byte* ptr) { delete[] reinterpret_cast<ctype*>(ptr); }};
  285. HostTensorStorage storage(cn);
  286. storage.only_reset_raw_storage(cn, sizeof(ctype) * size, raw_storage, 0);
  287. std::memcpy(storage.ptr(), items, sizeof(ctype) * size);
  288. ret.dtype = dtype;
  289. ret.shape = {size};
  290. ret.storage = HostStorage::make(std::move(storage));
  291. return true;
  292. }
  293. template <typename seq_type>
  294. bool pyseq2hval(seq_type obj, CompNode cn, HostTensorArgs& ret) {
  295. auto size = obj.size();
  296. if (size > MEGDNN_MAX_NDIM) {
  297. return false;
  298. }
  299. DTypeScalar items[size];
  300. DType dtype;
  301. for (size_t i = 0; i < size; ++i) {
  302. auto&& item = obj[i];
  303. if (item.get_type().is(py_type<py::int_>())) {
  304. items[i] = (dt_int32)item.template cast<py::int_>();
  305. if (!dtype.valid()) {
  306. dtype = dtype::Int32();
  307. } else if (dtype != dtype::Int32() && dtype != dtype::Float32()) {
  308. return false;
  309. }
  310. } else if (item.get_type().is(py_type<py::float_>())) {
  311. items[i] = (dt_float32)item.template cast<py::float_>();
  312. if (!dtype.valid()) {
  313. dtype = dtype::Float32();
  314. } else if (dtype == dtype::Int32()) {
  315. dtype = dtype::Float32();
  316. } else if (dtype != dtype::Float32()) {
  317. return false;
  318. }
  319. } else {
  320. return false;
  321. }
  322. }
  323. if (!dtype.valid()) {
  324. dtype = dtype::Float32();
  325. }
  326. ret.dtype = dtype;
  327. ret.shape = {size};
  328. if (dtype == dtype::Int32()) {
  329. ret.storage = vec2storage<dt_int32>({items, size}, cn, dtype);
  330. } else if (dtype == dtype::Float32()) {
  331. ret.storage = vec2storage<dt_float32>({items, size}, cn, dtype);
  332. } else {
  333. mgb_assert(false);
  334. }
  335. return true;
  336. }
  337. template <typename seq_type>
  338. bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  339. if (dtype == dtype::Int32()) {
  340. return pyseq2hval<seq_type, dt_int32>(obj, cn, dtype, ret);
  341. } else if (dtype == dtype::Float32()) {
  342. return pyseq2hval<seq_type, dt_float32>(obj, cn, dtype, ret);
  343. } else if (!dtype.valid()) {
  344. return pyseq2hval<seq_type>(obj, cn, ret);
  345. } else {
  346. return false;
  347. }
  348. }
  349. bool pyarr2hval(py::array obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  350. auto data = obj.cast<py::array>();
  351. auto strides = data.strides();
  352. bool need_squeeze = false;
  353. for (size_t i = 0; i < data.ndim(); ++i) {
  354. if (strides[i] == 0) {
  355. need_squeeze = true;
  356. break;
  357. }
  358. }
  359. if (need_squeeze) {
  360. std::vector<size_t> shape;
  361. for (size_t i = 0; i < data.ndim(); ++i) {
  362. shape.push_back(data.shape(i));
  363. }
  364. data = data.squeeze();
  365. data.resize(shape);
  366. }
  367. HostTensorND retnd(cn);
  368. retnd = npy::np2tensor(data.ptr(), npy::Meth::copy_into(&retnd), dtype);
  369. if (!dtype.valid()) {
  370. dtype = retnd.dtype();
  371. }
  372. mgb_assert(
  373. retnd.layout().is_empty() || retnd.layout().is_contiguous(),
  374. "host value should be continuous");
  375. for (size_t i = 0; i < data.ndim(); ++i) {
  376. ret.shape[ret.shape.ndim++] = data.shape(i);
  377. }
  378. ret.dtype = dtype;
  379. ret.storage = HostStorage::make(retnd.storage());
  380. return true;
  381. }
  382. bool pyint2hval(py::int_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  383. if (!dtype.valid()) {
  384. dtype = dtype::Int32();
  385. }
  386. ret.dtype = dtype;
  387. ret.storage = scalar2storage((dt_int32)obj, cn, dtype);
  388. return true;
  389. }
  390. bool pyfloat2hval(py::float_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  391. if (!dtype.valid()) {
  392. dtype = dtype::Float32();
  393. }
  394. ret.dtype = dtype;
  395. ret.storage = scalar2storage((dt_float32)obj, cn, dtype);
  396. return true;
  397. }
  398. HostTensorArgs pyobj2hval(py::object obj, CompNode cn, DType dtype) {
  399. HostTensorArgs ret;
  400. bool success = false;
  401. // check order: float -> int -> tuple(int -> float) -> list(int -> float)
  402. // only handle `exact` pytype, isinstance also accepts subtype
  403. // for example, isinstance(True, int) == True
  404. if (obj.get_type().is(py_type<py::float_>())) {
  405. success = pyfloat2hval(py::float_(obj), cn, dtype, ret);
  406. } else if (obj.get_type().is(py_type<py::int_>())) { // py::bool_ is py::int_
  407. success = pyint2hval(py::int_(obj), cn, dtype, ret);
  408. } else if (obj.get_type().is(py_type<py::tuple>())) {
  409. success = pyseq2hval<py::tuple>(py::tuple(obj), cn, dtype, ret);
  410. } else if (obj.get_type().is(py_type<py::list>())) {
  411. success = pyseq2hval<py::list>(py::list(obj), cn, dtype, ret);
  412. } else if (obj.is_none()) {
  413. obj = py::list(0);
  414. }
  415. if (!success) {
  416. success = pyarr2hval(obj, cn, dtype, ret);
  417. }
  418. mgb_assert(success);
  419. return ret;
  420. }
  421. struct PyArgDesc {
  422. const char* name;
  423. py::object (*default_value)();
  424. };
  425. struct PyArgDescs {
  426. std::vector<PyArgDesc> items;
  427. ssize_t (*name2idx)(const char* name);
  428. };
  429. py::tuple parse_args(py::tuple args, const PyArgDescs& descs) {
  430. size_t nr_args = args.size();
  431. size_t nr_items = descs.items.size();
  432. mgb_assert(nr_args <= nr_items, "too many args");
  433. if (nr_args == nr_items) {
  434. return args;
  435. }
  436. py::tuple ret(nr_items);
  437. for (size_t i = 0; i < nr_args; ++i) {
  438. ret[i] = args[i];
  439. }
  440. for (size_t i = nr_args; i < nr_items; ++i) {
  441. ret[i] = descs.items[i].default_value();
  442. }
  443. return ret;
  444. }
  445. py::tuple parse_args_and_kwargs(
  446. py::tuple args, py::dict kwargs, const PyArgDescs& descs) {
  447. size_t nr_args = args.size();
  448. size_t nr_kwargs = kwargs.size();
  449. size_t nr_items = descs.items.size();
  450. mgb_assert(nr_args + nr_kwargs <= nr_items, "too many args");
  451. if (nr_args == nr_items) {
  452. return args;
  453. }
  454. py::tuple ret(nr_items);
  455. for (size_t i = 0; i < nr_args; ++i) {
  456. ret[i] = args[i];
  457. }
  458. bool has_value[nr_items - nr_args];
  459. for (size_t i = nr_args; i < nr_items; ++i) {
  460. has_value[i - nr_args] = false;
  461. }
  462. for (auto&& [k, v] : kwargs) {
  463. auto key = py::str(k).cast<std::string>();
  464. ssize_t index = descs.name2idx(key.c_str());
  465. mgb_assert(index >= nr_args);
  466. ret[index] = v;
  467. has_value[index - nr_args] = true;
  468. }
  469. for (size_t i = nr_args; i < nr_items; ++i) {
  470. if (!has_value[i - nr_args]) {
  471. ret[i] = descs.items[i].default_value();
  472. }
  473. }
  474. return ret;
  475. }
  476. CompNode as_comp_node(const std::string& name) {
  477. thread_local struct {
  478. std::string name;
  479. CompNode cn;
  480. } cached;
  481. if (cached.name != name) {
  482. cached.name = name;
  483. cached.cn = CompNode::load(name);
  484. }
  485. return cached.cn;
  486. }
  487. CompNode as_comp_node(py::object py_device) {
  488. std::optional<std::string> device_name;
  489. if (py_device.is_none() || py::str::check_(py_device)) {
  490. auto cls = py::handle(reinterpret_cast<PyObject*>(py_tensor_type));
  491. auto dmap_callback = cls.attr("dmap_callback");
  492. std::string name;
  493. if (dmap_callback.is_none() && py_device.is_none()) {
  494. name = get_default_device();
  495. } else {
  496. if (py_device.is_none()) {
  497. py_device = py::str(get_default_device());
  498. }
  499. if (!dmap_callback.is_none()) {
  500. py_device = dmap_callback(py_device);
  501. }
  502. name = py::str(py_device).cast<std::string>();
  503. }
  504. return as_comp_node(name);
  505. } else {
  506. if (py::isinstance(py_device, py_device_type)) {
  507. py_device = py_device.attr("_cn");
  508. }
  509. mgb_assert(py::isinstance(py_device, py_comp_node_type));
  510. return py_device.cast<CompNode>();
  511. }
  512. }
  513. template <char... Chars>
  514. bool compare_cstr(const char* cstr) {
  515. return (((*cstr++) == Chars) && ...) && *cstr == '\0';
  516. }
  517. ssize_t name2idx(const char* name) {
  518. const char* ch = name;
  519. // TODO: trie
  520. // clang-format off
  521. switch (*ch++) {
  522. case 'd':
  523. switch (*ch++) {
  524. // data
  525. case 'a': return compare_cstr<'t', 'a'>(ch) ? 0 : -1;
  526. // dtype
  527. case 't': return compare_cstr<'y', 'p', 'e'>(ch) ? 1 : -1;
  528. // device
  529. case 'e': return compare_cstr<'v', 'i', 'c', 'e'>(ch) ? 2 : -1;
  530. }
  531. case 'i':
  532. // is_const
  533. return compare_cstr<'s', '_', 'c', 'o', 'n', 's', 't'>(ch) ? 3 : -1;
  534. case 'n':
  535. switch (*ch++) {
  536. // no_cache
  537. case 'o': return compare_cstr<'_', 'c', 'a', 'c', 'h', 'e'>(ch) ? 4 : -1;
  538. // name
  539. case 'a': return compare_cstr<'m', 'e'>(ch) ? 5 : -1;
  540. }
  541. }
  542. // clang-format on
  543. return -1;
  544. }
  545. } // namespace
  546. TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
  547. static PyArgDescs descs = {
  548. {
  549. {"data", []() -> py::object { return py::none(); }},
  550. {"dtype", []() -> py::object { return py::none(); }},
  551. {"device", []() -> py::object { return py::none(); }},
  552. {"is_const", []() -> py::object { return py::bool_(false); }},
  553. {"no_cache", []() -> py::object { return py::bool_(false); }},
  554. {"name", []() -> py::object { return py::none(); }},
  555. },
  556. name2idx};
  557. py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
  558. auto tup = py::reinterpret_borrow<py::tuple>(args);
  559. if (kwargs) {
  560. tup = parse_args_and_kwargs(
  561. tup, py::reinterpret_borrow<py::dict>(kwargs), descs);
  562. } else {
  563. tup = parse_args(tup, descs);
  564. }
  565. mgb_assert(tup.size() == 6);
  566. if (auto* t = try_cast(tup[0].ptr())) {
  567. m_tensor = t->m_tensor->copy();
  568. } else {
  569. auto data = tup[0];
  570. DType dtype = tup[1].cast<DType>();
  571. bool is_const = tup[3].cast<bool>();
  572. bool no_cache = tup[4].cast<bool>();
  573. std::string name;
  574. if (!tup[5].is_none()) {
  575. name = tup[5].cast<std::string>();
  576. }
  577. CompNode cn = as_comp_node(tup[2]);
  578. {
  579. CreateTensor::Kind kind = is_const ? CreateTensor::Const
  580. : no_cache ? CreateTensor::Unique
  581. : CreateTensor::Common;
  582. auto&& hval = pyobj2hval(data, cn, dtype);
  583. auto val = imperative::apply(
  584. CreateTensor(kind, cn, hval.dtype, hval.shape), hval.storage)[0];
  585. m_tensor.emplace(val);
  586. }
  587. if (!name.empty()) {
  588. m_tensor->reset(imperative::apply(RenameValue(name), m_tensor->data())[0]);
  589. }
  590. }
  591. mgb_assert(m_tensor->data());
  592. }
  593. PyObject* TensorWrapper::module_trace_info() {
  594. if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) {
  595. if (module_trace_info->ptr()) {
  596. return module_trace_info->inc_ref().ptr();
  597. }
  598. }
  599. PyErr_SetString(
  600. PyExc_AttributeError,
  601. "Has no attribute named \'_NodeMixin__node\', please "
  602. "set it first");
  603. return nullptr;
  604. }
  605. void TensorWrapper::set_module_trace_info(PyObject* obj) {
  606. // TODO: erase when obj == nullptr
  607. module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj);
  608. }
  609. void TensorWrapper::_set_name(PyObject* dest) {
  610. auto py_dest = py::reinterpret_borrow<py::object>(dest);
  611. auto name = py_dest.cast<std::string>();
  612. m_tensor->set_name(name);
  613. }
  614. PyObject* TensorWrapper::_detail() {
  615. return py::str(m_tensor->data().unwrap().to_string()).release().ptr();
  616. }
  617. void TensorWrapper::_watch() {
  618. m_tensor->data().watch();
  619. }
  620. PyObject* TensorWrapper::shape() {
  621. auto shape = m_tensor->shape();
  622. if (!shape) {
  623. Py_RETURN_NONE;
  624. }
  625. py::tuple ret(shape->ndim);
  626. for (size_t i = 0; i < shape->ndim; ++i) {
  627. ret[i] = shape->at(i);
  628. }
  629. return ret.release().ptr();
  630. }
  631. PyObject* TensorWrapper::dtype() {
  632. return py::cast(m_tensor->dtype()).release().ptr();
  633. }
  634. PyObject* TensorWrapper::device() {
  635. return py::cast(m_tensor->comp_node()).release().ptr();
  636. }
  637. PyObject* TensorWrapper::numpy() {
  638. auto hv = m_tensor->numpy();
  639. if (!hv) {
  640. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  641. return nullptr;
  642. }
  643. auto arr = py::reinterpret_steal<py::array>(
  644. npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE));
  645. if (hv->shape().is_scalar()) {
  646. mgb_assert(PyArray_Check(arr.ptr()));
  647. return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
  648. }
  649. return arr.release().ptr();
  650. }
  651. void TensorWrapper::reset(PyObject* tensor) {
  652. TensorWrapper* t = TensorWrapper::try_cast(tensor);
  653. if (!t) {
  654. throw py::type_error("expect Tensor");
  655. }
  656. m_tensor->reset(t->m_tensor->data());
  657. }
  658. PyObject* TensorWrapper::detach() {
  659. auto detached = imperative::apply(DetachGrad(), m_tensor->data())[0];
  660. return TensorWrapper::make(py_tensor_type, detached).release().ptr();
  661. }
  662. PyObject* TensorWrapper::_dev_tensor() {
  663. auto dv = m_tensor->data().dev_tensor();
  664. // TODO: handle scalar
  665. return py::cast(dv->as_nd(true)).release().ptr();
  666. }
  667. void TensorWrapper::_drop() {
  668. imperative::apply(DTRCommand(DTRCommand::Drop), m_tensor->data());
  669. }
  670. PyObject* TensorWrapper::isscalar() {
  671. if (m_tensor->is_scalar()) {
  672. Py_RETURN_TRUE;
  673. } else {
  674. Py_RETURN_FALSE;
  675. }
  676. }
  677. struct TensorWeakRef {
  678. ValueWeakRef data;
  679. TensorWeakRef(const TensorWrapper& tw) : data(tw.m_tensor->data()) {}
  680. py::object operator()() {
  681. if (auto p = data.lock()) {
  682. return TensorWrapper::make(py_tensor_type, p);
  683. }
  684. return py::none();
  685. }
  686. };
  687. #ifdef METH_FASTCALL
  688. #define MGE_PY_INTERFACE(NAME, FUNC) \
  689. { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
  690. #else
  691. #define WRAP_FUNC_PY35(FUNC) \
  692. PyObject* py35_##FUNC(PyObject* self, PyObject* args) { \
  693. auto* arr = &PyTuple_GET_ITEM(args, 0); \
  694. auto size = PyTuple_GET_SIZE(args); \
  695. return FUNC(self, arr, size); \
  696. }
  697. WRAP_FUNC_PY35(py_apply);
  698. WRAP_FUNC_PY35(dtype_promotion);
  699. WRAP_FUNC_PY35(get_device);
  700. WRAP_FUNC_PY35(make_shape_tuple);
  701. WRAP_FUNC_PY35(getitem_cpp);
  702. WRAP_FUNC_PY35(setitem_cpp);
  703. WRAP_FUNC_PY35(split_cpp);
  704. WRAP_FUNC_PY35(expand_dims_cpp);
  705. WRAP_FUNC_PY35(squeeze_cpp);
  706. WRAP_FUNC_PY35(transpose_cpp);
  707. WRAP_FUNC_PY35(broadcast_cpp);
  708. WRAP_FUNC_PY35(reshape_cpp);
  709. WRAP_FUNC_PY35(adaptive_pool2d_cpp);
  710. WRAP_FUNC_PY35(Const);
  711. WRAP_FUNC_PY35(astype_cpp);
  712. WRAP_FUNC_PY35(matmul_cpp);
  713. WRAP_FUNC_PY35(batched_matmul_cpp);
  714. WRAP_FUNC_PY35(convert_single_value_cpp);
  715. WRAP_FUNC_PY35(convert_inputs_cpp);
  716. WRAP_FUNC_PY35(astensor1d_cpp);
  717. WRAP_FUNC_PY35(pixel_shuffle_cpp);
  718. #undef WRAP_FUNC_PY35
  719. #define MGE_PY_INTERFACE(NAME, FUNC) \
  720. { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
  721. #endif
  722. void init_tensor(py::module m) {
  723. imperative::Tensor::static_initialize();
  724. static auto& transformations = TransformationManager::get_instance();
  725. using Segment = TransformationManager::Segment;
  726. using Channel = interpreter::Interpreter::Channel;
  727. auto* channel =
  728. imperative::ResourceManager::create_global<std::unique_ptr<Channel>>(
  729. interpreter::Interpreter::inst().create_channel())
  730. ->get();
  731. interpreter_for_py = channel;
  732. MGB_MARK_USED_VAR(
  733. transformations
  734. .register_at<Segment::Eval>(
  735. std::make_shared<InterpreterTransformation>(
  736. std::shared_ptr<Channel>(channel, [](Channel*) {})))
  737. .release());
  738. MGB_MARK_USED_VAR(transformations
  739. .register_at<Segment::Scalar>(
  740. std::make_shared<ScalarTransformation>())
  741. .release());
  742. MGB_MARK_USED_VAR(transformations
  743. .register_at<Segment::DTypePromote>(
  744. std::make_shared<DTypePromoteTransformation>())
  745. .release());
  746. MGB_MARK_USED_VAR(transformations
  747. .register_at<Segment::DimExpansion>(
  748. std::make_shared<DimExpansionTransformation>())
  749. .release());
  750. static py::exception<interpreter::AsyncError> py_async_error(
  751. m, "AsyncError", PyExc_RuntimeError);
  752. py::register_exception_translator([](std::exception_ptr p) {
  753. try {
  754. if (p)
  755. std::rethrow_exception(p);
  756. } catch (const interpreter::AsyncError& e) {
  757. pyext17::pybind11_translate_exception(e.nested_ptr());
  758. if (PyErr_Occurred()) {
  759. PyObject *exc, *val, *tb;
  760. PyErr_Fetch(&exc, &val, &tb);
  761. PyErr_NormalizeException(&exc, &val, &tb);
  762. if (tb) {
  763. PyException_SetTraceback(val, tb);
  764. }
  765. auto val2 = py_async_error.py::object::operator()(
  766. "An async error is reported. See above for the actual cause."
  767. " Hint: This is where it is reported, not where it happened."
  768. " You may call `megengine.config.async_level = 0 "
  769. "to get better error reporting.");
  770. PyException_SetCause(
  771. val2.ptr(), val); // PyException_SetCause steals reference
  772. Py_XDECREF(exc);
  773. Py_XDECREF(tb);
  774. PyErr_Restore(
  775. py_async_error.inc_ref().ptr(), val2.release().ptr(), nullptr);
  776. } else {
  777. py_async_error("Unkown async error");
  778. }
  779. }
  780. });
  781. auto* tensor_type =
  782. TensorWrapper::wrap_t::type()
  783. .def<&TensorWrapper::numpy>("numpy")
  784. .def_getset<&TensorWrapper::shape>("shape")
  785. .def_getset<&TensorWrapper::dtype>("dtype")
  786. .def_getset<&TensorWrapper::device>("device")
  787. .def<&TensorWrapper::reset>("_reset")
  788. .def<&TensorWrapper::isscalar>("_isscalar")
  789. .def<&TensorWrapper::detach>("detach")
  790. // TODO: remove this
  791. .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
  792. .def<&TensorWrapper::_drop>("_drop")
  793. .def<&TensorWrapper::_detail>("_detail")
  794. .def<&TensorWrapper::_set_name>("_set_name")
  795. .def<&TensorWrapper::_watch>("_watch")
  796. .def_getset<
  797. &TensorWrapper::module_trace_info,
  798. &TensorWrapper::set_module_trace_info>("_NodeMixin__node")
  799. .finalize();
  800. if (!tensor_type)
  801. throw py::error_already_set();
  802. py::setattr(m, "Tensor", tensor_type);
  803. py::class_<TensorWeakRef>(m, "TensorWeakRef")
  804. .def(py::init<const TensorWrapper&>())
  805. .def("__call__", &TensorWeakRef::operator());
  806. py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar")
  807. .def_property_readonly(
  808. "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); })
  809. .def_property(
  810. "var", [](PySymbolVar* v) { return v->m_node; },
  811. [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; })
  812. .def_property_readonly(
  813. "device", [](PySymbolVar* v) { return v->m_node->comp_node(); })
  814. .def_property_readonly(
  815. "graph", [](PySymbolVar* v) { return v->m_node->owner_graph(); })
  816. .def_property_readonly(
  817. "shape",
  818. [](PySymbolVar* v) -> const TensorShape* {
  819. auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
  820. return mgr.infer_shape_fallible(v->m_node);
  821. })
  822. .def("numpy",
  823. [](PySymbolVar* v) {
  824. auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
  825. auto&& type = mgr.get_infer_type(v->m_node);
  826. using InferType = cg::static_infer::InferType;
  827. if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
  828. throw py::value_error("value invalid!");
  829. }
  830. auto* val = mgr.infer_value_fallible(v->m_node);
  831. if (!val) {
  832. throw py::value_error("value invalid!");
  833. }
  834. auto np_val = py::cast(*val).attr("numpy")();
  835. return np_val;
  836. })
  837. .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
  838. .def(py::init([](cg::VarNode* node) {
  839. return std::make_shared<PySymbolVar>(node);
  840. }),
  841. py::arg() = nullptr);
  842. static PyMethodDef method_defs[] = {
  843. MGE_PY_INTERFACE(apply, py_apply),
  844. MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
  845. MGE_PY_INTERFACE(get_device, get_device),
  846. MGE_PY_INTERFACE(make_shape_tuple, make_shape_tuple),
  847. MGE_PY_INTERFACE(getitem_cpp, getitem_cpp),
  848. MGE_PY_INTERFACE(setitem_cpp, setitem_cpp),
  849. MGE_PY_INTERFACE(split_cpp, split_cpp),
  850. MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp),
  851. MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp),
  852. MGE_PY_INTERFACE(transpose_cpp, transpose_cpp),
  853. MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp),
  854. MGE_PY_INTERFACE(reshape_cpp, reshape_cpp),
  855. MGE_PY_INTERFACE(adaptive_pool2d_cpp, adaptive_pool2d_cpp),
  856. MGE_PY_INTERFACE(Const, Const),
  857. MGE_PY_INTERFACE(astype_cpp, astype_cpp),
  858. MGE_PY_INTERFACE(matmul_cpp, matmul_cpp),
  859. MGE_PY_INTERFACE(batched_matmul_cpp, batched_matmul_cpp),
  860. MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp),
  861. MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp),
  862. MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp),
  863. MGE_PY_INTERFACE(pixel_shuffle_cpp, pixel_shuffle_cpp),
  864. {nullptr, nullptr, 0, nullptr}};
  865. for (auto&& def : method_defs) {
  866. if (def.ml_meth != nullptr) {
  867. auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
  868. if (!func)
  869. throw py::error_already_set();
  870. py::setattr(m, def.ml_name, func);
  871. }
  872. }
  873. static constexpr auto sync_py_task_q = [] {
  874. py::gil_scoped_release _;
  875. py_task_q.wait_all_task_finish();
  876. };
  877. m.def("clear_candidates", [channel]() { channel->clear_candidates(); });
  878. m.def("set_option", [channel](std::string name, size_t value) {
  879. channel->set_option(name, value);
  880. });
  881. m.def("get_option",
  882. [channel](std::string name) { return channel->get_option(name); });
  883. m.def("push_scope", [channel](std::string name) {
  884. Transformation::push_scope(name);
  885. channel->push_scope(name);
  886. });
  887. m.def("pop_scope", [channel](std::string name) {
  888. channel->pop_scope(name);
  889. Transformation::pop_scope(name);
  890. });
  891. m.def("start_profile", [channel](imperative::Profiler::options_t options) {
  892. channel->sync();
  893. imperative::Profiler::load_options(std::move(options));
  894. imperative::Profiler::start_profile();
  895. channel->start_profile();
  896. });
  897. m.def("stop_profile", [channel]() -> std::function<void(std::string, std::string)> {
  898. channel->stop_profile();
  899. channel->sync();
  900. CompNode::sync_all();
  901. imperative::Profiler::stop_profile();
  902. auto results = std::make_shared<imperative::Profiler::bundle_t>(
  903. imperative::Profiler::collect());
  904. return [results = results](std::string basename, std::string format) mutable {
  905. imperative::Profiler::dump_profile(basename, format, std::move(*results));
  906. results = nullptr;
  907. };
  908. });
  909. m.def("enable_cupti", &cupti::enable);
  910. m.def("disable_cupti", &cupti::disable);
  911. m.def("cupti_available", &cupti::available);
  912. m.def("sync", [channel]() {
  913. if (channel->check_available()) {
  914. channel->sync();
  915. }
  916. sync_py_task_q();
  917. });
  918. m.def("full_sync", [channel]() {
  919. if (channel->check_available()) {
  920. channel->sync();
  921. }
  922. CompNode::sync_all();
  923. CompNode::foreach ([](CompNode cn) {
  924. auto err = cn.check_async_error();
  925. mgb_assert(!err, "%s", err->what());
  926. });
  927. sync_py_task_q();
  928. });
  929. m.def("close", [channel]() {
  930. channel->close();
  931. sync_py_task_q();
  932. });
  933. py::handle grad_key_type =
  934. GradKeyWrapper::wrap_t::type()
  935. .def<&GradKeyWrapper::attach>("attach")
  936. .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
  937. .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>(
  938. "name")
  939. .def<&GradKeyWrapper::enter>("enter")
  940. .def<&GradKeyWrapper::exit>("exit")
  941. .def<&GradKeyWrapper::suppress>("suppress")
  942. .def<&GradKeyWrapper::resume>("resume")
  943. .finalize();
  944. if (!grad_key_type)
  945. throw py::error_already_set();
  946. py::setattr(m, "GradKey", grad_key_type);
  947. m.def("backward", &GradKeyWrapper::backward);
  948. m.def("get_backward_closure", &GradKeyWrapper::get_backward_closure);
  949. m.def("set_py_tensor_type", [](py::object type_obj) {
  950. py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
  951. });
  952. m.def("set_py_device_type",
  953. [](py::object type_obj) { py_device_type = type_obj.inc_ref(); });
  954. /**
  955. * \brief trace proxy
  956. *
  957. */
  958. struct Trace {
  959. bool symbolic = false;
  960. bool no_exec = false;
  961. bool capture_as_const = false;
  962. bool profile = false;
  963. bool record_input_shapes = false;
  964. py::function options_visitor;
  965. std::shared_ptr<TracingTransformation> tracing;
  966. std::shared_ptr<CompiledTransformation> compiled;
  967. std::shared_ptr<LazyEvalTransformation> lazy_eval;
  968. std::pair<size_t, std::shared_ptr<GraphProfiler>> profiler;
  969. std::optional<TraceResult> trace_result;
  970. std::function<bool(py::object, py::object)> array_comparator;
  971. std::unique_ptr<CleanupGuard<>> tracing_guard;
  972. std::unique_ptr<CleanupGuard<>> compiled_guard;
  973. std::unique_ptr<CleanupGuard<>> lazy_eval_guard;
  974. bool compare_value(ValueRef lhs, ValueRef rhs) {
  975. auto lvalue = lhs.cast_ref<HostValue>();
  976. auto rvalue = rhs.cast_ref<HostValue>();
  977. if (lvalue->shape() != rvalue->shape()) {
  978. return false;
  979. }
  980. if (lvalue->shape().total_nr_elems() == 1) {
  981. return lvalue->item() == rvalue->item();
  982. }
  983. HostTensorND lnd = lvalue->as_nd(true);
  984. HostTensorND rnd = rvalue->as_nd(true);
  985. auto larr = py::reinterpret_steal<py::array>(
  986. npy::ndarray_from_tensor(lnd, npy::ShareType::TRY_SHARE));
  987. auto rarr = py::reinterpret_steal<py::array>(
  988. npy::ndarray_from_tensor(rnd, npy::ShareType::TRY_SHARE));
  989. return array_comparator(larr, rarr);
  990. }
  991. void enter() {
  992. auto& self = *this;
  993. if (!self.trace_result) { // untraced
  994. self.tracing = std::make_shared<TracingTransformation>(
  995. self.capture_as_const, self.record_input_shapes);
  996. if (self.symbolic) {
  997. self.lazy_eval =
  998. std::make_shared<LazyEvalTransformation>(self.no_exec);
  999. self.options_visitor(py::cast(&self.lazy_eval->options()));
  1000. }
  1001. } else if (!self.compiled) { // traced but not compiled
  1002. using namespace std::placeholders;
  1003. self.compiled = std::make_shared<CompiledTransformation>(
  1004. *self.trace_result, self.record_input_shapes);
  1005. self.compiled->set_value_comparator(
  1006. std::bind(&Trace::compare_value, this, _1, _2));
  1007. self.options_visitor(py::cast(&self.compiled->options()));
  1008. self.compiled->compile();
  1009. }
  1010. // register transformations
  1011. if (self.compiled) {
  1012. if (self.profile) {
  1013. auto& current_graph = self.compiled->graph();
  1014. if (self.profiler.first != self.compiled->graph().id()) {
  1015. // graph changed
  1016. self.profiler = std::make_pair(
  1017. current_graph.id(),
  1018. std::make_shared<GraphProfiler>(&current_graph));
  1019. }
  1020. }
  1021. compiled_guard =
  1022. transformations.register_at<Segment::Trace>(self.compiled);
  1023. // start execute because InputCallback depends
  1024. self.compiled->execute();
  1025. } else if (self.tracing) {
  1026. tracing_guard =
  1027. transformations.register_at<Segment::Trace>(self.tracing);
  1028. if (self.lazy_eval) {
  1029. lazy_eval_guard =
  1030. transformations.register_at<Segment::Eval>(self.lazy_eval);
  1031. }
  1032. } else {
  1033. mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
  1034. }
  1035. }
  1036. void exit() {
  1037. auto& self = *this;
  1038. if (self.tracing) {
  1039. tracing_guard.reset();
  1040. self.trace_result = self.tracing->get_result();
  1041. self.tracing.reset();
  1042. if (self.lazy_eval) {
  1043. auto lazy_eval = std::move(self.lazy_eval);
  1044. lazy_eval_guard.reset();
  1045. lazy_eval->check_exception();
  1046. }
  1047. } else if (self.compiled) {
  1048. compiled_guard.reset();
  1049. self.compiled->wait();
  1050. } else {
  1051. mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
  1052. }
  1053. }
  1054. VarNodeArray dump(
  1055. std::shared_ptr<ComputingGraph> graph,
  1056. std::vector<std::tuple<std::string, std::string, TensorShape>> inputs,
  1057. std::vector<std::pair<std::string, std::string>> outputs,
  1058. bool prefer_input_names) {
  1059. auto& self = *this;
  1060. mgb_assert(self.trace_result);
  1061. // mark is like "arg_0", "kwarg_xxx", "output_0" ...
  1062. std::unordered_map<std::string, size_t> mark2var;
  1063. for (size_t i = 0; i < self.trace_result->vars.size(); ++i) {
  1064. auto& name = self.trace_result->vars[i].mark;
  1065. if (!name.empty()) {
  1066. mark2var[name] = i;
  1067. }
  1068. }
  1069. std::vector<std::tuple<size_t, std::string, TensorShape>> input_vars;
  1070. std::vector<std::pair<size_t, std::string>> output_vars;
  1071. for (auto&& [input_mark, input_name, input_shape] : inputs) {
  1072. mgb_assert(input_shape.ndim, "input shape invalid");
  1073. input_vars.push_back(
  1074. {mark2var.at(input_mark), input_name, input_shape});
  1075. }
  1076. for (auto&& [output_name, repr] : outputs) {
  1077. output_vars.push_back({mark2var.at(output_name), repr});
  1078. }
  1079. self.options_visitor(py::cast(&graph->options()));
  1080. auto vars = self.trace_result->dump(
  1081. *graph, input_vars, output_vars, prefer_input_names);
  1082. return vars;
  1083. }
  1084. };
  1085. py::class_<Trace>(m, "Trace")
  1086. .def(py::init<>())
  1087. .def_readwrite("record_input_shapes", &Trace::record_input_shapes)
  1088. .def_readwrite("array_comparator", &Trace::array_comparator)
  1089. .def_readwrite("profile", &Trace::profile)
  1090. .def_property_readonly(
  1091. "options",
  1092. [](Trace& self) {
  1093. if (self.compiled) {
  1094. return &self.compiled->options();
  1095. } else {
  1096. return (ComputingGraph::Options*)nullptr;
  1097. }
  1098. })
  1099. .def("get_profile",
  1100. [](Trace& self) -> py::object {
  1101. if (self.profiler.second && self.compiled) {
  1102. auto json = self.profiler.second->to_json_full(
  1103. self.compiled->graph().current_comp_seq());
  1104. return py::str(json->to_string());
  1105. } else {
  1106. return py::none();
  1107. }
  1108. })
  1109. .def_readwrite("symbolic", &Trace::symbolic)
  1110. .def_readwrite("capture_as_const", &Trace::capture_as_const)
  1111. .def_readwrite("no_exec", &Trace::no_exec)
  1112. .def_readwrite("options_visitor", &Trace::options_visitor)
  1113. .def("enter", &Trace::enter)
  1114. .def("exit", &Trace::exit)
  1115. .def("dump", &Trace::dump)
  1116. .def("begin_excluded_region",
  1117. [](Trace& self) {
  1118. mgb_assert(bool(self.tracing) ^ bool(self.compiled));
  1119. if (self.tracing) {
  1120. self.tracing_guard.reset();
  1121. } else if (self.compiled) {
  1122. self.compiled_guard.reset();
  1123. }
  1124. })
  1125. .def("end_excluded_region", [](Trace& self) {
  1126. mgb_assert(bool(self.tracing) ^ bool(self.compiled));
  1127. if (self.tracing) {
  1128. self.tracing_guard =
  1129. transformations.register_at<Segment::Trace>(self.tracing);
  1130. } else if (self.compiled) {
  1131. self.compiled_guard =
  1132. transformations.register_at<Segment::Trace>(self.compiled);
  1133. }
  1134. });
  1135. m.def("reduce_to_scalar", [](py::object op, py::object tensor) -> py::object {
  1136. auto reduce_to_scalar = [](const OpDef& op, const ValueRef& input) {
  1137. auto make_scalar_shape = [&](CompNode device) {
  1138. return imperative::apply(
  1139. CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}),
  1140. HostStorage::make(device))[0];
  1141. };
  1142. return imperative::apply(op, input, make_scalar_shape(*input.device()))[0];
  1143. };
  1144. if (py::isinstance<PySymbolVar>(tensor)) {
  1145. auto* graph = tensor.cast<PySymbolVar*>()->m_node->owner_graph();
  1146. SymbolVarContext context(graph);
  1147. context.init();
  1148. auto output = reduce_to_scalar(
  1149. *op.cast<std::shared_ptr<OpDef>>(), context.symvar2val(tensor));
  1150. auto typeobj = tensor.get_type();
  1151. return context.val2symvar(typeobj, output);
  1152. } else {
  1153. auto* tw = TensorWrapper::try_cast(tensor.ptr());
  1154. auto output = reduce_to_scalar(
  1155. *op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data());
  1156. return TensorWrapper::make(py_tensor_type, output);
  1157. }
  1158. });
  1159. m.def("name_tensor", [](std::string name, py::object tensor) {
  1160. auto* tw = TensorWrapper::try_cast(tensor.ptr());
  1161. auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0];
  1162. tw->m_tensor->reset(output);
  1163. });
  1164. m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool {
  1165. SmallVector<ValueRef> values(tensors.size());
  1166. for (size_t i = 0; i < tensors.size(); ++i) {
  1167. values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
  1168. }
  1169. auto outputs = imperative::apply(GetGradKey(), values);
  1170. if (outputs[0].is<GradKeyValue>()) {
  1171. return true;
  1172. } else {
  1173. return false;
  1174. }
  1175. });
  1176. m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object {
  1177. SmallVector<ValueRef> values(tensors.size());
  1178. for (size_t i = 0; i < tensors.size(); ++i) {
  1179. values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
  1180. }
  1181. auto output = imperative::apply(GetGradKey(), values)[0];
  1182. if (!output) {
  1183. return py::none();
  1184. }
  1185. return py::reinterpret_borrow<py::object>(GradKeyWrapper::wrap_t::pycast(
  1186. GradKeyWrapper::get(output.cast<GradKeyValue>())));
  1187. });
  1188. m.def("set_grad", [](py::function backward_fn, std::vector<py::object> inputs,
  1189. std::vector<py::object> outputs) {
  1190. GenericFunction generic_backward_fn =
  1191. [backward_fn](Span<ValueRef> output_grads) -> ValueRefList {
  1192. py::list output_grad_tws;
  1193. for (auto&& output_grad : output_grads) {
  1194. if (output_grad) {
  1195. output_grad_tws.append(
  1196. TensorWrapper::make(py_tensor_type, output_grad));
  1197. } else {
  1198. output_grad_tws.append(py::none());
  1199. }
  1200. }
  1201. py::tuple input_grad_tws = backward_fn(*output_grad_tws);
  1202. ValueRefList input_grads(input_grad_tws.size());
  1203. for (size_t i = 0; i < input_grad_tws.size(); ++i) {
  1204. auto input_grad_tw = input_grad_tws[i];
  1205. if (!input_grad_tw.is_none()) {
  1206. input_grads[i] =
  1207. py::cast<TensorWrapper>(input_grad_tw).m_tensor->data();
  1208. } else {
  1209. input_grads[i] = {};
  1210. }
  1211. }
  1212. return input_grads;
  1213. };
  1214. SmallVector<ValueRef> values(inputs.size() + outputs.size());
  1215. for (size_t i = 0; i < inputs.size(); ++i) {
  1216. values[i] = inputs[i].cast<TensorWrapper>().m_tensor->data();
  1217. }
  1218. for (size_t i = 0; i < outputs.size(); ++i) {
  1219. values[i + inputs.size()] =
  1220. outputs[i].cast<TensorWrapper>().m_tensor->data();
  1221. }
  1222. auto wrapped_output_values =
  1223. imperative::apply(SetGrad(generic_backward_fn, inputs.size()), values);
  1224. std::vector<py::object> wrapped_outputs;
  1225. mgb_assert(wrapped_output_values.size() == outputs.size());
  1226. for (auto&& output_value : wrapped_output_values) {
  1227. wrapped_outputs.push_back(
  1228. TensorWrapper::make(py_tensor_type, output_value));
  1229. }
  1230. return wrapped_outputs;
  1231. });
  1232. static py::function module_trace_hook;
  1233. static auto get_module_trace = [] {
  1234. static std::shared_ptr<ModuleTraceTransformation> module_trace_transformation;
  1235. if (!module_trace_transformation) {
  1236. mgb_assert(module_trace_hook);
  1237. module_trace_transformation =
  1238. std::make_shared<ModuleTraceTransformation>(module_trace_hook);
  1239. MGB_MARK_USED_VAR(transformations
  1240. .register_at<Segment::ModuleTrace>(
  1241. module_trace_transformation)
  1242. .release());
  1243. }
  1244. return module_trace_transformation;
  1245. };
  1246. m.def("set_cpp_use_symbolic_shape", &set_cpp_use_symbolic_shape);
  1247. m.def("set_module_tracing", [=] { get_module_trace()->enable(); });
  1248. m.def("unset_module_tracing", [=] { get_module_trace()->disable(); });
  1249. m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); });
  1250. m.def("set_module_trace_hook", [](py::function function) {
  1251. module_trace_hook = function;
  1252. module_trace_hook.inc_ref();
  1253. });
  1254. auto atexit = py::module::import("atexit");
  1255. atexit.attr("register")(py::cpp_function([]() { module_trace_hook = {}; }));
  1256. m.def("begin_record_values", [] { Value::begin_record_values(); });
  1257. m.def("end_record_values", [] {
  1258. std::vector<std::pair<size_t, std::string>> reprs;
  1259. auto values = Value::end_record_values();
  1260. for (auto&& value : values) {
  1261. reprs.push_back({value.id(), value.to_string()});
  1262. }
  1263. return reprs;
  1264. });
  1265. m.def("print_stats", [] { Stats::print(); });
  1266. m.def("reset_stats", [] { Stats::reset(); });
  1267. m.def("_get_convert_inputs",
  1268. []() -> bool { return DTypePromoteCfg::convert_input_enabled; });
  1269. m.def("_set_convert_inputs", [](bool flag) -> bool {
  1270. bool ret = DTypePromoteCfg::convert_input_enabled;
  1271. DTypePromoteCfg::convert_input_enabled = flag;
  1272. return ret;
  1273. });
  1274. m.def("_get_amp_dtype_autocast",
  1275. []() -> bool { return DTypePromoteCfg::amp_dtype_autocast_enabled; });
  1276. m.def("_set_amp_dtype_autocast", [](bool flag) -> bool {
  1277. bool ret = DTypePromoteCfg::amp_dtype_autocast_enabled;
  1278. DTypePromoteCfg::amp_dtype_autocast_enabled = flag;
  1279. return ret;
  1280. });
  1281. static auto get_amp_prec_dtype = [](bool is_high) -> std::string {
  1282. DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
  1283. : DTypePromoteCfg::amp_low_prec_dtype;
  1284. mgb_assert(target.category() == DTypeCategory::FLOAT);
  1285. std::string ret = target.name();
  1286. transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
  1287. return ret;
  1288. };
  1289. static auto set_amp_prec_dtype = [](bool is_high,
  1290. std::string dtype_name) -> std::string {
  1291. DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
  1292. : DTypePromoteCfg::amp_low_prec_dtype;
  1293. std::string ret = target.name();
  1294. if (dtype_name == "float32") {
  1295. target = dtype::Float32();
  1296. } else if (dtype_name == "float16") {
  1297. target = dtype::Float16();
  1298. } else if (dtype_name == "bfloat16") {
  1299. target = dtype::BFloat16();
  1300. } else {
  1301. mgb_assert(
  1302. false, "casted type of amp should be float, but you give %s\n",
  1303. dtype_name.c_str());
  1304. }
  1305. transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
  1306. return ret;
  1307. };
  1308. m.def("_get_amp_high_prec_dtype",
  1309. []() -> std::string { return get_amp_prec_dtype(true); });
  1310. m.def("_set_amp_high_prec_dtype", [](std::string dtype_name) -> std::string {
  1311. return set_amp_prec_dtype(true, dtype_name);
  1312. });
  1313. m.def("_get_amp_low_prec_dtype",
  1314. []() -> std::string { return get_amp_prec_dtype(false); });
  1315. m.def("_set_amp_low_prec_dtype", [](std::string dtype_name) -> std::string {
  1316. return set_amp_prec_dtype(false, dtype_name);
  1317. });
  1318. m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); });
  1319. py::register_exception<TraceError>(m, "TraceError");
  1320. }
  1321. #undef MGE_PY_INTERFACE
  1322. } // namespace mgb::imperative::python