Browse Source

feat(mge/opr): add silu and gelu

GitOrigin-RevId: 75aa42947e
release-1.5
Megvii Engine Team 3 years ago
parent
commit
f76a2cc2c6
55 changed files with 414 additions and 8 deletions
  1. +2
    -2
      dnn/scripts/gen_elemwise_multi_type_utils.py
  2. +2
    -2
      dnn/scripts/gen_elemwise_utils.py
  3. +5
    -1
      dnn/scripts/opr_param_defs.py
  4. +4
    -0
      dnn/src/common/elemwise/each_mode.inl
  5. +29
    -0
      dnn/src/common/elemwise/kern_defs.cuh
  6. +7
    -0
      dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cu
  7. +7
    -0
      dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float16.cu
  8. +5
    -0
      dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float32.cu
  9. +7
    -0
      dnn/src/cuda/elemwise/kimpl/GELU_dt_bfloat16.cu
  10. +7
    -0
      dnn/src/cuda/elemwise/kimpl/GELU_dt_float16.cu
  11. +5
    -0
      dnn/src/cuda/elemwise/kimpl/GELU_dt_float32.cu
  12. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cu
  13. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float16.cu
  14. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float32.cu
  15. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SILU_dt_bfloat16.cu
  16. +7
    -0
      dnn/src/cuda/elemwise/kimpl/SILU_dt_float16.cu
  17. +5
    -0
      dnn/src/cuda/elemwise/kimpl/SILU_dt_float32.cu
  18. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/GELU_GRAD_dt_qint8_dt_qint8.cu
  19. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/GELU_dt_qint8_dt_qint8.cu
  20. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/SILU_GRAD_dt_qint8_dt_qint8.cu
  21. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/SILU_dt_qint8_dt_qint8.cu
  22. +7
    -0
      dnn/src/naive/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cpp
  23. +7
    -0
      dnn/src/naive/elemwise/kimpl/GELU_GRAD_dt_float16.cpp
  24. +5
    -0
      dnn/src/naive/elemwise/kimpl/GELU_GRAD_dt_float32.cpp
  25. +7
    -0
      dnn/src/naive/elemwise/kimpl/GELU_dt_bfloat16.cpp
  26. +7
    -0
      dnn/src/naive/elemwise/kimpl/GELU_dt_float16.cpp
  27. +5
    -0
      dnn/src/naive/elemwise/kimpl/GELU_dt_float32.cpp
  28. +7
    -0
      dnn/src/naive/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cpp
  29. +7
    -0
      dnn/src/naive/elemwise/kimpl/SILU_GRAD_dt_float16.cpp
  30. +5
    -0
      dnn/src/naive/elemwise/kimpl/SILU_GRAD_dt_float32.cpp
  31. +7
    -0
      dnn/src/naive/elemwise/kimpl/SILU_dt_bfloat16.cpp
  32. +7
    -0
      dnn/src/naive/elemwise/kimpl/SILU_dt_float16.cpp
  33. +5
    -0
      dnn/src/naive/elemwise/kimpl/SILU_dt_float32.cpp
  34. +7
    -0
      dnn/src/rocm/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cpp.hip
  35. +7
    -0
      dnn/src/rocm/elemwise/kimpl/GELU_GRAD_dt_float16.cpp.hip
  36. +5
    -0
      dnn/src/rocm/elemwise/kimpl/GELU_GRAD_dt_float32.cpp.hip
  37. +7
    -0
      dnn/src/rocm/elemwise/kimpl/GELU_dt_bfloat16.cpp.hip
  38. +7
    -0
      dnn/src/rocm/elemwise/kimpl/GELU_dt_float16.cpp.hip
  39. +5
    -0
      dnn/src/rocm/elemwise/kimpl/GELU_dt_float32.cpp.hip
  40. +7
    -0
      dnn/src/rocm/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cpp.hip
  41. +7
    -0
      dnn/src/rocm/elemwise/kimpl/SILU_GRAD_dt_float16.cpp.hip
  42. +5
    -0
      dnn/src/rocm/elemwise/kimpl/SILU_GRAD_dt_float32.cpp.hip
  43. +7
    -0
      dnn/src/rocm/elemwise/kimpl/SILU_dt_bfloat16.cpp.hip
  44. +7
    -0
      dnn/src/rocm/elemwise/kimpl/SILU_dt_float16.cpp.hip
  45. +5
    -0
      dnn/src/rocm/elemwise/kimpl/SILU_dt_float32.cpp.hip
  46. +21
    -0
      imperative/python/megengine/functional/nn.py
  47. +1
    -1
      imperative/python/megengine/module/__init__.py
  48. +69
    -1
      imperative/python/megengine/module/activation.py
  49. +2
    -0
      imperative/python/megengine/module/elemwise.py
  50. +7
    -0
      imperative/python/test/unit/functional/test_elemwise.py
  51. +1
    -1
      src/jit/impl/ast_c.cpp
  52. +4
    -0
      src/opr/impl/basic_arith.cpp
  53. +9
    -0
      src/opr/test/basic_arith/elemwise.cpp
  54. +4
    -0
      src/opr/test/basic_arith/elemwise_binary_trait_def.inl
  55. +2
    -0
      src/opr/test/basic_arith/elemwise_unary_trait_def.inl

