You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cpp_class.cpp 9.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. /**
  2. * \file imperative/tablegen/targets/cpp_class.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./cpp_class.h"
  12. #include "../emitter.h"
  13. namespace mlir::tblgen {
  14. namespace {
  15. llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
  16. // Note: we have already registered the corresponding attr wrappers
  17. // for following basic ctypes so we needn't handle them here
  18. /* auto&& attr_type_name = attr.getAttrDefName();
  19. if (attr_type_name == "UI32Attr") {
  20. return "uint32_t";
  21. }
  22. if (attr_type_name == "UI64Attr") {
  23. return "uint64_t";
  24. }
  25. if (attr_type_name == "I32Attr") {
  26. return "int32_t";
  27. }
  28. if (attr_type_name == "F32Attr") {
  29. return "float";
  30. }
  31. if (attr_type_name == "F64Attr") {
  32. return "double";
  33. }
  34. if (attr_type_name == "StrAttr") {
  35. return "std::string";
  36. }
  37. if (attr_type_name == "BoolAttr") {
  38. return "bool";
  39. }*/
  40. auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
  41. if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
  42. return e->getEnumName();
  43. }
  44. return attr.getUnderlyingType();
  45. }
  46. class OpDefEmitter final : public EmitterBase {
  47. public:
  48. OpDefEmitter(MgbOp& op_, raw_ostream& os_) : EmitterBase(os_), op(op_) {}
  49. void emit_header();
  50. void emit_tpl_spl();
  51. void emit_body();
  52. private:
  53. MgbOp& op;
  54. };
  55. void OpDefEmitter::emit_header() {
  56. os << formatv(
  57. "class {0} : public OpDefImplBase<{0}> {{\n"
  58. " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
  59. "public:\n",
  60. op.getCppClassName());
  61. // handle enum alias
  62. for (auto&& i : op.getMgbAttributes()) {
  63. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  64. os << formatv(
  65. " using {0} = {1};\n", attr->getEnumName(),
  66. attr->getUnderlyingType());
  67. }
  68. }
  69. for (auto&& i : op.getMgbAttributes()) {
  70. auto defaultValue = i.attr.getDefaultValue().str();
  71. if (!defaultValue.empty()) {
  72. defaultValue = formatv(" = {0}", defaultValue);
  73. }
  74. os << formatv(" {0} {1}{2};\n", attr_to_ctype(i.attr), i.name, defaultValue);
  75. }
  76. auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
  77. os << formatv(
  78. " {0}({1}){2}{3}\n", op.getCppClassName(), paramList, memInitList,
  79. body);
  80. };
  81. gen_ctor("", "", " = default;");
  82. if (!op.getMgbAttributes().empty()) {
  83. std::vector<std::string> paramList, initList;
  84. for (auto&& i : op.getMgbAttributes()) {
  85. paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name));
  86. initList.push_back(formatv("{0}({0}_)", i.name));
  87. }
  88. paramList.push_back("std::string scope_ = {}");
  89. gen_ctor(
  90. llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "),
  91. " { set_scope(scope_); }");
  92. }
  93. auto packedParams = op.getPackedParams();
  94. if (!packedParams.empty()) {
  95. std::vector<std::string> paramList, initList;
  96. for (auto&& p : packedParams) {
  97. auto&& paramFields = p.getFields();
  98. auto&& paramType = p.getFullName();
  99. auto&& paramName = formatv("packed_param_{0}", paramList.size());
  100. paramList.push_back(
  101. paramFields.empty() ? paramType.str()
  102. : formatv("{0} {1}", paramType, paramName));
  103. for (auto&& i : paramFields) {
  104. initList.push_back(formatv("{0}({1}.{0})", i.name, paramName));
  105. }
  106. }
  107. for (auto&& i : op.getExtraArguments()) {
  108. paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name));
  109. initList.push_back(formatv("{0}({0}_)", i.name));
  110. }
  111. gen_ctor(
  112. llvm::join(paramList, ", "),
  113. initList.empty() ? "" : ": " + llvm::join(initList, ", "), " {}");
  114. }
  115. if (!packedParams.empty()) {
  116. for (auto&& p : packedParams) {
  117. auto accessor = p.getAccessor();
  118. if (!accessor.empty()) {
  119. os << formatv(" {0} {1}() const {{\n", p.getFullName(), accessor);
  120. std::vector<llvm::StringRef> fields;
  121. for (auto&& i : p.getFields()) {
  122. fields.push_back(i.name);
  123. }
  124. os << formatv(" return {{{0}};\n", llvm::join(fields, ", "));
  125. os << " }\n";
  126. }
  127. }
  128. }
  129. if (auto decl = op.getExtraOpdefDecl()) {
  130. os << decl.getValue();
  131. }
  132. os << formatv("};\n\n");
  133. }
  134. void OpDefEmitter::emit_tpl_spl() {
  135. for (auto&& i : op.getMgbAttributes()) {
  136. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  137. if (attr->supportToString()) {
  138. std::vector<std::string> case_body;
  139. std::string ename =
  140. formatv("{0}::{1}", op.getCppClassName(), attr->getEnumName());
  141. llvm::for_each(attr->getEnumMembers(), [&](auto&& v) {
  142. size_t d1 = v.find(' ');
  143. size_t d2 = v.find('=');
  144. size_t d = d1 <= d2 ? d1 : d2;
  145. case_body.push_back(formatv(
  146. "case {0}::{1}: return \"{1}\";", ename, v.substr(0, d)));
  147. });
  148. os << formatv(
  149. R"(
  150. template <>
  151. struct ToStringTrait<{0}> {
  152. std::string operator()({0} e) const {
  153. switch (e) {
  154. {1}
  155. default:
  156. return "{0}::Unknown";
  157. }
  158. }
  159. };
  160. )",
  161. ename, llvm::join(case_body, "\n"));
  162. }
  163. }
  164. }
  165. }
  166. void OpDefEmitter::emit_body() {
  167. auto&& className = op.getCppClassName();
  168. os << formatv("MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className);
  169. auto formatMethImpl = [&](auto&& meth) {
  170. return formatv("{0}_{1}_impl", className, meth);
  171. };
  172. std::vector<std::string> methods;
  173. if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
  174. os << "namespace {\n";
  175. // generate hash()
  176. mlir::tblgen::FmtContext ctx;
  177. os << formatv("size_t {0}(const OpDef& def_) {{\n", formatMethImpl("hash"));
  178. os << formatv(
  179. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  180. " static_cast<void>(op_);\n",
  181. className);
  182. ctx.withSelf("op_");
  183. os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
  184. os << "}\n";
  185. // generate is_same_st()
  186. os << formatv(
  187. "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
  188. formatMethImpl("is_same_st"));
  189. os << formatv(
  190. " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n"
  191. " &&b_ = rhs_.cast_final_safe<{0}>();\n"
  192. " static_cast<void>(a_);\n"
  193. " static_cast<void>(b_);\n",
  194. className);
  195. os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
  196. os << "}\n";
  197. // generate props()
  198. os << formatv(
  199. "std::vector<std::pair<const char*, std::string>> {0}(const OpDef& "
  200. "def_) {{\n",
  201. formatMethImpl("props"));
  202. os << formatv(
  203. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  204. " static_cast<void>(op_);\n",
  205. className);
  206. ctx.withSelf("op_");
  207. os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx);
  208. os << "}\n";
  209. // generate make_name()
  210. os << formatv(
  211. "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name"));
  212. os << formatv(
  213. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  214. " static_cast<void>(op_);\n",
  215. className);
  216. ctx.withSelf("op_");
  217. os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx);
  218. os << "}\n";
  219. os << "} // anonymous namespace\n";
  220. methods.push_back("hash");
  221. methods.push_back("is_same_st");
  222. methods.push_back("props");
  223. methods.push_back("make_name");
  224. }
  225. if (!methods.empty()) {
  226. os << formatv("OP_TRAIT_REG({0}, {0})", op.getCppClassName());
  227. for (auto&& i : methods) {
  228. os << formatv("\n .{0}({1})", i, formatMethImpl(i));
  229. }
  230. os << ";\n\n";
  231. }
  232. }
  233. } // namespace
  234. bool gen_op_def_c_header(raw_ostream& os, llvm::RecordKeeper& keeper) {
  235. foreach_operator(keeper, [&](MgbOp& op) {
  236. OpDefEmitter emitter(op, os);
  237. emitter.emit_header();
  238. emitter.emit_tpl_spl();
  239. });
  240. return false;
  241. }
  242. bool gen_op_def_c_body(raw_ostream& os, llvm::RecordKeeper& keeper) {
  243. foreach_operator(keeper, [&](MgbOp& op) {
  244. OpDefEmitter emitter(op, os);
  245. emitter.emit_body();
  246. });
  247. return false;
  248. }
  249. } // namespace mlir::tblgen