diff --git a/dnn/scripts/gen_elemwise_multi_type_utils.py b/dnn/scripts/gen_elemwise_multi_type_utils.py index ffccb1c8..f97505c6 100755 --- a/dnn/scripts/gen_elemwise_multi_type_utils.py +++ b/dnn/scripts/gen_elemwise_multi_type_utils.py @@ -10,13 +10,13 @@ MODES = { 1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', 'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', - 'ERFCINV', 'H_SWISH'], + 'ERFCINV', 'H_SWISH', 'SILU', 'GELU'], 2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', - 'FUSE_ADD_H_SWISH'], + 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], 3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], } diff --git a/dnn/scripts/gen_elemwise_utils.py b/dnn/scripts/gen_elemwise_utils.py index 5b48a7d0..5d744e74 100755 --- a/dnn/scripts/gen_elemwise_utils.py +++ b/dnn/scripts/gen_elemwise_utils.py @@ -21,13 +21,13 @@ MODES = { (1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', 'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', - 'ERFCINV', 'H_SWISH'], + 'ERFCINV', 'H_SWISH', 'SILU', 'GELU'], (2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', - 'FUSE_ADD_H_SWISH'], + 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], (3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], (1, 'BOOL'): ['NOT'], (2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 26738684..38dfe14b 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -410,7 +410,11 @@ pdef('Elemwise').add_enum( Doc('NOT', 'unary: !x'), Doc('AND', 'binary: x && y'), Doc('OR', 'binary: x || y'), - Doc('XOR', 'binary: x ^ y') + Doc('XOR', 'binary: x ^ y'), + Doc('SILU', 'unary: x / (1 + exp(-x))'), + Doc('SILU_GRAD', 'binary: grad(x / (1 + exp(-x))'), + Doc('GELU', 'unary: x Phi(x)'), + Doc('GELU_GRAD', 'binary: grad(x Phi(x))'), ) pdef('ElemwiseMultiType').add_enum( diff --git a/dnn/src/common/elemwise/each_mode.inl b/dnn/src/common/elemwise/each_mode.inl index 2a6424d5..703cf8f8 100644 --- a/dnn/src/common/elemwise/each_mode.inl +++ b/dnn/src/common/elemwise/each_mode.inl @@ -25,6 +25,8 @@ MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) \ #define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ @@ -64,6 +66,8 @@ MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) \ #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index 43313a10..f0a590ff 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -69,6 +69,31 @@ namespace megdnn { return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx; } + //! grad of silu + __device__ __host__ inline float silu_grad(float x, float dy) { + const float one = 1.0; + float sigmoid = one / (one + expf(-x)); + return dy * sigmoid * (one + x * (one - sigmoid)); + } + + __device__ __host__ inline float normcdf(float x) { +#if MEGDNN_CC_HOST + return 0.5f * (1.f + erff(x / sqrtf(2.f))); +#else + //! use cuda build-in math + return ::normcdff(x); +#endif + } + + //! grad of gelu + __device__ __host__ inline float gelu_grad(float x, float dy) { + //! 1/ sqrt(2 * pi) + const float coeff = 0.3989422804014327f; + float phi = coeff * expf(-0.5f * x * x); + float normcdf_v = normcdf(x); + return dy * (normcdf_v + x * phi); + } + #include "src/common/elemwise/each_mode.inl" template @@ -137,6 +162,8 @@ namespace megdnn { DEF_KERN_FLOAT(ERFC, erfcf(x)); DEF_KERN_FLOAT(ERFCINV, erfcinvf(x)); DEF_KERN_FLOAT(H_SWISH, x * min(max(x + 3, 0.f), 6.f) * (1.f / 6.f)); + DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f)); + DEF_KERN_FLOAT(GELU, x * normcdf(x)); // int only DEF_KERN(dt_bool, NOT, x ^ 1); @@ -207,6 +234,8 @@ namespace megdnn { x < -3.f ? (ctype)0.f : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y))); DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y)); + DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y)); + DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); #undef KERN_SIG /* ================== ternary kernels ================== */ diff --git a/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..32a7d825 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float16.cu new file mode 100644 index 00000000..f3481781 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float32.cu new file mode 100644 index 00000000..6bff61fe --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/GELU_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/GELU_dt_bfloat16.cu new file mode 100644 index 00000000..862473da --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/GELU_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/GELU_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/GELU_dt_float16.cu new file mode 100644 index 00000000..2df85208 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/GELU_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/GELU_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/GELU_dt_float32.cu new file mode 100644 index 00000000..5896ef8b --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/GELU_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..eb3018bc --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float16.cu new file mode 100644 index 00000000..e9960d0d --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float32.cu new file mode 100644 index 00000000..004a7e5c --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SILU_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SILU_dt_bfloat16.cu new file mode 100644 index 00000000..27009b26 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SILU_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SILU_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SILU_dt_float16.cu new file mode 100644 index 00000000..1fd1dd0c --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SILU_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SILU_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SILU_dt_float32.cu new file mode 100644 index 00000000..c66df4bb --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SILU_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/GELU_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/GELU_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..a200df94 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/GELU_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/GELU_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/GELU_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..8cc57ad7 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/GELU_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SILU_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SILU_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..b6b198b8 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SILU_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SILU_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SILU_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..f2127860 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SILU_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..32a7d825 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/GELU_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/GELU_GRAD_dt_float16.cpp new file mode 100644 index 00000000..f3481781 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/GELU_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/GELU_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/GELU_GRAD_dt_float32.cpp new file mode 100644 index 00000000..6bff61fe --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/GELU_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/GELU_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/GELU_dt_bfloat16.cpp new file mode 100644 index 00000000..862473da --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/GELU_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/GELU_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/GELU_dt_float16.cpp new file mode 100644 index 00000000..2df85208 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/GELU_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/GELU_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/GELU_dt_float32.cpp new file mode 100644 index 00000000..5896ef8b --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/GELU_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..eb3018bc --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SILU_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/SILU_GRAD_dt_float16.cpp new file mode 100644 index 00000000..e9960d0d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SILU_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SILU_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/SILU_GRAD_dt_float32.cpp new file mode 100644 index 00000000..004a7e5c --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SILU_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SILU_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SILU_dt_bfloat16.cpp new file mode 100644 index 00000000..27009b26 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SILU_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SILU_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/SILU_dt_float16.cpp new file mode 100644 index 00000000..1fd1dd0c --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SILU_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SILU_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/SILU_dt_float32.cpp new file mode 100644 index 00000000..c66df4bb --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SILU_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..32a7d825 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/GELU_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/GELU_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..f3481781 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/GELU_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/GELU_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/GELU_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..6bff61fe --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/GELU_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/GELU_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/GELU_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..862473da --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/GELU_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/GELU_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/GELU_dt_float16.cpp.hip new file mode 100644 index 00000000..2df85208 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/GELU_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/GELU_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/GELU_dt_float32.cpp.hip new file mode 100644 index 00000000..5896ef8b --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/GELU_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..eb3018bc --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SILU_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SILU_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..e9960d0d --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SILU_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SILU_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SILU_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..004a7e5c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SILU_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SILU_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SILU_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..27009b26 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SILU_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SILU_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SILU_dt_float16.cpp.hip new file mode 100644 index 00000000..1fd1dd0c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SILU_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SILU_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SILU_dt_float32.cpp.hip new file mode 100644 index 00000000..c66df4bb --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SILU_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index e506b03f..79ca60a0 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -48,6 +48,7 @@ __all__ = [ "deformable_psroi_pooling", "dropout", "embedding", + "gelu", "hsigmoid", "hswish", "indexing_one_hot", @@ -67,6 +68,7 @@ __all__ = [ "sigmoid", "sliding_window", "sliding_window_transpose", + "silu", "softmax", "softplus", "sync_batch_norm", @@ -766,6 +768,25 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: return maximum(inp, 0) + negative_slope * minimum(inp, 0) +def silu(x): + r""" + Applies the element-wise Sigmoid Linear Unit function, i.e. `x * sigmoid(x)`. + """ + return _elwise(x, mode=Elemwise.Mode.SILU) + + +def gelu(x): + r""" + Applies the element-wise function: + + .. math:: + \text{gelu}(x) = x\Phi(x) + + where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. + """ + return _elwise(x, mode=Elemwise.Mode.GELU) + + def softplus(inp: Tensor) -> Tensor: r""" Applies the element-wise function: diff --git a/imperative/python/megengine/module/__init__.py b/imperative/python/megengine/module/__init__.py index bacb61bb..e8b51a53 100644 --- a/imperative/python/megengine/module/__init__.py +++ b/imperative/python/megengine/module/__init__.py @@ -7,7 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax +from .activation import GELU, LeakyReLU, PReLU, ReLU, Sigmoid, SiLU, Softmax from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d from .batch_matmul_activation import BatchMatMulActivation from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm diff --git a/imperative/python/megengine/module/activation.py b/imperative/python/megengine/module/activation.py index 072d7be0..659d56b4 100644 --- a/imperative/python/megengine/module/activation.py +++ b/imperative/python/megengine/module/activation.py @@ -8,7 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np -from ..functional import leaky_relu, prelu, relu, sigmoid, softmax +from ..functional import gelu, leaky_relu, prelu, relu, sigmoid, silu, softmax from ..tensor import Parameter from .module import Module @@ -92,6 +92,74 @@ class Sigmoid(Module): return sigmoid(inputs) +class SiLU(Module): + r""" + Applies the element-wise function: + + .. math:: + \text{SiLU}(x) = \frac{x}{1 + \exp(-x)} + + Examples: + + .. testcode:: + + import numpy as np + import megengine as mge + import megengine.module as M + + data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32)) + silu = M.SiLU() + output = silu(data) + with np.printoptions(precision=6): + print(output.numpy()) + + Outputs: + + .. testoutput:: + + [-0.238406 -0.268941 0. 0.731059 1.761594] + + """ + + def forward(self, inputs): + return silu(inputs) + + +class GELU(Module): + r""" + Applies the element-wise function: + + .. math:: + \text{GELU}(x) = x\Phi(x) + + where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. + + Examples: + + .. testcode:: + + import numpy as np + import megengine as mge + import megengine.module as M + + data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32)) + gelu = M.GELU() + output = gelu(data) + with np.printoptions(precision=4): + print(output.numpy()) + + Outputs: + + .. testoutput:: + + [-0.0455 -0.1587 0. 0.8413 1.9545] + + """ + + def forward(self, inputs): + return gelu(inputs) + + class ReLU(Module): r""" Applies the element-wise function: diff --git a/imperative/python/megengine/module/elemwise.py b/imperative/python/megengine/module/elemwise.py index 8f935227..007c54b8 100644 --- a/imperative/python/megengine/module/elemwise.py +++ b/imperative/python/megengine/module/elemwise.py @@ -28,6 +28,8 @@ class Elemwise(Module): * "fuse_add_sigmoid": sigmoid(x + y) * "fuse_add_tanh": tanh(x + y) * "relu": x > 0 ? x : 0 + * "silu": silu(x) + * "gelu": gelu(x) * "abs": x > 0 ? x : -x * "sigmoid": sigmoid(x) * "exp": exp(x) diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index cc45dc5f..7e1c2680 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -144,6 +144,13 @@ def test_hswish(): np.testing.assert_almost_equal(y_np, y_mge, decimal=6) +def test_silu(): + x = np.array([-1.5, 0.0, 1.0, 1.5]).astype("float32") + y_np = x / (1 + np.exp(-x)) + y_mge = F.silu(tensor(x)).numpy() + np.testing.assert_almost_equal(y_np, y_mge, decimal=6) + + def test_hsigmoid(): np.random.seed(42) x = np.random.randn(100).astype("float32") diff --git a/src/jit/impl/ast_c.cpp b/src/jit/impl/ast_c.cpp index d70d0b33..9e39c274 100644 --- a/src/jit/impl/ast_c.cpp +++ b/src/jit/impl/ast_c.cpp @@ -145,7 +145,7 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { 0.f}) / 6.f), }; - mgb_assert(map.size() + 12 == opr::Elemwise::Param::MODE_NR_MEMBER); + mgb_assert(map.size() + 16 == opr::Elemwise::Param::MODE_NR_MEMBER); // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, // ERFINV, ERFCINV, NOT, AND, OR, XOR return map; diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index c87e7a11..d7554fb6 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -613,6 +613,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) { RET(EL2(H_SWISH_GRAD, (i0 + i1), og)); case Mode::NOT: return nullptr; + case Mode::SILU: + RET(EL2(SILU_GRAD, i0, og)); + case Mode::GELU: + RET(EL2(GELU_GRAD, i0, og)); // binary case Mode::ABS_GRAD: diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index 8f8d233d..acee9f89 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -131,6 +131,12 @@ namespace { std::numeric_limits::digits)); } + float do_gelu_grad(float x, float y) { + float phi = 1.f / sqrtf(2.0 * M_PI) * expf(-0.5f * x * x); + float normcdf_v = 0.5f * (1.f + erff(x / sqrtf(2.f))); + return y * (normcdf_v + x * phi); + } + /* ======================= basic framework ======================= */ template @@ -563,6 +569,9 @@ namespace { } }; + template<> struct CheckerConfig: public NoGradCheckerConfig {}; + template<> struct CheckerConfig: public NoGradCheckerConfig {}; + /* ======================= ternary config ======================= */ template<> struct CheckerConfig: public BinaryInputMinGap {}; diff --git a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl index 481667d9..5fa82a8c 100644 --- a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl @@ -64,6 +64,10 @@ DEF_TRAIT(FUSE_ADD_H_SWISH, do_fuse_add_h_swish(x, y)) DEF_TRAIT(FAST_TANH_GRAD, do_fast_tanh_grad(x, y)) DEF_TRAIT(ATAN2, std::atan2(x, y)) DEF_TRAIT(H_SWISH_GRAD, do_h_swish_grad(x, y)) +DEF_TRAIT(SILU_GRAD, y*(1 + std::exp(-x) + x * std::exp(-x)) / + (1 + std::exp(-x)) / (1 + std::exp(-x))) +DEF_TRAIT(GELU_GRAD, do_gelu_grad(x, y)) + #undef _ALLOW_INT #undef _ALLOW_FLOAT diff --git a/src/opr/test/basic_arith/elemwise_unary_trait_def.inl b/src/opr/test/basic_arith/elemwise_unary_trait_def.inl index 1d1d64c6..778648b6 100644 --- a/src/opr/test/basic_arith/elemwise_unary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_unary_trait_def.inl @@ -56,6 +56,8 @@ DEF_TRAIT(ERFINV, do_erfinv(x)) DEF_TRAIT(ERFC, std::erfc(x)) DEF_TRAIT(ERFCINV, do_erfcinv(x)) DEF_TRAIT(H_SWISH, do_h_swish(x)) +DEF_TRAIT(SILU, x / (1 + std::exp(-x))) +DEF_TRAIT(GELU, x*(0.5f * (1.f + std::erf(x / std::sqrt(2.f))))) #undef _ALLOW_INT #undef _ALLOW_FLOAT