+ 2
- 2
dnn/scripts/gen_elemwise_multi_type_utils.py View File

@@ -10,13 +10,13 @@ MODES = {
1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', 1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS',
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN',
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', 'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC',
'ERFCINV', 'H_SWISH'],
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'],
2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', 2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT',
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW',
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD',
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD',
'FUSE_ADD_H_SWISH'],
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'],
3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], 3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'],
} }




+ 2
- 2
dnn/scripts/gen_elemwise_utils.py View File

@@ -21,13 +21,13 @@ MODES = {
(1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', (1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS',
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN',
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', 'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC',
'ERFCINV', 'H_SWISH'],
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'],
(2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', (2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT',
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW',
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD',
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD',
'FUSE_ADD_H_SWISH'],
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'],
(3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], (3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'],
(1, 'BOOL'): ['NOT'], (1, 'BOOL'): ['NOT'],
(2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], (2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'],


+ 5
- 1
dnn/scripts/opr_param_defs.py View File

@@ -410,7 +410,11 @@ pdef('Elemwise').add_enum(
Doc('NOT', 'unary: !x'), Doc('NOT', 'unary: !x'),
Doc('AND', 'binary: x && y'), Doc('AND', 'binary: x && y'),
Doc('OR', 'binary: x || y'), Doc('OR', 'binary: x || y'),
Doc('XOR', 'binary: x ^ y')
Doc('XOR', 'binary: x ^ y'),
Doc('SILU', 'unary: x / (1 + exp(-x))'),
Doc('SILU_GRAD', 'binary: grad(x / (1 + exp(-x))'),
Doc('GELU', 'unary: x Phi(x)'),
Doc('GELU_GRAD', 'binary: grad(x Phi(x))'),
) )


pdef('ElemwiseMultiType').add_enum( pdef('ElemwiseMultiType').add_enum(


+ 4
- 0
dnn/src/common/elemwise/each_mode.inl View File

@@ -25,6 +25,8 @@
MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) \


#define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \ #define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \
@@ -64,6 +66,8 @@
MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) \


#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \ #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \


+ 29
- 0
dnn/src/common/elemwise/kern_defs.cuh View File

@@ -69,6 +69,31 @@ namespace megdnn {
return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx; return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx;
} }


//! grad of silu
__device__ __host__ inline float silu_grad(float x, float dy) {
const float one = 1.0;
float sigmoid = one / (one + expf(-x));
return dy * sigmoid * (one + x * (one - sigmoid));
}

__device__ __host__ inline float normcdf(float x) {
#if MEGDNN_CC_HOST
return 0.5f * (1.f + erff(x / sqrtf(2.f)));
#else
//! use cuda build-in math
return ::normcdff(x);
#endif
}

//! grad of gelu
__device__ __host__ inline float gelu_grad(float x, float dy) {
//! 1/ sqrt(2 * pi)
const float coeff = 0.3989422804014327f;
float phi = coeff * expf(-0.5f * x * x);
float normcdf_v = normcdf(x);
return dy * (normcdf_v + x * phi);
}

#include "src/common/elemwise/each_mode.inl" #include "src/common/elemwise/each_mode.inl"


template<megcorePlatform_t plat, uint32_t mode, typename dtype> template<megcorePlatform_t plat, uint32_t mode, typename dtype>
@@ -137,6 +162,8 @@ namespace megdnn {
DEF_KERN_FLOAT(ERFC, erfcf(x)); DEF_KERN_FLOAT(ERFC, erfcf(x));
DEF_KERN_FLOAT(ERFCINV, erfcinvf(x)); DEF_KERN_FLOAT(ERFCINV, erfcinvf(x));
DEF_KERN_FLOAT(H_SWISH, x * min(max(x + 3, 0.f), 6.f) * (1.f / 6.f)); DEF_KERN_FLOAT(H_SWISH, x * min(max(x + 3, 0.f), 6.f) * (1.f / 6.f));
DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f));
DEF_KERN_FLOAT(GELU, x * normcdf(x));


