You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autogen.cpp 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. #include <iostream>
  2. #include <unordered_map>
  3. #include <functional>
  4. #include "./helper.h"
  5. using llvm::raw_ostream;
  6. using llvm::RecordKeeper;
  7. enum ActionType {
  8. None,
  9. CppHeader,
  10. CppBody,
  11. Pybind,
  12. CPython
  13. };
  14. // NOLINTNEXTLINE
  15. llvm::cl::opt<ActionType> action(
  16. llvm::cl::desc("Action to perform:"),
  17. llvm::cl::values(clEnumValN(CppHeader, "gen-cpp-header",
  18. "Generate operator cpp header"),
  19. clEnumValN(CppBody, "gen-cpp-body",
  20. "Generate operator cpp body"),
  21. clEnumValN(Pybind, "gen-python-binding",
  22. "Generate pybind11 python bindings"),
  23. clEnumValN(CPython, "gen-python-c-extension",
  24. "Generate python c extensions")));
  25. using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase;
  26. using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin;
  27. using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin;
  28. using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin;
  29. using MgbOp = mlir::tblgen::MgbOpBase;
  30. using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin;
  31. llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
  32. // Note: we have already registered the corresponding attr wrappers
  33. // for following basic ctypes so we needn't handle them here
  34. /* auto&& attr_type_name = attr.getAttrDefName();
  35. if (attr_type_name == "UI32Attr") {
  36. return "uint32_t";
  37. }
  38. if (attr_type_name == "UI64Attr") {
  39. return "uint64_t";
  40. }
  41. if (attr_type_name == "I32Attr") {
  42. return "int32_t";
  43. }
  44. if (attr_type_name == "F32Attr") {
  45. return "float";
  46. }
  47. if (attr_type_name == "F64Attr") {
  48. return "double";
  49. }
  50. if (attr_type_name == "StrAttr") {
  51. return "std::string";
  52. }
  53. if (attr_type_name == "BoolAttr") {
  54. return "bool";
  55. }*/
  56. auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
  57. if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
  58. return e->getEnumName();
  59. }
  60. return attr.getUnderlyingType();
  61. }
  62. static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) {
  63. os << formatv(
  64. "class {0} : public OpDefImplBase<{0}> {{\n"
  65. " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
  66. "public:\n",
  67. op.getCppClassName()
  68. );
  69. // handle enum alias
  70. for (auto &&i : op.getMgbAttributes()) {
  71. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  72. os << formatv(
  73. " using {0} = {1};\n",
  74. attr->getEnumName(), attr->getUnderlyingType()
  75. );
  76. }
  77. }
  78. for (auto &&i : op.getMgbAttributes()) {
  79. auto defaultValue = i.attr.getDefaultValue().str();
  80. if (!defaultValue.empty()) {
  81. defaultValue = formatv(" = {0}", defaultValue);
  82. }
  83. os << formatv(
  84. " {0} {1}{2};\n",
  85. attr_to_ctype(i.attr), i.name, defaultValue
  86. );
  87. }
  88. auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
  89. os << formatv(
  90. " {0}({1}){2}{3}\n",
  91. op.getCppClassName(), paramList, memInitList, body
  92. );
  93. };
  94. gen_ctor("", "", " = default;");
  95. if (!op.getMgbAttributes().empty()) {
  96. std::vector<std::string> paramList, initList;
  97. for (auto &&i : op.getMgbAttributes()) {
  98. paramList.push_back(formatv(
  99. "{0} {1}_", attr_to_ctype(i.attr), i.name
  100. ));
  101. initList.push_back(formatv(
  102. "{0}({0}_)", i.name
  103. ));
  104. }
  105. paramList.push_back("std::string scope_ = {}");
  106. gen_ctor(llvm::join(paramList, ", "),
  107. ": " + llvm::join(initList, ", "),
  108. " { set_scope(scope_); }");
  109. }
  110. auto packedParams = op.getPackedParams();
  111. if (!packedParams.empty()) {
  112. std::vector<std::string> paramList, initList;
  113. for (auto &&p : packedParams) {
  114. auto&& paramFields = p.getFields();
  115. auto&& paramType = p.getFullName();
  116. auto&& paramName = formatv("packed_param_{0}", paramList.size());
  117. paramList.push_back(
  118. paramFields.empty() ? paramType.str()
  119. : formatv("{0} {1}", paramType, paramName)
  120. );
  121. for (auto&& i : paramFields) {
  122. initList.push_back(formatv(
  123. "{0}({1}.{0})", i.name, paramName
  124. ));
  125. }
  126. }
  127. for (auto&& i : op.getExtraArguments()) {
  128. paramList.push_back(formatv(
  129. "{0} {1}_", attr_to_ctype(i.attr), i.name
  130. ));
  131. initList.push_back(formatv(
  132. "{0}({0}_)", i.name
  133. ));
  134. }
  135. gen_ctor(llvm::join(paramList, ", "),
  136. initList.empty() ? "" : ": " + llvm::join(initList, ", "),
  137. " {}");
  138. }
  139. if (!packedParams.empty()) {
  140. for (auto&& p : packedParams) {
  141. auto accessor = p.getAccessor();
  142. if (!accessor.empty()) {
  143. os << formatv(
  144. " {0} {1}() const {{\n",
  145. p.getFullName(), accessor
  146. );
  147. std::vector<llvm::StringRef> fields;
  148. for (auto&& i : p.getFields()) {
  149. fields.push_back(i.name);
  150. }
  151. os << formatv(
  152. " return {{{0}};\n",
  153. llvm::join(fields, ", ")
  154. );
  155. os << " }\n";
  156. }
  157. }
  158. }
  159. if (auto decl = op.getExtraOpdefDecl()) {
  160. os << decl.getValue();
  161. }
  162. os << formatv(
  163. "};\n\n"
  164. );
  165. }
  166. static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
  167. auto&& className = op.getCppClassName();
  168. os << formatv(
  169. "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className
  170. );
  171. auto formatMethImpl = [&](auto&& meth) {
  172. return formatv(
  173. "{0}_{1}_impl", className, meth
  174. );
  175. };
  176. std::vector<std::string> methods;
  177. if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
  178. os << "namespace {\n";
  179. // generate hash()
  180. mlir::tblgen::FmtContext ctx;
  181. os << formatv(
  182. "size_t {0}(const OpDef& def_) {{\n",
  183. formatMethImpl("hash")
  184. );
  185. os << formatv(
  186. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  187. " static_cast<void>(op_);\n",
  188. className
  189. );
  190. ctx.withSelf("op_");
  191. os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
  192. os << "}\n";
  193. // generate is_same_st()
  194. os << formatv(
  195. "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
  196. formatMethImpl("is_same_st")
  197. );
  198. os << formatv(
  199. " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n"
  200. " &&b_ = rhs_.cast_final_safe<{0}>();\n"
  201. " static_cast<void>(a_);\n"
  202. " static_cast<void>(b_);\n",
  203. className
  204. );
  205. os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
  206. os << "}\n";
  207. // generate props()
  208. os << formatv(
  209. "std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n",
  210. formatMethImpl("props")
  211. );
  212. os << formatv(
  213. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  214. " static_cast<void>(op_);\n",
  215. className
  216. );
  217. ctx.withSelf("op_");
  218. os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx);
  219. os << "}\n";
  220. // generate make_name()
  221. os << formatv(
  222. "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name")
  223. );
  224. os << mlir::tblgen::tgfmt(hashable->getNameFunctionTemplate(), &ctx);
  225. os << "}\n";
  226. os << "} // anonymous namespace\n";
  227. methods.push_back("hash");
  228. methods.push_back("is_same_st");
  229. methods.push_back("props");
  230. methods.push_back("make_name");
  231. }
  232. if (!methods.empty()) {
  233. os << formatv(
  234. "OP_TRAIT_REG({0}, {0})", op.getCppClassName()
  235. );
  236. for (auto&& i : methods) {
  237. os << formatv(
  238. "\n .{0}({1})", i, formatMethImpl(i)
  239. );
  240. }
  241. os << ";\n\n";
  242. }
  243. }
  244. struct EnumContext {
  245. std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias;
  246. };
  247. static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
  248. auto className = op.getCppClassName();
  249. os << formatv(
  250. "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
  251. className
  252. );
  253. for (auto&& i : op.getMgbAttributes()) {
  254. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  255. unsigned int enumID;
  256. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  257. auto&& aliasBase = alias->getAliasBase();
  258. enumID =
  259. llvm::cast<MgbEnumAttr>(aliasBase)
  260. .getBaseRecord()->getID();
  261. } else {
  262. enumID = attr->getBaseRecord()->getID();
  263. }
  264. auto&& enumAlias = ctx.enumAlias;
  265. auto&& iter = enumAlias.find(enumID);
  266. if (iter == enumAlias.end()) {
  267. os << formatv(
  268. "py::enum_<{0}::{1}>({0}Inst, \"{1}\")",
  269. className, attr->getEnumName()
  270. );
  271. std::vector<std::string> body;
  272. for (auto&& i: attr->getEnumMembers()) {
  273. os << formatv(
  274. "\n .value(\"{2}\", {0}::{1}::{2})",
  275. className, attr->getEnumName(), i
  276. );
  277. body.push_back(formatv(
  278. "if (str == \"{2}\") return {0}::{1}::{2};",
  279. className, attr->getEnumName(), i
  280. ));
  281. }
  282. os << formatv(
  283. "\n .def(py::init([](const std::string& in) {"
  284. "\n auto&& str = normalize_enum(in);"
  285. "\n {0}"
  286. "\n throw py::cast_error(\"invalid enum value \" + in);"
  287. "\n }));\n",
  288. llvm::join(body, "\n ")
  289. );
  290. os << formatv(
  291. "py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
  292. className, attr->getEnumName()
  293. );
  294. enumAlias.emplace(enumID,
  295. std::make_pair(className, attr->getEnumName()));
  296. } else {
  297. os << formatv(
  298. "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n",
  299. className, attr->getEnumName(),
  300. iter->second.first, iter->second.second
  301. );
  302. }
  303. }
  304. }
  305. // generate op class binding
  306. os << formatv("{0}Inst", className);
  307. bool hasDefaultCtor = op.getMgbAttributes().empty();
  308. if (!hasDefaultCtor) {
  309. os << "\n .def(py::init<";
  310. std::vector<llvm::StringRef> targs;
  311. for (auto &&i : op.getMgbAttributes()) {
  312. targs.push_back(i.attr.getReturnType());
  313. }
  314. os << llvm::join(targs, ", ");
  315. os << ", std::string>()";
  316. for (auto &&i : op.getMgbAttributes()) {
  317. os << formatv(", py::arg(\"{0}\")", i.name);
  318. auto defaultValue = i.attr.getDefaultValue();
  319. if (!defaultValue.empty()) {
  320. os << formatv(" = {0}", defaultValue);
  321. } else {
  322. hasDefaultCtor = true;
  323. }
  324. }
  325. os << ", py::arg(\"scope\") = {})";
  326. }
  327. if (hasDefaultCtor) {
  328. os << "\n .def(py::init<>())";
  329. }
  330. for (auto &&i : op.getMgbAttributes()) {
  331. os << formatv(
  332. "\n .def_readwrite(\"{0}\", &{1}::{0})",
  333. i.name, className
  334. );
  335. }
  336. os << ";\n\n";
  337. }
  338. static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
  339. auto className = op.getCppClassName();
  340. std::string body;
  341. // generate PyType for enum class member
  342. for (auto&& i : op.getMgbAttributes()) {
  343. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  344. unsigned int enumID;
  345. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  346. auto&& aliasBase = alias->getAliasBase();
  347. enumID =
  348. llvm::cast<MgbEnumAttr>(aliasBase)
  349. .getBaseRecord()->getID();
  350. } else {
  351. enumID = attr->getBaseRecord()->getID();
  352. }
  353. auto&& enumAlias = ctx.enumAlias;
  354. auto&& iter = enumAlias.find(enumID);
  355. auto enumName = attr->getEnumName();
  356. body += "{\n";
  357. body += formatv(
  358. "auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName
  359. );
  360. if (iter == enumAlias.end()) {
  361. os << formatv(
  362. "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n",
  363. className, enumName);
  364. os << formatv(
  365. "template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n",
  366. className, enumName);
  367. std::vector<std::string> pairStr;
  368. for (auto&& i: attr->getEnumMembers()) {
  369. pairStr.push_back(formatv(
  370. "{{normalize_enum(\"{2}\"), {0}::{1}::{2}}",
  371. className, enumName, i));
  372. }
  373. os << formatv(R"(
  374. template<> std::unordered_map<std::string, {0}::{1}>
  375. EnumWrapper<{0}::{1}>::str2type = {{
  376. {2}
  377. };
  378. )", className, enumName, llvm::join(pairStr, ", "));
  379. pairStr.clear();
  380. for (auto&& i: attr->getEnumMembers()) {
  381. pairStr.push_back(formatv(
  382. "{{{0}::{1}::{2}, normalize_enum(\"{2}\")}",
  383. className, enumName, i));
  384. }
  385. os << formatv(R"(
  386. template<> std::unordered_map<{0}::{1}, std::string>
  387. EnumWrapper<{0}::{1}>::type2str = {{
  388. {2}
  389. };
  390. )", className, enumName, llvm::join(pairStr, ", "));
  391. body += formatv(R"(
  392. e_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
  393. e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}";
  394. e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>);
  395. e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  396. e_type.tp_doc = "{0}.{1}";
  397. e_type.tp_base = &PyBaseObject_Type;
  398. e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr;
  399. e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare;
  400. mgb_assert(PyType_Ready(&e_type) >= 0);
  401. )", className, enumName);
  402. for (auto&& i: attr->getEnumMembers()) {
  403. body += formatv(R"({{
  404. PyObject* inst = e_type.tp_alloc(&e_type, 0);
  405. reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2};
  406. mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0);
  407. })", className, enumName, i);
  408. }
  409. enumAlias.emplace(enumID, std::make_pair(className, enumName));
  410. }
  411. body += formatv(R"(
  412. PyType_Modified(&e_type);
  413. mgb_assert(PyDict_SetItemString(
  414. py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0);
  415. )", enumName);
  416. body += "}\n";
  417. }
  418. }
  419. // generate getsetters
  420. std::vector<std::string> getsetters;
  421. for (auto &&i : op.getMgbAttributes()) {
  422. getsetters.push_back(formatv(
  423. "{{const_cast<char*>(\"{1}\"), py_get_generic({0}, {1}), py_set_generic({0}, {1}), const_cast<char*>(\"{1}\"), NULL},",
  424. className, i.name));
  425. }
  426. getsetters.push_back(formatv(
  427. "{{\"scope\", py_get_scope({0}), py_set_scope({0}), \"scope\", NULL},",
  428. className));
  429. // generate tp_init
  430. std::string initBody;
  431. if (!op.getMgbAttributes().empty()) {
  432. initBody += "static const char* kwlist[] = {";
  433. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  434. initBody += formatv("\"{0}\", ", attr.name);
  435. });
  436. initBody += "\"scope\", ";
  437. initBody += "NULL};\n";
  438. initBody += " PyObject ";
  439. std::vector<std::string> attrs;
  440. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  441. attrs.push_back(formatv("*{0} = NULL", attr.name));
  442. });
  443. initBody += llvm::join(attrs, ", ") + ";\n";
  444. initBody += " PyObject *scope = NULL;\n";
  445. initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
  446. // an extra slot created for name
  447. initBody += std::string(op.getMgbAttributes().size() + 1, 'O');
  448. initBody += "\", const_cast<char**>(kwlist)";
  449. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  450. initBody += formatv(", &{0}", attr.name);
  451. });
  452. initBody += ", &scope";
  453. initBody += "))\n";
  454. initBody += " return -1;\n";
  455. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  456. initBody += formatv(R"(
  457. if ({1}) {{
  458. try {{
  459. reinterpret_cast<PyOp({0})*>(self)->inst().{1} =
  460. pyobj_convert_generic<decltype({0}::{1})>::from({1});
  461. } catch(py::error_already_set& e) {{
  462. e.restore();
  463. return -1;
  464. } catch(py::builtin_exception& e) {{
  465. e.set_error();
  466. return -1;
  467. } catch(...) {{
  468. PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
  469. return -1;
  470. }
  471. }
  472. )", className, attr.name);
  473. });
  474. initBody += formatv(R"(
  475. if (scope) {{
  476. try {{
  477. reinterpret_cast<PyOp({0})*>(self)->inst().set_scope(
  478. pyobj_convert_generic<std::string>::from(scope));
  479. } catch(py::error_already_set& e) {{
  480. e.restore();
  481. return -1;
  482. } catch(py::builtin_exception& e) {{
  483. e.set_error();
  484. return -1;
  485. } catch(...) {{
  486. PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
  487. return -1;
  488. }
  489. }
  490. )", className);
  491. }
  492. initBody += "\n return 0;";
  493. os << formatv(R"(
  494. PyOpDefBegin({0}) // {{
  495. static PyGetSetDef py_getsetters[];
  496. static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
  497. // };
  498. PyOpDefEnd({0})
  499. PyGetSetDef PyOp({0})::py_getsetters[] = {{
  500. {1}
  501. {{NULL} /* Sentinel */
  502. };
  503. int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{
  504. {2}
  505. }
  506. void _init_py_{0}(py::module m) {{
  507. using py_op = PyOp({0});
  508. auto& py_type = PyOpType({0});
  509. py_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
  510. py_type.tp_name = "megengine.core._imperative_rt.ops.{0}";
  511. py_type.tp_basicsize = sizeof(PyOp({0}));
  512. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  513. py_type.tp_doc = "{0}";
  514. py_type.tp_base = &PyOpType(OpDef);
  515. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  516. py_type.tp_new = py_new_generic<py_op>;
  517. py_type.tp_init = py_op::py_init;
  518. py_type.tp_getset = py_op::py_getsetters;
  519. mgb_assert(PyType_Ready(&py_type) >= 0);
  520. {3}
  521. PyType_Modified(&py_type);
  522. m.add_object("{0}", reinterpret_cast<PyObject*>(&py_type));
  523. mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second);
  524. }
  525. )",
  526. op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body);
  527. }
  528. static void for_each_operator(raw_ostream &os, RecordKeeper &keeper,
  529. std::function<void(raw_ostream&, MgbOp&)> callback) {
  530. auto op_base_class = keeper.getClass("Op");
  531. ASSERT(op_base_class, "could not find base class Op");
  532. for (auto&& i: keeper.getDefs()) {
  533. auto&& r = i.second;
  534. if (r->isSubClassOf(op_base_class)) {
  535. auto op = mlir::tblgen::Operator(r.get());
  536. if (op.getDialectName().str() == "mgb") {
  537. std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl;
  538. callback(os, llvm::cast<MgbOp>(op));
  539. }
  540. }
  541. }
  542. }
  543. static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) {
  544. for_each_operator(os, keeper, gen_op_def_c_header_single);
  545. return false;
  546. }
  547. static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) {
  548. for_each_operator(os, keeper, gen_op_def_c_body_single);
  549. return false;
  550. }
  551. static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) {
  552. EnumContext ctx;
  553. using namespace std::placeholders;
  554. for_each_operator(os, keeper,
  555. std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx)));
  556. return false;
  557. }
  558. static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) {
  559. EnumContext ctx;
  560. using namespace std::placeholders;
  561. for_each_operator(os, keeper,
  562. std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx)));
  563. os << "#define INIT_ALL_OP(m)";
  564. for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) {
  565. os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName());
  566. });
  567. os << "\n";
  568. return false;
  569. }
  570. int main(int argc, char **argv) {
  571. llvm::InitLLVM y(argc, argv);
  572. llvm::cl::ParseCommandLineOptions(argc, argv);
  573. if (action == ActionType::CppHeader) {
  574. return TableGenMain(argv[0], &gen_op_def_c_header);
  575. }
  576. if (action == ActionType::CppBody) {
  577. return TableGenMain(argv[0], &gen_op_def_c_body);
  578. }
  579. if (action == ActionType::Pybind) {
  580. return TableGenMain(argv[0], &gen_op_def_pybind11);
  581. }
  582. if (action == ActionType::CPython) {
  583. return TableGenMain(argv[0], &gen_op_def_python_c_extension);
  584. }
  585. return -1;
  586. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台