You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python_c_extension.cpp 16 kB


  1. /**
  2. * \file imperative/tablegen/targets/python_c_extension.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "python_c_extension.h"
  12. #include "../emitter.h"
  13. namespace mlir::tblgen {
  14. namespace {
  15. struct Initproc {
  16. std::string func;
  17. Initproc(std::string&& s) : func(std::move(s)) {}
  18. std::string operator()(std::string argument) {
  19. return formatv("{0}({1})", func, argument);
  20. }
  21. };
  22. class OpDefEmitter : public EmitterBase {
  23. public:
  24. OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_)
  25. : EmitterBase(os_, env_), op(op_) {
  26. ctx.withSelf(op.getCppClassName());
  27. }
  28. Initproc emit();
  29. private:
  30. void emit_class();
  31. void emit_py_init();
  32. void emit_py_getsetters();
  33. void emit_py_methods();
  34. Initproc emit_initproc();
  35. MgbOp& op;
  36. std::vector<Initproc> subclasses;
  37. mlir::tblgen::FmtContext ctx;
  38. };
  39. class EnumAttrEmitter : public EmitterBase {
  40. public:
  41. EnumAttrEmitter(
  42. llvm::StringRef parent, MgbEnumAttr* attr_, raw_ostream& os_,
  43. Environment& env_)
  44. : EmitterBase(os_, env_), attr(attr_) {
  45. unsigned int enumID;
  46. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  47. auto&& aliasBase = alias->getAliasBase();
  48. enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
  49. } else {
  50. enumID = attr->getBaseRecord()->getID();
  51. }
  52. ctx.addSubst(
  53. "enumTpl",
  54. attr->getEnumCombinedFlag() ? "BitCombinedEnumWrapper" : "EnumWrapper");
  55. ctx.addSubst("opClass", parent);
  56. ctx.addSubst("enumClass", attr->getEnumName());
  57. firstOccur =
  58. env().enumAlias
  59. .emplace(enumID, std::make_pair(parent, attr->getEnumName()))
  60. .second;
  61. }
  62. Initproc emit();
  63. protected:
  64. void emit_trait();
  65. void emit_tpl_spl();
  66. Initproc emit_initproc();
  67. MgbEnumAttr* attr;
  68. bool firstOccur;
  69. mlir::tblgen::FmtContext ctx;
  70. };
  71. Initproc EnumAttrEmitter::emit() {
  72. emit_trait();
  73. emit_tpl_spl();
  74. return emit_initproc();
  75. }
  76. void EnumAttrEmitter::emit_trait() {
  77. if (!firstOccur)
  78. return;
  79. auto enumMax = [&] {
  80. if (attr->getEnumCombinedFlag()) {
  81. return formatv("(1llu << {0}) - 1", attr->getEnumMembers().size());
  82. } else {
  83. return formatv("{0} - 1", attr->getEnumMembers().size());
  84. }
  85. };
  86. os << tgfmt(
  87. R"(
  88. template<> struct EnumTrait<$opClass::$enumClass> {
  89. static constexpr const char *name = "$opClass.$enumClass";
  90. static constexpr std::underlying_type_t<$opClass::$enumClass> max = $0;
  91. };
  92. )",
  93. &ctx, enumMax());
  94. }
  95. void EnumAttrEmitter::emit_tpl_spl() {
  96. if (!firstOccur)
  97. return;
  98. os << tgfmt(
  99. "template<> PyTypeObject* $enumTpl<$opClass::$enumClass>::type = "
  100. "nullptr;\n",
  101. &ctx);
  102. auto quote = [&](auto&& i) -> std::string {
  103. size_t d1 = i.find(' ');
  104. size_t d2 = i.find('=');
  105. size_t d = d1 <= d2 ? d1 : d2;
  106. return formatv("\"{0}\"", i.substr(0, d));
  107. };
  108. os << tgfmt(
  109. R"(
  110. template<> const char*
  111. $enumTpl<$opClass::$enumClass>::members[] = {$0};
  112. )",
  113. &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", "));
  114. auto mem2value = [&](auto&& i) -> std::string {
  115. size_t d1 = i.find(' ');
  116. size_t d2 = i.find('=');
  117. size_t d = d1 <= d2 ? d1 : d2;
  118. return tgfmt(
  119. "{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx,
  120. i.substr(0, d));
  121. };
  122. os << tgfmt(
  123. R"(
  124. template<> std::unordered_map<std::string, $opClass::$enumClass>
  125. $enumTpl<$opClass::$enumClass>::mem2value = {$0};
  126. )",
  127. &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), mem2value), ", "));
  128. os << tgfmt(
  129. "template<> PyObject* "
  130. "$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n",
  131. &ctx, attr->getEnumMembers().size());
  132. }
  133. Initproc EnumAttrEmitter::emit_initproc() {
  134. std::string initproc =
  135. formatv("_init_py_{0}_{1}", ctx.getSubstFor("opClass"),
  136. ctx.getSubstFor("enumClass"));
  137. os << tgfmt(
  138. R"(
  139. void $0(PyTypeObject& py_type) {
  140. auto& e_type = $enumTpl<$opClass::$enumClass>::type;
  141. )",
  142. &ctx, initproc);
  143. if (firstOccur) {
  144. os << tgfmt(
  145. R"(
  146. static PyMethodDef tp_methods[] = {
  147. {const_cast<char*>("dump"), (PyCFunction)$enumTpl<$opClass::$enumClass>::py_dump, METH_NOARGS, NULL},
  148. {NULL} /* Sentinel */
  149. };
  150. )",
  151. &ctx);
  152. os << tgfmt(
  153. R"(
  154. static PyType_Slot slots[] = {
  155. {Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr},
  156. {Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare},
  157. {Py_tp_methods, tp_methods},
  158. )",
  159. &ctx);
  160. if (attr->getEnumCombinedFlag()) {
  161. // only bit combined enum could new instance because bitwise operation,
  162. // others should always use singleton
  163. os << tgfmt(
  164. R"(
  165. {Py_tp_new, (void*)$enumTpl<$opClass::$enumClass>::py_new_combined_enum},
  166. {Py_nb_or, (void*)$enumTpl<$opClass::$enumClass>::py_or},
  167. {Py_nb_and, (void*)$enumTpl<$opClass::$enumClass>::py_and},
  168. )",
  169. &ctx);
  170. }
  171. os << R"(
  172. {0, NULL}
  173. };)";
  174. os << tgfmt(
  175. R"(
  176. static PyType_Spec spec = {
  177. // name
  178. "megengine.core._imperative_rt.ops.$opClass.$enumClass",
  179. // basicsize
  180. sizeof($enumTpl<$opClass::$enumClass>),
  181. // itemsize
  182. 0,
  183. // flags
  184. Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
  185. // slots
  186. slots
  187. };)",
  188. &ctx);
  189. os << tgfmt(
  190. R"(
  191. e_type = reinterpret_cast<PyTypeObject*>(PyType_FromSpec(&spec));
  192. )",
  193. &ctx);
  194. for (auto&& i :
  195. {std::pair<std::string, std::string>{
  196. "__name__", tgfmt("$enumClass", &ctx)},
  197. {"__module__", "megengine.core._imperative_rt.ops"},
  198. {"__qualname__", tgfmt("$opClass.$enumClass", &ctx)}}) {
  199. os << formatv(
  200. R"(
  201. mgb_assert(
  202. e_type->tp_setattro(
  203. reinterpret_cast<PyObject*>(e_type),
  204. py::cast("{0}").release().ptr(),
  205. py::cast("{1}").release().ptr()) >= 0);
  206. )",
  207. i.first, i.second);
  208. }
  209. auto&& members = attr->getEnumMembers();
  210. for (size_t idx = 0; idx < members.size(); ++idx) {
  211. size_t d1 = members[idx].find(' ');
  212. size_t d2 = members[idx].find('=');
  213. size_t d = d1 <= d2 ? d1 : d2;
  214. os << tgfmt(
  215. R"({
  216. PyObject* inst = e_type->tp_alloc(e_type, 0);
  217. reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
  218. mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
  219. $enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
  220. })",
  221. &ctx, members[idx].substr(0, d), idx);
  222. }
  223. }
  224. os << tgfmt(
  225. R"(
  226. Py_INCREF(e_type);
  227. mgb_assert(PyDict_SetItemString(
  228. py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(e_type)) >= 0);
  229. )",
  230. &ctx);
  231. os << "}\n";
  232. return initproc;
  233. }
  234. Initproc OpDefEmitter::emit() {
  235. for (auto&& i : op.getMgbAttributes()) {
  236. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  237. subclasses.push_back(
  238. EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit());
  239. }
  240. }
  241. emit_class();
  242. emit_py_init();
  243. emit_py_getsetters();
  244. emit_py_methods();
  245. return emit_initproc();
  246. }
  247. void OpDefEmitter::emit_class() {
  248. auto&& className = op.getCppClassName();
  249. std::string method_defs;
  250. std::vector<std::string> body;
  251. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  252. body.push_back(
  253. formatv(R"(
  254. {{"{0}", serialization<decltype(opdef.{0})>::dump(opdef.{0})})",
  255. attr.name));
  256. });
  257. method_defs +=
  258. formatv(R"(
  259. static PyObject* getstate(PyObject* self, PyObject*) {{
  260. auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst();
  261. static_cast<void>(opdef);
  262. std::unordered_map<std::string, py::object> state {{
  263. {1}
  264. };
  265. return py::cast(state).release().ptr();
  266. })",
  267. className, llvm::join(body, ","));
  268. body.clear();
  269. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  270. body.push_back(
  271. formatv(R"(
  272. {{
  273. auto&& iter = state.find("{0}");
  274. if (iter != state.end()) {
  275. opdef.{0} = serialization<decltype(opdef.{0})>::load(iter->second);
  276. }
  277. })",
  278. attr.name));
  279. });
  280. method_defs +=
  281. formatv(R"(
  282. static PyObject* setstate(PyObject* self, PyObject* args) {{
  283. PyObject* dict = PyTuple_GetItem(args, 0);
  284. if (!dict) return NULL;
  285. auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
  286. auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst();
  287. static_cast<void>(opdef);
  288. {1}
  289. Py_RETURN_NONE;
  290. })",
  291. className, llvm::join(body, "\n"));
  292. os << tgfmt(
  293. R"(
  294. PyOpDefBegin($_self) // {
  295. static PyGetSetDef py_getsetters[];
  296. static PyMethodDef tp_methods[];
  297. $0
  298. static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
  299. // };
  300. PyOpDefEnd($_self)
  301. )",
  302. &ctx, method_defs);
  303. }
  304. void OpDefEmitter::emit_py_init() {
  305. std::string initBody;
  306. if (!op.getMgbAttributes().empty()) {
  307. initBody += "static const char* kwlist[] = {";
  308. std::vector<llvm::StringRef> attr_name_list;
  309. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  310. attr_name_list.push_back(attr.name);
  311. });
  312. attr_name_list.push_back("scope");
  313. llvm::for_each(attr_name_list, [&](auto&& attr) {
  314. initBody += formatv("\"{0}\", ", attr);
  315. });
  316. initBody += "NULL};\n";
  317. initBody += " PyObject ";
  318. auto initializer = [&](auto&& attr) -> std::string {
  319. return formatv("*{0} = NULL", attr);
  320. };
  321. initBody +=
  322. llvm::join(llvm::map_range(attr_name_list, initializer), ", ") + ";\n";
  323. initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
  324. // an extra slot created for name
  325. initBody += std::string(attr_name_list.size(), 'O');
  326. initBody += "\", const_cast<char**>(kwlist)";
  327. llvm::for_each(attr_name_list, [&](auto&& attr) {
  328. initBody += formatv(", &{0}", attr);
  329. });
  330. initBody += "))\n";
  331. initBody += " return -1;\n";
  332. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  333. initBody +=
  334. tgfmt(R"(
  335. if ($0) {
  336. try {
  337. // TODO: remove this guard which is used for pybind11 implicit conversion
  338. py::detail::loader_life_support guard{};
  339. reinterpret_cast<PyOp($_self)*>(self)->inst().$0 =
  340. py::cast<decltype($_self::$0)>(py::handle($0));
  341. } CATCH_ALL(-1)
  342. }
  343. )",
  344. &ctx, attr.name);
  345. });
  346. initBody +=
  347. tgfmt(R"(
  348. if (scope) {
  349. try {
  350. reinterpret_cast<PyOp(OpDef)*>(self)->op
  351. ->set_scope(py::cast<std::string>(py::handle(scope)));
  352. } CATCH_ALL(-1)
  353. }
  354. )",
  355. &ctx);
  356. }
  357. initBody += "\n return 0;";
  358. os << tgfmt(
  359. R"(
  360. int PyOp($_self)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
  361. $0
  362. }
  363. )",
  364. &ctx, initBody);
  365. }
  366. void OpDefEmitter::emit_py_getsetters() {
  367. auto f = [&](auto&& attr) -> std::string {
  368. return tgfmt(
  369. "{const_cast<char*>(\"$0\"), py_get_generic($_self, $0), "
  370. "py_set_generic($_self, $0), const_cast<char*>(\"$0\"), NULL},",
  371. &ctx, attr.name);
  372. };
  373. os << tgfmt(
  374. R"(
  375. PyGetSetDef PyOp($_self)::py_getsetters[] = {
  376. $0
  377. {NULL} /* Sentinel */
  378. };
  379. )",
  380. &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n "));
  381. }
  382. void OpDefEmitter::emit_py_methods() {
  383. // generate methods
  384. std::string method_defs;
  385. std::vector<std::string> method_items;
  386. {
  387. auto&& className = op.getCppClassName();
  388. // generate getstate
  389. method_items.push_back(
  390. formatv("{{const_cast<char*>(\"__getstate__\"), PyOp({0})::getstate, "
  391. "METH_NOARGS, \"{0} getstate\"},",
  392. className));
  393. // generate setstate
  394. method_items.push_back(
  395. formatv("{{const_cast<char*>(\"__setstate__\"), PyOp({0})::setstate, "
  396. "METH_VARARGS, \"{0} setstate\"},",
  397. className));
  398. }
  399. os << tgfmt(
  400. R"(
  401. PyMethodDef PyOp($_self)::tp_methods[] = {
  402. $0
  403. {NULL} /* Sentinel */
  404. };
  405. )",
  406. &ctx, llvm::join(method_items, "\n "));
  407. }
  408. Initproc OpDefEmitter::emit_initproc() {
  409. std::string initproc = formatv("_init_py_{0}", op.getCppClassName());
  410. std::string subclass_init_call;
  411. for (auto&& i : subclasses) {
  412. subclass_init_call += formatv(" {0};\n", i("py_type"));
  413. }
  414. os << tgfmt(
  415. R"(
  416. void $0(py::module m) {
  417. using py_op = PyOp($_self);
  418. auto& py_type = PyOpType($_self);
  419. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  420. py_type.tp_name = "megengine.core._imperative_rt.ops.$_self";
  421. py_type.tp_basicsize = sizeof(PyOp($_self));
  422. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  423. py_type.tp_doc = "$_self";
  424. py_type.tp_base = &PyOpType(OpDef);
  425. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  426. py_type.tp_new = py_new_generic<py_op>;
  427. py_type.tp_init = py_op::py_init;
  428. py_type.tp_methods = py_op::tp_methods;
  429. py_type.tp_getset = py_op::py_getsetters;
  430. mgb_assert(PyType_Ready(&py_type) >= 0);
  431. $1
  432. PyType_Modified(&py_type);
  433. m.add_object("$_self", reinterpret_cast<PyObject*>(&py_type));
  434. mgb_assert(PyOp(OpDef)::ctype2pytype.emplace($_self::typeinfo(), &py_type).second);
  435. }
  436. )",
  437. &ctx, initproc, subclass_init_call);
  438. return initproc;
  439. }
  440. } // namespace
  441. bool gen_op_def_python_c_extension(raw_ostream& os, llvm::RecordKeeper& keeper) {
  442. Environment env;
  443. using namespace std::placeholders;
  444. std::vector<Initproc> initprocs;
  445. foreach_operator(keeper, [&](MgbOp& op) {
  446. initprocs.emplace_back(OpDefEmitter(op, os, env).emit());
  447. });
  448. os << "#define INIT_ALL_OP(m)";
  449. for (auto&& init : initprocs) {
  450. os << formatv(" \\\n {0};", init("m"));
  451. }
  452. os << "\n";
  453. return false;
  454. }
  455. } // namespace mlir::tblgen