You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python_c_extension.cpp 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. #include "python_c_extension.h"
  2. #include "../emitter.h"
  3. namespace mlir::tblgen {
  4. namespace {
  5. struct Initproc {
  6. std::string func;
  7. Initproc(std::string&& s) : func(std::move(s)) {}
  8. std::string operator()(std::string argument) {
  9. return formatv("{0}({1})", func, argument);
  10. }
  11. };
  12. class OpDefEmitter : public EmitterBase {
  13. public:
  14. OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_)
  15. : EmitterBase(os_, env_), op(op_) {
  16. ctx.withSelf(op.getCppClassName());
  17. }
  18. Initproc emit();
  19. private:
  20. void emit_class();
  21. void emit_py_init();
  22. void emit_py_getsetters();
  23. void emit_py_methods();
  24. Initproc emit_initproc();
  25. MgbOp& op;
  26. std::vector<Initproc> subclasses;
  27. mlir::tblgen::FmtContext ctx;
  28. };
  29. class EnumAttrEmitter : public EmitterBase {
  30. public:
  31. EnumAttrEmitter(
  32. llvm::StringRef parent, MgbEnumAttr* attr_, raw_ostream& os_,
  33. Environment& env_)
  34. : EmitterBase(os_, env_), attr(attr_) {
  35. unsigned int enumID;
  36. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  37. auto&& aliasBase = alias->getAliasBase();
  38. enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
  39. } else {
  40. enumID = attr->getBaseRecord()->getID();
  41. }
  42. ctx.addSubst(
  43. "enumTpl",
  44. attr->getEnumCombinedFlag() ? "BitCombinedEnumWrapper" : "EnumWrapper");
  45. ctx.addSubst("opClass", parent);
  46. ctx.addSubst("enumClass", attr->getEnumName());
  47. firstOccur =
  48. env().enumAlias
  49. .emplace(enumID, std::make_pair(parent, attr->getEnumName()))
  50. .second;
  51. }
  52. Initproc emit();
  53. protected:
  54. void emit_trait();
  55. void emit_tpl_spl();
  56. Initproc emit_initproc();
  57. MgbEnumAttr* attr;
  58. bool firstOccur;
  59. mlir::tblgen::FmtContext ctx;
  60. };
  61. Initproc EnumAttrEmitter::emit() {
  62. emit_trait();
  63. emit_tpl_spl();
  64. return emit_initproc();
  65. }
  66. void EnumAttrEmitter::emit_trait() {
  67. if (!firstOccur)
  68. return;
  69. auto enumMax = [&] {
  70. if (attr->getEnumCombinedFlag()) {
  71. return formatv("(1llu << {0}) - 1", attr->getEnumMembers().size());
  72. } else {
  73. return formatv("{0} - 1", attr->getEnumMembers().size());
  74. }
  75. };
  76. os << tgfmt(
  77. R"(
  78. template<> struct EnumTrait<$opClass::$enumClass> {
  79. static constexpr const char *name = "$opClass.$enumClass";
  80. static constexpr std::underlying_type_t<$opClass::$enumClass> max = $0;
  81. };
  82. )",
  83. &ctx, enumMax());
  84. }
  85. void EnumAttrEmitter::emit_tpl_spl() {
  86. if (!firstOccur)
  87. return;
  88. os << tgfmt(
  89. "template<> PyTypeObject* $enumTpl<$opClass::$enumClass>::type = "
  90. "nullptr;\n",
  91. &ctx);
  92. auto quote = [&](auto&& i) -> std::string {
  93. size_t d1 = i.find(' ');
  94. size_t d2 = i.find('=');
  95. size_t d = d1 <= d2 ? d1 : d2;
  96. return formatv("\"{0}\"", i.substr(0, d));
  97. };
  98. os << tgfmt(
  99. R"(
  100. template<> const char*
  101. $enumTpl<$opClass::$enumClass>::members[] = {$0};
  102. )",
  103. &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", "));
  104. auto mem2value = [&](auto&& i) -> std::string {
  105. size_t d1 = i.find(' ');
  106. size_t d2 = i.find('=');
  107. size_t d = d1 <= d2 ? d1 : d2;
  108. return tgfmt(
  109. "{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx,
  110. i.substr(0, d));
  111. };
  112. os << tgfmt(
  113. R"(
  114. template<> std::unordered_map<std::string, $opClass::$enumClass>
  115. $enumTpl<$opClass::$enumClass>::mem2value = {$0};
  116. )",
  117. &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), mem2value), ", "));
  118. os << tgfmt(
  119. "template<> PyObject* "
  120. "$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n",
  121. &ctx, attr->getEnumMembers().size());
  122. }
  123. Initproc EnumAttrEmitter::emit_initproc() {
  124. std::string initproc =
  125. formatv("_init_py_{0}_{1}", ctx.getSubstFor("opClass"),
  126. ctx.getSubstFor("enumClass"));
  127. os << tgfmt(
  128. R"(
  129. void $0(PyTypeObject& py_type) {
  130. auto& e_type = $enumTpl<$opClass::$enumClass>::type;
  131. )",
  132. &ctx, initproc);
  133. if (firstOccur) {
  134. os << tgfmt(
  135. R"(
  136. static PyMethodDef tp_methods[] = {
  137. {const_cast<char*>("dump"), (PyCFunction)$enumTpl<$opClass::$enumClass>::py_dump, METH_NOARGS, NULL},
  138. {NULL} /* Sentinel */
  139. };
  140. )",
  141. &ctx);
  142. os << tgfmt(
  143. R"(
  144. static PyType_Slot slots[] = {
  145. {Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr},
  146. {Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare},
  147. {Py_tp_methods, tp_methods},
  148. )",
  149. &ctx);
  150. if (attr->getEnumCombinedFlag()) {
  151. // only bit combined enum could new instance because bitwise operation,
  152. // others should always use singleton
  153. os << tgfmt(
  154. R"(
  155. {Py_tp_new, (void*)$enumTpl<$opClass::$enumClass>::py_new_combined_enum},
  156. {Py_nb_or, (void*)$enumTpl<$opClass::$enumClass>::py_or},
  157. {Py_nb_and, (void*)$enumTpl<$opClass::$enumClass>::py_and},
  158. )",
  159. &ctx);
  160. }
  161. os << R"(
  162. {0, NULL}
  163. };)";
  164. os << tgfmt(
  165. R"(
  166. static PyType_Spec spec = {
  167. // name
  168. "megengine.core._imperative_rt.ops.$opClass.$enumClass",
  169. // basicsize
  170. sizeof($enumTpl<$opClass::$enumClass>),
  171. // itemsize
  172. 0,
  173. // flags
  174. Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
  175. // slots
  176. slots
  177. };)",
  178. &ctx);
  179. os << tgfmt(
  180. R"(
  181. e_type = reinterpret_cast<PyTypeObject*>(PyType_FromSpec(&spec));
  182. )",
  183. &ctx);
  184. for (auto&& i :
  185. {std::pair<std::string, std::string>{
  186. "__name__", tgfmt("$enumClass", &ctx)},
  187. {"__module__", "megengine.core._imperative_rt.ops"},
  188. {"__qualname__", tgfmt("$opClass.$enumClass", &ctx)}}) {
  189. os << formatv(
  190. R"(
  191. mgb_assert(
  192. e_type->tp_setattro(
  193. reinterpret_cast<PyObject*>(e_type),
  194. py::cast("{0}").release().ptr(),
  195. py::cast("{1}").release().ptr()) >= 0);
  196. )",
  197. i.first, i.second);
  198. }
  199. auto&& members = attr->getEnumMembers();
  200. for (size_t idx = 0; idx < members.size(); ++idx) {
  201. size_t d1 = members[idx].find(' ');
  202. size_t d2 = members[idx].find('=');
  203. size_t d = d1 <= d2 ? d1 : d2;
  204. os << tgfmt(
  205. R"({
  206. PyObject* inst = e_type->tp_alloc(e_type, 0);
  207. reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
  208. mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
  209. $enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
  210. })",
  211. &ctx, members[idx].substr(0, d), idx);
  212. }
  213. }
  214. os << tgfmt(
  215. R"(
  216. Py_INCREF(e_type);
  217. mgb_assert(PyDict_SetItemString(
  218. py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(e_type)) >= 0);
  219. )",
  220. &ctx);
  221. os << "}\n";
  222. return initproc;
  223. }
  224. Initproc OpDefEmitter::emit() {
  225. for (auto&& i : op.getMgbAttributes()) {
  226. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  227. subclasses.push_back(
  228. EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit());
  229. }
  230. }
  231. emit_class();
  232. emit_py_init();
  233. emit_py_getsetters();
  234. emit_py_methods();
  235. return emit_initproc();
  236. }
  237. void OpDefEmitter::emit_class() {
  238. auto&& className = op.getCppClassName();
  239. std::string method_defs;
  240. std::vector<std::string> body;
  241. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  242. body.push_back(
  243. formatv(R"(
  244. {{"{0}", serialization<decltype(opdef.{0})>::dump(opdef.{0})})",
  245. attr.name));
  246. });
  247. method_defs +=
  248. formatv(R"(
  249. static PyObject* getstate(PyObject* self, PyObject*) {{
  250. auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst();
  251. static_cast<void>(opdef);
  252. std::unordered_map<std::string, py::object> state {{
  253. {1}
  254. };
  255. return py::cast(state).release().ptr();
  256. })",
  257. className, llvm::join(body, ","));
  258. body.clear();
  259. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  260. body.push_back(
  261. formatv(R"(
  262. {{
  263. auto&& iter = state.find("{0}");
  264. if (iter != state.end()) {
  265. opdef.{0} = serialization<decltype(opdef.{0})>::load(iter->second);
  266. }
  267. })",
  268. attr.name));
  269. });
  270. method_defs +=
  271. formatv(R"(
  272. static PyObject* setstate(PyObject* self, PyObject* args) {{
  273. PyObject* dict = PyTuple_GetItem(args, 0);
  274. if (!dict) return NULL;
  275. auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
  276. auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst();
  277. static_cast<void>(opdef);
  278. {1}
  279. Py_RETURN_NONE;
  280. })",
  281. className, llvm::join(body, "\n"));
  282. os << tgfmt(
  283. R"(
  284. PyOpDefBegin($_self) // {
  285. static PyGetSetDef py_getsetters[];
  286. static PyMethodDef tp_methods[];
  287. $0
  288. static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
  289. // };
  290. PyOpDefEnd($_self)
  291. )",
  292. &ctx, method_defs);
  293. }
  294. void OpDefEmitter::emit_py_init() {
  295. std::string initBody;
  296. if (!op.getMgbAttributes().empty()) {
  297. initBody += "static const char* kwlist[] = {";
  298. std::vector<llvm::StringRef> attr_name_list;
  299. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  300. attr_name_list.push_back(attr.name);
  301. });
  302. attr_name_list.push_back("scope");
  303. llvm::for_each(attr_name_list, [&](auto&& attr) {
  304. initBody += formatv("\"{0}\", ", attr);
  305. });
  306. initBody += "NULL};\n";
  307. initBody += " PyObject ";
  308. auto initializer = [&](auto&& attr) -> std::string {
  309. return formatv("*{0} = NULL", attr);
  310. };
  311. initBody +=
  312. llvm::join(llvm::map_range(attr_name_list, initializer), ", ") + ";\n";
  313. initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
  314. // an extra slot created for name
  315. initBody += std::string(attr_name_list.size(), 'O');
  316. initBody += "\", const_cast<char**>(kwlist)";
  317. llvm::for_each(attr_name_list, [&](auto&& attr) {
  318. initBody += formatv(", &{0}", attr);
  319. });
  320. initBody += "))\n";
  321. initBody += " return -1;\n";
  322. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  323. initBody +=
  324. tgfmt(R"(
  325. if ($0) {
  326. try {
  327. // TODO: remove this guard which is used for pybind11 implicit conversion
  328. py::detail::loader_life_support guard{};
  329. reinterpret_cast<PyOp($_self)*>(self)->inst().$0 =
  330. py::cast<decltype($_self::$0)>(py::handle($0));
  331. } CATCH_ALL(-1)
  332. }
  333. )",
  334. &ctx, attr.name);
  335. });
  336. initBody +=
  337. tgfmt(R"(
  338. if (scope) {
  339. try {
  340. reinterpret_cast<PyOp(OpDef)*>(self)->op
  341. ->set_scope(py::cast<std::string>(py::handle(scope)));
  342. } CATCH_ALL(-1)
  343. }
  344. )",
  345. &ctx);
  346. }
  347. initBody += "\n return 0;";
  348. os << tgfmt(
  349. R"(
  350. int PyOp($_self)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
  351. $0
  352. }
  353. )",
  354. &ctx, initBody);
  355. }
  356. void OpDefEmitter::emit_py_getsetters() {
  357. auto f = [&](auto&& attr) -> std::string {
  358. return tgfmt(
  359. "{const_cast<char*>(\"$0\"), py_get_generic($_self, $0), "
  360. "py_set_generic($_self, $0), const_cast<char*>(\"$0\"), NULL},",
  361. &ctx, attr.name);
  362. };
  363. os << tgfmt(
  364. R"(
  365. PyGetSetDef PyOp($_self)::py_getsetters[] = {
  366. $0
  367. {NULL} /* Sentinel */
  368. };
  369. )",
  370. &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n "));
  371. }
  372. void OpDefEmitter::emit_py_methods() {
  373. // generate methods
  374. std::string method_defs;
  375. std::vector<std::string> method_items;
  376. {
  377. auto&& className = op.getCppClassName();
  378. // generate getstate
  379. method_items.push_back(
  380. formatv("{{const_cast<char*>(\"__getstate__\"), PyOp({0})::getstate, "
  381. "METH_NOARGS, \"{0} getstate\"},",
  382. className));
  383. // generate setstate
  384. method_items.push_back(
  385. formatv("{{const_cast<char*>(\"__setstate__\"), PyOp({0})::setstate, "
  386. "METH_VARARGS, \"{0} setstate\"},",
  387. className));
  388. }
  389. os << tgfmt(
  390. R"(
  391. PyMethodDef PyOp($_self)::tp_methods[] = {
  392. $0
  393. {NULL} /* Sentinel */
  394. };
  395. )",
  396. &ctx, llvm::join(method_items, "\n "));
  397. }
  398. Initproc OpDefEmitter::emit_initproc() {
  399. std::string initproc = formatv("_init_py_{0}", op.getCppClassName());
  400. std::string subclass_init_call;
  401. for (auto&& i : subclasses) {
  402. subclass_init_call += formatv(" {0};\n", i("py_type"));
  403. }
  404. os << tgfmt(
  405. R"(
  406. void $0(py::module m) {
  407. using py_op = PyOp($_self);
  408. auto& py_type = PyOpType($_self);
  409. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  410. py_type.tp_name = "megengine.core._imperative_rt.ops.$_self";
  411. py_type.tp_basicsize = sizeof(PyOp($_self));
  412. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  413. py_type.tp_doc = "$_self";
  414. py_type.tp_base = &PyOpType(OpDef);
  415. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  416. py_type.tp_new = py_new_generic<py_op>;
  417. py_type.tp_init = py_op::py_init;
  418. py_type.tp_methods = py_op::tp_methods;
  419. py_type.tp_getset = py_op::py_getsetters;
  420. mgb_assert(PyType_Ready(&py_type) >= 0);
  421. $1
  422. PyType_Modified(&py_type);
  423. m.add_object("$_self", reinterpret_cast<PyObject*>(&py_type));
  424. mgb_assert(PyOp(OpDef)::ctype2pytype.emplace($_self::typeinfo(), &py_type).second);
  425. }
  426. )",
  427. &ctx, initproc, subclass_init_call);
  428. return initproc;
  429. }
  430. } // namespace
  431. bool gen_op_def_python_c_extension(raw_ostream& os, llvm::RecordKeeper& keeper) {
  432. Environment env;
  433. using namespace std::placeholders;
  434. std::vector<Initproc> initprocs;
  435. foreach_operator(keeper, [&](MgbOp& op) {
  436. initprocs.emplace_back(OpDefEmitter(op, os, env).emit());
  437. });
  438. os << "#define INIT_ALL_OP(m)";
  439. for (auto&& init : initprocs) {
  440. os << formatv(" \\\n {0};", init("m"));
  441. }
  442. os << "\n";
  443. return false;
  444. }
  445. } // namespace mlir::tblgen