You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.h 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. #include <string>
  2. #include <vector>
  3. #include "llvm/Support/CommandLine.h"
  4. #include "llvm/Support/FormatVariadic.h"
  5. #include "llvm/Support/InitLLVM.h"
  6. #include "llvm/Support/Signals.h"
  7. #include "llvm/TableGen/Main.h"
  8. #include "llvm/TableGen/Record.h"
  9. #include "llvm/TableGen/TableGenBackend.h"
  10. #include "mlir/TableGen/Attribute.h"
  11. #include "mlir/TableGen/Format.h"
  12. #include "mlir/TableGen/Operator.h"
  13. using llvm::formatv;
  14. using llvm::StringRef;
  15. using llvm::Record;
  16. #define ASSERT(stmt, msg) \
  17. if (!(stmt)) { \
  18. std::cerr << "\033[1;31m" \
  19. << "tablegen autogen abort due to: " << msg \
  20. << "\033[0m" << std::endl; \
  21. exit(1); \
  22. }
  23. namespace mlir {
  24. namespace tblgen {
  25. template<typename ConcreteType>
  26. struct MgbInterface : public ConcreteType {
  27. MgbInterface() = delete;
  28. MgbInterface(const MgbInterface&) = delete;
  29. MgbInterface(MgbInterface&&) = delete;
  30. ~MgbInterface() = delete;
  31. };
  32. struct MgbAttrWrapperBase : public MgbInterface<Attribute> {
  33. private:
  34. struct RecordVisitor : public MgbInterface<Constraint> {
  35. public:
  36. static bool classof(const Constraint*) {
  37. return true;
  38. }
  39. const llvm::Record* getDef() const {
  40. return def;
  41. }
  42. };
  43. public:
  44. static bool classof(const Attribute* attr) {
  45. return attr->isSubClassOf("MgbAttrWrapperBase");
  46. }
  47. const llvm::Record* getBaseRecord() const {
  48. auto baseAttr = getBaseAttr();
  49. return llvm::cast<RecordVisitor>(baseAttr).getDef();
  50. }
  51. llvm::StringRef getUnderlyingType() const {
  52. return def->getValueAsString("underlyingType");
  53. }
  54. };
  55. struct MgbEnumAttrMixin : public MgbAttrWrapperBase {
  56. static bool classof(const Attribute* attr) {
  57. return attr->getBaseAttr().isSubClassOf("MgbEnumAttrMixin");
  58. }
  59. llvm::StringRef getParentNamespace() const {
  60. return getBaseRecord()->getValueAsString("parentNamespace");
  61. }
  62. llvm::StringRef getEnumName() const {
  63. return getBaseRecord()->getValueAsString("enumName");
  64. }
  65. std::vector<StringRef> getEnumMembers() const {
  66. return getBaseRecord()->getValueAsListOfStrings("enumMembers");
  67. }
  68. bool supportToString() const {
  69. return getBaseRecord()->getValueAsBit("supportToString");
  70. }
  71. bool getEnumCombinedFlag() const {
  72. return getBaseRecord()->getValueAsBit("enumCombined");
  73. }
  74. };
  75. struct MgbHashableAttrMixin : public MgbAttrWrapperBase {
  76. static bool classof(const Attribute* attr) {
  77. return attr->getBaseAttr().isSubClassOf("MgbHashableAttrMixin");
  78. }
  79. llvm::StringRef getHashFunctionTemplate() const {
  80. return getBaseRecord()->getValueAsString("hashFunction");
  81. }
  82. llvm::StringRef getCmpFunctionTemplate() const {
  83. return getBaseRecord()->getValueAsString("cmpFunction");
  84. }
  85. llvm::StringRef getReprFunctionTemplate() const {
  86. return getBaseRecord()->getValueAsString("reprFunction");
  87. }
  88. };
  89. struct MgbAliasAttrMixin : public MgbAttrWrapperBase {
  90. static bool classof(const Attribute* attr) {
  91. return attr->getBaseAttr().isSubClassOf("MgbAliasAttrMixin");
  92. }
  93. Attribute getAliasBase() const {
  94. return Attribute(getBaseRecord()->getValueAsDef("aliasBase"));
  95. }
  96. };
  97. class MgbPackedParam {
  98. public:
  99. MgbPackedParam(Record* def_): def(def_) {
  100. auto&& dag = def->getValueAsDag("fields");
  101. for (size_t i = 0; i < dag->getNumArgs(); ++ i) {
  102. fields.push_back({
  103. dag->getArgNameStr(i),
  104. Attribute(llvm::cast<llvm::DefInit>(dag->getArg(i)))
  105. });
  106. }
  107. }
  108. llvm::StringRef getFullName() const {
  109. return def->getValueAsString("fullName");
  110. }
  111. std::vector<NamedAttribute> getFields() const {
  112. return fields;
  113. }
  114. llvm::StringRef getAccessor() const {
  115. return def->getValueAsString("paramAccessor");
  116. }
  117. private:
  118. std::vector<NamedAttribute> fields;
  119. Record* def;
  120. };
  121. struct MgbOpBase : public MgbInterface<Operator> {
  122. static bool isPackedParam(Record* def) {
  123. return def->isSubClassOf("MgbPackedParamBase");
  124. }
  125. public:
  126. static bool classof(const Operator* op) {
  127. return op->getDef().isSubClassOf("MgbOp");
  128. }
  129. std::vector<NamedAttribute> getMgbAttributes() const {
  130. std::vector<NamedAttribute> ret;
  131. for (auto&& i: getAttributes()) {
  132. if (isa<MgbAttrWrapperBase>(i.attr)) {
  133. ret.push_back(i);
  134. }
  135. }
  136. return ret;
  137. }
  138. std::vector<NamedAttribute> getExtraArguments() const {
  139. std::vector<NamedAttribute> ret;
  140. auto&& dag = getDef().getValueAsDag("extraArguments");
  141. for (size_t i = 0; i < dag->getNumArgs(); ++ i) {
  142. ret.push_back({
  143. dag->getArgNameStr(i),
  144. Attribute(llvm::cast<llvm::DefInit>(dag->getArg(i)))
  145. });
  146. }
  147. return ret;
  148. }
  149. llvm::Optional<StringRef> getExtraOpdefDecl() const {
  150. return getDef().getValueAsOptionalString("extraOpdefDecl");
  151. }
  152. std::vector<MgbPackedParam> getPackedParams() const {
  153. std::vector<MgbPackedParam> ret;
  154. for (auto&& i : getDef().getValueAsListOfDefs("dnnParams")) {
  155. if (isPackedParam(i)) {
  156. ret.emplace_back(i);
  157. }
  158. }
  159. return ret;
  160. }
  161. std::string getNameFunctionTemplate() const {
  162. if (auto f = getDef().getValueAsOptionalString("nameFunction")) {
  163. return f.getValue().str();
  164. }
  165. return formatv(" return \"{0}\";\n", getCppClassName());
  166. }
  167. };
  168. struct MgbHashableOpMixin : public MgbOpBase {
  169. private:
  170. std::string getDefaultHashFunction() const {
  171. std::string body = " size_t val = mgb::hash($_self.dyn_typeinfo());\n";
  172. if (!getMgbAttributes().empty()) {
  173. auto getHashFunc = [&](auto&& iter) {
  174. auto&& attr = llvm::cast<MgbHashableAttrMixin>(iter.attr);
  175. return attr.getHashFunctionTemplate();
  176. };
  177. mlir::tblgen::FmtContext ctx;
  178. for (auto&& it: getMgbAttributes()) {
  179. body += formatv(
  180. " val = mgb::hash_pair_combine(val, {0});\n",
  181. mlir::tblgen::tgfmt(getHashFunc(it), &ctx, "$_self." + it.name)
  182. );
  183. }
  184. }
  185. body += " return val;\n";
  186. return body;
  187. }
  188. std::string getDefaultCmpFunction() const {
  189. std::string body;
  190. if (!getMgbAttributes().empty()) {
  191. mlir::tblgen::FmtContext ctx;
  192. for (auto&& it : getMgbAttributes()) {
  193. auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr);
  194. body += formatv(
  195. " if ({0}) return false;\n",
  196. mlir::tblgen::tgfmt(attr.getCmpFunctionTemplate(),
  197. &ctx, "$0." + it.name, "$1." + it.name)
  198. );
  199. }
  200. }
  201. body += " return true;\n";
  202. return body;
  203. }
  204. std::string getDefaultPropsFunction() const {
  205. std::string body = " std::vector<std::pair<const char*, std::string>> props_;\n";
  206. if (!getMgbAttributes().empty()) {
  207. mlir::tblgen::FmtContext ctx;
  208. for (auto&& it : getMgbAttributes()) {
  209. if (auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr)) {
  210. body += formatv(" switch ({0}){{\n", "$_self." + it.name);
  211. for (auto&& enumMember: enumAttr->getEnumMembers()) {
  212. body += formatv(
  213. " case {0}::{1}::{2}:\n",
  214. getCppClassName(), enumAttr->getEnumName(), enumMember
  215. );
  216. body += formatv(
  217. " props_.emplace_back(\"{0}\", \"{1}\");\n",
  218. it.name, enumMember
  219. );
  220. body += " break;\n";
  221. }
  222. body += " default: break;\n";
  223. body += " }\n";
  224. } else {
  225. auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr);
  226. body += formatv(
  227. " props_.emplace_back(\"{0}\", {1});\n", it.name,
  228. mlir::tblgen::tgfmt(attr.getReprFunctionTemplate(),
  229. &ctx, "$_self." + it.name)
  230. );
  231. }
  232. }
  233. }
  234. body += " return props_;\n";
  235. return body;
  236. }
  237. public:
  238. static bool classof(const Operator* op) {
  239. return op->getDef().isSubClassOf("MgbHashableOpMixin");
  240. }
  241. std::string getHashFunctionTemplate() const {
  242. if (auto f = getDef().getValueAsOptionalString("hashFunction")) {
  243. return f.getValue().str();
  244. }
  245. return getDefaultHashFunction();
  246. }
  247. std::string getCmpFunctionTemplate() const {
  248. if (auto f = getDef().getValueAsOptionalString("cmpFunction")) {
  249. return f.getValue().str();
  250. }
  251. return getDefaultCmpFunction();
  252. }
  253. std::string getPropsFunctionTemplate() const {
  254. if (auto f = getDef().getValueAsOptionalString("propsFunction")) {
  255. return f.getValue().str();
  256. }
  257. return getDefaultPropsFunction();
  258. }
  259. };
  260. } // namespace tblgen
  261. } // namespace mlir

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台