You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python_c_extension.cpp 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. /**
  2. * \file imperative/tablegen/targets/python_c_extension.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "python_c_extension.h"
  12. #include "../emitter.h"
  13. namespace mlir::tblgen {
  14. namespace {
  15. struct Initproc {
  16. std::string func;
  17. Initproc(std::string&& s): func(std::move(s)) {}
  18. std::string operator()(std::string argument) {
  19. return formatv("{0}({1})", func, argument);
  20. }
  21. };
  22. class OpDefEmitter: public EmitterBase {
  23. public:
  24. OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_):
  25. EmitterBase(os_, env_), op(op_) {
  26. ctx.withSelf(op.getCppClassName());
  27. }
  28. Initproc emit();
  29. private:
  30. void emit_class();
  31. void emit_py_init();
  32. void emit_py_getsetters();
  33. Initproc emit_initproc();
  34. MgbOp& op;
  35. std::vector<Initproc> subclasses;
  36. mlir::tblgen::FmtContext ctx;
  37. };
  38. class EnumAttrEmitter: public EmitterBase {
  39. public:
  40. EnumAttrEmitter(llvm::StringRef parent, MgbEnumAttr* attr_, raw_ostream& os_, Environment& env_):
  41. EmitterBase(os_, env_), attr(attr_) {
  42. unsigned int enumID;
  43. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  44. auto&& aliasBase = alias->getAliasBase();
  45. enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
  46. } else {
  47. enumID = attr->getBaseRecord()->getID();
  48. }
  49. ctx.addSubst("enumTpl", attr->getEnumCombinedFlag() ? "BitCombinedEnumWrapper" : "EnumWrapper");
  50. ctx.addSubst("opClass", parent);
  51. ctx.addSubst("enumClass", attr->getEnumName());
  52. firstOccur = env().enumAlias.emplace(enumID, std::make_pair(parent, attr->getEnumName())).second;
  53. }
  54. Initproc emit();
  55. protected:
  56. void emit_trait();
  57. void emit_tpl_spl();
  58. Initproc emit_initproc();
  59. MgbEnumAttr* attr;
  60. bool firstOccur;
  61. mlir::tblgen::FmtContext ctx;
  62. };
  63. Initproc EnumAttrEmitter::emit() {
  64. emit_trait();
  65. emit_tpl_spl();
  66. return emit_initproc();
  67. }
  68. void EnumAttrEmitter::emit_trait() {
  69. if (!firstOccur) return;
  70. auto enumMax = [&] {
  71. if (attr->getEnumCombinedFlag()) {
  72. return formatv("(1llu << {0}) - 1", attr->getEnumMembers().size());
  73. } else {
  74. return formatv("{0} - 1", attr->getEnumMembers().size());
  75. }
  76. };
  77. os << tgfmt(R"(
  78. template<> struct EnumTrait<$opClass::$enumClass> {
  79. static constexpr const char *name = "$opClass.$enumClass";
  80. static constexpr std::underlying_type_t<$opClass::$enumClass> max = $0;
  81. };
  82. )", &ctx, enumMax());
  83. }
  84. void EnumAttrEmitter::emit_tpl_spl() {
  85. if (!firstOccur) return;
  86. os << tgfmt(
  87. "template<> PyTypeObject* $enumTpl<$opClass::$enumClass>::type = nullptr;\n",
  88. &ctx);
  89. auto quote = [&](auto&& i) -> std::string {
  90. return formatv("\"{0}\"", i);
  91. };
  92. os << tgfmt(R"(
  93. template<> const char*
  94. $enumTpl<$opClass::$enumClass>::members[] = {$0};
  95. )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", "));
  96. auto mem2value = [&](auto&& i) -> std::string {
  97. return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i);
  98. };
  99. os << tgfmt(R"(
  100. template<> std::unordered_map<std::string, $opClass::$enumClass>
  101. $enumTpl<$opClass::$enumClass>::mem2value = {$0};
  102. )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), mem2value), ", "));
  103. os << tgfmt(
  104. "template<> PyObject* "
  105. "$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n",
  106. &ctx, attr->getEnumMembers().size());
  107. }
  108. Initproc EnumAttrEmitter::emit_initproc() {
  109. std::string initproc = formatv("_init_py_{0}_{1}",
  110. ctx.getSubstFor("opClass"), ctx.getSubstFor("enumClass"));
  111. os << tgfmt(R"(
  112. void $0(PyTypeObject& py_type) {
  113. auto& e_type = $enumTpl<$opClass::$enumClass>::type;
  114. )", &ctx, initproc);
  115. if (firstOccur) {
  116. os << tgfmt(R"(
  117. static PyType_Slot slots[] = {
  118. {Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr},
  119. {Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare},
  120. )", &ctx);
  121. if (attr->getEnumCombinedFlag()) {
  122. // only bit combined enum could new instance because bitwise operation,
  123. // others should always use singleton
  124. os << tgfmt(R"(
  125. {Py_tp_new, (void*)$enumTpl<$opClass::$enumClass>::py_new_combined_enum},
  126. {Py_nb_or, (void*)$enumTpl<$opClass::$enumClass>::py_or},
  127. {Py_nb_and, (void*)$enumTpl<$opClass::$enumClass>::py_and},
  128. )", &ctx);
  129. }
  130. os << R"(
  131. {0, NULL}
  132. };)";
  133. os << tgfmt(R"(
  134. static PyType_Spec spec = {
  135. // name
  136. "megengine.core._imperative_rt.ops.$opClass.$enumClass",
  137. // basicsize
  138. sizeof($enumTpl<$opClass::$enumClass>),
  139. // itemsize
  140. 0,
  141. // flags
  142. Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
  143. // slots
  144. slots
  145. };)", &ctx);
  146. os << tgfmt(R"(
  147. e_type = reinterpret_cast<PyTypeObject*>(PyType_FromSpec(&spec));
  148. )", &ctx);
  149. for (auto&& i : {
  150. std::pair<std::string, std::string>{"__name__", tgfmt("$enumClass", &ctx)},
  151. {"__module__", "megengine.core._imperative_rt.ops"},
  152. {"__qualname__", tgfmt("$opClass.$enumClass", &ctx)}}) {
  153. os << formatv(R"(
  154. mgb_assert(
  155. e_type->tp_setattro(
  156. reinterpret_cast<PyObject*>(e_type),
  157. py::cast("{0}").release().ptr(),
  158. py::cast("{1}").release().ptr()) >= 0);
  159. )", i.first, i.second);
  160. }
  161. auto&& members = attr->getEnumMembers();
  162. for (size_t idx = 0; idx < members.size(); ++ idx) {
  163. os << tgfmt(R"({
  164. PyObject* inst = e_type->tp_alloc(e_type, 0);
  165. reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
  166. mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
  167. $enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
  168. })", &ctx, members[idx], idx);
  169. }
  170. }
  171. os << tgfmt(R"(
  172. Py_INCREF(e_type);
  173. mgb_assert(PyDict_SetItemString(
  174. py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(e_type)) >= 0);
  175. )", &ctx);
  176. os << "}\n";
  177. return initproc;
  178. }
  179. Initproc OpDefEmitter::emit() {
  180. for (auto&& i : op.getMgbAttributes()) {
  181. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  182. subclasses.push_back(EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit());
  183. }
  184. }
  185. emit_class();
  186. emit_py_init();
  187. emit_py_getsetters();
  188. return emit_initproc();
  189. }
  190. void OpDefEmitter::emit_class() {
  191. os << tgfmt(R"(
  192. PyOpDefBegin($_self) // {
  193. static PyGetSetDef py_getsetters[];
  194. static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
  195. // };
  196. PyOpDefEnd($_self)
  197. )", &ctx);
  198. }
  199. void OpDefEmitter::emit_py_init() {
  200. std::string initBody;
  201. if (!op.getMgbAttributes().empty()) {
  202. initBody += "static const char* kwlist[] = {";
  203. std::vector<llvm::StringRef> attr_name_list;
  204. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  205. attr_name_list.push_back(attr.name);
  206. });
  207. attr_name_list.push_back("scope");
  208. llvm::for_each(attr_name_list, [&](auto&& attr) {
  209. initBody += formatv("\"{0}\", ", attr);
  210. });
  211. initBody += "NULL};\n";
  212. initBody += " PyObject ";
  213. auto initializer = [&](auto&& attr) -> std::string {
  214. return formatv("*{0} = NULL", attr);
  215. };
  216. initBody += llvm::join(llvm::map_range(attr_name_list, initializer), ", ") + ";\n";
  217. initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
  218. // an extra slot created for name
  219. initBody += std::string(attr_name_list.size(), 'O');
  220. initBody += "\", const_cast<char**>(kwlist)";
  221. llvm::for_each(attr_name_list, [&](auto&& attr) {
  222. initBody += formatv(", &{0}", attr);
  223. });
  224. initBody += "))\n";
  225. initBody += " return -1;\n";
  226. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  227. initBody += tgfmt(R"(
  228. if ($0) {
  229. try {
  230. // TODO: remove this guard which is used for pybind11 implicit conversion
  231. py::detail::loader_life_support guard{};
  232. reinterpret_cast<PyOp($_self)*>(self)->inst().$0 =
  233. py::cast<decltype($_self::$0)>(py::handle($0));
  234. } CATCH_ALL(-1)
  235. }
  236. )", &ctx, attr.name);
  237. });
  238. initBody += tgfmt(R"(
  239. if (scope) {
  240. try {
  241. reinterpret_cast<PyOp(OpDef)*>(self)->op
  242. ->set_scope(py::cast<std::string>(py::handle(scope)));
  243. } CATCH_ALL(-1)
  244. }
  245. )", &ctx);
  246. }
  247. initBody += "\n return 0;";
  248. os << tgfmt(R"(
  249. int PyOp($_self)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
  250. $0
  251. }
  252. )", &ctx, initBody);
  253. }
  254. void OpDefEmitter::emit_py_getsetters() {
  255. auto f = [&](auto&& attr) -> std::string {
  256. return tgfmt(
  257. "{const_cast<char*>(\"$0\"), py_get_generic($_self, $0), py_set_generic($_self, $0), const_cast<char*>(\"$0\"), NULL},",
  258. &ctx, attr.name);
  259. };
  260. os << tgfmt(R"(
  261. PyGetSetDef PyOp($_self)::py_getsetters[] = {
  262. $0
  263. {NULL} /* Sentinel */
  264. };
  265. )", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n "));
  266. }
  267. Initproc OpDefEmitter::emit_initproc() {
  268. std::string initproc = formatv("_init_py_{0}", op.getCppClassName());
  269. std::string subclass_init_call;
  270. for (auto&& i : subclasses) {
  271. subclass_init_call += formatv(" {0};\n", i("py_type"));
  272. }
  273. os << tgfmt(R"(
  274. void $0(py::module m) {
  275. using py_op = PyOp($_self);
  276. auto& py_type = PyOpType($_self);
  277. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  278. py_type.tp_name = "megengine.core._imperative_rt.ops.$_self";
  279. py_type.tp_basicsize = sizeof(PyOp($_self));
  280. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  281. py_type.tp_doc = "$_self";
  282. py_type.tp_base = &PyOpType(OpDef);
  283. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  284. py_type.tp_new = py_new_generic<py_op>;
  285. py_type.tp_init = py_op::py_init;
  286. py_type.tp_getset = py_op::py_getsetters;
  287. mgb_assert(PyType_Ready(&py_type) >= 0);
  288. $1
  289. PyType_Modified(&py_type);
  290. m.add_object("$_self", reinterpret_cast<PyObject*>(&py_type));
  291. mgb_assert(PyOp(OpDef)::ctype2pytype.emplace($_self::typeinfo(), &py_type).second);
  292. }
  293. )", &ctx, initproc, subclass_init_call);
  294. return initproc;
  295. }
  296. } // namespace
  297. bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper) {
  298. Environment env;
  299. using namespace std::placeholders;
  300. std::vector<Initproc> initprocs;
  301. foreach_operator(keeper, [&](MgbOp& op) {
  302. initprocs.emplace_back(OpDefEmitter(op, os, env).emit());
  303. });
  304. os << "#define INIT_ALL_OP(m)";
  305. for(auto&& init : initprocs) {
  306. os << formatv(" \\\n {0};", init("m"));
  307. }
  308. os << "\n";
  309. return false;
  310. }
  311. } // namespace mlir::tblgen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台