You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.h 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. #pragma once
  2. #include <iostream>
  3. #include <string>
  4. #include <vector>
  5. #include "llvm/Support/CommandLine.h"
  6. #include "llvm/Support/FormatVariadic.h"
  7. #include "llvm/Support/InitLLVM.h"
  8. #include "llvm/Support/Signals.h"
  9. #include "llvm/TableGen/Main.h"
  10. #include "llvm/TableGen/Record.h"
  11. #include "llvm/TableGen/TableGenBackend.h"
  12. #include "mlir/TableGen/Attribute.h"
  13. #include "mlir/TableGen/Format.h"
  14. #include "mlir/TableGen/Operator.h"
  15. using llvm::formatv;
  16. using llvm::Record;
  17. using llvm::StringRef;
  18. #define ASSERT(stmt, msg) \
  19. if (!(stmt)) { \
  20. std::cerr << "\033[1;31m" \
  21. << "tablegen autogen abort due to: " << msg << "\033[0m" \
  22. << std::endl; \
  23. exit(1); \
  24. }
  25. namespace mlir {
  26. namespace tblgen {
  27. template <typename ConcreteType>
  28. struct MgbInterface : public ConcreteType {
  29. MgbInterface() = delete;
  30. MgbInterface(const MgbInterface&) = delete;
  31. MgbInterface(MgbInterface&&) = delete;
  32. ~MgbInterface() = delete;
  33. };
  34. struct MgbAttrWrapperBase : public MgbInterface<Attribute> {
  35. private:
  36. struct RecordVisitor : public MgbInterface<Constraint> {
  37. public:
  38. static bool classof(const Constraint*) { return true; }
  39. const llvm::Record* getDef() const { return def; }
  40. };
  41. public:
  42. static bool classof(const Attribute* attr) {
  43. return attr->isSubClassOf("MgbAttrWrapperBase");
  44. }
  45. const llvm::Record* getBaseRecord() const {
  46. auto baseAttr = getBaseAttr();
  47. return llvm::cast<RecordVisitor>(baseAttr).getDef();
  48. }
  49. llvm::StringRef getUnderlyingType() const {
  50. return def->getValueAsString("underlyingType");
  51. }
  52. };
  53. struct MgbEnumAttrMixin : public MgbAttrWrapperBase {
  54. static bool classof(const Attribute* attr) {
  55. return attr->getBaseAttr().isSubClassOf("MgbEnumAttrMixin");
  56. }
  57. llvm::StringRef getParentNamespace() const {
  58. return getBaseRecord()->getValueAsString("parentNamespace");
  59. }
  60. llvm::StringRef getEnumName() const {
  61. return getBaseRecord()->getValueAsString("enumName");
  62. }
  63. std::vector<StringRef> getEnumMembers() const {
  64. return getBaseRecord()->getValueAsListOfStrings("enumMembers");
  65. }
  66. bool supportToString() const {
  67. return getBaseRecord()->getValueAsBit("supportToString");
  68. }
  69. bool getEnumCombinedFlag() const {
  70. return getBaseRecord()->getValueAsBit("enumCombined");
  71. }
  72. };
  73. struct MgbHashableAttrMixin : public MgbAttrWrapperBase {
  74. static bool classof(const Attribute* attr) {
  75. return attr->getBaseAttr().isSubClassOf("MgbHashableAttrMixin");
  76. }
  77. llvm::StringRef getHashFunctionTemplate() const {
  78. return getBaseRecord()->getValueAsString("hashFunction");
  79. }
  80. llvm::StringRef getCmpFunctionTemplate() const {
  81. return getBaseRecord()->getValueAsString("cmpFunction");
  82. }
  83. llvm::StringRef getReprFunctionTemplate() const {
  84. return getBaseRecord()->getValueAsString("reprFunction");
  85. }
  86. };
  87. struct MgbAliasAttrMixin : public MgbAttrWrapperBase {
  88. static bool classof(const Attribute* attr) {
  89. return attr->getBaseAttr().isSubClassOf("MgbAliasAttrMixin");
  90. }
  91. Attribute getAliasBase() const {
  92. return Attribute(getBaseRecord()->getValueAsDef("aliasBase"));
  93. }
  94. };
  95. class MgbPackedParam {
  96. public:
  97. MgbPackedParam(Record* def_) : def(def_) {
  98. auto&& dag = def->getValueAsDag("fields");
  99. for (size_t i = 0; i < dag->getNumArgs(); ++i) {
  100. fields.push_back(
  101. {dag->getArgNameStr(i),
  102. Attribute(llvm::cast<llvm::DefInit>(dag->getArg(i)))});
  103. }
  104. }
  105. llvm::StringRef getFullName() const { return def->getValueAsString("fullName"); }
  106. std::vector<NamedAttribute> getFields() const { return fields; }
  107. llvm::StringRef getAccessor() const {
  108. return def->getValueAsString("paramAccessor");
  109. }
  110. private:
  111. std::vector<NamedAttribute> fields;
  112. Record* def;
  113. };
  114. struct MgbOpBase : public MgbInterface<Operator> {
  115. static bool isPackedParam(Record* def) {
  116. return def->isSubClassOf("MgbPackedParamBase");
  117. }
  118. public:
  119. static bool classof(const Operator* op) {
  120. return op->getDef().isSubClassOf("MgbOp");
  121. }
  122. std::vector<NamedAttribute> getMgbAttributes() const {
  123. std::vector<NamedAttribute> ret;
  124. for (auto&& i : getAttributes()) {
  125. if (isa<MgbAttrWrapperBase>(i.attr)) {
  126. ret.push_back(i);
  127. }
  128. }
  129. return ret;
  130. }
  131. std::vector<NamedAttribute> getExtraArguments() const {
  132. std::vector<NamedAttribute> ret;
  133. auto&& dag = getDef().getValueAsDag("extraArguments");
  134. for (size_t i = 0; i < dag->getNumArgs(); ++i) {
  135. ret.push_back(
  136. {dag->getArgNameStr(i),
  137. Attribute(llvm::cast<llvm::DefInit>(dag->getArg(i)))});
  138. }
  139. return ret;
  140. }
  141. llvm::Optional<StringRef> getExtraOpdefDecl() const {
  142. return getDef().getValueAsOptionalString("extraOpdefDecl");
  143. }
  144. std::vector<MgbPackedParam> getPackedParams() const {
  145. std::vector<MgbPackedParam> ret;
  146. for (auto&& i : getDef().getValueAsListOfDefs("dnnParams")) {
  147. if (isPackedParam(i)) {
  148. ret.emplace_back(i);
  149. }
  150. }
  151. return ret;
  152. }
  153. std::string getNameFunctionTemplate() const {
  154. if (auto f = getDef().getValueAsOptionalString("nameFunction")) {
  155. return f.getValue().str();
  156. }
  157. return formatv(" return \"{0}\";\n", getCppClassName());
  158. }
  159. };
  160. struct MgbHashableOpMixin : public MgbOpBase {
  161. private:
  162. std::string getDefaultHashFunction() const {
  163. std::string body = " size_t val = mgb::hash($_self.dyn_typeinfo());\n";
  164. if (!getMgbAttributes().empty()) {
  165. auto getHashFunc = [&](auto&& iter) {
  166. auto&& attr = llvm::cast<MgbHashableAttrMixin>(iter.attr);
  167. return attr.getHashFunctionTemplate();
  168. };
  169. mlir::tblgen::FmtContext ctx;
  170. for (auto&& it : getMgbAttributes()) {
  171. body +=
  172. formatv(" val = mgb::hash_pair_combine(val, {0});\n",
  173. mlir::tblgen::tgfmt(
  174. getHashFunc(it), &ctx, "$_self." + it.name));
  175. }
  176. }
  177. body += " return val;\n";
  178. return body;
  179. }
  180. std::string getDefaultCmpFunction() const {
  181. std::string body;
  182. if (!getMgbAttributes().empty()) {
  183. mlir::tblgen::FmtContext ctx;
  184. for (auto&& it : getMgbAttributes()) {
  185. auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr);
  186. body +=
  187. formatv(" if ({0}) return false;\n",
  188. mlir::tblgen::tgfmt(
  189. attr.getCmpFunctionTemplate(), &ctx,
  190. "$0." + it.name, "$1." + it.name));
  191. }
  192. }
  193. body += " return true;\n";
  194. return body;
  195. }
  196. std::string getDefaultPropsFunction() const {
  197. std::string body =
  198. " std::vector<std::pair<const char*, std::string>> props_;\n";
  199. if (!getMgbAttributes().empty()) {
  200. mlir::tblgen::FmtContext ctx;
  201. for (auto&& it : getMgbAttributes()) {
  202. if (auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr)) {
  203. body += formatv(" switch ({0}){{\n", "$_self." + it.name);
  204. for (auto&& enumMember : enumAttr->getEnumMembers()) {
  205. size_t d1 = enumMember.find(' ');
  206. size_t d2 = enumMember.find('=');
  207. size_t d = d1 <= d2 ? d1 : d2;
  208. body += formatv(
  209. " case {0}::{1}::{2}:\n", getCppClassName(),
  210. enumAttr->getEnumName(), enumMember.substr(0, d));
  211. body +=
  212. formatv(" props_.emplace_back(\"{0}\", "
  213. "\"{1}\");\n",
  214. it.name, enumMember.substr(0, d));
  215. body += " break;\n";
  216. }
  217. body += " default:\n";
  218. body +=
  219. formatv(" props_.emplace_back(\"{0}\", "
  220. "\"INVALID\");\n",
  221. it.name);
  222. body += " break;\n";
  223. body += " }\n";
  224. } else {
  225. auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr);
  226. body +=
  227. formatv(" props_.emplace_back(\"{0}\", {1});\n", it.name,
  228. mlir::tblgen::tgfmt(
  229. attr.getReprFunctionTemplate(), &ctx,
  230. "$_self." + it.name));
  231. }
  232. }
  233. }
  234. body += " return props_;\n";
  235. return body;
  236. }
  237. public:
  238. static bool classof(const Operator* op) {
  239. return op->getDef().isSubClassOf("MgbHashableOpMixin");
  240. }
  241. std::string getHashFunctionTemplate() const {
  242. if (auto f = getDef().getValueAsOptionalString("hashFunction")) {
  243. return f.getValue().str();
  244. }
  245. return getDefaultHashFunction();
  246. }
  247. std::string getCmpFunctionTemplate() const {
  248. if (auto f = getDef().getValueAsOptionalString("cmpFunction")) {
  249. return f.getValue().str();
  250. }
  251. return getDefaultCmpFunction();
  252. }
  253. std::string getPropsFunctionTemplate() const {
  254. if (auto f = getDef().getValueAsOptionalString("propsFunction")) {
  255. return f.getValue().str();
  256. }
  257. return getDefaultPropsFunction();
  258. }
  259. };
  260. using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase;
  261. using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin;
  262. using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin;
  263. using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin;
  264. using MgbOp = mlir::tblgen::MgbOpBase;
  265. using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin;
  266. static inline void foreach_operator(
  267. llvm::RecordKeeper& keeper, std::function<void(MgbOp&)> callback) {
  268. auto op_base_class = keeper.getClass("Op");
  269. ASSERT(op_base_class, "could not find base class Op");
  270. for (auto&& i : keeper.getDefs()) {
  271. auto&& r = i.second;
  272. if (r->isSubClassOf(op_base_class)) {
  273. auto op = mlir::tblgen::Operator(r.get());
  274. if (op.getDialectName().str() == "mgb") {
  275. std::cerr << "\033[34;15m"
  276. << "Generating " << r->getName().str() << "\033[0m"
  277. << std::endl;
  278. callback(llvm::cast<MgbOp>(op));
  279. }
  280. }
  281. }
  282. }
  283. } // namespace tblgen
  284. } // namespace mlir