You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cpp_class.cpp 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. /**
  2. * \file imperative/tablegen/targets/cpp_class.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./cpp_class.h"
  12. #include "../emitter.h"
  13. namespace mlir::tblgen {
  14. namespace {
  15. llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
  16. // Note: we have already registered the corresponding attr wrappers
  17. // for following basic ctypes so we needn't handle them here
  18. /* auto&& attr_type_name = attr.getAttrDefName();
  19. if (attr_type_name == "UI32Attr") {
  20. return "uint32_t";
  21. }
  22. if (attr_type_name == "UI64Attr") {
  23. return "uint64_t";
  24. }
  25. if (attr_type_name == "I32Attr") {
  26. return "int32_t";
  27. }
  28. if (attr_type_name == "F32Attr") {
  29. return "float";
  30. }
  31. if (attr_type_name == "F64Attr") {
  32. return "double";
  33. }
  34. if (attr_type_name == "StrAttr") {
  35. return "std::string";
  36. }
  37. if (attr_type_name == "BoolAttr") {
  38. return "bool";
  39. }*/
  40. auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
  41. if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
  42. return e->getEnumName();
  43. }
  44. return attr.getUnderlyingType();
  45. }
  46. class OpDefEmitter final : public EmitterBase {
  47. public:
  48. OpDefEmitter(MgbOp& op_, raw_ostream& os_) : EmitterBase(os_), op(op_) {}
  49. void emit_header();
  50. void emit_tpl_spl();
  51. void emit_body();
  52. private:
  53. MgbOp& op;
  54. };
  55. void OpDefEmitter::emit_header() {
  56. os << formatv(
  57. "class {0} : public OpDefImplBase<{0}> {{\n"
  58. " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
  59. "public:\n",
  60. op.getCppClassName());
  61. // handle enum alias
  62. for (auto&& i : op.getMgbAttributes()) {
  63. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  64. os << formatv(
  65. " using {0} = {1};\n", attr->getEnumName(),
  66. attr->getUnderlyingType());
  67. }
  68. }
  69. for (auto&& i : op.getMgbAttributes()) {
  70. auto defaultValue = i.attr.getDefaultValue().str();
  71. if (!defaultValue.empty()) {
  72. defaultValue = formatv(" = {0}", defaultValue);
  73. }
  74. os << formatv(" {0} {1}{2};\n", attr_to_ctype(i.attr), i.name, defaultValue);
  75. }
  76. auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
  77. os << formatv(
  78. " {0}({1}){2}{3}\n", op.getCppClassName(), paramList, memInitList,
  79. body);
  80. };
  81. gen_ctor("", "", " = default;");
  82. if (!op.getMgbAttributes().empty()) {
  83. std::string strategy_val = "";
  84. std::vector<std::string> paramList, initList;
  85. for (auto&& i : op.getMgbAttributes()) {
  86. if (attr_to_ctype(i.attr).compare("Strategy") == 0) {
  87. strategy_val = i.name;
  88. }
  89. paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name));
  90. initList.push_back(formatv("{0}({0}_)", i.name));
  91. }
  92. paramList.push_back("std::string scope_ = {}");
  93. if (!strategy_val.empty()) {
  94. gen_ctor(
  95. llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "),
  96. formatv(" {"
  97. "\n set_scope(scope_);"
  98. "\n mgb_assert(static_cast<uint32_t>({0}) <= "
  99. "uint32_t(8));"
  100. "\n }",
  101. strategy_val));
  102. } else {
  103. gen_ctor(
  104. llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "),
  105. " { set_scope(scope_); }");
  106. }
  107. }
  108. auto packedParams = op.getPackedParams();
  109. if (!packedParams.empty()) {
  110. std::string strategy_val = "";
  111. std::vector<std::string> paramList, initList;
  112. for (auto&& p : packedParams) {
  113. auto&& paramFields = p.getFields();
  114. auto&& paramType = p.getFullName();
  115. auto&& paramName = formatv("packed_param_{0}", paramList.size());
  116. paramList.push_back(
  117. paramFields.empty() ? paramType.str()
  118. : formatv("{0} {1}", paramType, paramName));
  119. for (auto&& i : paramFields) {
  120. if (i.name.compare("strategy") == 0) {
  121. strategy_val = i.name;
  122. }
  123. initList.push_back(formatv("{0}({1}.{0})", i.name, paramName));
  124. }
  125. }
  126. for (auto&& i : op.getExtraArguments()) {
  127. paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name));
  128. initList.push_back(formatv("{0}({0}_)", i.name));
  129. }
  130. if (!strategy_val.empty()) {
  131. gen_ctor(
  132. llvm::join(paramList, ", "),
  133. initList.empty() ? "" : ": " + llvm::join(initList, ", "),
  134. formatv(" {"
  135. "\n mgb_assert(static_cast<uint32_t>({0}) <= "
  136. "uint32_t(8));"
  137. "\n }",
  138. strategy_val));
  139. } else {
  140. gen_ctor(
  141. llvm::join(paramList, ", "),
  142. initList.empty() ? "" : ": " + llvm::join(initList, ", "), " {}");
  143. }
  144. }
  145. if (!packedParams.empty()) {
  146. for (auto&& p : packedParams) {
  147. auto accessor = p.getAccessor();
  148. if (!accessor.empty()) {
  149. os << formatv(" {0} {1}() const {{\n", p.getFullName(), accessor);
  150. std::vector<llvm::StringRef> fields;
  151. for (auto&& i : p.getFields()) {
  152. fields.push_back(i.name);
  153. }
  154. os << formatv(" return {{{0}};\n", llvm::join(fields, ", "));
  155. os << " }\n";
  156. }
  157. }
  158. }
  159. if (auto decl = op.getExtraOpdefDecl()) {
  160. os << decl.getValue();
  161. }
  162. os << formatv("};\n\n");
  163. }
  164. void OpDefEmitter::emit_tpl_spl() {
  165. for (auto&& i : op.getMgbAttributes()) {
  166. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  167. if (attr->supportToString()) {
  168. std::vector<std::string> case_body;
  169. std::string ename =
  170. formatv("{0}::{1}", op.getCppClassName(), attr->getEnumName());
  171. llvm::for_each(attr->getEnumMembers(), [&](auto&& v) {
  172. size_t d1 = v.find(' ');
  173. size_t d2 = v.find('=');
  174. size_t d = d1 <= d2 ? d1 : d2;
  175. case_body.push_back(formatv(
  176. "case {0}::{1}: return \"{1}\";", ename, v.substr(0, d)));
  177. });
  178. os << formatv(
  179. R"(
  180. template <>
  181. struct ToStringTrait<{0}> {
  182. std::string operator()({0} e) const {
  183. switch (e) {
  184. {1}
  185. default:
  186. return "{0}::Unknown";
  187. }
  188. }
  189. };
  190. )",
  191. ename, llvm::join(case_body, "\n"));
  192. }
  193. }
  194. }
  195. }
  196. void OpDefEmitter::emit_body() {
  197. auto&& className = op.getCppClassName();
  198. os << formatv("MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className);
  199. auto formatMethImpl = [&](auto&& meth) {
  200. return formatv("{0}_{1}_impl", className, meth);
  201. };
  202. std::vector<std::string> methods;
  203. if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
  204. os << "namespace {\n";
  205. // generate hash()
  206. mlir::tblgen::FmtContext ctx;
  207. os << formatv("size_t {0}(const OpDef& def_) {{\n", formatMethImpl("hash"));
  208. os << formatv(
  209. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  210. " static_cast<void>(op_);\n",
  211. className);
  212. ctx.withSelf("op_");
  213. os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
  214. os << "}\n";
  215. // generate is_same_st()
  216. os << formatv(
  217. "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
  218. formatMethImpl("is_same_st"));
  219. os << formatv(
  220. " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n"
  221. " &&b_ = rhs_.cast_final_safe<{0}>();\n"
  222. " static_cast<void>(a_);\n"
  223. " static_cast<void>(b_);\n",
  224. className);
  225. os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
  226. os << "}\n";
  227. // generate props()
  228. os << formatv(
  229. "std::vector<std::pair<const char*, std::string>> {0}(const OpDef& "
  230. "def_) {{\n",
  231. formatMethImpl("props"));
  232. os << formatv(
  233. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  234. " static_cast<void>(op_);\n",
  235. className);
  236. ctx.withSelf("op_");
  237. os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx);
  238. os << "}\n";
  239. // generate make_name()
  240. os << formatv(
  241. "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name"));
  242. os << formatv(
  243. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  244. " static_cast<void>(op_);\n",
  245. className);
  246. ctx.withSelf("op_");
  247. os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx);
  248. os << "}\n";
  249. os << "} // anonymous namespace\n";
  250. methods.push_back("hash");
  251. methods.push_back("is_same_st");
  252. methods.push_back("props");
  253. methods.push_back("make_name");
  254. }
  255. if (!methods.empty()) {
  256. os << formatv("OP_TRAIT_REG({0}, {0})", op.getCppClassName());
  257. for (auto&& i : methods) {
  258. os << formatv("\n .{0}({1})", i, formatMethImpl(i));
  259. }
  260. os << ";\n\n";
  261. }
  262. }
  263. } // namespace
  264. bool gen_op_def_c_header(raw_ostream& os, llvm::RecordKeeper& keeper) {
  265. foreach_operator(keeper, [&](MgbOp& op) {
  266. OpDefEmitter emitter(op, os);
  267. emitter.emit_header();
  268. emitter.emit_tpl_spl();
  269. });
  270. return false;
  271. }
  272. bool gen_op_def_c_body(raw_ostream& os, llvm::RecordKeeper& keeper) {
  273. foreach_operator(keeper, [&](MgbOp& op) {
  274. OpDefEmitter emitter(op, os);
  275. emitter.emit_body();
  276. });
  277. return false;
  278. }
  279. } // namespace mlir::tblgen