You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 53 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387
  1. #include "megbrain/common.h"
  2. #include "megbrain/dtype.h"
  3. #include "megbrain/imperative/cpp_cupti.h"
  4. #include "megbrain/imperative/ops/autogen.h"
  5. #include "megbrain/imperative/ops/backward_graph.h"
  6. #include "megbrain/imperative/ops/utility.h"
  7. #include "megbrain/imperative/profiler.h"
  8. #include "megbrain/imperative/transformations/dim_expansion.h"
  9. #include "megbrain/imperative/transformations/dtype_promote.h"
  10. #include "megbrain/imperative/transformations/eval.h"
  11. #include "megbrain/imperative/transformations/format.h"
  12. #include "megbrain/imperative/transformations/lazy.h"
  13. #include "megbrain/imperative/transformations/scalar.h"
  14. #include "megbrain/imperative/transformations/symbol.h"
  15. #include "megbrain/imperative/transformations/trace.h"
  16. #include "megbrain/imperative/utils/map.h"
  17. #include "megbrain/opr/io.h"
  18. #include "megbrain/plugin/profiler.h"
  19. #include "megbrain/utils/stats.h"
  20. #include "megdnn/algorithm_cache.h"
  21. #include "./common.h"
  22. #include "./grad.h"
  23. #include "./graph_rt.h"
  24. #include "./helper.h"
  25. #include "./module_trace.h"
  26. #include "./numpy_dtypes.h"
  27. #include "./tensor.h"
  28. #include "./tensor_utils.h"
  29. #include "./transformation.h"
  30. #include <object.h>
  31. #include <pybind11/numpy.h>
  32. #include <pybind11/operators.h>
  33. #include <pybind11/pytypes.h>
  34. #include <pyerrors.h>
  35. #include <iterator>
  36. #include <range/v3/all.hpp>
  37. #include <string>
  38. #include <unordered_map>
  39. #include "../../src/impl/mgb_cg_impl.h"
  40. namespace py = pybind11;
  41. namespace views = ranges::views;
  42. namespace mgb::imperative::python {
  43. interpreter::Interpreter::Channel* interpreter_for_py = nullptr;
  44. PyTypeObject* py_tensor_type = nullptr;
  45. PyTypeObject* py_varnode_type = nullptr;
  46. pybind11::handle py_device_type = nullptr;
  47. PyObject* cpp_use_symbolic_shape;
  48. #define REGISTE_APPLY_FUNC(mode) \
  49. void set_##mode(py::object pyf) { mode = pyf.ptr(); }
  50. REGISTE_APPLY_FUNC(cpp_use_symbolic_shape)
  51. #undef REGISTE_APPLY_FUNC
  52. PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs);
  53. CompNode _get_device(PyObject* const* args, size_t nargs);
  54. PyObject* py_apply(
  55. PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) {
  56. try {
  57. // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  58. // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
  59. // return nullptr;
  60. // }
  61. if (nargs < 2) {
  62. PyErr_SetString(
  63. PyExc_TypeError,
  64. "py_apply expects one Op and at least one tensor "
  65. "as argument");
  66. return nullptr;
  67. }
  68. auto* py_op = args[0];
  69. ++args;
  70. --nargs;
  71. auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>();
  72. SmallVector<ValueRef, 8> tensors(nargs);
  73. mgb::CompNode target_cn;
  74. mgb::DType target_dtype;
  75. auto convert_pyinput_to_tensor = [&](size_t i) -> ValueRef {
  76. if (!target_dtype.valid()) {
  77. target_dtype = npy::dtype_np2mgb_descr(_dtype_promotion(args, nargs));
  78. target_cn = _get_device(args, nargs);
  79. }
  80. HostTensorND ht(target_cn);
  81. ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype);
  82. if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler
  83. // py_tuple is not allowed here because of tracing
  84. return imperative::apply(
  85. CreateTensor(CreateTensor::Const, target_cn, ht.layout()),
  86. HostStorage::make(ht.storage()))[0];
  87. } else { // scaler
  88. return imperative::apply(
  89. CreateTensor(CreateTensor::Const, target_cn, target_dtype, {}),
  90. HostStorage::make(ht.storage()))[0];
  91. }
  92. };
  93. bool is_varnode_apply = false;
  94. for (size_t i = 0; i < nargs; ++i) {
  95. if (PyObject_TypeCheck(args[i], py_varnode_type)) {
  96. is_varnode_apply = true;
  97. }
  98. if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
  99. tensors[i] = tw->m_tensor->data();
  100. } else if (
  101. DTypePromoteCfg::convert_input_enabled &&
  102. op->same_type<Elemwise>()) {
  103. tensors[i] = convert_pyinput_to_tensor(i);
  104. } else {
  105. PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs");
  106. return nullptr;
  107. }
  108. }
  109. auto outputs = [&] { return imperative::apply(*op, tensors); }();
  110. size_t nout = outputs.size();
  111. auto ret = py::tuple(nout);
  112. PyTypeObject* py_type = is_varnode_apply ? py_varnode_type : py_tensor_type;
  113. for (size_t i = 0; i < nout; ++i) {
  114. ret[i] = TensorWrapper::make(py_type, std::move(outputs[i]));
  115. }
  116. return ret.release().ptr();
  117. }
  118. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  119. }
  120. namespace {
  121. template <typename T>
  122. py::handle py_type() {
  123. if constexpr (std::is_same_v<T, py::int_>) {
  124. return (PyObject*)&PyLong_Type;
  125. } else if constexpr (std::is_same_v<T, py::float_>) {
  126. return (PyObject*)&PyFloat_Type;
  127. } else if constexpr (std::is_same_v<T, py::tuple>) {
  128. return (PyObject*)&PyTuple_Type;
  129. } else if constexpr (std::is_same_v<T, py::list>) {
  130. return (PyObject*)&PyList_Type;
  131. } else {
  132. static_assert(std::is_same_v<T, T>);
  133. }
  134. }
  135. template <typename T>
  136. auto scalar2storage(T val, CompNode cn, DType dtype) {
  137. using max_ctype_t = DTypeScalar::max_ctype;
  138. DTypeScalar scalar(dtype);
  139. scalar.set_retain_dtype(val);
  140. HostTensorStorage storage(cn);
  141. auto* raw_ptr = reinterpret_cast<dt_byte*>(new max_ctype_t());
  142. std::shared_ptr<dt_byte> raw_storage = {
  143. raw_ptr, [](dt_byte* ptr) { delete reinterpret_cast<max_ctype_t*>(ptr); }};
  144. storage.only_reset_raw_storage(cn, dtype.size(), raw_storage, 0);
  145. std::memcpy(storage.ptr(), scalar.storage(), dtype.size());
  146. return HostStorage::make(std::move(storage));
  147. }
  148. template <typename ctype>
  149. auto vec2storage(Span<DTypeScalar> vec, CompNode cn, DType dtype) {
  150. mgb_assert(vec.size() <= MEGDNN_MAX_NDIM);
  151. // TODO: use storage cache and modify ConstTensorCache to return (Host, Device)
  152. auto* raw_ptr = new ctype[MEGDNN_MAX_NDIM];
  153. for (size_t i = 0; i < vec.size(); ++i) {
  154. raw_ptr[i] = vec[i].get_cast<ctype>();
  155. }
  156. mgb_assert(sizeof(ctype) == dtype.size());
  157. std::shared_ptr<dt_byte> raw_storage = {
  158. reinterpret_cast<dt_byte*>(raw_ptr),
  159. [](dt_byte* ptr) { delete[] reinterpret_cast<ctype*>(ptr); }};
  160. HostTensorStorage storage(cn);
  161. storage.only_reset_raw_storage(cn, sizeof(ctype) * vec.size(), raw_storage, 0);
  162. return HostStorage::make(std::move(storage));
  163. }
  164. struct HostTensorArgs {
  165. ValueShape shape;
  166. DType dtype;
  167. HostStorage::ref_t storage;
  168. HostTensorND as_tensor_nd() const {
  169. HostTensorND ret(CompNode::default_cpu(), shape.as_tensor_shape(), dtype);
  170. ret.only_reset_raw_storage(*storage);
  171. return ret;
  172. }
  173. };
  174. template <typename seq_type, typename ctype>
  175. bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  176. auto size = obj.size();
  177. if (size > MEGDNN_MAX_NDIM) {
  178. return false;
  179. }
  180. ctype items[size];
  181. for (size_t i = 0; i < size; ++i) {
  182. py::handle item = obj[i];
  183. if (item.get_type().is(py_type<py::int_>())) {
  184. items[i] = (ctype)(dt_int32)item.template cast<py::int_>();
  185. } else if (item.get_type().is(py_type<py::float_>())) {
  186. items[i] = (ctype)(dt_float32)item.template cast<py::float_>();
  187. } else {
  188. return false;
  189. }
  190. }
  191. mgb_assert(sizeof(ctype) == dtype.size());
  192. auto* raw_ptr = new ctype[size];
  193. std::shared_ptr<dt_byte> raw_storage = {
  194. reinterpret_cast<dt_byte*>(raw_ptr),
  195. [](dt_byte* ptr) { delete[] reinterpret_cast<ctype*>(ptr); }};
  196. HostTensorStorage storage(cn);
  197. storage.only_reset_raw_storage(cn, sizeof(ctype) * size, raw_storage, 0);
  198. std::memcpy(storage.ptr(), items, sizeof(ctype) * size);
  199. ret.dtype = dtype;
  200. ret.shape = {size};
  201. ret.storage = HostStorage::make(std::move(storage));
  202. return true;
  203. }
  204. template <typename seq_type>
  205. bool pyseq2hval(seq_type obj, CompNode cn, HostTensorArgs& ret) {
  206. auto size = obj.size();
  207. if (size > MEGDNN_MAX_NDIM) {
  208. return false;
  209. }
  210. DTypeScalar items[size];
  211. DType dtype;
  212. for (size_t i = 0; i < size; ++i) {
  213. auto&& item = obj[i];
  214. if (item.get_type().is(py_type<py::int_>())) {
  215. items[i] = (dt_int32)item.template cast<py::int_>();
  216. if (!dtype.valid()) {
  217. dtype = dtype::Int32();
  218. } else if (dtype != dtype::Int32() && dtype != dtype::Float32()) {
  219. return false;
  220. }
  221. } else if (item.get_type().is(py_type<py::float_>())) {
  222. items[i] = (dt_float32)item.template cast<py::float_>();
  223. if (!dtype.valid()) {
  224. dtype = dtype::Float32();
  225. } else if (dtype == dtype::Int32()) {
  226. dtype = dtype::Float32();
  227. } else if (dtype != dtype::Float32()) {
  228. return false;
  229. }
  230. } else {
  231. return false;
  232. }
  233. }
  234. if (!dtype.valid()) {
  235. dtype = dtype::Float32();
  236. }
  237. ret.dtype = dtype;
  238. ret.shape = {size};
  239. if (dtype == dtype::Int32()) {
  240. ret.storage = vec2storage<dt_int32>({items, size}, cn, dtype);
  241. } else if (dtype == dtype::Float32()) {
  242. ret.storage = vec2storage<dt_float32>({items, size}, cn, dtype);
  243. } else {
  244. mgb_assert(false);
  245. }
  246. return true;
  247. }
  248. template <typename seq_type>
  249. bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  250. if (dtype == dtype::Int32()) {
  251. return pyseq2hval<seq_type, dt_int32>(obj, cn, dtype, ret);
  252. } else if (dtype == dtype::Float32()) {
  253. return pyseq2hval<seq_type, dt_float32>(obj, cn, dtype, ret);
  254. } else if (!dtype.valid()) {
  255. return pyseq2hval<seq_type>(obj, cn, ret);
  256. } else {
  257. return false;
  258. }
  259. }
  260. bool pyarr2hval(py::array obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  261. auto data = obj.cast<py::array>();
  262. auto strides = data.strides();
  263. bool need_squeeze = false;
  264. for (size_t i = 0; i < data.ndim(); ++i) {
  265. if (strides[i] == 0) {
  266. need_squeeze = true;
  267. break;
  268. }
  269. }
  270. if (need_squeeze) {
  271. std::vector<size_t> shape;
  272. for (size_t i = 0; i < data.ndim(); ++i) {
  273. shape.push_back(data.shape(i));
  274. }
  275. data = data.squeeze();
  276. data.resize(shape);
  277. }
  278. HostTensorND retnd(cn);
  279. retnd = npy::np2tensor(data.ptr(), npy::Meth::copy_into(&retnd), dtype);
  280. if (!dtype.valid()) {
  281. dtype = retnd.dtype();
  282. }
  283. mgb_assert(
  284. retnd.layout().is_empty() || retnd.layout().is_contiguous(),
  285. "host value should be continuous");
  286. for (size_t i = 0; i < data.ndim(); ++i) {
  287. ret.shape[ret.shape.ndim++] = data.shape(i);
  288. }
  289. ret.dtype = dtype;
  290. ret.storage = HostStorage::make(retnd.storage());
  291. return true;
  292. }
  293. bool pyint2hval(py::int_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  294. if (!dtype.valid()) {
  295. dtype = dtype::Int32();
  296. }
  297. ret.dtype = dtype;
  298. ret.storage = scalar2storage((dt_int32)obj, cn, dtype);
  299. return true;
  300. }
  301. bool pyfloat2hval(py::float_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  302. if (!dtype.valid()) {
  303. dtype = dtype::Float32();
  304. }
  305. ret.dtype = dtype;
  306. ret.storage = scalar2storage((dt_float32)obj, cn, dtype);
  307. return true;
  308. }
  309. HostTensorArgs pyobj2hval(py::object obj, CompNode cn, DType dtype) {
  310. HostTensorArgs ret;
  311. bool success = false;
  312. // check order: float -> int -> tuple(int -> float) -> list(int -> float)
  313. // only handle `exact` pytype, isinstance also accepts subtype
  314. // for example, isinstance(True, int) == True
  315. if (obj.get_type().is(py_type<py::float_>())) {
  316. success = pyfloat2hval(py::float_(obj), cn, dtype, ret);
  317. } else if (obj.get_type().is(py_type<py::int_>())) { // py::bool_ is py::int_
  318. success = pyint2hval(py::int_(obj), cn, dtype, ret);
  319. } else if (obj.get_type().is(py_type<py::tuple>())) {
  320. success = pyseq2hval<py::tuple>(py::tuple(obj), cn, dtype, ret);
  321. } else if (obj.get_type().is(py_type<py::list>())) {
  322. success = pyseq2hval<py::list>(py::list(obj), cn, dtype, ret);
  323. } else if (obj.is_none()) {
  324. obj = py::list(0);
  325. }
  326. if (!success) {
  327. success = pyarr2hval(obj, cn, dtype, ret);
  328. }
  329. mgb_assert(success);
  330. return ret;
  331. }
  332. struct PyArgDesc {
  333. const char* name;
  334. py::object (*default_value)();
  335. };
  336. struct PyArgDescs {
  337. std::vector<PyArgDesc> items;
  338. ssize_t (*name2idx)(const char* name);
  339. };
  340. py::tuple parse_args(py::tuple args, const PyArgDescs& descs) {
  341. size_t nr_args = args.size();
  342. size_t nr_items = descs.items.size();
  343. mgb_assert(nr_args <= nr_items, "too many args");
  344. if (nr_args == nr_items) {
  345. return args;
  346. }
  347. py::tuple ret(nr_items);
  348. for (size_t i = 0; i < nr_args; ++i) {
  349. ret[i] = args[i];
  350. }
  351. for (size_t i = nr_args; i < nr_items; ++i) {
  352. ret[i] = descs.items[i].default_value();
  353. }
  354. return ret;
  355. }
  356. py::tuple parse_args_and_kwargs(
  357. py::tuple args, py::dict kwargs, const PyArgDescs& descs) {
  358. size_t nr_args = args.size();
  359. size_t nr_kwargs = kwargs.size();
  360. size_t nr_items = descs.items.size();
  361. mgb_assert(nr_args + nr_kwargs <= nr_items, "too many args");
  362. if (nr_args == nr_items) {
  363. return args;
  364. }
  365. py::tuple ret(nr_items);
  366. for (size_t i = 0; i < nr_args; ++i) {
  367. ret[i] = args[i];
  368. }
  369. bool has_value[nr_items - nr_args];
  370. for (size_t i = nr_args; i < nr_items; ++i) {
  371. has_value[i - nr_args] = false;
  372. }
  373. for (auto&& [k, v] : kwargs) {
  374. auto key = py::str(k).cast<std::string>();
  375. ssize_t index = descs.name2idx(key.c_str());
  376. mgb_assert(index >= nr_args);
  377. ret[index] = v;
  378. has_value[index - nr_args] = true;
  379. }
  380. for (size_t i = nr_args; i < nr_items; ++i) {
  381. if (!has_value[i - nr_args]) {
  382. ret[i] = descs.items[i].default_value();
  383. }
  384. }
  385. return ret;
  386. }
  387. CompNode as_comp_node(const std::string& name) {
  388. thread_local struct {
  389. std::string name;
  390. CompNode cn;
  391. } cached;
  392. if (cached.name != name) {
  393. cached.name = name;
  394. cached.cn = CompNode::load(name);
  395. }
  396. return cached.cn;
  397. }
  398. CompNode as_comp_node(py::object py_device) {
  399. std::optional<std::string> device_name;
  400. if (py_device.is_none() || py::str::check_(py_device)) {
  401. auto cls = py::handle(reinterpret_cast<PyObject*>(py_tensor_type));
  402. auto dmap_callback = cls.attr("dmap_callback");
  403. std::string name;
  404. if (dmap_callback.is_none() && py_device.is_none()) {
  405. name = get_default_device();
  406. } else {
  407. if (py_device.is_none()) {
  408. py_device = py::str(get_default_device());
  409. }
  410. if (!dmap_callback.is_none()) {
  411. py_device = dmap_callback(py_device);
  412. }
  413. name = py::str(py_device).cast<std::string>();
  414. }
  415. return as_comp_node(name);
  416. } else {
  417. if (py::isinstance(py_device, py_device_type)) {
  418. py_device = py_device.attr("_cn");
  419. }
  420. mgb_assert(py::isinstance(py_device, py_comp_node_type));
  421. return py_device.cast<CompNode>();
  422. }
  423. }
  424. template <char... Chars>
  425. bool compare_cstr(const char* cstr) {
  426. return (((*cstr++) == Chars) && ...) && *cstr == '\0';
  427. }
  428. ssize_t name2idx(const char* name) {
  429. const char* ch = name;
  430. // TODO: trie
  431. // clang-format off
  432. switch (*ch++) {
  433. case 'd':
  434. switch (*ch++) {
  435. // data
  436. case 'a': return compare_cstr<'t', 'a'>(ch) ? 0 : -1;
  437. // dtype
  438. case 't': return compare_cstr<'y', 'p', 'e'>(ch) ? 1 : -1;
  439. // device
  440. case 'e': return compare_cstr<'v', 'i', 'c', 'e'>(ch) ? 2 : -1;
  441. }
  442. case 'i':
  443. // is_const
  444. return compare_cstr<'s', '_', 'c', 'o', 'n', 's', 't'>(ch) ? 3 : -1;
  445. case 'n':
  446. switch (*ch++) {
  447. // no_cache
  448. case 'o': return compare_cstr<'_', 'c', 'a', 'c', 'h', 'e'>(ch) ? 4 : -1;
  449. // name
  450. case 'a': return compare_cstr<'m', 'e'>(ch) ? 5 : -1;
  451. }
  452. case 'f':
  453. // format
  454. return compare_cstr<'o', 'r', 'm', 'a', 't'>(ch) ? 6 : -1;
  455. }
  456. // clang-format on
  457. return -1;
  458. }
  459. } // namespace
  460. TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
  461. static PyArgDescs descs = {
  462. {
  463. {"data", []() -> py::object { return py::none(); }},
  464. {"dtype", []() -> py::object { return py::none(); }},
  465. {"device", []() -> py::object { return py::none(); }},
  466. {"is_const", []() -> py::object { return py::bool_(false); }},
  467. {"no_cache", []() -> py::object { return py::bool_(false); }},
  468. {"name", []() -> py::object { return py::none(); }},
  469. {"format", []() -> py::object { return py::none(); }},
  470. },
  471. name2idx};
  472. py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
  473. auto tup = py::reinterpret_borrow<py::tuple>(args);
  474. if (kwargs) {
  475. tup = parse_args_and_kwargs(
  476. tup, py::reinterpret_borrow<py::dict>(kwargs), descs);
  477. } else {
  478. tup = parse_args(tup, descs);
  479. }
  480. mgb_assert(tup.size() == 7);
  481. if (auto* t = try_cast(tup[0].ptr())) {
  482. m_tensor = t->m_tensor;
  483. // TODO: merge two path in arg parse
  484. if (!tup[1].is_none()) {
  485. auto dtype = tup[1].cast<DType>();
  486. mgb_assert(
  487. dtype == m_tensor->dtype(), "dtype mismatch: %s vs %s",
  488. dtype.name(), m_tensor->dtype().name());
  489. }
  490. if (!tup[2].is_none()) {
  491. auto device = as_comp_node(tup[2]);
  492. mgb_assert(
  493. device == m_tensor->comp_node(), "device mismatch: %s vs %s",
  494. device.to_string().c_str(),
  495. m_tensor->comp_node().to_string().c_str());
  496. }
  497. mgb_assert(!tup[3].cast<bool>(), "expect is_const == False, got True");
  498. bool no_cache = tup[4].cast<bool>();
  499. if (no_cache) {
  500. // always copy because it's hard to tell whether this tensor is cached
  501. m_tensor = m_tensor->copy();
  502. }
  503. // ignore name
  504. if (!tup[6].is_none()) {
  505. Format format = tup[6].cast<std::string>();
  506. mgb_assert(
  507. format == m_tensor->format(), "format mismatch: %s vs %s",
  508. format.to_string().c_str(), m_tensor->format().to_string().c_str());
  509. }
  510. } else {
  511. auto data = tup[0];
  512. DType dtype = tup[1].cast<DType>();
  513. CompNode cn = as_comp_node(tup[2]);
  514. bool is_const = tup[3].cast<bool>();
  515. bool no_cache = tup[4].cast<bool>();
  516. std::string name;
  517. if (!tup[5].is_none()) {
  518. name = tup[5].cast<std::string>();
  519. }
  520. Format format;
  521. if (!tup[6].is_none()) {
  522. format = tup[6].cast<std::string>();
  523. }
  524. {
  525. CreateTensor::Kind kind = is_const ? CreateTensor::Const
  526. : no_cache ? CreateTensor::Unique
  527. : CreateTensor::Common;
  528. ValueRef val;
  529. if (py::isinstance(data, Py_Varnode)) {
  530. cg::VarNode* m_node = py::handle(data).cast<cg::VarNode*>();
  531. val = imperative::apply(
  532. CreateNode(m_node), Span<ValueRef>(nullptr, nullptr))[0];
  533. } else {
  534. auto&& hval = pyobj2hval(data, cn, dtype);
  535. val = imperative::apply(
  536. CreateTensor(kind, cn, hval.dtype, hval.shape, format),
  537. hval.storage)[0];
  538. }
  539. m_tensor.emplace(val);
  540. }
  541. if (!name.empty()) {
  542. m_tensor->reset(imperative::apply(RenameValue(name), m_tensor->data())[0]);
  543. }
  544. }
  545. mgb_assert(m_tensor->data());
  546. }
  547. PyObject* TensorWrapper::module_trace_info() {
  548. if (auto module_trace_info =
  549. ModuleTraceTransformation::module_trace_info_map.try_get(
  550. m_tensor->data())) {
  551. if (module_trace_info->ptr()) {
  552. return module_trace_info->inc_ref().ptr();
  553. }
  554. }
  555. PyErr_SetString(
  556. PyExc_AttributeError,
  557. "Has no attribute named \'_NodeMixin__node\', please "
  558. "set it first");
  559. return nullptr;
  560. }
  561. void TensorWrapper::set_module_trace_info(PyObject* obj) {
  562. // TODO: erase when obj == nullptr
  563. ModuleTraceTransformation::module_trace_info_map[m_tensor->data()] =
  564. py::reinterpret_borrow<py::object>(obj);
  565. }
  566. void TensorWrapper::_set_format(PyObject* dest) {
  567. auto py_dest = py::reinterpret_borrow<py::object>(dest);
  568. auto format = py_dest.cast<std::string>();
  569. m_tensor->set_format(format);
  570. }
  571. void TensorWrapper::_set_name(PyObject* dest) {
  572. auto py_dest = py::reinterpret_borrow<py::object>(dest);
  573. auto name = py_dest.cast<std::string>();
  574. m_tensor->set_name(name);
  575. }
  576. PyObject* TensorWrapper::_detail() {
  577. return py::str(m_tensor->data().unwrap().to_string()).release().ptr();
  578. }
  579. void TensorWrapper::_watch() {
  580. m_tensor->data().watch();
  581. }
  582. PyObject* TensorWrapper::shape() {
  583. auto shape = m_tensor->shape();
  584. if (!shape) {
  585. Py_RETURN_NONE;
  586. }
  587. py::tuple ret(shape->ndim);
  588. for (size_t i = 0; i < shape->ndim; ++i) {
  589. ret[i] = shape->at(i);
  590. }
  591. return ret.release().ptr();
  592. }
  593. PyObject* TensorWrapper::dtype() {
  594. return py::cast(m_tensor->dtype()).release().ptr();
  595. }
  596. PyObject* TensorWrapper::device() {
  597. return py::cast(m_tensor->comp_node()).release().ptr();
  598. }
  599. PyObject* TensorWrapper::format() {
  600. return py::cast(m_tensor->format().to_string()).release().ptr();
  601. }
  602. PyObject* TensorWrapper::numpy() {
  603. auto hv = m_tensor->numpy();
  604. if (!hv) {
  605. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  606. return nullptr;
  607. }
  608. auto arr = py::reinterpret_steal<py::array>(
  609. npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE));
  610. if (hv->shape().is_scalar()) {
  611. mgb_assert(PyArray_Check(arr.ptr()));
  612. return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
  613. }
  614. return arr.release().ptr();
  615. }
  616. void TensorWrapper::reset(PyObject* tensor) {
  617. TensorWrapper* t = TensorWrapper::try_cast(tensor);
  618. if (!t) {
  619. throw py::type_error("expect Tensor");
  620. }
  621. m_tensor->reset(t->m_tensor->data());
  622. }
  623. PyObject* TensorWrapper::detach() {
  624. auto detached = imperative::apply(DetachGrad(), m_tensor->data())[0];
  625. return TensorWrapper::make(py_tensor_type, detached).release().ptr();
  626. }
  627. PyObject* TensorWrapper::_dev_tensor() {
  628. auto dv = m_tensor->data().dev_tensor();
  629. // TODO: handle scalar
  630. return py::cast(dv->as_nd(true)).release().ptr();
  631. }
  632. void TensorWrapper::_drop() {
  633. imperative::apply(DTRCommand(DTRCommand::Drop), m_tensor->data());
  634. }
  635. PyObject* TensorWrapper::isscalar() {
  636. if (m_tensor->is_scalar()) {
  637. Py_RETURN_TRUE;
  638. } else {
  639. Py_RETURN_FALSE;
  640. }
  641. }
  642. PyObject* TensorWrapper::_var() {
  643. TypedValueRef<NodeValue> value =
  644. imperative::apply(GetVarVal(), m_tensor->data())[0].as_ref<NodeValue>();
  645. auto* node = value->node();
  646. return py::cast(node).release().ptr();
  647. }
  648. PyObject* TensorWrapper::_graph() {
  649. TypedValueRef<NodeValue> value =
  650. imperative::apply(GetVarVal(), m_tensor->data())[0].as_ref<NodeValue>();
  651. auto* graph = value->graph();
  652. return py::cast(graph).release().ptr();
  653. }
  654. struct TensorWeakRef {
  655. ValueWeakRef data;
  656. TensorWeakRef(const TensorWrapper& tw) : data(tw.m_tensor->data()) {}
  657. py::object operator()() {
  658. if (auto p = data.lock()) {
  659. return TensorWrapper::make(py_tensor_type, p);
  660. }
  661. return py::none();
  662. }
  663. };
  664. #ifdef METH_FASTCALL
  665. #define MGE_PY_INTERFACE(NAME, FUNC) \
  666. { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
  667. #else
  668. #define WRAP_FUNC_PY35(FUNC) \
  669. PyObject* py35_##FUNC(PyObject* self, PyObject* args) { \
  670. auto* arr = &PyTuple_GET_ITEM(args, 0); \
  671. auto size = PyTuple_GET_SIZE(args); \
  672. return FUNC(self, arr, size); \
  673. }
  674. WRAP_FUNC_PY35(py_apply);
  675. WRAP_FUNC_PY35(dtype_promotion);
  676. WRAP_FUNC_PY35(get_device);
  677. WRAP_FUNC_PY35(make_shape_tuple);
  678. WRAP_FUNC_PY35(getitem_cpp);
  679. WRAP_FUNC_PY35(setitem_cpp);
  680. WRAP_FUNC_PY35(split_cpp);
  681. WRAP_FUNC_PY35(expand_dims_cpp);
  682. WRAP_FUNC_PY35(squeeze_cpp);
  683. WRAP_FUNC_PY35(transpose_cpp);
  684. WRAP_FUNC_PY35(broadcast_cpp);
  685. WRAP_FUNC_PY35(reshape_cpp);
  686. WRAP_FUNC_PY35(adaptive_pool2d_cpp);
  687. WRAP_FUNC_PY35(Const);
  688. WRAP_FUNC_PY35(astype_cpp);
  689. WRAP_FUNC_PY35(matmul_cpp);
  690. WRAP_FUNC_PY35(batched_matmul_cpp);
  691. WRAP_FUNC_PY35(convert_single_value_cpp);
  692. WRAP_FUNC_PY35(convert_inputs_cpp);
  693. WRAP_FUNC_PY35(astensor1d_cpp);
  694. WRAP_FUNC_PY35(pixel_shuffle_cpp);
  695. #undef WRAP_FUNC_PY35
  696. #define MGE_PY_INTERFACE(NAME, FUNC) \
  697. { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
  698. #endif
  699. void init_tensor(py::module m) {
  700. imperative::Tensor::static_initialize();
  701. // Transformations
  702. static auto& transformations = TransformationManager::get_instance();
  703. using Segment = TransformationManager::Segment;
  704. using Channel = interpreter::Interpreter::Channel;
  705. auto* channel =
  706. imperative::ResourceManager::create_global<std::unique_ptr<Channel>>(
  707. interpreter::Interpreter::inst().create_channel())
  708. ->get();
  709. interpreter_for_py = channel;
  710. MGB_MARK_USED_VAR(
  711. transformations
  712. .register_at<Segment::Eval>(
  713. std::make_shared<InterpreterTransformation>(
  714. std::shared_ptr<Channel>(channel, [](Channel*) {})))
  715. .release());
  716. MGB_MARK_USED_VAR(transformations
  717. .register_at<Segment::Scalar>(
  718. std::make_shared<ScalarTransformation>())
  719. .release());
  720. MGB_MARK_USED_VAR(transformations
  721. .register_at<Segment::Symbol>(
  722. std::make_shared<SymbolTransformation>())
  723. .release());
  724. MGB_MARK_USED_VAR(transformations
  725. .register_at<Segment::DTypePromote>(
  726. std::make_shared<DTypePromoteTransformation>())
  727. .release());
  728. MGB_MARK_USED_VAR(transformations
  729. .register_at<Segment::DimExpansion>(
  730. std::make_shared<DimExpansionTransformation>())
  731. .release());
  732. auto format_trans = std::make_shared<FormatTransformation>();
  733. MGB_MARK_USED_VAR(
  734. transformations.register_at<Segment::Format>(format_trans).release());
  735. static py::exception<interpreter::AsyncError> py_async_error(
  736. m, "AsyncError", PyExc_RuntimeError);
  737. py::register_exception_translator([](std::exception_ptr p) {
  738. try {
  739. if (p)
  740. std::rethrow_exception(p);
  741. } catch (const interpreter::AsyncError& e) {
  742. pyext17::pybind11_translate_exception(e.nested_ptr());
  743. if (PyErr_Occurred()) {
  744. PyObject *exc, *val, *tb;
  745. PyErr_Fetch(&exc, &val, &tb);
  746. PyErr_NormalizeException(&exc, &val, &tb);
  747. if (tb) {
  748. PyException_SetTraceback(val, tb);
  749. }
  750. auto val2 = py_async_error.py::object::operator()(
  751. "An async error is reported. See above for the actual cause."
  752. " Hint: This is where it is reported, not where it happened."
  753. " You may call `megengine.config.async_level = 0 "
  754. "to get better error reporting.");
  755. PyException_SetCause(
  756. val2.ptr(), val); // PyException_SetCause steals reference
  757. Py_XDECREF(exc);
  758. Py_XDECREF(tb);
  759. PyErr_Restore(
  760. py_async_error.inc_ref().ptr(), val2.release().ptr(), nullptr);
  761. } else {
  762. py_async_error("Unkown async error");
  763. }
  764. }
  765. });
  766. // Tensor
  767. auto* tensor_type =
  768. TensorWrapper::wrap_t::type()
  769. .def<&TensorWrapper::numpy>("numpy")
  770. .def_getset<&TensorWrapper::shape>("shape")
  771. .def_getset<&TensorWrapper::dtype>("dtype")
  772. .def_getset<&TensorWrapper::device>("device")
  773. .def<&TensorWrapper::format>("format")
  774. .def<&TensorWrapper::reset>("_reset")
  775. .def<&TensorWrapper::isscalar>("_isscalar")
  776. .def<&TensorWrapper::detach>("detach")
  777. // TODO: remove this
  778. .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
  779. .def<&TensorWrapper::_drop>("_drop")
  780. .def<&TensorWrapper::_detail>("_detail")
  781. .def<&TensorWrapper::_set_format>("_set_format")
  782. .def<&TensorWrapper::_set_name>("_set_name")
  783. .def<&TensorWrapper::_watch>("_watch")
  784. .def<&TensorWrapper::_var>("var")
  785. .def<&TensorWrapper::_graph>("graph")
  786. .def_getset<
  787. &TensorWrapper::module_trace_info,
  788. &TensorWrapper::set_module_trace_info>("_NodeMixin__node")
  789. .finalize();
  790. if (!tensor_type)
  791. throw py::error_already_set();
  792. py::setattr(m, "Tensor", tensor_type);
  793. py::enum_<Format::Type>(m, "FormatType")
  794. .value("DEFAULT", Format::Type::DEFAULT)
  795. .value("NCHW", Format::Type::NCHW)
  796. .value("NHWC", Format::Type::NHWC)
  797. .export_values();
  798. py::class_<TensorWeakRef>(m, "TensorWeakRef")
  799. .def(py::init<const TensorWrapper&>())
  800. .def("__call__", &TensorWeakRef::operator());
  801. static PyMethodDef method_defs[] = {
  802. MGE_PY_INTERFACE(apply, py_apply),
  803. MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
  804. MGE_PY_INTERFACE(get_device, get_device),
  805. MGE_PY_INTERFACE(make_shape_tuple, make_shape_tuple),
  806. MGE_PY_INTERFACE(getitem_cpp, getitem_cpp),
  807. MGE_PY_INTERFACE(setitem_cpp, setitem_cpp),
  808. MGE_PY_INTERFACE(split_cpp, split_cpp),
  809. MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp),
  810. MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp),
  811. MGE_PY_INTERFACE(transpose_cpp, transpose_cpp),
  812. MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp),
  813. MGE_PY_INTERFACE(reshape_cpp, reshape_cpp),
  814. MGE_PY_INTERFACE(adaptive_pool2d_cpp, adaptive_pool2d_cpp),
  815. MGE_PY_INTERFACE(Const, Const),
  816. MGE_PY_INTERFACE(astype_cpp, astype_cpp),
  817. MGE_PY_INTERFACE(matmul_cpp, matmul_cpp),
  818. MGE_PY_INTERFACE(batched_matmul_cpp, batched_matmul_cpp),
  819. MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp),
  820. MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp),
  821. MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp),
  822. MGE_PY_INTERFACE(pixel_shuffle_cpp, pixel_shuffle_cpp),
  823. {nullptr, nullptr, 0, nullptr}};
  824. for (auto&& def : method_defs) {
  825. if (def.ml_meth != nullptr) {
  826. auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
  827. if (!func)
  828. throw py::error_already_set();
  829. py::setattr(m, def.ml_name, func);
  830. }
  831. }
  832. static constexpr auto sync_py_task_q = [] {
  833. py::gil_scoped_release _;
  834. py_task_q.wait_all_task_finish();
  835. };
  836. m.def("clear_candidates", [channel]() { channel->clear_candidates(); });
  837. m.def("set_option", [channel](std::string name, size_t value) {
  838. channel->set_option(name, value);
  839. });
  840. m.def("get_option",
  841. [channel](std::string name) { return channel->get_option(name); });
  842. m.def("push_scope", [channel](std::string name) {
  843. Transformation::push_scope(name);
  844. channel->push_scope(name);
  845. });
  846. m.def("pop_scope", [channel](std::string name) {
  847. channel->pop_scope(name);
  848. Transformation::pop_scope(name);
  849. });
  850. m.def("start_profile", [channel](imperative::Profiler::options_t options) {
  851. channel->sync();
  852. imperative::Profiler::load_options(std::move(options));
  853. imperative::Profiler::start_profile();
  854. channel->start_profile();
  855. });
  856. m.def("stop_profile", [channel]() -> std::function<void(std::string, std::string)> {
  857. channel->stop_profile();
  858. channel->sync();
  859. CompNode::sync_all();
  860. imperative::Profiler::stop_profile();
  861. auto results = std::make_shared<imperative::Profiler::bundle_t>(
  862. imperative::Profiler::collect());
  863. return [results = results](std::string basename, std::string format) mutable {
  864. imperative::Profiler::dump_profile(basename, format, std::move(*results));
  865. results = nullptr;
  866. };
  867. });
  868. m.def("enable_cupti", &cupti::enable);
  869. m.def("disable_cupti", &cupti::disable);
  870. m.def("cupti_available", &cupti::available);
  871. m.def("sync", [channel]() {
  872. if (channel->check_available()) {
  873. channel->sync();
  874. }
  875. sync_py_task_q();
  876. });
  877. m.def("full_sync", [channel]() {
  878. if (channel->check_available()) {
  879. channel->sync();
  880. }
  881. CompNode::sync_all();
  882. CompNode::foreach ([](CompNode cn) {
  883. auto err = cn.check_async_error();
  884. mgb_assert(!err, "%s", err->what());
  885. });
  886. sync_py_task_q();
  887. });
  888. m.def("close", [channel]() {
  889. channel->close();
  890. sync_py_task_q();
  891. });
  892. // GradTransformation
  893. py::handle grad_key_type =
  894. GradKeyWrapper::wrap_t::type()
  895. .def<&GradKeyWrapper::attach>("attach")
  896. .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
  897. .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>(
  898. "name")
  899. .def<&GradKeyWrapper::enter>("enter")
  900. .def<&GradKeyWrapper::exit>("exit")
  901. .def<&GradKeyWrapper::suppress>("suppress")
  902. .def<&GradKeyWrapper::resume>("resume")
  903. .finalize();
  904. if (!grad_key_type)
  905. throw py::error_already_set();
  906. py::setattr(m, "GradKey", grad_key_type);
  907. m.def("backward", &GradKeyWrapper::backward);
  908. m.def("get_backward_closure", &GradKeyWrapper::get_backward_closure);
  909. m.def("set_py_tensor_type", [](py::object type_obj) {
  910. py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
  911. });
  912. m.def("set_py_varnode_type", [](py::object type_obj) {
  913. py_varnode_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
  914. });
  915. m.def("set_py_device_type",
  916. [](py::object type_obj) { py_device_type = type_obj.inc_ref(); });
  917. /**
  918. * \brief trace proxy
  919. *
  920. */
  921. struct Trace {
  922. bool symbolic = false;
  923. bool no_exec = false;
  924. bool capture_as_const = false;
  925. bool profile = false;
  926. bool record_input_shapes = false;
  927. py::function options_visitor;
  928. std::shared_ptr<TracingTransformation> tracing;
  929. std::shared_ptr<CompiledTransformation> compiled;
  930. std::shared_ptr<LazyEvalTransformation> lazy_eval;
  931. std::pair<size_t, std::shared_ptr<GraphProfiler>> profiler;
  932. std::optional<TraceResult> trace_result;
  933. std::function<bool(py::object, py::object)> array_comparator;
  934. std::unique_ptr<CleanupGuard<>> tracing_guard;
  935. std::unique_ptr<CleanupGuard<>> compiled_guard;
  936. std::unique_ptr<CleanupGuard<>> lazy_eval_guard;
  937. bool compare_value(ValueRef lhs, ValueRef rhs) {
  938. auto lvalue = lhs.cast_ref<HostValue>();
  939. auto rvalue = rhs.cast_ref<HostValue>();
  940. if (lvalue->shape() != rvalue->shape()) {
  941. return false;
  942. }
  943. if (lvalue->shape().total_nr_elems() == 1) {
  944. return lvalue->item() == rvalue->item();
  945. }
  946. HostTensorND lnd = lvalue->as_nd(true);
  947. HostTensorND rnd = rvalue->as_nd(true);
  948. auto larr = py::reinterpret_steal<py::array>(
  949. npy::ndarray_from_tensor(lnd, npy::ShareType::TRY_SHARE));
  950. auto rarr = py::reinterpret_steal<py::array>(
  951. npy::ndarray_from_tensor(rnd, npy::ShareType::TRY_SHARE));
  952. return array_comparator(larr, rarr);
  953. }
  954. void enter() {
  955. auto& self = *this;
  956. if (!self.trace_result) { // untraced
  957. self.tracing = std::make_shared<TracingTransformation>(
  958. self.capture_as_const, self.record_input_shapes);
  959. if (self.symbolic) {
  960. self.lazy_eval =
  961. std::make_shared<LazyEvalTransformation>(self.no_exec);
  962. self.options_visitor(py::cast(&self.lazy_eval->options()));
  963. }
  964. } else if (!self.compiled) { // traced but not compiled
  965. using namespace std::placeholders;
  966. self.compiled = std::make_shared<CompiledTransformation>(
  967. *self.trace_result, self.record_input_shapes);
  968. self.compiled->set_value_comparator(
  969. std::bind(&Trace::compare_value, this, _1, _2));
  970. self.options_visitor(py::cast(&self.compiled->options()));
  971. try {
  972. self.compiled->compile();
  973. } catch (const std::exception& e) {
  974. mgb_log_error("error in trace: %s", e.what());
  975. }
  976. }
  977. // register transformations
  978. if (self.compiled) {
  979. if (self.profile) {
  980. auto& current_graph = self.compiled->graph();
  981. if (self.profiler.first != self.compiled->graph().id()) {
  982. // graph changed
  983. self.profiler = std::make_pair(
  984. current_graph.id(),
  985. std::make_shared<GraphProfiler>(&current_graph));
  986. }
  987. }
  988. compiled_guard =
  989. transformations.register_at<Segment::Trace>(self.compiled);
  990. // start execute because InputCallback depends
  991. self.compiled->execute();
  992. } else if (self.tracing) {
  993. tracing_guard =
  994. transformations.register_at<Segment::Trace>(self.tracing);
  995. if (self.lazy_eval) {
  996. lazy_eval_guard =
  997. transformations.register_at<Segment::Eval>(self.lazy_eval);
  998. }
  999. } else {
  1000. mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
  1001. }
  1002. }
  1003. void exit() {
  1004. auto& self = *this;
  1005. if (self.tracing) {
  1006. tracing_guard.reset();
  1007. self.trace_result = self.tracing->get_result();
  1008. self.tracing.reset();
  1009. if (self.lazy_eval) {
  1010. auto lazy_eval = std::move(self.lazy_eval);
  1011. lazy_eval_guard.reset();
  1012. lazy_eval->check_exception();
  1013. }
  1014. } else if (self.compiled) {
  1015. compiled_guard.reset();
  1016. self.compiled->wait();
  1017. } else {
  1018. mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
  1019. }
  1020. }
  1021. VarNodeArray dump(
  1022. std::shared_ptr<ComputingGraph> graph,
  1023. std::vector<std::tuple<std::string, std::string, TensorShape>> inputs,
  1024. std::vector<std::pair<std::string, std::string>> outputs,
  1025. bool prefer_input_names) {
  1026. auto& self = *this;
  1027. mgb_assert(self.trace_result);
  1028. // mark is like "arg_0", "kwarg_xxx", "output_0" ...
  1029. std::unordered_map<std::string, size_t> mark2var;
  1030. for (size_t i = 0; i < self.trace_result->vars.size(); ++i) {
  1031. auto& name = self.trace_result->vars[i].mark;
  1032. if (!name.empty()) {
  1033. mark2var[name] = i;
  1034. }
  1035. }
  1036. std::vector<std::tuple<size_t, std::string, TensorShape>> input_vars;
  1037. std::vector<std::pair<size_t, std::string>> output_vars;
  1038. for (auto&& [input_mark, input_name, input_shape] : inputs) {
  1039. mgb_assert(input_shape.ndim, "input shape invalid");
  1040. input_vars.push_back(
  1041. {mark2var.at(input_mark), input_name, input_shape});
  1042. }
  1043. for (auto&& [output_name, repr] : outputs) {
  1044. output_vars.push_back({mark2var.at(output_name), repr});
  1045. }
  1046. self.options_visitor(py::cast(&graph->options()));
  1047. auto vars = self.trace_result->dump(
  1048. *graph, input_vars, output_vars, prefer_input_names);
  1049. return vars;
  1050. }
  1051. };
  1052. py::class_<Trace>(m, "Trace")
  1053. .def(py::init<>())
  1054. .def_readwrite("record_input_shapes", &Trace::record_input_shapes)
  1055. .def_readwrite("array_comparator", &Trace::array_comparator)
  1056. .def_readwrite("profile", &Trace::profile)
  1057. .def_property_readonly(
  1058. "options",
  1059. [](Trace& self) {
  1060. if (self.compiled) {
  1061. return &self.compiled->options();
  1062. } else {
  1063. return (ComputingGraph::Options*)nullptr;
  1064. }
  1065. })
  1066. .def("get_profile",
  1067. [](Trace& self) -> py::object {
  1068. if (self.profiler.second && self.compiled) {
  1069. auto json = self.profiler.second->to_json_full(
  1070. self.compiled->graph().current_comp_seq());
  1071. return py::str(json->to_string());
  1072. } else {
  1073. return py::none();
  1074. }
  1075. })
  1076. .def_readwrite("symbolic", &Trace::symbolic)
  1077. .def_readwrite("capture_as_const", &Trace::capture_as_const)
  1078. .def_readwrite("no_exec", &Trace::no_exec)
  1079. .def_readwrite("options_visitor", &Trace::options_visitor)
  1080. .def("enter", &Trace::enter)
  1081. .def("exit", &Trace::exit)
  1082. .def("dump", &Trace::dump)
  1083. .def("begin_excluded_region",
  1084. [](Trace& self) {
  1085. mgb_assert(bool(self.tracing) ^ bool(self.compiled));
  1086. if (self.tracing) {
  1087. self.tracing_guard.reset();
  1088. } else if (self.compiled) {
  1089. self.compiled_guard.reset();
  1090. }
  1091. })
  1092. .def("end_excluded_region", [](Trace& self) {
  1093. mgb_assert(bool(self.tracing) ^ bool(self.compiled));
  1094. if (self.tracing) {
  1095. self.tracing_guard =
  1096. transformations.register_at<Segment::Trace>(self.tracing);
  1097. } else if (self.compiled) {
  1098. self.compiled_guard =
  1099. transformations.register_at<Segment::Trace>(self.compiled);
  1100. }
  1101. });
  1102. m.def("name_tensor", [](std::string name, py::object tensor) {
  1103. auto* tw = TensorWrapper::try_cast(tensor.ptr());
  1104. auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0];
  1105. tw->m_tensor->reset(output);
  1106. });
  1107. m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool {
  1108. SmallVector<ValueRef> values(tensors.size());
  1109. for (size_t i = 0; i < tensors.size(); ++i) {
  1110. values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
  1111. }
  1112. auto outputs = imperative::apply(GetGradKey(), values);
  1113. if (outputs[0].is<GradKeyValue>()) {
  1114. return true;
  1115. } else {
  1116. return false;
  1117. }
  1118. });
  1119. m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object {
  1120. SmallVector<ValueRef> values(tensors.size());
  1121. for (size_t i = 0; i < tensors.size(); ++i) {
  1122. values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
  1123. }
  1124. auto output = imperative::apply(GetGradKey(), values)[0];
  1125. if (!output) {
  1126. return py::none();
  1127. }
  1128. return py::reinterpret_borrow<py::object>(GradKeyWrapper::wrap_t::pycast(
  1129. GradKeyWrapper::get(output.cast<GradKeyValue>())));
  1130. });
  1131. m.def("set_grad", [](py::function backward_fn, std::vector<py::object> inputs,
  1132. std::vector<py::object> outputs) {
  1133. GenericFunction generic_backward_fn =
  1134. [backward_fn](Span<ValueRef> output_grads) -> ValueRefList {
  1135. py::list output_grad_tws;
  1136. for (auto&& output_grad : output_grads) {
  1137. if (output_grad) {
  1138. output_grad_tws.append(
  1139. TensorWrapper::make(py_tensor_type, output_grad));
  1140. } else {
  1141. output_grad_tws.append(py::none());
  1142. }
  1143. }
  1144. py::tuple input_grad_tws = backward_fn(*output_grad_tws);
  1145. ValueRefList input_grads(input_grad_tws.size());
  1146. for (size_t i = 0; i < input_grad_tws.size(); ++i) {
  1147. auto input_grad_tw = input_grad_tws[i];
  1148. if (!input_grad_tw.is_none()) {
  1149. input_grads[i] =
  1150. py::cast<TensorWrapper>(input_grad_tw).m_tensor->data();
  1151. } else {
  1152. input_grads[i] = {};
  1153. }
  1154. }
  1155. return input_grads;
  1156. };
  1157. SmallVector<ValueRef> values(inputs.size() + outputs.size());
  1158. for (size_t i = 0; i < inputs.size(); ++i) {
  1159. values[i] = inputs[i].cast<TensorWrapper>().m_tensor->data();
  1160. }
  1161. for (size_t i = 0; i < outputs.size(); ++i) {
  1162. values[i + inputs.size()] =
  1163. outputs[i].cast<TensorWrapper>().m_tensor->data();
  1164. }
  1165. auto wrapped_output_values =
  1166. imperative::apply(SetGrad(generic_backward_fn, inputs.size()), values);
  1167. std::vector<py::object> wrapped_outputs;
  1168. mgb_assert(wrapped_output_values.size() == outputs.size());
  1169. for (auto&& output_value : wrapped_output_values) {
  1170. wrapped_outputs.push_back(
  1171. TensorWrapper::make(py_tensor_type, output_value));
  1172. }
  1173. return wrapped_outputs;
  1174. });
  1175. // ModuleTraceTransformation
  1176. static py::function module_trace_hook;
  1177. static auto get_module_trace = [] {
  1178. static std::shared_ptr<ModuleTraceTransformation> module_trace_transformation;
  1179. if (!module_trace_transformation) {
  1180. mgb_assert(module_trace_hook);
  1181. module_trace_transformation =
  1182. std::make_shared<ModuleTraceTransformation>(module_trace_hook);
  1183. MGB_MARK_USED_VAR(transformations
  1184. .register_at<Segment::ModuleTrace>(
  1185. module_trace_transformation)
  1186. .release());
  1187. }
  1188. return module_trace_transformation;
  1189. };
  1190. m.def("set_cpp_use_symbolic_shape", &set_cpp_use_symbolic_shape);
  1191. m.def("set_module_tracing", [=] { get_module_trace()->enable(); });
  1192. m.def("unset_module_tracing", [=] { get_module_trace()->disable(); });
  1193. m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); });
  1194. m.def("set_module_trace_hook", [](py::function function) {
  1195. module_trace_hook = function;
  1196. module_trace_hook.inc_ref();
  1197. });
  1198. auto atexit = py::module::import("atexit");
  1199. atexit.attr("register")(py::cpp_function([]() { module_trace_hook = {}; }));
  1200. m.def("begin_record_values", [] { Value::begin_record_values(); });
  1201. m.def("end_record_values", [] {
  1202. std::vector<std::pair<size_t, std::string>> reprs;
  1203. auto values = Value::end_record_values();
  1204. for (auto&& value : values) {
  1205. reprs.push_back({value.id(), value.to_string()});
  1206. }
  1207. return reprs;
  1208. });
  1209. m.def("print_stats", [] { Stats::print(); });
  1210. m.def("reset_stats", [] { Stats::reset(); });
  1211. m.def("_get_convert_inputs",
  1212. []() -> bool { return DTypePromoteCfg::convert_input_enabled; });
  1213. m.def("_set_convert_inputs", [](bool flag) -> bool {
  1214. bool ret = DTypePromoteCfg::convert_input_enabled;
  1215. DTypePromoteCfg::convert_input_enabled = flag;
  1216. return ret;
  1217. });
  1218. m.def("_get_amp_dtype_autocast",
  1219. []() -> bool { return DTypePromoteCfg::amp_dtype_autocast_enabled; });
  1220. m.def("_set_amp_dtype_autocast", [](bool flag) -> bool {
  1221. bool ret = DTypePromoteCfg::amp_dtype_autocast_enabled;
  1222. DTypePromoteCfg::amp_dtype_autocast_enabled = flag;
  1223. return ret;
  1224. });
  1225. static auto get_amp_prec_dtype = [](bool is_high) -> std::string {
  1226. DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
  1227. : DTypePromoteCfg::amp_low_prec_dtype;
  1228. mgb_assert(target.category() == DTypeCategory::FLOAT);
  1229. std::string ret = target.name();
  1230. transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
  1231. return ret;
  1232. };
  1233. static auto set_amp_prec_dtype = [](bool is_high,
  1234. std::string dtype_name) -> std::string {
  1235. DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
  1236. : DTypePromoteCfg::amp_low_prec_dtype;
  1237. std::string ret = target.name();
  1238. if (dtype_name == "float32") {
  1239. target = dtype::Float32();
  1240. } else if (dtype_name == "float16") {
  1241. target = dtype::Float16();
  1242. } else if (dtype_name == "bfloat16") {
  1243. target = dtype::BFloat16();
  1244. } else {
  1245. mgb_assert(
  1246. false, "casted type of amp should be float, but you give %s\n",
  1247. dtype_name.c_str());
  1248. }
  1249. transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
  1250. return ret;
  1251. };
  1252. m.def("_get_amp_high_prec_dtype",
  1253. []() -> std::string { return get_amp_prec_dtype(true); });
  1254. m.def("_set_amp_high_prec_dtype", [](std::string dtype_name) -> std::string {
  1255. return set_amp_prec_dtype(true, dtype_name);
  1256. });
  1257. m.def("_get_amp_low_prec_dtype",
  1258. []() -> std::string { return get_amp_prec_dtype(false); });
  1259. m.def("_set_amp_low_prec_dtype", [](std::string dtype_name) -> std::string {
  1260. return set_amp_prec_dtype(false, dtype_name);
  1261. });
  1262. m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); });
  1263. // FormatTransformation
  1264. m.def("set_auto_format_convert",
  1265. [format_trans](bool enabled) { format_trans->set_auto_convert(enabled); });
  1266. m.def("get_auto_format_convert",
  1267. [format_trans]() { return format_trans->get_auto_convert(); });
  1268. py::register_exception<TraceError>(m, "TraceError");
  1269. }
  1270. #undef MGE_PY_INTERFACE
  1271. } // namespace mgb::imperative::python