// int only // int only
DEF_KERN(dt_bool, NOT, x ^ 1); DEF_KERN(dt_bool, NOT, x ^ 1);
@@ -207,6 +234,8 @@ namespace megdnn {
x < -3.f ? (ctype)0.f : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y))); x < -3.f ? (ctype)0.f : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y)));


DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y)); DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y));
DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y));
DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
#undef KERN_SIG #undef KERN_SIG


/* ================== ternary kernels ================== */ /* ================== ternary kernels ================== */


+ 7
- 0
dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/GELU_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/GELU_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/GELU_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SILU_dt_bfloat16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/cuda/elemwise/kimpl/SILU_dt_float16.cu View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/cuda/elemwise/kimpl/SILU_dt_float32.cu View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/GELU_GRAD_dt_qint8_dt_qint8.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_qint8
#define KERN_IMPL_DTYPE dt_qint8
#include "../kern_impl.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/GELU_dt_qint8_dt_qint8.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_STYPE dt_qint8
#define KERN_IMPL_DTYPE dt_qint8
#include "../kern_impl.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/SILU_GRAD_dt_qint8_dt_qint8.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_qint8
#define KERN_IMPL_DTYPE dt_qint8
#include "../kern_impl.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/SILU_dt_qint8_dt_qint8.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_STYPE dt_qint8
#define KERN_IMPL_DTYPE dt_qint8
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/naive/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cpp View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/naive/elemwise/kimpl/GELU_GRAD_dt_float16.cpp View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/naive/elemwise/kimpl/GELU_GRAD_dt_float32.cpp View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/naive/elemwise/kimpl/GELU_dt_bfloat16.cpp View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/naive/elemwise/kimpl/GELU_dt_float16.cpp View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/naive/elemwise/kimpl/GELU_dt_float32.cpp View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/naive/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cpp View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/naive/elemwise/kimpl/SILU_GRAD_dt_float16.cpp View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/naive/elemwise/kimpl/SILU_GRAD_dt_float32.cpp View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/naive/elemwise/kimpl/SILU_dt_bfloat16.cpp View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/naive/elemwise/kimpl/SILU_dt_float16.cpp View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/naive/elemwise/kimpl/SILU_dt_float32.cpp View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/rocm/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cpp.hip View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/rocm/elemwise/kimpl/GELU_GRAD_dt_float16.cpp.hip View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/rocm/elemwise/kimpl/GELU_GRAD_dt_float32.cpp.hip View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/rocm/elemwise/kimpl/GELU_dt_bfloat16.cpp.hip View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/rocm/elemwise/kimpl/GELU_dt_float16.cpp.hip View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/rocm/elemwise/kimpl/GELU_dt_float32.cpp.hip View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/rocm/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cpp.hip View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/rocm/elemwise/kimpl/SILU_GRAD_dt_float16.cpp.hip View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/rocm/elemwise/kimpl/SILU_GRAD_dt_float32.cpp.hip View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/rocm/elemwise/kimpl/SILU_dt_bfloat16.cpp.hip View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 7
- 0
dnn/src/rocm/elemwise/kimpl/SILU_dt_float16.cpp.hip View File

@@ -0,0 +1,7 @@
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif

+ 5
- 0
dnn/src/rocm/elemwise/kimpl/SILU_dt_float32.cpp.hip View File

@@ -0,0 +1,5 @@
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"

+ 21
- 0
imperative/python/megengine/functional/nn.py View File

