You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpreter_impl.cpp 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. /**
  2. * \file imperative/src/impl/interpreter_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./interpreter_impl.h"
  12. using namespace mgb;
  13. using namespace imperative;
  14. using namespace interpreter;
  15. using namespace interpreter::intl;
  16. std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
  17. return std::make_unique<ChannelImpl>();
  18. }
  19. Interpreter& Interpreter::inst() {
  20. static InterpreterImpl inst_;
  21. return inst_;
  22. }
  23. void* ChannelImpl::put(const HostTensorND& value) {
  24. auto info = alloc();
  25. info->desc.layout = value.layout();
  26. info->desc.comp_node = value.comp_node();
  27. info->desc.value = value.proxy_to_default_cpu();
  28. m_valid_handle.insert(info);
  29. m_worker.add_task(Put{info, value});
  30. return info;
  31. }
  32. void* ChannelImpl::put(const DeviceTensorND& data) {
  33. auto info = alloc();
  34. info->desc.layout = data.layout();
  35. info->desc.comp_node = data.comp_node();
  36. info->ptr = Tensor::make(data);
  37. m_valid_handle.insert(info);
  38. return info;
  39. }
  40. void ChannelImpl::del(void* handle) {
  41. mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle);
  42. m_worker.add_task(Del{reinterpret_cast<TensorInfo*>(handle)});
  43. }
  44. SmallVector<void*> ChannelImpl::apply_op(
  45. std::shared_ptr<OpDef> op,
  46. const SmallVector<void*>& inputs) {
  47. SmallVector<TensorInfo*> input_infos;
  48. input_infos.reserve(inputs.size());
  49. SmallVector<LogicalTensorDesc> input_descs;
  50. input_descs.reserve(inputs.size());
  51. for (auto i : inputs) {
  52. auto info = reinterpret_cast<TensorInfo*>(i);
  53. input_infos.push_back(info);
  54. input_descs.push_back(info->desc);
  55. }
  56. auto output_descs = OpDef::infer_output_attrs_fallible(*op, input_descs);
  57. ApplyOp cmd{std::move(op)};
  58. cmd.inputs = std::move(input_infos);
  59. cmd.outputs.reserve(output_descs.size());
  60. SmallVector<void*> outputs;
  61. bool is_fallible = false;
  62. for (auto&& desc : output_descs) {
  63. if (desc.layout.ndim == 0) {
  64. is_fallible = true;
  65. }
  66. auto info = alloc();
  67. info->desc = desc;
  68. m_valid_handle.insert(info);
  69. cmd.outputs.push_back(info);
  70. outputs.push_back(info);
  71. }
  72. m_worker.add_task(std::move(cmd));
  73. if (is_fallible && m_async_level <= 1) {
  74. sync();
  75. }
  76. return outputs;
  77. }
  78. HostTensorND ChannelImpl::get_value(void* handle) {
  79. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  80. "invalid handle: %p", handle);
  81. auto info = reinterpret_cast<TensorInfo*>(handle);
  82. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  83. mgb_assert(!m_waitee);
  84. if (!info->value_fetched) {
  85. m_waitee = info;
  86. m_worker.add_task(GetValue{info});
  87. m_cv.wait(lock, [&]() {
  88. check_worker_exc_unsafe();
  89. return info->value_fetched;
  90. });
  91. m_waitee = nullptr;
  92. }
  93. mgb_assert(info->ptr->value_fetched());
  94. return info->ptr->get_value();
  95. }
  96. TensorShape ChannelImpl::get_shape(void* handle) {
  97. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  98. "invalid handle: %p", handle);
  99. auto info = reinterpret_cast<TensorInfo*>(handle);
  100. if (info->desc.layout.ndim != 0) {
  101. return info->desc.layout;
  102. }
  103. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  104. mgb_assert(!m_waitee);
  105. m_waitee = info;
  106. m_cv.wait(lock, [&]() {
  107. check_worker_exc_unsafe();
  108. return bool(info->ptr);
  109. });
  110. m_waitee = nullptr;
  111. TensorShape ret = info->ptr->layout();
  112. mgb_assert(ret.ndim != 0);
  113. return ret;
  114. }
  115. DType ChannelImpl::get_dtype(void* handle) {
  116. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  117. "invalid handle: %p", handle);
  118. auto info = reinterpret_cast<TensorInfo*>(handle);
  119. auto ret = info->desc.layout.dtype;
  120. mgb_assert(ret.valid());
  121. return ret;
  122. }
  123. CompNode ChannelImpl::get_device(void* handle) {
  124. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  125. "invalid handle: %p", handle);
  126. auto info = reinterpret_cast<TensorInfo*>(handle);
  127. auto ret = info->desc.comp_node;
  128. mgb_assert(ret.valid());
  129. return ret;
  130. }
  131. DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) {
  132. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  133. "invalid handle: %p", handle);
  134. auto info = reinterpret_cast<TensorInfo*>(handle);
  135. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  136. mgb_assert(!m_waitee);
  137. m_waitee = info;
  138. m_cv.wait(lock, [&]() {
  139. check_worker_exc_unsafe();
  140. return bool(info->ptr);
  141. });
  142. m_waitee = nullptr;
  143. return info->ptr->dev_tensor();
  144. }
  145. void ChannelImpl::sync() {
  146. m_worker.wait_all_task_finish();
  147. MGB_LOCK_GUARD(m_mutex);
  148. check_worker_exc_unsafe();
  149. }
  150. void ChannelImpl::close() {
  151. sync();
  152. }
  153. void ChannelImpl::config_async_level(int level) {
  154. mgb_assert(level <= 2 and level >= 0, "async_level should be 0, 1 or 2");
  155. m_async_level = level;
  156. }
  157. int ChannelImpl::get_async_level() {
  158. return m_async_level;
  159. }
  160. TensorInfo* ChannelImpl::alloc() {
  161. MGB_LOCK_GUARD(m_mutex);
  162. return m_pool.alloc();
  163. }
  164. void ChannelImpl::free(TensorInfo* ptr) {
  165. MGB_LOCK_GUARD(m_mutex);
  166. m_pool.free(ptr);
  167. }
  168. ChannelImpl::~ChannelImpl() {
  169. close();
  170. }
  171. void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
  172. MGB_LOCK_GUARD(m_mutex);
  173. dest->value_fetched = ptr->value_fetched();
  174. dest->ptr = std::move(ptr);
  175. if (m_waitee == dest) {
  176. m_cv.notify_all();
  177. }
  178. }
  179. void ChannelImpl::process_one_task(Command& cmd) {
  180. //TODO: remove std::visit for support osx 10.12
  181. std::visit([this](auto& cmd) {
  182. using T = std::remove_reference_t<decltype(cmd)>;
  183. try {
  184. if constexpr (std::is_same_v<T, Put>) {
  185. produce_tensor(cmd.dest, Tensor::make(cmd.value));
  186. } else if constexpr (std::is_same_v<T, ApplyOp>) {
  187. SmallVector<TensorPtr> tensor_inputs;
  188. tensor_inputs.reserve(cmd.inputs.size());
  189. for (auto i : cmd.inputs) {
  190. tensor_inputs.push_back(i->ptr);
  191. }
  192. auto tensor_outputs = OpDef::apply_on_physical_tensor(*cmd.op, tensor_inputs);
  193. mgb_assert(tensor_outputs.size() == cmd.outputs.size());
  194. for (size_t i = 0; i < tensor_outputs.size(); ++i) {
  195. produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i]));
  196. }
  197. } else if constexpr (std::is_same_v<T, Del>) {
  198. free(cmd.dest);
  199. } else if constexpr (std::is_same_v<T, GetValue>) {
  200. cmd.dest->ptr->fetch_value();
  201. MGB_LOCK_GUARD(m_mutex);
  202. cmd.dest->value_fetched = true;
  203. if (m_waitee == cmd.dest) {
  204. m_cv.notify_all();
  205. }
  206. } else {
  207. static_assert(!std::is_same_v<T, T>);
  208. }
  209. } catch (...) {
  210. MGB_LOCK_GUARD(m_mutex);
  211. m_worker_exc = std::current_exception();
  212. m_cv.notify_all();
  213. }
  214. }, cmd);
  215. }
  216. void ChannelImpl::check_worker_exc_unsafe() {
  217. if (m_worker_exc) {
  218. std::exception_ptr exc;
  219. std::swap(exc, m_worker_exc);
  220. std::rethrow_exception(exc);
  221. }
  222. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台