You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.h 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. /**
  2. * \file imperative/python/src/tensor.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #pragma GCC diagnostic ignored "-Wmissing-field-initializers"
  13. #include <variant>
  14. #include <string>
  15. #include <unordered_map>
  16. #include "megbrain/imperative/interpreter.h"
  17. #include "pybind11/pybind11.h"
  18. #include "./pyext17.h"
  19. namespace mgb::imperative::python {
  20. template <typename T, typename B = pybind11::object>
  21. struct ObjectPtr : B {
  22. using B::B;
  23. T& operator*() { return reinterpret_cast<T&>(*B::ptr()); }
  24. T* operator->() { return reinterpret_cast<T*>(B::ptr()); }
  25. };
  26. } // namespace mgb::imperative::python
  27. #include "./grad_info.h" // for struct GradInfo
  28. #include "./trace_info.h" // for struct TraceInfo
  29. namespace mgb::imperative::python {
  30. struct GradKey;
  31. extern interpreter::Interpreter::Channel* interpreter_for_py;
  32. class SharedHandle {
  33. using Handle = interpreter::Interpreter::Handle;
  34. static_assert(std::is_pointer_v<Handle>);
  35. std::shared_ptr<std::remove_pointer_t<Handle>> holder;
  36. public:
  37. inline explicit SharedHandle(Handle handle)
  38. : holder(handle, [](auto* h) {
  39. if (h) {
  40. interpreter_for_py->del(h);
  41. }
  42. }) {}
  43. SharedHandle(const SharedHandle&) = default;
  44. SharedHandle& operator=(const SharedHandle&) = default;
  45. SharedHandle(SharedHandle&&) = default;
  46. SharedHandle& operator=(SharedHandle&&) = default;
  47. inline Handle get() { return holder.get(); }
  48. };
  49. // impl in grad.cpp
  50. class GradInfoCollection {
  51. private:
  52. SmallVector<GradInfo> m_storage;
  53. protected:
  54. void _shrink();
  55. public:
  56. bool contains(GradKey* key);
  57. GradInfo& operator[](GradKey* key);
  58. GradInfo& at(GradKey* key);
  59. bool empty() {
  60. _shrink();
  61. return m_storage.empty();
  62. }
  63. auto begin() {
  64. _shrink();
  65. return m_storage.begin();
  66. }
  67. auto end() {
  68. _shrink();
  69. return m_storage.end();
  70. }
  71. size_t count(GradKey* key) { return contains(key) ? 1 : 0; }
  72. };
  73. struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
  74. using flags_t = uint64_t;
  75. struct Flags {
  76. static constexpr flags_t SCALAR = 1;
  77. static constexpr flags_t GRAD = 1 << 1;
  78. static constexpr flags_t TRACE = 1 << 2;
  79. static constexpr flags_t MODULE_TRACE = 1 << 3;
  80. };
  81. flags_t m_flags = 0;
  82. GradInfoCollection m_grad_info_dict;
  83. TraceInfo m_trace_info;
  84. SharedHandle m_handle;
  85. std::string user_custom_name;
  86. std::string automatic_name;
  87. cg::VarNode* m_var;
  88. pybind11::object m_module_trace_info;
  89. using Handle = interpreter::Interpreter::Handle;
  90. inline Tensor() : m_handle(nullptr), m_var(nullptr) {}
  91. inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {}
  92. inline explicit Tensor(SharedHandle handle)
  93. : m_handle(std::move(handle)), m_var(nullptr) {}
  94. inline explicit Tensor(cg::VarNode* var) : m_handle(nullptr), m_var(var) {}
  95. ~Tensor() = default;
  96. inline std::shared_ptr<Tensor> copy() {
  97. auto ret = std::make_shared<Tensor>(m_handle);
  98. ret->m_flags = m_flags;
  99. ret->m_grad_info_dict = m_grad_info_dict;
  100. ret->m_trace_info = m_trace_info;
  101. ret->m_var = m_var;
  102. return ret;
  103. }
  104. inline DType dtype() {
  105. if (m_var) {
  106. return m_var->dtype();
  107. }
  108. return interpreter_for_py->get_dtype(m_handle.get());
  109. }
  110. inline CompNode comp_node() {
  111. if (m_var) {
  112. return m_var->comp_node();
  113. }
  114. return interpreter_for_py->get_device(m_handle.get());
  115. }
  116. inline TensorShape shape() {
  117. if (m_var) {
  118. return m_var->shape();
  119. }
  120. return interpreter_for_py->get_shape(m_handle.get());
  121. }
  122. };
  123. struct TensorWrapper {
  124. std::shared_ptr<Tensor> m_tensor;
  125. inline TensorWrapper(std::shared_ptr<Tensor> tensor = {})
  126. : m_tensor(std::move(tensor)) {}
  127. TensorWrapper(PyObject* args, PyObject* kwargs);
  128. ~TensorWrapper() = default;
  129. static constexpr auto tp_name = pybind11::detail::_("Tensor");
  130. using wrap_t = pyext17::wrap<TensorWrapper>;
  131. friend wrap_t;
  132. inline static TensorWrapper* cast(PyObject* obj) {
  133. return reinterpret_cast<wrap_t*>(obj)->inst();
  134. }
  135. inline static TensorWrapper* try_cast(PyObject* obj) {
  136. if (!wrap_t::type().isinstance(obj))
  137. return nullptr;
  138. return cast(obj);
  139. }
  140. inline ObjectPtr<TensorWrapper, pybind11::handle> self() {
  141. return wrap_t::pycast(this);
  142. }
  143. template <typename... Args>
  144. static ObjectPtr<Tensor> make(Args&&... args) {
  145. auto* op = wrap_t::cnew(std::forward<Args>(args)...);
  146. return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
  147. }
  148. template <typename... Args>
  149. static ObjectPtr<Tensor> make(PyTypeObject* pytype, Args&&... args) {
  150. auto* op = wrap_t::cnew_with_type(pytype, std::forward<Args>(args)...);
  151. return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
  152. }
  153. PyObject* shape();
  154. PyObject* dtype();
  155. PyObject* device();
  156. PyObject* numpy();
  157. void reset(PyObject*);
  158. PyObject* detach();
  159. PyObject* isscalar();
  160. void setscalar();
  161. void unsetscalar();
  162. PyObject* _dev_tensor();
  163. void _swap_in();
  164. void _swap_out();
  165. void _drop();
  166. PyObject* varnode();
  167. void reset_varnode();
  168. PyObject* handle();
  169. void set_handle(PyObject*);
  170. PyObject* mixin_handle();
  171. PyObject* recording();
  172. PyObject* copied();
  173. void set_mixin_handle(PyObject*);
  174. void set_recording(PyObject*);
  175. PyObject* compiled_info();
  176. void set_compiled_info(PyObject*);
  177. PyObject* trace_mixin_info();
  178. void set_trace_mixin_info(PyObject*);
  179. PyObject* module_trace_info();
  180. void set_module_trace_info(PyObject*);
  181. PyObject* user_custom_name();
  182. void set_user_custom_name(PyObject*);
  183. PyObject* automatic_name();
  184. void set_automatic_name(PyObject*);
  185. PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
  186. };
  187. struct PySymbolVar {
  188. cg::VarNode* m_node = nullptr;
  189. bool is_scalar = false;
  190. PySymbolVar() = default;
  191. PySymbolVar(VarNode* m) : m_node(m) {}
  192. };
  193. PyObject* py_apply(
  194. PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */);
  195. struct ApplyContext {
  196. static Tensor::flags_t global_disable;
  197. static Tensor::flags_t global_enable;
  198. Tensor::flags_t flags = 0;
  199. std::shared_ptr<OpDef> op;
  200. Tensor* const* args;
  201. size_t nargs;
  202. PyTypeObject* pytype = nullptr;
  203. bool backward = false;
  204. class scoped_disable : NonCopyableObj {
  205. Tensor::flags_t saved_flags;
  206. public:
  207. scoped_disable(Tensor::flags_t flags)
  208. : saved_flags(ApplyContext::global_disable) {
  209. ApplyContext::global_disable |= flags;
  210. }
  211. ~scoped_disable() { ApplyContext::global_disable = saved_flags; }
  212. };
  213. };
  214. using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>;
  215. apply_result_t apply(ApplyContext& ctx);
  216. template <typename T>
  217. decltype(auto) resolve_arrow(T&& p) {
  218. if constexpr (std::is_pointer_v<std::remove_reference_t<T>>) {
  219. auto* ret = p;
  220. return ret;
  221. } else {
  222. auto probe = [](auto&& p) -> decltype(p.operator->()) {};
  223. if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) {
  224. return resolve_arrow(p.operator->());
  225. } else {
  226. return std::forward<T>(p);
  227. }
  228. }
  229. }
  230. template <typename... Args>
  231. constexpr bool is_all_tensor_ptr =
  232. (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>);
  233. template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0>
  234. apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) {
  235. ApplyContext ctx;
  236. Tensor* arg_arr[] = {resolve_arrow(args)...};
  237. ctx.flags = (0 | ... | args->m_flags);
  238. ctx.args = arg_arr;
  239. ctx.nargs = sizeof...(args);
  240. ctx.op = std::move(op);
  241. return apply(ctx);
  242. }
  243. inline auto apply(std::shared_ptr<OpDef> op, Tensor* const* args, size_t nargs) {
  244. ApplyContext ctx;
  245. ctx.op = std::move(op);
  246. ctx.nargs = nargs;
  247. ctx.args = args;
  248. for (size_t i = 0; i < nargs; ++i) {
  249. ctx.flags |= args[i]->m_flags;
  250. }
  251. return apply(ctx);
  252. }
  253. template <typename T>
  254. auto apply(std::shared_ptr<OpDef> op, T&& tensors) -> std::enable_if_t<
  255. std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>, apply_result_t> {
  256. size_t nargs = tensors.size();
  257. Tensor* args[nargs];
  258. for (size_t i = 0; i < nargs; ++i) {
  259. args[i] = resolve_arrow(tensors[i]);
  260. }
  261. return apply(op, args, nargs);
  262. }
  263. std::shared_ptr<Tensor> make_const(imperative::TensorPtr value);
  264. inline auto apply(Subgraph graph, Tensor* const* args, size_t nargs) {
  265. SmallVector<std::shared_ptr<Tensor>> inputs;
  266. for (size_t i = 0; i < nargs; ++i) {
  267. inputs.push_back(args[i]->shared_from_this());
  268. }
  269. auto apply_functor = [](std::shared_ptr<OpDef> op,
  270. SmallVector<std::shared_ptr<Tensor>> inputs,
  271. size_t) { return apply(op, std::move(inputs)); };
  272. return graph.apply(inputs, apply_functor, &make_const);
  273. }
  274. template <typename T>
  275. auto apply(Subgraph graph, T&& tensors) -> std::enable_if_t<
  276. std::is_same_v<std::decay_t<decltype(tensors[0])>, Tensor*>, apply_result_t> {
  277. size_t nargs = tensors.size();
  278. Tensor* args[nargs];
  279. for (size_t i = 0; i < nargs; ++i) {
  280. args[i] = resolve_arrow(tensors[i]);
  281. }
  282. return apply(graph, args, nargs);
  283. }
  284. void init_tensor(pybind11::module);
  285. extern PyObject* cpp_apply_with_tracing;
  286. extern PyObject* cpp_apply_backward_varnode;
  287. extern PyObject* cpp_apply_module_trace;
  288. } // namespace mgb::imperative::python
  289. namespace pybind11::detail {
  290. template <>
  291. struct type_caster<mgb::imperative::python::TensorWrapper>
  292. : mgb::imperative::python::TensorWrapper::wrap_t::caster {};
  293. } // namespace pybind11::detail

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台