You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. /**
  2. * \file imperative/tablegen/helper.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <iostream>
  13. #include <string>
  14. #include <vector>
  15. #include "llvm/Support/CommandLine.h"
  16. #include "llvm/Support/FormatVariadic.h"
  17. #include "llvm/Support/InitLLVM.h"
  18. #include "llvm/Support/Signals.h"
  19. #include "llvm/TableGen/Main.h"
  20. #include "llvm/TableGen/Record.h"
  21. #include "llvm/TableGen/TableGenBackend.h"
  22. #include "mlir/TableGen/Attribute.h"
  23. #include "mlir/TableGen/Format.h"
  24. #include "mlir/TableGen/Operator.h"
  25. using llvm::formatv;
  26. using llvm::Record;
  27. using llvm::StringRef;
  28. #define ASSERT(stmt, msg) \
  29. if (!(stmt)) { \
  30. std::cerr << "\033[1;31m" \
  31. << "tablegen autogen abort due to: " << msg << "\033[0m" \
  32. << std::endl; \
  33. exit(1); \
  34. }
  35. namespace mlir {
  36. namespace tblgen {
  37. template <typename ConcreteType>
  38. struct MgbInterface : public ConcreteType {
  39. MgbInterface() = delete;
  40. MgbInterface(const MgbInterface&) = delete;
  41. MgbInterface(MgbInterface&&) = delete;
  42. ~MgbInterface() = delete;
  43. };
  44. struct MgbAttrWrapperBase : public MgbInterface<Attribute> {
  45. private:
  46. struct RecordVisitor : public MgbInterface<Constraint> {
  47. public:
  48. static bool classof(const Constraint*) { return true; }
  49. const llvm::Record* getDef() const { return def; }
  50. };
  51. public:
  52. static bool classof(const Attribute* attr) {
  53. return attr->isSubClassOf("MgbAttrWrapperBase");
  54. }
  55. const llvm::Record* getBaseRecord() const {
  56. auto baseAttr = getBaseAttr();
  57. return llvm::cast<RecordVisitor>(baseAttr).getDef();
  58. }
  59. llvm::StringRef getUnderlyingType() const {
  60. return def->getValueAsString("underlyingType");
  61. }
  62. };
  63. struct MgbEnumAttrMixin : public MgbAttrWrapperBase {
  64. static bool classof(const Attribute* attr) {
  65. return attr->getBaseAttr().isSubClassOf("MgbEnumAttrMixin");
  66. }
  67. llvm::StringRef getParentNamespace() const {
  68. return getBaseRecord()->getValueAsString("parentNamespace");
  69. }
  70. llvm::StringRef getEnumName() const {
  71. return getBaseRecord()->getValueAsString("enumName");
  72. }
  73. std::vector<StringRef> getEnumMembers() const {
  74. return getBaseRecord()->getValueAsListOfStrings("enumMembers");
  75. }
  76. bool supportToString() const {
  77. return getBaseRecord()->getValueAsBit("supportToString");
  78. }
  79. bool getEnumCombinedFlag() const {
  80. return getBaseRecord()->getValueAsBit("enumCombined");
  81. }
  82. };
  83. struct MgbHashableAttrMixin : public MgbAttrWrapperBase {
  84. static bool classof(const Attribute* attr) {
  85. return attr->getBaseAttr().isSubClassOf("MgbHashableAttrMixin");
  86. }
  87. llvm::StringRef getHashFunctionTemplate() const {
  88. return getBaseRecord()->getValueAsString("hashFunction");
  89. }
  90. llvm::StringRef getCmpFunctionTemplate() const {
  91. return getBaseRecord()->getValueAsString("cmpFunction");
  92. }
  93. llvm::StringRef getReprFunctionTemplate() const {
  94. return getBaseRecord()->getValueAsString("reprFunction");
  95. }
  96. };
  97. struct MgbAliasAttrMixin : public MgbAttrWrapperBase {
  98. static bool classof(const Attribute* attr) {
  99. return attr->getBaseAttr().isSubClassOf("MgbAliasAttrMixin");
  100. }
  101. Attribute getAliasBase() const {
  102. return Attribute(getBaseRecord()->getValueAsDef("aliasBase"));
  103. }
  104. };
  105. class MgbPackedParam {
  106. public:
  107. MgbPackedParam(Record* def_) : def(def_) {
  108. auto&& dag = def->getValueAsDag("fields");
  109. for (size_t i = 0; i < dag->getNumArgs(); ++i) {
  110. fields.push_back(
  111. {dag->getArgNameStr(i),
  112. Attribute(llvm::cast<llvm::DefInit>(dag->getArg(i)))});
  113. }
  114. }
  115. llvm::StringRef getFullName() const { return def->getValueAsString("fullName"); }
  116. std::vector<NamedAttribute> getFields() const { return fields; }
  117. llvm::StringRef getAccessor() const {
  118. return def->getValueAsString("paramAccessor");
  119. }
  120. private:
  121. std::vector<NamedAttribute> fields;
  122. Record* def;
  123. };
  124. struct MgbOpBase : public MgbInterface<Operator> {
  125. static bool isPackedParam(Record* def) {
  126. return def->isSubClassOf("MgbPackedParamBase");
  127. }
  128. public:
  129. static bool classof(const Operator* op) {
  130. return op->getDef().isSubClassOf("MgbOp");
  131. }
  132. std::vector<NamedAttribute> getMgbAttributes() const {
  133. std::vector<NamedAttribute> ret;
  134. for (auto&& i : getAttributes()) {
  135. if (isa<MgbAttrWrapperBase>(i.attr)) {
  136. ret.push_back(i);
  137. }
  138. }
  139. return ret;
  140. }
  141. std::vector<NamedAttribute> getExtraArguments() const {
  142. std::vector<NamedAttribute> ret;
  143. auto&& dag = getDef().getValueAsDag("extraArguments");
  144. for (size_t i = 0; i < dag->getNumArgs(); ++i) {
  145. ret.push_back(
  146. {dag->getArgNameStr(i),
  147. Attribute(llvm::cast<llvm::DefInit>(dag->getArg(i)))});
  148. }
  149. return ret;
  150. }
  151. llvm::Optional<StringRef> getExtraOpdefDecl() const {
  152. return getDef().getValueAsOptionalString("extraOpdefDecl");
  153. }
  154. std::vector<MgbPackedParam> getPackedParams() const {
  155. std::vector<MgbPackedParam> ret;
  156. for (auto&& i : getDef().getValueAsListOfDefs("dnnParams")) {
  157. if (isPackedParam(i)) {
  158. ret.emplace_back(i);
  159. }
  160. }
  161. return ret;
  162. }
  163. std::string getNameFunctionTemplate() const {
  164. if (auto f = getDef().getValueAsOptionalString("nameFunction")) {
  165. return f.getValue().str();
  166. }
  167. return formatv(" return \"{0}\";\n", getCppClassName());
  168. }
  169. };
  170. struct MgbHashableOpMixin : public MgbOpBase {
  171. private:
  172. std::string getDefaultHashFunction() const {
  173. std::string body = " size_t val = mgb::hash($_self.dyn_typeinfo());\n";
  174. if (!getMgbAttributes().empty()) {
  175. auto getHashFunc = [&](auto&& iter) {
  176. auto&& attr = llvm::cast<MgbHashableAttrMixin>(iter.attr);
  177. return attr.getHashFunctionTemplate();
  178. };
  179. mlir::tblgen::FmtContext ctx;
  180. for (auto&& it : getMgbAttributes()) {
  181. body +=
  182. formatv(" val = mgb::hash_pair_combine(val, {0});\n",
  183. mlir::tblgen::tgfmt(
  184. getHashFunc(it), &ctx, "$_self." + it.name));
  185. }
  186. }
  187. body += " return val;\n";
  188. return body;
  189. }
  190. std::string getDefaultCmpFunction() const {
  191. std::string body;
  192. if (!getMgbAttributes().empty()) {
  193. mlir::tblgen::FmtContext ctx;
  194. for (auto&& it : getMgbAttributes()) {
  195. auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr);
  196. body +=
  197. formatv(" if ({0}) return false;\n",
  198. mlir::tblgen::tgfmt(
  199. attr.getCmpFunctionTemplate(), &ctx,
  200. "$0." + it.name, "$1." + it.name));
  201. }
  202. }
  203. body += " return true;\n";
  204. return body;
  205. }
  206. std::string getDefaultPropsFunction() const {
  207. std::string body =
  208. " std::vector<std::pair<const char*, std::string>> props_;\n";
  209. if (!getMgbAttributes().empty()) {
  210. mlir::tblgen::FmtContext ctx;
  211. for (auto&& it : getMgbAttributes()) {
  212. if (auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr)) {
  213. body += formatv(" switch ({0}){{\n", "$_self." + it.name);
  214. for (auto&& enumMember : enumAttr->getEnumMembers()) {
  215. size_t d1 = enumMember.find(' ');
  216. size_t d2 = enumMember.find('=');
  217. size_t d = d1 <= d2 ? d1 : d2;
  218. body += formatv(
  219. " case {0}::{1}::{2}:\n", getCppClassName(),
  220. enumAttr->getEnumName(), enumMember.substr(0, d));
  221. body +=
  222. formatv(" props_.emplace_back(\"{0}\", "
  223. "\"{1}\");\n",
  224. it.name, enumMember.substr(0, d));
  225. body += " break;\n";
  226. }
  227. body += " default: break;\n";
  228. body += " }\n";
  229. } else {
  230. auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr);
  231. body +=
  232. formatv(" props_.emplace_back(\"{0}\", {1});\n", it.name,
  233. mlir::tblgen::tgfmt(
  234. attr.getReprFunctionTemplate(), &ctx,
  235. "$_self." + it.name));
  236. }
  237. }
  238. }
  239. body += " return props_;\n";
  240. return body;
  241. }
  242. public:
  243. static bool classof(const Operator* op) {
  244. return op->getDef().isSubClassOf("MgbHashableOpMixin");
  245. }
  246. std::string getHashFunctionTemplate() const {
  247. if (auto f = getDef().getValueAsOptionalString("hashFunction")) {
  248. return f.getValue().str();
  249. }
  250. return getDefaultHashFunction();
  251. }
  252. std::string getCmpFunctionTemplate() const {
  253. if (auto f = getDef().getValueAsOptionalString("cmpFunction")) {
  254. return f.getValue().str();
  255. }
  256. return getDefaultCmpFunction();
  257. }
  258. std::string getPropsFunctionTemplate() const {
  259. if (auto f = getDef().getValueAsOptionalString("propsFunction")) {
  260. return f.getValue().str();
  261. }
  262. return getDefaultPropsFunction();
  263. }
  264. };
  265. using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase;
  266. using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin;
  267. using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin;
  268. using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin;
  269. using MgbOp = mlir::tblgen::MgbOpBase;
  270. using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin;
  271. static inline void foreach_operator(
  272. llvm::RecordKeeper& keeper, std::function<void(MgbOp&)> callback) {
  273. auto op_base_class = keeper.getClass("Op");
  274. ASSERT(op_base_class, "could not find base class Op");
  275. for (auto&& i : keeper.getDefs()) {
  276. auto&& r = i.second;
  277. if (r->isSubClassOf(op_base_class)) {
  278. auto op = mlir::tblgen::Operator(r.get());
  279. if (op.getDialectName().str() == "mgb") {
  280. std::cerr << "\033[34;15m"
  281. << "Generating " << r->getName().str() << "\033[0m"
  282. << std::endl;
  283. callback(llvm::cast<MgbOp>(op));
  284. }
  285. }
  286. }
  287. }
  288. } // namespace tblgen
  289. } // namespace mlir