You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.h 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. /**
  2. * \file imperative/tablegen/helper.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <iostream>
  13. #include <string>
  14. #include <vector>
  15. #include "llvm/Support/CommandLine.h"
  16. #include "llvm/Support/FormatVariadic.h"
  17. #include "llvm/Support/InitLLVM.h"
  18. #include "llvm/Support/Signals.h"
  19. #include "llvm/TableGen/Main.h"
  20. #include "llvm/TableGen/Record.h"
  21. #include "llvm/TableGen/TableGenBackend.h"
  22. #include "mlir/TableGen/Attribute.h"
  23. #include "mlir/TableGen/Format.h"
  24. #include "mlir/TableGen/Operator.h"
  25. using llvm::formatv;
  26. using llvm::StringRef;
  27. using llvm::Record;
  28. #define ASSERT(stmt, msg) \
  29. if (!(stmt)) { \
  30. std::cerr << "\033[1;31m" \
  31. << "tablegen autogen abort due to: " << msg \
  32. << "\033[0m" << std::endl; \
  33. exit(1); \
  34. }
  35. namespace mlir {
  36. namespace tblgen {
  37. template<typename ConcreteType>
  38. struct MgbInterface : public ConcreteType {
  39. MgbInterface() = delete;
  40. MgbInterface(const MgbInterface&) = delete;
  41. MgbInterface(MgbInterface&&) = delete;
  42. ~MgbInterface() = delete;
  43. };
  44. struct MgbAttrWrapperBase : public MgbInterface<Attribute> {
  45. private:
  46. struct RecordVisitor : public MgbInterface<Constraint> {
  47. public:
  48. static bool classof(const Constraint*) {
  49. return true;
  50. }
  51. const llvm::Record* getDef() const {
  52. return def;
  53. }
  54. };
  55. public:
  56. static bool classof(const Attribute* attr) {
  57. return attr->isSubClassOf("MgbAttrWrapperBase");
  58. }
  59. const llvm::Record* getBaseRecord() const {
  60. auto baseAttr = getBaseAttr();
  61. return llvm::cast<RecordVisitor>(baseAttr).getDef();
  62. }
  63. llvm::StringRef getUnderlyingType() const {
  64. return def->getValueAsString("underlyingType");
  65. }
  66. };
  67. struct MgbEnumAttrMixin : public MgbAttrWrapperBase {
  68. static bool classof(const Attribute* attr) {
  69. return attr->getBaseAttr().isSubClassOf("MgbEnumAttrMixin");
  70. }
  71. llvm::StringRef getParentNamespace() const {
  72. return getBaseRecord()->getValueAsString("parentNamespace");
  73. }
  74. llvm::StringRef getEnumName() const {
  75. return getBaseRecord()->getValueAsString("enumName");
  76. }
  77. std::vector<StringRef> getEnumMembers() const {
  78. return getBaseRecord()->getValueAsListOfStrings("enumMembers");
  79. }
  80. bool supportToString() const {
  81. return getBaseRecord()->getValueAsBit("supportToString");
  82. }
  83. bool getEnumCombinedFlag() const {
  84. return getBaseRecord()->getValueAsBit("enumCombined");
  85. }
  86. };
  87. struct MgbHashableAttrMixin : public MgbAttrWrapperBase {
  88. static bool classof(const Attribute* attr) {
  89. return attr->getBaseAttr().isSubClassOf("MgbHashableAttrMixin");
  90. }
  91. llvm::StringRef getHashFunctionTemplate() const {
  92. return getBaseRecord()->getValueAsString("hashFunction");
  93. }
  94. llvm::StringRef getCmpFunctionTemplate() const {
  95. return getBaseRecord()->getValueAsString("cmpFunction");
  96. }
  97. llvm::StringRef getReprFunctionTemplate() const {
  98. return getBaseRecord()->getValueAsString("reprFunction");
  99. }
  100. };
  101. struct MgbAliasAttrMixin : public MgbAttrWrapperBase {
  102. static bool classof(const Attribute* attr) {
  103. return attr->getBaseAttr().isSubClassOf("MgbAliasAttrMixin");
  104. }
  105. Attribute getAliasBase() const {
  106. return Attribute(getBaseRecord()->getValueAsDef("aliasBase"));
  107. }
  108. };
  109. class MgbPackedParam {
  110. public:
  111. MgbPackedParam(Record* def_): def(def_) {
  112. auto&& dag = def->getValueAsDag("fields");
  113. for (size_t i = 0; i < dag->getNumArgs(); ++ i) {
  114. fields.push_back({
  115. dag->getArgNameStr(i),
  116. Attribute(llvm::cast<llvm::DefInit>(dag->getArg(i)))
  117. });
  118. }
  119. }
  120. llvm::StringRef getFullName() const {
  121. return def->getValueAsString("fullName");
  122. }
  123. std::vector<NamedAttribute> getFields() const {
  124. return fields;
  125. }
  126. llvm::StringRef getAccessor() const {
  127. return def->getValueAsString("paramAccessor");
  128. }
  129. private:
  130. std::vector<NamedAttribute> fields;
  131. Record* def;
  132. };
  133. struct MgbOpBase : public MgbInterface<Operator> {
  134. static bool isPackedParam(Record* def) {
  135. return def->isSubClassOf("MgbPackedParamBase");
  136. }
  137. public:
  138. static bool classof(const Operator* op) {
  139. return op->getDef().isSubClassOf("MgbOp");
  140. }
  141. std::vector<NamedAttribute> getMgbAttributes() const {
  142. std::vector<NamedAttribute> ret;
  143. for (auto&& i: getAttributes()) {
  144. if (isa<MgbAttrWrapperBase>(i.attr)) {
  145. ret.push_back(i);
  146. }
  147. }
  148. return ret;
  149. }
  150. std::vector<NamedAttribute> getExtraArguments() const {
  151. std::vector<NamedAttribute> ret;
  152. auto&& dag = getDef().getValueAsDag("extraArguments");
  153. for (size_t i = 0; i < dag->getNumArgs(); ++ i) {
  154. ret.push_back({
  155. dag->getArgNameStr(i),
  156. Attribute(llvm::cast<llvm::DefInit>(dag->getArg(i)))
  157. });
  158. }
  159. return ret;
  160. }
  161. llvm::Optional<StringRef> getExtraOpdefDecl() const {
  162. return getDef().getValueAsOptionalString("extraOpdefDecl");
  163. }
  164. std::vector<MgbPackedParam> getPackedParams() const {
  165. std::vector<MgbPackedParam> ret;
  166. for (auto&& i : getDef().getValueAsListOfDefs("dnnParams")) {
  167. if (isPackedParam(i)) {
  168. ret.emplace_back(i);
  169. }
  170. }
  171. return ret;
  172. }
  173. std::string getNameFunctionTemplate() const {
  174. if (auto f = getDef().getValueAsOptionalString("nameFunction")) {
  175. return f.getValue().str();
  176. }
  177. return formatv(" return \"{0}\";\n", getCppClassName());
  178. }
  179. };
  180. struct MgbHashableOpMixin : public MgbOpBase {
  181. private:
  182. std::string getDefaultHashFunction() const {
  183. std::string body = " size_t val = mgb::hash($_self.dyn_typeinfo());\n";
  184. if (!getMgbAttributes().empty()) {
  185. auto getHashFunc = [&](auto&& iter) {
  186. auto&& attr = llvm::cast<MgbHashableAttrMixin>(iter.attr);
  187. return attr.getHashFunctionTemplate();
  188. };
  189. mlir::tblgen::FmtContext ctx;
  190. for (auto&& it: getMgbAttributes()) {
  191. body += formatv(
  192. " val = mgb::hash_pair_combine(val, {0});\n",
  193. mlir::tblgen::tgfmt(getHashFunc(it), &ctx, "$_self." + it.name)
  194. );
  195. }
  196. }
  197. body += " return val;\n";
  198. return body;
  199. }
  200. std::string getDefaultCmpFunction() const {
  201. std::string body;
  202. if (!getMgbAttributes().empty()) {
  203. mlir::tblgen::FmtContext ctx;
  204. for (auto&& it : getMgbAttributes()) {
  205. auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr);
  206. body += formatv(
  207. " if ({0}) return false;\n",
  208. mlir::tblgen::tgfmt(attr.getCmpFunctionTemplate(),
  209. &ctx, "$0." + it.name, "$1." + it.name)
  210. );
  211. }
  212. }
  213. body += " return true;\n";
  214. return body;
  215. }
  216. std::string getDefaultPropsFunction() const {
  217. std::string body = " std::vector<std::pair<const char*, std::string>> props_;\n";
  218. if (!getMgbAttributes().empty()) {
  219. mlir::tblgen::FmtContext ctx;
  220. for (auto&& it : getMgbAttributes()) {
  221. if (auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr)) {
  222. body += formatv(" switch ({0}){{\n", "$_self." + it.name);
  223. for (auto&& enumMember: enumAttr->getEnumMembers()) {
  224. body += formatv(
  225. " case {0}::{1}::{2}:\n",
  226. getCppClassName(), enumAttr->getEnumName(), enumMember
  227. );
  228. body += formatv(
  229. " props_.emplace_back(\"{0}\", \"{1}\");\n",
  230. it.name, enumMember
  231. );
  232. body += " break;\n";
  233. }
  234. body += " default: break;\n";
  235. body += " }\n";
  236. } else {
  237. auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr);
  238. body += formatv(
  239. " props_.emplace_back(\"{0}\", {1});\n", it.name,
  240. mlir::tblgen::tgfmt(attr.getReprFunctionTemplate(),
  241. &ctx, "$_self." + it.name)
  242. );
  243. }
  244. }
  245. }
  246. body += " return props_;\n";
  247. return body;
  248. }
  249. public:
  250. static bool classof(const Operator* op) {
  251. return op->getDef().isSubClassOf("MgbHashableOpMixin");
  252. }
  253. std::string getHashFunctionTemplate() const {
  254. if (auto f = getDef().getValueAsOptionalString("hashFunction")) {
  255. return f.getValue().str();
  256. }
  257. return getDefaultHashFunction();
  258. }
  259. std::string getCmpFunctionTemplate() const {
  260. if (auto f = getDef().getValueAsOptionalString("cmpFunction")) {
  261. return f.getValue().str();
  262. }
  263. return getDefaultCmpFunction();
  264. }
  265. std::string getPropsFunctionTemplate() const {
  266. if (auto f = getDef().getValueAsOptionalString("propsFunction")) {
  267. return f.getValue().str();
  268. }
  269. return getDefaultPropsFunction();
  270. }
  271. };
  272. using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase;
  273. using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin;
  274. using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin;
  275. using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin;
  276. using MgbOp = mlir::tblgen::MgbOpBase;
  277. using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin;
  278. static inline void foreach_operator(llvm::RecordKeeper &keeper,
  279. std::function<void(MgbOp&)> callback) {
  280. auto op_base_class = keeper.getClass("Op");
  281. ASSERT(op_base_class, "could not find base class Op");
  282. for (auto&& i: keeper.getDefs()) {
  283. auto&& r = i.second;
  284. if (r->isSubClassOf(op_base_class)) {
  285. auto op = mlir::tblgen::Operator(r.get());
  286. if (op.getDialectName().str() == "mgb") {
  287. std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl;
  288. callback(llvm::cast<MgbOp>(op));
  289. }
  290. }
  291. }
  292. }
  293. } // namespace tblgen
  294. } // namespace mlir

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台