You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cuh 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. /**
  2. * \file dnn/src/cuda/utils.cuh
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "include/megdnn/dtype.h"
  14. #include "src/common/utils.cuh"
  15. #include <stdint.h>
  16. #include <cublas_v2.h>
  17. #include <cuda_runtime_api.h>
  18. #include <cusolverDn.h>
  19. #include "cuda.h"
  20. #include "cutlass/cutlass.h"
  21. #include "src/cuda/atomic_add.cuh"
  22. #include "src/cuda/cudnn_with_check.h"
  23. #define cuda_check(_x) \
  24. do { \
  25. cudaError_t _err = (_x); \
  26. if (_err != cudaSuccess) { \
  27. std::string x = std::string(#_x); \
  28. char line[10]; \
  29. sprintf(line, "%d", __LINE__); \
  30. ::megdnn::cuda::__throw_cuda_error__( \
  31. _err, (x + " error file:" + std::string(__FILE__) + ":" + \
  32. std::string(line)) \
  33. .c_str()); \
  34. } \
  35. } while (0)
  36. #define cublas_check(_x) \
  37. do { \
  38. cublasStatus_t _err = (_x); \
  39. if (_err != CUBLAS_STATUS_SUCCESS) { \
  40. ::megdnn::cuda::__throw_cublas_error__(_err, #_x); \
  41. } \
  42. } while (0)
  43. #define cudnn_check(_x) \
  44. do { \
  45. cudnnStatus_t _err = (_x); \
  46. if (_err != CUDNN_STATUS_SUCCESS) { \
  47. ::megdnn::cuda::__throw_cudnn_error__(_err, #_x); \
  48. } \
  49. } while (0)
  50. #define cusolver_check(_x) \
  51. do { \
  52. cusolverStatus_t _err = (_x); \
  53. if (_err != CUSOLVER_STATUS_SUCCESS) { \
  54. ::megdnn::cuda::__throw_cusolver_error__(_err, #_x); \
  55. } \
  56. } while (0)
  57. #define cucheck(_x) \
  58. do { \
  59. CUresult _err = (_x); \
  60. if (_err != CUDA_SUCCESS) { \
  61. ::megdnn::cuda::__throw_cuda_driver_error__(_err, #_x); \
  62. } \
  63. } while (0)
  64. #define cutlass_check(_x) \
  65. do { \
  66. cutlass::Status _err = (_x); \
  67. if (_err != cutlass::Status::kSuccess) { \
  68. ::megdnn::cuda::__throw_cutlass_error__(_err, #_x); \
  69. } \
  70. } while (0)
  71. #define after_kernel_launch() \
  72. do { \
  73. cuda_check(cudaGetLastError()); \
  74. } while (0)
  75. #if MEGDNN_THREADS_512
  76. #define NR_THREADS 512
  77. #define NR_THREADS_X 32
  78. #define NR_THREADS_Y 16
  79. #else
  80. #define NR_THREADS 1024
  81. #define NR_THREADS_X 32
  82. #define NR_THREADS_Y 32
  83. #endif
  84. #define DIVUP(x, y) (((x) + (y)-1) / (y))
  85. #define ROUNDUP(x, y) (DIVUP(x, y) * (y))
  86. #define KERN_FOR(i, n) \
  87. for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
  88. i += blockDim.x * gridDim.x)
  89. namespace megdnn {
  90. namespace cuda {
  91. //! Error handling funcions
  92. MEGDNN_NORETURN void __throw_cuda_error__(cudaError_t err, const char* msg);
  93. MEGDNN_NORETURN void __throw_cudnn_error__(cudnnStatus_t err, const char* msg);
  94. MEGDNN_NORETURN void __throw_cublas_error__(cublasStatus_t err, const char* msg);
  95. MEGDNN_NORETURN void __throw_cusolver_error__(cusolverStatus_t err, const char* msg);
  96. MEGDNN_NORETURN void __throw_cuda_driver_error__(CUresult err, const char* msg);
  97. MEGDNN_NORETURN void __throw_cutlass_error__(cutlass::Status status, const char* msg);
  98. MEGDNN_NORETURN void report_error(const char* msg);
  99. template <typename T, size_t N>
  100. struct array_wrapper {
  101. T data[N];
  102. MEGDNN_DEVICE __forceinline__ T& operator[](size_t pos) {
  103. return reinterpret_cast<T&>(data[pos]);
  104. }
  105. MEGDNN_DEVICE __forceinline__ T const& operator[](size_t pos) const {
  106. return reinterpret_cast<T const&>(data[pos]);
  107. }
  108. };
  109. /*!
  110. * \brief convert size to uint32_t and check for not overflow
  111. *
  112. * throw exception with human readable message if size not in the interval (0,
  113. * Uint32Fastdiv::MAX_DIVIDEND)
  114. */
  115. uint32_t safe_size_in_kern(size_t size);
  116. #ifdef __CUDACC__
  117. template <typename T>
  118. inline __device__ void fill_shared_mem(T* shared, uint32_t n, const T& val) {
  119. uint32_t stride = blockDim.x * blockDim.y * blockDim.z;
  120. uint32_t i = (threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x;
  121. for (; i < n; i += stride)
  122. shared[i] = val;
  123. }
  124. #endif
  125. // ==========================DTypeParam wrapper=================================
  126. // Division is inefficient in cuda, so we replace div scale with mul 1/scale,
  127. // and we need a wrapper of DTypeParam to hold the reciprocal of scale.
  128. template <typename Type>
  129. struct CudaDTypeParamImpl;
  130. template <typename DType>
  131. using CudaDTypeParam = CudaDTypeParamImpl<typename DTypeTrait<DType>::ctype>;
  132. template <>
  133. struct CudaDTypeParamImpl<dt_quint8> : DTypeParamImpl<dt_quint8> {
  134. float inv_scale;
  135. CudaDTypeParamImpl() = default;
  136. CudaDTypeParamImpl(float scale, uint8_t zero_point)
  137. : DTypeParamImpl<dt_quint8>(scale, zero_point), inv_scale(1.0f / scale) {}
  138. CudaDTypeParamImpl(const DTypeParamImpl<dt_quint8>& param)
  139. : CudaDTypeParamImpl(param.scale, param.zero_point) {}
  140. __device__ dt_quint8 quantize(float in) const {
  141. float v = in * inv_scale;
  142. v = roundf(v);
  143. v = v + zero_point;
  144. v = fmin(fmax(0.f, v), 255.f);
  145. return static_cast<dt_quint8>(v);
  146. }
  147. };
  148. template <>
  149. struct CudaDTypeParamImpl<dt_qint8> : DTypeParamImpl<dt_qint8> {
  150. float inv_scale;
  151. CudaDTypeParamImpl() = default;
  152. CudaDTypeParamImpl(float scale)
  153. : DTypeParamImpl<dt_qint8>(scale), inv_scale(1.0f / scale) {}
  154. CudaDTypeParamImpl(const DTypeParamImpl<dt_qint8>& param)
  155. : CudaDTypeParamImpl(param.scale) {}
  156. __device__ dt_qint8 quantize(float in) const {
  157. float v = in * inv_scale;
  158. v = roundf(v);
  159. v = fmin(fmax(-128.f, v), 127.f);
  160. return static_cast<dt_qint8>(v);
  161. }
  162. };
  163. template <>
  164. struct CudaDTypeParamImpl<dt_qint32> : DTypeParamImpl<dt_qint32> {
  165. float inv_scale;
  166. CudaDTypeParamImpl() = default;
  167. CudaDTypeParamImpl(float scale)
  168. : DTypeParamImpl<dt_qint32>(scale), inv_scale(1.0f / scale) {}
  169. CudaDTypeParamImpl(const DTypeParamImpl<dt_qint32>& param)
  170. : CudaDTypeParamImpl(param.scale) {}
  171. __device__ dt_qint32 quantize(float in) const {
  172. float v = in * inv_scale;
  173. v = roundf(v);
  174. /*! \note: the maximal signed integer that can be correctly represented
  175. * as a single precision floating point number is 2147483520
  176. */
  177. v = fmin(fmax(-2147483648.f, v), 2147483520.f);
  178. return static_cast<dt_qint32>(v);
  179. }
  180. };
  181. template <>
  182. struct CudaDTypeParamImpl<dt_quint4> : DTypeParamImpl<dt_quint4> {
  183. float inv_scale;
  184. CudaDTypeParamImpl() = default;
  185. CudaDTypeParamImpl(float scale, uint8_t zero_point)
  186. : DTypeParamImpl<dt_quint4>(scale, zero_point), inv_scale(1.0f / scale) {}
  187. CudaDTypeParamImpl(const DTypeParamImpl<dt_quint4>& param)
  188. : CudaDTypeParamImpl(param.scale, param.zero_point) {}
  189. __device__ dt_quint4 quantize(float in) const {
  190. float v = in * inv_scale;
  191. v = roundf(v);
  192. v = v + zero_point;
  193. v = fmin(fmax(0.f, v), 15.f);
  194. return static_cast<dt_quint4>(v);
  195. }
  196. };
  197. template <>
  198. struct CudaDTypeParamImpl<dt_qint4> : DTypeParamImpl<dt_qint4> {
  199. float inv_scale;
  200. CudaDTypeParamImpl() = default;
  201. CudaDTypeParamImpl(float scale)
  202. : DTypeParamImpl<dt_qint4>(scale), inv_scale(1.0f / scale) {}
  203. CudaDTypeParamImpl(const DTypeParamImpl<dt_qint4>& param)
  204. : CudaDTypeParamImpl(param.scale) {}
  205. __device__ dt_qint4 quantize(float in) const {
  206. float v = in * inv_scale;
  207. v = roundf(v);
  208. v = fmin(fmax(-8.f, v), 7.f);
  209. return static_cast<dt_qint4>(v);
  210. }
  211. };
  212. #if MEGDNN_CC_CUDA
  213. static inline MEGDNN_DEVICE void dot_prod(int a, int b, int c, int& d) {
  214. #if __CUDA_ARCH__ >= 610
  215. // clang-format off
  216. asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
  217. : "=r"(d)
  218. : "r"(a), "r"(b), "r"(c));
  219. // clang-format on
  220. #else
  221. d = 0;
  222. #pragma unroll
  223. for (int i = 0; i < 4; ++i) {
  224. int8_t val_a = (a & 0xff), val_b = (b & 0xff);
  225. d += static_cast<int>(val_a) * static_cast<int>(val_b);
  226. a = (a >> 8), b = (b >> 8);
  227. }
  228. d += c;
  229. #endif
  230. }
  231. // the following code is taken from cutlass:
  232. // https://github.com/NVIDIA/cutlass/blob/master/cutlass/gemm/igemm_epilogue.h
  233. // Note: using .rni integer rounding modifier, i.e. rounding to nearest integer,
  234. // choosing even integer if source is equidistant between two integers. The
  235. // reason not use roundf is that roundf() maps to an 8-instruction sequence on
  236. // the device, which causes significant performance drop in some cases. For
  237. // details, refer to
  238. // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
  239. MEGDNN_DEVICE __forceinline__ static int transform_float4_to_int8x4(float4 val) {
  240. int ix, iy, iz, iw;
  241. asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(ix) : "f"(val.x));
  242. asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(iy) : "f"(val.y));
  243. asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(iz) : "f"(val.z));
  244. asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(iw) : "f"(val.w));
  245. asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(ix) : "r"(iy));
  246. asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(iz) : "r"(iw));
  247. asm volatile("prmt.b32 %0, %0, %1, 0x5410;" : "+r"(ix) : "r"(iz));
  248. return ix;
  249. }
  250. MEGDNN_DEVICE __forceinline__ static float4 transform_int8x4_to_float4(int val) {
  251. int ix, iy, iz, iw = val;
  252. // Extract the 4 bytes
  253. asm volatile("prmt.b32 %0, %1, 0x0, 0x4440;" : "=r"(ix) : "r"(iw));
  254. asm volatile("prmt.b32 %0, %1, 0x0, 0x4441;" : "=r"(iy) : "r"(iw));
  255. asm volatile("prmt.b32 %0, %1, 0x0, 0x4442;" : "=r"(iz) : "r"(iw));
  256. asm volatile("prmt.b32 %0, %1, 0x0, 0x4443;" : "=r"(iw) : "r"(iw));
  257. // the floats
  258. float fx, fy, fz, fw;
  259. // convert to floats (make sure we generate I2F.F32.S8)
  260. asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fx) : "r"(ix));
  261. asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fy) : "r"(iy));
  262. asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fz) : "r"(iz));
  263. asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fw) : "r"(iw));
  264. return ::make_float4(fx, fy, fz, fw);
  265. }
  266. MEGDNN_DEVICE __forceinline__ static float4 operator*(float scalar, float4 val) {
  267. return make_float4(scalar * val.x, scalar * val.y, scalar * val.z, scalar * val.w);
  268. }
  269. MEGDNN_DEVICE __forceinline__ static float4 operator+(float4 lval, float4 rval) {
  270. return make_float4(
  271. lval.x + rval.x, lval.y + rval.y, lval.z + rval.z, lval.w + rval.w);
  272. }
  273. #endif
  274. } // namespace cuda
  275. } // namespace megdnn
  276. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台