You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kern_defs.cuh 7.2 kB


  1. /**
  2. * \file dnn/src/common/elemwise/kern_defs.cuh
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "src/common/opr_param_defs_enumv.cuh"
  13. #include "src/common/elemwise_helper.cuh"
  14. #include "src/common/utils.cuh"
  15. #include "src/common/elemwise/erfinv.h"
  16. #include "megcore_cdefs.h"
  17. #include "megdnn/dtype.h"
  18. #include <cmath>
  19. #include <cstdlib>
  20. #if MEGDNN_CC_HOST
  21. #include <algorithm>
  22. using std::max;
  23. using std::min;
  24. #endif
  25. #ifndef MEGDNN_ELEMWISE_MODE_ENABLE
  26. #define MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb) _cb(_mode)
  27. #define MEGDNN_ELEMWISE_MODE_ENABLE_ALL 1
  28. #endif
  29. #if MEGDNN_CC_HOST && !defined(__host__)
  30. #define MEGDNN_HOST_DEVICE_SELF_DEFINE
  31. #define __host__
  32. #define __device__
  33. #endif
  34. namespace megdnn {
  35. template<typename T>
  36. __device__ __host__ inline T log_sum_exp(T x, T y) {
  37. T a, b;
  38. a = x < y ? x : y;
  39. b = x < y ? y : x;
  40. return T(b + log1pf(exp(a - b)));
  41. }
  42. __device__ __host__ inline float fast_tanh(float x) {
  43. return x * (27.f + x * x) / (27.f + 9.f * x * x);
  44. }
  45. //! use multiplying (1.f / 6.f) to replace dividing 6.f, because we didn't
  46. //! pass
  47. //! --use_fast_math to nvcc to enable --prec_div optimization, which will
  48. //! cause performance drop on Turing architecture
  49. __device__ __host__ inline float fuse_add_hswish(float x, float y) {
  50. float z = x + y;
  51. return z * min(max(z + 3, 0.f), 6.f) * (1.f / 6.f);
  52. }
  53. __device__ __host__ inline float fast_tanh_grad(float x, float dx) {
  54. float x_pow2 = x * x;
  55. float deno = 3.f + x_pow2;
  56. return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx;
  57. }
  58. #include "src/common/elemwise/each_mode.inl"
  59. template<megcorePlatform_t plat, uint32_t mode, typename dtype>
  60. struct ElemwiseKern;
  61. //! define kernel for a single ctype
  62. #define DEF_KERN(_ctype, _mode, _imp) \
  63. template<megcorePlatform_t plat> \
  64. struct ElemwiseKern<plat, param_enumv::Elemwise::Mode::_mode, _ctype> { \
  65. typedef _ctype ctype; \
  66. static __host__ __device__ _ctype apply(KERN_SIG) { \
  67. return ctype(_imp); \
  68. } \
  69. }
  70. //! define kernel for all float types
  71. #define DEF_KERN_FLOAT(_mode, _imp) \
  72. DEF_KERN(dt_float32, _mode, _imp); \
  73. DNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \
  74. DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);)
  75. //! define kernel for all int types
  76. #define DEF_KERN_INT(_mode, _imp) \
  77. DEF_KERN(dt_int32, _mode, _imp); \
  78. DEF_KERN(dt_int16, _mode, _imp); \
  79. DEF_KERN(dt_int8, _mode, _imp); \
  80. DEF_KERN(dt_uint8, _mode, _imp); \
  81. //! define kernel for all ctypes
  82. #define DEF_KERN_ALL(_mode, _imp) \
  83. DEF_KERN_INT(_mode, _imp); \
  84. DEF_KERN_FLOAT(_mode, _imp); \
  85. /* ================== unary kernels ================== */
  86. #define KERN_SIG ctype x
  87. // int and float
  88. DEF_KERN_ALL(NEGATE, -x);
  89. #if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
  90. DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x);
  91. DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x);
  92. #else
  93. DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x);
  94. #endif
  95. DEF_KERN_INT(ABS, abs(int(x)));
  96. // DEF_KERN_INT(ABS, x > ctype(0) ? x : -x);
  97. DEF_KERN_FLOAT(ABS, fabsf(x));
  98. // float only
  99. DEF_KERN_FLOAT(ACOS, acosf(x));
  100. DEF_KERN_FLOAT(ASIN, asinf(x));
  101. DEF_KERN_FLOAT(CEIL, ceilf(x));
  102. DEF_KERN_FLOAT(COS, cosf(x));
  103. DEF_KERN_FLOAT(EXP, expf(x));
  104. DEF_KERN_FLOAT(EXPM1, expm1f(x));
  105. DEF_KERN_FLOAT(FLOOR, floorf(x));
  106. DEF_KERN_FLOAT(LOG, logf(x));
  107. DEF_KERN_FLOAT(LOG1P, log1pf(x));
  108. DEF_KERN_FLOAT(SIGMOID, 1.f / (expf(-x) + 1.f));
  109. DEF_KERN_FLOAT(SIN, sinf(x));
  110. DEF_KERN_FLOAT(TANH, tanhf(x));
  111. DEF_KERN_FLOAT(FAST_TANH, fast_tanh(x));
  112. DEF_KERN_FLOAT(ROUND, roundf(x));
  113. DEF_KERN_FLOAT(ERF, erff(x));
  114. DEF_KERN_FLOAT(ERFINV, erfinvf(x));
  115. DEF_KERN_FLOAT(ERFC, erfcf(x));
  116. DEF_KERN_FLOAT(ERFCINV, erfcinvf(x));
  117. DEF_KERN_FLOAT(H_SWISH, x * min(max(x + 3, 0.f), 6.f) * (1.f / 6.f));
  118. // int only
  119. DEF_KERN(dt_bool, NOT, x ^ 1);
  120. #undef KERN_SIG
  121. /* ================== binary kernels ================== */
  122. #define KERN_SIG ctype x, ctype y
  123. // int and float
  124. #if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
  125. DEF_KERN_INT(ABS_GRAD, x > ctype(0) ? y : -y);
  126. DEF_KERN_FLOAT(ABS_GRAD, x > 0.f ? y : -y);
  127. #else
  128. DEF_KERN_ALL(ABS_GRAD, x > ctype(0) ? y : -y);
  129. #endif
  130. DEF_KERN_ALL(ADD, x + y);
  131. DEF_KERN_ALL(MAX, x > y ? x : y);
  132. DEF_KERN_ALL(MIN, x < y ? x : y);
  133. DEF_KERN_ALL(MUL, x* y);
  134. DEF_KERN(dt_bool, AND, x && y);
  135. DEF_KERN(dt_bool, OR, x || y);
  136. DEF_KERN(dt_bool, XOR, x ^ y);
  137. DEF_KERN_INT(RMULH, round_mulh_saturate(x, y));
  138. DEF_KERN_ALL(SIGMOID_GRAD, x*(ctype(1) - x) * y);
  139. DEF_KERN_ALL(SUB, x - y);
  140. #if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
  141. DEF_KERN_INT(SWITCH_GT0, x > ctype(0) ? y : ctype(0));
  142. DEF_KERN_FLOAT(SWITCH_GT0, x > 0.f ? y : ctype(0));
  143. #else
  144. DEF_KERN_ALL(SWITCH_GT0, x > ctype(0) ? y : ctype(0));
  145. #endif
  146. DEF_KERN_ALL(TANH_GRAD, (ctype(1) - x * x) * y);
  147. DEF_KERN_ALL(LT, x < y);
  148. DEF_KERN_ALL(LEQ, x <= y);
  149. DEF_KERN_ALL(EQ, x == y);
  150. DEF_KERN(dt_bool, LT, x < y);
  151. DEF_KERN(dt_bool, LEQ, x <= y);
  152. DEF_KERN(dt_bool, EQ, x == y);
  153. DEF_KERN_INT(FLOOR_DIV, x / y);
  154. DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y));
  155. DEF_KERN_INT(MOD, x % y);
  156. DEF_KERN_FLOAT(MOD, fmodf(x, y));
  157. DEF_KERN_INT(SHL, x << y);
  158. DEF_KERN_INT(SHR, x >> y);
  159. #if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
  160. DEF_KERN_INT(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y));
  161. DEF_KERN_FLOAT(FUSE_ADD_RELU, (x + y) <= 0.f ? ctype(0) : (x + y));
  162. #else
  163. DEF_KERN_ALL(FUSE_ADD_RELU,
  164. (x + y) <= ctype(0) ? ctype(0) : (x + y));
  165. #endif
  166. // float only
  167. DEF_KERN_FLOAT(TRUE_DIV, x / y);
  168. DEF_KERN_FLOAT(POW, powf(x, y));
  169. DEF_KERN_FLOAT(LOG_SUM_EXP, log_sum_exp(x, y));
  170. DEF_KERN_FLOAT(FAST_TANH_GRAD, fast_tanh_grad(x, y));
  171. DEF_KERN_FLOAT(FUSE_ADD_TANH, tanhf(x+y));
  172. DEF_KERN_FLOAT(FUSE_ADD_SIGMOID, 1.f / (expf(-(x+y)) + 1.f));
  173. DEF_KERN_FLOAT(ATAN2, atan2f(x, y));
  174. DEF_KERN_FLOAT(H_SWISH_GRAD,
  175. x < -3.f ? (ctype)0.f : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y)));
  176. DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y));
  177. #undef KERN_SIG
  178. /* ================== ternary kernels ================== */
  179. #define KERN_SIG ctype x, ctype y, ctype z
  180. // int and float
  181. DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0));
  182. DEF_KERN_ALL(FUSE_MUL_ADD3, x * y + z);
  183. #undef KERN_SIG
  184. #undef DEF_KERN_AD
  185. #undef DEF_KERN
  186. } // namespace megdnn
  187. #if MEGDNN_CC_HOST && defined(MEGDNN_HOST_DEVICE_SELF_DEFINE)
  188. #undef MEGDNN_HOST_DEVICE_SELF_DEFINE
  189. #undef __host__
  190. #undef __device__
  191. #endif
  192. // vim: ft=cpp syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台