@@ -48,6 +48,7 @@ __all__ = [
"deformable_psroi_pooling", "deformable_psroi_pooling",
"dropout", "dropout",
"embedding", "embedding",
"gelu",
"hsigmoid", "hsigmoid",
"hswish", "hswish",
"indexing_one_hot", "indexing_one_hot",
@@ -67,6 +68,7 @@ __all__ = [
"sigmoid", "sigmoid",
"sliding_window", "sliding_window",
"sliding_window_transpose", "sliding_window_transpose",
"silu",
"softmax", "softmax",
"softplus", "softplus",
"sync_batch_norm", "sync_batch_norm",
@@ -766,6 +768,25 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
return maximum(inp, 0) + negative_slope * minimum(inp, 0) return maximum(inp, 0) + negative_slope * minimum(inp, 0)




def silu(x):
r"""
Applies the element-wise Sigmoid Linear Unit function, i.e. `x * sigmoid(x)`.
"""
return _elwise(x, mode=Elemwise.Mode.SILU)


def gelu(x):
r"""
Applies the element-wise function:

.. math::
\text{gelu}(x) = x\Phi(x)

where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
"""
return _elwise(x, mode=Elemwise.Mode.GELU)


def softplus(inp: Tensor) -> Tensor: def softplus(inp: Tensor) -> Tensor:
r""" r"""
Applies the element-wise function: Applies the element-wise function:


+ 1
- 1
imperative/python/megengine/module/__init__.py View File

@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax
from .activation import GELU, LeakyReLU, PReLU, ReLU, Sigmoid, SiLU, Softmax
from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d
from .batch_matmul_activation import BatchMatMulActivation from .batch_matmul_activation import BatchMatMulActivation
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm


+ 69
- 1
imperative/python/megengine/module/activation.py View File

@@ -8,7 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np


from ..functional import leaky_relu, prelu, relu, sigmoid, softmax
from ..functional import gelu, leaky_relu, prelu, relu, sigmoid, silu, softmax
from ..tensor import Parameter from ..tensor import Parameter
from .module import Module from .module import Module


@@ -92,6 +92,74 @@ class Sigmoid(Module):
return sigmoid(inputs) return sigmoid(inputs)




class SiLU(Module):
r"""
Applies the element-wise function:

.. math::
\text{SiLU}(x) = \frac{x}{1 + \exp(-x)}

Examples:

.. testcode::

import numpy as np
import megengine as mge
import megengine.module as M

data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32))
silu = M.SiLU()
output = silu(data)
with np.printoptions(precision=6):
print(output.numpy())

Outputs:

.. testoutput::

[-0.238406 -0.268941 0. 0.731059 1.761594]

"""

def forward(self, inputs):
return silu(inputs)


class GELU(Module):
r"""
Applies the element-wise function:

.. math::
\text{GELU}(x) = x\Phi(x)

where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.

Examples:

.. testcode::

import numpy as np
import megengine as mge
import megengine.module as M

data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32))
gelu = M.GELU()
output = gelu(data)
with np.printoptions(precision=4):
print(output.numpy())

Outputs:

.. testoutput::

[-0.0455 -0.1587 0. 0.8413 1.9545]

"""

def forward(self, inputs):
return gelu(inputs)


class ReLU(Module): class ReLU(Module):
r""" r"""
Applies the element-wise function: Applies the element-wise function:


+ 2
- 0
imperative/python/megengine/module/elemwise.py View File

@@ -28,6 +28,8 @@ class Elemwise(Module):
* "fuse_add_sigmoid": sigmoid(x + y) * "fuse_add_sigmoid": sigmoid(x + y)
* "fuse_add_tanh": tanh(x + y) * "fuse_add_tanh": tanh(x + y)
* "relu": x > 0 ? x : 0 * "relu": x > 0 ? x : 0
* "silu": silu(x)
* "gelu": gelu(x)
* "abs": x > 0 ? x : -x * "abs": x > 0 ? x : -x
* "sigmoid": sigmoid(x) * "sigmoid": sigmoid(x)
* "exp": exp(x) * "exp": exp(x)


+ 7
- 0
imperative/python/test/unit/functional/test_elemwise.py View File

@@ -144,6 +144,13 @@ def test_hswish():
np.testing.assert_almost_equal(y_np, y_mge, decimal=6) np.testing.assert_almost_equal(y_np, y_mge, decimal=6)




def test_silu():
x = np.array([-1.5, 0.0, 1.0, 1.5]).astype("float32")
y_np = x / (1 + np.exp(-x))
y_mge = F.silu(tensor(x)).numpy()
np.testing.assert_almost_equal(y_np, y_mge, decimal=6)


def test_hsigmoid(): def test_hsigmoid():
np.random.seed(42) np.random.seed(42)
x = np.random.randn(100).astype("float32") x = np.random.randn(100).astype("float32")


+ 1
- 1
src/jit/impl/ast_c.cpp View File

@@ -145,7 +145,7 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() {
0.f}) / 0.f}) /
6.f), 6.f),
}; };
mgb_assert(map.size() + 12 == opr::Elemwise::Param::MODE_NR_MEMBER);
mgb_assert(map.size() + 16 == opr::Elemwise::Param::MODE_NR_MEMBER);
// unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH,
// ERFINV, ERFCINV, NOT, AND, OR, XOR // ERFINV, ERFCINV, NOT, AND, OR, XOR
return map; return map;


