You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pybind11.cpp 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. /**
  2. * \file imperative/tablegen/targets/pybind11.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./pybind11.h"
  12. #include "../emitter.h"
  13. namespace mlir::tblgen {
  14. namespace {
  15. class OpDefEmitter final: public EmitterBase {
  16. public:
  17. OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_):
  18. EmitterBase(os_, env_), op(op_) {}
  19. void emit();
  20. private:
  21. MgbOp& op;
  22. };
  23. void OpDefEmitter::emit() {
  24. auto className = op.getCppClassName();
  25. os << formatv(
  26. "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
  27. className
  28. );
  29. for (auto&& i : op.getMgbAttributes()) {
  30. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  31. unsigned int enumID;
  32. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  33. auto&& aliasBase = alias->getAliasBase();
  34. enumID =
  35. llvm::cast<MgbEnumAttr>(aliasBase)
  36. .getBaseRecord()->getID();
  37. } else {
  38. enumID = attr->getBaseRecord()->getID();
  39. }
  40. auto&& enumAlias = env().enumAlias;
  41. auto&& iter = enumAlias.find(enumID);
  42. if (iter == enumAlias.end()) {
  43. os << formatv(
  44. "py::enum_<{0}::{1}>({0}Inst, \"{1}\")",
  45. className, attr->getEnumName()
  46. );
  47. std::vector<std::string> body;
  48. for (auto&& i: attr->getEnumMembers()) {
  49. os << formatv(
  50. "\n .value(\"{2}\", {0}::{1}::{2})",
  51. className, attr->getEnumName(), i
  52. );
  53. body.push_back(formatv(
  54. "if (str == \"{2}\") return {0}::{1}::{2};",
  55. className, attr->getEnumName(), i
  56. ));
  57. }
  58. if (attr->getEnumCombinedFlag()) {
  59. //! define operator |
  60. os << formatv(
  61. "\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ "
  62. "\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));"
  63. "\n })",
  64. className, attr->getEnumName());
  65. //! define operator &
  66. os << formatv(
  67. "\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{"
  68. "\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));"
  69. "\n })",
  70. className, attr->getEnumName());
  71. }
  72. os << formatv(
  73. "\n .def(py::init([](const std::string& in) {"
  74. "\n auto&& str = normalize_enum(in);"
  75. "\n {0}"
  76. "\n throw py::cast_error(\"invalid enum value \" + in);"
  77. "\n }));\n",
  78. llvm::join(body, "\n ")
  79. );
  80. os << formatv(
  81. "py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
  82. className, attr->getEnumName()
  83. );
  84. enumAlias.emplace(enumID,
  85. std::make_pair(className, attr->getEnumName()));
  86. } else {
  87. os << formatv(
  88. "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n",
  89. className, attr->getEnumName(),
  90. iter->second.first, iter->second.second
  91. );
  92. }
  93. }
  94. }
  95. // generate op class binding
  96. os << formatv("{0}Inst", className);
  97. bool hasDefaultCtor = op.getMgbAttributes().empty();
  98. if (!hasDefaultCtor) {
  99. os << "\n .def(py::init<";
  100. std::vector<llvm::StringRef> targs;
  101. for (auto &&i : op.getMgbAttributes()) {
  102. targs.push_back(i.attr.getReturnType());
  103. }
  104. os << llvm::join(targs, ", ");
  105. os << ", std::string>()";
  106. for (auto &&i : op.getMgbAttributes()) {
  107. os << formatv(", py::arg(\"{0}\")", i.name);
  108. auto defaultValue = i.attr.getDefaultValue();
  109. if (!defaultValue.empty()) {
  110. os << formatv(" = {0}", defaultValue);
  111. } else {
  112. hasDefaultCtor = true;
  113. }
  114. }
  115. os << ", py::arg(\"scope\") = {})";
  116. }
  117. if (hasDefaultCtor) {
  118. os << "\n .def(py::init<>())";
  119. }
  120. for (auto &&i : op.getMgbAttributes()) {
  121. os << formatv(
  122. "\n .def_readwrite(\"{0}\", &{1}::{0})",
  123. i.name, className
  124. );
  125. }
  126. os << ";\n\n";
  127. }
  128. } // namespace
  129. bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper) {
  130. Environment env;
  131. using namespace std::placeholders;
  132. foreach_operator(keeper, [&](MgbOp& op) {
  133. OpDefEmitter(op, os, env).emit();
  134. });
  135. return false;
  136. }
  137. } // namespace mlir::tblgen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台