You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pybind11.cpp 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. /**
  2. * \file imperative/tablegen/targets/pybind11.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./pybind11.h"
  12. #include "../emitter.h"
  13. namespace mlir::tblgen {
  14. namespace {
  15. class OpDefEmitter final: public EmitterBase {
  16. public:
  17. OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_):
  18. EmitterBase(os_, env_), op(op_) {}
  19. void emit();
  20. private:
  21. MgbOp& op;
  22. };
  23. void OpDefEmitter::emit() {
  24. auto className = op.getCppClassName();
  25. os << formatv(
  26. "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
  27. className
  28. );
  29. for (auto&& i : op.getMgbAttributes()) {
  30. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  31. unsigned int enumID;
  32. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  33. auto&& aliasBase = alias->getAliasBase();
  34. enumID =
  35. llvm::cast<MgbEnumAttr>(aliasBase)
  36. .getBaseRecord()->getID();
  37. } else {
  38. enumID = attr->getBaseRecord()->getID();
  39. }
  40. auto&& enumAlias = env().enumAlias;
  41. auto&& iter = enumAlias.find(enumID);
  42. if (iter == enumAlias.end()) {
  43. os << formatv(
  44. "py::enum_<{0}::{1}>({0}Inst, \"{1}\")",
  45. className, attr->getEnumName()
  46. );
  47. std::vector<std::string> body;
  48. for (auto&& i: attr->getEnumMembers()) {
  49. size_t d1 = i.find(' ');
  50. size_t d2 = i.find('=');
  51. size_t d = d1 <= d2 ? d1 : d2;
  52. os << formatv("\n .value(\"{2}\", {0}::{1}::{2})",
  53. className, attr->getEnumName(),
  54. i.substr(0, d));
  55. body.push_back(formatv(
  56. "if (str == \"{2}\") return {0}::{1}::{2};",
  57. className, attr->getEnumName(), i.substr(0, d)));
  58. }
  59. if (attr->getEnumCombinedFlag()) {
  60. //! define operator |
  61. os << formatv(
  62. "\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ "
  63. "\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));"
  64. "\n })",
  65. className, attr->getEnumName());
  66. //! define operator &
  67. os << formatv(
  68. "\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{"
  69. "\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));"
  70. "\n })",
  71. className, attr->getEnumName());
  72. }
  73. os << formatv(
  74. "\n .def(py::init([](const std::string& in) {"
  75. "\n auto&& str = normalize_enum(in);"
  76. "\n {0}"
  77. "\n throw py::cast_error(\"invalid enum value \" + in);"
  78. "\n }));\n",
  79. llvm::join(body, "\n ")
  80. );
  81. os << formatv(
  82. "py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
  83. className, attr->getEnumName()
  84. );
  85. enumAlias.emplace(enumID,
  86. std::make_pair(className, attr->getEnumName()));
  87. } else {
  88. os << formatv(
  89. "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n",
  90. className, attr->getEnumName(),
  91. iter->second.first, iter->second.second
  92. );
  93. }
  94. }
  95. }
  96. // generate op class binding
  97. os << formatv("{0}Inst", className);
  98. bool hasDefaultCtor = op.getMgbAttributes().empty();
  99. if (!hasDefaultCtor) {
  100. os << "\n .def(py::init<";
  101. std::vector<llvm::StringRef> targs;
  102. for (auto &&i : op.getMgbAttributes()) {
  103. targs.push_back(i.attr.getReturnType());
  104. }
  105. os << llvm::join(targs, ", ");
  106. os << ", std::string>()";
  107. for (auto &&i : op.getMgbAttributes()) {
  108. os << formatv(", py::arg(\"{0}\")", i.name);
  109. auto defaultValue = i.attr.getDefaultValue();
  110. if (!defaultValue.empty()) {
  111. os << formatv(" = {0}", defaultValue);
  112. } else {
  113. hasDefaultCtor = true;
  114. }
  115. }
  116. os << ", py::arg(\"scope\") = {})";
  117. }
  118. if (hasDefaultCtor) {
  119. os << "\n .def(py::init<>())";
  120. }
  121. for (auto &&i : op.getMgbAttributes()) {
  122. os << formatv(
  123. "\n .def_readwrite(\"{0}\", &{1}::{0})",
  124. i.name, className
  125. );
  126. }
  127. os << ";\n\n";
  128. }
  129. } // namespace
  130. bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper) {
  131. Environment env;
  132. using namespace std::placeholders;
  133. foreach_operator(keeper, [&](MgbOp& op) {
  134. OpDefEmitter(op, os, env).emit();
  135. });
  136. return false;
  137. }
  138. } // namespace mlir::tblgen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台