You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.h 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. /**
  2. * \file imperative/python/src/helper.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megbrain/common.h"
  13. #include "megbrain/utils/persistent_cache.h"
  14. #include "megbrain/imperative/op_def.h"
  15. #include <Python.h>
  16. #include <string>
  17. #include <iterator>
  18. #if __cplusplus > 201703L
  19. #include <ranges>
  20. #endif
  21. #include <pybind11/pybind11.h>
  22. #include <pybind11/stl.h>
  23. #include <pybind11/numpy.h>
  24. #include <pybind11/functional.h>
  25. #include "./numpy_dtypes.h"
  26. pybind11::module submodule(pybind11::module parent, const char* name, const char* doc = nullptr);
  27. pybind11::module rel_import(pybind11::str name, pybind11::module m, int level);
  28. #if __cplusplus > 201703L
  29. using std::ranges::range_value_t;
  30. #else
  31. template<typename T>
  32. using range_value_t = std::remove_cv_t<std::remove_reference_t<decltype(*std::declval<T>().begin())>>;
  33. #endif
  34. template<typename T>
  35. auto to_list(const T& x) {
  36. using elem_t = range_value_t<T>;
  37. std::vector<elem_t> ret(x.begin(), x.end());
  38. return pybind11::cast(ret);
  39. }
  40. template<typename T>
  41. auto to_tuple(const T& x, pybind11::return_value_policy policy = pybind11::return_value_policy::automatic) {
  42. auto ret = pybind11::tuple(x.size());
  43. for (size_t i = 0; i < x.size(); ++i) {
  44. ret[i] = pybind11::cast(x[i], policy);
  45. }
  46. return ret;
  47. }
  48. template<typename T>
  49. auto to_tuple(T begin, T end, pybind11::return_value_policy policy = pybind11::return_value_policy::automatic) {
  50. auto ret = pybind11::tuple(end - begin);
  51. for (size_t i = 0; begin < end; ++begin, ++i) {
  52. ret[i] = pybind11::cast(*begin, policy);
  53. }
  54. return ret;
  55. }
  56. class PyTaskDipatcher {
  57. struct Queue : mgb::AsyncQueueSC<std::function<void(void)>, Queue> {
  58. using Task = std::function<void(void)>;
  59. // set max_spin=0 to prevent Queue fetch task in busy wait manner.
  60. // this won't affect throughput when python interpreter is sending enough task,
  61. // but will significantly save CPU time when waiting for task, e.g. wait for data input
  62. Queue() : mgb::AsyncQueueSC<std::function<void(void)>, Queue>(0) {}
  63. void process_one_task(Task& f) {
  64. if (!Py_IsInitialized()) return;
  65. pybind11::gil_scoped_acquire _;
  66. f();
  67. }
  68. void on_async_queue_worker_thread_start() override {
  69. mgb::sys::set_thread_name("py_task_worker");
  70. }
  71. };
  72. Queue queue;
  73. bool finalized = false;
  74. public:
  75. template<typename T>
  76. void add_task(T&& task) {
  77. // CPython never dlclose an extension so
  78. // finalized means the interpreter has been shutdown
  79. if (!finalized) {
  80. queue.add_task(std::forward<T>(task));
  81. }
  82. }
  83. void wait_all_task_finish() {
  84. queue.wait_all_task_finish();
  85. }
  86. ~PyTaskDipatcher() {
  87. finalized = true;
  88. queue.wait_all_task_finish();
  89. }
  90. };
  91. extern PyTaskDipatcher py_task_q;
  92. class GILManager {
  93. PyGILState_STATE gstate;
  94. public:
  95. GILManager():
  96. gstate(PyGILState_Ensure())
  97. {
  98. }
  99. ~GILManager() {
  100. PyGILState_Release(gstate);
  101. }
  102. };
  103. #define PYTHON_GIL GILManager __gil_manager
  104. //! wraps a shared_ptr and decr PyObject ref when destructed
  105. class PyObjRefKeeper {
  106. std::shared_ptr<PyObject> m_ptr;
  107. public:
  108. static void deleter(PyObject* p) {
  109. if (p) {
  110. py_task_q.add_task([p](){Py_DECREF(p);});
  111. }
  112. }
  113. PyObjRefKeeper() = default;
  114. PyObjRefKeeper(PyObject* p) : m_ptr{p, deleter} {}
  115. PyObject* get() const { return m_ptr.get(); }
  116. //! create a shared_ptr as an alias of the underlying ptr
  117. template <typename T>
  118. std::shared_ptr<T> make_shared(T* ptr) const {
  119. return {m_ptr, ptr};
  120. }
  121. };
  122. //! exception to be thrown when python callback fails
  123. class PyExceptionForward : public std::exception {
  124. PyObject *m_type, *m_value, *m_traceback;
  125. std::string m_msg;
  126. PyExceptionForward(PyObject* type, PyObject* value, PyObject* traceback,
  127. const std::string& msg)
  128. : m_type{type},
  129. m_value{value},
  130. m_traceback{traceback},
  131. m_msg{msg} {}
  132. public:
  133. PyExceptionForward(const PyExceptionForward&) = delete;
  134. PyExceptionForward& operator=(const PyExceptionForward&) = delete;
  135. ~PyExceptionForward();
  136. PyExceptionForward(PyExceptionForward&& rhs)
  137. : m_type{rhs.m_type},
  138. m_value{rhs.m_value},
  139. m_traceback{rhs.m_traceback},
  140. m_msg{std::move(rhs.m_msg)} {
  141. rhs.m_type = rhs.m_value = rhs.m_traceback = nullptr;
  142. }
  143. //! throw PyExceptionForward from current python error state
  144. static void throw_() __attribute__((noreturn));
  145. //! restore python error
  146. void restore();
  147. const char* what() const noexcept override { return m_msg.c_str(); }
  148. };
  149. //! numpy utils
  150. namespace npy {
  151. //! convert tensor shape to raw vector
  152. static inline std::vector<size_t> shape2vec(const mgb::TensorShape &shape) {
  153. return {shape.shape, shape.shape + shape.ndim};
  154. }
  155. //! change numpy dtype to megbrain supported dtype
  156. PyObject* to_mgb_supported_dtype(PyObject *dtype);
  157. //! convert raw vector to tensor shape
  158. mgb::TensorShape vec2shape(const std::vector<size_t> &vec);
  159. struct PyArrayDescrDeleter {
  160. void operator()(PyArray_Descr* obj) {
  161. Py_XDECREF(obj);
  162. }
  163. };
  164. //! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new
  165. //! reference to the descriptor.
  166. std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(mgb::DType dtype);
  167. mgb::DType dtype_np2mgb_descr(PyArray_Descr* descr);
  168. //! convert megbrain dtype to numpy dtype object; return new reference
  169. PyObject* dtype_mgb2np(mgb::DType dtype);
  170. //! convert numpy dtype object or string to megbrain dtype
  171. mgb::DType dtype_np2mgb(PyObject *obj);
  172. //! buffer sharing type
  173. enum class ShareType {
  174. MUST_SHARE, //!< must be shared
  175. MUST_UNSHARE, //!< must not be shared
  176. TRY_SHARE //!< share if possible
  177. };
  178. //! get ndarray from HostTensorND
  179. PyObject* ndarray_from_tensor(const mgb::HostTensorND &val,
  180. ShareType share_type);
  181. //! specify how to convert numpy array to tensor
  182. struct Meth {
  183. bool must_borrow_ = false;
  184. mgb::HostTensorND *dest_tensor_ = nullptr;
  185. mgb::CompNode dest_cn_;
  186. //! make a Meth that allows borrowing numpy array memory
  187. static Meth borrow(
  188. mgb::CompNode dest_cn = mgb::CompNode::default_cpu()) {
  189. return {false, nullptr, dest_cn};
  190. }
  191. //! make a Meth that requires the numpy array to be borrowed
  192. static Meth must_borrow(
  193. mgb::CompNode dest_cn = mgb::CompNode::default_cpu()) {
  194. return {true, nullptr, dest_cn};
  195. }
  196. //! make a Meth that requires copying the value into another
  197. //! tensor
  198. static Meth copy_into(mgb::HostTensorND *tensor) {
  199. return {false, tensor, tensor->comp_node()};
  200. }
  201. };
  202. /*!
  203. * \brief convert an object to megbrain tensor
  204. * \param meth specifies how the conversion should take place
  205. * \param dtype desired dtype; it can be set as invalid to allow arbitrary
  206. * dtype
  207. */
  208. mgb::HostTensorND np2tensor(PyObject *obj, const Meth &meth,
  209. mgb::DType dtype);
  210. }
  211. // Note: following macro was copied from pybind11/detail/common.h
  212. // Robust support for some features and loading modules compiled against different pybind versions
  213. // requires forcing hidden visibility on pybind code, so we enforce this by setting the attribute on
  214. // the main `pybind11` namespace.
  215. #if !defined(PYBIND11_NAMESPACE)
  216. # ifdef __GNUG__
  217. # define PYBIND11_NAMESPACE pybind11 __attribute__((visibility("hidden")))
  218. # else
  219. # define PYBIND11_NAMESPACE pybind11
  220. # endif
  221. #endif
  222. namespace PYBIND11_NAMESPACE {
  223. namespace detail {
  224. template<typename T, unsigned N> struct type_caster<megdnn::SmallVector<T, N>>
  225. : list_caster<megdnn::SmallVector<T, N>, T> {};
  226. template <> struct type_caster<mgb::DType> {
  227. PYBIND11_TYPE_CASTER(mgb::DType, _("DType"));
  228. public:
  229. bool load(handle src, bool convert) {
  230. auto obj = reinterpret_borrow<object>(src);
  231. if (!convert && !isinstance<dtype>(obj)) {
  232. return false;
  233. }
  234. if (obj.is_none()) {
  235. return true;
  236. }
  237. try {
  238. obj = pybind11::dtype::from_args(obj);
  239. } catch (pybind11::error_already_set&) {
  240. return false;
  241. }
  242. try {
  243. value = npy::dtype_np2mgb(obj.ptr());
  244. } catch (...) {
  245. return false;
  246. }
  247. return true;
  248. }
  249. static handle cast(mgb::DType dt, return_value_policy /* policy */, handle /* parent */) {
  250. // ignore policy and parent because we always return a pure python object
  251. return npy::dtype_mgb2np(std::move(dt));
  252. }
  253. };
  254. template <> struct type_caster<mgb::TensorShape> {
  255. PYBIND11_TYPE_CASTER(mgb::TensorShape, _("TensorShape"));
  256. public:
  257. bool load(handle src, bool convert) {
  258. auto obj = reinterpret_borrow<object>(src);
  259. if (!convert && !isinstance<tuple>(obj)) {
  260. return false;
  261. }
  262. if (obj.is_none()) {
  263. return true;
  264. }
  265. value.ndim = len(obj);
  266. mgb_assert(value.ndim <= mgb::TensorShape::MAX_NDIM);
  267. size_t i = 0;
  268. for (auto v : obj) {
  269. mgb_assert(i < value.ndim);
  270. value.shape[i] = reinterpret_borrow<object>(v).cast<size_t>();
  271. ++i;
  272. }
  273. return true;
  274. }
  275. static handle cast(mgb::TensorShape shape, return_value_policy /* policy */, handle /* parent */) {
  276. // ignore policy and parent because we always return a pure python object
  277. return to_tuple(shape.shape, shape.shape + shape.ndim).release();
  278. }
  279. };
  280. // hack to make custom object implicitly convertible from None
  281. template <typename T> struct from_none_caster : public type_caster_base<T> {
  282. using base = type_caster_base<T>;
  283. bool load(handle src, bool convert) {
  284. if (!convert || !src.is_none()) {
  285. return base::load(src, convert);
  286. }
  287. // adapted from pybind11::implicitly_convertible
  288. auto temp = reinterpret_steal<object>(PyObject_Call(
  289. (PyObject*) this->typeinfo->type, tuple().ptr(), nullptr));
  290. if (!temp) {
  291. PyErr_Clear();
  292. return false;
  293. }
  294. // adapted from pybind11::detail::type_caster_generic
  295. if (base::load(temp, false)) {
  296. loader_life_support::add_patient(temp);
  297. return true;
  298. }
  299. return false;
  300. }
  301. };
  302. template<> struct type_caster<mgb::CompNode> : public from_none_caster<mgb::CompNode> {};
  303. template <> struct type_caster<mgb::PersistentCache::Blob> {
  304. PYBIND11_TYPE_CASTER(mgb::PersistentCache::Blob, _("Blob"));
  305. public:
  306. bool load(handle src, bool convert) {
  307. if (!isinstance<bytes>(src)) {
  308. return false;
  309. }
  310. value.ptr = PYBIND11_BYTES_AS_STRING(src.ptr());
  311. value.size = PYBIND11_BYTES_SIZE(src.ptr());
  312. return true;
  313. }
  314. static handle cast(mgb::PersistentCache::Blob blob, return_value_policy /* policy */, handle /* parent */) {
  315. return bytes((const char*)blob.ptr, blob.size).release();
  316. }
  317. };
  318. template <typename T> struct type_caster<mgb::Maybe<T>> {
  319. using value_conv = make_caster<T>;
  320. PYBIND11_TYPE_CASTER(mgb::Maybe<T>, _("Optional[") + value_conv::name + _("]"));
  321. public:
  322. bool load(handle src, bool convert) {
  323. if(!src) {
  324. return false;
  325. }
  326. if (src.is_none()) {
  327. return true;
  328. }
  329. value_conv inner_caster;
  330. if (!inner_caster.load(src, convert)) {
  331. return false;
  332. }
  333. value.emplace(cast_op<T&&>(std::move(inner_caster)));
  334. return true;
  335. }
  336. static handle cast(mgb::Maybe<T> src, return_value_policy policy, handle parent) {
  337. if(!src.valid()) {
  338. return none().inc_ref();
  339. }
  340. return pybind11::cast(src.val(), policy, parent);
  341. }
  342. };
  343. template<> struct type_caster<mgb::imperative::OpDef> {
  344. protected:
  345. std::shared_ptr<mgb::imperative::OpDef> value;
  346. public:
  347. static constexpr auto name = _("OpDef");
  348. operator mgb::imperative::OpDef&() { return *value; }
  349. operator const mgb::imperative::OpDef&() { return *value; }
  350. operator std::shared_ptr<mgb::imperative::OpDef>&() { return value; }
  351. operator std::shared_ptr<mgb::imperative::OpDef>&&() && { return std::move(value); }
  352. template <typename T> using cast_op_type = T;
  353. bool load(handle src, bool convert);
  354. static handle cast(const mgb::imperative::OpDef& op, return_value_policy /* policy */, handle /* parent */);
  355. static handle cast(std::shared_ptr<mgb::imperative::OpDef> op, return_value_policy policy, handle parent) {
  356. return cast(*op, policy, parent);
  357. }
  358. };
  359. template <> struct type_caster<std::shared_ptr<mgb::imperative::OpDef>> :
  360. public type_caster<mgb::imperative::OpDef> {
  361. template <typename T> using cast_op_type = pybind11::detail::movable_cast_op_type<T>;
  362. };
  363. } // detail
  364. } // PYBIND11_NAMESPACE
  365. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台