You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

handle.cpp 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. /**
  2. * \file dnn/src/common/handle.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/basic_types.h"
  12. #include "src/common/handle_impl.h"
  13. #include "src/common/utils.h"
  14. #include "src/fallback/handle.h"
  15. #include "src/naive/handle.h"
  16. #include "midout.h"
  17. #if MEGDNN_X86
  18. #include "src/x86/handle.h"
  19. #endif
  20. #if MEGDNN_ARMV7
  21. #include "src/armv7/handle.h"
  22. #endif
  23. #if MEGDNN_AARCH64
  24. #include "src/aarch64/handle.h"
  25. #endif
  26. #if MEGDNN_WITH_CUDA
  27. #include "src/cuda/handle.h"
  28. #endif
  29. #if MEGDNN_WITH_CAMBRICON
  30. #include "src/cambricon/handle.h"
  31. #endif
  32. #ifdef MEGDNN_WITH_ATLAS
  33. #include "src/atlas/handle.h"
  34. #endif
  35. using namespace megdnn;
  36. MIDOUT_DECL(HandlePlatform);
  37. MIDOUT_DECL(HandleOpr);
  38. Handle::Handle(megcoreComputingHandle_t computing_handle, HandleType type)
  39. : m_computing_handle(computing_handle), m_handle_type(type) {}
  40. std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle,
  41. int debug_level) {
  42. (void)debug_level;
  43. megcoreDeviceHandle_t device_handle;
  44. megcorePlatform_t platform;
  45. megcoreGetDeviceHandle(computing_handle, &device_handle);
  46. megcoreGetPlatform(device_handle, &platform);
  47. if (platform == megcorePlatformCPU) {
  48. // only enable midout for CPU, becuase CPU might be unused when some
  49. // other platforms are used
  50. MIDOUT_BEGIN(HandlePlatform, midout_iv(megcorePlatformCPU)) {
  51. // CPU
  52. #if MEGDNN_NAIVE
  53. return make_unique<naive::HandleImpl>(computing_handle);
  54. #else
  55. if (debug_level == 0) {
  56. #if MEGDNN_X86
  57. // Because of ICC bug, we cannot use make_unique here. It will
  58. // trigger an internal compiler error.
  59. return std::unique_ptr<x86::HandleImpl>(
  60. new x86::HandleImpl(computing_handle));
  61. // return make_unique<x86::HandleImpl>(computing_handle);
  62. #elif MEGDNN_ARMV7
  63. return make_unique<armv7::HandleImpl>(computing_handle);
  64. #elif MEGDNN_AARCH64
  65. return make_unique<aarch64::HandleImpl>(computing_handle);
  66. #else
  67. return make_unique<fallback::HandleImpl>(computing_handle);
  68. #endif
  69. } else if (debug_level == 1) {
  70. return make_unique<fallback::HandleImpl>(computing_handle);
  71. } else if (debug_level == 2) {
  72. return make_unique<naive::HandleImpl>(computing_handle);
  73. } else {
  74. megdnn_throw(megdnn_mangle("Debug level must be 0/1/2."));
  75. }
  76. }
  77. MIDOUT_END();
  78. #endif
  79. }
  80. else if (platform == megcorePlatformROCM) {
  81. #if MEGDNN_WITH_ROCM
  82. return make_rocm_handle(computing_handle);
  83. #else
  84. return nullptr;
  85. #endif
  86. }
  87. else if (platform == megcorePlatformCambricon) {
  88. #if MEGDNN_WITH_CAMBRICON
  89. return make_unique<cambricon::HandleImpl>(computing_handle);
  90. #else
  91. return nullptr;
  92. #endif
  93. }
  94. else if (platform == megcorePlatformAtlas) {
  95. #if MEGDNN_WITH_ATLAS
  96. return make_unique<atlas::HandleImpl>(computing_handle);
  97. #else
  98. return nullptr;
  99. #endif
  100. }
  101. else {
  102. // CUDA
  103. megdnn_assert_internal(platform == megcorePlatformCUDA);
  104. #if MEGDNN_WITH_CUDA
  105. return make_unique<cuda::HandleImpl>(computing_handle);
  106. #else
  107. return nullptr;
  108. #endif
  109. }
  110. return nullptr;
  111. }
  112. void Handle::set_destructor(const thin_function<void()>& d) {
  113. megdnn_assert(!m_destructor, "destructor can be set only once");
  114. m_destructor = d;
  115. }
  116. Handle::~Handle() {
  117. if (m_destructor)
  118. m_destructor();
  119. m_alive_magic = 0;
  120. }
  121. size_t Handle::alignment_requirement() const {
  122. // default to 32
  123. return 32;
  124. }
  125. size_t Handle::image2d_pitch_alignment() const {
  126. megdnn_throw("image2d tensor format not supported on this handle");
  127. }
  128. bool Handle::check_cross_dev_copy_constraint(const TensorLayout& src) {
  129. return src.is_contiguous();
  130. }
  131. void Handle::on_opr_destructed(OperatorBase * opr) {
  132. if (m_alive_magic != ALIVE_MAGIC) {
  133. megdnn_log_error(
  134. "Handle is destructed before opr gets destructed. "
  135. "Please fix the destruction order as this would cause "
  136. "undefined memory access. "
  137. "Abort now to avoid further problems.");
  138. abort();
  139. }
  140. if (m_on_opr_destructed) {
  141. m_on_opr_destructed(opr);
  142. }
  143. }
  144. OperatorBase::~OperatorBase() { m_handle->on_opr_destructed(this); }
  145. template <typename Opr>
  146. std::unique_ptr<Opr> Handle::create_operator() {
  147. #define CASE(etype, nm) \
  148. case HandleType::etype: { \
  149. MIDOUT_BEGIN(HandleOpr, Opr, midout_iv(HandleType::etype)) { \
  150. return static_cast<nm::HandleImpl*>(this)->create_operator<Opr>(); \
  151. } \
  152. MIDOUT_END(); \
  153. }
  154. switch (m_handle_type) {
  155. CASE(NAIVE, naive);
  156. #if !MEGDNN_NAIVE
  157. CASE(FALLBACK, fallback);
  158. #if MEGDNN_X86
  159. CASE(X86, x86);
  160. #endif
  161. #if MEGDNN_ARMV7
  162. CASE(ARMV7, armv7);
  163. #endif
  164. #if MEGDNN_AARCH64
  165. CASE(AARCH64, aarch64);
  166. #endif
  167. #if MEGDNN_ARMV7 || MEGDNN_AARCH64
  168. CASE(ARM_COMMON, arm_common);
  169. #endif
  170. #endif // !MEGDNN_NAIVE
  171. #if MEGDNN_WITH_CUDA
  172. CASE(CUDA,cuda);
  173. #endif
  174. #if MEGDNN_WITH_ATLAS
  175. CASE(ATLAS, atlas);
  176. #endif
  177. #if MEGDNN_WITH_ROCM
  178. case HandleType::ROCM: {
  179. MIDOUT_BEGIN(HandleOpr, Opr, midout_iv(HandleType::ROCM)) {
  180. return create_rocm_operator<Opr>();
  181. }
  182. MIDOUT_END();
  183. }
  184. #endif
  185. #if MEGDNN_WITH_CAMBRICON
  186. CASE(CAMBRICON, cambricon);
  187. #endif
  188. default:
  189. megdnn_throw(megdnn_mangle("bad handle type"));
  190. }
  191. #undef CASE
  192. }
  193. #define INST(opr) template std::unique_ptr<opr> Handle::create_operator();
  194. MEGDNN_FOREACH_OPR_CLASS(INST)
  195. #undef INST
  196. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台