+ 4
- 0
src/opr/impl/basic_arith.cpp View File

@@ -613,6 +613,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
RET(EL2(H_SWISH_GRAD, (i0 + i1), og)); RET(EL2(H_SWISH_GRAD, (i0 + i1), og));
case Mode::NOT: case Mode::NOT:
return nullptr; return nullptr;
case Mode::SILU:
RET(EL2(SILU_GRAD, i0, og));
case Mode::GELU:
RET(EL2(GELU_GRAD, i0, og));


// binary // binary
case Mode::ABS_GRAD: case Mode::ABS_GRAD:


+ 9
- 0
src/opr/test/basic_arith/elemwise.cpp View File

@@ -131,6 +131,12 @@ namespace {
std::numeric_limits<T>::digits)); std::numeric_limits<T>::digits));
} }


float do_gelu_grad(float x, float y) {
float phi = 1.f / sqrtf(2.0 * M_PI) * expf(-0.5f * x * x);
float normcdf_v = 0.5f * (1.f + erff(x / sqrtf(2.f)));
return y * (normcdf_v + x * phi);
}

/* ======================= basic framework ======================= */ /* ======================= basic framework ======================= */


template<typename ctype, bool stable_sign = false> template<typename ctype, bool stable_sign = false>
@@ -563,6 +569,9 @@ namespace {
} }
}; };


template<> struct CheckerConfig<SILU_GRAD>: public NoGradCheckerConfig {};
template<> struct CheckerConfig<GELU_GRAD>: public NoGradCheckerConfig {};

/* ======================= ternary config ======================= */ /* ======================= ternary config ======================= */
template<> struct CheckerConfig<COND_LEQ_MOV>: template<> struct CheckerConfig<COND_LEQ_MOV>:
public BinaryInputMinGap<false> {}; public BinaryInputMinGap<false> {};


+ 4
- 0
src/opr/test/basic_arith/elemwise_binary_trait_def.inl View File

@@ -64,6 +64,10 @@ DEF_TRAIT(FUSE_ADD_H_SWISH, do_fuse_add_h_swish(x, y))
DEF_TRAIT(FAST_TANH_GRAD, do_fast_tanh_grad(x, y)) DEF_TRAIT(FAST_TANH_GRAD, do_fast_tanh_grad(x, y))
DEF_TRAIT(ATAN2, std::atan2(x, y)) DEF_TRAIT(ATAN2, std::atan2(x, y))
DEF_TRAIT(H_SWISH_GRAD, do_h_swish_grad(x, y)) DEF_TRAIT(H_SWISH_GRAD, do_h_swish_grad(x, y))
DEF_TRAIT(SILU_GRAD, y*(1 + std::exp(-x) + x * std::exp(-x)) /
(1 + std::exp(-x)) / (1 + std::exp(-x)))
DEF_TRAIT(GELU_GRAD, do_gelu_grad(x, y))

#undef _ALLOW_INT #undef _ALLOW_INT
#undef _ALLOW_FLOAT #undef _ALLOW_FLOAT




+ 2
- 0
src/opr/test/basic_arith/elemwise_unary_trait_def.inl View File

@@ -56,6 +56,8 @@ DEF_TRAIT(ERFINV, do_erfinv(x))
DEF_TRAIT(ERFC, std::erfc(x)) DEF_TRAIT(ERFC, std::erfc(x))
DEF_TRAIT(ERFCINV, do_erfcinv(x)) DEF_TRAIT(ERFCINV, do_erfcinv(x))
DEF_TRAIT(H_SWISH, do_h_swish(x)) DEF_TRAIT(H_SWISH, do_h_swish(x))
DEF_TRAIT(SILU, x / (1 + std::exp(-x)))
DEF_TRAIT(GELU, x*(0.5f * (1.f + std::erf(x / std::sqrt(2.f)))))
#undef _ALLOW_INT #undef _ALLOW_INT


#undef _ALLOW_FLOAT #undef _ALLOW_FLOAT


Loading…
Cancel
Save