You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dispatcher.cpp 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. /**
  2. * \file imperative/python/src/dispatcher.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./dispatcher.h"
  12. #include "./pyext17.h"
  13. #include "megbrain/exception.h"
  14. #include "megbrain/utils/hash.h"
  15. #include "megbrain/utils/small_vector.h"
  16. #include <unordered_map>
  17. #include <structmember.h>
  18. namespace py = pybind11;
  19. namespace pyx = pyext17;
  20. namespace {
  21. struct Handler {
  22. PyObject* func; // borrowed
  23. bool enabled;
  24. Handler() = default;
  25. Handler(PyObject* func_, bool enable = true) : func(func_), enabled(enable) {}
  26. };
  27. using FastSig = mgb::SmallVector<void*, 8>;
  28. using MRO = std::vector<Handler*>;
  29. struct Frame {
  30. MRO* mro;
  31. size_t mro_offset;
  32. Frame() = default;
  33. Frame(MRO* mro_, size_t mro_offset_ = 0) : mro(mro_), mro_offset(mro_offset_) {}
  34. };
  35. struct FastSigHash {
  36. size_t operator()(const FastSig& sig) const {
  37. auto* ptr = &sig.front();
  38. return mgb::XXHash()
  39. .update(ptr, sig.size() * sizeof(FastSig::value_type))
  40. .digest();
  41. }
  42. };
  43. struct ObjectIdHash : std::hash<void*> {
  44. size_t operator()(const py::handle& h) const {
  45. return std::hash<void*>::operator()(h.ptr());
  46. }
  47. };
  48. namespace {
  49. using Container = std::vector<Frame>;
  50. struct DispatcherStack: Container {
  51. constexpr static size_t MAX_RECURSIVE_DEPTH = 1024u;
  52. DispatcherStack() { reserve(MAX_RECURSIVE_DEPTH); }
  53. template<typename... Args>
  54. auto&& emplace_back_safely(Args&& ...args) {
  55. mgb_throw_if(size() >= MAX_RECURSIVE_DEPTH, mgb::MegBrainError,
  56. "recursion depth %zu is greater than the MAX_RECURSIVE_DEPTH(%zu)",
  57. size(), MAX_RECURSIVE_DEPTH);
  58. return emplace_back(std::forward<Args>(args)...);
  59. }
  60. };
  61. } // anonymous namespace
  62. struct Dispatcher {
  63. std::unordered_map<FastSig, std::unique_ptr<MRO>, FastSigHash> cache;
  64. DispatcherStack stack;
  65. std::unordered_map<py::object, std::unique_ptr<Handler>, ObjectIdHash> registry;
  66. inline py::handle self() {
  67. return pyx::wrap<Dispatcher>::pycast(this);
  68. }
  69. bool prepare_call(PyObject*const* args, Py_ssize_t nargs) {
  70. FastSig sig(nargs);
  71. for (Py_ssize_t i = 0; i < nargs; ++i) {
  72. sig[i] = Py_TYPE(args[i]);
  73. }
  74. auto it = cache.find(sig);
  75. if (it == cache.end()) {
  76. if (auto mro = resolve(sig)) {
  77. it = cache.emplace(std::move(sig), std::move(mro)).first;
  78. } else {
  79. return false;
  80. }
  81. }
  82. stack.emplace_back_safely(it->second.get());
  83. return true;
  84. }
  85. template<typename T>
  86. PyObject* do_call(T&& caller) {
  87. auto& frame = stack.back();
  88. auto& mro = *frame.mro;
  89. auto& i = frame.mro_offset;
  90. if (!mro.size()) {
  91. PyErr_SetString(PyExc_NotImplementedError, "function not registered in dispatcher");
  92. return nullptr;
  93. }
  94. for (; i < mro.size(); ++i) {
  95. if (mro[i]->enabled) {
  96. auto ret = caller(mro[i]->func);
  97. if (ret != Py_NotImplemented) {
  98. stack.pop_back();
  99. return ret;
  100. }
  101. Py_DECREF(ret);
  102. }
  103. }
  104. PyErr_SetString(PyExc_NotImplementedError, "mro exhausted");
  105. stack.pop_back();
  106. return nullptr;
  107. }
  108. std::unique_ptr<MRO> resolve(const FastSig& sig) {
  109. try {
  110. py::tuple args(sig.size());
  111. for (size_t i = 0; i < sig.size(); ++i) {
  112. args[i] = (PyObject*)sig[i];
  113. }
  114. auto mro_iter = self().attr("dispatch_iter")(*args);
  115. auto ret = std::make_unique<MRO>();
  116. for (auto i : mro_iter) {
  117. auto it = registry.find(py::reinterpret_borrow<py::object>(i));
  118. if (it == registry.end()) {
  119. PyErr_SetString(PyExc_RuntimeError, "resolved to unregistered function");
  120. return nullptr;
  121. }
  122. ret->push_back(it->second.get());
  123. }
  124. return ret;
  125. } catch (py::error_already_set& e) {
  126. e.restore();
  127. } catch (std::runtime_error& e) {
  128. PyErr_SetString(PyExc_RuntimeError, e.what());
  129. }
  130. return nullptr;
  131. }
  132. public:
  133. static constexpr auto tp_name = "Dispatcher";
  134. PyObject* tp_call(PyObject* args, PyObject* kwargs) {
  135. if (!prepare_call(&PyTuple_GET_ITEM(args, 0), PyTuple_GET_SIZE(args))) return nullptr;
  136. return do_call([=](PyObject* func){return PyObject_Call(func, args, kwargs);});
  137. }
  138. #if PY_MINOR_VERSION >= 6
  139. PyObject* tp_vectorcall(PyObject*const* args, Py_ssize_t nargs) {
  140. if (!prepare_call(args, nargs)) return nullptr;
  141. return do_call([=](PyObject* func){return _PyObject_FastCall(func, const_cast<PyObject**>(args), nargs);});
  142. }
  143. #endif
  144. #if PY_MINOR_VERSION >= 6
  145. PyObject* super(PyObject*const* args, Py_ssize_t nargs) {
  146. if (stack.empty()) {
  147. PyErr_SetString(PyExc_RuntimeError, "super called at top level");
  148. return nullptr;
  149. }
  150. stack.emplace_back_safely(stack.back()).mro_offset++;
  151. return do_call([=](PyObject* func){return _PyObject_FastCall(func, const_cast<PyObject**>(args), nargs);});
  152. }
  153. #else
  154. PyObject* super(PyObject* args, PyObject* kwargs) {
  155. if (stack.empty()) {
  156. PyErr_SetString(PyExc_RuntimeError, "super called at top level");
  157. return nullptr;
  158. }
  159. stack.emplace_back_safely(stack.back()).mro_offset++;
  160. return do_call([=](PyObject* func){return PyObject_Call(func, args, kwargs);});
  161. }
  162. #endif
  163. void enable(PyObject* func) {
  164. auto obj = py::reinterpret_borrow<py::object>(func);
  165. auto it = registry.find(obj);
  166. if (it != registry.end()) {
  167. it->second->enabled = true;
  168. } else {
  169. registry.emplace(std::move(obj), std::make_unique<Handler>(func));
  170. }
  171. }
  172. PyObject* disable(PyObject* func) {
  173. auto obj = py::reinterpret_borrow<py::object>(func);
  174. auto it = registry.find(obj);
  175. if (it == registry.end()) {
  176. PyErr_SetString(PyExc_ValueError, "function not registered");
  177. return nullptr;
  178. } else {
  179. it->second->enabled = false;
  180. }
  181. Py_RETURN_NONE;
  182. }
  183. void clear_cache() {
  184. cache.clear();
  185. }
  186. };
  187. } // namespace
  188. void init_dispatcher(py::module m) {
  189. auto* dispatcher_type = pyx::wrap<Dispatcher>::type()
  190. .def<&Dispatcher::enable>("enable")
  191. .def<&Dispatcher::disable>("disable")
  192. .def<&Dispatcher::clear_cache>("clear_cache")
  193. #if PY_MINOR_VERSION >= 6
  194. .def<&Dispatcher::tp_vectorcall>("call")
  195. #else
  196. .def<&Dispatcher::tp_call>("call")
  197. #endif
  198. .def<&Dispatcher::super>("super")
  199. .finalize();
  200. if (!dispatcher_type) throw py::error_already_set();
  201. m.attr("Dispatcher") = dispatcher_type;
  202. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台