Browse Source

feat(dnn): add elemwise modes

GitOrigin-RevId: edea1a48b0
try-import
Megvii Engine Team 3 years ago
parent
commit
042c7fd10a
100 changed files with 731 additions and 12 deletions
  1. +0
    -1
      dnn/include/megdnn/oprs/general.h
  2. +23
    -2
      dnn/scripts/gen_elemwise_multi_type_utils.py
  3. +30
    -3
      dnn/scripts/gen_elemwise_utils.py
  4. +22
    -0
      dnn/scripts/opr_param_defs.py
  5. +33
    -6
      dnn/src/common/elemwise/each_mode.inl
  6. +35
    -0
      dnn/src/common/elemwise/kern_defs.cuh
  7. +28
    -0
      dnn/src/common/elemwise/opr_impl.cpp
  8. +7
    -0
      dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cu
  9. +7
    -0
      dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float16.cu
  10. +5
    -0
      dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float32.cu
  11. +7
    -0
      dnn/src/cuda/elemwise/kimpl/ACOSH_dt_bfloat16.cu
  12. +7
    -0
      dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float16.cu
  13. +5
    -0
      dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float32.cu
  14. +7
    -0
      dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cu
  15. +7
    -0
      dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float16.cu
  16. +5
    -0
      dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float32.cu
  17. +7
    -0
      dnn/src/cuda/elemwise/kimpl/ASINH_dt_bfloat16.cu
  18. +7
    -0
      dnn/src/cuda/elemwise/kimpl/ASINH_dt_float16.cu
  19. +5
    -0
      dnn/src/cuda/elemwise/kimpl/ASINH_dt_float32.cu
  20. +7
    -0
      dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cu
  21. +7
    -0
      dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float16.cu
  22. +5
    -0
      dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float32.cu
  23. +7
    -0
      dnn/src/cuda/elemwise/kimpl/ATANH_dt_bfloat16.cu
  24. +7
    -0
      dnn/src/cuda/elemwise/kimpl/ATANH_dt_float16.cu
  25. +5
    -0
      dnn/src/cuda/elemwise/kimpl/ATANH_dt_float32.cu
  26. +7
    -0
      dnn/src/cuda/elemwise/kimpl/CLIP_dt_bfloat16.cu
  27. +7
    -0
      dnn/src/cuda/elemwise/kimpl/CLIP_dt_float16.cu
  28. +5
    -0
      dnn/src/cuda/elemwise/kimpl/CLIP_dt_float32.cu
  29. +5
    -0
      dnn/src/cuda/elemwise/kimpl/CLIP_dt_int16.cu
  30. +5
    -0
      dnn/src/cuda/elemwise/kimpl/CLIP_dt_int32.cu
  31. +5
    -0
      dnn/src/cuda/elemwise/kimpl/CLIP_dt_int8.cu
  32. +5
    -0
      dnn/src/cuda/elemwise/kimpl/CLIP_dt_uint8.cu
  33. +7
    -0
      dnn/src/cuda/elemwise/kimpl/COSH_dt_bfloat16.cu
  34. +7
    -0
      dnn/src/cuda/elemwise/kimpl/COSH_dt_float16.cu
  35. +5
    -0
      dnn/src/cuda/elemwise/kimpl/COSH_dt_float32.cu
  36. +7
    -0
      dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cu
  37. +7
    -0
      dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cu
  38. +5
    -0
      dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cu
  39. +7
    -0
      dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_bfloat16.cu
  40. +7
    -0
      dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float16.cu
  41. +5
    -0
      dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float32.cu
  42. +7
    -0
      dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cu
  43. +7
    -0
      dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float16.cu
  44. +5
    -0
      dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float32.cu
  45. +7
    -0
      dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cu
  46. +7
    -0
      dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float16.cu
  47. +5
    -0
      dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float32.cu
  48. +7
    -0
      dnn/src/cuda/elemwise/kimpl/PRELU_dt_bfloat16.cu
  49. +7
    -0
      dnn/src/cuda/elemwise/kimpl/PRELU_dt_float16.cu
  50. +5
    -0
      dnn/src/cuda/elemwise/kimpl/PRELU_dt_float32.cu
  51. +5
    -0
      dnn/src/cuda/elemwise/kimpl/PRELU_dt_int16.cu
  52. +5
    -0
      dnn/src/cuda/elemwise/kimpl/PRELU_dt_int32.cu
  53. +5
    -0
      dnn/src/cuda/elemwise/kimpl/PRELU_dt_int8.cu
  54. +5
    -0
      dnn/src/cuda/elemwise/kimpl/PRELU_dt_uint8.cu
  55. +7
    -0
      dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cu
  56. +7
    -0
      dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float16.cu
  57. +5
    -0
      dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float32.cu
  58. +7
    -0
      dnn/src/cuda/elemwise/kimpl/RELU6_dt_bfloat16.cu
  59. +7
    -0
      dnn/src/cuda/elemwise/kimpl/RELU6_dt_float16.cu
  60. +5
    -0
      dnn/src/cuda/elemwise/kimpl/RELU6_dt_float32.cu
  61. +5
    -0
      dnn/src/cuda/elemwise/kimpl/RELU6_dt_int16.cu
  62. +5
    -0
      dnn/src/cuda/elemwise/kimpl/RELU6_dt_int32.cu
  63. +5
    -0
      dnn/src/cuda/elemwise/kimpl/RELU6_dt_int8.cu
  64. +5
    -0
      dnn/src/cuda/elemwise/kimpl/RELU6_dt_uint8.cu
  65. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SIGN_dt_bfloat16.cu
  66. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SIGN_dt_float16.cu
  67. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SIGN_dt_float32.cu
  68. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SIGN_dt_int16.cu
  69. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SIGN_dt_int32.cu
  70. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SIGN_dt_int8.cu
  71. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SIGN_dt_uint8.cu
  72. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SINH_dt_bfloat16.cu
  73. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SINH_dt_float16.cu
  74. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SINH_dt_float32.cu
  75. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cu
  76. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cu
  77. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cu
  78. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cu
  79. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float16.cu
  80. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float32.cu
  81. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SQRT_dt_bfloat16.cu
  82. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SQRT_dt_float16.cu
  83. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SQRT_dt_float32.cu
  84. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SQUARE_dt_bfloat16.cu
  85. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float16.cu
  86. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float32.cu
  87. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int16.cu
  88. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int32.cu
  89. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int8.cu
  90. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SQUARE_dt_uint8.cu
  91. +7
    -0
      dnn/src/cuda/elemwise/kimpl/TAN_dt_bfloat16.cu
  92. +7
    -0
      dnn/src/cuda/elemwise/kimpl/TAN_dt_float16.cu
  93. +5
    -0
      dnn/src/cuda/elemwise/kimpl/TAN_dt_float32.cu
  94. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_GRAD_dt_qint8_dt_qint8.cu
  95. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_dt_qint8_dt_qint8.cu
  96. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_GRAD_dt_qint8_dt_qint8.cu
  97. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_dt_qint8_dt_qint8.cu
  98. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_GRAD_dt_qint8_dt_qint8.cu
  99. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_dt_qint8_dt_qint8.cu
  100. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_qint4_dt_qint4.cu

