You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cpp_class.cpp 9.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. /**
  2. * \file imperative/tablegen/targets/cpp_class.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./cpp_class.h"
  12. #include "../emitter.h"
  13. namespace mlir::tblgen {
  14. namespace {
  15. llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
  16. // Note: we have already registered the corresponding attr wrappers
  17. // for following basic ctypes so we needn't handle them here
  18. /* auto&& attr_type_name = attr.getAttrDefName();
  19. if (attr_type_name == "UI32Attr") {
  20. return "uint32_t";
  21. }
  22. if (attr_type_name == "UI64Attr") {
  23. return "uint64_t";
  24. }
  25. if (attr_type_name == "I32Attr") {
  26. return "int32_t";
  27. }
  28. if (attr_type_name == "F32Attr") {
  29. return "float";
  30. }
  31. if (attr_type_name == "F64Attr") {
  32. return "double";
  33. }
  34. if (attr_type_name == "StrAttr") {
  35. return "std::string";
  36. }
  37. if (attr_type_name == "BoolAttr") {
  38. return "bool";
  39. }*/
  40. auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
  41. if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
  42. return e->getEnumName();
  43. }
  44. return attr.getUnderlyingType();
  45. }
  46. class OpDefEmitter final: public EmitterBase {
  47. public:
  48. OpDefEmitter(MgbOp& op_, raw_ostream& os_):
  49. EmitterBase(os_), op(op_) {}
  50. void emit_header();
  51. void emit_tpl_spl();
  52. void emit_body();
  53. private:
  54. MgbOp& op;
  55. };
  56. void OpDefEmitter::emit_header() {
  57. os << formatv(
  58. "class {0} : public OpDefImplBase<{0}> {{\n"
  59. " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
  60. "public:\n",
  61. op.getCppClassName()
  62. );
  63. // handle enum alias
  64. for (auto &&i : op.getMgbAttributes()) {
  65. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  66. os << formatv(
  67. " using {0} = {1};\n",
  68. attr->getEnumName(), attr->getUnderlyingType()
  69. );
  70. }
  71. }
  72. for (auto &&i : op.getMgbAttributes()) {
  73. auto defaultValue = i.attr.getDefaultValue().str();
  74. if (!defaultValue.empty()) {
  75. defaultValue = formatv(" = {0}", defaultValue);
  76. }
  77. os << formatv(
  78. " {0} {1}{2};\n",
  79. attr_to_ctype(i.attr), i.name, defaultValue
  80. );
  81. }
  82. auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
  83. os << formatv(
  84. " {0}({1}){2}{3}\n",
  85. op.getCppClassName(), paramList, memInitList, body
  86. );
  87. };
  88. gen_ctor("", "", " = default;");
  89. if (!op.getMgbAttributes().empty()) {
  90. std::vector<std::string> paramList, initList;
  91. for (auto &&i : op.getMgbAttributes()) {
  92. paramList.push_back(formatv(
  93. "{0} {1}_", attr_to_ctype(i.attr), i.name
  94. ));
  95. initList.push_back(formatv(
  96. "{0}({0}_)", i.name
  97. ));
  98. }
  99. paramList.push_back("std::string scope_ = {}");
  100. gen_ctor(llvm::join(paramList, ", "),
  101. ": " + llvm::join(initList, ", "),
  102. " { set_scope(scope_); }");
  103. }
  104. auto packedParams = op.getPackedParams();
  105. if (!packedParams.empty()) {
  106. std::vector<std::string> paramList, initList;
  107. for (auto &&p : packedParams) {
  108. auto&& paramFields = p.getFields();
  109. auto&& paramType = p.getFullName();
  110. auto&& paramName = formatv("packed_param_{0}", paramList.size());
  111. paramList.push_back(
  112. paramFields.empty() ? paramType.str()
  113. : formatv("{0} {1}", paramType, paramName)
  114. );
  115. for (auto&& i : paramFields) {
  116. initList.push_back(formatv(
  117. "{0}({1}.{0})", i.name, paramName
  118. ));
  119. }
  120. }
  121. for (auto&& i : op.getExtraArguments()) {
  122. paramList.push_back(formatv(
  123. "{0} {1}_", attr_to_ctype(i.attr), i.name
  124. ));
  125. initList.push_back(formatv(
  126. "{0}({0}_)", i.name
  127. ));
  128. }
  129. gen_ctor(llvm::join(paramList, ", "),
  130. initList.empty() ? "" : ": " + llvm::join(initList, ", "),
  131. " {}");
  132. }
  133. if (!packedParams.empty()) {
  134. for (auto&& p : packedParams) {
  135. auto accessor = p.getAccessor();
  136. if (!accessor.empty()) {
  137. os << formatv(
  138. " {0} {1}() const {{\n",
  139. p.getFullName(), accessor
  140. );
  141. std::vector<llvm::StringRef> fields;
  142. for (auto&& i : p.getFields()) {
  143. fields.push_back(i.name);
  144. }
  145. os << formatv(
  146. " return {{{0}};\n",
  147. llvm::join(fields, ", ")
  148. );
  149. os << " }\n";
  150. }
  151. }
  152. }
  153. if (auto decl = op.getExtraOpdefDecl()) {
  154. os << decl.getValue();
  155. }
  156. os << formatv(
  157. "};\n\n"
  158. );
  159. }
  160. void OpDefEmitter::emit_tpl_spl() {
  161. for (auto &&i : op.getMgbAttributes()) {
  162. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  163. if (attr->supportToString()) {
  164. std::vector<std::string> case_body;
  165. std::string ename = formatv("{0}::{1}",
  166. op.getCppClassName(), attr->getEnumName());
  167. llvm::for_each(attr->getEnumMembers(), [&](auto&& v) {
  168. size_t d1 = v.find(' ');
  169. size_t d2 = v.find('=');
  170. size_t d = d1 <= d2 ? d1 : d2;
  171. case_body.push_back(
  172. formatv("case {0}::{1}: return \"{1}\";", ename,
  173. v.substr(0, d)));
  174. });
  175. os << formatv(R"(
  176. template <>
  177. struct ToStringTrait<{0}> {
  178. std::string operator()({0} e) const {
  179. switch (e) {
  180. {1}
  181. default:
  182. return "{0}::Unknown";
  183. }
  184. }
  185. };
  186. )", ename, llvm::join(case_body, "\n"));
  187. }
  188. }
  189. }
  190. }
  191. void OpDefEmitter::emit_body() {
  192. auto&& className = op.getCppClassName();
  193. os << formatv(
  194. "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className
  195. );
  196. auto formatMethImpl = [&](auto&& meth) {
  197. return formatv(
  198. "{0}_{1}_impl", className, meth
  199. );
  200. };
  201. std::vector<std::string> methods;
  202. if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
  203. os << "namespace {\n";
  204. // generate hash()
  205. mlir::tblgen::FmtContext ctx;
  206. os << formatv(
  207. "size_t {0}(const OpDef& def_) {{\n",
  208. formatMethImpl("hash")
  209. );
  210. os << formatv(
  211. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  212. " static_cast<void>(op_);\n",
  213. className
  214. );
  215. ctx.withSelf("op_");
  216. os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
  217. os << "}\n";
  218. // generate is_same_st()
  219. os << formatv(
  220. "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
  221. formatMethImpl("is_same_st")
  222. );
  223. os << formatv(
  224. " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n"
  225. " &&b_ = rhs_.cast_final_safe<{0}>();\n"
  226. " static_cast<void>(a_);\n"
  227. " static_cast<void>(b_);\n",
  228. className
  229. );
  230. os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
  231. os << "}\n";
  232. // generate props()
  233. os << formatv(
  234. "std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n",
  235. formatMethImpl("props")
  236. );
  237. os << formatv(
  238. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  239. " static_cast<void>(op_);\n",
  240. className
  241. );
  242. ctx.withSelf("op_");
  243. os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx);
  244. os << "}\n";
  245. // generate make_name()
  246. os << formatv(
  247. "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name")
  248. );
  249. os << formatv(
  250. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  251. " static_cast<void>(op_);\n",
  252. className
  253. );
  254. ctx.withSelf("op_");
  255. os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx);
  256. os << "}\n";
  257. os << "} // anonymous namespace\n";
  258. methods.push_back("hash");
  259. methods.push_back("is_same_st");
  260. methods.push_back("props");
  261. methods.push_back("make_name");
  262. }
  263. if (!methods.empty()) {
  264. os << formatv(
  265. "OP_TRAIT_REG({0}, {0})", op.getCppClassName()
  266. );
  267. for (auto&& i : methods) {
  268. os << formatv(
  269. "\n .{0}({1})", i, formatMethImpl(i)
  270. );
  271. }
  272. os << ";\n\n";
  273. }
  274. }
  275. } // namespace
  276. bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper) {
  277. foreach_operator(keeper, [&](MgbOp& op) {
  278. OpDefEmitter emitter(op, os);
  279. emitter.emit_header();
  280. emitter.emit_tpl_spl();
  281. });
  282. return false;
  283. }
  284. bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper) {
  285. foreach_operator(keeper, [&](MgbOp& op) {
  286. OpDefEmitter emitter(op, os);
  287. emitter.emit_body();
  288. });
  289. return false;
  290. }
  291. } // namespace mlir::tblgen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台