You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpreter_impl.cpp 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. /**
  2. * \file imperative/src/impl/interpreter_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./interpreter_impl.h"
  12. using namespace mgb;
  13. using namespace imperative;
  14. using namespace interpreter;
  15. using namespace interpreter::intl;
  16. std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
  17. return std::make_unique<ChannelImpl>();
  18. }
  19. Interpreter& Interpreter::inst() {
  20. static InterpreterImpl inst_;
  21. return inst_;
  22. }
  23. void* ChannelImpl::put(const HostTensorND& value) {
  24. auto info = alloc();
  25. info->desc.layout = value.layout();
  26. info->desc.comp_node = value.comp_node();
  27. info->desc.value = value.proxy_to_default_cpu();
  28. m_valid_handle.insert(info);
  29. m_worker.add_task(Put{info, value});
  30. return info;
  31. }
  32. void ChannelImpl::del(void* handle) {
  33. mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle);
  34. m_worker.add_task(Del{reinterpret_cast<TensorInfo*>(handle)});
  35. }
  36. SmallVector<void*> ChannelImpl::apply_op(
  37. std::shared_ptr<OpDef> op,
  38. const SmallVector<void*>& inputs) {
  39. SmallVector<LogicalTensorDesc> input_descs;
  40. input_descs.reserve(inputs.size());
  41. for (auto h : inputs) {
  42. auto info = reinterpret_cast<TensorInfo*>(h);
  43. input_descs.push_back(info->desc);
  44. }
  45. auto output_descs = OpDef::infer_output_attrs_fallible(*op, input_descs);
  46. ApplyOp cmd{std::move(op)};
  47. cmd.inputs.reserve(inputs.size());
  48. for (auto i : inputs) {
  49. cmd.inputs.push_back(reinterpret_cast<TensorInfo*>(i));
  50. }
  51. cmd.outputs.reserve(output_descs.size());
  52. SmallVector<void*> outputs;
  53. for (auto&& desc : output_descs) {
  54. auto info = alloc();
  55. info->desc = desc;
  56. m_valid_handle.insert(info);
  57. cmd.outputs.push_back(info);
  58. outputs.push_back(info);
  59. }
  60. m_worker.add_task(std::move(cmd));
  61. return outputs;
  62. }
  63. HostTensorND ChannelImpl::get_value(void* handle) {
  64. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  65. "invalid handle: %p", handle);
  66. auto info = reinterpret_cast<TensorInfo*>(handle);
  67. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  68. mgb_assert(!m_waitee);
  69. if (!info->value_fetched) {
  70. m_waitee = info;
  71. m_worker.add_task(GetValue{info});
  72. m_cv.wait(lock, [&]() {
  73. check_worker_exc_unsafe();
  74. return info->value_fetched;
  75. });
  76. m_waitee = nullptr;
  77. }
  78. mgb_assert(info->ptr->value_fetched());
  79. return info->ptr->get_value();
  80. }
  81. TensorShape ChannelImpl::get_shape(void* handle) {
  82. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  83. "invalid handle: %p", handle);
  84. auto info = reinterpret_cast<TensorInfo*>(handle);
  85. if (info->desc.layout.ndim != 0) {
  86. return info->desc.layout;
  87. }
  88. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  89. mgb_assert(!m_waitee);
  90. m_waitee = info;
  91. m_cv.wait(lock, [&]() {
  92. check_worker_exc_unsafe();
  93. return bool(info->ptr);
  94. });
  95. m_waitee = nullptr;
  96. TensorShape ret = info->ptr->layout();
  97. mgb_assert(ret.ndim != 0);
  98. return ret;
  99. }
  100. DType ChannelImpl::get_dtype(void* handle) {
  101. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  102. "invalid handle: %p", handle);
  103. auto info = reinterpret_cast<TensorInfo*>(handle);
  104. auto ret = info->desc.layout.dtype;
  105. mgb_assert(ret.valid());
  106. return ret;
  107. }
  108. CompNode ChannelImpl::get_device(void* handle) {
  109. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  110. "invalid handle: %p", handle);
  111. auto info = reinterpret_cast<TensorInfo*>(handle);
  112. auto ret = info->desc.comp_node;
  113. mgb_assert(ret.valid());
  114. return ret;
  115. }
  116. DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) {
  117. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  118. "invalid handle: %p", handle);
  119. auto info = reinterpret_cast<TensorInfo*>(handle);
  120. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  121. mgb_assert(!m_waitee);
  122. m_waitee = info;
  123. m_cv.wait(lock, [&]() {
  124. check_worker_exc_unsafe();
  125. return bool(info->ptr);
  126. });
  127. m_waitee = nullptr;
  128. return info->ptr->dev_tensor();
  129. }
  130. void ChannelImpl::sync() {
  131. m_worker.wait_all_task_finish();
  132. MGB_LOCK_GUARD(m_mutex);
  133. check_worker_exc_unsafe();
  134. }
  135. void ChannelImpl::close() {
  136. sync();
  137. }
  138. void ChannelImpl::config_async_level(int level) {
  139. mgb_assert(0);
  140. }
  141. TensorInfo* ChannelImpl::alloc() {
  142. MGB_LOCK_GUARD(m_mutex);
  143. return m_pool.alloc();
  144. }
  145. void ChannelImpl::free(TensorInfo* ptr) {
  146. MGB_LOCK_GUARD(m_mutex);
  147. m_pool.free(ptr);
  148. }
  149. ChannelImpl::~ChannelImpl() {}
  150. void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
  151. MGB_LOCK_GUARD(m_mutex);
  152. dest->value_fetched = ptr->value_fetched();
  153. dest->ptr = std::move(ptr);
  154. if (m_waitee == dest) {
  155. m_cv.notify_all();
  156. }
  157. }
  158. void ChannelImpl::process_one_task(Command& cmd) {
  159. std::visit([this](auto& cmd) {
  160. using T = std::remove_reference_t<decltype(cmd)>;
  161. try {
  162. if constexpr (std::is_same_v<T, Put>) {
  163. produce_tensor(cmd.dest, Tensor::make(cmd.value));
  164. } else if constexpr (std::is_same_v<T, ApplyOp>) {
  165. SmallVector<TensorPtr> tensor_inputs;
  166. tensor_inputs.reserve(cmd.inputs.size());
  167. for (auto i : cmd.inputs) {
  168. tensor_inputs.push_back(i->ptr);
  169. }
  170. auto tensor_outputs = OpDef::apply_on_physical_tensor(*cmd.op, tensor_inputs);
  171. mgb_assert(tensor_outputs.size() == cmd.outputs.size());
  172. for (size_t i = 0; i < tensor_outputs.size(); ++i) {
  173. produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i]));
  174. }
  175. } else if constexpr (std::is_same_v<T, Del>) {
  176. free(cmd.dest);
  177. } else if constexpr (std::is_same_v<T, GetValue>) {
  178. cmd.dest->ptr->fetch_value();
  179. MGB_LOCK_GUARD(m_mutex);
  180. cmd.dest->value_fetched = true;
  181. if (m_waitee == cmd.dest) {
  182. m_cv.notify_all();
  183. }
  184. } else {
  185. static_assert(!std::is_same_v<T, T>);
  186. }
  187. } catch (...) {
  188. MGB_LOCK_GUARD(m_mutex);
  189. m_worker_exc = std::current_exception();
  190. m_cv.notify_all();
  191. }
  192. }, cmd);
  193. }
  194. void ChannelImpl::check_worker_exc_unsafe() {
  195. if (m_worker_exc) {
  196. std::exception_ptr exc;
  197. std::swap(exc, m_worker_exc);
  198. std::rethrow_exception(exc);
  199. }
  200. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台