+ 0
- 1
dnn/include/megdnn/oprs/general.h View File

@@ -313,7 +313,6 @@ protected:
size_t workspace_in_bytes); size_t workspace_in_bytes);
}; };
using Cumsum = CumsumForward; using Cumsum = CumsumForward;

// mxx can be max or min // mxx can be max or min
class ArgmxxBase : public OperatorBase { class ArgmxxBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(ArgmxxBase, OperatorBase); DEF_OPR_IMPL_CTOR(ArgmxxBase, OperatorBase);


+ 23
- 2
dnn/scripts/gen_elemwise_multi_type_utils.py View File

@@ -48,6 +48,19 @@ MODES = {
"H_SWISH", "H_SWISH",
"SILU", "SILU",
"GELU", "GELU",
"SINH",
"COSH",
"ASINH",
"ACOSH",
"ATANH",
"TAN",
"SOFTPLUS",
"RELU6",
"HSIGMOID",
"LOGSIGMOID",
"SQRT",
"SQUARE",
"SIGN",
], ],
2: [ 2: [
"ABS_GRAD", "ABS_GRAD",
@@ -76,8 +89,15 @@ MODES = {
"FUSE_ADD_H_SWISH", "FUSE_ADD_H_SWISH",
"SILU_GRAD", "SILU_GRAD",
"GELU_GRAD", "GELU_GRAD",
"PRELU",
"ASINH_GRAD",
"ACOSH_GRAD",
"ATANH_GRAD",
"SOFTPLUS_GRAD",
"RELU6_GRAD",
"HSIGMOID_GRAD",
], ],
3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"],
3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3", "CLIP", "PRELU_GRAD"],
} }


