You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.h 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. /**
  2. * \file imperative/python/src/tensor.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <variant>
  13. #include "megbrain/imperative/interpreter.h"
  14. #include "pybind11/pybind11.h"
  15. #include "./pyext17.h"
  16. namespace mgb::imperative::python {
  17. template<typename T, typename B = pybind11::object>
  18. struct ObjectPtr : B {
  19. using B::B;
  20. T& operator*() {return reinterpret_cast<T&>(*B::ptr());}
  21. T* operator->() {return reinterpret_cast<T*>(B::ptr());}
  22. };
  23. } // namespace mgb::imperative::python
  24. #include "./grad_info.h" // for struct GradInfo
  25. #include "./trace_info.h" // for struct TraceInfo
  26. namespace mgb::imperative::python {
  27. extern std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py;
  28. class SharedHandle {
  29. using Handle = interpreter::Interpreter::Handle;
  30. static_assert(std::is_pointer_v<Handle>);
  31. std::shared_ptr<std::remove_pointer_t<Handle>> holder;
  32. public:
  33. inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){
  34. if (h) {
  35. interpreter_for_py->del(h);
  36. }
  37. }) {}
  38. SharedHandle(const SharedHandle&) = default;
  39. SharedHandle& operator=(const SharedHandle&) = default;
  40. SharedHandle(SharedHandle&&) = default;
  41. SharedHandle& operator=(SharedHandle&&) = default;
  42. inline Handle get() {return holder.get();}
  43. };
  44. struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
  45. using flags_t = uint64_t;
  46. struct Flags {
  47. static constexpr flags_t SCALAR = 1;
  48. static constexpr flags_t GRAD = 1 << 1;
  49. static constexpr flags_t TRACE = 1 << 2;
  50. };
  51. flags_t m_flags = 0;
  52. GradInfo m_grad_info;
  53. TraceInfo m_trace_info;
  54. SharedHandle m_handle;
  55. cg::VarNode* m_var;
  56. using Handle = interpreter::Interpreter::Handle;
  57. inline Tensor() : m_handle(nullptr), m_var(nullptr) {}
  58. inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {}
  59. inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)), m_var(nullptr) {}
  60. inline explicit Tensor(cg::VarNode *var) : m_handle(nullptr), m_var(var) {}
  61. ~Tensor() = default;
  62. inline std::shared_ptr<Tensor> copy() {
  63. auto ret = std::make_shared<Tensor>(m_handle);
  64. ret->m_flags = m_flags;
  65. ret->m_grad_info = m_grad_info;
  66. ret->m_trace_info = m_trace_info;
  67. ret->m_var = m_var;
  68. return ret;
  69. }
  70. inline DType dtype() {
  71. if (m_var) {
  72. return m_var->dtype();
  73. }
  74. return interpreter_for_py->get_dtype(m_handle.get());
  75. }
  76. inline CompNode comp_node() {
  77. if (m_var) {
  78. return m_var->comp_node();
  79. }
  80. return interpreter_for_py->get_device(m_handle.get());
  81. }
  82. inline TensorShape shape() {
  83. if (m_var) {
  84. return m_var->shape();
  85. }
  86. return interpreter_for_py->get_shape(m_handle.get());
  87. }
  88. };
  89. struct TensorWrapper {
  90. std::shared_ptr<Tensor> m_tensor;
  91. inline TensorWrapper(std::shared_ptr<Tensor> tensor = {}) : m_tensor(std::move(tensor)) {}
  92. TensorWrapper(PyObject* args, PyObject* kwargs);
  93. ~TensorWrapper() = default;
  94. static constexpr auto tp_name = pybind11::detail::_("Tensor");
  95. using wrap_t = pyext17::wrap<TensorWrapper>;
  96. friend wrap_t;
  97. inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();}
  98. inline static TensorWrapper* cast_safe(PyObject* op) {
  99. if (!wrap_t::type().isinstance(op)) return nullptr;
  100. return cast(op);
  101. }
  102. inline ObjectPtr<TensorWrapper, pybind11::handle> self() {return wrap_t::pycast(this);}
  103. template <typename... Args>
  104. static ObjectPtr<Tensor> make(Args&&... args) {
  105. auto* op = wrap_t::cnew(std::forward<Args>(args)...);
  106. return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
  107. }
  108. template <typename... Args>
  109. static ObjectPtr<Tensor> make(PyTypeObject* pytype, Args&&... args) {
  110. auto* op = wrap_t::cnew_with_type(pytype,std::forward<Args>(args)...);
  111. return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
  112. }
  113. PyObject* shape();
  114. PyObject* dtype();
  115. PyObject* device();
  116. PyObject* numpy();
  117. void reset(PyObject*);
  118. PyObject* detach();
  119. PyObject* isscalar();
  120. void setscalar();
  121. PyObject* _dev_tensor();
  122. void _swap_in();
  123. void _swap_out();
  124. void _drop();
  125. PyObject* varnode();
  126. PyObject* handle();
  127. void set_handle(PyObject *);
  128. PyObject* data_read();
  129. PyObject* value_read();
  130. PyObject* shape_read();
  131. PyObject* mixin_handle();
  132. void set_data_read(PyObject*);
  133. void set_value_read(PyObject*);
  134. void set_shape_read(PyObject*);
  135. void set_mixin_handle(PyObject*);
  136. };
  137. PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */);
  138. struct ApplyContext {
  139. Tensor::flags_t flags;
  140. std::shared_ptr<OpDef> op;
  141. Tensor*const* args;
  142. size_t nargs;
  143. bool backward = false;
  144. };
  145. using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>;
  146. apply_result_t apply(ApplyContext& ctx);
  147. void init_tensor(pybind11::module);
  148. extern bool is_tracing;
  149. extern bool is_symbolic;
  150. extern bool is_compiled;
  151. extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode;
  152. extern pybind11::object cpp_apply_backward_varnode;
  153. } // namespace mgb::imperative::python
  154. namespace pybind11::detail {
  155. template<> struct type_caster<mgb::imperative::python::TensorWrapper> : mgb::imperative::python::TensorWrapper::wrap_t::caster {};
  156. } // namespace pybind11::detail

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台