You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autogen.cpp 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. #include <iostream>
  2. #include <unordered_map>
  3. #include <functional>
  4. #include "./helper.h"
  5. using llvm::raw_ostream;
  6. using llvm::RecordKeeper;
  7. enum ActionType {
  8. None,
  9. CppHeader,
  10. CppBody,
  11. Pybind
  12. };
  13. // NOLINTNEXTLINE
  14. llvm::cl::opt<ActionType> action(
  15. llvm::cl::desc("Action to perform:"),
  16. llvm::cl::values(clEnumValN(CppHeader, "gen-cpp-header",
  17. "Generate operator cpp header"),
  18. clEnumValN(CppBody, "gen-cpp-body",
  19. "Generate operator cpp body"),
  20. clEnumValN(Pybind, "gen-python-binding",
  21. "Generate pybind11 python bindings")));
  22. using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase;
  23. using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin;
  24. using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin;
  25. using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin;
  26. using MgbOp = mlir::tblgen::MgbOpBase;
  27. using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin;
  28. llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
  29. // Note: we have already registered the corresponding attr wrappers
  30. // for following basic ctypes so we needn't handle them here
  31. /* auto&& attr_type_name = attr.getAttrDefName();
  32. if (attr_type_name == "UI32Attr") {
  33. return "uint32_t";
  34. }
  35. if (attr_type_name == "UI64Attr") {
  36. return "uint64_t";
  37. }
  38. if (attr_type_name == "I32Attr") {
  39. return "int32_t";
  40. }
  41. if (attr_type_name == "F32Attr") {
  42. return "float";
  43. }
  44. if (attr_type_name == "F64Attr") {
  45. return "double";
  46. }
  47. if (attr_type_name == "StrAttr") {
  48. return "std::string";
  49. }
  50. if (attr_type_name == "BoolAttr") {
  51. return "bool";
  52. }*/
  53. auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
  54. if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
  55. return e->getEnumName();
  56. }
  57. return attr.getUnderlyingType();
  58. }
  59. static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) {
  60. os << formatv(
  61. "class {0} : public OpDefImplBase<{0}> {{\n"
  62. " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
  63. "public:\n",
  64. op.getCppClassName()
  65. );
  66. // handle enum alias
  67. for (auto &&i : op.getMgbAttributes()) {
  68. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  69. os << formatv(
  70. " using {0} = {1};\n",
  71. attr->getEnumName(), attr->getUnderlyingType()
  72. );
  73. }
  74. }
  75. for (auto &&i : op.getMgbAttributes()) {
  76. auto defaultValue = i.attr.getDefaultValue().str();
  77. if (!defaultValue.empty()) {
  78. defaultValue = formatv(" = {0}", defaultValue);
  79. }
  80. os << formatv(
  81. " {0} {1}{2};\n",
  82. attr_to_ctype(i.attr), i.name, defaultValue
  83. );
  84. }
  85. auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
  86. os << formatv(
  87. " {0}({1}){2}{3}\n",
  88. op.getCppClassName(), paramList, memInitList, body
  89. );
  90. };
  91. gen_ctor("", "", " = default;");
  92. if (!op.getMgbAttributes().empty()) {
  93. std::vector<std::string> paramList, initList;
  94. for (auto &&i : op.getMgbAttributes()) {
  95. paramList.push_back(formatv(
  96. "{0} {1}_", attr_to_ctype(i.attr), i.name
  97. ));
  98. initList.push_back(formatv(
  99. "{0}({0}_)", i.name
  100. ));
  101. }
  102. gen_ctor(llvm::join(paramList, ", "),
  103. ": " + llvm::join(initList, ", "),
  104. " {}");
  105. }
  106. auto packedParams = op.getPackedParams();
  107. if (!packedParams.empty()) {
  108. std::vector<std::string> paramList, initList;
  109. for (auto &&p : packedParams) {
  110. auto&& paramFields = p.getFields();
  111. auto&& paramType = p.getFullName();
  112. auto&& paramName = formatv("packed_param_{0}", paramList.size());
  113. paramList.push_back(
  114. paramFields.empty() ? paramType.str()
  115. : formatv("{0} {1}", paramType, paramName)
  116. );
  117. for (auto&& i : paramFields) {
  118. initList.push_back(formatv(
  119. "{0}({1}.{0})", i.name, paramName
  120. ));
  121. }
  122. }
  123. for (auto&& i : op.getExtraArguments()) {
  124. paramList.push_back(formatv(
  125. "{0} {1}_", attr_to_ctype(i.attr), i.name
  126. ));
  127. initList.push_back(formatv(
  128. "{0}({0}_)", i.name
  129. ));
  130. }
  131. gen_ctor(llvm::join(paramList, ", "),
  132. initList.empty() ? "" : ": " + llvm::join(initList, ", "),
  133. " {}");
  134. }
  135. if (!packedParams.empty()) {
  136. for (auto&& p : packedParams) {
  137. auto accessor = p.getAccessor();
  138. if (!accessor.empty()) {
  139. os << formatv(
  140. " {0} {1}() const {{\n",
  141. p.getFullName(), accessor
  142. );
  143. std::vector<llvm::StringRef> fields;
  144. for (auto&& i : p.getFields()) {
  145. fields.push_back(i.name);
  146. }
  147. os << formatv(
  148. " return {{{0}};\n",
  149. llvm::join(fields, ", ")
  150. );
  151. os << " }\n";
  152. }
  153. }
  154. }
  155. if (auto decl = op.getExtraOpdefDecl()) {
  156. os << decl.getValue();
  157. }
  158. os << formatv(
  159. "};\n\n"
  160. );
  161. }
  162. static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
  163. auto&& className = op.getCppClassName();
  164. os << formatv(
  165. "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className
  166. );
  167. auto formatMethImpl = [&](auto&& meth) {
  168. return formatv(
  169. "{0}_{1}_impl", className, meth
  170. );
  171. };
  172. std::vector<std::string> methods;
  173. if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
  174. os << "namespace {\n";
  175. // generate hash()
  176. mlir::tblgen::FmtContext ctx;
  177. os << formatv(
  178. "size_t {0}(const OpDef& def_) {{\n",
  179. formatMethImpl("hash")
  180. );
  181. os << formatv(
  182. " auto op_ = def_.cast_final_safe<{0}>();\n"
  183. " static_cast<void>(op_);\n",
  184. className
  185. );
  186. ctx.withSelf("op_");
  187. os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
  188. os << "}\n";
  189. // generate is_same_st()
  190. os << formatv(
  191. "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
  192. formatMethImpl("is_same_st")
  193. );
  194. os << formatv(
  195. " auto a_ = lhs_.cast_final_safe<{0}>(),\n"
  196. " b_ = rhs_.cast_final_safe<{0}>();\n"
  197. " static_cast<void>(a_);\n"
  198. " static_cast<void>(b_);\n",
  199. className
  200. );
  201. os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
  202. os << "}\n";
  203. os << "} // anonymous namespace\n";
  204. methods.push_back("hash");
  205. methods.push_back("is_same_st");
  206. }
  207. if (!methods.empty()) {
  208. os << formatv(
  209. "OP_TRAIT_REG({0}, {0})", op.getCppClassName()
  210. );
  211. for (auto&& i : methods) {
  212. os << formatv(
  213. "\n .{0}({1})", i, formatMethImpl(i)
  214. );
  215. }
  216. os << ";\n\n";
  217. }
  218. }
  219. struct PybindContext {
  220. std::unordered_map<unsigned int, std::string> enumAlias;
  221. };
  222. static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext& ctx) {
  223. auto class_name = op.getCppClassName();
  224. os << formatv(
  225. "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
  226. class_name
  227. );
  228. for (auto&& i : op.getMgbAttributes()) {
  229. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  230. unsigned int enumID;
  231. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  232. auto&& aliasBase = alias->getAliasBase();
  233. enumID =
  234. llvm::cast<MgbEnumAttr>(aliasBase)
  235. .getBaseRecord()->getID();
  236. } else {
  237. enumID = attr->getBaseRecord()->getID();
  238. }
  239. auto&& enumAlias = ctx.enumAlias;
  240. auto&& iter = enumAlias.find(enumID);
  241. if (iter == enumAlias.end()) {
  242. os << formatv(
  243. "py::enum_<{0}::{1}>({0}Inst, \"{1}\")",
  244. class_name, attr->getEnumName()
  245. );
  246. std::vector<std::string> body;
  247. for (auto&& i: attr->getEnumMembers()) {
  248. os << formatv(
  249. "\n .value(\"{2}\", {0}::{1}::{2})",
  250. class_name, attr->getEnumName(), i
  251. );
  252. body.push_back(formatv(
  253. "if (str == \"{2}\") return {0}::{1}::{2};",
  254. class_name, attr->getEnumName(), i
  255. ));
  256. }
  257. os << formatv(
  258. "\n .def(py::init([](const std::string& in) {"
  259. "\n auto&& str = normalize_enum(in);"
  260. "\n {0}"
  261. "\n throw py::cast_error(\"invalid enum value \" + in);"
  262. "\n }));\n",
  263. llvm::join(body, "\n ")
  264. );
  265. os << formatv(
  266. "py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
  267. class_name, attr->getEnumName()
  268. );
  269. enumAlias.emplace(enumID, formatv(
  270. "{0}Inst.attr(\"{1}\")", class_name, attr->getEnumName()
  271. ));
  272. } else {
  273. os << formatv(
  274. "{0}Inst.attr(\"{1}\") = {2};\n\n",
  275. class_name, attr->getEnumName(), iter->second
  276. );
  277. }
  278. }
  279. }
  280. // generate op class binding
  281. os << formatv("{0}Inst", class_name);
  282. bool hasDefaultCtor = op.getMgbAttributes().empty();
  283. if (!hasDefaultCtor) {
  284. os << "\n .def(py::init<";
  285. std::vector<llvm::StringRef> targs;
  286. for (auto &&i : op.getMgbAttributes()) {
  287. targs.push_back(i.attr.getReturnType());
  288. }
  289. os << llvm::join(targs, ", ");
  290. os << ">()";
  291. for (auto &&i : op.getMgbAttributes()) {
  292. os << formatv(", py::arg(\"{0}\")", i.name);
  293. auto defaultValue = i.attr.getDefaultValue();
  294. if (!defaultValue.empty()) {
  295. os << formatv(" = {0}", defaultValue);
  296. } else {
  297. hasDefaultCtor = true;
  298. }
  299. }
  300. os << ")";
  301. }
  302. if (hasDefaultCtor) {
  303. os << "\n .def(py::init<>())";
  304. }
  305. for (auto &&i : op.getMgbAttributes()) {
  306. os << formatv(
  307. "\n .def_readwrite(\"{0}\", &{1}::{0})",
  308. i.name, class_name
  309. );
  310. }
  311. os << ";\n\n";
  312. }
  313. static void for_each_operator(raw_ostream &os, RecordKeeper &keeper,
  314. std::function<void(raw_ostream&, MgbOp&)> callback) {
  315. auto op_base_class = keeper.getClass("Op");
  316. ASSERT(op_base_class, "could not find base class Op");
  317. for (auto&& i: keeper.getDefs()) {
  318. auto&& r = i.second;
  319. if (r->isSubClassOf(op_base_class)) {
  320. auto op = mlir::tblgen::Operator(r.get());
  321. if (op.getDialectName().str() == "mgb") {
  322. std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl;
  323. callback(os, llvm::cast<MgbOp>(op));
  324. }
  325. }
  326. }
  327. }
  328. static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) {
  329. for_each_operator(os, keeper, gen_op_def_c_header_single);
  330. return false;
  331. }
  332. static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) {
  333. for_each_operator(os, keeper, gen_op_def_c_body_single);
  334. return false;
  335. }
  336. static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) {
  337. PybindContext ctx;
  338. using namespace std::placeholders;
  339. for_each_operator(os, keeper,
  340. std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx)));
  341. return false;
  342. }
  343. int main(int argc, char **argv) {
  344. llvm::InitLLVM y(argc, argv);
  345. llvm::cl::ParseCommandLineOptions(argc, argv);
  346. if (action == ActionType::CppHeader) {
  347. return TableGenMain(argv[0], &gen_op_def_c_header);
  348. }
  349. if (action == ActionType::CppBody) {
  350. return TableGenMain(argv[0], &gen_op_def_c_body);
  351. }
  352. if (action == ActionType::Pybind) {
  353. return TableGenMain(argv[0], &gen_op_def_pybind11);
  354. }
  355. return -1;
  356. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台