You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpreter_impl.cpp 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. /**
  2. * \file imperative/src/impl/interpreter_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./interpreter_impl.h"
  12. using namespace mgb;
  13. using namespace imperative;
  14. using namespace interpreter;
  15. using namespace interpreter::intl;
  16. std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
  17. return std::make_unique<ChannelImpl>();
  18. }
  19. Interpreter& Interpreter::inst() {
  20. static InterpreterImpl inst_;
  21. return inst_;
  22. }
  23. void* ChannelImpl::put(const HostTensorND& value) {
  24. auto info = alloc();
  25. info->desc.layout = value.layout();
  26. info->desc.comp_node = value.comp_node();
  27. info->desc.value = value.proxy_to_default_cpu();
  28. m_valid_handle.insert(info);
  29. m_worker.add_task(Put{info, value});
  30. return info;
  31. }
  32. void* ChannelImpl::put(const DeviceTensorND& data) {
  33. auto info = alloc();
  34. info->desc.layout = data.layout();
  35. info->desc.comp_node = data.comp_node();
  36. info->ptr = Tensor::make(data);
  37. m_valid_handle.insert(info);
  38. return info;
  39. }
  40. void ChannelImpl::del(void* handle) {
  41. mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle);
  42. m_worker.add_task(Del{reinterpret_cast<TensorInfo*>(handle)});
  43. }
  44. SmallVector<void*> ChannelImpl::apply_op(
  45. std::shared_ptr<OpDef> op,
  46. const SmallVector<void*>& inputs) {
  47. SmallVector<LogicalTensorDesc> input_descs;
  48. input_descs.reserve(inputs.size());
  49. for (auto h : inputs) {
  50. auto info = reinterpret_cast<TensorInfo*>(h);
  51. input_descs.push_back(info->desc);
  52. }
  53. auto output_descs = OpDef::infer_output_attrs_fallible(*op, input_descs);
  54. ApplyOp cmd{std::move(op)};
  55. cmd.inputs.reserve(inputs.size());
  56. for (auto i : inputs) {
  57. cmd.inputs.push_back(reinterpret_cast<TensorInfo*>(i));
  58. }
  59. cmd.outputs.reserve(output_descs.size());
  60. SmallVector<void*> outputs;
  61. for (auto&& desc : output_descs) {
  62. auto info = alloc();
  63. info->desc = desc;
  64. m_valid_handle.insert(info);
  65. cmd.outputs.push_back(info);
  66. outputs.push_back(info);
  67. }
  68. m_worker.add_task(std::move(cmd));
  69. return outputs;
  70. }
  71. HostTensorND ChannelImpl::get_value(void* handle) {
  72. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  73. "invalid handle: %p", handle);
  74. auto info = reinterpret_cast<TensorInfo*>(handle);
  75. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  76. mgb_assert(!m_waitee);
  77. if (!info->value_fetched) {
  78. m_waitee = info;
  79. m_worker.add_task(GetValue{info});
  80. m_cv.wait(lock, [&]() {
  81. check_worker_exc_unsafe();
  82. return info->value_fetched;
  83. });
  84. m_waitee = nullptr;
  85. }
  86. mgb_assert(info->ptr->value_fetched());
  87. return info->ptr->get_value();
  88. }
  89. TensorShape ChannelImpl::get_shape(void* handle) {
  90. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  91. "invalid handle: %p", handle);
  92. auto info = reinterpret_cast<TensorInfo*>(handle);
  93. if (info->desc.layout.ndim != 0) {
  94. return info->desc.layout;
  95. }
  96. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  97. mgb_assert(!m_waitee);
  98. m_waitee = info;
  99. m_cv.wait(lock, [&]() {
  100. check_worker_exc_unsafe();
  101. return bool(info->ptr);
  102. });
  103. m_waitee = nullptr;
  104. TensorShape ret = info->ptr->layout();
  105. mgb_assert(ret.ndim != 0);
  106. return ret;
  107. }
  108. DType ChannelImpl::get_dtype(void* handle) {
  109. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  110. "invalid handle: %p", handle);
  111. auto info = reinterpret_cast<TensorInfo*>(handle);
  112. auto ret = info->desc.layout.dtype;
  113. mgb_assert(ret.valid());
  114. return ret;
  115. }
  116. CompNode ChannelImpl::get_device(void* handle) {
  117. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  118. "invalid handle: %p", handle);
  119. auto info = reinterpret_cast<TensorInfo*>(handle);
  120. auto ret = info->desc.comp_node;
  121. mgb_assert(ret.valid());
  122. return ret;
  123. }
  124. DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) {
  125. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  126. "invalid handle: %p", handle);
  127. auto info = reinterpret_cast<TensorInfo*>(handle);
  128. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  129. mgb_assert(!m_waitee);
  130. m_waitee = info;
  131. m_cv.wait(lock, [&]() {
  132. check_worker_exc_unsafe();
  133. return bool(info->ptr);
  134. });
  135. m_waitee = nullptr;
  136. return info->ptr->dev_tensor();
  137. }
  138. void ChannelImpl::sync() {
  139. m_worker.wait_all_task_finish();
  140. MGB_LOCK_GUARD(m_mutex);
  141. check_worker_exc_unsafe();
  142. }
  143. void ChannelImpl::close() {
  144. sync();
  145. }
  146. void ChannelImpl::config_async_level(int level) {
  147. mgb_assert(0);
  148. }
  149. TensorInfo* ChannelImpl::alloc() {
  150. MGB_LOCK_GUARD(m_mutex);
  151. return m_pool.alloc();
  152. }
  153. void ChannelImpl::free(TensorInfo* ptr) {
  154. MGB_LOCK_GUARD(m_mutex);
  155. m_pool.free(ptr);
  156. }
  157. ChannelImpl::~ChannelImpl() {
  158. close();
  159. }
  160. void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
  161. MGB_LOCK_GUARD(m_mutex);
  162. dest->value_fetched = ptr->value_fetched();
  163. dest->ptr = std::move(ptr);
  164. if (m_waitee == dest) {
  165. m_cv.notify_all();
  166. }
  167. }
  168. void ChannelImpl::process_one_task(Command& cmd) {
  169. //TODO: remove std::visit for support osx 10.12
  170. std::visit([this](auto& cmd) {
  171. using T = std::remove_reference_t<decltype(cmd)>;
  172. try {
  173. if constexpr (std::is_same_v<T, Put>) {
  174. produce_tensor(cmd.dest, Tensor::make(cmd.value));
  175. } else if constexpr (std::is_same_v<T, ApplyOp>) {
  176. SmallVector<TensorPtr> tensor_inputs;
  177. tensor_inputs.reserve(cmd.inputs.size());
  178. for (auto i : cmd.inputs) {
  179. tensor_inputs.push_back(i->ptr);
  180. }
  181. auto tensor_outputs = OpDef::apply_on_physical_tensor(*cmd.op, tensor_inputs);
  182. mgb_assert(tensor_outputs.size() == cmd.outputs.size());
  183. for (size_t i = 0; i < tensor_outputs.size(); ++i) {
  184. produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i]));
  185. }
  186. } else if constexpr (std::is_same_v<T, Del>) {
  187. free(cmd.dest);
  188. } else if constexpr (std::is_same_v<T, GetValue>) {
  189. cmd.dest->ptr->fetch_value();
  190. MGB_LOCK_GUARD(m_mutex);
  191. cmd.dest->value_fetched = true;
  192. if (m_waitee == cmd.dest) {
  193. m_cv.notify_all();
  194. }
  195. } else {
  196. static_assert(!std::is_same_v<T, T>);
  197. }
  198. } catch (...) {
  199. MGB_LOCK_GUARD(m_mutex);
  200. m_worker_exc = std::current_exception();
  201. m_cv.notify_all();
  202. }
  203. }, cmd);
  204. }
  205. void ChannelImpl::check_worker_exc_unsafe() {
  206. if (m_worker_exc) {
  207. std::exception_ptr exc;
  208. std::swap(exc, m_worker_exc);
  209. std::rethrow_exception(exc);
  210. }
  211. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台