QINT4_MODES = { QINT4_MODES = {
@@ -107,8 +127,9 @@ QINT4_MODES = {
"FUSE_ADD_TANH", "FUSE_ADD_TANH",
"FUSE_ADD_SIGMOID", "FUSE_ADD_SIGMOID",
"FUSE_ADD_H_SWISH", "FUSE_ADD_H_SWISH",
"PRELU",
], ],
3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"],
3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3", "CLIP"],
} }


QINT32_MODES = { QINT32_MODES = {


+ 30
- 3
dnn/scripts/gen_elemwise_utils.py View File

@@ -12,7 +12,7 @@ DTYPES = {
} }


MODES = { MODES = {
(1, "INT"): ["RELU", "ABS", "NEGATE"],
(1, "INT"): ["RELU", "ABS", "NEGATE", "RELU6", "SQUARE", "SIGN"],
(2, "INT"): [ (2, "INT"): [
"ABS_GRAD", "ABS_GRAD",
"ADD", "ADD",
@@ -32,8 +32,9 @@ MODES = {
"SHL", "SHL",
"SHR", "SHR",
"RMULH", "RMULH",
"PRELU",
], ],
(3, "INT"): ["COND_LEQ_MOV", "COND_LT_MOV"],
(3, "INT"): ["COND_LEQ_MOV", "COND_LT_MOV", "CLIP"],
(1, "FLOAT"): [ (1, "FLOAT"): [
"RELU", "RELU",
"ABS", "ABS",
@@ -59,6 +60,19 @@ MODES = {
"H_SWISH", "H_SWISH",
"SILU", "SILU",
"GELU", "GELU",
"SINH",
"COSH",
"ASINH",
"ACOSH",
"ATANH",
"TAN",
"SOFTPLUS",
"RELU6",
"HSIGMOID",
"LOGSIGMOID",
"SQRT",
"SQUARE",
"SIGN",
], ],
(2, "FLOAT"): [ (2, "FLOAT"): [
"ABS_GRAD", "ABS_GRAD",
@@ -87,8 +101,21 @@ MODES = {
"FUSE_ADD_H_SWISH", "FUSE_ADD_H_SWISH",
"SILU_GRAD", "SILU_GRAD",
"GELU_GRAD", "GELU_GRAD",
"PRELU",
"ASINH_GRAD",
"ACOSH_GRAD",
"ATANH_GRAD",
"SOFTPLUS_GRAD",
"RELU6_GRAD",
"HSIGMOID_GRAD",
],
(3, "FLOAT"): [
"COND_LEQ_MOV",
"COND_LT_MOV",
"FUSE_MUL_ADD3",
"CLIP",
"PRELU_GRAD",
], ],
(3, "FLOAT"): ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"],
(1, "BOOL"): ["NOT"], (1, "BOOL"): ["NOT"],
(2, "BOOL"): ["AND", "OR", "XOR", "LT", "LEQ", "EQ"], (2, "BOOL"): ["AND", "OR", "XOR", "LT", "LEQ", "EQ"],
(3, "BOOL"): [], (3, "BOOL"): [],


+ 22
- 0
dnn/scripts/opr_param_defs.py View File

@@ -424,6 +424,28 @@ pdef('Elemwise').add_enum(
Doc('NEQ = 61', 'binary: x != y'), Doc('NEQ = 61', 'binary: x != y'),
Doc('ISNAN = 62', 'unary: isnan(x)'), Doc('ISNAN = 62', 'unary: isnan(x)'),
Doc('ISINF = 63', 'unary: isinf(x)'), Doc('ISINF = 63', 'unary: isinf(x)'),
Doc('SINH = 64', 'unary: sinh(x)'),
Doc('COSH = 65', 'unary: cosh(x)'),
Doc('ASINH = 66', 'unary: asinh(x)'),
Doc('ACOSH = 67', 'unary: acosh(x)'),
Doc('ATANH = 68', 'unary: atanh(x)'),
Doc('TAN = 69', 'unary: tan(x)'),
Doc('ASINH_GRAD = 70', 'binary: y / sqrt(x^2 + 1)'),
Doc('ACOSH_GRAD = 71', 'binary: y / sqrt(x^2 - 1) (x > 1)'),
Doc('ATANH_GRAD = 72', 'binary: y / (1 - x^2) (|x| < 1)'),
Doc('PRELU = 73', 'binary: x > 0 ? x : x * y'),
Doc('CLIP = 74', 'ternary: x <= y ? y : (x <= z ? x : z)'),
Doc('PRELU_GRAD = 75', 'ternary: x > 0 ? y : y * z'),
Doc('SOFTPLUS = 76', 'unary: log(1 + e^x)'),
Doc('SOFTPLUS_GRAD = 77', 'binary: y * e^x / (1 + e^x)'),
Doc('RELU6 = 78', 'unary: min(max(0, x), 6)'),
Doc('RELU6_GRAD = 79', 'binary: x < 0 ? 0 : (x > 6 ? 0 : y)'),
Doc('HSIGMOID = 80', 'unary: relu6(x + 3) / 6'),
Doc('HSIGMOID_GRAD = 81', 'binary: x < -3 ? 0 : (x > 3 ? 0 : y / 6)'),
Doc('LOGSIGMOID = 82', 'unary: -log(1 + e^(-x))'),
Doc('SQRT = 83', 'unary: x^(1/2)'),
Doc('SQUARE = 84', 'unary: x^2'),
Doc('SIGN = 85', 'unary: sgn(x)'),
) )


pdef('ElemwiseMultiType').add_enum( pdef('ElemwiseMultiType').add_enum(


+ 33
- 6
dnn/src/common/elemwise/each_mode.inl View File

@@ -25,12 +25,28 @@
MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)


#define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \ #define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb)
MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)


#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb) \ #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \
@@ -66,7 +82,14 @@
MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb)


#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \ #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \
@@ -86,15 +109,19 @@
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb)
MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)


#define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_BOOL(cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_BOOL(cb)


#define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \ #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb)
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)


#define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb) \ #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)

+ 35
- 0
dnn/src/common/elemwise/kern_defs.cuh View File

@@ -154,11 +154,18 @@ struct ElemwiseKern;


// int and float // int and float
DEF_KERN_ALL(NEGATE, -x); DEF_KERN_ALL(NEGATE, -x);
DEF_KERN_ALL(SQUARE, x* x);
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__) #if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x); DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x);
DEF_KERN_INT(RELU6, x <= ctype(0) ? ctype(0) : (x <= ctype(6) ? x : ctype(6)));
DEF_KERN_INT(SIGN, x < ctype(0) ? ctype(-1) : (x > ctype(0) ? ctype(1) : ctype(0)));
DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x); DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x);
DEF_KERN_FLOAT(RELU6, x <= 6.f ? ctype(0) : (x <= 6.f ? x : ctype(6)));
DEF_KERN_FLOAT(SIGN, x < 0.f ? -1.f : (x > 0.f ? 1.f : 0.f));
#else #else
DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x); DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x);
DEF_KERN_ALL(RELU6, x <= ctype(0) ? ctype(0) : (x <= ctype(6) ? x : ctype(6)));
DEF_KERN_ALL(SIGN, x < ctype(0) ? ctype(-1) : (x > ctype(0) ? ctype(1) : ctype(0)));
#endif #endif
DEF_KERN_INT(ABS, abs(int(x))); DEF_KERN_INT(ABS, abs(int(x)));
// DEF_KERN_INT(ABS, x > ctype(0) ? x : -x); // DEF_KERN_INT(ABS, x > ctype(0) ? x : -x);
@@ -186,6 +193,18 @@ DEF_KERN_FLOAT(ERFCINV, erfcinvf(x));
DEF_KERN_FLOAT(H_SWISH, x* min(max(x + 3, 0.f), 6.f) * (1.f / 6.f)); DEF_KERN_FLOAT(H_SWISH, x* min(max(x + 3, 0.f), 6.f) * (1.f / 6.f));
DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f)); DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f));
DEF_KERN_FLOAT(GELU, x* normcdf(x)); DEF_KERN_FLOAT(GELU, x* normcdf(x));
DEF_KERN_FLOAT(SINH, sinhf(x));
DEF_KERN_FLOAT(COSH, coshf(x));
DEF_KERN_FLOAT(ASINH, asinhf(x));
DEF_KERN_FLOAT(ACOSH, acoshf(x));
DEF_KERN_FLOAT(ATANH, atanhf(x));
DEF_KERN_FLOAT(TAN, tanf(x));
DEF_KERN_FLOAT(SOFTPLUS, log1pf(expf(-fabsf(x))) + (x <= ctype(0) ? ctype(0) : x));
DEF_KERN_FLOAT(
HSIGMOID,
x <= ctype(-3) ? ctype(0) : (x >= ctype(3) ? ctype(1) : ((x + 3.f) / 6.f)));
DEF_KERN_FLOAT(SQRT, sqrtf(x));
DEF_KERN_FLOAT(LOGSIGMOID, -log1pf(expf(-fabsf(x))) + (x >= ctype(0) ? ctype(0) : x));


