You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autogen.cpp 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. #include <iostream>
  2. #include <unordered_map>
  3. #include <functional>
  4. #include "./helper.h"
  5. using llvm::raw_ostream;
  6. using llvm::RecordKeeper;
  7. enum ActionType {
  8. None,
  9. CppHeader,
  10. CppBody,
  11. Pybind,
  12. CPython
  13. };
  14. // NOLINTNEXTLINE
  15. llvm::cl::opt<ActionType> action(
  16. llvm::cl::desc("Action to perform:"),
  17. llvm::cl::values(clEnumValN(CppHeader, "gen-cpp-header",
  18. "Generate operator cpp header"),
  19. clEnumValN(CppBody, "gen-cpp-body",
  20. "Generate operator cpp body"),
  21. clEnumValN(Pybind, "gen-python-binding",
  22. "Generate pybind11 python bindings"),
  23. clEnumValN(CPython, "gen-python-c-extension",
  24. "Generate python c extensions")));
  25. using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase;
  26. using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin;
  27. using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin;
  28. using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin;
  29. using MgbOp = mlir::tblgen::MgbOpBase;
  30. using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin;
  31. llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
  32. // Note: we have already registered the corresponding attr wrappers
  33. // for following basic ctypes so we needn't handle them here
  34. /* auto&& attr_type_name = attr.getAttrDefName();
  35. if (attr_type_name == "UI32Attr") {
  36. return "uint32_t";
  37. }
  38. if (attr_type_name == "UI64Attr") {
  39. return "uint64_t";
  40. }
  41. if (attr_type_name == "I32Attr") {
  42. return "int32_t";
  43. }
  44. if (attr_type_name == "F32Attr") {
  45. return "float";
  46. }
  47. if (attr_type_name == "F64Attr") {
  48. return "double";
  49. }
  50. if (attr_type_name == "StrAttr") {
  51. return "std::string";
  52. }
  53. if (attr_type_name == "BoolAttr") {
  54. return "bool";
  55. }*/
  56. auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
  57. if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
  58. return e->getEnumName();
  59. }
  60. return attr.getUnderlyingType();
  61. }
  62. static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) {
  63. os << formatv(
  64. "class {0} : public OpDefImplBase<{0}> {{\n"
  65. " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
  66. "public:\n",
  67. op.getCppClassName()
  68. );
  69. // handle enum alias
  70. for (auto &&i : op.getMgbAttributes()) {
  71. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  72. os << formatv(
  73. " using {0} = {1};\n",
  74. attr->getEnumName(), attr->getUnderlyingType()
  75. );
  76. }
  77. }
  78. for (auto &&i : op.getMgbAttributes()) {
  79. auto defaultValue = i.attr.getDefaultValue().str();
  80. if (!defaultValue.empty()) {
  81. defaultValue = formatv(" = {0}", defaultValue);
  82. }
  83. os << formatv(
  84. " {0} {1}{2};\n",
  85. attr_to_ctype(i.attr), i.name, defaultValue
  86. );
  87. }
  88. auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
  89. os << formatv(
  90. " {0}({1}){2}{3}\n",
  91. op.getCppClassName(), paramList, memInitList, body
  92. );
  93. };
  94. gen_ctor("", "", " = default;");
  95. if (!op.getMgbAttributes().empty()) {
  96. std::vector<std::string> paramList, initList;
  97. for (auto &&i : op.getMgbAttributes()) {
  98. paramList.push_back(formatv(
  99. "{0} {1}_", attr_to_ctype(i.attr), i.name
  100. ));
  101. initList.push_back(formatv(
  102. "{0}({0}_)", i.name
  103. ));
  104. }
  105. paramList.push_back("std::string scope_ = {}");
  106. gen_ctor(llvm::join(paramList, ", "),
  107. ": " + llvm::join(initList, ", "),
  108. " { set_scope(scope_); }");
  109. }
  110. auto packedParams = op.getPackedParams();
  111. if (!packedParams.empty()) {
  112. std::vector<std::string> paramList, initList;
  113. for (auto &&p : packedParams) {
  114. auto&& paramFields = p.getFields();
  115. auto&& paramType = p.getFullName();
  116. auto&& paramName = formatv("packed_param_{0}", paramList.size());
  117. paramList.push_back(
  118. paramFields.empty() ? paramType.str()
  119. : formatv("{0} {1}", paramType, paramName)
  120. );
  121. for (auto&& i : paramFields) {
  122. initList.push_back(formatv(
  123. "{0}({1}.{0})", i.name, paramName
  124. ));
  125. }
  126. }
  127. for (auto&& i : op.getExtraArguments()) {
  128. paramList.push_back(formatv(
  129. "{0} {1}_", attr_to_ctype(i.attr), i.name
  130. ));
  131. initList.push_back(formatv(
  132. "{0}({0}_)", i.name
  133. ));
  134. }
  135. gen_ctor(llvm::join(paramList, ", "),
  136. initList.empty() ? "" : ": " + llvm::join(initList, ", "),
  137. " {}");
  138. }
  139. if (!packedParams.empty()) {
  140. for (auto&& p : packedParams) {
  141. auto accessor = p.getAccessor();
  142. if (!accessor.empty()) {
  143. os << formatv(
  144. " {0} {1}() const {{\n",
  145. p.getFullName(), accessor
  146. );
  147. std::vector<llvm::StringRef> fields;
  148. for (auto&& i : p.getFields()) {
  149. fields.push_back(i.name);
  150. }
  151. os << formatv(
  152. " return {{{0}};\n",
  153. llvm::join(fields, ", ")
  154. );
  155. os << " }\n";
  156. }
  157. }
  158. }
  159. if (auto decl = op.getExtraOpdefDecl()) {
  160. os << decl.getValue();
  161. }
  162. os << formatv(
  163. "};\n\n"
  164. );
  165. }
  166. static void gen_to_string_trait_for_enum(raw_ostream &os, MgbOp& op) {
  167. for (auto &&i : op.getMgbAttributes()) {
  168. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  169. if (attr->supportToString()) {
  170. std::vector<std::string> case_body;
  171. std::string ename = formatv("{0}::{1}",
  172. op.getCppClassName(), attr->getEnumName());
  173. llvm::for_each(attr->getEnumMembers(), [&](auto&& v){
  174. case_body.push_back(formatv(
  175. "case {0}::{1}: return \"{1}\";", ename, v));
  176. });
  177. os << formatv(R"(
  178. template <>
  179. struct ToStringTrait<{0}> {
  180. std::string operator()({0} e) const {
  181. switch (e) {
  182. {1}
  183. default:
  184. return "{0}::Unknown";
  185. }
  186. }
  187. };
  188. )", ename, llvm::join(case_body, "\n"));
  189. }
  190. }
  191. }
  192. }
  193. static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
  194. auto&& className = op.getCppClassName();
  195. os << formatv(
  196. "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className
  197. );
  198. auto formatMethImpl = [&](auto&& meth) {
  199. return formatv(
  200. "{0}_{1}_impl", className, meth
  201. );
  202. };
  203. std::vector<std::string> methods;
  204. if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
  205. os << "namespace {\n";
  206. // generate hash()
  207. mlir::tblgen::FmtContext ctx;
  208. os << formatv(
  209. "size_t {0}(const OpDef& def_) {{\n",
  210. formatMethImpl("hash")
  211. );
  212. os << formatv(
  213. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  214. " static_cast<void>(op_);\n",
  215. className
  216. );
  217. ctx.withSelf("op_");
  218. os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
  219. os << "}\n";
  220. // generate is_same_st()
  221. os << formatv(
  222. "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
  223. formatMethImpl("is_same_st")
  224. );
  225. os << formatv(
  226. " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n"
  227. " &&b_ = rhs_.cast_final_safe<{0}>();\n"
  228. " static_cast<void>(a_);\n"
  229. " static_cast<void>(b_);\n",
  230. className
  231. );
  232. os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
  233. os << "}\n";
  234. // generate props()
  235. os << formatv(
  236. "std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n",
  237. formatMethImpl("props")
  238. );
  239. os << formatv(
  240. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  241. " static_cast<void>(op_);\n",
  242. className
  243. );
  244. ctx.withSelf("op_");
  245. os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx);
  246. os << "}\n";
  247. // generate make_name()
  248. os << formatv(
  249. "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name")
  250. );
  251. os << formatv(
  252. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  253. " static_cast<void>(op_);\n",
  254. className
  255. );
  256. ctx.withSelf("op_");
  257. os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx);
  258. os << "}\n";
  259. os << "} // anonymous namespace\n";
  260. methods.push_back("hash");
  261. methods.push_back("is_same_st");
  262. methods.push_back("props");
  263. methods.push_back("make_name");
  264. }
  265. if (!methods.empty()) {
  266. os << formatv(
  267. "OP_TRAIT_REG({0}, {0})", op.getCppClassName()
  268. );
  269. for (auto&& i : methods) {
  270. os << formatv(
  271. "\n .{0}({1})", i, formatMethImpl(i)
  272. );
  273. }
  274. os << ";\n\n";
  275. }
  276. }
  277. struct EnumContext {
  278. std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias;
  279. };
  280. static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
  281. auto className = op.getCppClassName();
  282. os << formatv(
  283. "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
  284. className
  285. );
  286. for (auto&& i : op.getMgbAttributes()) {
  287. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  288. unsigned int enumID;
  289. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  290. auto&& aliasBase = alias->getAliasBase();
  291. enumID =
  292. llvm::cast<MgbEnumAttr>(aliasBase)
  293. .getBaseRecord()->getID();
  294. } else {
  295. enumID = attr->getBaseRecord()->getID();
  296. }
  297. auto&& enumAlias = ctx.enumAlias;
  298. auto&& iter = enumAlias.find(enumID);
  299. if (iter == enumAlias.end()) {
  300. os << formatv(
  301. "py::enum_<{0}::{1}>({0}Inst, \"{1}\")",
  302. className, attr->getEnumName()
  303. );
  304. std::vector<std::string> body;
  305. for (auto&& i: attr->getEnumMembers()) {
  306. os << formatv(
  307. "\n .value(\"{2}\", {0}::{1}::{2})",
  308. className, attr->getEnumName(), i
  309. );
  310. body.push_back(formatv(
  311. "if (str == \"{2}\") return {0}::{1}::{2};",
  312. className, attr->getEnumName(), i
  313. ));
  314. }
  315. if (attr->getEnumCombinedFlag()) {
  316. //! define operator |
  317. os << formatv(
  318. "\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ "
  319. "\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));"
  320. "\n })",
  321. className, attr->getEnumName());
  322. //! define operator &
  323. os << formatv(
  324. "\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{"
  325. "\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));"
  326. "\n })",
  327. className, attr->getEnumName());
  328. }
  329. os << formatv(
  330. "\n .def(py::init([](const std::string& in) {"
  331. "\n auto&& str = normalize_enum(in);"
  332. "\n {0}"
  333. "\n throw py::cast_error(\"invalid enum value \" + in);"
  334. "\n }));\n",
  335. llvm::join(body, "\n ")
  336. );
  337. os << formatv(
  338. "py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
  339. className, attr->getEnumName()
  340. );
  341. enumAlias.emplace(enumID,
  342. std::make_pair(className, attr->getEnumName()));
  343. } else {
  344. os << formatv(
  345. "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n",
  346. className, attr->getEnumName(),
  347. iter->second.first, iter->second.second
  348. );
  349. }
  350. }
  351. }
  352. // generate op class binding
  353. os << formatv("{0}Inst", className);
  354. bool hasDefaultCtor = op.getMgbAttributes().empty();
  355. if (!hasDefaultCtor) {
  356. os << "\n .def(py::init<";
  357. std::vector<llvm::StringRef> targs;
  358. for (auto &&i : op.getMgbAttributes()) {
  359. targs.push_back(i.attr.getReturnType());
  360. }
  361. os << llvm::join(targs, ", ");
  362. os << ", std::string>()";
  363. for (auto &&i : op.getMgbAttributes()) {
  364. os << formatv(", py::arg(\"{0}\")", i.name);
  365. auto defaultValue = i.attr.getDefaultValue();
  366. if (!defaultValue.empty()) {
  367. os << formatv(" = {0}", defaultValue);
  368. } else {
  369. hasDefaultCtor = true;
  370. }
  371. }
  372. os << ", py::arg(\"scope\") = {})";
  373. }
  374. if (hasDefaultCtor) {
  375. os << "\n .def(py::init<>())";
  376. }
  377. for (auto &&i : op.getMgbAttributes()) {
  378. os << formatv(
  379. "\n .def_readwrite(\"{0}\", &{1}::{0})",
  380. i.name, className
  381. );
  382. }
  383. os << ";\n\n";
  384. }
  385. static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
  386. auto className = op.getCppClassName();
  387. std::string body;
  388. // generate PyType for enum class member
  389. for (auto&& i : op.getMgbAttributes()) {
  390. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  391. unsigned int enumID;
  392. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  393. auto&& aliasBase = alias->getAliasBase();
  394. enumID =
  395. llvm::cast<MgbEnumAttr>(aliasBase)
  396. .getBaseRecord()->getID();
  397. } else {
  398. enumID = attr->getBaseRecord()->getID();
  399. }
  400. auto&& enumAlias = ctx.enumAlias;
  401. auto&& iter = enumAlias.find(enumID);
  402. auto enumName = attr->getEnumName();
  403. body += "{\n";
  404. body += formatv(
  405. "auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName
  406. );
  407. if (iter == enumAlias.end()) {
  408. os << formatv(
  409. "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n",
  410. className, enumName);
  411. os << formatv(
  412. "template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n",
  413. className, enumName);
  414. std::vector<std::string> pairStr;
  415. for (auto&& i: attr->getEnumMembers()) {
  416. pairStr.push_back(formatv(
  417. "{{normalize_enum(\"{2}\"), {0}::{1}::{2}}",
  418. className, enumName, i));
  419. }
  420. os << formatv(R"(
  421. template<> std::unordered_map<std::string, {0}::{1}>
  422. EnumWrapper<{0}::{1}>::str2type = {{
  423. {2}
  424. };
  425. )", className, enumName, llvm::join(pairStr, ", "));
  426. pairStr.clear();
  427. for (auto&& i: attr->getEnumMembers()) {
  428. pairStr.push_back(formatv(
  429. "{{{0}::{1}::{2}, normalize_enum(\"{2}\")}",
  430. className, enumName, i));
  431. }
  432. os << formatv(R"(
  433. template<> std::unordered_map<{0}::{1}, std::string>
  434. EnumWrapper<{0}::{1}>::type2str = {{
  435. {2}
  436. };
  437. )", className, enumName, llvm::join(pairStr, ", "));
  438. body += formatv(R"(
  439. e_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
  440. e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}";
  441. e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>);
  442. e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  443. e_type.tp_doc = "{0}.{1}";
  444. e_type.tp_base = &PyBaseObject_Type;
  445. e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr;
  446. e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare;
  447. mgb_assert(PyType_Ready(&e_type) >= 0);
  448. )", className, enumName);
  449. for (auto&& i: attr->getEnumMembers()) {
  450. body += formatv(R"({{
  451. PyObject* inst = e_type.tp_alloc(&e_type, 0);
  452. reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2};
  453. mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0);
  454. })", className, enumName, i);
  455. }
  456. enumAlias.emplace(enumID, std::make_pair(className, enumName));
  457. }
  458. body += formatv(R"(
  459. PyType_Modified(&e_type);
  460. mgb_assert(PyDict_SetItemString(
  461. py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0);
  462. )", enumName);
  463. body += "}\n";
  464. }
  465. }
  466. // generate getsetters
  467. std::vector<std::string> getsetters;
  468. for (auto &&i : op.getMgbAttributes()) {
  469. getsetters.push_back(formatv(
  470. "{{const_cast<char*>(\"{1}\"), py_get_generic({0}, {1}), py_set_generic({0}, {1}), const_cast<char*>(\"{1}\"), NULL},",
  471. className, i.name));
  472. }
  473. // generate tp_init
  474. std::string initBody;
  475. if (!op.getMgbAttributes().empty()) {
  476. initBody += "static const char* kwlist[] = {";
  477. std::vector<llvm::StringRef> attr_name_list;
  478. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  479. attr_name_list.push_back(attr.name);
  480. });
  481. attr_name_list.push_back("scope");
  482. llvm::for_each(attr_name_list, [&](auto&& attr) {
  483. initBody += formatv("\"{0}\", ", attr);
  484. });
  485. initBody += "NULL};\n";
  486. initBody += " PyObject ";
  487. std::vector<std::string> attr_init;
  488. llvm::for_each(attr_name_list, [&](auto&& attr) {
  489. attr_init.push_back(formatv("*{0} = NULL", attr));
  490. });
  491. initBody += llvm::join(attr_init, ", ") + ";\n";
  492. initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
  493. // an extra slot created for name
  494. initBody += std::string(attr_name_list.size(), 'O');
  495. initBody += "\", const_cast<char**>(kwlist)";
  496. llvm::for_each(attr_name_list, [&](auto&& attr) {
  497. initBody += formatv(", &{0}", attr);
  498. });
  499. initBody += "))\n";
  500. initBody += " return -1;\n";
  501. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  502. initBody += formatv(R"(
  503. if ({1}) {{
  504. try {{
  505. reinterpret_cast<PyOp({0})*>(self)->inst().{1} =
  506. pyobj_convert_generic<decltype({0}::{1})>::from({1});
  507. } CATCH_ALL(-1)
  508. }
  509. )", className, attr.name);
  510. });
  511. initBody += formatv(R"(
  512. if (scope) {{
  513. try {{
  514. reinterpret_cast<PyOp(OpDef)*>(self)->op
  515. ->set_scope(pyobj_convert_generic<std::string>::from(scope));
  516. } CATCH_ALL(-1)
  517. }
  518. )", className);
  519. }
  520. initBody += "\n return 0;";
  521. os << formatv(R"(
  522. PyOpDefBegin({0}) // {{
  523. static PyGetSetDef py_getsetters[];
  524. static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
  525. // };
  526. PyOpDefEnd({0})
  527. PyGetSetDef PyOp({0})::py_getsetters[] = {{
  528. {1}
  529. {{NULL} /* Sentinel */
  530. };
  531. int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{
  532. {2}
  533. }
  534. void _init_py_{0}(py::module m) {{
  535. using py_op = PyOp({0});
  536. auto& py_type = PyOpType({0});
  537. py_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
  538. py_type.tp_name = "megengine.core._imperative_rt.ops.{0}";
  539. py_type.tp_basicsize = sizeof(PyOp({0}));
  540. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  541. py_type.tp_doc = "{0}";
  542. py_type.tp_base = &PyOpType(OpDef);
  543. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  544. py_type.tp_new = py_new_generic<py_op>;
  545. py_type.tp_init = py_op::py_init;
  546. py_type.tp_getset = py_op::py_getsetters;
  547. mgb_assert(PyType_Ready(&py_type) >= 0);
  548. {3}
  549. PyType_Modified(&py_type);
  550. m.add_object("{0}", reinterpret_cast<PyObject*>(&py_type));
  551. mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second);
  552. }
  553. )",
  554. op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body);
  555. }
  556. static void for_each_operator(raw_ostream &os, RecordKeeper &keeper,
  557. std::function<void(raw_ostream&, MgbOp&)> callback) {
  558. auto op_base_class = keeper.getClass("Op");
  559. ASSERT(op_base_class, "could not find base class Op");
  560. for (auto&& i: keeper.getDefs()) {
  561. auto&& r = i.second;
  562. if (r->isSubClassOf(op_base_class)) {
  563. auto op = mlir::tblgen::Operator(r.get());
  564. if (op.getDialectName().str() == "mgb") {
  565. std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl;
  566. callback(os, llvm::cast<MgbOp>(op));
  567. }
  568. }
  569. }
  570. }
  571. static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) {
  572. for_each_operator(os, keeper, gen_op_def_c_header_single);
  573. for_each_operator(os, keeper, gen_to_string_trait_for_enum);
  574. return false;
  575. }
  576. static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) {
  577. for_each_operator(os, keeper, gen_op_def_c_body_single);
  578. return false;
  579. }
  580. static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) {
  581. EnumContext ctx;
  582. using namespace std::placeholders;
  583. for_each_operator(os, keeper,
  584. std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx)));
  585. return false;
  586. }
  587. static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) {
  588. EnumContext ctx;
  589. using namespace std::placeholders;
  590. for_each_operator(os, keeper,
  591. std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx)));
  592. os << "#define INIT_ALL_OP(m)";
  593. for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) {
  594. os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName());
  595. });
  596. os << "\n";
  597. return false;
  598. }
  599. int main(int argc, char **argv) {
  600. llvm::InitLLVM y(argc, argv);
  601. llvm::cl::ParseCommandLineOptions(argc, argv);
  602. if (action == ActionType::CppHeader) {
  603. return TableGenMain(argv[0], &gen_op_def_c_header);
  604. }
  605. if (action == ActionType::CppBody) {
  606. return TableGenMain(argv[0], &gen_op_def_c_body);
  607. }
  608. if (action == ActionType::Pybind) {
  609. return TableGenMain(argv[0], &gen_op_def_pybind11);
  610. }
  611. if (action == ActionType::CPython) {
  612. return TableGenMain(argv[0], &gen_op_def_python_c_extension);
  613. }
  614. return -1;
  615. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台