You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pybind11.cpp 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. #include "./pybind11.h"
  2. #include "../emitter.h"
  3. namespace mlir::tblgen {
  4. namespace {
  5. class OpDefEmitter final : public EmitterBase {
  6. public:
  7. OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_)
  8. : EmitterBase(os_, env_), op(op_) {}
  9. void emit();
  10. private:
  11. MgbOp& op;
  12. };
  13. void OpDefEmitter::emit() {
  14. auto className = op.getCppClassName();
  15. os << formatv(
  16. "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
  17. className);
  18. for (auto&& i : op.getMgbAttributes()) {
  19. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  20. unsigned int enumID;
  21. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  22. auto&& aliasBase = alias->getAliasBase();
  23. enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
  24. } else {
  25. enumID = attr->getBaseRecord()->getID();
  26. }
  27. auto&& enumAlias = env().enumAlias;
  28. auto&& iter = enumAlias.find(enumID);
  29. if (iter == enumAlias.end()) {
  30. os << formatv(
  31. "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", className,
  32. attr->getEnumName());
  33. std::vector<std::string> body;
  34. for (auto&& i : attr->getEnumMembers()) {
  35. size_t d1 = i.find(' ');
  36. size_t d2 = i.find('=');
  37. size_t d = d1 <= d2 ? d1 : d2;
  38. os << formatv(
  39. "\n .value(\"{2}\", {0}::{1}::{2})", className,
  40. attr->getEnumName(), i.substr(0, d));
  41. body.push_back(
  42. formatv("if (str == \"{2}\") return {0}::{1}::{2};",
  43. className, attr->getEnumName(), i.substr(0, d)));
  44. }
  45. if (attr->getEnumCombinedFlag()) {
  46. //! define operator |
  47. os << formatv(
  48. "\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ "
  49. "\n return static_cast<{0}::{1}>(uint32_t(s0) | "
  50. "uint32_t(s1));"
  51. "\n })",
  52. className, attr->getEnumName());
  53. //! define operator &
  54. os << formatv(
  55. "\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{"
  56. "\n return static_cast<{0}::{1}>(uint32_t(s0) & "
  57. "uint32_t(s1));"
  58. "\n })",
  59. className, attr->getEnumName());
  60. }
  61. os << formatv(
  62. "\n .def(py::init([](const std::string& in) {"
  63. "\n auto&& str = normalize_enum(in);"
  64. "\n {0}"
  65. "\n throw py::cast_error(\"invalid enum value \" + in);"
  66. "\n }));\n",
  67. llvm::join(body, "\n "));
  68. os << formatv(
  69. "py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
  70. className, attr->getEnumName());
  71. enumAlias.emplace(
  72. enumID, std::make_pair(className, attr->getEnumName()));
  73. } else {
  74. os << formatv(
  75. "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", className,
  76. attr->getEnumName(), iter->second.first, iter->second.second);
  77. }
  78. }
  79. }
  80. // generate op class binding
  81. os << formatv("{0}Inst", className);
  82. bool hasDefaultCtor = op.getMgbAttributes().empty();
  83. if (!hasDefaultCtor) {
  84. os << "\n .def(py::init<";
  85. std::vector<llvm::StringRef> targs;
  86. for (auto&& i : op.getMgbAttributes()) {
  87. targs.push_back(i.attr.getReturnType());
  88. }
  89. os << llvm::join(targs, ", ");
  90. os << ", std::string>()";
  91. for (auto&& i : op.getMgbAttributes()) {
  92. os << formatv(", py::arg(\"{0}\")", i.name);
  93. auto defaultValue = i.attr.getDefaultValue();
  94. if (!defaultValue.empty()) {
  95. os << formatv(" = {0}", defaultValue);
  96. } else {
  97. hasDefaultCtor = true;
  98. }
  99. }
  100. os << ", py::arg(\"scope\") = {})";
  101. }
  102. if (hasDefaultCtor) {
  103. os << "\n .def(py::init<>())";
  104. }
  105. for (auto&& i : op.getMgbAttributes()) {
  106. os << formatv("\n .def_readwrite(\"{0}\", &{1}::{0})", i.name, className);
  107. }
  108. os << ";\n\n";
  109. }
  110. } // namespace
  111. bool gen_op_def_pybind11(raw_ostream& os, llvm::RecordKeeper& keeper) {
  112. Environment env;
  113. using namespace std::placeholders;
  114. foreach_operator(keeper, [&](MgbOp& op) { OpDefEmitter(op, os, env).emit(); });
  115. return false;
  116. }
  117. } // namespace mlir::tblgen