You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cpp_class.cpp 9.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. /**
  2. * \file imperative/tablegen/targets/cpp_class.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./cpp_class.h"
  12. #include "../emitter.h"
  13. namespace mlir::tblgen {
  14. namespace {
  15. llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
  16. // Note: we have already registered the corresponding attr wrappers
  17. // for following basic ctypes so we needn't handle them here
  18. /* auto&& attr_type_name = attr.getAttrDefName();
  19. if (attr_type_name == "UI32Attr") {
  20. return "uint32_t";
  21. }
  22. if (attr_type_name == "UI64Attr") {
  23. return "uint64_t";
  24. }
  25. if (attr_type_name == "I32Attr") {
  26. return "int32_t";
  27. }
  28. if (attr_type_name == "F32Attr") {
  29. return "float";
  30. }
  31. if (attr_type_name == "F64Attr") {
  32. return "double";
  33. }
  34. if (attr_type_name == "StrAttr") {
  35. return "std::string";
  36. }
  37. if (attr_type_name == "BoolAttr") {
  38. return "bool";
  39. }*/
  40. auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
  41. if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
  42. return e->getEnumName();
  43. }
  44. return attr.getUnderlyingType();
  45. }
  46. class OpDefEmitter final: public EmitterBase {
  47. public:
  48. OpDefEmitter(MgbOp& op_, raw_ostream& os_):
  49. EmitterBase(os_), op(op_) {}
  50. void emit_header();
  51. void emit_tpl_spl();
  52. void emit_body();
  53. private:
  54. MgbOp& op;
  55. };
  56. void OpDefEmitter::emit_header() {
  57. os << formatv(
  58. "class {0} : public OpDefImplBase<{0}> {{\n"
  59. " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
  60. "public:\n",
  61. op.getCppClassName()
  62. );
  63. // handle enum alias
  64. for (auto &&i : op.getMgbAttributes()) {
  65. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  66. os << formatv(
  67. " using {0} = {1};\n",
  68. attr->getEnumName(), attr->getUnderlyingType()
  69. );
  70. }
  71. }
  72. for (auto &&i : op.getMgbAttributes()) {
  73. auto defaultValue = i.attr.getDefaultValue().str();
  74. if (!defaultValue.empty()) {
  75. defaultValue = formatv(" = {0}", defaultValue);
  76. }
  77. os << formatv(
  78. " {0} {1}{2};\n",
  79. attr_to_ctype(i.attr), i.name, defaultValue
  80. );
  81. }
  82. auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
  83. os << formatv(
  84. " {0}({1}){2}{3}\n",
  85. op.getCppClassName(), paramList, memInitList, body
  86. );
  87. };
  88. gen_ctor("", "", " = default;");
  89. if (!op.getMgbAttributes().empty()) {
  90. std::vector<std::string> paramList, initList;
  91. for (auto &&i : op.getMgbAttributes()) {
  92. paramList.push_back(formatv(
  93. "{0} {1}_", attr_to_ctype(i.attr), i.name
  94. ));
  95. initList.push_back(formatv(
  96. "{0}({0}_)", i.name
  97. ));
  98. }
  99. paramList.push_back("std::string scope_ = {}");
  100. gen_ctor(llvm::join(paramList, ", "),
  101. ": " + llvm::join(initList, ", "),
  102. " { set_scope(scope_); }");
  103. }
  104. auto packedParams = op.getPackedParams();
  105. if (!packedParams.empty()) {
  106. std::vector<std::string> paramList, initList;
  107. for (auto &&p : packedParams) {
  108. auto&& paramFields = p.getFields();
  109. auto&& paramType = p.getFullName();
  110. auto&& paramName = formatv("packed_param_{0}", paramList.size());
  111. paramList.push_back(
  112. paramFields.empty() ? paramType.str()
  113. : formatv("{0} {1}", paramType, paramName)
  114. );
  115. for (auto&& i : paramFields) {
  116. initList.push_back(formatv(
  117. "{0}({1}.{0})", i.name, paramName
  118. ));
  119. }
  120. }
  121. for (auto&& i : op.getExtraArguments()) {
  122. paramList.push_back(formatv(
  123. "{0} {1}_", attr_to_ctype(i.attr), i.name
  124. ));
  125. initList.push_back(formatv(
  126. "{0}({0}_)", i.name
  127. ));
  128. }
  129. gen_ctor(llvm::join(paramList, ", "),
  130. initList.empty() ? "" : ": " + llvm::join(initList, ", "),
  131. " {}");
  132. }
  133. if (!packedParams.empty()) {
  134. for (auto&& p : packedParams) {
  135. auto accessor = p.getAccessor();
  136. if (!accessor.empty()) {
  137. os << formatv(
  138. " {0} {1}() const {{\n",
  139. p.getFullName(), accessor
  140. );
  141. std::vector<llvm::StringRef> fields;
  142. for (auto&& i : p.getFields()) {
  143. fields.push_back(i.name);
  144. }
  145. os << formatv(
  146. " return {{{0}};\n",
  147. llvm::join(fields, ", ")
  148. );
  149. os << " }\n";
  150. }
  151. }
  152. }
  153. if (auto decl = op.getExtraOpdefDecl()) {
  154. os << decl.getValue();
  155. }
  156. os << formatv(
  157. "};\n\n"
  158. );
  159. }
  160. void OpDefEmitter::emit_tpl_spl() {
  161. for (auto &&i : op.getMgbAttributes()) {
  162. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  163. if (attr->supportToString()) {
  164. std::vector<std::string> case_body;
  165. std::string ename = formatv("{0}::{1}",
  166. op.getCppClassName(), attr->getEnumName());
  167. llvm::for_each(attr->getEnumMembers(), [&](auto&& v){
  168. case_body.push_back(formatv(
  169. "case {0}::{1}: return \"{1}\";", ename, v));
  170. });
  171. os << formatv(R"(
  172. template <>
  173. struct ToStringTrait<{0}> {
  174. std::string operator()({0} e) const {
  175. switch (e) {
  176. {1}
  177. default:
  178. return "{0}::Unknown";
  179. }
  180. }
  181. };
  182. )", ename, llvm::join(case_body, "\n"));
  183. }
  184. }
  185. }
  186. }
  187. void OpDefEmitter::emit_body() {
  188. auto&& className = op.getCppClassName();
  189. os << formatv(
  190. "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className
  191. );
  192. auto formatMethImpl = [&](auto&& meth) {
  193. return formatv(
  194. "{0}_{1}_impl", className, meth
  195. );
  196. };
  197. std::vector<std::string> methods;
  198. if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
  199. os << "namespace {\n";
  200. // generate hash()
  201. mlir::tblgen::FmtContext ctx;
  202. os << formatv(
  203. "size_t {0}(const OpDef& def_) {{\n",
  204. formatMethImpl("hash")
  205. );
  206. os << formatv(
  207. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  208. " static_cast<void>(op_);\n",
  209. className
  210. );
  211. ctx.withSelf("op_");
  212. os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
  213. os << "}\n";
  214. // generate is_same_st()
  215. os << formatv(
  216. "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
  217. formatMethImpl("is_same_st")
  218. );
  219. os << formatv(
  220. " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n"
  221. " &&b_ = rhs_.cast_final_safe<{0}>();\n"
  222. " static_cast<void>(a_);\n"
  223. " static_cast<void>(b_);\n",
  224. className
  225. );
  226. os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
  227. os << "}\n";
  228. // generate props()
  229. os << formatv(
  230. "std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n",
  231. formatMethImpl("props")
  232. );
  233. os << formatv(
  234. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  235. " static_cast<void>(op_);\n",
  236. className
  237. );
  238. ctx.withSelf("op_");
  239. os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx);
  240. os << "}\n";
  241. // generate make_name()
  242. os << formatv(
  243. "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name")
  244. );
  245. os << formatv(
  246. " auto&& op_ = def_.cast_final_safe<{0}>();\n"
  247. " static_cast<void>(op_);\n",
  248. className
  249. );
  250. ctx.withSelf("op_");
  251. os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx);
  252. os << "}\n";
  253. os << "} // anonymous namespace\n";
  254. methods.push_back("hash");
  255. methods.push_back("is_same_st");
  256. methods.push_back("props");
  257. methods.push_back("make_name");
  258. }
  259. if (!methods.empty()) {
  260. os << formatv(
  261. "OP_TRAIT_REG({0}, {0})", op.getCppClassName()
  262. );
  263. for (auto&& i : methods) {
  264. os << formatv(
  265. "\n .{0}({1})", i, formatMethImpl(i)
  266. );
  267. }
  268. os << ";\n\n";
  269. }
  270. }
  271. } // namespace
  272. bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper) {
  273. foreach_operator(keeper, [&](MgbOp& op) {
  274. OpDefEmitter emitter(op, os);
  275. emitter.emit_header();
  276. emitter.emit_tpl_spl();
  277. });
  278. return false;
  279. }
  280. bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper) {
  281. foreach_operator(keeper, [&](MgbOp& op) {
  282. OpDefEmitter emitter(op, os);
  283. emitter.emit_body();
  284. });
  285. return false;
  286. }
  287. } // namespace mlir::tblgen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台