You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python_c_extension.cpp 14 kB


  1. /**
  2. * \file imperative/tablegen/targets/python_c_extension.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "python_c_extension.h"
  12. #include "../emitter.h"
  13. namespace mlir::tblgen {
  14. namespace {
  15. struct Initproc {
  16. std::string func;
  17. Initproc(std::string&& s): func(std::move(s)) {}
  18. std::string operator()(std::string argument) {
  19. return formatv("{0}({1})", func, argument);
  20. }
  21. };
  22. class OpDefEmitter: public EmitterBase {
  23. public:
  24. OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_):
  25. EmitterBase(os_, env_), op(op_) {
  26. ctx.withSelf(op.getCppClassName());
  27. }
  28. Initproc emit();
  29. private:
  30. void emit_class();
  31. void emit_py_init();
  32. void emit_py_getsetters();
  33. void emit_py_methods();
  34. Initproc emit_initproc();
  35. MgbOp& op;
  36. std::vector<Initproc> subclasses;
  37. mlir::tblgen::FmtContext ctx;
  38. };
  39. class EnumAttrEmitter: public EmitterBase {
  40. public:
  41. EnumAttrEmitter(llvm::StringRef parent, MgbEnumAttr* attr_, raw_ostream& os_, Environment& env_):
  42. EmitterBase(os_, env_), attr(attr_) {
  43. unsigned int enumID;
  44. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  45. auto&& aliasBase = alias->getAliasBase();
  46. enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
  47. } else {
  48. enumID = attr->getBaseRecord()->getID();
  49. }
  50. ctx.addSubst("enumTpl", attr->getEnumCombinedFlag() ? "BitCombinedEnumWrapper" : "EnumWrapper");
  51. ctx.addSubst("opClass", parent);
  52. ctx.addSubst("enumClass", attr->getEnumName());
  53. firstOccur = env().enumAlias.emplace(enumID, std::make_pair(parent, attr->getEnumName())).second;
  54. }
  55. Initproc emit();
  56. protected:
  57. void emit_trait();
  58. void emit_tpl_spl();
  59. Initproc emit_initproc();
  60. MgbEnumAttr* attr;
  61. bool firstOccur;
  62. mlir::tblgen::FmtContext ctx;
  63. };
  64. Initproc EnumAttrEmitter::emit() {
  65. emit_trait();
  66. emit_tpl_spl();
  67. return emit_initproc();
  68. }
  69. void EnumAttrEmitter::emit_trait() {
  70. if (!firstOccur) return;
  71. auto enumMax = [&] {
  72. if (attr->getEnumCombinedFlag()) {
  73. return formatv("(1llu << {0}) - 1", attr->getEnumMembers().size());
  74. } else {
  75. return formatv("{0} - 1", attr->getEnumMembers().size());
  76. }
  77. };
  78. os << tgfmt(R"(
  79. template<> struct EnumTrait<$opClass::$enumClass> {
  80. static constexpr const char *name = "$opClass.$enumClass";
  81. static constexpr std::underlying_type_t<$opClass::$enumClass> max = $0;
  82. };
  83. )", &ctx, enumMax());
  84. }
  85. void EnumAttrEmitter::emit_tpl_spl() {
  86. if (!firstOccur) return;
  87. os << tgfmt(
  88. "template<> PyTypeObject* $enumTpl<$opClass::$enumClass>::type = nullptr;\n",
  89. &ctx);
  90. auto quote = [&](auto&& i) -> std::string {
  91. size_t d1 = i.find(' ');
  92. size_t d2 = i.find('=');
  93. size_t d = d1 <= d2 ? d1 : d2;
  94. return formatv("\"{0}\"", i.substr(0, d));
  95. };
  96. os << tgfmt(R"(
  97. template<> const char*
  98. $enumTpl<$opClass::$enumClass>::members[] = {$0};
  99. )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", "));
  100. auto mem2value = [&](auto&& i) -> std::string {
  101. size_t d1 = i.find(' ');
  102. size_t d2 = i.find('=');
  103. size_t d = d1 <= d2 ? d1 : d2;
  104. return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx,
  105. i.substr(0, d));
  106. };
  107. os << tgfmt(R"(
  108. template<> std::unordered_map<std::string, $opClass::$enumClass>
  109. $enumTpl<$opClass::$enumClass>::mem2value = {$0};
  110. )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), mem2value), ", "));
  111. os << tgfmt(
  112. "template<> PyObject* "
  113. "$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n",
  114. &ctx, attr->getEnumMembers().size());
  115. }
  116. Initproc EnumAttrEmitter::emit_initproc() {
  117. std::string initproc = formatv("_init_py_{0}_{1}",
  118. ctx.getSubstFor("opClass"), ctx.getSubstFor("enumClass"));
  119. os << tgfmt(R"(
  120. void $0(PyTypeObject& py_type) {
  121. auto& e_type = $enumTpl<$opClass::$enumClass>::type;
  122. )", &ctx, initproc);
  123. if (firstOccur) {
  124. os << tgfmt(R"(
  125. static PyMethodDef tp_methods[] = {
  126. {const_cast<char*>("dump"), (PyCFunction)$enumTpl<$opClass::$enumClass>::py_dump, METH_NOARGS, NULL},
  127. {NULL} /* Sentinel */
  128. };
  129. )", &ctx);
  130. os << tgfmt(R"(
  131. static PyType_Slot slots[] = {
  132. {Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr},
  133. {Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare},
  134. {Py_tp_methods, tp_methods},
  135. )", &ctx);
  136. if (attr->getEnumCombinedFlag()) {
  137. // only bit combined enum could new instance because bitwise operation,
  138. // others should always use singleton
  139. os << tgfmt(R"(
  140. {Py_tp_new, (void*)$enumTpl<$opClass::$enumClass>::py_new_combined_enum},
  141. {Py_nb_or, (void*)$enumTpl<$opClass::$enumClass>::py_or},
  142. {Py_nb_and, (void*)$enumTpl<$opClass::$enumClass>::py_and},
  143. )", &ctx);
  144. }
  145. os << R"(
  146. {0, NULL}
  147. };)";
  148. os << tgfmt(R"(
  149. static PyType_Spec spec = {
  150. // name
  151. "megengine.core._imperative_rt.ops.$opClass.$enumClass",
  152. // basicsize
  153. sizeof($enumTpl<$opClass::$enumClass>),
  154. // itemsize
  155. 0,
  156. // flags
  157. Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
  158. // slots
  159. slots
  160. };)", &ctx);
  161. os << tgfmt(R"(
  162. e_type = reinterpret_cast<PyTypeObject*>(PyType_FromSpec(&spec));
  163. )", &ctx);
  164. for (auto&& i : {
  165. std::pair<std::string, std::string>{"__name__", tgfmt("$enumClass", &ctx)},
  166. {"__module__", "megengine.core._imperative_rt.ops"},
  167. {"__qualname__", tgfmt("$opClass.$enumClass", &ctx)}}) {
  168. os << formatv(R"(
  169. mgb_assert(
  170. e_type->tp_setattro(
  171. reinterpret_cast<PyObject*>(e_type),
  172. py::cast("{0}").release().ptr(),
  173. py::cast("{1}").release().ptr()) >= 0);
  174. )", i.first, i.second);
  175. }
  176. auto&& members = attr->getEnumMembers();
  177. for (size_t idx = 0; idx < members.size(); ++ idx) {
  178. size_t d1 = members[idx].find(' ');
  179. size_t d2 = members[idx].find('=');
  180. size_t d = d1 <= d2 ? d1 : d2;
  181. os << tgfmt(R"({
  182. PyObject* inst = e_type->tp_alloc(e_type, 0);
  183. reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
  184. mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
  185. $enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
  186. })", &ctx, members[idx].substr(0, d), idx);
  187. }
  188. }
  189. os << tgfmt(R"(
  190. Py_INCREF(e_type);
  191. mgb_assert(PyDict_SetItemString(
  192. py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(e_type)) >= 0);
  193. )", &ctx);
  194. os << "}\n";
  195. return initproc;
  196. }
  197. Initproc OpDefEmitter::emit() {
  198. for (auto&& i : op.getMgbAttributes()) {
  199. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  200. subclasses.push_back(EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit());
  201. }
  202. }
  203. emit_class();
  204. emit_py_init();
  205. emit_py_getsetters();
  206. emit_py_methods();
  207. return emit_initproc();
  208. }
  209. void OpDefEmitter::emit_class() {
  210. auto&& className = op.getCppClassName();
  211. std::string method_defs;
  212. std::vector<std::string> body;
  213. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  214. body.push_back(formatv(R"(
  215. {{"{0}", serialization<decltype(opdef.{0})>::dump(opdef.{0})})"
  216. , attr.name));
  217. });
  218. method_defs += formatv(R"(
  219. static PyObject* getstate(PyObject* self, PyObject*) {{
  220. auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst();
  221. static_cast<void>(opdef);
  222. std::unordered_map<std::string, py::object> state {{
  223. {1}
  224. };
  225. return py::cast(state).release().ptr();
  226. })", className, llvm::join(body, ","));
  227. body.clear();
  228. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  229. body.push_back(formatv(R"(
  230. {{
  231. auto&& iter = state.find("{0}");
  232. if (iter != state.end()) {
  233. opdef.{0} = serialization<decltype(opdef.{0})>::load(iter->second);
  234. }
  235. })", attr.name));
  236. });
  237. method_defs += formatv(R"(
  238. static PyObject* setstate(PyObject* self, PyObject* args) {{
  239. PyObject* dict = PyTuple_GetItem(args, 0);
  240. if (!dict) return NULL;
  241. auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
  242. auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst();
  243. static_cast<void>(opdef);
  244. {1}
  245. Py_RETURN_NONE;
  246. })", className, llvm::join(body, "\n"));
  247. os << tgfmt(R"(
  248. PyOpDefBegin($_self) // {
  249. static PyGetSetDef py_getsetters[];
  250. static PyMethodDef tp_methods[];
  251. $0
  252. static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
  253. // };
  254. PyOpDefEnd($_self)
  255. )", &ctx, method_defs);
  256. }
  257. void OpDefEmitter::emit_py_init() {
  258. std::string initBody;
  259. if (!op.getMgbAttributes().empty()) {
  260. initBody += "static const char* kwlist[] = {";
  261. std::vector<llvm::StringRef> attr_name_list;
  262. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  263. attr_name_list.push_back(attr.name);
  264. });
  265. attr_name_list.push_back("scope");
  266. llvm::for_each(attr_name_list, [&](auto&& attr) {
  267. initBody += formatv("\"{0}\", ", attr);
  268. });
  269. initBody += "NULL};\n";
  270. initBody += " PyObject ";
  271. auto initializer = [&](auto&& attr) -> std::string {
  272. return formatv("*{0} = NULL", attr);
  273. };
  274. initBody += llvm::join(llvm::map_range(attr_name_list, initializer), ", ") + ";\n";
  275. initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
  276. // an extra slot created for name
  277. initBody += std::string(attr_name_list.size(), 'O');
  278. initBody += "\", const_cast<char**>(kwlist)";
  279. llvm::for_each(attr_name_list, [&](auto&& attr) {
  280. initBody += formatv(", &{0}", attr);
  281. });
  282. initBody += "))\n";
  283. initBody += " return -1;\n";
  284. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  285. initBody += tgfmt(R"(
  286. if ($0) {
  287. try {
  288. // TODO: remove this guard which is used for pybind11 implicit conversion
  289. py::detail::loader_life_support guard{};
  290. reinterpret_cast<PyOp($_self)*>(self)->inst().$0 =
  291. py::cast<decltype($_self::$0)>(py::handle($0));
  292. } CATCH_ALL(-1)
  293. }
  294. )", &ctx, attr.name);
  295. });
  296. initBody += tgfmt(R"(
  297. if (scope) {
  298. try {
  299. reinterpret_cast<PyOp(OpDef)*>(self)->op
  300. ->set_scope(py::cast<std::string>(py::handle(scope)));
  301. } CATCH_ALL(-1)
  302. }
  303. )", &ctx);
  304. }
  305. initBody += "\n return 0;";
  306. os << tgfmt(R"(
  307. int PyOp($_self)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
  308. $0
  309. }
  310. )", &ctx, initBody);
  311. }
  312. void OpDefEmitter::emit_py_getsetters() {
  313. auto f = [&](auto&& attr) -> std::string {
  314. return tgfmt(
  315. "{const_cast<char*>(\"$0\"), py_get_generic($_self, $0), py_set_generic($_self, $0), const_cast<char*>(\"$0\"), NULL},",
  316. &ctx, attr.name);
  317. };
  318. os << tgfmt(R"(
  319. PyGetSetDef PyOp($_self)::py_getsetters[] = {
  320. $0
  321. {NULL} /* Sentinel */
  322. };
  323. )", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n "));
  324. }
  325. void OpDefEmitter::emit_py_methods(){
  326. // generate methods
  327. std::string method_defs;
  328. std::vector<std::string> method_items;
  329. {
  330. auto&& className = op.getCppClassName();
  331. // generate getstate
  332. method_items.push_back(formatv(
  333. "{{const_cast<char*>(\"__getstate__\"), PyOp({0})::getstate, METH_NOARGS, \"{0} getstate\"},",
  334. className));
  335. // generate setstate
  336. method_items.push_back(formatv(
  337. "{{const_cast<char*>(\"__setstate__\"), PyOp({0})::setstate, METH_VARARGS, \"{0} setstate\"},",
  338. className));
  339. }
  340. os << tgfmt(R"(
  341. PyMethodDef PyOp($_self)::tp_methods[] = {
  342. $0
  343. {NULL} /* Sentinel */
  344. };
  345. )", &ctx, llvm::join(method_items, "\n "));
  346. }
  347. Initproc OpDefEmitter::emit_initproc() {
  348. std::string initproc = formatv("_init_py_{0}", op.getCppClassName());
  349. std::string subclass_init_call;
  350. for (auto&& i : subclasses) {
  351. subclass_init_call += formatv(" {0};\n", i("py_type"));
  352. }
  353. os << tgfmt(R"(
  354. void $0(py::module m) {
  355. using py_op = PyOp($_self);
  356. auto& py_type = PyOpType($_self);
  357. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  358. py_type.tp_name = "megengine.core._imperative_rt.ops.$_self";
  359. py_type.tp_basicsize = sizeof(PyOp($_self));
  360. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  361. py_type.tp_doc = "$_self";
  362. py_type.tp_base = &PyOpType(OpDef);
  363. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  364. py_type.tp_new = py_new_generic<py_op>;
  365. py_type.tp_init = py_op::py_init;
  366. py_type.tp_methods = py_op::tp_methods;
  367. py_type.tp_getset = py_op::py_getsetters;
  368. mgb_assert(PyType_Ready(&py_type) >= 0);
  369. $1
  370. PyType_Modified(&py_type);
  371. m.add_object("$_self", reinterpret_cast<PyObject*>(&py_type));
  372. mgb_assert(PyOp(OpDef)::ctype2pytype.emplace($_self::typeinfo(), &py_type).second);
  373. }
  374. )", &ctx, initproc, subclass_init_call);
  375. return initproc;
  376. }
  377. } // namespace
  378. bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper) {
  379. Environment env;
  380. using namespace std::placeholders;
  381. std::vector<Initproc> initprocs;
  382. foreach_operator(keeper, [&](MgbOp& op) {
  383. initprocs.emplace_back(OpDefEmitter(op, os, env).emit());
  384. });
  385. os << "#define INIT_ALL_OP(m)";
  386. for(auto&& init : initprocs) {
  387. os << formatv(" \\\n {0};", init("m"));
  388. }
  389. os << "\n";
  390. return false;
  391. }
  392. } // namespace mlir::tblgen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台