You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pyext17.h 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. /**
  2. * \file imperative/python/src/pyext17.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <stdexcept>
  13. #include <vector>
  14. #include <utility>
  15. #include <Python.h>
  16. #include <pybind11/pybind11.h>
  17. namespace pyext17 {
  18. #ifdef METH_FASTCALL
  19. constexpr bool has_fastcall = true;
  20. #else
  21. constexpr bool has_fastcall = false;
  22. #endif
  23. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  24. constexpr bool has_vectorcall = true;
  25. #else
  26. constexpr bool has_vectorcall = false;
  27. #endif
  28. template<typename... Args>
  29. struct invocable_with {
  30. template<typename T>
  31. constexpr bool operator()(T&& lmb) {
  32. return std::is_invocable_v<T, Args...>;
  33. }
  34. };
  35. #define HAS_MEMBER_TYPE(T, U) invocable_with<T>{}([](auto&& x) -> typename std::decay_t<decltype(x)>::U {})
  36. #define HAS_MEMBER(T, m) invocable_with<T>{}([](auto&& x) -> decltype(&std::decay_t<decltype(x)>::m) {})
  37. inline PyObject* cvt_retval(PyObject* rv) {
  38. return rv;
  39. }
  40. #define CVT_RET_PYOBJ(...) \
  41. if constexpr (std::is_same_v<decltype(__VA_ARGS__), void>) { \
  42. __VA_ARGS__; \
  43. Py_RETURN_NONE; \
  44. } else { \
  45. return cvt_retval(__VA_ARGS__); \
  46. }
  47. inline int cvt_retint(int ret) {
  48. return ret;
  49. }
  50. #define CVT_RET_INT(...) \
  51. if constexpr (std::is_same_v<decltype(__VA_ARGS__), void>) { \
  52. __VA_ARGS__; \
  53. return 0; \
  54. } else { \
  55. return cvt_retint(__VA_ARGS__); \
  56. }
  57. struct py_err_set : std::exception {};
  58. #define HANDLE_ALL_EXC(RET) catch(py_err_set&) {return RET;} \
  59. catch(pybind11::error_already_set& e) {e.restore(); return RET;} \
  60. catch(pybind11::builtin_exception& e) {e.set_error(); return RET;} \
  61. catch(std::exception& e) {PyErr_SetString(PyExc_RuntimeError, e.what()); return RET;}
  62. template <typename T>
  63. struct wrap {
  64. private:
  65. typedef wrap<T> wrap_t;
  66. public:
  67. PyObject_HEAD
  68. std::aligned_storage_t<sizeof(T), alignof(T)> storage;
  69. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  70. PyObject* (*vectorcall_slot)(PyObject*, PyObject*const*, size_t, PyObject*);
  71. #endif
  72. inline T* inst() {
  73. return reinterpret_cast<T*>(&storage);
  74. }
  75. inline static PyObject* pycast(T* ptr) {
  76. return (PyObject*)((char*)ptr - offsetof(wrap_t, storage));
  77. }
  78. private:
  79. // method wrapper
  80. enum struct meth_type {
  81. noarg,
  82. varkw,
  83. fastcall,
  84. singarg
  85. };
  86. template<auto f>
  87. struct detect_meth_type {
  88. static constexpr meth_type value = []() {
  89. using F = decltype(f);
  90. static_assert(std::is_member_function_pointer_v<F>);
  91. if constexpr (std::is_invocable_v<F, T>) {
  92. return meth_type::noarg;
  93. } else if constexpr (std::is_invocable_v<F, T, PyObject*, PyObject*>) {
  94. return meth_type::varkw;
  95. } else if constexpr (std::is_invocable_v<F, T, PyObject*const*, Py_ssize_t>) {
  96. return meth_type::fastcall;
  97. } else if constexpr (std::is_invocable_v<F, T, PyObject*>) {
  98. return meth_type::singarg;
  99. } else {
  100. static_assert(!std::is_same_v<F, F>);
  101. }
  102. }();
  103. };
  104. template<meth_type, auto f>
  105. struct meth {};
  106. template<auto f>
  107. struct meth<meth_type::noarg, f> {
  108. static constexpr int flags = METH_NOARGS;
  109. static PyObject* impl(PyObject* self, PyObject*) {
  110. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  111. try {
  112. CVT_RET_PYOBJ((inst->*f)());
  113. } HANDLE_ALL_EXC(nullptr)
  114. }
  115. };
  116. template<auto f>
  117. struct meth<meth_type::varkw, f> {
  118. static constexpr int flags = METH_VARARGS | METH_KEYWORDS;
  119. static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
  120. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  121. try {
  122. CVT_RET_PYOBJ((inst->*f)(args, kwargs));
  123. } HANDLE_ALL_EXC(nullptr)
  124. }
  125. };
  126. template<auto f>
  127. struct meth<meth_type::fastcall, f> {
  128. #ifdef METH_FASTCALL
  129. static constexpr int flags = METH_FASTCALL;
  130. static PyObject* impl(PyObject* self, PyObject*const* args, Py_ssize_t nargs) {
  131. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  132. try {
  133. CVT_RET_PYOBJ((inst->*f)(args, nargs));
  134. } HANDLE_ALL_EXC(nullptr)
  135. }
  136. #else
  137. static constexpr int flags = METH_VARARGS;
  138. static PyObject* impl(PyObject* self, PyObject* args) {
  139. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  140. auto* arr = &PyTuple_GET_ITEM(args, 0);
  141. auto size = PyTuple_GET_SIZE(args);
  142. try {
  143. CVT_RET_PYOBJ((inst->*f)(arr, size));
  144. } HANDLE_ALL_EXC(nullptr)
  145. }
  146. #endif
  147. };
  148. template<auto f>
  149. struct meth<meth_type::singarg, f> {
  150. static constexpr int flags = METH_O;
  151. static PyObject* impl(PyObject* self, PyObject* obj) {
  152. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  153. try {
  154. CVT_RET_PYOBJ((inst->*f)(obj));
  155. } HANDLE_ALL_EXC(nullptr)
  156. }
  157. };
  158. template<auto f>
  159. static constexpr PyMethodDef make_meth_def(const char* name, const char* doc = nullptr) {
  160. using M = meth<detect_meth_type<f>::value, f>;
  161. return {name, (PyCFunction)M::impl, M::flags, doc};
  162. }
  163. template<auto f>
  164. struct getter {
  165. using F = decltype(f);
  166. static PyObject* impl(PyObject* self, void* closure) {
  167. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  168. try {
  169. if constexpr (std::is_invocable_v<F, PyObject*, void*>) {
  170. CVT_RET_PYOBJ(f(self, closure));
  171. } else if constexpr (std::is_invocable_v<F, T, void*>) {
  172. CVT_RET_PYOBJ((inst->*f)(closure));
  173. } else if constexpr (std::is_invocable_v<F, T>) {
  174. CVT_RET_PYOBJ((inst->*f)());
  175. } else {
  176. static_assert(!std::is_same_v<F, F>);
  177. }
  178. } HANDLE_ALL_EXC(nullptr)
  179. }
  180. };
  181. template<auto f>
  182. struct setter {
  183. using F = decltype(f);
  184. template<typename = void>
  185. static int impl_(PyObject* self, PyObject* val, void* closure) {
  186. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  187. try {
  188. if constexpr (std::is_invocable_v<F, PyObject*, PyObject*, void*>) {
  189. CVT_RET_INT(f(self, val, closure));
  190. } else if constexpr (std::is_invocable_v<F, T, PyObject*, void*>) {
  191. CVT_RET_INT((inst->*f)(val, closure));
  192. } else if constexpr (std::is_invocable_v<F, T, PyObject*>) {
  193. CVT_RET_INT((inst->*f)(val));
  194. } else {
  195. static_assert(!std::is_same_v<F, F>);
  196. }
  197. } HANDLE_ALL_EXC(-1)
  198. }
  199. static constexpr auto impl = []() {if constexpr (std::is_same_v<F, std::nullptr_t>) return nullptr;
  200. else return impl_<>;}();
  201. };
  202. template<auto get, auto set = nullptr>
  203. static constexpr PyGetSetDef make_getset_def(const char* name, const char* doc = nullptr, void* closure = nullptr) {
  204. return {const_cast<char *>(name), getter<get>::impl, setter<set>::impl, const_cast<char *>(doc), closure};
  205. }
  206. // polyfills
  207. struct tp_vectorcall {
  208. static constexpr bool valid = HAS_MEMBER(T, tp_vectorcall);
  209. static constexpr bool haskw = [](){if constexpr (valid)
  210. if constexpr (std::is_invocable_v<decltype(&T::tp_vectorcall), T, PyObject*const*, size_t, PyObject*>)
  211. return true;
  212. return false;}();
  213. template<typename = void>
  214. static PyObject* impl(PyObject* self, PyObject*const* args, size_t nargsf, PyObject *kwnames) {
  215. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  216. if constexpr (haskw) {
  217. CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf, kwnames));
  218. } else {
  219. if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  220. PyErr_SetString(PyExc_TypeError, "expect no keyword argument");
  221. return nullptr;
  222. }
  223. CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf));
  224. }
  225. }
  226. static constexpr Py_ssize_t offset = []() {if constexpr (valid) return offsetof(wrap_t, vectorcall_slot);
  227. else return 0;}();
  228. };
  229. struct tp_call {
  230. static constexpr bool provided = HAS_MEMBER(T, tp_call);
  231. static constexpr bool static_form = invocable_with<T, PyObject*, PyObject*, PyObject*>{}(
  232. [](auto&& t, auto... args) -> decltype(std::decay_t<decltype(t)>::tp_call(args...)) {});
  233. static constexpr bool valid = provided || tp_vectorcall::valid;
  234. template<typename = void>
  235. static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
  236. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  237. CVT_RET_PYOBJ(inst->tp_call(args, kwargs));
  238. }
  239. static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call;
  240. else if constexpr (provided) return impl<>;
  241. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  242. else if constexpr (valid) return PyVectorcall_Call;
  243. #endif
  244. else return nullptr;}();
  245. };
  246. struct tp_new {
  247. static constexpr bool provided = HAS_MEMBER(T, tp_new);
  248. static constexpr bool varkw = std::is_constructible_v<T, PyObject*, PyObject*>;
  249. static constexpr bool noarg = std::is_default_constructible_v<T>;
  250. template<typename = void>
  251. static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
  252. struct FreeGuard {
  253. PyObject* self;
  254. PyTypeObject* type;
  255. ~FreeGuard() {if (self) type->tp_free(self);}
  256. };
  257. auto* self = type->tp_alloc(type, 0);
  258. FreeGuard free_guard{self, type};
  259. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  260. if constexpr (has_vectorcall && tp_vectorcall::valid) {
  261. reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>;
  262. }
  263. try {
  264. if constexpr (varkw) {
  265. new(inst) T(args, kwargs);
  266. } else {
  267. new(inst) T();
  268. }
  269. } HANDLE_ALL_EXC(nullptr)
  270. free_guard.self = nullptr;
  271. return self;
  272. }
  273. static constexpr newfunc value = []() {if constexpr (provided) return T::tp_new;
  274. else if constexpr (varkw || noarg) return impl<>;
  275. else return nullptr;}();
  276. };
  277. struct tp_dealloc {
  278. static constexpr bool provided = HAS_MEMBER(T, tp_dealloc);
  279. template<typename = void>
  280. static void impl(PyObject* self) {
  281. reinterpret_cast<wrap_t*>(self)->inst()->~T();
  282. Py_TYPE(self)->tp_free(self);
  283. }
  284. static constexpr destructor value = []() {if constexpr (provided) return T::tp_dealloc;
  285. else return impl<>;}();
  286. };
  287. public:
  288. class TypeBuilder {
  289. std::vector<PyMethodDef> m_methods;
  290. std::vector<PyGetSetDef> m_getsets;
  291. PyTypeObject m_type;
  292. bool m_finalized = false;
  293. bool m_ready = false;
  294. void check_finalized() {
  295. if (m_finalized) {
  296. throw std::runtime_error("type is already finalized");
  297. }
  298. }
  299. static const char* to_c_str(const char* s) {return s;}
  300. template <size_t N, typename... Ts>
  301. static const char* to_c_str(const pybind11::detail::descr<N, Ts...>& desc) {
  302. return desc.text;
  303. }
  304. public:
  305. TypeBuilder(const TypeBuilder&) = delete;
  306. TypeBuilder& operator=(const TypeBuilder&) = delete;
  307. TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} {
  308. constexpr auto has_tp_name = HAS_MEMBER(T, tp_name);
  309. if constexpr (has_tp_name) {
  310. m_type.tp_name = to_c_str(T::tp_name);
  311. }
  312. m_type.tp_dealloc = tp_dealloc::value;
  313. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  314. m_type.tp_vectorcall_offset = tp_vectorcall::offset;
  315. #endif
  316. m_type.tp_call = tp_call::value;
  317. m_type.tp_basicsize = sizeof(wrap_t);
  318. m_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  319. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  320. if constexpr (tp_vectorcall::valid) {
  321. m_type.tp_flags |= _Py_TPFLAGS_HAVE_VECTORCALL;
  322. }
  323. #endif
  324. m_type.tp_new = tp_new::value;
  325. }
  326. PyTypeObject* operator->() {
  327. return &m_type;
  328. }
  329. bool ready() const {
  330. return m_ready;
  331. }
  332. bool isinstance(PyObject* op) {
  333. return PyObject_TypeCheck(op, &m_type);
  334. }
  335. bool isexact(PyObject* op) {
  336. return Py_TYPE(op) == &m_type;
  337. }
  338. PyObject* finalize() {
  339. if (!m_finalized) {
  340. m_finalized = true;
  341. if (m_methods.size()) {
  342. m_methods.push_back({0});
  343. if (m_type.tp_methods) {
  344. PyErr_SetString(PyExc_SystemError, "tp_method is already set");
  345. return nullptr;
  346. }
  347. m_type.tp_methods = &m_methods[0];
  348. }
  349. if (m_getsets.size()) {
  350. m_getsets.push_back({0});
  351. if (m_type.tp_getset) {
  352. PyErr_SetString(PyExc_SystemError, "tp_getset is already set");
  353. return nullptr;
  354. }
  355. m_type.tp_getset = &m_getsets[0];
  356. }
  357. if (PyType_Ready(&m_type)) {
  358. return nullptr;
  359. }
  360. m_ready = true;
  361. }
  362. return (PyObject*)&m_type;
  363. }
  364. template<auto f>
  365. TypeBuilder& def(const char* name, const char* doc = nullptr) {
  366. check_finalized();
  367. m_methods.push_back(make_meth_def<f>(name, doc));
  368. return *this;
  369. }
  370. template<auto get, auto set = nullptr>
  371. TypeBuilder& def_getset(const char* name, const char* doc = nullptr, void* closure = nullptr) {
  372. check_finalized();
  373. m_getsets.push_back(make_getset_def<get, set>(name, doc, closure));
  374. return *this;
  375. }
  376. };
  377. static TypeBuilder& type() {
  378. static TypeBuilder type_helper;
  379. return type_helper;
  380. }
  381. template<typename... Args>
  382. static PyObject* cnew(Args&&... args) {
  383. auto* pytype = type().operator->();
  384. auto* self = pytype->tp_alloc(pytype, 0);
  385. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  386. if constexpr (has_vectorcall && tp_vectorcall::valid) {
  387. reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>;
  388. }
  389. new(inst) T(std::forward<Args>(args)...);
  390. return self;
  391. }
  392. template<typename... Args>
  393. static PyObject* cnew_with_type(PyTypeObject* pytype, Args&&... args) {
  394. auto* self = pytype->tp_alloc(pytype, 0);
  395. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  396. if constexpr (has_vectorcall && tp_vectorcall::valid) {
  397. reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>;
  398. }
  399. new(inst) T(std::forward<Args>(args)...);
  400. return self;
  401. }
  402. struct caster {
  403. static constexpr auto name = T::tp_name;
  404. T* value;
  405. bool load(pybind11::handle src, bool convert) {
  406. if (wrap_t::type().isinstance(src.ptr())) {
  407. value = reinterpret_cast<wrap_t*>(src.ptr())->inst();
  408. return true;
  409. }
  410. return false;
  411. }
  412. template <typename U> using cast_op_type = pybind11::detail::cast_op_type<U>;
  413. operator T*() { return value; }
  414. operator T&() { return *value; }
  415. };
  416. };
  417. } // namespace pyext17
  418. #undef HAS_MEMBER_TYPE
  419. #undef HAS_MEMBER
  420. #undef CVT_RET_PYOBJ
  421. #undef CVT_RET_INT
  422. #undef HANDLE_ALL_EXC

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台