You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cpp_class.cpp 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. #include "./cpp_class.h"
  2. #include "../emitter.h"
  3. namespace mlir::tblgen {
  4. namespace {
  5. llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
  6. // Note: we have already registered the corresponding attr wrappers
  7. // for following basic ctypes so we needn't handle them here
  8. /* auto&& attr_type_name = attr.getAttrDefName();
  9. if (attr_type_name == "UI32Attr") {
  10. return "uint32_t";
  11. }
  12. if (attr_type_name == "UI64Attr") {
  13. return "uint64_t";
  14. }
  15. if (attr_type_name == "I32Attr") {
  16. return "int32_t";
  17. }
  18. if (attr_type_name == "F32Attr") {
  19. return "float";
  20. }
  21. if (attr_type_name == "F64Attr") {
  22. return "double";
  23. }
  24. if (attr_type_name == "StrAttr") {
  25. return "std::string";
  26. }
  27. if (attr_type_name == "BoolAttr") {
  28. return "bool";
  29. }*/
  30. auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
  31. if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
  32. return e->getEnumName();
  33. }
  34. return attr.getUnderlyingType();
  35. }
  36. class OpDefEmitter final : public EmitterBase {
  37. public:
  38. OpDefEmitter(MgbOp& op_, raw_ostream& os_) : EmitterBase(os_), op(op_) {}
  39. void emit_header();
  40. void emit_tpl_spl();
  41. void emit_body();
  42. private:
  43. MgbOp& op;
  44. };
  45. void OpDefEmitter::emit_header() {
  46. os << formatv(
  47. "class {0} : public OpDefImplBase<{0}> {{\n"
  48. " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
  49. "public:\n",
  50. op.getCppClassName());
  51. // handle enum alias
  52. for (auto&& i : op.getMgbAttributes()) {
  53. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  54. os << formatv(
  55. " using {0} = {1};\n", attr->getEnumName(),
  56. attr->getUnderlyingType());
  57. }
  58. }
  59. for (auto&& i : op.getMgbAttributes()) {
  60. auto defaultValue = i.attr.getDefaultValue().str();
  61. if (!defaultValue.empty()) {
  62. defaultValue = formatv(" = {0}", defaultValue);
  63. }
  64. os << formatv(" {0} {1}{2};\n", attr_to_ctype(i.attr), i.name, defaultValue);
  65. }
  66. auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
  67. os << formatv(
  68. " {0}({1}){2}{3}\n", op.getCppClassName(), paramList, memInitList,
  69. body);
  70. };
  71. gen_ctor("", "", " = default;");
  72. if (!op.getMgbAttributes().empty()) {
  73. std::string strategy_val = "";
  74. std::vector<std::string> paramList, initList;
  75. for (auto&& i : op.getMgbAttributes()) {
  76. if (attr_to_ctype(i.attr).compare("Strategy") == 0) {
  77. strategy_val = i.name;
  78. }
  79. paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name));
  80. initList.push_back(formatv("{0}({0}_)", i.name));
  81. }
  82. paramList.push_back("std::string scope_ = {}");
  83. if (!strategy_val.empty()) {
  84. gen_ctor(
  85. llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "),
  86. formatv(" {"
  87. "\n set_scope(scope_);"
  88. "\n mgb_assert(static_cast<uint32_t>({0}) <= "
  89. "uint32_t(8));"
  90. "\n }",
  91. strategy_val));
  92. } else {
  93. gen_ctor(
  94. llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "),
  95. " { set_scope(scope_); }");
  96. }
  97. }
  98. auto packedParams = op.getPackedParams();
  99. if (!packedParams.empty()) {
  100. std::string strategy_val = "";
  101. std::vector<std::string> paramList, initList;
  102. for (auto&& p : packedParams) {
  103. auto&& paramFields = p.getFields();
  104. auto&& paramType = p.getFullName();
  105. auto&& paramName = formatv("packed_param_{0}", paramList.size());
  106. paramList.push_back(
  107. paramFields.empty() ? paramType.str()
  108. : formatv("{0} {1}", paramType, paramName));
  109. for (auto&& i : paramFields) {
  110. if (i.name.compare("strategy") == 0) {
  111. strategy_val = i.name;
  112. }
  113. initList.push_back(formatv("{0}({1}.{0})", i.name, paramName));
  114. }
  115. }
  116. for (auto&& i : op.getExtraArguments()) {
  117. paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name));
  118. initList.push_back(formatv("{0}({0}_)", i.name));
  119. }
  120. if (!strategy_val.empty()) {
  121. gen_ctor(
  122. llvm::join(paramList, ", "),
  123. initList.empty() ? "" : ": " + llvm::join(initList, ", "),
  124. formatv(" {"
  125. "\n mgb_assert(static_cast<uint32_t>({0}) <= "
  126. "uint32_t(8));"
  127. "\n }",
  128. strategy_val));
  129. } else {
  130. gen_ctor(
  131. llvm::join(paramList, ", "),
  132. initList.empty() ? "" : ": " + llvm::join(initList, ", "), " {}");
  133. }
  134. }
  135. if (!packedParams.empty()) {
  136. for (auto&& p : packedParams) {
  137. auto accessor = p.getAccessor();
  138. if (!accessor.empty()) {
  139. os << formatv(" {0} {1}() const {{\n", p.getFullName(), accessor);
  140. std::vector<llvm::StringRef> fields;
  141. for (auto&& i : p.getFields()) {
  142. fields.push_back(i.name);
  143. }
  144. os << formatv(" return {{{0}};\n", llvm::join(fields, ", "));
  145. os << " }\n";
  146. }
  147. }
  148. }
  149. if (auto decl = op.getExtraOpdefDecl()) {
  150. os << decl.getValue();
  151. }
  152. os << formatv("};\n\n");
  153. }
  154. void OpDefEmitter::emit_tpl_spl() {
  155. for (auto&& i : op.getMgbAttributes()) {
  156. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  157. if (attr->supportToString()) {
  158. std::vector<std::string> case_body;
  159. std::string ename =
  160. formatv("{0}::{1}", op.getCppClassName(), attr->getEnumName());
  161. llvm::for_each(attr->getEnumMembers(), [&](auto&& v) {
  162. size_t d1 = v.find(' ');
  163. size_t d2 = v.find('=');
  164. size_t d = d1 <= d2 ? d1 : d2;
  165. case_body.push_back(formatv(
  166. "case {0}::{1}: return \"{1}\";", ename, v.substr(0, d)));
  167. });
  168. os << formatv(
  169. R"(
  170. template <>
  171. struct ToStringTrait<{0}> {
  172. std::string operator()({0} e) const {
  173. switch (e) {
  174. {1}
  175. default:
  176. return "{0}::Unknown";
  177. }
  178. }
  179. };
  180. )",
  181. ename, llvm::join(case_body, "\n"));
  182. }
  183. }
  184. }
  185. }
  186. void OpDefEmitter::emit_body() {
  187. auto&& className = op.getCppClassName();
  188. os << formatv("MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className);
  189. auto formatMethImpl = [&](auto&& meth) {
  190. return formatv("{0}_{1}_impl", className, meth);
  191. };
  192. std::vector<std::string> methods;
  193. if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
  194. os << "namespace {\n";
  195. // generate hash()
  196. mlir::tblgen::FmtContext ctx;
  197. os << formatv("size_t {0}(const OpDef& def_) {{\n", formatMethImpl("hash"));
  198. os << formatv(
  199. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  200. " static_cast<void>(op_);\n",
  201. className);
  202. ctx.withSelf("op_");
  203. os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
  204. os << "}\n";
  205. // generate is_same_st()
  206. os << formatv(
  207. "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
  208. formatMethImpl("is_same_st"));
  209. os << formatv(
  210. " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n"
  211. " &&b_ = rhs_.cast_final_safe<{0}>();\n"
  212. " static_cast<void>(a_);\n"
  213. " static_cast<void>(b_);\n",
  214. className);
  215. os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
  216. os << "}\n";
  217. // generate props()
  218. os << formatv(
  219. "std::vector<std::pair<const char*, std::string>> {0}(const OpDef& "
  220. "def_) {{\n",
  221. formatMethImpl("props"));
  222. os << formatv(
  223. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  224. " static_cast<void>(op_);\n",
  225. className);
  226. ctx.withSelf("op_");
  227. os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx);
  228. os << "}\n";
  229. // generate make_name()
  230. os << formatv(
  231. "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name"));
  232. os << formatv(
  233. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  234. " static_cast<void>(op_);\n",
  235. className);
  236. ctx.withSelf("op_");
  237. os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx);
  238. os << "}\n";
  239. os << "} // anonymous namespace\n";
  240. methods.push_back("hash");
  241. methods.push_back("is_same_st");
  242. methods.push_back("props");
  243. methods.push_back("make_name");
  244. }
  245. if (!methods.empty()) {
  246. os << formatv("OP_TRAIT_REG({0}, {0})", op.getCppClassName());
  247. for (auto&& i : methods) {
  248. os << formatv("\n .{0}({1})", i, formatMethImpl(i));
  249. }
  250. os << ";\n\n";
  251. }
  252. }
  253. } // namespace
  254. bool gen_op_def_c_header(raw_ostream& os, llvm::RecordKeeper& keeper) {
  255. foreach_operator(keeper, [&](MgbOp& op) {
  256. OpDefEmitter emitter(op, os);
  257. emitter.emit_header();
  258. emitter.emit_tpl_spl();
  259. });
  260. return false;
  261. }
  262. bool gen_op_def_c_body(raw_ostream& os, llvm::RecordKeeper& keeper) {
  263. foreach_operator(keeper, [&](MgbOp& op) {
  264. OpDefEmitter emitter(op, os);
  265. emitter.emit_body();
  266. });
  267. return false;
  268. }
  269. } // namespace mlir::tblgen