You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.cpp 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747
  1. /**
  2. * \file imperative/python/src/ops.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./ops.h"
  12. #include "./helper.h"
  13. #include "./tensor.h"
  14. #include "megbrain/common.h"
  15. #include "megbrain/imperative.h"
  16. #include "megbrain/imperative/graph_builder.h"
  17. #include "megbrain/imperative/ops/backward_graph.h"
  18. #include "megbrain/imperative/ops/opr_attr.h"
  19. #include "megbrain/imperative/ops/utility.h"
  20. #include "megbrain/imperative/ops/autogen.h"
  21. #include "megbrain/imperative/ops/rng.h"
  22. #include <Python.h>
  23. #include <unordered_map>
  24. namespace py = pybind11;
  25. using namespace mgb::imperative;
  26. namespace {
  27. auto normalize_enum(const std::string& in) {
  28. std::string ret;
  29. for (auto&& c : in) {
  30. ret += toupper(c);
  31. }
  32. return ret;
  33. }
  34. } // anonymous namespace
  35. #define CATCH_ALL(RETVAL) \
  36. catch(py::error_already_set& e) { \
  37. e.restore(); \
  38. return RETVAL; \
  39. } catch(py::builtin_exception& e) { \
  40. e.set_error(); \
  41. return RETVAL; \
  42. } catch(std::exception& e) { \
  43. PyErr_SetString(PyExc_RuntimeError, e.what()); \
  44. return RETVAL; \
  45. } \
  46. namespace {
  47. #define PyOp(name) Py##name
  48. #define PyOpType(name) PyOp(name)::py_type
  49. #define PyOpDefBegin(name) \
  50. struct PyOp(name) : PyOpDef { \
  51. using Ty = name; \
  52. Ty& inst() { return op->cast_final_safe<Ty>(); } \
  53. static PyTypeObject py_type;
  54. #define PyOpDefEnd(name) \
  55. }; \
  56. PyTypeObject PyOpType(name);
  57. #define RETURN_RICHCOMPARE(val1, val2, op) \
  58. do { \
  59. switch (op) { \
  60. case Py_EQ: if ((val1) == (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  61. case Py_NE: if ((val1) != (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  62. case Py_LT: if ((val1) < (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  63. case Py_GT: if ((val1) > (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  64. case Py_LE: if ((val1) <= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  65. case Py_GE: if ((val1) >= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  66. default: \
  67. Py_FatalError("Unreachable C code path reached"); \
  68. } \
  69. } while (0)
  70. template <typename T>
  71. PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) {
  72. PyObject* obj = type->tp_alloc(type, 0);
  73. T* self = reinterpret_cast<T*>(obj);
  74. if (self != NULL) {
  75. self->op = T::Ty::make();
  76. }
  77. return obj;
  78. }
  79. template<typename T, typename SNIFAE=void>
  80. struct serialization {
  81. static T load(py::object obj) {
  82. return py::cast<T>(obj);
  83. }
  84. template<typename U,
  85. typename = std::enable_if_t<std::is_same_v<T, std::decay_t<U>>>>
  86. static py::object dump(U&& t) {
  87. return py::cast(std::forward<U>(t));
  88. }
  89. };
  90. template<typename T>
  91. void py_dealloc_generic(PyObject* obj) {
  92. reinterpret_cast<T*>(obj)->op.reset();
  93. Py_TYPE(obj)->tp_free(obj);
  94. }
  95. template<typename T, typename U, U T::Ty::*attr>
  96. PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
  97. auto& op = reinterpret_cast<T*>(obj)->inst();
  98. return py::cast(op.*attr).release().ptr();
  99. }
  100. #define py_get_generic(name, attr) \
  101. py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
  102. template<typename T, typename U, U T::Ty::*attr>
  103. int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
  104. if (value == NULL) {
  105. PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
  106. return -1;
  107. }
  108. auto& op = reinterpret_cast<T*>(obj)->inst();
  109. try {
  110. // TODO: remove this guard which is used for pybind11 implicit conversion
  111. py::detail::loader_life_support guard{};
  112. op.*attr = py::cast<U>(py::handle(value));
  113. } CATCH_ALL(-1)
  114. return 0;
  115. }
  116. #define py_set_generic(name, attr) \
  117. py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
  118. struct PyOpDef {
  119. PyObject_HEAD
  120. std::shared_ptr<OpDef> op;
  121. static PyTypeObject py_type;
  122. static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype;
  123. static PyGetSetDef py_getsetters[];
  124. static Py_hash_t tp_hash(PyObject *obj);
  125. static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op);
  126. static PyObject* py_repr(PyObject* self) {
  127. return py::cast(
  128. reinterpret_cast<PyOpDef*>(self)->op->make_name())
  129. .release()
  130. .ptr();
  131. }
  132. };
  133. PyTypeObject PyOpType(OpDef);
  134. std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype;
  135. PyObject* py_get_scope(PyObject* obj, void* /* closure */) {
  136. return py::cast(
  137. reinterpret_cast<PyOp(OpDef)*>(obj)->op->scope()).release().ptr();
  138. }
  139. int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
  140. if (value == NULL) {
  141. PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
  142. return -1;
  143. }
  144. try {
  145. reinterpret_cast<PyOp(OpDef)*>(obj)->op
  146. ->set_scope(py::cast<std::string>(py::handle(value)));
  147. } CATCH_ALL(-1)
  148. return 0;
  149. }
  150. PyGetSetDef PyOp(OpDef)::py_getsetters[] = {
  151. {const_cast<char*>("scope"), py_get_scope, py_set_scope, "scope", NULL},
  152. {NULL}
  153. };
  154. Py_hash_t PyOp(OpDef)::tp_hash(PyObject *obj) {
  155. return static_cast<Py_hash_t>(
  156. reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash());
  157. }
  158. PyObject* PyOp(OpDef)::tp_richcompare(PyObject *self, PyObject *other, int op) {
  159. bool same = reinterpret_cast<PyOp(OpDef)*>(self)->op->is_same(
  160. *reinterpret_cast<PyOp(OpDef)*>(other)->op);
  161. if (op == Py_EQ || op == Py_NE) {
  162. RETURN_RICHCOMPARE(same, true, op);
  163. }
  164. Py_RETURN_NOTIMPLEMENTED;
  165. }
  166. template<typename T>
  167. struct EnumTrait;
  168. #define PyEnumHead \
  169. static_assert(std::is_enum_v<T>); \
  170. PyObject_HEAD \
  171. T value; \
  172. constexpr static const char *name = EnumTrait<T>::name; \
  173. static PyTypeObject* type; \
  174. static const char* members[]; \
  175. static std::unordered_map<std::string, T> mem2value; \
  176. static PyObject* pyobj_insts[];
  177. template<typename T>
  178. struct EnumWrapper {
  179. PyEnumHead
  180. std::string to_string() const {
  181. return members[static_cast<size_t>(value)];
  182. }
  183. static PyObject* py_repr(PyObject* self) {
  184. return py::cast(
  185. std::string(name) + "." + reinterpret_cast<EnumWrapper*>(self)->to_string())
  186. .release().ptr();
  187. }
  188. static PyObject* py_dump(PyObject* self) {
  189. return py::cast(reinterpret_cast<EnumWrapper*>(self)->to_string())
  190. .release()
  191. .ptr();
  192. }
  193. static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) {
  194. if (op == Py_EQ || op == Py_NE) {
  195. T lhs, rhs;
  196. if (load(other, rhs) && load(self, lhs)) {
  197. RETURN_RICHCOMPARE(lhs, rhs, op);
  198. } else {
  199. RETURN_RICHCOMPARE(0, 1, op);
  200. }
  201. }
  202. Py_RETURN_NOTIMPLEMENTED;
  203. }
  204. static bool load(py::handle src, T& value) {
  205. PyObject* obj = src.ptr();
  206. if (PyObject_TypeCheck(obj, type)) {
  207. value = reinterpret_cast<EnumWrapper*>(obj)->value;
  208. return true;
  209. }
  210. if (py::isinstance<py::str>(src)) {
  211. auto&& iter = mem2value.find(
  212. normalize_enum(py::cast<std::string>(src)));
  213. if (iter != mem2value.end()) {
  214. value = iter->second;
  215. return true;
  216. } else {
  217. return false;
  218. }
  219. }
  220. return false;
  221. }
  222. static PyObject* cast(const T& value) {
  223. auto v = static_cast<std::underlying_type_t<T>>(value);
  224. mgb_assert(v <= EnumTrait<T>::max);
  225. PyObject* obj = pyobj_insts[v];
  226. Py_INCREF(obj);
  227. return obj;
  228. }
  229. };
  230. template<typename T>
  231. struct BitCombinedEnumWrapper {
  232. PyEnumHead
  233. std::string to_string() const {
  234. uint32_t value_int = static_cast<uint32_t>(value);
  235. if (value_int == 0) {
  236. return "None";
  237. } else {
  238. std::string ret;
  239. bool first = true;
  240. for (uint32_t i = 0; i < 32; i++) {
  241. if (value_int >> i & 1) {
  242. if (!first) {
  243. ret += " + ";
  244. } else {
  245. first = false;
  246. }
  247. ret += (std::string(name) + "." + members[i]);
  248. }
  249. }
  250. return ret;
  251. }
  252. }
  253. static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject* args, PyObject*) {
  254. if (!PyTuple_Size(args)) {
  255. PyObject* obj = type->tp_alloc(type, 0);
  256. reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T();
  257. return obj;
  258. }
  259. else {
  260. PyObject* input;
  261. if (!PyArg_ParseTuple(args, "|O", &input)) {
  262. return nullptr;
  263. }
  264. T value;
  265. if (load(input, value)) {
  266. return cast(value);
  267. } else {
  268. PyErr_SetString(PyExc_RuntimeError,
  269. mgb::ssprintf("Cannot convert type %s to type %s\n",
  270. input->ob_type->tp_name, name).c_str());
  271. return nullptr;
  272. }
  273. }
  274. }
  275. static PyObject* py_repr(PyObject* self) {
  276. return py::cast(
  277. reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string())
  278. .release().ptr();
  279. }
  280. static PyObject* py_dump(PyObject* self) {
  281. std::vector<std::string> result;
  282. auto value = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value;
  283. uint32_t value_int = static_cast<uint32_t>(value);
  284. for (uint32_t i = 0; i < 32; i++) {
  285. if (value_int >> i & 1) {
  286. result.push_back(members[i]);
  287. }
  288. }
  289. return py::tuple(py::cast(result)).release().ptr();
  290. }
  291. static PyObject* py_or(PyObject* self, PyObject* other) {
  292. if(!(self->ob_type == other->ob_type)){
  293. return PyErr_Format(
  294. PyExc_RuntimeError,
  295. "Operand in or operator must be the same type.");
  296. }
  297. T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
  298. rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
  299. return cast(lhs | rhs);
  300. }
  301. static PyObject* py_and(PyObject* self, PyObject* other) {
  302. if (!(self->ob_type == other->ob_type)) {
  303. return PyErr_Format(
  304. PyExc_RuntimeError,
  305. "Operand in and operator must be the same type.");
  306. }
  307. T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
  308. rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
  309. return cast(lhs & rhs);
  310. }
  311. static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
  312. if (op == Py_EQ || op == Py_NE) {
  313. T lhs, rhs;
  314. if (load(other, rhs) && load(self, lhs)) {
  315. RETURN_RICHCOMPARE(lhs, rhs, op);
  316. } else {
  317. RETURN_RICHCOMPARE(0, 1, op);
  318. }
  319. }
  320. Py_RETURN_NOTIMPLEMENTED;
  321. }
  322. static bool load(py::handle src, T& value) {
  323. PyObject* obj = src.ptr();
  324. if (PyObject_TypeCheck(obj, type)) {
  325. value = reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value;
  326. return true;
  327. }
  328. if (py::isinstance<py::str>(src)) {
  329. auto&& iter = mem2value.find(
  330. normalize_enum(py::cast<std::string>(src)));
  331. if (iter != mem2value.end()) {
  332. value = iter->second;
  333. return true;
  334. } else {
  335. return false;
  336. }
  337. }
  338. if (py::isinstance<py::tuple>(src)) {
  339. auto params = py::cast<std::vector<std::string>>(src);
  340. bool first = true;
  341. for (auto s : params){
  342. auto&& iter = mem2value.find(normalize_enum(s));
  343. if (iter != mem2value.end()) {
  344. if (first) {
  345. value = iter->second;
  346. first = false;
  347. } else {
  348. value |= iter->second;
  349. }
  350. } else {
  351. return false;
  352. }
  353. }
  354. return true;
  355. }
  356. if (py::isinstance<py::int_>(obj)) {
  357. auto v = py::cast<std::underlying_type_t<T>>(src);
  358. if(v > EnumTrait<T>::max) {
  359. return false;
  360. }
  361. value = static_cast<T>(v);
  362. return true;
  363. }
  364. return false;
  365. }
  366. static PyObject* cast(const T& value) {
  367. auto v = static_cast<std::underlying_type_t<T>>(value);
  368. mgb_assert(v <= EnumTrait<T>::max);
  369. if ((!v) || (v & (v - 1))) {
  370. PyObject* obj = type->tp_alloc(type, 0);
  371. reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
  372. return obj;
  373. } else {
  374. PyObject* obj = pyobj_insts[__builtin_ctz(v)];
  375. Py_INCREF(obj);
  376. return obj;
  377. }
  378. }
  379. };
  380. template<typename T>
  381. struct serialization<T,
  382. std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> {
  383. static T load(py::object obj) {
  384. auto caster = pybind11::detail::type_caster<T>();
  385. if (caster.load(obj, true)) {
  386. return caster;
  387. } else {
  388. PyErr_SetString(PyExc_RuntimeError,
  389. "load faild \n");
  390. return caster;
  391. }
  392. }
  393. static py::object dump(T t) {
  394. return py::cast(t).attr("dump")();
  395. }
  396. };
  397. void _init_py_op_def(py::module m) {
  398. using py_op = PyOp(OpDef);
  399. auto& py_type = PyOpType(OpDef);
  400. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  401. py_type.tp_name = "megengine.core._imperative_rt.OpDef";
  402. py_type.tp_basicsize = sizeof(PyOp(OpDef));
  403. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  404. py_type.tp_doc = "OpDef";
  405. py_type.tp_base = &PyBaseObject_Type;
  406. py_type.tp_hash = PyOp(OpDef)::tp_hash;
  407. py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare;
  408. py_type.tp_getset = py_op::py_getsetters;
  409. py_type.tp_repr = py_op::py_repr;
  410. mgb_assert(PyType_Ready(&py_type) >= 0);
  411. m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type));
  412. }
  413. /*********** begin of hand-write opdefs **************/
  414. struct PyOpBase : PyOpDef {
  415. static PyTypeObject py_type;
  416. static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) {
  417. auto* obj = type->tp_alloc(type, 0);
  418. if (obj) {
  419. auto* self = reinterpret_cast<PyOpBase*>(obj);
  420. new(&self->op) decltype(self->op);
  421. }
  422. return obj;
  423. }
  424. };
  425. PyTypeObject PyOpBase::py_type;
  426. void _init_py_op_base(py::module m) {
  427. using py_op = PyOpBase;
  428. auto& py_type = PyOpBase::py_type;
  429. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  430. py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase";
  431. py_type.tp_basicsize = sizeof(py_op);
  432. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  433. py_type.tp_doc = "PyOpBase";
  434. py_type.tp_base = &PyOpType(OpDef);
  435. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  436. py_type.tp_new = py_op::tp_new;
  437. mgb_assert(PyType_Ready(&py_type) >= 0);
  438. m.add_object("PyOpBase", reinterpret_cast<PyObject*>(&py_type));
  439. }
  440. /*********** end of hand-write opdefs **************/
  441. // auto generated opdefs
  442. #include "opdef.cpy.inl"
  443. #undef CATCH_ALL
  444. } // anonymous namespace
  445. namespace PYBIND11_NAMESPACE {
  446. namespace detail {
  447. bool type_caster<OpDef>::load(handle src, bool convert) {
  448. PyObject* obj = src.ptr();
  449. if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) {
  450. return false;
  451. }
  452. value = reinterpret_cast<PyOp(OpDef)*>(obj)->op;
  453. if (!value) {
  454. // opdef only defined in Python
  455. value = std::make_shared<GenericPyOp>(reinterpret_borrow<object>(src));
  456. }
  457. return true;
  458. }
  459. handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) {
  460. if (auto* pyop = op.try_cast_final<GenericPyOp>()) {
  461. return object(pyop->obj).release();
  462. }
  463. PyTypeObject* pytype;
  464. auto& c2p = PyOp(OpDef)::ctype2pytype;
  465. auto&& iter = c2p.find(op.dyn_typeinfo());
  466. if (iter != c2p.end()) { // FIXME: should always meet this condition
  467. pytype = iter->second;
  468. } else { // which means unregistered op type, jsut make it as an opaque op type
  469. // currently, only OprAttr goes into this branch
  470. pytype = &PyOpType(OpDef);
  471. }
  472. PyObject* obj = pytype->tp_alloc(pytype, 0);
  473. mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef)));
  474. reinterpret_cast<PyOp(OpDef)*>(obj)->op = const_cast<OpDef&>(op).shared_from_this();
  475. return py::handle(obj);
  476. }
  477. #define ENUM_CASTER_IMPL(T) \
  478. bool type_caster<T>::load(handle src, bool) { \
  479. return EnumWrapper<T>::load(src, value); \
  480. } \
  481. handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
  482. return EnumWrapper<T>::cast(value); \
  483. }
  484. FOR_EACH_ENUM_PARAM(ENUM_CASTER_IMPL)
  485. #define BIT_COMBINED_ENUM_CASTER_IMPL(T) \
  486. bool type_caster<T>::load(handle src, bool) { \
  487. return BitCombinedEnumWrapper<T>::load(src, value); \
  488. } \
  489. handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
  490. return BitCombinedEnumWrapper<T>::cast(value); \
  491. }
  492. FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL)
  493. } // detail
  494. } // PYBIND11_NAMESPACE
  495. void init_ops(py::module m) {
  496. _init_py_op_def(m);
  497. _init_py_op_base(m);
  498. INIT_ALL_OP(m)
  499. m.def("new_rng_handle", &rng::new_handle);
  500. m.def("delete_rng_handle", [](size_t handle){
  501. // RNG op might execute after handle released due to async dispatch, so
  502. // we need sync before delete a handle to avoid memory leak or use-after-free
  503. if(python::interpreter_for_py->check_available()){
  504. python::interpreter_for_py->sync();
  505. }
  506. mgb::CompNode::sync_all();
  507. py_task_q.wait_all_task_finish();
  508. rng::delete_handle(handle);
  509. }, py::call_guard<py::gil_scoped_release>());
  510. m.def("set_global_rng_seed", &rng::set_global_rng_seed);
  511. m.def("get_global_rng_seed", &rng::get_global_rng_seed);
  512. m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode);
  513. struct PySubgraphBuilder {
  514. explicit PySubgraphBuilder(std::string name) : name{name}{}
  515. std::string name;
  516. std::shared_ptr<Subgraph> graph_storage = std::make_shared<Subgraph>();
  517. std::shared_ptr<UniqueKey> graph_key = std::make_shared<UniqueKey>();
  518. Subgraph& graph = *graph_storage;
  519. mgb::SmallVector<bool> output_grad_mask;
  520. Subgraph::var_t next_var = 1;
  521. std::shared_ptr<OpDef> build() const {
  522. return SubgraphOp::make(name, graph_storage, output_grad_mask, graph_key);
  523. }
  524. };
  525. py::class_<PySubgraphBuilder>(m, "SubgraphBuilder")
  526. .def(py::init<std::string>())
  527. .def("input", [](PySubgraphBuilder& self){
  528. auto var = self.next_var++;
  529. self.graph.inputs.push_back(var);
  530. return var;
  531. })
  532. .def("apply", [](PySubgraphBuilder& self, std::shared_ptr<OpDef> op, Subgraph::vars_t inputs, size_t nr_outputs){
  533. Subgraph::vars_t outputs;
  534. for (size_t i = 0; i < nr_outputs; ++i) {
  535. outputs.push_back(self.next_var++);
  536. }
  537. self.graph.exprs.push_back({op, inputs, outputs});
  538. return outputs;
  539. })
  540. .def("apply_const", [](PySubgraphBuilder& self, py::object value, mgb::DType dtype, mgb::CompNode cn){
  541. auto var = self.next_var++;
  542. mgb::HostTensorND hvalue(cn);
  543. npy::np2tensor(value.cast<py::array>().ptr(), npy::Meth::copy_into(&hvalue), dtype);
  544. self.graph.constants.push_back({var, Tensor::make(hvalue)});
  545. return var;
  546. })
  547. .def("outputs", [](PySubgraphBuilder& self, Subgraph::vars_t outputs){
  548. self.graph.outputs = outputs;
  549. self.output_grad_mask.resize(outputs.size(), true);
  550. })
  551. .def("outputs_has_grad", [](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad){
  552. mgb_assert(self.graph.outputs.size() == self.output_grad_mask.size());
  553. self.output_grad_mask = outputs_has_grad;
  554. })
  555. .def("get", [](PySubgraphBuilder& self){
  556. return (std::shared_ptr<OpDef>)self.build();
  557. })
  558. .def("compile", [](PySubgraphBuilder& self, int gopt_level){
  559. return (std::shared_ptr<OpDef>)CompiledOp::make(self.build(), gopt_level);
  560. });
  561. auto custom = submodule(m, "_custom");
  562. init_custom(custom);
  563. }
  564. #define CUSTOM_CASE_TO_PARSE_NON_LIST(dyn_type, static_type) \
  565. case custom::ParamDynType::dyn_type: { \
  566. param_val = py::handle(kv.second).cast<static_type>(); \
  567. break; \
  568. }
  569. #define CUSTOM_CASE_TO_PARSE_LIST(dyn_type, static_type) \
  570. case custom::ParamDynType::dyn_type: { \
  571. auto pyvals = py::handle(kv.second).cast<py::list>(); \
  572. static_type vals; \
  573. using basic_type = \
  574. custom::get_vector_template_arg_type<static_type>::type; \
  575. for (auto &pyval: pyvals) { \
  576. vals.push_back(py::handle(pyval).cast<basic_type>()); \
  577. } \
  578. param_val = vals; \
  579. break; \
  580. }
  581. PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs) {
  582. #if MGB_CUSTOM_OP
  583. auto op_name = py::handle(args[0]).cast<std::string>();
  584. auto kwargs = py::handle(args[1]).cast<py::dict>();
  585. std::shared_ptr<OpDef> opdef = CustomOpDefFactory::inst()->create_opdef(op_name);
  586. auto &custom_opdef = static_cast<mgb::imperative::CustomOpDef&>(*opdef);
  587. auto &param = custom_opdef.param();
  588. for (auto &&kv: kwargs) {
  589. std::string param_name = py::handle(kv.first).cast<std::string>();
  590. std::string type_name = py::handle(kv.second).ptr()->ob_type->tp_name;
  591. if (!param.exist(param_name)) {
  592. mgb_log_warn(
  593. "op %s have no param named %s, ignore this param parsed from python",
  594. op_name.c_str(), param_name.c_str()
  595. );
  596. continue;
  597. }
  598. auto& param_val = param[param_name];
  599. switch (param_val.type()) {
  600. CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST)
  601. CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST)
  602. CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
  603. CUSTOM_FOR_BOOL_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
  604. CUSTOM_FOR_STRING_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
  605. default: {
  606. mgb_assert(
  607. false, "param dtype of %s:%s is invalid",
  608. op_name.c_str(), param_name.c_str()
  609. );
  610. }
  611. }
  612. }
  613. PyTypeObject* pytype;
  614. pytype = &PyOpType(OpDef);
  615. PyObject* obj = pytype->tp_alloc(pytype, 0);
  616. reinterpret_cast<PyOp(OpDef)*>(obj)->op = opdef;
  617. return obj;
  618. #else
  619. mgb_assert(false, "Custom Op is disabled now, please build megengine with Custom Op open");
  620. return nullptr;
  621. #endif
  622. }
  623. #undef CUSTOM_CASE_TO_PARSE_LIST
  624. #undef CUSTOM_CASE_TO_PARSE_NON_LIST
  625. py::list install_custom(const std::string &name, const std::string &path) {
  626. #if MGB_CUSTOM_OP
  627. py::list ret;
  628. const auto &ops_in_lib = custom::LibManager::inst()->install(name, path);
  629. for (const auto &op: ops_in_lib) {
  630. ret.append(op);
  631. }
  632. return ret;
  633. #else
  634. mgb_assert(false, "Custom Op is disabled now, please build megengine with Custom Op open");
  635. py::list ret;
  636. return ret;
  637. #endif
  638. }
  639. bool uninstall_custom(const std::string &name) {
  640. #if MGB_CUSTOM_OP
  641. return custom::LibManager::inst()->uninstall(name);
  642. #else
  643. mgb_assert(false, "Custom Op is disabled now, please build megengine with Custom Op open");
  644. return false;
  645. #endif
  646. }
  647. py::list get_custom_op_list(void) {
  648. #if MGB_CUSTOM_OP
  649. std::vector<std::string> all_ops = CustomOpDefFactory::inst()->op_list();
  650. py::list ret;
  651. for (auto &op: all_ops) {
  652. ret.append(op);
  653. }
  654. return ret;
  655. #else
  656. mgb_assert(false, "Custom Op is disabled now, please build megengine with Custom Op open");
  657. py::list ret;
  658. return ret;
  659. #endif
  660. }
  661. #ifndef METH_FASTCALL
  662. PyObject* py35_make_custom_op(PyObject* self, PyObject* args) {
  663. auto* arr = &PyTuple_GET_ITEM(args, 0);
  664. auto size = PyTuple_GET_SIZE(args);
  665. return make_custom_op(self, arr, size);
  666. };
  667. #endif
  668. void init_custom(pybind11::module m) {
  669. m.def("_install", &install_custom);
  670. m.def("_uninstall", &uninstall_custom);
  671. m.def("_get_custom_op_list", &get_custom_op_list);
  672. static PyMethodDef method_def = {
  673. #ifdef METH_FASTCALL
  674. "_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, ""
  675. #else
  676. "_make_custom_op", (PyCFunction)py35_make_custom_op, METH_VARARGS, ""
  677. #endif
  678. };
  679. auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr);
  680. pybind11::setattr(m, method_def.ml_name, func);
  681. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台