You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autogen.cpp 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. #include <iostream>
  2. #include <unordered_map>
  3. #include <functional>
  4. #include "./helper.h"
  5. using llvm::raw_ostream;
  6. using llvm::RecordKeeper;
  7. enum ActionType {
  8. None,
  9. CppHeader,
  10. CppBody,
  11. Pybind,
  12. CPython
  13. };
  14. // NOLINTNEXTLINE
  15. llvm::cl::opt<ActionType> action(
  16. llvm::cl::desc("Action to perform:"),
  17. llvm::cl::values(clEnumValN(CppHeader, "gen-cpp-header",
  18. "Generate operator cpp header"),
  19. clEnumValN(CppBody, "gen-cpp-body",
  20. "Generate operator cpp body"),
  21. clEnumValN(Pybind, "gen-python-binding",
  22. "Generate pybind11 python bindings"),
  23. clEnumValN(CPython, "gen-python-c-extension",
  24. "Generate python c extensions")));
  25. using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase;
  26. using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin;
  27. using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin;
  28. using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin;
  29. using MgbOp = mlir::tblgen::MgbOpBase;
  30. using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin;
  31. llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
  32. // Note: we have already registered the corresponding attr wrappers
  33. // for following basic ctypes so we needn't handle them here
  34. /* auto&& attr_type_name = attr.getAttrDefName();
  35. if (attr_type_name == "UI32Attr") {
  36. return "uint32_t";
  37. }
  38. if (attr_type_name == "UI64Attr") {
  39. return "uint64_t";
  40. }
  41. if (attr_type_name == "I32Attr") {
  42. return "int32_t";
  43. }
  44. if (attr_type_name == "F32Attr") {
  45. return "float";
  46. }
  47. if (attr_type_name == "F64Attr") {
  48. return "double";
  49. }
  50. if (attr_type_name == "StrAttr") {
  51. return "std::string";
  52. }
  53. if (attr_type_name == "BoolAttr") {
  54. return "bool";
  55. }*/
  56. auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
  57. if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
  58. return e->getEnumName();
  59. }
  60. return attr.getUnderlyingType();
  61. }
  62. static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) {
  63. os << formatv(
  64. "class {0} : public OpDefImplBase<{0}> {{\n"
  65. " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
  66. "public:\n",
  67. op.getCppClassName()
  68. );
  69. // handle enum alias
  70. for (auto &&i : op.getMgbAttributes()) {
  71. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  72. os << formatv(
  73. " using {0} = {1};\n",
  74. attr->getEnumName(), attr->getUnderlyingType()
  75. );
  76. }
  77. }
  78. for (auto &&i : op.getMgbAttributes()) {
  79. auto defaultValue = i.attr.getDefaultValue().str();
  80. if (!defaultValue.empty()) {
  81. defaultValue = formatv(" = {0}", defaultValue);
  82. }
  83. os << formatv(
  84. " {0} {1}{2};\n",
  85. attr_to_ctype(i.attr), i.name, defaultValue
  86. );
  87. }
  88. auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
  89. os << formatv(
  90. " {0}({1}){2}{3}\n",
  91. op.getCppClassName(), paramList, memInitList, body
  92. );
  93. };
  94. gen_ctor("", "", " = default;");
  95. if (!op.getMgbAttributes().empty()) {
  96. std::vector<std::string> paramList, initList;
  97. for (auto &&i : op.getMgbAttributes()) {
  98. paramList.push_back(formatv(
  99. "{0} {1}_", attr_to_ctype(i.attr), i.name
  100. ));
  101. initList.push_back(formatv(
  102. "{0}({0}_)", i.name
  103. ));
  104. }
  105. gen_ctor(llvm::join(paramList, ", "),
  106. ": " + llvm::join(initList, ", "),
  107. " {}");
  108. }
  109. auto packedParams = op.getPackedParams();
  110. if (!packedParams.empty()) {
  111. std::vector<std::string> paramList, initList;
  112. for (auto &&p : packedParams) {
  113. auto&& paramFields = p.getFields();
  114. auto&& paramType = p.getFullName();
  115. auto&& paramName = formatv("packed_param_{0}", paramList.size());
  116. paramList.push_back(
  117. paramFields.empty() ? paramType.str()
  118. : formatv("{0} {1}", paramType, paramName)
  119. );
  120. for (auto&& i : paramFields) {
  121. initList.push_back(formatv(
  122. "{0}({1}.{0})", i.name, paramName
  123. ));
  124. }
  125. }
  126. for (auto&& i : op.getExtraArguments()) {
  127. paramList.push_back(formatv(
  128. "{0} {1}_", attr_to_ctype(i.attr), i.name
  129. ));
  130. initList.push_back(formatv(
  131. "{0}({0}_)", i.name
  132. ));
  133. }
  134. gen_ctor(llvm::join(paramList, ", "),
  135. initList.empty() ? "" : ": " + llvm::join(initList, ", "),
  136. " {}");
  137. }
  138. if (!packedParams.empty()) {
  139. for (auto&& p : packedParams) {
  140. auto accessor = p.getAccessor();
  141. if (!accessor.empty()) {
  142. os << formatv(
  143. " {0} {1}() const {{\n",
  144. p.getFullName(), accessor
  145. );
  146. std::vector<llvm::StringRef> fields;
  147. for (auto&& i : p.getFields()) {
  148. fields.push_back(i.name);
  149. }
  150. os << formatv(
  151. " return {{{0}};\n",
  152. llvm::join(fields, ", ")
  153. );
  154. os << " }\n";
  155. }
  156. }
  157. }
  158. if (auto decl = op.getExtraOpdefDecl()) {
  159. os << decl.getValue();
  160. }
  161. os << formatv(
  162. "};\n\n"
  163. );
  164. }
  165. static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
  166. auto&& className = op.getCppClassName();
  167. os << formatv(
  168. "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className
  169. );
  170. auto formatMethImpl = [&](auto&& meth) {
  171. return formatv(
  172. "{0}_{1}_impl", className, meth
  173. );
  174. };
  175. std::vector<std::string> methods;
  176. if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
  177. os << "namespace {\n";
  178. // generate hash()
  179. mlir::tblgen::FmtContext ctx;
  180. os << formatv(
  181. "size_t {0}(const OpDef& def_) {{\n",
  182. formatMethImpl("hash")
  183. );
  184. os << formatv(
  185. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  186. " static_cast<void>(op_);\n",
  187. className
  188. );
  189. ctx.withSelf("op_");
  190. os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
  191. os << "}\n";
  192. // generate is_same_st()
  193. os << formatv(
  194. "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
  195. formatMethImpl("is_same_st")
  196. );
  197. os << formatv(
  198. " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n"
  199. " &&b_ = rhs_.cast_final_safe<{0}>();\n"
  200. " static_cast<void>(a_);\n"
  201. " static_cast<void>(b_);\n",
  202. className
  203. );
  204. os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
  205. os << "}\n";
  206. os << "} // anonymous namespace\n";
  207. methods.push_back("hash");
  208. methods.push_back("is_same_st");
  209. }
  210. if (!methods.empty()) {
  211. os << formatv(
  212. "OP_TRAIT_REG({0}, {0})", op.getCppClassName()
  213. );
  214. for (auto&& i : methods) {
  215. os << formatv(
  216. "\n .{0}({1})", i, formatMethImpl(i)
  217. );
  218. }
  219. os << ";\n\n";
  220. }
  221. }
  222. struct EnumContext {
  223. std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias;
  224. };
  225. static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
  226. auto className = op.getCppClassName();
  227. os << formatv(
  228. "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
  229. className
  230. );
  231. for (auto&& i : op.getMgbAttributes()) {
  232. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  233. unsigned int enumID;
  234. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  235. auto&& aliasBase = alias->getAliasBase();
  236. enumID =
  237. llvm::cast<MgbEnumAttr>(aliasBase)
  238. .getBaseRecord()->getID();
  239. } else {
  240. enumID = attr->getBaseRecord()->getID();
  241. }
  242. auto&& enumAlias = ctx.enumAlias;
  243. auto&& iter = enumAlias.find(enumID);
  244. if (iter == enumAlias.end()) {
  245. os << formatv(
  246. "py::enum_<{0}::{1}>({0}Inst, \"{1}\")",
  247. className, attr->getEnumName()
  248. );
  249. std::vector<std::string> body;
  250. for (auto&& i: attr->getEnumMembers()) {
  251. os << formatv(
  252. "\n .value(\"{2}\", {0}::{1}::{2})",
  253. className, attr->getEnumName(), i
  254. );
  255. body.push_back(formatv(
  256. "if (str == \"{2}\") return {0}::{1}::{2};",
  257. className, attr->getEnumName(), i
  258. ));
  259. }
  260. os << formatv(
  261. "\n .def(py::init([](const std::string& in) {"
  262. "\n auto&& str = normalize_enum(in);"
  263. "\n {0}"
  264. "\n throw py::cast_error(\"invalid enum value \" + in);"
  265. "\n }));\n",
  266. llvm::join(body, "\n ")
  267. );
  268. os << formatv(
  269. "py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
  270. className, attr->getEnumName()
  271. );
  272. enumAlias.emplace(enumID,
  273. std::make_pair(className, attr->getEnumName()));
  274. } else {
  275. os << formatv(
  276. "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n",
  277. className, attr->getEnumName(),
  278. iter->second.first, iter->second.second
  279. );
  280. }
  281. }
  282. }
  283. // generate op class binding
  284. os << formatv("{0}Inst", className);
  285. bool hasDefaultCtor = op.getMgbAttributes().empty();
  286. if (!hasDefaultCtor) {
  287. os << "\n .def(py::init<";
  288. std::vector<llvm::StringRef> targs;
  289. for (auto &&i : op.getMgbAttributes()) {
  290. targs.push_back(i.attr.getReturnType());
  291. }
  292. os << llvm::join(targs, ", ");
  293. os << ">()";
  294. for (auto &&i : op.getMgbAttributes()) {
  295. os << formatv(", py::arg(\"{0}\")", i.name);
  296. auto defaultValue = i.attr.getDefaultValue();
  297. if (!defaultValue.empty()) {
  298. os << formatv(" = {0}", defaultValue);
  299. } else {
  300. hasDefaultCtor = true;
  301. }
  302. }
  303. os << ")";
  304. }
  305. if (hasDefaultCtor) {
  306. os << "\n .def(py::init<>())";
  307. }
  308. for (auto &&i : op.getMgbAttributes()) {
  309. os << formatv(
  310. "\n .def_readwrite(\"{0}\", &{1}::{0})",
  311. i.name, className
  312. );
  313. }
  314. os << ";\n\n";
  315. }
  316. static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
  317. auto className = op.getCppClassName();
  318. std::string body;
  319. // generate PyType for enum class member
  320. for (auto&& i : op.getMgbAttributes()) {
  321. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  322. unsigned int enumID;
  323. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  324. auto&& aliasBase = alias->getAliasBase();
  325. enumID =
  326. llvm::cast<MgbEnumAttr>(aliasBase)
  327. .getBaseRecord()->getID();
  328. } else {
  329. enumID = attr->getBaseRecord()->getID();
  330. }
  331. auto&& enumAlias = ctx.enumAlias;
  332. auto&& iter = enumAlias.find(enumID);
  333. auto enumName = attr->getEnumName();
  334. body += "{\n";
  335. body += formatv(
  336. "auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName
  337. );
  338. if (iter == enumAlias.end()) {
  339. os << formatv(
  340. "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n",
  341. className, enumName);
  342. os << formatv(
  343. "template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n",
  344. className, enumName);
  345. std::vector<std::string> pairStr;
  346. for (auto&& i: attr->getEnumMembers()) {
  347. pairStr.push_back(formatv(
  348. "{{normalize_enum(\"{2}\"), {0}::{1}::{2}}",
  349. className, enumName, i));
  350. }
  351. os << formatv(R"(
  352. template<> std::unordered_map<std::string, {0}::{1}>
  353. EnumWrapper<{0}::{1}>::str2type = {{
  354. {2}
  355. };
  356. )", className, enumName, llvm::join(pairStr, ", "));
  357. pairStr.clear();
  358. for (auto&& i: attr->getEnumMembers()) {
  359. pairStr.push_back(formatv(
  360. "{{{0}::{1}::{2}, normalize_enum(\"{2}\")}",
  361. className, enumName, i));
  362. }
  363. os << formatv(R"(
  364. template<> std::unordered_map<{0}::{1}, std::string>
  365. EnumWrapper<{0}::{1}>::type2str = {{
  366. {2}
  367. };
  368. )", className, enumName, llvm::join(pairStr, ", "));
  369. body += formatv(R"(
  370. e_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
  371. e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}";
  372. e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>);
  373. e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  374. e_type.tp_doc = "{0}.{1}";
  375. e_type.tp_base = &PyBaseObject_Type;
  376. e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr;
  377. e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare;
  378. mgb_assert(PyType_Ready(&e_type) >= 0);
  379. )", className, enumName);
  380. for (auto&& i: attr->getEnumMembers()) {
  381. body += formatv(R"({{
  382. PyObject* inst = e_type.tp_alloc(&e_type, 0);
  383. reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2};
  384. mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0);
  385. })", className, enumName, i);
  386. }
  387. enumAlias.emplace(enumID, std::make_pair(className, enumName));
  388. }
  389. body += formatv(R"(
  390. PyType_Modified(&e_type);
  391. mgb_assert(PyDict_SetItemString(
  392. py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0);
  393. )", enumName);
  394. body += "}\n";
  395. }
  396. }
  397. // generate getsetters
  398. std::vector<std::string> getsetters;
  399. for (auto &&i : op.getMgbAttributes()) {
  400. getsetters.push_back(formatv(
  401. "{{\"{1}\", py_get_generic({0}, {1}), py_set_generic({0}, {1}), \"{1}\", NULL},",
  402. className, i.name));
  403. }
  404. // generate tp_init
  405. std::string initBody;
  406. if (!op.getMgbAttributes().empty()) {
  407. initBody += "static const char* kwlist[] = {";
  408. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  409. initBody += formatv("\"{0}\", ", attr.name);
  410. });
  411. initBody += "NULL};\n";
  412. initBody += " PyObject ";
  413. std::vector<std::string> attrs;
  414. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  415. attrs.push_back(formatv("*{0} = NULL", attr.name));
  416. });
  417. initBody += llvm::join(attrs, ", ") + ";\n";
  418. initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
  419. initBody += std::string(op.getMgbAttributes().size(), 'O');
  420. initBody += "\", const_cast<char**>(kwlist)";
  421. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  422. initBody += formatv(" ,&{0}", attr.name);
  423. });
  424. initBody += "))\n";
  425. initBody += " return -1;\n";
  426. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  427. initBody += formatv(R"(
  428. if ({1}) {{
  429. try {{
  430. reinterpret_cast<PyOp({0})*>(self)->inst().{1} =
  431. pyobj_convert_generic<decltype({0}::{1})>::from({1});
  432. } catch(py::error_already_set& e) {{
  433. e.restore();
  434. return -1;
  435. } catch(py::builtin_exception& e) {{
  436. e.set_error();
  437. return -1;
  438. } catch(...) {{
  439. PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
  440. return -1;
  441. }
  442. }
  443. )", className, attr.name);
  444. });
  445. }
  446. initBody += "\n return 0;";
  447. os << formatv(R"(
  448. PyOpDefBegin({0}) // {{
  449. static PyGetSetDef py_getsetters[];
  450. static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
  451. // };
  452. PyOpDefEnd({0})
  453. PyGetSetDef PyOp({0})::py_getsetters[] = {{
  454. {1}
  455. {{NULL} /* Sentinel */
  456. };
  457. int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{
  458. {2}
  459. }
  460. void _init_py_{0}(py::module m) {{
  461. using py_op = PyOp({0});
  462. auto& py_type = PyOpType({0});
  463. py_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
  464. py_type.tp_name = "megengine.core._imperative_rt.ops.{0}";
  465. py_type.tp_basicsize = sizeof(PyOp({0}));
  466. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  467. py_type.tp_doc = "{0}";
  468. py_type.tp_base = &PyOpType(OpDef);
  469. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  470. py_type.tp_new = py_new_generic<py_op>;
  471. py_type.tp_init = py_op::py_init;
  472. py_type.tp_getset = py_op::py_getsetters;
  473. mgb_assert(PyType_Ready(&py_type) >= 0);
  474. {3}
  475. PyType_Modified(&py_type);
  476. m.add_object("{0}", reinterpret_cast<PyObject*>(&py_type));
  477. mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second);
  478. }
  479. )",
  480. op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body);
  481. }
  482. static void for_each_operator(raw_ostream &os, RecordKeeper &keeper,
  483. std::function<void(raw_ostream&, MgbOp&)> callback) {
  484. auto op_base_class = keeper.getClass("Op");
  485. ASSERT(op_base_class, "could not find base class Op");
  486. for (auto&& i: keeper.getDefs()) {
  487. auto&& r = i.second;
  488. if (r->isSubClassOf(op_base_class)) {
  489. auto op = mlir::tblgen::Operator(r.get());
  490. if (op.getDialectName().str() == "mgb") {
  491. std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl;
  492. callback(os, llvm::cast<MgbOp>(op));
  493. }
  494. }
  495. }
  496. }
  497. static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) {
  498. for_each_operator(os, keeper, gen_op_def_c_header_single);
  499. return false;
  500. }
  501. static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) {
  502. for_each_operator(os, keeper, gen_op_def_c_body_single);
  503. return false;
  504. }
  505. static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) {
  506. EnumContext ctx;
  507. using namespace std::placeholders;
  508. for_each_operator(os, keeper,
  509. std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx)));
  510. return false;
  511. }
  512. static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) {
  513. EnumContext ctx;
  514. using namespace std::placeholders;
  515. for_each_operator(os, keeper,
  516. std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx)));
  517. os << "#define INIT_ALL_OP(m)";
  518. for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) {
  519. os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName());
  520. });
  521. os << "\n";
  522. return false;
  523. }
  524. int main(int argc, char **argv) {
  525. llvm::InitLLVM y(argc, argv);
  526. llvm::cl::ParseCommandLineOptions(argc, argv);
  527. if (action == ActionType::CppHeader) {
  528. return TableGenMain(argv[0], &gen_op_def_c_header);
  529. }
  530. if (action == ActionType::CppBody) {
  531. return TableGenMain(argv[0], &gen_op_def_c_body);
  532. }
  533. if (action == ActionType::Pybind) {
  534. return TableGenMain(argv[0], &gen_op_def_pybind11);
  535. }
  536. if (action == ActionType::CPython) {
  537. return TableGenMain(argv[0], &gen_op_def_python_c_extension);
  538. }
  539. return -1;
  540. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台