You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python_c_extension.cpp 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792
  1. #include <cctype>
  2. #include <functional>
  3. #include <iostream>
  4. #include <sstream>
  5. #include <string>
  6. #include <tuple>
  7. #include <unordered_map>
  8. #include <vector>
  9. #include "../emitter.h"
  10. #include "python_c_extension.h"
  11. namespace mlir::tblgen {
  12. namespace {
  13. class TypeInfo;
  14. std::pair<TypeInfo, int> parse_type(const std::string&, const int);
  15. std::pair<std::vector<std::string>, int> parse_namespace(const std::string&, const int);
  16. struct Unit {};
  17. Unit unit;
  18. struct ParseError {};
  19. class TypeInfo {
  20. public:
  21. TypeInfo(std::string name) : name(name) {}
  22. std::string to_python_type_string() {
  23. std::stringstream ss;
  24. ss << translate_type_name(name);
  25. if (params.size() > 0) {
  26. ss << "[" << params[0].to_python_type_string();
  27. for (auto i = 1; i < params.size(); i++) {
  28. ss << ", " << params[i].to_python_type_string();
  29. }
  30. ss << "]";
  31. }
  32. return ss.str();
  33. }
  34. std::string translate_type_name(const std::string& cppTypeName) {
  35. auto res = translation.find(cppTypeName);
  36. if (res != translation.end())
  37. return res->second;
  38. try {
  39. auto segments = parse_namespace(cppTypeName, 0).first;
  40. // special rules
  41. if (segments.size() > 3 && segments[0] == "megdnn" &&
  42. segments[1] == "param") {
  43. segments.erase(segments.begin(), segments.begin() + 3);
  44. } else if (
  45. segments.size() == 2 && segments[0] == "megdnn" &&
  46. segments[1] == "DType") {
  47. segments.erase(segments.begin(), segments.begin() + 1);
  48. segments[0] = "str";
  49. } else if (
  50. segments.size() == 2 && segments[0] == "mgb" &&
  51. segments[1] == "CompNode") {
  52. segments.erase(segments.begin(), segments.begin() + 1);
  53. segments[0] = "str";
  54. }
  55. std::stringstream joined;
  56. joined << segments[0];
  57. for (auto i = 1; i < segments.size(); i++) {
  58. joined << "." << segments[i];
  59. }
  60. return joined.str();
  61. } catch (ParseError) {
  62. return cppTypeName;
  63. }
  64. }
  65. std::string name;
  66. std::vector<TypeInfo> params;
  67. private:
  68. static const std::unordered_map<std::string, std::string> translation;
  69. };
  70. const std::unordered_map<std::string, std::string> TypeInfo::translation = {
  71. {"bool", "bool"}, {"double", "float"}, {"float", "float"},
  72. {"int32_t", "int"}, {"int8_t", "int"}, {"size_t", "int"},
  73. {"std::string", "str"}, {"std::tuple", "tuple"}, {"std::vector", "list"},
  74. {"uint32_t", "int"}, {"uint64_t", "int"},
  75. };
  76. // a parser takes:
  77. // 1. a string to parse
  78. // 2. location to parse from (index of character)
  79. // returns:
  80. // 1. parsing result (type T)
  81. // 2. end location of substring which is consumed by parsing
  82. // throws exception when failed to parse
  83. template <typename T>
  84. using Parser = std::function<std::pair<T, int>(const std::string&, const int)>;
  85. std::pair<Unit, int> parse_blank(const std::string& text, const int begin) {
  86. auto now = begin;
  87. while (now < text.length() && isblank(text[now]))
  88. now += 1;
  89. return {unit, now};
  90. }
  91. Parser<Unit> parse_non_blank_char(char ch) {
  92. return [=](const std::string& text, const int begin) -> std::pair<Unit, int> {
  93. auto blankEnd = parse_blank(text, begin).second;
  94. if (blankEnd >= text.length() || text[blankEnd] != ch)
  95. throw ParseError{};
  96. return {unit, blankEnd + 1};
  97. };
  98. }
  99. Parser<std::string> parse_allowed_chars(std::function<bool(char)> allow) {
  100. return [=](const std::string& text,
  101. const int begin) -> std::pair<std::string, int> {
  102. auto now = begin;
  103. while (now < text.length() && allow(text[now]))
  104. now += 1;
  105. return {text.substr(begin, now - begin), now};
  106. };
  107. }
  108. template <typename T>
  109. Parser<std::tuple<T>> parse_seq(Parser<T> only) {
  110. return [=](const std::string& text,
  111. const int begin) -> std::pair<std::tuple<T>, int> {
  112. auto res = only(text, begin);
  113. return {{res.first}, res.second};
  114. };
  115. }
  116. template <typename Head, typename... Tail>
  117. Parser<std::tuple<Head, Tail...>> parse_seq(Parser<Head> head, Parser<Tail>... tail) {
  118. return [=](const std::string& text,
  119. const int begin) -> std::pair<std::tuple<Head, Tail...>, int> {
  120. std::pair<Head, int> headRes = head(text, begin);
  121. std::pair<std::tuple<Tail...>, int> tailRes =
  122. parse_seq(tail...)(text, headRes.second);
  123. return {std::tuple_cat(std::tuple<Head>(headRes.first), tailRes.first),
  124. tailRes.second};
  125. };
  126. }
  127. template <typename T>
  128. Parser<std::vector<T>> parse_many_at_least0(Parser<T> one) {
  129. return [=](const std::string& text,
  130. const int begin) -> std::pair<std::vector<T>, int> {
  131. std::vector<T> ret;
  132. auto now = begin;
  133. try {
  134. while (true) {
  135. auto oneRes = one(text, now);
  136. ret.emplace_back(oneRes.first);
  137. now = oneRes.second;
  138. }
  139. } catch (ParseError) {
  140. }
  141. return {ret, now};
  142. };
  143. }
  144. template <typename C>
  145. Parser<std::vector<C>> parse_sep_by_at_least1(
  146. Parser<Unit> separator, Parser<C> component) {
  147. return [=](const std::string& text,
  148. const int begin) -> std::pair<std::vector<C>, int> {
  149. std::vector<C> ret;
  150. auto headRes = component(text, begin);
  151. ret.emplace_back(headRes.first);
  152. auto tailRes = parse_many_at_least0(parse_seq(separator, component))(
  153. text, headRes.second);
  154. for (const auto& elem : tailRes.first) {
  155. ret.emplace_back(std::get<1>(elem));
  156. }
  157. return {ret, tailRes.second};
  158. };
  159. }
  160. std::pair<std::string, int> parse_identifier(const std::string& text, const int begin) {
  161. auto blankEnd = parse_blank(text, begin).second;
  162. auto indentRes = parse_allowed_chars(
  163. [](char ch) { return std::isalnum(ch) || ch == '_'; })(text, blankEnd);
  164. if (indentRes.first.empty())
  165. throw ParseError{};
  166. return indentRes;
  167. };
  168. std::pair<std::string, int> parse_qualified(const std::string& text, const int begin) {
  169. auto blankEnd = parse_blank(text, begin).second;
  170. auto indentRes = parse_allowed_chars([](char ch) {
  171. return std::isalnum(ch) || ch == '_' || ch == ':';
  172. })(text, blankEnd);
  173. if (indentRes.first.empty())
  174. throw ParseError{};
  175. return indentRes;
  176. };
  177. std::pair<std::vector<std::string>, int> parse_namespace(
  178. const std::string& text, const int begin) {
  179. auto res = parse_many_at_least0(parse_seq(
  180. parse_non_blank_char(':'), parse_non_blank_char(':'),
  181. Parser<std::string>(parse_identifier)))(text, begin);
  182. std::vector<std::string> ret;
  183. for (const auto& elem : res.first) {
  184. ret.emplace_back(std::get<2>(elem));
  185. }
  186. return {ret, res.second};
  187. }
  188. std::pair<TypeInfo, int> parse_leaf_type(const std::string& text, const int begin) {
  189. auto ret = parse_qualified(text, begin);
  190. return {TypeInfo(ret.first), ret.second};
  191. };
  192. std::pair<TypeInfo, int> parse_node_type(const std::string& text, const int begin) {
  193. auto nameRes = parse_qualified(text, begin);
  194. auto ret = TypeInfo(nameRes.first);
  195. auto now = parse_non_blank_char('<')(text, nameRes.second).second;
  196. auto argsRes = parse_sep_by_at_least1(
  197. parse_non_blank_char(','), Parser<TypeInfo>(parse_type))(text, now);
  198. ret.params = argsRes.first;
  199. now = parse_non_blank_char('>')(text, argsRes.second).second;
  200. return {ret, now};
  201. };
  202. std::pair<TypeInfo, int> parse_type(const std::string& text, const int begin) {
  203. try {
  204. return parse_node_type(text, begin);
  205. } catch (ParseError) {
  206. }
  207. return parse_leaf_type(text, begin);
  208. };
  209. std::string cpp_type_to_python_type(const std::string& input) {
  210. auto res = parse_type(input, 0);
  211. return res.first.to_python_type_string();
  212. }
  213. struct Initproc {
  214. std::string func;
  215. Initproc(std::string&& s) : func(std::move(s)) {}
  216. std::string operator()(std::string argument) {
  217. return formatv("{0}({1})", func, argument);
  218. }
  219. };
  220. class OpDefEmitter : public EmitterBase {
  221. public:
  222. OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_)
  223. : EmitterBase(os_, env_), op(op_) {
  224. ctx.withSelf(op.getCppClassName());
  225. }
  226. Initproc emit();
  227. private:
  228. void emit_class();
  229. void emit_py_init();
  230. void emit_py_getsetters();
  231. void emit_py_methods();
  232. void emit_py_init_proxy();
  233. void emit_py_init_methoddef(
  234. const std::unordered_map<std::string, std::vector<std::string>>&
  235. enum_attr_members);
  236. Initproc emit_initproc();
  237. MgbOp& op;
  238. std::vector<Initproc> subclasses;
  239. mlir::tblgen::FmtContext ctx;
  240. };
  241. class EnumAttrEmitter : public EmitterBase {
  242. public:
  243. EnumAttrEmitter(
  244. llvm::StringRef parent, MgbEnumAttr* attr_, raw_ostream& os_,
  245. Environment& env_)
  246. : EmitterBase(os_, env_), attr(attr_) {
  247. unsigned int enumID;
  248. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  249. auto&& aliasBase = alias->getAliasBase();
  250. enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
  251. } else {
  252. enumID = attr->getBaseRecord()->getID();
  253. }
  254. ctx.addSubst(
  255. "enumTpl",
  256. attr->getEnumCombinedFlag() ? "BitCombinedEnumWrapper" : "EnumWrapper");
  257. ctx.addSubst("opClass", parent);
  258. ctx.addSubst("enumClass", attr->getEnumName());
  259. firstOccur =
  260. env().enumAlias
  261. .emplace(enumID, std::make_pair(parent, attr->getEnumName()))
  262. .second;
  263. }
  264. Initproc emit();
  265. protected:
  266. void emit_trait();
  267. void emit_tpl_spl();
  268. Initproc emit_initproc();
  269. MgbEnumAttr* attr;
  270. bool firstOccur;
  271. mlir::tblgen::FmtContext ctx;
  272. };
  273. Initproc EnumAttrEmitter::emit() {
  274. emit_trait();
  275. emit_tpl_spl();
  276. return emit_initproc();
  277. }
  278. void EnumAttrEmitter::emit_trait() {
  279. if (!firstOccur)
  280. return;
  281. auto enumMax = [&] {
  282. if (attr->getEnumCombinedFlag()) {
  283. return formatv("(1llu << {0}) - 1", attr->getEnumMembers().size());
  284. } else {
  285. return formatv("{0} - 1", attr->getEnumMembers().size());
  286. }
  287. };
  288. os << tgfmt(
  289. R"(
  290. template<> struct EnumTrait<$opClass::$enumClass> {
  291. static constexpr const char *name = "$opClass.$enumClass";
  292. static constexpr std::underlying_type_t<$opClass::$enumClass> max = $0;
  293. };
  294. )",
  295. &ctx, enumMax());
  296. }
  297. void EnumAttrEmitter::emit_tpl_spl() {
  298. if (!firstOccur)
  299. return;
  300. os << tgfmt(
  301. "template<> PyTypeObject* $enumTpl<$opClass::$enumClass>::type = "
  302. "nullptr;\n",
  303. &ctx);
  304. auto quote = [&](auto&& i) -> std::string {
  305. size_t d1 = i.find(' ');
  306. size_t d2 = i.find('=');
  307. size_t d = d1 <= d2 ? d1 : d2;
  308. return formatv("\"{0}\"", i.substr(0, d));
  309. };
  310. os << tgfmt(
  311. R"(
  312. template<> const char*
  313. $enumTpl<$opClass::$enumClass>::members[] = {$0};
  314. )",
  315. &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", "));
  316. auto mem2value = [&](auto&& i) -> std::string {
  317. size_t d1 = i.find(' ');
  318. size_t d2 = i.find('=');
  319. size_t d = d1 <= d2 ? d1 : d2;
  320. return tgfmt(
  321. "{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx,
  322. i.substr(0, d));
  323. };
  324. os << tgfmt(
  325. R"(
  326. template<> std::unordered_map<std::string, $opClass::$enumClass>
  327. $enumTpl<$opClass::$enumClass>::mem2value = {$0};
  328. )",
  329. &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), mem2value), ", "));
  330. os << tgfmt(
  331. "template<> PyObject* "
  332. "$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n",
  333. &ctx, attr->getEnumMembers().size());
  334. }
  335. Initproc EnumAttrEmitter::emit_initproc() {
  336. std::string initproc =
  337. formatv("_init_py_{0}_{1}", ctx.getSubstFor("opClass"),
  338. ctx.getSubstFor("enumClass"));
  339. os << tgfmt(
  340. R"(
  341. void $0(PyTypeObject& py_type) {
  342. auto& e_type = $enumTpl<$opClass::$enumClass>::type;
  343. )",
  344. &ctx, initproc);
  345. if (firstOccur) {
  346. os << tgfmt(
  347. R"(
  348. static PyMethodDef tp_methods[] = {
  349. {const_cast<char*>("dump"), (PyCFunction)$enumTpl<$opClass::$enumClass>::py_dump, METH_NOARGS, NULL},
  350. {NULL} /* Sentinel */
  351. };
  352. )",
  353. &ctx);
  354. os << tgfmt(
  355. R"(
  356. static PyType_Slot slots[] = {
  357. {Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr},
  358. {Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare},
  359. {Py_tp_methods, tp_methods},
  360. )",
  361. &ctx);
  362. if (attr->getEnumCombinedFlag()) {
  363. // only bit combined enum could new instance because bitwise operation,
  364. // others should always use singleton
  365. os << tgfmt(
  366. R"(
  367. {Py_tp_new, (void*)$enumTpl<$opClass::$enumClass>::py_new_combined_enum},
  368. {Py_nb_or, (void*)$enumTpl<$opClass::$enumClass>::py_or},
  369. {Py_nb_and, (void*)$enumTpl<$opClass::$enumClass>::py_and},
  370. )",
  371. &ctx);
  372. }
  373. os << R"(
  374. {0, NULL}
  375. };)";
  376. os << tgfmt(
  377. R"(
  378. static PyType_Spec spec = {
  379. // name
  380. "megengine.core._imperative_rt.ops.$opClass.$enumClass",
  381. // basicsize
  382. sizeof($enumTpl<$opClass::$enumClass>),
  383. // itemsize
  384. 0,
  385. // flags
  386. Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
  387. // slots
  388. slots
  389. };)",
  390. &ctx);
  391. os << tgfmt(
  392. R"(
  393. e_type = reinterpret_cast<PyTypeObject*>(PyType_FromSpec(&spec));
  394. )",
  395. &ctx);
  396. for (auto&& i :
  397. {std::pair<std::string, std::string>{
  398. "__name__", tgfmt("$enumClass", &ctx)},
  399. {"__module__", "megengine.core._imperative_rt.ops"},
  400. {"__qualname__", tgfmt("$opClass.$enumClass", &ctx)}}) {
  401. os << formatv(
  402. R"(
  403. mgb_assert(
  404. e_type->tp_setattro(
  405. reinterpret_cast<PyObject*>(e_type),
  406. py::cast("{0}").release().ptr(),
  407. py::cast("{1}").release().ptr()) >= 0);
  408. )",
  409. i.first, i.second);
  410. }
  411. auto&& members = attr->getEnumMembers();
  412. for (size_t idx = 0; idx < members.size(); ++idx) {
  413. size_t d1 = members[idx].find(' ');
  414. size_t d2 = members[idx].find('=');
  415. size_t d = d1 <= d2 ? d1 : d2;
  416. os << tgfmt(
  417. R"({
  418. PyObject* inst = e_type->tp_alloc(e_type, 0);
  419. reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
  420. mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
  421. $enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
  422. })",
  423. &ctx, members[idx].substr(0, d), idx);
  424. }
  425. }
  426. os << tgfmt(
  427. R"(
  428. Py_INCREF(e_type);
  429. mgb_assert(PyDict_SetItemString(
  430. py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(e_type)) >= 0);
  431. )",
  432. &ctx);
  433. os << "}\n";
  434. return initproc;
  435. }
  436. Initproc OpDefEmitter::emit() {
  437. std::unordered_map<std::string, std::vector<std::string>> enum_attr_members;
  438. for (auto&& i : op.getMgbAttributes()) {
  439. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  440. subclasses.push_back(
  441. EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit());
  442. auto retType = cpp_type_to_python_type(std::string(attr->getReturnType()));
  443. enum_attr_members[retType] = std::vector<std::string>();
  444. for (const auto& member : attr->getEnumMembers()) {
  445. enum_attr_members[retType].emplace_back(member);
  446. }
  447. }
  448. }
  449. emit_class();
  450. emit_py_init();
  451. emit_py_getsetters();
  452. emit_py_methods();
  453. emit_py_init_proxy();
  454. emit_py_init_methoddef(enum_attr_members);
  455. return emit_initproc();
  456. }
  457. void OpDefEmitter::emit_class() {
  458. auto&& className = op.getCppClassName();
  459. std::string method_defs;
  460. std::vector<std::string> body;
  461. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  462. body.push_back(
  463. formatv(R"(
  464. {{"{0}", serialization<decltype(opdef.{0})>::dump(opdef.{0})})",
  465. attr.name));
  466. });
  467. method_defs +=
  468. formatv(R"(
  469. static PyObject* getstate(PyObject* self, PyObject*) {{
  470. auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst();
  471. static_cast<void>(opdef);
  472. std::unordered_map<std::string, py::object> state {{
  473. {1}
  474. };
  475. return py::cast(state).release().ptr();
  476. })",
  477. className, llvm::join(body, ","));
  478. body.clear();
  479. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  480. body.push_back(
  481. formatv(R"(
  482. {{
  483. auto&& iter = state.find("{0}");
  484. if (iter != state.end()) {
  485. opdef.{0} = serialization<decltype(opdef.{0})>::load(iter->second);
  486. }
  487. })",
  488. attr.name));
  489. });
  490. method_defs +=
  491. formatv(R"(
  492. static PyObject* setstate(PyObject* self, PyObject* args) {{
  493. PyObject* dict = PyTuple_GetItem(args, 0);
  494. if (!dict) return NULL;
  495. auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
  496. auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst();
  497. static_cast<void>(opdef);
  498. {1}
  499. Py_RETURN_NONE;
  500. })",
  501. className, llvm::join(body, "\n"));
  502. os << tgfmt(
  503. R"(
  504. PyOpDefBegin($_self) // {
  505. static PyGetSetDef py_getsetters[];
  506. static PyMethodDef tp_methods[];
  507. $0
  508. static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
  509. static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds);
  510. static PyMethodDef py_init_methoddef;
  511. // };
  512. PyOpDefEnd($_self)
  513. )",
  514. &ctx, method_defs);
  515. }
  516. void OpDefEmitter::emit_py_init() {
  517. std::string initBody;
  518. if (!op.getMgbAttributes().empty()) {
  519. initBody += "static const char* kwlist[] = {";
  520. std::vector<llvm::StringRef> attr_name_list;
  521. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  522. attr_name_list.push_back(attr.name);
  523. });
  524. attr_name_list.push_back("scope");
  525. llvm::for_each(attr_name_list, [&](auto&& attr) {
  526. initBody += formatv("\"{0}\", ", attr);
  527. });
  528. initBody += "NULL};\n";
  529. initBody += " PyObject ";
  530. auto initializer = [&](auto&& attr) -> std::string {
  531. return formatv("*{0} = NULL", attr);
  532. };
  533. initBody +=
  534. llvm::join(llvm::map_range(attr_name_list, initializer), ", ") + ";\n";
  535. initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
  536. // an extra slot created for name
  537. initBody += std::string(attr_name_list.size(), 'O');
  538. initBody += "\", const_cast<char**>(kwlist)";
  539. llvm::for_each(attr_name_list, [&](auto&& attr) {
  540. initBody += formatv(", &{0}", attr);
  541. });
  542. initBody += "))\n";
  543. initBody += " return -1;\n";
  544. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  545. initBody +=
  546. tgfmt(R"(
  547. if ($0) {
  548. try {
  549. // TODO: remove this guard which is used for pybind11 implicit conversion
  550. py::detail::loader_life_support guard{};
  551. reinterpret_cast<PyOp($_self)*>(self)->inst().$0 =
  552. py::cast<decltype($_self::$0)>(py::handle($0));
  553. } CATCH_ALL(-1)
  554. }
  555. )",
  556. &ctx, attr.name);
  557. });
  558. initBody +=
  559. tgfmt(R"(
  560. if (scope) {
  561. try {
  562. reinterpret_cast<PyOp(OpDef)*>(self)->op
  563. ->set_scope(py::cast<std::string>(py::handle(scope)));
  564. } CATCH_ALL(-1)
  565. }
  566. )",
  567. &ctx);
  568. }
  569. initBody += "\n return 0;";
  570. os << tgfmt(
  571. R"(
  572. int PyOp($_self)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
  573. $0
  574. }
  575. )",
  576. &ctx, initBody);
  577. }
  578. void OpDefEmitter::emit_py_getsetters() {
  579. auto f = [&](auto&& attr) -> std::string {
  580. return tgfmt(
  581. "{const_cast<char*>(\"$0\"), py_get_generic($_self, $0), "
  582. "py_set_generic($_self, $0), const_cast<char*>(\"$0\"), NULL},",
  583. &ctx, attr.name);
  584. };
  585. os << tgfmt(
  586. R"(
  587. PyGetSetDef PyOp($_self)::py_getsetters[] = {
  588. $0
  589. {NULL} /* Sentinel */
  590. };
  591. )",
  592. &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n "));
  593. }
  594. void OpDefEmitter::emit_py_methods() {
  595. // generate methods
  596. std::string method_defs;
  597. std::vector<std::string> method_items;
  598. {
  599. auto&& className = op.getCppClassName();
  600. // generate getstate
  601. method_items.push_back(
  602. formatv("{{const_cast<char*>(\"__getstate__\"), PyOp({0})::getstate, "
  603. "METH_NOARGS, \"{0} getstate\"},",
  604. className));
  605. // generate setstate
  606. method_items.push_back(
  607. formatv("{{const_cast<char*>(\"__setstate__\"), PyOp({0})::setstate, "
  608. "METH_VARARGS, \"{0} setstate\"},",
  609. className));
  610. }
  611. os << tgfmt(
  612. R"(
  613. PyMethodDef PyOp($_self)::tp_methods[] = {
  614. $0
  615. {NULL} /* Sentinel */
  616. };
  617. )",
  618. &ctx, llvm::join(method_items, "\n "));
  619. }
  620. void OpDefEmitter::emit_py_init_proxy() {
  621. os << tgfmt(
  622. R"(
  623. PyObject *PyOp($_self)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) {
  624. if (PyOp($_self)::py_init(self, args, kwds) < 0) {
  625. return NULL;
  626. }
  627. Py_RETURN_NONE;
  628. }
  629. )",
  630. &ctx);
  631. }
  632. void OpDefEmitter::emit_py_init_methoddef(
  633. const std::unordered_map<std::string, std::vector<std::string>>&
  634. enum_attr_members) {
  635. std::string docstring = "__init__(self";
  636. for (const auto& attr : op.getMgbAttributes()) {
  637. if (attr.name == "workspace_limit")
  638. continue;
  639. auto pyType = cpp_type_to_python_type(std::string(attr.attr.getReturnType()));
  640. auto findRes = enum_attr_members.find(pyType);
  641. if (findRes != enum_attr_members.end()) {
  642. pyType = formatv("Union[str, {0}]", pyType);
  643. // TODO stubgen cannot handle Literal strings for now
  644. // auto members = findRes->second;
  645. // std::string enumTypeString = "Literal[";
  646. // enumTypeString += formatv("'{0}'", lowercase(members[0]));
  647. // for (auto i = 1; i < members.size(); i++) {
  648. // enumTypeString += formatv(", '{0}'", lowercase(members[i]));
  649. // }
  650. // enumTypeString += "]";
  651. // pyType = enumTypeString;
  652. }
  653. docstring += formatv(", {0}: {1} = ...", attr.name, pyType);
  654. }
  655. docstring += ") -> None\\n";
  656. os << tgfmt(
  657. R"(
  658. PyMethodDef PyOp($_self)::py_init_methoddef = {
  659. "__init__",
  660. (PyCFunction)PyOp($_self)::py_init_proxy,
  661. METH_VARARGS | METH_KEYWORDS,
  662. "$0"
  663. };
  664. )",
  665. &ctx, docstring);
  666. }
  667. Initproc OpDefEmitter::emit_initproc() {
  668. std::string initproc = formatv("_init_py_{0}", op.getCppClassName());
  669. std::string subclass_init_call;
  670. for (auto&& i : subclasses) {
  671. subclass_init_call += formatv(" {0};\n", i("py_type"));
  672. }
  673. os << tgfmt(
  674. R"(
  675. void $0(py::module m) {
  676. using py_op = PyOp($_self);
  677. auto& py_type = PyOpType($_self);
  678. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  679. py_type.tp_name = "megengine.core._imperative_rt.ops.$_self";
  680. py_type.tp_basicsize = sizeof(PyOp($_self));
  681. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  682. py_type.tp_doc = "$_self";
  683. py_type.tp_base = &PyOpType(OpDef);
  684. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  685. py_type.tp_new = py_new_generic<py_op>;
  686. py_type.tp_init = py_op::py_init;
  687. py_type.tp_methods = py_op::tp_methods;
  688. py_type.tp_getset = py_op::py_getsetters;
  689. py_type.tp_dict = PyDict_New();
  690. PyObject* descr = PyDescr_NewMethod(&PyOpType($_self), &PyOp($_self)::py_init_methoddef);
  691. PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
  692. mgb_assert(PyType_Ready(&py_type) >= 0);
  693. $1
  694. PyType_Modified(&py_type);
  695. m.add_object("$_self", reinterpret_cast<PyObject*>(&py_type));
  696. mgb_assert(PyOp(OpDef)::ctype2pytype.emplace($_self::typeinfo(), &py_type).second);
  697. }
  698. )",
  699. &ctx, initproc, subclass_init_call);
  700. return initproc;
  701. }
  702. } // namespace
  703. bool gen_op_def_python_c_extension(raw_ostream& os, llvm::RecordKeeper& keeper) {
  704. Environment env;
  705. using namespace std::placeholders;
  706. std::vector<Initproc> initprocs;
  707. foreach_operator(keeper, [&](MgbOp& op) {
  708. initprocs.emplace_back(OpDefEmitter(op, os, env).emit());
  709. });
  710. os << "#define INIT_ALL_OP(m)";
  711. for (auto&& init : initprocs) {
  712. os << formatv(" \\\n {0};", init("m"));
  713. }
  714. os << "\n";
  715. return false;
  716. }
  717. } // namespace mlir::tblgen