You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

handle.cpp 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. /**
  2. * \file dnn/src/rocm/handle.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "hcc_detail/hcc_defs_prologue.h"
  12. #include "src/common/handle_impl.h"
  13. #include "src/common/version_symbol.h"
  14. #include "src/rocm/handle.h"
  15. #include "src/rocm/miopen_with_check.h"
  16. #include "src/rocm/utils.h"
  17. #include "src/rocm/checksum/opr_impl.h"
  18. #include "src/rocm/convolution/opr_impl.h"
  19. #include "src/rocm/elemwise/opr_impl.h"
  20. #include "src/rocm/eye/opr_impl.h"
  21. #include "src/rocm/pooling/opr_impl.h"
  22. #include "src/rocm/reduce/opr_impl.h"
  23. #include "src/rocm/type_cvt/opr_impl.h"
  24. #include "src/rocm/add_update/opr_impl.h"
  25. #include "src/rocm/matrix_mul/opr_impl.h"
  26. #include "src/rocm/batched_matrix_mul/opr_impl.h"
  27. #include "src/rocm/indexing_one_hot/opr_impl.h"
  28. #include "src/rocm/rng/opr_impl.h"
  29. #include "src/rocm/relayout/opr_impl.h"
  30. #include "src/rocm/powc/opr_impl.h"
  31. #include "src/rocm/indexing_multi_axis_vec/opr_impl.h"
  32. #include "src/rocm/linspace/opr_impl.h"
  33. #include "src/rocm/argmxx/opr_impl.h"
  34. #include "src/rocm/sleep/opr_impl.h"
  35. #include "src/rocm/batch_normalization/opr_impl.h"
  36. #include "src/rocm/param_pack/opr_impl.h"
  37. #include "src/rocm/fill/opr_impl.h"
  38. #include <miopen/version.h>
  39. #include <hip/hip_version.h>
  40. #include <cstring>
  41. #define STR_HELPER(x) #x
  42. #define STR(x) STR_HELPER(x)
  43. #define MIOPEN_VERSION_STR \
  44. STR(MIOPEN_VERSION_MAJOR) \
  45. "." STR(MIOPEN_VERSION_MINOR) "." STR(MIOPEN_VERSION_PATCH)
  46. #pragma message "compile with MIOpen " MIOPEN_VERSION_STR " "
  47. #undef STR
  48. #undef STR_HELPER
  49. namespace megdnn {
  50. std::unique_ptr<Handle> Handle::make_rocm_handle(megcoreComputingHandle_t computing_handle) {
  51. return std::make_unique<rocm::HandleImpl>(computing_handle);
  52. }
  53. template <typename Opr>
  54. std::unique_ptr<Opr> Handle::create_rocm_operator() {
  55. return static_cast<rocm::HandleImpl*>(this)->create_operator<Opr>();
  56. }
  57. #define INST(opr) \
  58. template std::unique_ptr<opr> Handle::create_rocm_operator();
  59. MEGDNN_FOREACH_OPR_CLASS(INST)
  60. #undef INST
  61. }
  62. namespace megdnn {
  63. namespace rocm {
  64. HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle)
  65. : HandleImplHelper(comp_handle, HandleType::ROCM) {
  66. // Get megcore device handle
  67. megcoreDeviceHandle_t dev_handle;
  68. megcoreGetDeviceHandle(comp_handle, &dev_handle);
  69. int dev_id;
  70. megcoreGetDeviceID(dev_handle, &dev_id);
  71. if (dev_id < 0) {
  72. hip_check(hipGetDevice(&dev_id));
  73. }
  74. m_device_id = dev_id;
  75. hip_check(hipGetDeviceProperties(&m_device_prop, dev_id));
  76. // Get stream from MegCore computing handle.
  77. //! no version check
  78. megcore::getROCMContext(comp_handle, &m_megcore_context);
  79. rocblas_check(rocblas_create_handle(&m_rocblas_handle));
  80. //! must call miopenCreateWithStream() to create miopen handle, then the
  81. //! rocblas_handle of miopen will set to be the same stream , otherwise
  82. //! miopen create rocblas_handle with default stream
  83. miopen_check(miopenCreateWithStream(&m_miopen_handle, stream()));
  84. // Set stream for miopen and rocblas handles.
  85. rocblas_check(rocblas_set_stream(m_rocblas_handle, stream()));
  86. // Note that all rocblas scalars (alpha, beta) and scalar results such as
  87. // dot output resides at device side.
  88. rocblas_check(rocblas_set_pointer_mode(m_rocblas_handle,
  89. rocblas_pointer_mode_device));
  90. // init const scalars
  91. hip_check(hipMalloc(&m_const_scalars, sizeof(ConstScalars)));
  92. ConstScalars const_scalars_val;
  93. const_scalars_val.init();
  94. hip_check(hipMemcpyAsync(m_const_scalars, &const_scalars_val,
  95. sizeof(ConstScalars), hipMemcpyHostToDevice,
  96. stream()));
  97. hip_check(hipStreamSynchronize(stream()));
  98. }
  99. HandleImpl::~HandleImpl() noexcept {
  100. miopen_check(miopenDestroy(m_miopen_handle));
  101. rocblas_check(rocblas_destroy_handle(m_rocblas_handle));
  102. hip_check(hipFree(m_const_scalars));
  103. }
  104. void HandleImpl::ConstScalars::init() {
  105. #if !MEGDNN_DISABLE_FLOAT16
  106. f16[0].megdnn_x = 0;
  107. f16[1].megdnn_x = 1;
  108. #endif
  109. f32[0] = 0;
  110. f32[1] = 1;
  111. i32[0] = 0;
  112. i32[1] = 1;
  113. }
  114. template <typename Opr>
  115. std::unique_ptr<Opr> HandleImpl::create_operator() {
  116. megdnn_throw("unsupported rocm opr");
  117. return nullptr;
  118. }
  119. size_t HandleImpl::alignment_requirement() const {
  120. auto&& prop = m_device_prop;
  121. MEGDNN_MARK_USED_VAR(prop);
  122. //! for now, texture functions are not supported.
  123. return 1u;
  124. }
  125. bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) {
  126. // is contiguous or can be hold by
  127. // relayout::param::try_copy_2d/try_copy_last_contig
  128. return src.is_contiguous() || src.stride[src.ndim - 1] == 1;
  129. }
  130. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionForward);
  131. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData);
  132. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardFilter);
  133. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ElemwiseForward);
  134. MEGDNN_SPECIALIZE_CREATE_OPERATOR(Eye);
  135. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ChecksumForward);
  136. MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingForward);
  137. MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingBackward);
  138. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward);
  139. MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt);
  140. MEGDNN_SPECIALIZE_CREATE_OPERATOR(AddUpdateForward);
  141. MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixMulForward);
  142. MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward);
  143. MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingOneHotForward);
  144. MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingSetOneHotForward);
  145. MEGDNN_SPECIALIZE_CREATE_OPERATOR(UniformRNG);
  146. MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianRNG);
  147. MEGDNN_SPECIALIZE_CREATE_OPERATOR(RelayoutForward);
  148. MEGDNN_SPECIALIZE_CREATE_OPERATOR(PowC);
  149. MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingMultiAxisVec);
  150. MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingSetMultiAxisVec);
  151. MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingIncrMultiAxisVec);
  152. MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace);
  153. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward);
  154. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward);
  155. MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward);
  156. MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward);
  157. MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward);
  158. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat);
  159. MEGDNN_SPECIALIZE_CREATE_OPERATOR(Fill);
  160. #pragma GCC diagnostic push
  161. #pragma GCC diagnostic ignored "-Wpragmas"
  162. #pragma GCC diagnostic ignored "-Winstantiation-after-specialization"
  163. MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR)
  164. #pragma GCC diagnostic pop
  165. } // namespace rocm
  166. } // namespace megdnn
  167. MEGDNN_VERSION_SYMBOL3(HIP, HIP_VERSION_MAJOR, HIP_VERSION_MINOR, HIP_VERSION_PATCH);
  168. MEGDNN_VERSION_SYMBOL3(MIOPEN, MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR,
  169. MIOPEN_VERSION_PATCH);
  170. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台