You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autogen.cpp 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660
  1. #include <iostream>
  2. #include <unordered_map>
  3. #include <functional>
  4. #include "./helper.h"
  5. using llvm::raw_ostream;
  6. using llvm::RecordKeeper;
  7. enum ActionType {
  8. None,
  9. CppHeader,
  10. CppBody,
  11. Pybind,
  12. CPython
  13. };
  14. // NOLINTNEXTLINE
  15. llvm::cl::opt<ActionType> action(
  16. llvm::cl::desc("Action to perform:"),
  17. llvm::cl::values(clEnumValN(CppHeader, "gen-cpp-header",
  18. "Generate operator cpp header"),
  19. clEnumValN(CppBody, "gen-cpp-body",
  20. "Generate operator cpp body"),
  21. clEnumValN(Pybind, "gen-python-binding",
  22. "Generate pybind11 python bindings"),
  23. clEnumValN(CPython, "gen-python-c-extension",
  24. "Generate python c extensions")));
  25. using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase;
  26. using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin;
  27. using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin;
  28. using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin;
  29. using MgbOp = mlir::tblgen::MgbOpBase;
  30. using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin;
  31. llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
  32. // Note: we have already registered the corresponding attr wrappers
  33. // for following basic ctypes so we needn't handle them here
  34. /* auto&& attr_type_name = attr.getAttrDefName();
  35. if (attr_type_name == "UI32Attr") {
  36. return "uint32_t";
  37. }
  38. if (attr_type_name == "UI64Attr") {
  39. return "uint64_t";
  40. }
  41. if (attr_type_name == "I32Attr") {
  42. return "int32_t";
  43. }
  44. if (attr_type_name == "F32Attr") {
  45. return "float";
  46. }
  47. if (attr_type_name == "F64Attr") {
  48. return "double";
  49. }
  50. if (attr_type_name == "StrAttr") {
  51. return "std::string";
  52. }
  53. if (attr_type_name == "BoolAttr") {
  54. return "bool";
  55. }*/
  56. auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
  57. if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
  58. return e->getEnumName();
  59. }
  60. return attr.getUnderlyingType();
  61. }
  62. static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) {
  63. os << formatv(
  64. "class {0} : public OpDefImplBase<{0}> {{\n"
  65. " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
  66. "public:\n",
  67. op.getCppClassName()
  68. );
  69. // handle enum alias
  70. for (auto &&i : op.getMgbAttributes()) {
  71. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  72. os << formatv(
  73. " using {0} = {1};\n",
  74. attr->getEnumName(), attr->getUnderlyingType()
  75. );
  76. }
  77. }
  78. for (auto &&i : op.getMgbAttributes()) {
  79. auto defaultValue = i.attr.getDefaultValue().str();
  80. if (!defaultValue.empty()) {
  81. defaultValue = formatv(" = {0}", defaultValue);
  82. }
  83. os << formatv(
  84. " {0} {1}{2};\n",
  85. attr_to_ctype(i.attr), i.name, defaultValue
  86. );
  87. }
  88. auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
  89. os << formatv(
  90. " {0}({1}){2}{3}\n",
  91. op.getCppClassName(), paramList, memInitList, body
  92. );
  93. };
  94. gen_ctor("", "", " = default;");
  95. if (!op.getMgbAttributes().empty()) {
  96. std::vector<std::string> paramList, initList;
  97. for (auto &&i : op.getMgbAttributes()) {
  98. paramList.push_back(formatv(
  99. "{0} {1}_", attr_to_ctype(i.attr), i.name
  100. ));
  101. initList.push_back(formatv(
  102. "{0}({0}_)", i.name
  103. ));
  104. }
  105. paramList.push_back("std::string scope_ = {}");
  106. gen_ctor(llvm::join(paramList, ", "),
  107. ": " + llvm::join(initList, ", "),
  108. " { set_scope(scope_); }");
  109. }
  110. auto packedParams = op.getPackedParams();
  111. if (!packedParams.empty()) {
  112. std::vector<std::string> paramList, initList;
  113. for (auto &&p : packedParams) {
  114. auto&& paramFields = p.getFields();
  115. auto&& paramType = p.getFullName();
  116. auto&& paramName = formatv("packed_param_{0}", paramList.size());
  117. paramList.push_back(
  118. paramFields.empty() ? paramType.str()
  119. : formatv("{0} {1}", paramType, paramName)
  120. );
  121. for (auto&& i : paramFields) {
  122. initList.push_back(formatv(
  123. "{0}({1}.{0})", i.name, paramName
  124. ));
  125. }
  126. }
  127. for (auto&& i : op.getExtraArguments()) {
  128. paramList.push_back(formatv(
  129. "{0} {1}_", attr_to_ctype(i.attr), i.name
  130. ));
  131. initList.push_back(formatv(
  132. "{0}({0}_)", i.name
  133. ));
  134. }
  135. gen_ctor(llvm::join(paramList, ", "),
  136. initList.empty() ? "" : ": " + llvm::join(initList, ", "),
  137. " {}");
  138. }
  139. if (!packedParams.empty()) {
  140. for (auto&& p : packedParams) {
  141. auto accessor = p.getAccessor();
  142. if (!accessor.empty()) {
  143. os << formatv(
  144. " {0} {1}() const {{\n",
  145. p.getFullName(), accessor
  146. );
  147. std::vector<llvm::StringRef> fields;
  148. for (auto&& i : p.getFields()) {
  149. fields.push_back(i.name);
  150. }
  151. os << formatv(
  152. " return {{{0}};\n",
  153. llvm::join(fields, ", ")
  154. );
  155. os << " }\n";
  156. }
  157. }
  158. }
  159. if (auto decl = op.getExtraOpdefDecl()) {
  160. os << decl.getValue();
  161. }
  162. os << formatv(
  163. "};\n\n"
  164. );
  165. }
  166. static void gen_to_string_trait_for_enum(raw_ostream &os, MgbOp& op) {
  167. for (auto &&i : op.getMgbAttributes()) {
  168. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  169. if (attr->supportToString()) {
  170. std::vector<std::string> case_body;
  171. std::string ename = formatv("{0}::{1}",
  172. op.getCppClassName(), attr->getEnumName());
  173. llvm::for_each(attr->getEnumMembers(), [&](auto&& v){
  174. case_body.push_back(formatv(
  175. "case {0}::{1}: return \"{1}\";", ename, v));
  176. });
  177. os << formatv(R"(
  178. template <>
  179. struct ToStringTrait<{0}> {
  180. std::string operator()({0} e) const {
  181. switch (e) {
  182. {1}
  183. default:
  184. return "{0}::Unknown";
  185. }
  186. }
  187. };
  188. )", ename, llvm::join(case_body, "\n"));
  189. }
  190. }
  191. }
  192. }
  193. static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
  194. auto&& className = op.getCppClassName();
  195. os << formatv(
  196. "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className
  197. );
  198. auto formatMethImpl = [&](auto&& meth) {
  199. return formatv(
  200. "{0}_{1}_impl", className, meth
  201. );
  202. };
  203. std::vector<std::string> methods;
  204. if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
  205. os << "namespace {\n";
  206. // generate hash()
  207. mlir::tblgen::FmtContext ctx;
  208. os << formatv(
  209. "size_t {0}(const OpDef& def_) {{\n",
  210. formatMethImpl("hash")
  211. );
  212. os << formatv(
  213. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  214. " static_cast<void>(op_);\n",
  215. className
  216. );
  217. ctx.withSelf("op_");
  218. os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
  219. os << "}\n";
  220. // generate is_same_st()
  221. os << formatv(
  222. "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
  223. formatMethImpl("is_same_st")
  224. );
  225. os << formatv(
  226. " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n"
  227. " &&b_ = rhs_.cast_final_safe<{0}>();\n"
  228. " static_cast<void>(a_);\n"
  229. " static_cast<void>(b_);\n",
  230. className
  231. );
  232. os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
  233. os << "}\n";
  234. // generate props()
  235. os << formatv(
  236. "std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n",
  237. formatMethImpl("props")
  238. );
  239. os << formatv(
  240. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  241. " static_cast<void>(op_);\n",
  242. className
  243. );
  244. ctx.withSelf("op_");
  245. os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx);
  246. os << "}\n";
  247. // generate make_name()
  248. os << formatv(
  249. "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name")
  250. );
  251. os << formatv(
  252. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  253. " static_cast<void>(op_);\n",
  254. className
  255. );
  256. ctx.withSelf("op_");
  257. os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx);
  258. os << "}\n";
  259. os << "} // anonymous namespace\n";
  260. methods.push_back("hash");
  261. methods.push_back("is_same_st");
  262. methods.push_back("props");
  263. methods.push_back("make_name");
  264. }
  265. if (!methods.empty()) {
  266. os << formatv(
  267. "OP_TRAIT_REG({0}, {0})", op.getCppClassName()
  268. );
  269. for (auto&& i : methods) {
  270. os << formatv(
  271. "\n .{0}({1})", i, formatMethImpl(i)
  272. );
  273. }
  274. os << ";\n\n";
  275. }
  276. }
  277. struct EnumContext {
  278. std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias;
  279. };
  280. static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
  281. auto className = op.getCppClassName();
  282. os << formatv(
  283. "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
  284. className
  285. );
  286. for (auto&& i : op.getMgbAttributes()) {
  287. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  288. unsigned int enumID;
  289. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  290. auto&& aliasBase = alias->getAliasBase();
  291. enumID =
  292. llvm::cast<MgbEnumAttr>(aliasBase)
  293. .getBaseRecord()->getID();
  294. } else {
  295. enumID = attr->getBaseRecord()->getID();
  296. }
  297. auto&& enumAlias = ctx.enumAlias;
  298. auto&& iter = enumAlias.find(enumID);
  299. if (iter == enumAlias.end()) {
  300. os << formatv(
  301. "py::enum_<{0}::{1}>({0}Inst, \"{1}\")",
  302. className, attr->getEnumName()
  303. );
  304. std::vector<std::string> body;
  305. for (auto&& i: attr->getEnumMembers()) {
  306. os << formatv(
  307. "\n .value(\"{2}\", {0}::{1}::{2})",
  308. className, attr->getEnumName(), i
  309. );
  310. body.push_back(formatv(
  311. "if (str == \"{2}\") return {0}::{1}::{2};",
  312. className, attr->getEnumName(), i
  313. ));
  314. }
  315. os << formatv(
  316. "\n .def(py::init([](const std::string& in) {"
  317. "\n auto&& str = normalize_enum(in);"
  318. "\n {0}"
  319. "\n throw py::cast_error(\"invalid enum value \" + in);"
  320. "\n }));\n",
  321. llvm::join(body, "\n ")
  322. );
  323. os << formatv(
  324. "py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
  325. className, attr->getEnumName()
  326. );
  327. enumAlias.emplace(enumID,
  328. std::make_pair(className, attr->getEnumName()));
  329. } else {
  330. os << formatv(
  331. "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n",
  332. className, attr->getEnumName(),
  333. iter->second.first, iter->second.second
  334. );
  335. }
  336. }
  337. }
  338. // generate op class binding
  339. os << formatv("{0}Inst", className);
  340. bool hasDefaultCtor = op.getMgbAttributes().empty();
  341. if (!hasDefaultCtor) {
  342. os << "\n .def(py::init<";
  343. std::vector<llvm::StringRef> targs;
  344. for (auto &&i : op.getMgbAttributes()) {
  345. targs.push_back(i.attr.getReturnType());
  346. }
  347. os << llvm::join(targs, ", ");
  348. os << ", std::string>()";
  349. for (auto &&i : op.getMgbAttributes()) {
  350. os << formatv(", py::arg(\"{0}\")", i.name);
  351. auto defaultValue = i.attr.getDefaultValue();
  352. if (!defaultValue.empty()) {
  353. os << formatv(" = {0}", defaultValue);
  354. } else {
  355. hasDefaultCtor = true;
  356. }
  357. }
  358. os << ", py::arg(\"scope\") = {})";
  359. }
  360. if (hasDefaultCtor) {
  361. os << "\n .def(py::init<>())";
  362. }
  363. for (auto &&i : op.getMgbAttributes()) {
  364. os << formatv(
  365. "\n .def_readwrite(\"{0}\", &{1}::{0})",
  366. i.name, className
  367. );
  368. }
  369. os << ";\n\n";
  370. }
  371. static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
  372. auto className = op.getCppClassName();
  373. std::string body;
  374. // generate PyType for enum class member
  375. for (auto&& i : op.getMgbAttributes()) {
  376. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  377. unsigned int enumID;
  378. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  379. auto&& aliasBase = alias->getAliasBase();
  380. enumID =
  381. llvm::cast<MgbEnumAttr>(aliasBase)
  382. .getBaseRecord()->getID();
  383. } else {
  384. enumID = attr->getBaseRecord()->getID();
  385. }
  386. auto&& enumAlias = ctx.enumAlias;
  387. auto&& iter = enumAlias.find(enumID);
  388. auto enumName = attr->getEnumName();
  389. body += "{\n";
  390. body += formatv(
  391. "auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName
  392. );
  393. if (iter == enumAlias.end()) {
  394. os << formatv(
  395. "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n",
  396. className, enumName);
  397. os << formatv(
  398. "template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n",
  399. className, enumName);
  400. std::vector<std::string> pairStr;
  401. for (auto&& i: attr->getEnumMembers()) {
  402. pairStr.push_back(formatv(
  403. "{{normalize_enum(\"{2}\"), {0}::{1}::{2}}",
  404. className, enumName, i));
  405. }
  406. os << formatv(R"(
  407. template<> std::unordered_map<std::string, {0}::{1}>
  408. EnumWrapper<{0}::{1}>::str2type = {{
  409. {2}
  410. };
  411. )", className, enumName, llvm::join(pairStr, ", "));
  412. pairStr.clear();
  413. for (auto&& i: attr->getEnumMembers()) {
  414. pairStr.push_back(formatv(
  415. "{{{0}::{1}::{2}, normalize_enum(\"{2}\")}",
  416. className, enumName, i));
  417. }
  418. os << formatv(R"(
  419. template<> std::unordered_map<{0}::{1}, std::string>
  420. EnumWrapper<{0}::{1}>::type2str = {{
  421. {2}
  422. };
  423. )", className, enumName, llvm::join(pairStr, ", "));
  424. body += formatv(R"(
  425. e_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
  426. e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}";
  427. e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>);
  428. e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  429. e_type.tp_doc = "{0}.{1}";
  430. e_type.tp_base = &PyBaseObject_Type;
  431. e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr;
  432. e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare;
  433. mgb_assert(PyType_Ready(&e_type) >= 0);
  434. )", className, enumName);
  435. for (auto&& i: attr->getEnumMembers()) {
  436. body += formatv(R"({{
  437. PyObject* inst = e_type.tp_alloc(&e_type, 0);
  438. reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2};
  439. mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0);
  440. })", className, enumName, i);
  441. }
  442. enumAlias.emplace(enumID, std::make_pair(className, enumName));
  443. }
  444. body += formatv(R"(
  445. PyType_Modified(&e_type);
  446. mgb_assert(PyDict_SetItemString(
  447. py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0);
  448. )", enumName);
  449. body += "}\n";
  450. }
  451. }
  452. // generate getsetters
  453. std::vector<std::string> getsetters;
  454. for (auto &&i : op.getMgbAttributes()) {
  455. getsetters.push_back(formatv(
  456. "{{const_cast<char*>(\"{1}\"), py_get_generic({0}, {1}), py_set_generic({0}, {1}), const_cast<char*>(\"{1}\"), NULL},",
  457. className, i.name));
  458. }
  459. getsetters.push_back(formatv(
  460. "{{\"scope\", py_get_scope({0}), py_set_scope({0}), \"scope\", NULL},",
  461. className));
  462. // generate tp_init
  463. std::string initBody;
  464. if (!op.getMgbAttributes().empty()) {
  465. initBody += "static const char* kwlist[] = {";
  466. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  467. initBody += formatv("\"{0}\", ", attr.name);
  468. });
  469. initBody += "\"scope\", ";
  470. initBody += "NULL};\n";
  471. initBody += " PyObject ";
  472. std::vector<std::string> attrs;
  473. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  474. attrs.push_back(formatv("*{0} = NULL", attr.name));
  475. });
  476. initBody += llvm::join(attrs, ", ") + ";\n";
  477. initBody += " PyObject *scope = NULL;\n";
  478. initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
  479. // an extra slot created for name
  480. initBody += std::string(op.getMgbAttributes().size() + 1, 'O');
  481. initBody += "\", const_cast<char**>(kwlist)";
  482. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  483. initBody += formatv(", &{0}", attr.name);
  484. });
  485. initBody += ", &scope";
  486. initBody += "))\n";
  487. initBody += " return -1;\n";
  488. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  489. initBody += formatv(R"(
  490. if ({1}) {{
  491. try {{
  492. reinterpret_cast<PyOp({0})*>(self)->inst().{1} =
  493. pyobj_convert_generic<decltype({0}::{1})>::from({1});
  494. } catch(py::error_already_set& e) {{
  495. e.restore();
  496. return -1;
  497. } catch(py::builtin_exception& e) {{
  498. e.set_error();
  499. return -1;
  500. } catch(...) {{
  501. PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
  502. return -1;
  503. }
  504. }
  505. )", className, attr.name);
  506. });
  507. initBody += formatv(R"(
  508. if (scope) {{
  509. try {{
  510. reinterpret_cast<PyOp({0})*>(self)->inst().set_scope(
  511. pyobj_convert_generic<std::string>::from(scope));
  512. } catch(py::error_already_set& e) {{
  513. e.restore();
  514. return -1;
  515. } catch(py::builtin_exception& e) {{
  516. e.set_error();
  517. return -1;
  518. } catch(...) {{
  519. PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
  520. return -1;
  521. }
  522. }
  523. )", className);
  524. }
  525. initBody += "\n return 0;";
  526. os << formatv(R"(
  527. PyOpDefBegin({0}) // {{
  528. static PyGetSetDef py_getsetters[];
  529. static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
  530. // };
  531. PyOpDefEnd({0})
  532. PyGetSetDef PyOp({0})::py_getsetters[] = {{
  533. {1}
  534. {{NULL} /* Sentinel */
  535. };
  536. int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{
  537. {2}
  538. }
  539. void _init_py_{0}(py::module m) {{
  540. using py_op = PyOp({0});
  541. auto& py_type = PyOpType({0});
  542. py_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
  543. py_type.tp_name = "megengine.core._imperative_rt.ops.{0}";
  544. py_type.tp_basicsize = sizeof(PyOp({0}));
  545. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  546. py_type.tp_doc = "{0}";
  547. py_type.tp_base = &PyOpType(OpDef);
  548. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  549. py_type.tp_new = py_new_generic<py_op>;
  550. py_type.tp_init = py_op::py_init;
  551. py_type.tp_getset = py_op::py_getsetters;
  552. mgb_assert(PyType_Ready(&py_type) >= 0);
  553. {3}
  554. PyType_Modified(&py_type);
  555. m.add_object("{0}", reinterpret_cast<PyObject*>(&py_type));
  556. mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second);
  557. }
  558. )",
  559. op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body);
  560. }
  561. static void for_each_operator(raw_ostream &os, RecordKeeper &keeper,
  562. std::function<void(raw_ostream&, MgbOp&)> callback) {
  563. auto op_base_class = keeper.getClass("Op");
  564. ASSERT(op_base_class, "could not find base class Op");
  565. for (auto&& i: keeper.getDefs()) {
  566. auto&& r = i.second;
  567. if (r->isSubClassOf(op_base_class)) {
  568. auto op = mlir::tblgen::Operator(r.get());
  569. if (op.getDialectName().str() == "mgb") {
  570. std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl;
  571. callback(os, llvm::cast<MgbOp>(op));
  572. }
  573. }
  574. }
  575. }
  576. static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) {
  577. for_each_operator(os, keeper, gen_op_def_c_header_single);
  578. for_each_operator(os, keeper, gen_to_string_trait_for_enum);
  579. return false;
  580. }
  581. static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) {
  582. for_each_operator(os, keeper, gen_op_def_c_body_single);
  583. return false;
  584. }
  585. static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) {
  586. EnumContext ctx;
  587. using namespace std::placeholders;
  588. for_each_operator(os, keeper,
  589. std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx)));
  590. return false;
  591. }
  592. static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) {
  593. EnumContext ctx;
  594. using namespace std::placeholders;
  595. for_each_operator(os, keeper,
  596. std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx)));
  597. os << "#define INIT_ALL_OP(m)";
  598. for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) {
  599. os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName());
  600. });
  601. os << "\n";
  602. return false;
  603. }
  604. int main(int argc, char **argv) {
  605. llvm::InitLLVM y(argc, argv);
  606. llvm::cl::ParseCommandLineOptions(argc, argv);
  607. if (action == ActionType::CppHeader) {
  608. return TableGenMain(argv[0], &gen_op_def_c_header);
  609. }
  610. if (action == ActionType::CppBody) {
  611. return TableGenMain(argv[0], &gen_op_def_c_body);
  612. }
  613. if (action == ActionType::Pybind) {
  614. return TableGenMain(argv[0], &gen_op_def_pybind11);
  615. }
  616. if (action == ActionType::CPython) {
  617. return TableGenMain(argv[0], &gen_op_def_python_c_extension);
  618. }
  619. return -1;
  620. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台