You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pybind11.cpp 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. /**
  2. * \file imperative/tablegen/targets/pybind11.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./pybind11.h"
  12. #include "../emitter.h"
  13. namespace mlir::tblgen {
  14. namespace {
  15. class OpDefEmitter final : public EmitterBase {
  16. public:
  17. OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_)
  18. : EmitterBase(os_, env_), op(op_) {}
  19. void emit();
  20. private:
  21. MgbOp& op;
  22. };
  23. void OpDefEmitter::emit() {
  24. auto className = op.getCppClassName();
  25. os << formatv(
  26. "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
  27. className);
  28. for (auto&& i : op.getMgbAttributes()) {
  29. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  30. unsigned int enumID;
  31. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  32. auto&& aliasBase = alias->getAliasBase();
  33. enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
  34. } else {
  35. enumID = attr->getBaseRecord()->getID();
  36. }
  37. auto&& enumAlias = env().enumAlias;
  38. auto&& iter = enumAlias.find(enumID);
  39. if (iter == enumAlias.end()) {
  40. os << formatv(
  41. "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", className,
  42. attr->getEnumName());
  43. std::vector<std::string> body;
  44. for (auto&& i : attr->getEnumMembers()) {
  45. size_t d1 = i.find(' ');
  46. size_t d2 = i.find('=');
  47. size_t d = d1 <= d2 ? d1 : d2;
  48. os << formatv(
  49. "\n .value(\"{2}\", {0}::{1}::{2})", className,
  50. attr->getEnumName(), i.substr(0, d));
  51. body.push_back(
  52. formatv("if (str == \"{2}\") return {0}::{1}::{2};",
  53. className, attr->getEnumName(), i.substr(0, d)));
  54. }
  55. if (attr->getEnumCombinedFlag()) {
  56. //! define operator |
  57. os << formatv(
  58. "\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ "
  59. "\n return static_cast<{0}::{1}>(uint32_t(s0) | "
  60. "uint32_t(s1));"
  61. "\n })",
  62. className, attr->getEnumName());
  63. //! define operator &
  64. os << formatv(
  65. "\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{"
  66. "\n return static_cast<{0}::{1}>(uint32_t(s0) & "
  67. "uint32_t(s1));"
  68. "\n })",
  69. className, attr->getEnumName());
  70. }
  71. os << formatv(
  72. "\n .def(py::init([](const std::string& in) {"
  73. "\n auto&& str = normalize_enum(in);"
  74. "\n {0}"
  75. "\n throw py::cast_error(\"invalid enum value \" + in);"
  76. "\n }));\n",
  77. llvm::join(body, "\n "));
  78. os << formatv(
  79. "py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
  80. className, attr->getEnumName());
  81. enumAlias.emplace(
  82. enumID, std::make_pair(className, attr->getEnumName()));
  83. } else {
  84. os << formatv(
  85. "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", className,
  86. attr->getEnumName(), iter->second.first, iter->second.second);
  87. }
  88. }
  89. }
  90. // generate op class binding
  91. os << formatv("{0}Inst", className);
  92. bool hasDefaultCtor = op.getMgbAttributes().empty();
  93. if (!hasDefaultCtor) {
  94. os << "\n .def(py::init<";
  95. std::vector<llvm::StringRef> targs;
  96. for (auto&& i : op.getMgbAttributes()) {
  97. targs.push_back(i.attr.getReturnType());
  98. }
  99. os << llvm::join(targs, ", ");
  100. os << ", std::string>()";
  101. for (auto&& i : op.getMgbAttributes()) {
  102. os << formatv(", py::arg(\"{0}\")", i.name);
  103. auto defaultValue = i.attr.getDefaultValue();
  104. if (!defaultValue.empty()) {
  105. os << formatv(" = {0}", defaultValue);
  106. } else {
  107. hasDefaultCtor = true;
  108. }
  109. }
  110. os << ", py::arg(\"scope\") = {})";
  111. }
  112. if (hasDefaultCtor) {
  113. os << "\n .def(py::init<>())";
  114. }
  115. for (auto&& i : op.getMgbAttributes()) {
  116. os << formatv("\n .def_readwrite(\"{0}\", &{1}::{0})", i.name, className);
  117. }
  118. os << ";\n\n";
  119. }
  120. } // namespace
  121. bool gen_op_def_pybind11(raw_ostream& os, llvm::RecordKeeper& keeper) {
  122. Environment env;
  123. using namespace std::placeholders;
  124. foreach_operator(keeper, [&](MgbOp& op) { OpDefEmitter(op, os, env).emit(); });
  125. return false;
  126. }
  127. } // namespace mlir::tblgen