@@ -10,13 +10,13 @@ MODES = { | |||||
1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | 1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | ||||
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | ||||
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', | 'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', | ||||
'ERFCINV', 'H_SWISH'], | |||||
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'], | |||||
2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | 2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | ||||
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', | 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', | ||||
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', | 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', | ||||
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | ||||
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | ||||
'FUSE_ADD_H_SWISH'], | |||||
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | |||||
3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | 3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | ||||
} | } | ||||
@@ -21,13 +21,13 @@ MODES = { | |||||
(1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | (1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | ||||
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | ||||
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', | 'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', | ||||
'ERFCINV', 'H_SWISH'], | |||||
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'], | |||||
(2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | (2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | ||||
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', | 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', | ||||
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', | 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', | ||||
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | ||||
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | ||||
'FUSE_ADD_H_SWISH'], | |||||
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | |||||
(3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | (3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | ||||
(1, 'BOOL'): ['NOT'], | (1, 'BOOL'): ['NOT'], | ||||
(2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], | (2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], | ||||
@@ -410,7 +410,11 @@ pdef('Elemwise').add_enum( | |||||
Doc('NOT', 'unary: !x'), | Doc('NOT', 'unary: !x'), | ||||
Doc('AND', 'binary: x && y'), | Doc('AND', 'binary: x && y'), | ||||
Doc('OR', 'binary: x || y'), | Doc('OR', 'binary: x || y'), | ||||
Doc('XOR', 'binary: x ^ y') | |||||
Doc('XOR', 'binary: x ^ y'), | |||||
Doc('SILU', 'unary: x / (1 + exp(-x))'), | |||||
Doc('SILU_GRAD', 'binary: grad(x / (1 + exp(-x))'), | |||||
Doc('GELU', 'unary: x Phi(x)'), | |||||
Doc('GELU_GRAD', 'binary: grad(x Phi(x))'), | |||||
) | ) | ||||
pdef('ElemwiseMultiType').add_enum( | pdef('ElemwiseMultiType').add_enum( | ||||
@@ -25,6 +25,8 @@ | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) \ | ||||
MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \ | ||||
MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \ | ||||
MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) \ | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) \ | |||||
#define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \ | #define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \ | ||||
MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ | ||||
@@ -64,6 +66,8 @@ | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) \ | ||||
MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \ | ||||
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \ | ||||
MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) \ | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) \ | |||||
#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \ | #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \ | ||||
MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ | ||||
@@ -69,6 +69,31 @@ namespace megdnn { | |||||
return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx; | return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx; | ||||
} | } | ||||
//! grad of silu | |||||
__device__ __host__ inline float silu_grad(float x, float dy) { | |||||
const float one = 1.0; | |||||
float sigmoid = one / (one + expf(-x)); | |||||
return dy * sigmoid * (one + x * (one - sigmoid)); | |||||
} | |||||
__device__ __host__ inline float normcdf(float x) { | |||||
#if MEGDNN_CC_HOST | |||||
return 0.5f * (1.f + erff(x / sqrtf(2.f))); | |||||
#else | |||||
//! use cuda build-in math | |||||
return ::normcdff(x); | |||||
#endif | |||||
} | |||||
//! grad of gelu | |||||
__device__ __host__ inline float gelu_grad(float x, float dy) { | |||||
//! 1/ sqrt(2 * pi) | |||||
const float coeff = 0.3989422804014327f; | |||||
float phi = coeff * expf(-0.5f * x * x); | |||||
float normcdf_v = normcdf(x); | |||||
return dy * (normcdf_v + x * phi); | |||||
} | |||||
#include "src/common/elemwise/each_mode.inl" | #include "src/common/elemwise/each_mode.inl" | ||||
template<megcorePlatform_t plat, uint32_t mode, typename dtype> | template<megcorePlatform_t plat, uint32_t mode, typename dtype> | ||||
@@ -137,6 +162,8 @@ namespace megdnn { | |||||
DEF_KERN_FLOAT(ERFC, erfcf(x)); | DEF_KERN_FLOAT(ERFC, erfcf(x)); | ||||
DEF_KERN_FLOAT(ERFCINV, erfcinvf(x)); | DEF_KERN_FLOAT(ERFCINV, erfcinvf(x)); | ||||
DEF_KERN_FLOAT(H_SWISH, x * min(max(x + 3, 0.f), 6.f) * (1.f / 6.f)); | DEF_KERN_FLOAT(H_SWISH, x * min(max(x + 3, 0.f), 6.f) * (1.f / 6.f)); | ||||
DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f)); | |||||
DEF_KERN_FLOAT(GELU, x * normcdf(x)); | |||||
// int only | // int only | ||||
DEF_KERN(dt_bool, NOT, x ^ 1); | DEF_KERN(dt_bool, NOT, x ^ 1); | ||||
@@ -207,6 +234,8 @@ namespace megdnn { | |||||
x < -3.f ? (ctype)0.f : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y))); | x < -3.f ? (ctype)0.f : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y))); | ||||
DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y)); | DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y)); | ||||
DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y)); | |||||
DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); | |||||
#undef KERN_SIG | #undef KERN_SIG | ||||
/* ================== ternary kernels ================== */ | /* ================== ternary kernels ================== */ | ||||
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_float16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_float32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_float16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_float32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_float16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_float32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_float16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_float32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint8 | |||||
#define KERN_IMPL_DTYPE dt_qint8 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint8 | |||||
#define KERN_IMPL_DTYPE dt_qint8 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint8 | |||||
#define KERN_IMPL_DTYPE dt_qint8 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint8 | |||||
#define KERN_IMPL_DTYPE dt_qint8 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_float16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_float32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_float16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_float32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_float16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_float32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_float16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_float32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_float16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_float32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_float16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_float32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_float16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_float32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_float16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_float32 | |||||
#include "../kern_impl.inl" |
@@ -48,6 +48,7 @@ __all__ = [ | |||||
"deformable_psroi_pooling", | "deformable_psroi_pooling", | ||||
"dropout", | "dropout", | ||||
"embedding", | "embedding", | ||||
"gelu", | |||||
"hsigmoid", | "hsigmoid", | ||||
"hswish", | "hswish", | ||||
"indexing_one_hot", | "indexing_one_hot", | ||||
@@ -67,6 +68,7 @@ __all__ = [ | |||||
"sigmoid", | "sigmoid", | ||||
"sliding_window", | "sliding_window", | ||||
"sliding_window_transpose", | "sliding_window_transpose", | ||||
"silu", | |||||
"softmax", | "softmax", | ||||
"softplus", | "softplus", | ||||
"sync_batch_norm", | "sync_batch_norm", | ||||
@@ -766,6 +768,25 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: | |||||
return maximum(inp, 0) + negative_slope * minimum(inp, 0) | return maximum(inp, 0) + negative_slope * minimum(inp, 0) | ||||
def silu(x): | |||||
r""" | |||||
Applies the element-wise Sigmoid Linear Unit function, i.e. `x * sigmoid(x)`. | |||||
""" | |||||
return _elwise(x, mode=Elemwise.Mode.SILU) | |||||
def gelu(x): | |||||
r""" | |||||
Applies the element-wise function: | |||||
.. math:: | |||||
\text{gelu}(x) = x\Phi(x) | |||||
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. | |||||
""" | |||||
return _elwise(x, mode=Elemwise.Mode.GELU) | |||||
def softplus(inp: Tensor) -> Tensor: | def softplus(inp: Tensor) -> Tensor: | ||||
r""" | r""" | ||||
Applies the element-wise function: | Applies the element-wise function: | ||||
@@ -7,7 +7,7 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||||
from .activation import GELU, LeakyReLU, PReLU, ReLU, Sigmoid, SiLU, Softmax | |||||
from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d | from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d | ||||
from .batch_matmul_activation import BatchMatMulActivation | from .batch_matmul_activation import BatchMatMulActivation | ||||
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | ||||
@@ -8,7 +8,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import numpy as np | import numpy as np | ||||
from ..functional import leaky_relu, prelu, relu, sigmoid, softmax | |||||
from ..functional import gelu, leaky_relu, prelu, relu, sigmoid, silu, softmax | |||||
from ..tensor import Parameter | from ..tensor import Parameter | ||||
from .module import Module | from .module import Module | ||||
@@ -92,6 +92,74 @@ class Sigmoid(Module): | |||||
return sigmoid(inputs) | return sigmoid(inputs) | ||||
class SiLU(Module): | |||||
r""" | |||||
Applies the element-wise function: | |||||
.. math:: | |||||
\text{SiLU}(x) = \frac{x}{1 + \exp(-x)} | |||||
Examples: | |||||
.. testcode:: | |||||
import numpy as np | |||||
import megengine as mge | |||||
import megengine.module as M | |||||
data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32)) | |||||
silu = M.SiLU() | |||||
output = silu(data) | |||||
with np.printoptions(precision=6): | |||||
print(output.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
[-0.238406 -0.268941 0. 0.731059 1.761594] | |||||
""" | |||||
def forward(self, inputs): | |||||
return silu(inputs) | |||||
class GELU(Module): | |||||
r""" | |||||
Applies the element-wise function: | |||||
.. math:: | |||||
\text{GELU}(x) = x\Phi(x) | |||||
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. | |||||
Examples: | |||||
.. testcode:: | |||||
import numpy as np | |||||
import megengine as mge | |||||
import megengine.module as M | |||||
data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32)) | |||||
gelu = M.GELU() | |||||
output = gelu(data) | |||||
with np.printoptions(precision=4): | |||||
print(output.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
[-0.0455 -0.1587 0. 0.8413 1.9545] | |||||
""" | |||||
def forward(self, inputs): | |||||
return gelu(inputs) | |||||
class ReLU(Module): | class ReLU(Module): | ||||
r""" | r""" | ||||
Applies the element-wise function: | Applies the element-wise function: | ||||
@@ -28,6 +28,8 @@ class Elemwise(Module): | |||||
* "fuse_add_sigmoid": sigmoid(x + y) | * "fuse_add_sigmoid": sigmoid(x + y) | ||||
* "fuse_add_tanh": tanh(x + y) | * "fuse_add_tanh": tanh(x + y) | ||||
* "relu": x > 0 ? x : 0 | * "relu": x > 0 ? x : 0 | ||||
* "silu": silu(x) | |||||
* "gelu": gelu(x) | |||||
* "abs": x > 0 ? x : -x | * "abs": x > 0 ? x : -x | ||||
* "sigmoid": sigmoid(x) | * "sigmoid": sigmoid(x) | ||||
* "exp": exp(x) | * "exp": exp(x) | ||||
@@ -144,6 +144,13 @@ def test_hswish(): | |||||
np.testing.assert_almost_equal(y_np, y_mge, decimal=6) | np.testing.assert_almost_equal(y_np, y_mge, decimal=6) | ||||
def test_silu(): | |||||
x = np.array([-1.5, 0.0, 1.0, 1.5]).astype("float32") | |||||
y_np = x / (1 + np.exp(-x)) | |||||
y_mge = F.silu(tensor(x)).numpy() | |||||
np.testing.assert_almost_equal(y_np, y_mge, decimal=6) | |||||
def test_hsigmoid(): | def test_hsigmoid(): | ||||
np.random.seed(42) | np.random.seed(42) | ||||
x = np.random.randn(100).astype("float32") | x = np.random.randn(100).astype("float32") | ||||
@@ -145,7 +145,7 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { | |||||
0.f}) / | 0.f}) / | ||||
6.f), | 6.f), | ||||
}; | }; | ||||
mgb_assert(map.size() + 12 == opr::Elemwise::Param::MODE_NR_MEMBER); | |||||
mgb_assert(map.size() + 16 == opr::Elemwise::Param::MODE_NR_MEMBER); | |||||
// unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, | // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, | ||||
// ERFINV, ERFCINV, NOT, AND, OR, XOR | // ERFINV, ERFCINV, NOT, AND, OR, XOR | ||||
return map; | return map; | ||||
@@ -613,6 +613,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) { | |||||
RET(EL2(H_SWISH_GRAD, (i0 + i1), og)); | RET(EL2(H_SWISH_GRAD, (i0 + i1), og)); | ||||
case Mode::NOT: | case Mode::NOT: | ||||
return nullptr; | return nullptr; | ||||
case Mode::SILU: | |||||
RET(EL2(SILU_GRAD, i0, og)); | |||||
case Mode::GELU: | |||||
RET(EL2(GELU_GRAD, i0, og)); | |||||
// binary | // binary | ||||
case Mode::ABS_GRAD: | case Mode::ABS_GRAD: | ||||
@@ -131,6 +131,12 @@ namespace { | |||||
std::numeric_limits<T>::digits)); | std::numeric_limits<T>::digits)); | ||||
} | } | ||||
float do_gelu_grad(float x, float y) { | |||||
float phi = 1.f / sqrtf(2.0 * M_PI) * expf(-0.5f * x * x); | |||||
float normcdf_v = 0.5f * (1.f + erff(x / sqrtf(2.f))); | |||||
return y * (normcdf_v + x * phi); | |||||
} | |||||
/* ======================= basic framework ======================= */ | /* ======================= basic framework ======================= */ | ||||
template<typename ctype, bool stable_sign = false> | template<typename ctype, bool stable_sign = false> | ||||
@@ -563,6 +569,9 @@ namespace { | |||||
} | } | ||||
}; | }; | ||||
template<> struct CheckerConfig<SILU_GRAD>: public NoGradCheckerConfig {}; | |||||
template<> struct CheckerConfig<GELU_GRAD>: public NoGradCheckerConfig {}; | |||||
/* ======================= ternary config ======================= */ | /* ======================= ternary config ======================= */ | ||||
template<> struct CheckerConfig<COND_LEQ_MOV>: | template<> struct CheckerConfig<COND_LEQ_MOV>: | ||||
public BinaryInputMinGap<false> {}; | public BinaryInputMinGap<false> {}; | ||||
@@ -64,6 +64,10 @@ DEF_TRAIT(FUSE_ADD_H_SWISH, do_fuse_add_h_swish(x, y)) | |||||
DEF_TRAIT(FAST_TANH_GRAD, do_fast_tanh_grad(x, y)) | DEF_TRAIT(FAST_TANH_GRAD, do_fast_tanh_grad(x, y)) | ||||
DEF_TRAIT(ATAN2, std::atan2(x, y)) | DEF_TRAIT(ATAN2, std::atan2(x, y)) | ||||
DEF_TRAIT(H_SWISH_GRAD, do_h_swish_grad(x, y)) | DEF_TRAIT(H_SWISH_GRAD, do_h_swish_grad(x, y)) | ||||
DEF_TRAIT(SILU_GRAD, y*(1 + std::exp(-x) + x * std::exp(-x)) / | |||||
(1 + std::exp(-x)) / (1 + std::exp(-x))) | |||||
DEF_TRAIT(GELU_GRAD, do_gelu_grad(x, y)) | |||||
#undef _ALLOW_INT | #undef _ALLOW_INT | ||||
#undef _ALLOW_FLOAT | #undef _ALLOW_FLOAT | ||||
@@ -56,6 +56,8 @@ DEF_TRAIT(ERFINV, do_erfinv(x)) | |||||
DEF_TRAIT(ERFC, std::erfc(x)) | DEF_TRAIT(ERFC, std::erfc(x)) | ||||
DEF_TRAIT(ERFCINV, do_erfcinv(x)) | DEF_TRAIT(ERFCINV, do_erfcinv(x)) | ||||
DEF_TRAIT(H_SWISH, do_h_swish(x)) | DEF_TRAIT(H_SWISH, do_h_swish(x)) | ||||
DEF_TRAIT(SILU, x / (1 + std::exp(-x))) | |||||
DEF_TRAIT(GELU, x*(0.5f * (1.f + std::erf(x / std::sqrt(2.f))))) | |||||
#undef _ALLOW_INT | #undef _ALLOW_INT | ||||
#undef _ALLOW_FLOAT | #undef _ALLOW_FLOAT | ||||