// int only // int only
DEF_KERN(dt_bool, NOT, x ^ 1); DEF_KERN(dt_bool, NOT, x ^ 1);
@@ -240,6 +259,12 @@ DEF_KERN_FLOAT(FUSE_ADD_RELU, (x + y) <= 0.f ? ctype(0) : (x + y));
#else #else
DEF_KERN_ALL(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y)); DEF_KERN_ALL(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y));
#endif #endif
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
DEF_KERN_INT(PRELU, x > ctype(0) ? x : (x * y));
DEF_KERN_FLOAT(PRELU, x > 0.f ? x : (x * y));
#else
DEF_KERN_ALL(PRELU, x > ctype(0) ? x : (x * y));
#endif


// float only // float only
DEF_KERN_FLOAT(TRUE_DIV, x / y); DEF_KERN_FLOAT(TRUE_DIV, x / y);
@@ -259,6 +284,14 @@ DEF_KERN_FLOAT(
DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y)); DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y));
DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y)); DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y));
DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
DEF_KERN_FLOAT(ASINH_GRAD, y / sqrt(x * x + 1.f));
DEF_KERN_FLOAT(ACOSH_GRAD, y / sqrt(x * x - 1.f));
DEF_KERN_FLOAT(ATANH_GRAD, y / (1.f - x * x));
DEF_KERN_FLOAT(SOFTPLUS_GRAD, y* expf(x) / (1.f + expf(x)));
DEF_KERN_FLOAT(RELU6_GRAD, x <= ctype(0) ? ctype(0) : (x >= ctype(6) ? ctype(0) : y));
DEF_KERN_FLOAT(
HSIGMOID_GRAD,
x <= ctype(-3) ? ctype(0) : (x >= ctype(3) ? ctype(0) : (y / 6.f)));
#undef KERN_SIG #undef KERN_SIG


