You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autogen.cpp 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. #include <iostream>
  2. #include <unordered_map>
  3. #include <functional>
  4. #include "./helper.h"
  5. using llvm::raw_ostream;
  6. using llvm::RecordKeeper;
  7. enum ActionType {
  8. None,
  9. CppHeader,
  10. CppBody,
  11. Pybind,
  12. CPython
  13. };
  14. // NOLINTNEXTLINE
  15. llvm::cl::opt<ActionType> action(
  16. llvm::cl::desc("Action to perform:"),
  17. llvm::cl::values(clEnumValN(CppHeader, "gen-cpp-header",
  18. "Generate operator cpp header"),
  19. clEnumValN(CppBody, "gen-cpp-body",
  20. "Generate operator cpp body"),
  21. clEnumValN(Pybind, "gen-python-binding",
  22. "Generate pybind11 python bindings"),
  23. clEnumValN(CPython, "gen-python-c-extension",
  24. "Generate python c extensions")));
  25. using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase;
  26. using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin;
  27. using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin;
  28. using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin;
  29. using MgbOp = mlir::tblgen::MgbOpBase;
  30. using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin;
  31. llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
  32. // Note: we have already registered the corresponding attr wrappers
  33. // for following basic ctypes so we needn't handle them here
  34. /* auto&& attr_type_name = attr.getAttrDefName();
  35. if (attr_type_name == "UI32Attr") {
  36. return "uint32_t";
  37. }
  38. if (attr_type_name == "UI64Attr") {
  39. return "uint64_t";
  40. }
  41. if (attr_type_name == "I32Attr") {
  42. return "int32_t";
  43. }
  44. if (attr_type_name == "F32Attr") {
  45. return "float";
  46. }
  47. if (attr_type_name == "F64Attr") {
  48. return "double";
  49. }
  50. if (attr_type_name == "StrAttr") {
  51. return "std::string";
  52. }
  53. if (attr_type_name == "BoolAttr") {
  54. return "bool";
  55. }*/
  56. auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
  57. if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
  58. return e->getEnumName();
  59. }
  60. return attr.getUnderlyingType();
  61. }
  62. static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) {
  63. os << formatv(
  64. "class {0} : public OpDefImplBase<{0}> {{\n"
  65. " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
  66. "public:\n",
  67. op.getCppClassName()
  68. );
  69. // handle enum alias
  70. for (auto &&i : op.getMgbAttributes()) {
  71. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  72. os << formatv(
  73. " using {0} = {1};\n",
  74. attr->getEnumName(), attr->getUnderlyingType()
  75. );
  76. }
  77. }
  78. for (auto &&i : op.getMgbAttributes()) {
  79. auto defaultValue = i.attr.getDefaultValue().str();
  80. if (!defaultValue.empty()) {
  81. defaultValue = formatv(" = {0}", defaultValue);
  82. }
  83. os << formatv(
  84. " {0} {1}{2};\n",
  85. attr_to_ctype(i.attr), i.name, defaultValue
  86. );
  87. }
  88. auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
  89. os << formatv(
  90. " {0}({1}){2}{3}\n",
  91. op.getCppClassName(), paramList, memInitList, body
  92. );
  93. };
  94. gen_ctor("", "", " = default;");
  95. if (!op.getMgbAttributes().empty()) {
  96. std::vector<std::string> paramList, initList;
  97. for (auto &&i : op.getMgbAttributes()) {
  98. paramList.push_back(formatv(
  99. "{0} {1}_", attr_to_ctype(i.attr), i.name
  100. ));
  101. initList.push_back(formatv(
  102. "{0}({0}_)", i.name
  103. ));
  104. }
  105. gen_ctor(llvm::join(paramList, ", "),
  106. ": " + llvm::join(initList, ", "),
  107. " {}");
  108. }
  109. auto packedParams = op.getPackedParams();
  110. if (!packedParams.empty()) {
  111. std::vector<std::string> paramList, initList;
  112. for (auto &&p : packedParams) {
  113. auto&& paramFields = p.getFields();
  114. auto&& paramType = p.getFullName();
  115. auto&& paramName = formatv("packed_param_{0}", paramList.size());
  116. paramList.push_back(
  117. paramFields.empty() ? paramType.str()
  118. : formatv("{0} {1}", paramType, paramName)
  119. );
  120. for (auto&& i : paramFields) {
  121. initList.push_back(formatv(
  122. "{0}({1}.{0})", i.name, paramName
  123. ));
  124. }
  125. }
  126. for (auto&& i : op.getExtraArguments()) {
  127. paramList.push_back(formatv(
  128. "{0} {1}_", attr_to_ctype(i.attr), i.name
  129. ));
  130. initList.push_back(formatv(
  131. "{0}({0}_)", i.name
  132. ));
  133. }
  134. gen_ctor(llvm::join(paramList, ", "),
  135. initList.empty() ? "" : ": " + llvm::join(initList, ", "),
  136. " {}");
  137. }
  138. if (!packedParams.empty()) {
  139. for (auto&& p : packedParams) {
  140. auto accessor = p.getAccessor();
  141. if (!accessor.empty()) {
  142. os << formatv(
  143. " {0} {1}() const {{\n",
  144. p.getFullName(), accessor
  145. );
  146. std::vector<llvm::StringRef> fields;
  147. for (auto&& i : p.getFields()) {
  148. fields.push_back(i.name);
  149. }
  150. os << formatv(
  151. " return {{{0}};\n",
  152. llvm::join(fields, ", ")
  153. );
  154. os << " }\n";
  155. }
  156. }
  157. }
  158. if (auto decl = op.getExtraOpdefDecl()) {
  159. os << decl.getValue();
  160. }
  161. os << formatv(
  162. "};\n\n"
  163. );
  164. }
  165. static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
  166. auto&& className = op.getCppClassName();
  167. os << formatv(
  168. "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className
  169. );
  170. auto formatMethImpl = [&](auto&& meth) {
  171. return formatv(
  172. "{0}_{1}_impl", className, meth
  173. );
  174. };
  175. std::vector<std::string> methods;
  176. if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
  177. os << "namespace {\n";
  178. // generate hash()
  179. mlir::tblgen::FmtContext ctx;
  180. os << formatv(
  181. "size_t {0}(const OpDef& def_) {{\n",
  182. formatMethImpl("hash")
  183. );
  184. os << formatv(
  185. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  186. " static_cast<void>(op_);\n",
  187. className
  188. );
  189. ctx.withSelf("op_");
  190. os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
  191. os << "}\n";
  192. // generate is_same_st()
  193. os << formatv(
  194. "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
  195. formatMethImpl("is_same_st")
  196. );
  197. os << formatv(
  198. " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n"
  199. " &&b_ = rhs_.cast_final_safe<{0}>();\n"
  200. " static_cast<void>(a_);\n"
  201. " static_cast<void>(b_);\n",
  202. className
  203. );
  204. os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
  205. os << "}\n";
  206. // generate props()
  207. os << formatv(
  208. "std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n",
  209. formatMethImpl("props")
  210. );
  211. os << formatv(
  212. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  213. " static_cast<void>(op_);\n",
  214. className
  215. );
  216. ctx.withSelf("op_");
  217. os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx);
  218. os << "}\n";
  219. os << "} // anonymous namespace\n";
  220. methods.push_back("hash");
  221. methods.push_back("is_same_st");
  222. methods.push_back("props");
  223. }
  224. if (!methods.empty()) {
  225. os << formatv(
  226. "OP_TRAIT_REG({0}, {0})", op.getCppClassName()
  227. );
  228. for (auto&& i : methods) {
  229. os << formatv(
  230. "\n .{0}({1})", i, formatMethImpl(i)
  231. );
  232. }
  233. os << ";\n\n";
  234. }
  235. }
  236. struct EnumContext {
  237. std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias;
  238. };
  239. static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
  240. auto className = op.getCppClassName();
  241. os << formatv(
  242. "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
  243. className
  244. );
  245. for (auto&& i : op.getMgbAttributes()) {
  246. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  247. unsigned int enumID;
  248. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  249. auto&& aliasBase = alias->getAliasBase();
  250. enumID =
  251. llvm::cast<MgbEnumAttr>(aliasBase)
  252. .getBaseRecord()->getID();
  253. } else {
  254. enumID = attr->getBaseRecord()->getID();
  255. }
  256. auto&& enumAlias = ctx.enumAlias;
  257. auto&& iter = enumAlias.find(enumID);
  258. if (iter == enumAlias.end()) {
  259. os << formatv(
  260. "py::enum_<{0}::{1}>({0}Inst, \"{1}\")",
  261. className, attr->getEnumName()
  262. );
  263. std::vector<std::string> body;
  264. for (auto&& i: attr->getEnumMembers()) {
  265. os << formatv(
  266. "\n .value(\"{2}\", {0}::{1}::{2})",
  267. className, attr->getEnumName(), i
  268. );
  269. body.push_back(formatv(
  270. "if (str == \"{2}\") return {0}::{1}::{2};",
  271. className, attr->getEnumName(), i
  272. ));
  273. }
  274. os << formatv(
  275. "\n .def(py::init([](const std::string& in) {"
  276. "\n auto&& str = normalize_enum(in);"
  277. "\n {0}"
  278. "\n throw py::cast_error(\"invalid enum value \" + in);"
  279. "\n }));\n",
  280. llvm::join(body, "\n ")
  281. );
  282. os << formatv(
  283. "py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
  284. className, attr->getEnumName()
  285. );
  286. enumAlias.emplace(enumID,
  287. std::make_pair(className, attr->getEnumName()));
  288. } else {
  289. os << formatv(
  290. "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n",
  291. className, attr->getEnumName(),
  292. iter->second.first, iter->second.second
  293. );
  294. }
  295. }
  296. }
  297. // generate op class binding
  298. os << formatv("{0}Inst", className);
  299. bool hasDefaultCtor = op.getMgbAttributes().empty();
  300. if (!hasDefaultCtor) {
  301. os << "\n .def(py::init<";
  302. std::vector<llvm::StringRef> targs;
  303. for (auto &&i : op.getMgbAttributes()) {
  304. targs.push_back(i.attr.getReturnType());
  305. }
  306. os << llvm::join(targs, ", ");
  307. os << ">()";
  308. for (auto &&i : op.getMgbAttributes()) {
  309. os << formatv(", py::arg(\"{0}\")", i.name);
  310. auto defaultValue = i.attr.getDefaultValue();
  311. if (!defaultValue.empty()) {
  312. os << formatv(" = {0}", defaultValue);
  313. } else {
  314. hasDefaultCtor = true;
  315. }
  316. }
  317. os << ")";
  318. }
  319. if (hasDefaultCtor) {
  320. os << "\n .def(py::init<>())";
  321. }
  322. for (auto &&i : op.getMgbAttributes()) {
  323. os << formatv(
  324. "\n .def_readwrite(\"{0}\", &{1}::{0})",
  325. i.name, className
  326. );
  327. }
  328. os << ";\n\n";
  329. }
  330. static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
  331. auto className = op.getCppClassName();
  332. std::string body;
  333. // generate PyType for enum class member
  334. for (auto&& i : op.getMgbAttributes()) {
  335. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  336. unsigned int enumID;
  337. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  338. auto&& aliasBase = alias->getAliasBase();
  339. enumID =
  340. llvm::cast<MgbEnumAttr>(aliasBase)
  341. .getBaseRecord()->getID();
  342. } else {
  343. enumID = attr->getBaseRecord()->getID();
  344. }
  345. auto&& enumAlias = ctx.enumAlias;
  346. auto&& iter = enumAlias.find(enumID);
  347. auto enumName = attr->getEnumName();
  348. body += "{\n";
  349. body += formatv(
  350. "auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName
  351. );
  352. if (iter == enumAlias.end()) {
  353. os << formatv(
  354. "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n",
  355. className, enumName);
  356. os << formatv(
  357. "template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n",
  358. className, enumName);
  359. std::vector<std::string> pairStr;
  360. for (auto&& i: attr->getEnumMembers()) {
  361. pairStr.push_back(formatv(
  362. "{{normalize_enum(\"{2}\"), {0}::{1}::{2}}",
  363. className, enumName, i));
  364. }
  365. os << formatv(R"(
  366. template<> std::unordered_map<std::string, {0}::{1}>
  367. EnumWrapper<{0}::{1}>::str2type = {{
  368. {2}
  369. };
  370. )", className, enumName, llvm::join(pairStr, ", "));
  371. pairStr.clear();
  372. for (auto&& i: attr->getEnumMembers()) {
  373. pairStr.push_back(formatv(
  374. "{{{0}::{1}::{2}, normalize_enum(\"{2}\")}",
  375. className, enumName, i));
  376. }
  377. os << formatv(R"(
  378. template<> std::unordered_map<{0}::{1}, std::string>
  379. EnumWrapper<{0}::{1}>::type2str = {{
  380. {2}
  381. };
  382. )", className, enumName, llvm::join(pairStr, ", "));
  383. body += formatv(R"(
  384. e_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
  385. e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}";
  386. e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>);
  387. e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  388. e_type.tp_doc = "{0}.{1}";
  389. e_type.tp_base = &PyBaseObject_Type;
  390. e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr;
  391. e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare;
  392. mgb_assert(PyType_Ready(&e_type) >= 0);
  393. )", className, enumName);
  394. for (auto&& i: attr->getEnumMembers()) {
  395. body += formatv(R"({{
  396. PyObject* inst = e_type.tp_alloc(&e_type, 0);
  397. reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2};
  398. mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0);
  399. })", className, enumName, i);
  400. }
  401. enumAlias.emplace(enumID, std::make_pair(className, enumName));
  402. }
  403. body += formatv(R"(
  404. PyType_Modified(&e_type);
  405. mgb_assert(PyDict_SetItemString(
  406. py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0);
  407. )", enumName);
  408. body += "}\n";
  409. }
  410. }
  411. // generate getsetters
  412. std::vector<std::string> getsetters;
  413. for (auto &&i : op.getMgbAttributes()) {
  414. getsetters.push_back(formatv(
  415. "{{const_cast<char*>(\"{1}\"), py_get_generic({0}, {1}), py_set_generic({0}, {1}), const_cast<char*>(\"{1}\"), NULL},",
  416. className, i.name));
  417. }
  418. // generate tp_init
  419. std::string initBody;
  420. if (!op.getMgbAttributes().empty()) {
  421. initBody += "static const char* kwlist[] = {";
  422. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  423. initBody += formatv("\"{0}\", ", attr.name);
  424. });
  425. initBody += "NULL};\n";
  426. initBody += " PyObject ";
  427. std::vector<std::string> attrs;
  428. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  429. attrs.push_back(formatv("*{0} = NULL", attr.name));
  430. });
  431. initBody += llvm::join(attrs, ", ") + ";\n";
  432. initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
  433. initBody += std::string(op.getMgbAttributes().size(), 'O');
  434. initBody += "\", const_cast<char**>(kwlist)";
  435. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  436. initBody += formatv(" ,&{0}", attr.name);
  437. });
  438. initBody += "))\n";
  439. initBody += " return -1;\n";
  440. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  441. initBody += formatv(R"(
  442. if ({1}) {{
  443. try {{
  444. reinterpret_cast<PyOp({0})*>(self)->inst().{1} =
  445. pyobj_convert_generic<decltype({0}::{1})>::from({1});
  446. } catch(py::error_already_set& e) {{
  447. e.restore();
  448. return -1;
  449. } catch(py::builtin_exception& e) {{
  450. e.set_error();
  451. return -1;
  452. } catch(...) {{
  453. PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
  454. return -1;
  455. }
  456. }
  457. )", className, attr.name);
  458. });
  459. }
  460. initBody += "\n return 0;";
  461. os << formatv(R"(
  462. PyOpDefBegin({0}) // {{
  463. static PyGetSetDef py_getsetters[];
  464. static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
  465. // };
  466. PyOpDefEnd({0})
  467. PyGetSetDef PyOp({0})::py_getsetters[] = {{
  468. {1}
  469. {{NULL} /* Sentinel */
  470. };
  471. int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{
  472. {2}
  473. }
  474. void _init_py_{0}(py::module m) {{
  475. using py_op = PyOp({0});
  476. auto& py_type = PyOpType({0});
  477. py_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
  478. py_type.tp_name = "megengine.core._imperative_rt.ops.{0}";
  479. py_type.tp_basicsize = sizeof(PyOp({0}));
  480. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  481. py_type.tp_doc = "{0}";
  482. py_type.tp_base = &PyOpType(OpDef);
  483. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  484. py_type.tp_new = py_new_generic<py_op>;
  485. py_type.tp_init = py_op::py_init;
  486. py_type.tp_getset = py_op::py_getsetters;
  487. mgb_assert(PyType_Ready(&py_type) >= 0);
  488. {3}
  489. PyType_Modified(&py_type);
  490. m.add_object("{0}", reinterpret_cast<PyObject*>(&py_type));
  491. mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second);
  492. }
  493. )",
  494. op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body);
  495. }
  496. static void for_each_operator(raw_ostream &os, RecordKeeper &keeper,
  497. std::function<void(raw_ostream&, MgbOp&)> callback) {
  498. auto op_base_class = keeper.getClass("Op");
  499. ASSERT(op_base_class, "could not find base class Op");
  500. for (auto&& i: keeper.getDefs()) {
  501. auto&& r = i.second;
  502. if (r->isSubClassOf(op_base_class)) {
  503. auto op = mlir::tblgen::Operator(r.get());
  504. if (op.getDialectName().str() == "mgb") {
  505. std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl;
  506. callback(os, llvm::cast<MgbOp>(op));
  507. }
  508. }
  509. }
  510. }
  511. static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) {
  512. for_each_operator(os, keeper, gen_op_def_c_header_single);
  513. return false;
  514. }
  515. static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) {
  516. for_each_operator(os, keeper, gen_op_def_c_body_single);
  517. return false;
  518. }
  519. static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) {
  520. EnumContext ctx;
  521. using namespace std::placeholders;
  522. for_each_operator(os, keeper,
  523. std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx)));
  524. return false;
  525. }
  526. static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) {
  527. EnumContext ctx;
  528. using namespace std::placeholders;
  529. for_each_operator(os, keeper,
  530. std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx)));
  531. os << "#define INIT_ALL_OP(m)";
  532. for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) {
  533. os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName());
  534. });
  535. os << "\n";
  536. return false;
  537. }
  538. int main(int argc, char **argv) {
  539. llvm::InitLLVM y(argc, argv);
  540. llvm::cl::ParseCommandLineOptions(argc, argv);
  541. if (action == ActionType::CppHeader) {
  542. return TableGenMain(argv[0], &gen_op_def_c_header);
  543. }
  544. if (action == ActionType::CppBody) {
  545. return TableGenMain(argv[0], &gen_op_def_c_body);
  546. }
  547. if (action == ActionType::Pybind) {
  548. return TableGenMain(argv[0], &gen_op_def_pybind11);
  549. }
  550. if (action == ActionType::CPython) {
  551. return TableGenMain(argv[0], &gen_op_def_python_c_extension);
  552. }
  553. return -1;
  554. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台