/* ================== ternary kernels ================== */ /* ================== ternary kernels ================== */
@@ -268,6 +301,8 @@ DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0)); DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0));
DEF_KERN_ALL(COND_LT_MOV, x < y ? z : ctype(0)); DEF_KERN_ALL(COND_LT_MOV, x < y ? z : ctype(0));
DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z); DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z);
DEF_KERN_ALL(CLIP, x <= y ? y : (x <= z ? x : z));
DEF_KERN_FLOAT(PRELU_GRAD, x >= 0.f ? y : (y * z));


#undef KERN_SIG #undef KERN_SIG




+ 28
- 0
dnn/src/common/elemwise/opr_impl.cpp View File

@@ -62,6 +62,9 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb);
MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb);
MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb);
cb(NEQ);
cb(ISNAN);
cb(ISINF);
#undef cb #undef cb


#define cb(_m) \ #define cb(_m) \
@@ -84,11 +87,14 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb);
MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb);
MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb); MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb);
cb(ISNAN);
cb(ISINF);
#undef _a #undef _a
#define _a 2 #define _a 2
MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb);
MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb);
MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb); MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb);
cb(NEQ);
#undef _a #undef _a
#define _a 3 #define _a 3
MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb);
@@ -223,6 +229,28 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
CB_MODE(Mode::GELU); CB_MODE(Mode::GELU);
CB_MODE(Mode::GELU_GRAD); CB_MODE(Mode::GELU_GRAD);
CB_MODE(Mode::COND_LT_MOV); CB_MODE(Mode::COND_LT_MOV);
CB_MODE(Mode::SINH);
CB_MODE(Mode::COSH);
CB_MODE(Mode::ASINH);
CB_MODE(Mode::ACOSH);
CB_MODE(Mode::ATANH);
CB_MODE(Mode::TAN);
CB_MODE(Mode::ASINH_GRAD);
CB_MODE(Mode::ACOSH_GRAD);
CB_MODE(Mode::ATANH_GRAD);
CB_MODE(Mode::PRELU);
CB_MODE(Mode::PRELU_GRAD);
CB_MODE(Mode::CLIP);
CB_MODE(Mode::SOFTPLUS);
CB_MODE(Mode::SOFTPLUS_GRAD);
CB_MODE(Mode::RELU6);
CB_MODE(Mode::RELU6_GRAD);
CB_MODE(Mode::HSIGMOID);
CB_MODE(Mode::HSIGMOID_GRAD);
CB_MODE(Mode::LOGSIGMOID);
CB_MODE(Mode::SQRT);
CB_MODE(Mode::SQUARE);
CB_MODE(Mode::SIGN);
default: default:
megdnn_assert( megdnn_assert(
0, 0,


+ 7
- 0
dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/ACOSH_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/ASINH_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/ASINH_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/ASINH_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/ATANH_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/ATANH_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/ATANH_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/CLIP_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/CLIP_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/CLIP_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/CLIP_dt_int16.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/CLIP_dt_int32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/CLIP_dt_int8.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/CLIP_dt_uint8.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/COSH_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/COSH_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/COSH_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/PRELU_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/PRELU_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/PRELU_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/PRELU_dt_int16.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/PRELU_dt_int32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/PRELU_dt_int8.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/PRELU_dt_uint8.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/RELU6_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/RELU6_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/RELU6_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/RELU6_dt_int16.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/RELU6_dt_int32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/RELU6_dt_int8.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/RELU6_dt_uint8.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SIGN_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SIGN_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SIGN_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SIGN_dt_int16.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SIGN_dt_int32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SIGN_dt_int8.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SIGN_dt_uint8.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SINH_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SINH_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SINH_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SQRT_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SQRT_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SQRT_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SQUARE_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int16.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int8.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SQUARE_dt_uint8.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/TAN_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/TAN_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/TAN_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_GRAD_dt_qint8_dt_qint8.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_qint8
#define KERN_IMPL_DTYPE dt_qint8
#include "../kern_impl.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_dt_qint8_dt_qint8.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_STYPE dt_qint8
#define KERN_IMPL_DTYPE dt_qint8
#include "../kern_impl.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_GRAD_dt_qint8_dt_qint8.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_qint8
#define KERN_IMPL_DTYPE dt_qint8
#include "../kern_impl.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_dt_qint8_dt_qint8.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_STYPE dt_qint8
#define KERN_IMPL_DTYPE dt_qint8
#include "../kern_impl.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_GRAD_dt_qint8_dt_qint8.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_qint8
#define KERN_IMPL_DTYPE dt_qint8
#include "../kern_impl.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_dt_qint8_dt_qint8.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_STYPE dt_qint8
#define KERN_IMPL_DTYPE dt_qint8
#include "../kern_impl.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_qint4_dt_qint4.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_STYPE dt_qint4
#define KERN_IMPL_DTYPE dt_qint4
#include "../kern_impl_q4.inl"

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save