diff --git a/dnn/scripts/gen_elemwise_multi_type_utils.py b/dnn/scripts/gen_elemwise_multi_type_utils.py index 6de5a512..39aec818 100755 --- a/dnn/scripts/gen_elemwise_multi_type_utils.py +++ b/dnn/scripts/gen_elemwise_multi_type_utils.py @@ -14,23 +14,27 @@ MODES = { 1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', 'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', - 'ERFCINV', 'H_SWISH', 'SILU', 'GELU'], + 'ERFCINV', 'H_SWISH', 'SILU', 'GELU', 'SINH', 'COSH', + 'ASINH', 'ACOSH', 'ATANH', 'TAN', 'SOFTPLUS', 'RELU6', + 'HSIGMOID', 'LOGSIGMOID', 'SQRT', 'SQUARE', 'SIGN'], 2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', - 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], - 3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], + 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD', 'PRELU', + 'ASINH_GRAD', 'ACOSH_GRAD', 'ATANH_GRAD', 'SOFTPLUS_GRAD', + 'RELU6_GRAD', 'HSIGMOID_GRAD'], + 3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3', 'CLIP', 'PRELU_GRAD'], } QINT4_MODES = { 1: ['RELU', 'ABS', 'NEGATE', 'CEIL', 'FLOOR', 'SIGMOID', 'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'], - 2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0', - 'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH', - 'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'], - 3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], + 2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0', + 'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH', + 'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH', 'PRELU'], + 3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3', 'CLIP'], } QINT32_MODES = { diff --git a/dnn/scripts/gen_elemwise_utils.py b/dnn/scripts/gen_elemwise_utils.py index 84bc541f..52f28bb7 100755 --- a/dnn/scripts/gen_elemwise_utils.py +++ b/dnn/scripts/gen_elemwise_utils.py @@ -12,23 +12,27 @@ DTYPES = {'dt_int32': ('Int32', 'INT'), } MODES = { - (1, 'INT'): ['RELU', 'ABS', 'NEGATE'], + (1, 'INT'): ['RELU', 'ABS', 'NEGATE', 'RELU6', 'SQUARE', 'SIGN'], (2, 'INT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ', - 'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH'], - (3, 'INT'): ['COND_LEQ_MOV', 'COND_LT_MOV'], + 'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH', 'PRELU'], + (3, 'INT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'CLIP'], (1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', 'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', - 'ERFCINV', 'H_SWISH', 'SILU', 'GELU'], + 'ERFCINV', 'H_SWISH', 'SILU', 'GELU', 'SINH', 'COSH', + 'ASINH', 'ACOSH', 'ATANH', 'TAN', 'SOFTPLUS', 'RELU6', + 'HSIGMOID', 'LOGSIGMOID', 'SQRT', 'SQUARE', 'SIGN'], (2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', - 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], - (3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], + 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD', 'PRELU', + 'ASINH_GRAD', 'ACOSH_GRAD', 'ATANH_GRAD', 'SOFTPLUS_GRAD', + 'RELU6_GRAD', 'HSIGMOID_GRAD'], + (3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3', 'CLIP', 'PRELU_GRAD'], (1, 'BOOL'): ['NOT'], (2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], (3, 'BOOL'): [] diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 1c7dd193..0e901b57 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -421,9 +421,31 @@ pdef('Elemwise').add_enum( Doc('GELU = 58', 'unary: x Phi(x)'), Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'), Doc('COND_LT_MOV = 60', 'ternary: x < y ? z : 0'), - Doc('NEQ = 61', 'binary: x != y'), - Doc('ISNAN = 62', 'unary: isnan(x)'), - Doc('ISINF = 63', 'unary: isinf(x)'), + Doc('SINH = 61', 'unary: sinh(x)'), + Doc('COSH = 62', 'unary: cosh(x)'), + Doc('ASINH = 63', 'unary: asinh(x)'), + Doc('ACOSH = 64', 'unary: acosh(x)'), + Doc('ATANH = 65', 'unary: atanh(x)'), + Doc('TAN = 66', 'unary: tan(x)'), + Doc('ASINH_GRAD = 67', 'binary: y / sqrt(x^2 + 1)'), + Doc('ACOSH_GRAD = 68', 'binary: y / sqrt(x^2 - 1) (x > 1)'), + Doc('ATANH_GRAD = 69', 'binary: y / (1 - x^2) (|x| < 1)'), + Doc('PRELU = 70', 'binary: x > 0 ? x : x * y'), + Doc('CLIP = 71', 'ternary: x <= y ? y : (x <= z ? x : z)'), + Doc('PRELU_GRAD = 72', 'ternary: x > 0 ? y : y * z'), + Doc('SOFTPLUS = 73', 'unary: log(1 + e^x)'), + Doc('SOFTPLUS_GRAD = 74', 'binary: y * e^x / (1 + e^x)'), + Doc('RELU6 = 75', 'unary: min(max(0, x), 6)'), + Doc('RELU6_GRAD = 76', 'binary: x < 0 ? 0 : (x > 6 ? 0 : y)'), + Doc('HSIGMOID = 77', 'unary: relu6(x + 3) / 6'), + Doc('HSIGMOID_GRAD = 78', 'binary: x < -3 ? 0 : (x > 3 ? 0 : y / 6)'), + Doc('LOGSIGMOID = 79', 'unary: -log(1 + e^(-x))'), + Doc('SQRT = 80', 'unary: x^(1/2)'), + Doc('SQUARE = 81', 'unary: x^2'), + Doc('SIGN = 82', 'unary: sgn(x)'), + Doc('NEQ = 83', 'binary: x != y'), + Doc('ISNAN = 84', 'unary: isnan(x)'), + Doc('ISINF = 85', 'unary: isinf(x)'), ) pdef('ElemwiseMultiType').add_enum( diff --git a/dnn/src/common/elemwise/each_mode.inl b/dnn/src/common/elemwise/each_mode.inl index 67a144ca..48ca51dc 100644 --- a/dnn/src/common/elemwise/each_mode.inl +++ b/dnn/src/common/elemwise/each_mode.inl @@ -25,12 +25,28 @@ MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) + MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) + MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \ @@ -66,7 +82,14 @@ MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) + MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ @@ -86,15 +109,19 @@ MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) + MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_BOOL(cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) + MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) + MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index 87788c0d..95105470 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -154,11 +154,18 @@ struct ElemwiseKern; // int and float DEF_KERN_ALL(NEGATE, -x); +DEF_KERN_ALL(SQUARE, x* x); #if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__) DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x); +DEF_KERN_INT(RELU6, x <= ctype(0) ? ctype(0) : (x <= ctype(6) ? x : ctype(6))); +DEF_KERN_INT(SIGN, x < ctype(0) ? ctype(-1) : (x > ctype(0) ? ctype(1) : ctype(0))); DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x); +DEF_KERN_FLOAT(RELU6, x <= 6.f ? ctype(0) : (x <= 6.f ? x : ctype(6))); +DEF_KERN_FLOAT(SIGN, x < 0.f ? -1.f : (x > 0.f ? 1.f : 0.f)); #else DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x); +DEF_KERN_ALL(RELU6, x <= ctype(0) ? ctype(0) : (x <= ctype(6) ? x : ctype(6))); +DEF_KERN_ALL(SIGN, x < ctype(0) ? ctype(-1) : (x > ctype(0) ? ctype(1) : ctype(0))); #endif DEF_KERN_INT(ABS, abs(int(x))); // DEF_KERN_INT(ABS, x > ctype(0) ? x : -x); @@ -186,6 +193,18 @@ DEF_KERN_FLOAT(ERFCINV, erfcinvf(x)); DEF_KERN_FLOAT(H_SWISH, x* min(max(x + 3, 0.f), 6.f) * (1.f / 6.f)); DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f)); DEF_KERN_FLOAT(GELU, x* normcdf(x)); +DEF_KERN_FLOAT(SINH, sinhf(x)); +DEF_KERN_FLOAT(COSH, coshf(x)); +DEF_KERN_FLOAT(ASINH, asinhf(x)); +DEF_KERN_FLOAT(ACOSH, acoshf(x)); +DEF_KERN_FLOAT(ATANH, atanhf(x)); +DEF_KERN_FLOAT(TAN, tanf(x)); +DEF_KERN_FLOAT(SOFTPLUS, log1pf(expf(-fabsf(x))) + (x <= ctype(0) ? ctype(0) : x)); +DEF_KERN_FLOAT( + HSIGMOID, + x <= ctype(-3) ? ctype(0) : (x >= ctype(3) ? ctype(1) : ((x + 3.f) / 6.f))); +DEF_KERN_FLOAT(SQRT, sqrtf(x)); +DEF_KERN_FLOAT(LOGSIGMOID, -log1pf(expf(-fabsf(x))) + (x >= ctype(0) ? ctype(0) : x)); // int only DEF_KERN(dt_bool, NOT, x ^ 1); @@ -240,6 +259,12 @@ DEF_KERN_FLOAT(FUSE_ADD_RELU, (x + y) <= 0.f ? ctype(0) : (x + y)); #else DEF_KERN_ALL(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y)); #endif +#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__) +DEF_KERN_INT(PRELU, x > ctype(0) ? x : (x * y)); +DEF_KERN_FLOAT(PRELU, x > 0.f ? x : (x * y)); +#else +DEF_KERN_ALL(PRELU, x > ctype(0) ? x : (x * y)); +#endif // float only DEF_KERN_FLOAT(TRUE_DIV, x / y); @@ -259,6 +284,14 @@ DEF_KERN_FLOAT( DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y)); DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y)); DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); +DEF_KERN_FLOAT(ASINH_GRAD, y / sqrt(x * x + 1.f)); +DEF_KERN_FLOAT(ACOSH_GRAD, y / sqrt(x * x - 1.f)); +DEF_KERN_FLOAT(ATANH_GRAD, y / (1.f - x * x)); +DEF_KERN_FLOAT(SOFTPLUS_GRAD, y* expf(x) / (1.f + expf(x))); +DEF_KERN_FLOAT(RELU6_GRAD, x <= ctype(0) ? ctype(0) : (x >= ctype(6) ? ctype(0) : y)); +DEF_KERN_FLOAT( + HSIGMOID_GRAD, + x <= ctype(-3) ? ctype(0) : (x >= ctype(3) ? ctype(0) : (y / 6.f))); #undef KERN_SIG /* ================== ternary kernels ================== */ @@ -268,6 +301,8 @@ DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0)); DEF_KERN_ALL(COND_LT_MOV, x < y ? z : ctype(0)); DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z); +DEF_KERN_ALL(CLIP, x <= y ? y : (x <= z ? x : z)); +DEF_KERN_FLOAT(PRELU_GRAD, x >= 0.f ? y : (y * z)); #undef KERN_SIG diff --git a/dnn/src/common/elemwise/opr_impl.cpp b/dnn/src/common/elemwise/opr_impl.cpp index ed994866..2f0a1d5a 100644 --- a/dnn/src/common/elemwise/opr_impl.cpp +++ b/dnn/src/common/elemwise/opr_impl.cpp @@ -220,6 +220,28 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { CB_MODE(Mode::GELU); CB_MODE(Mode::GELU_GRAD); CB_MODE(Mode::COND_LT_MOV); + CB_MODE(Mode::SINH); + CB_MODE(Mode::COSH); + CB_MODE(Mode::ASINH); + CB_MODE(Mode::ACOSH); + CB_MODE(Mode::ATANH); + CB_MODE(Mode::TAN); + CB_MODE(Mode::ASINH_GRAD); + CB_MODE(Mode::ACOSH_GRAD); + CB_MODE(Mode::ATANH_GRAD); + CB_MODE(Mode::PRELU); + CB_MODE(Mode::PRELU_GRAD); + CB_MODE(Mode::CLIP); + CB_MODE(Mode::SOFTPLUS); + CB_MODE(Mode::SOFTPLUS_GRAD); + CB_MODE(Mode::RELU6); + CB_MODE(Mode::RELU6_GRAD); + CB_MODE(Mode::HSIGMOID); + CB_MODE(Mode::HSIGMOID_GRAD); + CB_MODE(Mode::LOGSIGMOID); + CB_MODE(Mode::SQRT); + CB_MODE(Mode::SQUARE); + CB_MODE(Mode::SIGN); default: megdnn_assert( 0, diff --git a/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..fc3af7a7 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float16.cu new file mode 100644 index 00000000..1c4f89c8 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float32.cu new file mode 100644 index 00000000..7674459b --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_bfloat16.cu new file mode 100644 index 00000000..80411b71 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float16.cu new file mode 100644 index 00000000..aa417709 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float32.cu new file mode 100644 index 00000000..cfcf7ad3 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..393728a4 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float16.cu new file mode 100644 index 00000000..807bedca --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float32.cu new file mode 100644 index 00000000..7a3d5a3e --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ASINH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ASINH_dt_bfloat16.cu new file mode 100644 index 00000000..7a912381 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ASINH_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ASINH_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ASINH_dt_float16.cu new file mode 100644 index 00000000..29c9307d --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ASINH_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ASINH_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ASINH_dt_float32.cu new file mode 100644 index 00000000..b6fc8ab6 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ASINH_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..cf7afe1b --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float16.cu new file mode 100644 index 00000000..333b8f60 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float32.cu new file mode 100644 index 00000000..be794c32 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ATANH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ATANH_dt_bfloat16.cu new file mode 100644 index 00000000..fa6683df --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ATANH_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ATANH_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ATANH_dt_float16.cu new file mode 100644 index 00000000..804e5361 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ATANH_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ATANH_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ATANH_dt_float32.cu new file mode 100644 index 00000000..5fd222b3 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ATANH_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/CLIP_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_bfloat16.cu new file mode 100644 index 00000000..de73cd15 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/CLIP_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_float16.cu new file mode 100644 index 00000000..25e351c1 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/CLIP_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_float32.cu new file mode 100644 index 00000000..cf8dc776 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int16.cu new file mode 100644 index 00000000..f60b5c4c --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int16.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int32.cu new file mode 100644 index 00000000..c003f595 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int8.cu new file mode 100644 index 00000000..cb0ec046 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/CLIP_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_uint8.cu new file mode 100644 index 00000000..b0198d93 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_uint8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/COSH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/COSH_dt_bfloat16.cu new file mode 100644 index 00000000..7cb17527 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/COSH_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/COSH_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/COSH_dt_float16.cu new file mode 100644 index 00000000..5f42f235 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/COSH_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/COSH_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/COSH_dt_float32.cu new file mode 100644 index 00000000..94ea1870 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/COSH_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..a8115bff --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cu new file mode 100644 index 00000000..a1fb7ee3 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cu new file mode 100644 index 00000000..9c0a4aeb --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_bfloat16.cu new file mode 100644 index 00000000..28a83976 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float16.cu new file mode 100644 index 00000000..cdb77455 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float32.cu new file mode 100644 index 00000000..528f944c --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cu new file mode 100644 index 00000000..06322df6 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float16.cu new file mode 100644 index 00000000..d0b6c026 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float32.cu new file mode 100644 index 00000000..ea1bcf1a --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..dc331500 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float16.cu new file mode 100644 index 00000000..34411818 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float32.cu new file mode 100644 index 00000000..1fedc850 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_bfloat16.cu new file mode 100644 index 00000000..78c18f9d --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_float16.cu new file mode 100644 index 00000000..33e6ce73 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_float32.cu new file mode 100644 index 00000000..46f2d367 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int16.cu new file mode 100644 index 00000000..d1dfa9ac --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int16.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int32.cu new file mode 100644 index 00000000..d6d7332f --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int8.cu new file mode 100644 index 00000000..621a7dd3 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_uint8.cu new file mode 100644 index 00000000..86ff475d --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_uint8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..90699665 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float16.cu new file mode 100644 index 00000000..efb61fa6 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float32.cu new file mode 100644 index 00000000..6088f41d --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_bfloat16.cu new file mode 100644 index 00000000..cf79b7e9 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_float16.cu new file mode 100644 index 00000000..0646045d --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_float32.cu new file mode 100644 index 00000000..2fe7746f --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int16.cu new file mode 100644 index 00000000..32c2dab3 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int16.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int32.cu new file mode 100644 index 00000000..e59877c3 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int8.cu new file mode 100644 index 00000000..6f6f7741 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_uint8.cu new file mode 100644 index 00000000..60812b55 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_uint8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGN_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_bfloat16.cu new file mode 100644 index 00000000..34316156 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SIGN_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_float16.cu new file mode 100644 index 00000000..04ac0b86 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SIGN_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_float32.cu new file mode 100644 index 00000000..0402184f --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int16.cu new file mode 100644 index 00000000..0a854c23 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int16.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int32.cu new file mode 100644 index 00000000..5f3aa927 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int8.cu new file mode 100644 index 00000000..c0d44608 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGN_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_uint8.cu new file mode 100644 index 00000000..37f4b4b2 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_uint8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SINH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SINH_dt_bfloat16.cu new file mode 100644 index 00000000..19b3b24d --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SINH_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SINH_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SINH_dt_float16.cu new file mode 100644 index 00000000..0298140e --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SINH_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SINH_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SINH_dt_float32.cu new file mode 100644 index 00000000..d781a287 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SINH_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..9769ef87 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cu new file mode 100644 index 00000000..694fea1b --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cu new file mode 100644 index 00000000..05710880 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cu new file mode 100644 index 00000000..7df279c7 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float16.cu new file mode 100644 index 00000000..98d84dad --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float32.cu new file mode 100644 index 00000000..898996df --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SQRT_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SQRT_dt_bfloat16.cu new file mode 100644 index 00000000..b6483dfe --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQRT_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SQRT_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SQRT_dt_float16.cu new file mode 100644 index 00000000..262e68d4 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQRT_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SQRT_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SQRT_dt_float32.cu new file mode 100644 index 00000000..1c6aa2af --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQRT_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_bfloat16.cu new file mode 100644 index 00000000..59f5383e --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float16.cu new file mode 100644 index 00000000..c53551bb --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float32.cu new file mode 100644 index 00000000..4282b479 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int16.cu new file mode 100644 index 00000000..5d7bec08 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int16.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int32.cu new file mode 100644 index 00000000..c1237c37 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int8.cu new file mode 100644 index 00000000..8c01483e --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_uint8.cu new file mode 100644 index 00000000..094c4191 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_uint8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/TAN_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/TAN_dt_bfloat16.cu new file mode 100644 index 00000000..415ed3fc --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/TAN_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/TAN_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/TAN_dt_float16.cu new file mode 100644 index 00000000..dd5339a3 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/TAN_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/TAN_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/TAN_dt_float32.cu new file mode 100644 index 00000000..796c4e65 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/TAN_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..74e69bb6 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..6734fd67 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..6af7f406 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..9b080b73 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..1d477ec4 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..934699b3 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_qint4_dt_qint4.cu new file mode 100644 index 00000000..accf56fa --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_qint4_dt_qint4.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 +#include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..78de5e8a --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_quint4_dt_quint4.cu new file mode 100644 index 00000000..754de3e6 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_quint4_dt_quint4.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 +#include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/COSH_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/COSH_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..2bcc45ab --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/COSH_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/HSIGMOID_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/HSIGMOID_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..a50f44c9 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/HSIGMOID_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/HSIGMOID_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/HSIGMOID_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..5cac0b94 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/HSIGMOID_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LOGSIGMOID_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LOGSIGMOID_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..bb2abf6a --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LOGSIGMOID_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/PRELU_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/PRELU_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..1518830e --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/PRELU_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/PRELU_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/PRELU_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..60fa38fa --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/PRELU_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU6_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU6_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..49ef8a94 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU6_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU6_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU6_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..7cf2b512 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU6_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGN_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGN_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..279bb2bb --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGN_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SINH_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SINH_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..68f6f9a7 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SINH_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SOFTPLUS_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SOFTPLUS_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..2c8aa8b6 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SOFTPLUS_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SOFTPLUS_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SOFTPLUS_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..7c250442 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SOFTPLUS_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SQRT_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SQRT_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..48290ef9 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SQRT_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SQUARE_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SQUARE_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..a24e0457 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SQUARE_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/TAN_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/TAN_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..ecf0bcb6 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/TAN_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp b/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp index 64a90258..2f45e524 100644 --- a/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp @@ -267,7 +267,10 @@ IMPL_MODE_DISPATCHER(2, dt_qint4, dt_qint4); IMPL_MODE_DISPATCHER(2, dt_quint4, dt_quint4); #undef FOREACH -#define FOREACH MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT +#define FOREACH(cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) IMPL_MODE_DISPATCHER(3, dt_qint4, dt_qint4); IMPL_MODE_DISPATCHER(3, dt_quint4, dt_quint4); #undef FOREACH diff --git a/dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp b/dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp index 4197a2da..f0cf6c13 100644 --- a/dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp +++ b/dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp @@ -228,6 +228,7 @@ INST(Mode::SHL); INST(Mode::SHR); INST(Mode::FUSE_ADD_RELU); INST(Mode::RMULH); +INST(Mode::PRELU); #undef INST #define INST(mode) \ @@ -258,6 +259,13 @@ INST(Mode::H_SWISH_GRAD); INST(Mode::FUSE_ADD_H_SWISH); INST(Mode::SILU_GRAD); INST(Mode::GELU_GRAD); +INST(Mode::PRELU); +INST(Mode::ASINH_GRAD); +INST(Mode::ACOSH_GRAD); +INST(Mode::ATANH_GRAD); +INST(Mode::SOFTPLUS_GRAD); +INST(Mode::RELU6_GRAD); +INST(Mode::HSIGMOID_GRAD); #undef INST } // namespace fallback } // namespace megdnn diff --git a/dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp b/dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp index b36e7dae..bf7acf09 100644 --- a/dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp +++ b/dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp @@ -77,6 +77,9 @@ using Mode = param_enumv::Elemwise::Mode; INST(Mode::RELU); INST(Mode::ABS); INST(Mode::NEGATE); +INST(Mode::RELU6); +INST(Mode::SQUARE); +INST(Mode::SIGN); #undef INST #define INST(mode) \ @@ -105,6 +108,19 @@ INST(Mode::ERFCINV); INST(Mode::H_SWISH); INST(Mode::SILU); INST(Mode::GELU); +INST(Mode::SINH); +INST(Mode::COSH); +INST(Mode::ASINH); +INST(Mode::ACOSH); +INST(Mode::ATANH); +INST(Mode::TAN); +INST(Mode::SOFTPLUS); +INST(Mode::RELU6); +INST(Mode::HSIGMOID); +INST(Mode::LOGSIGMOID); +INST(Mode::SQRT); +INST(Mode::SQUARE); +INST(Mode::SIGN); #undef INST } // namespace fallback } // namespace megdnn diff --git a/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..fc3af7a7 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_float16.cpp new file mode 100644 index 00000000..1c4f89c8 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_float32.cpp new file mode 100644 index 00000000..7674459b --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/ACOSH_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ACOSH_dt_bfloat16.cpp new file mode 100644 index 00000000..80411b71 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ACOSH_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ACOSH_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/ACOSH_dt_float16.cpp new file mode 100644 index 00000000..aa417709 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ACOSH_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ACOSH_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/ACOSH_dt_float32.cpp new file mode 100644 index 00000000..cfcf7ad3 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ACOSH_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..393728a4 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_float16.cpp new file mode 100644 index 00000000..807bedca --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_float32.cpp new file mode 100644 index 00000000..7a3d5a3e --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/ASINH_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ASINH_dt_bfloat16.cpp new file mode 100644 index 00000000..7a912381 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ASINH_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ASINH_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/ASINH_dt_float16.cpp new file mode 100644 index 00000000..29c9307d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ASINH_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ASINH_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/ASINH_dt_float32.cpp new file mode 100644 index 00000000..b6fc8ab6 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ASINH_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..cf7afe1b --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_float16.cpp new file mode 100644 index 00000000..333b8f60 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_float32.cpp new file mode 100644 index 00000000..be794c32 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/ATANH_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ATANH_dt_bfloat16.cpp new file mode 100644 index 00000000..fa6683df --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ATANH_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ATANH_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/ATANH_dt_float16.cpp new file mode 100644 index 00000000..804e5361 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ATANH_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ATANH_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/ATANH_dt_float32.cpp new file mode 100644 index 00000000..5fd222b3 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ATANH_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/CLIP_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/CLIP_dt_bfloat16.cpp new file mode 100644 index 00000000..de73cd15 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/CLIP_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/CLIP_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/CLIP_dt_float16.cpp new file mode 100644 index 00000000..25e351c1 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/CLIP_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/CLIP_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/CLIP_dt_float32.cpp new file mode 100644 index 00000000..cf8dc776 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/CLIP_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/CLIP_dt_int16.cpp b/dnn/src/naive/elemwise/kimpl/CLIP_dt_int16.cpp new file mode 100644 index 00000000..f60b5c4c --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/CLIP_dt_int16.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/CLIP_dt_int32.cpp b/dnn/src/naive/elemwise/kimpl/CLIP_dt_int32.cpp new file mode 100644 index 00000000..c003f595 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/CLIP_dt_int32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/CLIP_dt_int8.cpp b/dnn/src/naive/elemwise/kimpl/CLIP_dt_int8.cpp new file mode 100644 index 00000000..cb0ec046 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/CLIP_dt_int8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/CLIP_dt_uint8.cpp b/dnn/src/naive/elemwise/kimpl/CLIP_dt_uint8.cpp new file mode 100644 index 00000000..b0198d93 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/CLIP_dt_uint8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/COSH_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/COSH_dt_bfloat16.cpp new file mode 100644 index 00000000..7cb17527 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/COSH_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/COSH_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/COSH_dt_float16.cpp new file mode 100644 index 00000000..5f42f235 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/COSH_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/COSH_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/COSH_dt_float32.cpp new file mode 100644 index 00000000..94ea1870 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/COSH_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..a8115bff --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cpp new file mode 100644 index 00000000..a1fb7ee3 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cpp new file mode 100644 index 00000000..9c0a4aeb --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_bfloat16.cpp new file mode 100644 index 00000000..28a83976 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_float16.cpp new file mode 100644 index 00000000..cdb77455 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_float32.cpp new file mode 100644 index 00000000..528f944c --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cpp new file mode 100644 index 00000000..06322df6 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_float16.cpp new file mode 100644 index 00000000..d0b6c026 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_float32.cpp new file mode 100644 index 00000000..ea1bcf1a --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..dc331500 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_float16.cpp new file mode 100644 index 00000000..34411818 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_float32.cpp new file mode 100644 index 00000000..1fedc850 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_dt_bfloat16.cpp new file mode 100644 index 00000000..78c18f9d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_dt_float16.cpp new file mode 100644 index 00000000..33e6ce73 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_dt_float32.cpp new file mode 100644 index 00000000..46f2d367 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_dt_int16.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_dt_int16.cpp new file mode 100644 index 00000000..d1dfa9ac --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_dt_int16.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_dt_int32.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_dt_int32.cpp new file mode 100644 index 00000000..d6d7332f --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_dt_int32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_dt_int8.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_dt_int8.cpp new file mode 100644 index 00000000..621a7dd3 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_dt_int8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_dt_uint8.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_dt_uint8.cpp new file mode 100644 index 00000000..86ff475d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_dt_uint8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..90699665 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_float16.cpp new file mode 100644 index 00000000..efb61fa6 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_float32.cpp new file mode 100644 index 00000000..6088f41d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_dt_bfloat16.cpp new file mode 100644 index 00000000..cf79b7e9 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_dt_float16.cpp new file mode 100644 index 00000000..0646045d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_dt_float32.cpp new file mode 100644 index 00000000..2fe7746f --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_dt_int16.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_dt_int16.cpp new file mode 100644 index 00000000..32c2dab3 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_dt_int16.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_dt_int32.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_dt_int32.cpp new file mode 100644 index 00000000..e59877c3 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_dt_int32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_dt_int8.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_dt_int8.cpp new file mode 100644 index 00000000..6f6f7741 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_dt_int8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_dt_uint8.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_dt_uint8.cpp new file mode 100644 index 00000000..60812b55 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_dt_uint8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SIGN_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SIGN_dt_bfloat16.cpp new file mode 100644 index 00000000..34316156 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGN_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SIGN_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/SIGN_dt_float16.cpp new file mode 100644 index 00000000..04ac0b86 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGN_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SIGN_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/SIGN_dt_float32.cpp new file mode 100644 index 00000000..0402184f --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGN_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SIGN_dt_int16.cpp b/dnn/src/naive/elemwise/kimpl/SIGN_dt_int16.cpp new file mode 100644 index 00000000..0a854c23 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGN_dt_int16.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SIGN_dt_int32.cpp b/dnn/src/naive/elemwise/kimpl/SIGN_dt_int32.cpp new file mode 100644 index 00000000..5f3aa927 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGN_dt_int32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SIGN_dt_int8.cpp b/dnn/src/naive/elemwise/kimpl/SIGN_dt_int8.cpp new file mode 100644 index 00000000..c0d44608 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGN_dt_int8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SIGN_dt_uint8.cpp b/dnn/src/naive/elemwise/kimpl/SIGN_dt_uint8.cpp new file mode 100644 index 00000000..37f4b4b2 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGN_dt_uint8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SINH_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SINH_dt_bfloat16.cpp new file mode 100644 index 00000000..19b3b24d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SINH_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SINH_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/SINH_dt_float16.cpp new file mode 100644 index 00000000..0298140e --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SINH_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SINH_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/SINH_dt_float32.cpp new file mode 100644 index 00000000..d781a287 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SINH_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..9769ef87 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cpp new file mode 100644 index 00000000..694fea1b --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cpp new file mode 100644 index 00000000..05710880 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cpp new file mode 100644 index 00000000..7df279c7 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_float16.cpp new file mode 100644 index 00000000..98d84dad --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_float32.cpp new file mode 100644 index 00000000..898996df --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SQRT_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SQRT_dt_bfloat16.cpp new file mode 100644 index 00000000..b6483dfe --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQRT_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SQRT_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/SQRT_dt_float16.cpp new file mode 100644 index 00000000..262e68d4 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQRT_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SQRT_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/SQRT_dt_float32.cpp new file mode 100644 index 00000000..1c6aa2af --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQRT_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SQUARE_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_bfloat16.cpp new file mode 100644 index 00000000..59f5383e --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SQUARE_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_float16.cpp new file mode 100644 index 00000000..c53551bb --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SQUARE_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_float32.cpp new file mode 100644 index 00000000..4282b479 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int16.cpp b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int16.cpp new file mode 100644 index 00000000..5d7bec08 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int16.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int32.cpp b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int32.cpp new file mode 100644 index 00000000..c1237c37 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int8.cpp b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int8.cpp new file mode 100644 index 00000000..8c01483e --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SQUARE_dt_uint8.cpp b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_uint8.cpp new file mode 100644 index 00000000..094c4191 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_uint8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/TAN_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/TAN_dt_bfloat16.cpp new file mode 100644 index 00000000..415ed3fc --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/TAN_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/TAN_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/TAN_dt_float16.cpp new file mode 100644 index 00000000..dd5339a3 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/TAN_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/TAN_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/TAN_dt_float32.cpp new file mode 100644 index 00000000..796c4e65 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/TAN_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..a231d0a3 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..ca103ee3 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..11f768c9 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..28c3f173 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_float16.cpp.hip new file mode 100644 index 00000000..06d2c12e --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_float32.cpp.hip new file mode 100644 index 00000000..9e969396 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..d0f61ca2 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..0840a54c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..c239ddb2 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ASINH_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ASINH_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..ea0f11ec --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ASINH_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ASINH_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ASINH_dt_float16.cpp.hip new file mode 100644 index 00000000..cc5d7302 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ASINH_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ASINH_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ASINH_dt_float32.cpp.hip new file mode 100644 index 00000000..96076302 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ASINH_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..3b0496f4 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..1a0c841a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..ce400c03 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ATANH_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ATANH_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..9ac3b63a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ATANH_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ATANH_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ATANH_dt_float16.cpp.hip new file mode 100644 index 00000000..56902e43 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ATANH_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ATANH_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ATANH_dt_float32.cpp.hip new file mode 100644 index 00000000..048cab3a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ATANH_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/CLIP_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..91c018ae --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/CLIP_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_float16.cpp.hip new file mode 100644 index 00000000..1d06d8d1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/CLIP_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_float32.cpp.hip new file mode 100644 index 00000000..346efcbd --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int16.cpp.hip new file mode 100644 index 00000000..4394848f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int32.cpp.hip new file mode 100644 index 00000000..ed51ebd8 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int8.cpp.hip new file mode 100644 index 00000000..dea82a24 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/CLIP_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_uint8.cpp.hip new file mode 100644 index 00000000..0d798b15 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/COSH_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COSH_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..25c504a3 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COSH_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/COSH_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COSH_dt_float16.cpp.hip new file mode 100644 index 00000000..a5e92a13 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COSH_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/COSH_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COSH_dt_float32.cpp.hip new file mode 100644 index 00000000..2d4c2784 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COSH_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..54b03d90 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..eaa54cca --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..a000bb10 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..fdb642b6 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_float16.cpp.hip new file mode 100644 index 00000000..94e88cd6 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_float32.cpp.hip new file mode 100644 index 00000000..8e13dd53 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..fa85681c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_float16.cpp.hip new file mode 100644 index 00000000..4992d81d --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_float32.cpp.hip new file mode 100644 index 00000000..72e6e3a9 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..d24ceb4a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..4665a277 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..023f6fe1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..2ae7e683 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_float16.cpp.hip new file mode 100644 index 00000000..1e1253ca --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_float32.cpp.hip new file mode 100644 index 00000000..9d3b9676 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int16.cpp.hip new file mode 100644 index 00000000..03846baf --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int32.cpp.hip new file mode 100644 index 00000000..41f41670 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int8.cpp.hip new file mode 100644 index 00000000..88c2bbb6 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_uint8.cpp.hip new file mode 100644 index 00000000..a9febead --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..f2099c6b --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..b46f1c52 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..37b7d1ba --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..bb39a1a8 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_float16.cpp.hip new file mode 100644 index 00000000..f84eac71 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_float32.cpp.hip new file mode 100644 index 00000000..fd3fe60f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int16.cpp.hip new file mode 100644 index 00000000..66a41225 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int32.cpp.hip new file mode 100644 index 00000000..1a5eed82 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int8.cpp.hip new file mode 100644 index 00000000..f2ecc40a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_uint8.cpp.hip new file mode 100644 index 00000000..b13aad0d --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGN_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..ee5373bf --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SIGN_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_float16.cpp.hip new file mode 100644 index 00000000..69277cd5 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SIGN_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_float32.cpp.hip new file mode 100644 index 00000000..718709a1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int16.cpp.hip new file mode 100644 index 00000000..d1cdfbd5 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int32.cpp.hip new file mode 100644 index 00000000..2955aef0 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int8.cpp.hip new file mode 100644 index 00000000..e6cc0849 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGN_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_uint8.cpp.hip new file mode 100644 index 00000000..b39727a4 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SINH_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SINH_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..7b0d4740 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SINH_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SINH_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SINH_dt_float16.cpp.hip new file mode 100644 index 00000000..070496fb --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SINH_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SINH_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SINH_dt_float32.cpp.hip new file mode 100644 index 00000000..d295a78f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SINH_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..2de73538 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..da7dabfd --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..f123cd7e --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..f86b264d --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_float16.cpp.hip new file mode 100644 index 00000000..991ba230 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_float32.cpp.hip new file mode 100644 index 00000000..55716f22 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SQRT_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQRT_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..9c4b9211 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQRT_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SQRT_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQRT_dt_float16.cpp.hip new file mode 100644 index 00000000..ad05f946 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQRT_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SQRT_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQRT_dt_float32.cpp.hip new file mode 100644 index 00000000..448f9a0c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQRT_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..aa6d5912 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_float16.cpp.hip new file mode 100644 index 00000000..1db3690d --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_float32.cpp.hip new file mode 100644 index 00000000..da4072ae --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int16.cpp.hip new file mode 100644 index 00000000..487b9824 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int32.cpp.hip new file mode 100644 index 00000000..ae617e29 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int8.cpp.hip new file mode 100644 index 00000000..9ebe2f85 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_uint8.cpp.hip new file mode 100644 index 00000000..b01411f1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/TAN_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TAN_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..9a274f50 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TAN_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/TAN_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TAN_dt_float16.cpp.hip new file mode 100644 index 00000000..522e02f8 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TAN_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/TAN_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TAN_dt_float32.cpp.hip new file mode 100644 index 00000000..42a0c3b1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TAN_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/test/common/elemwise.cpp b/dnn/test/common/elemwise.cpp index 31f9bcf0..2ef46962 100644 --- a/dnn/test/common/elemwise.cpp +++ b/dnn/test/common/elemwise.cpp @@ -744,8 +744,8 @@ DEF_TEST(all_modes) { TensorShapeArray shapes; UniformFloatRNG default_rng_f32{-100.f, 100.f}, pos_rng_f32{.1f, 1000.f}, small_pos_rng_f32{.1f, .10f}, small_rng_f32{-3.f, 3.f}, - abslt1_rng_f32{-1.f, 1.f}, uniform_0_2_rng{0.f, 2.f}, - tanh_rng_f32{-5.f, 5.f}; + abslt1_rng_f32{-0.95f, 0.95f}, uniform_0_2_rng{0.f, 2.f}, + tanh_rng_f32{-5.f, 5.f}, lt1_rng_f32{1.f, 10.f}; UniformFloatNonZeroRNG nonzero_rng_f32{.1f, 1000.f}, big_nonzero_rng_f32{100.f, 1000.f}; UniformIntRNG default_rng_i32{-100, 100}, small_rng_i32{-2, 2}, @@ -786,12 +786,14 @@ DEF_TEST(all_modes) { shapes[shapes.size() - 1] = {}; auto do_run = [&](DType dtype, float eps = 1e-3) { // limit value ranges for some modes - if (mode == Mode::LOG || mode == Mode::LOG1P) { + if (mode == Mode::LOG || mode == Mode::LOG1P || mode == Mode::SQRT) { checker.set_rng(0, &pos_rng_f32); - } else if (mode == Mode::POW) { + } else if (mode == Mode::POW || mode == Mode::SOFTPLUS_GRAD) { checker.set_rng(0, &small_pos_rng_f32); checker.set_rng(1, &small_rng_f32); - } else if (mode == Mode::EXP || mode == Mode::EXPM1) { + } else if ( + mode == Mode::EXP || mode == Mode::EXPM1 || mode == Mode::SINH || + mode == Mode::COSH) { checker.set_rng(0, &small_rng_f32); } else if (mode == Mode::FAST_TANH) { checker.set_rng(0, &tanh_rng_f32); @@ -807,6 +809,10 @@ DEF_TEST(all_modes) { checker.set_rng(1, &default_rng_f32); } else if (mode == Mode::ERFCINV) { checker.set_rng(0, &uniform_0_2_rng); + } else if (mode == Mode::ACOSH_GRAD || mode == Mode::ACOSH) { + checker.set_rng(0, <1_rng_f32); + } else if (mode == Mode::ATANH_GRAD || mode == Mode::ATANH) { + checker.set_rng(0, &abslt1_rng_f32); } else if ( mode == Mode::MOD || mode == Mode::TRUE_DIV || mode == Mode::FLOOR_DIV) { diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index e92cb2d9..b4568c6e 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -467,12 +467,12 @@ def log1p(x): def sqrt(x: Tensor) -> Tensor: r"""Element-wise `sqrt`.""" - return x ** 0.5 + return _elwise(x, mode=Elemwise.Mode.SQRT) def square(x: Tensor) -> Tensor: r"""Element-wise `square`.""" - return x ** 2 + return _elwise(x, mode=Elemwise.Mode.SQUARE) def round(x): @@ -515,7 +515,7 @@ def sin(x): def tan(x): r"""Element-wise `tangent`.""" - return sin(x) / cos(x) + return _elwise(x, mode=Elemwise.Mode.TAN) def acos(x): @@ -544,13 +544,12 @@ def atan2(y, x): def cosh(x): r"""Element-wise `hyperbolic cosine`.""" - return 0.5 * (exp(x) + exp(-x)) + return _elwise(x, mode=Elemwise.Mode.COSH) def sinh(x): r"""Element-wise `hyperbolic sine`.""" - u = expm1(x) - return 0.5 * u / (u + 1) * (u + 2) + return _elwise(x, mode=Elemwise.Mode.SINH) def tanh(x): @@ -560,17 +559,17 @@ def tanh(x): def asinh(x): r"""Element-wise `inverse hyperbolic sine`.""" - return log(x + (x ** 2 + 1) ** 0.5) + return _elwise(x, mode=Elemwise.Mode.ASINH) def acosh(x): r"""Element-wise `inverse hyperbolic cosine`.""" - return log(x + (x ** 2 - 1) ** 0.5) + return _elwise(x, mode=Elemwise.Mode.ACOSH) def atanh(x): r"""Element-wise `inverse hyperbolic tangent`.""" - return log1p(2 * x / (1 - x)) / 2 + return _elwise(x, mode=Elemwise.Mode.ATANH) # bit-twiddling functions @@ -680,7 +679,7 @@ def clip(x: Tensor, lower=None, upper=None) -> Tensor: ), "At least one of 'lower' or 'upper' must not be None" if lower is not None: if upper is not None: - return minimum(maximum(x, lower), upper) + return _elwise(x, lower, upper, mode=Elemwise.Mode.CLIP) else: return maximum(x, lower) else: diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 978221ac..42fa014e 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -6,7 +6,7 @@ from typing import Iterable, Optional, Sequence, Tuple, Union from ..core._imperative_rt.core2 import Const, apply from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core.ops import builtin -from ..core.tensor.array_method import _matmul +from ..core.tensor.array_method import _elwise, _matmul from ..core.tensor.utils import _normalize_axis from ..tensor import Tensor from ..utils.deprecation import deprecated_kwargs_default @@ -86,7 +86,7 @@ def sign(inp: Tensor): >>> F.sign(x) Tensor([ 1 -1 0], dtype=int32, device=xpux:0) """ - return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype) + return _elwise(inp, mode=builtin.Elemwise.Mode.SIGN) def sum( diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 58a60b4d..8a42e906 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -753,37 +753,9 @@ def sigmoid(x): return _elwise(x, mode=Elemwise.Mode.SIGMOID) -@lru_cache(maxsize=None) -def _get_hsigmoid_op(dtype=None, device=None): - @subgraph_fn( - "Hsigmoid", - dtype=dtype, - device=device, - nr_inputs=1, - jit_fusion=True, - custom_grad=True, - ) - def hsigmoid(inputs, f, c): - (inp,) = inputs[0:1] - inp = f("+", inp, c(3)) - max_0 = f("max", inp, c(0)) - min_6 = f("min", max_0, c(6)) - oup = f("/", min_6, c(6)) - (oup_grad,) = yield (oup,) - inp_grad = f("/", oup_grad, c(6)) - inp_grad = f("cond_leq_mov", max_0, c(6), inp_grad) - inp_grad = f("cond_leq_mov", c(0), inp, inp_grad) - yield (inp_grad,) - - return hsigmoid - - def hsigmoid(x): r"""Element-wise `relu6(x + 3) / 6`.""" - hsigmoid = _get_hsigmoid_op(x.dtype, x.device) - (x,) = hsigmoid(x) - return x - # return relu6(x + 3) / 6 + return _elwise(x, mode=Elemwise.Mode.HSIGMOID) def relu(x): @@ -791,95 +763,14 @@ def relu(x): return _elwise(x, mode=Elemwise.Mode.RELU) -@lru_cache(maxsize=None) -def _get_relu6_op(dtype=None, device=None): - @subgraph_fn( - "ReLU6", - dtype=dtype, - device=device, - nr_inputs=1, - jit_fusion=True, - custom_grad=True, - ) - def relu6(inputs, f, c): - (inp,) = inputs[0:1] - max_0 = f("max", inp, c(0)) - min_6 = f("min", max_0, c(6)) - oup = min_6 - (oup_grad,) = yield (oup,) - inp_grad = f("cond_leq_mov", max_0, c(6), oup_grad) - inp_grad = f("cond_leq_mov", c(0), inp, inp_grad) - yield (inp_grad,) - - return relu6 - - def relu6(x): r"""Element-wise `min(max(x, 0), 6)`.""" - relu6 = _get_relu6_op(x.dtype, x.device) - (x,) = relu6(x) - return x - - -@lru_cache(maxsize=None) -def _get_prelu_op(dtype=None, device=None): - @subgraph_fn( - "PReLU", - dtype=dtype, - device=device, - nr_inputs=2, - jit_fusion=True, - custom_grad=True, - ) - def prelu(inputs, f, c): - (inp, weight) = inputs[0:2] - max_0 = f("max", inp, c(0)) - min_0 = f("min", inp, c(0)) - oup = f("fma3", min_0, weight, max_0) - (oup_grad,) = yield (oup,) - inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad) - inp_grad_1 = f("*", oup_grad, weight) - inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1) - inp_grad = f("+", inp_grad_0, inp_grad_1) - weight_grad = f("*", oup_grad, min_0) - yield (inp_grad, weight_grad) - - return prelu - - -def prelu(inp: Tensor, weight: Tensor) -> Tensor: - r"""Element-wise PReLU function. - - Refer to :class:`~.PReLU` for more information. - """ - prelu = _get_prelu_op(dtype=inp.dtype, device=inp.device) - (oup,) = prelu(inp, broadcast_to(weight, inp.shape)) - return oup + return _elwise(x, mode=Elemwise.Mode.RELU6) -@lru_cache(maxsize=None) -def _get_leaky_relu_op(negative_slope, *, dtype=None, device=None): - @subgraph_fn( - "LeakyReLU", - dtype=dtype, - device=device, - nr_inputs=1, - jit_fusion=True, - custom_grad=True, - ) - def leakyReLU(inputs, f, c): - (inp,) = inputs[0:1] - max_0 = f("max", inp, c(0)) - min_0 = f("min", inp, c(0)) - oup = f("+", max_0, f("*", min_0, c(negative_slope))) - (oup_grad,) = yield (oup,) - inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad) - inp_grad_1 = f("*", oup_grad, c(negative_slope)) - inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1) - inp_grad = f("+", inp_grad_0, inp_grad_1) - yield (inp_grad,) - - return leakyReLU +def prelu(x, y): + r"""Element-wise `max(x, 0) + y * min(x, 0)`.""" + return _elwise(x, y, mode=Elemwise.Mode.PRELU) def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: @@ -887,9 +778,7 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: Refer to :class:`~.LeakyReLU` for more information. """ - leakyReLU = _get_leaky_relu_op(negative_slope, dtype=inp.dtype, device=inp.device) - (oup,) = leakyReLU(inp) - return oup + return _elwise(inp, negative_slope, mode=Elemwise.Mode.PRELU) def silu(x): @@ -908,36 +797,6 @@ def gelu(x): return _elwise(x, mode=Elemwise.Mode.GELU) -@lru_cache(maxsize=None) -def _get_softplus_op(dtype=None, device=None): - @subgraph_fn( - "Softplus", - dtype=dtype, - device=device, - nr_inputs=1, - jit_fusion=True, - custom_grad=True, - ) - def softplus(inputs, f, c): - (inp,) = inputs[0:1] - neg_abs = f("-", f("abs", inp)) - exp = f("exp", neg_abs) - oup0 = f("log1p", exp) - oup1 = f("relu", inp) - oup = f("+", oup0, oup1) - (oup_grad,) = yield (oup,) - inp_grad_0 = f("switch_gt0", oup1, oup_grad) - inp_grad_1 = oup_grad - inp_grad_1 = f("/", oup_grad, f("+", exp, c(1))) - inp_grad_1 = f("*", inp_grad_1, exp) - inp_grad_1 = f("-", inp_grad_1) - inp_grad_1 = f("abs_grad", inp, inp_grad_1) - inp_grad = f("+", inp_grad_0, inp_grad_1) - yield (inp_grad,) - - return softplus - - def softplus(inp: Tensor) -> Tensor: r"""Applies the element-wise function: @@ -960,9 +819,7 @@ def softplus(inp: Tensor) -> Tensor: >>> y.numpy().round(decimals=4) array([0.0486, 0.1269, 0.3133, 0.6931, 1.3133, 2.1269], dtype=float32) """ - softplus = _get_softplus_op(inp.dtype, inp.device) - (oup,) = softplus(inp) - return oup + return _elwise(inp, mode=Elemwise.Mode.SOFTPLUS) def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: @@ -991,39 +848,6 @@ def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: return inp - logsumexp(inp, axis, keepdims=True) -@lru_cache(maxsize=None) -def _get_logsigmoid_op(dtype=None, device=None): - @subgraph_fn( - "LogSigmoid", - dtype=dtype, - device=device, - nr_inputs=1, - jit_fusion=True, - custom_grad=True, - ) - def logsigmoid(inputs, f, c): - (inp,) = inputs[0:1] - neg_abs = f("-", f("abs", inp)) - exp = f("exp", neg_abs) - oup0 = f("log1p", exp) - oup1 = f("relu", f("-", inp)) - oup = f("+", oup0, oup1) - oup = f("-", oup) - (oup_grad,) = yield (oup,) - oup_grad = f("-", oup_grad) - inp_grad_0 = f("switch_gt0", oup1, oup_grad) - inp_grad_0 = f("-", inp_grad_0) - inp_grad_1 = oup_grad - inp_grad_1 = f("/", inp_grad_1, f("+", exp, c(1))) - inp_grad_1 = f("*", inp_grad_1, exp) - inp_grad_1 = f("-", inp_grad_1) - inp_grad_1 = f("abs_grad", inp, inp_grad_1) - inp_grad = f("+", inp_grad_0, inp_grad_1) - yield (inp_grad,) - - return logsigmoid - - def logsigmoid(inp: Tensor) -> Tensor: r"""Applies the element-wise function: @@ -1041,9 +865,7 @@ def logsigmoid(inp: Tensor) -> Tensor: array([-5.0067, -4.0182, -3.0486, -2.1269, -1.3133, -0.6931, -0.3133, -0.1269, -0.0486, -0.0181], dtype=float32) """ - logsigmoid = _get_logsigmoid_op(inp.dtype, inp.device) - (oup,) = logsigmoid(inp) - return oup + return _elwise(inp, mode=Elemwise.Mode.LOGSIGMOID) def logsumexp( diff --git a/imperative/src/impl/transformations/dtype_promote.cpp b/imperative/src/impl/transformations/dtype_promote.cpp index 58de880a..5e7743d7 100644 --- a/imperative/src/impl/transformations/dtype_promote.cpp +++ b/imperative/src/impl/transformations/dtype_promote.cpp @@ -116,12 +116,17 @@ ValueRefList elemwise_rule(const OpDef& op, Span inputs) { } static std::unordered_set cast_case1 = { - Elemwise::Mode::TRUE_DIV, Elemwise::Mode::EXP, - Elemwise::Mode::POW, Elemwise::Mode::LOG, - Elemwise::Mode::EXPM1, Elemwise::Mode::LOG1P, - Elemwise::Mode::ACOS, Elemwise::Mode::ASIN, - Elemwise::Mode::ATAN2, Elemwise::Mode::COS, - Elemwise::Mode::SIN, Elemwise::Mode::LOG_SUM_EXP, + Elemwise::Mode::TRUE_DIV, Elemwise::Mode::EXP, + Elemwise::Mode::POW, Elemwise::Mode::LOG, + Elemwise::Mode::EXPM1, Elemwise::Mode::LOG1P, + Elemwise::Mode::ACOS, Elemwise::Mode::ASIN, + Elemwise::Mode::ATAN2, Elemwise::Mode::COS, + Elemwise::Mode::SIN, Elemwise::Mode::LOG_SUM_EXP, + Elemwise::Mode::TAN, Elemwise::Mode::ASINH, + Elemwise::Mode::ACOSH, Elemwise::Mode::ATANH, + Elemwise::Mode::SINH, Elemwise::Mode::COSH, + Elemwise::Mode::SOFTPLUS, Elemwise::Mode::HSIGMOID, + Elemwise::Mode::LOGSIGMOID, Elemwise::Mode::SQRT, }; static std::unordered_set cast_case2 = { diff --git a/src/jit/impl/ast_c.cpp b/src/jit/impl/ast_c.cpp index edce0007..eb1dc794 100644 --- a/src/jit/impl/ast_c.cpp +++ b/src/jit/impl/ast_c.cpp @@ -133,7 +133,7 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { 0.f}) / 6.f), }; - mgb_assert(map.size() + 19 == opr::Elemwise::Param::MODE_NR_MEMBER); + mgb_assert(map.size() + 41 == opr::Elemwise::Param::MODE_NR_MEMBER); // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, // ERFINV, ERFCINV, NOT, AND, OR, XOR, NEQ, ISNAN, ISINF return map; diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 8111b4e0..9924f747 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -543,6 +543,34 @@ MGB_IMPL_OPR_GRAD(Elemwise) { RET(EL2(SILU_GRAD, i0, og)); case Mode::GELU: RET(EL2(GELU_GRAD, i0, og)); + case Mode::SINH: + RET(EL1(COSH, i0) * og); + case Mode::COSH: + RET(EL1(SINH, i0) * og); + case Mode::ASINH: + RET(EL2(ASINH_GRAD, i0, og)); + case Mode::ACOSH: + RET(EL2(ACOSH_GRAD, i0, og)); + case Mode::ATANH: + RET(EL2(ATANH_GRAD, i0, og)); + case Mode::TAN: { + auto two = i0.make_scalar_dt(2); + RET(og / (EL2(POW, EL1(COS, i0), two))); + } + case Mode::RELU6: + RET(EL2(RELU6_GRAD, i0, og)); + case Mode::SOFTPLUS: + RET(EL2(SOFTPLUS_GRAD, i0, og)); + case Mode::HSIGMOID: + RET(EL2(HSIGMOID_GRAD, i0, og)); + case Mode::LOGSIGMOID: + RET(EL2(SOFTPLUS_GRAD, EL1(NEGATE, i0), og)); + case Mode::SQRT: + RET(og / EL1(SQRT, i0) / 2); + case Mode::SQUARE: + RET(og * 2 * i0); + case Mode::SIGN: + RET(i0.make_scalar_dt(0).broadcast(i0.symshape())); // binary case Mode::ABS_GRAD: @@ -617,6 +645,11 @@ MGB_IMPL_OPR_GRAD(Elemwise) { case Mode::XOR: case Mode::AND: return nullptr; + case Mode::PRELU: + if (wrt_idx == 0) { + RET(EL3(PRELU_GRAD, i0, og, i1)); + } + RET(EL2(SWITCH_GT0, -i0, og * i0)); // ternary case Mode::COND_LEQ_MOV: @@ -627,6 +660,15 @@ MGB_IMPL_OPR_GRAD(Elemwise) { if (wrt_idx <= 1) return nullptr; RET(EL3(COND_LT_MOV, i0, i1, og)); + case Mode::CLIP: + if (wrt_idx == 0) { + RET(EL3(COND_LEQ_MOV, i1, i0, EL3(COND_LEQ_MOV, i0, i2, og))); + } + if (wrt_idx == 1) { + RET(EL3(COND_LEQ_MOV, i0, i1, og)); + } + RET(EL3(COND_LEQ_MOV, i2, i0, og)); + // fuse oprs case Mode::FUSE_MUL_ADD3: if (wrt_idx < 2) { diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index b4bab90d..89ae9ca3 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -349,6 +349,99 @@ struct CheckerConfig : public CheckerConfig {}; template <> struct CheckerConfig : public NoGradCheckerConfig {}; +template <> +struct CheckerConfig : public NoGradCheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(-1.2, 1.2); + } +}; +template <> +struct CheckerConfig : public CheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(-5, 5); + } + template + static void update_opt(Opt& opt) { + opt.numdiff_eps = 1e-2; + opt.numdiff_max_err = 0.1; + } +}; +template <> +struct CheckerConfig : public CheckerConfig {}; +template <> +struct CheckerConfig : public CheckerConfig { + template + static void update_opt(Opt& opt) { + opt.numdiff_eps = 1e-2; + opt.numdiff_max_err = 0.1; + } +}; +template <> +struct CheckerConfig : public CheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(1.05, 5); + } +}; +template <> +struct CheckerConfig : public CheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(-0.95, 0.95); + } +}; +template <> +struct CheckerConfig : public CheckerConfig {}; +template <> +struct CheckerConfig : public CheckerConfig {}; +template <> +struct CheckerConfig : public CheckerConfig {}; +template <> +struct CheckerConfig : public CheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(0.05, 5); + } + template + static void update_opt(Opt& opt) { + opt.numdiff_eps = 1e-2; + opt.numdiff_max_err = 0.1; + } +}; +template <> +struct CheckerConfig : public CheckerConfig { + template + static void do_update_checker(Checker& checker) { + auto icoord = [](const typename Checker::NumInpArray& inp) { + auto p0 = inp[0]->template ptr(); + for (size_t i = 0, it = inp[0]->shape().total_nr_elems(); i < it; ++i) { + if (std::abs(p0[i]) < 1) { + p0[i] += 2; + } else if (std::abs(p0[i] - 6) < 1) { + p0[i] += 2; + } + } + }; + checker.set_input_coordinator(icoord); + } + template + static void update_checker(Checker& checker) { + using ctype = typename Checker::ctype; + return do_update_checker(checker); + } +}; +template <> +struct CheckerConfig : public CheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(-2.95, 2.95); + } +}; +template <> +struct CheckerConfig : public NoZeroCheckerConfig<0> {}; + /* ======================= binary config ======================= */ template struct BinaryInputMinGap : public CheckerConfig { @@ -567,13 +660,85 @@ template <> struct CheckerConfig : public NoGradCheckerConfig {}; template <> struct CheckerConfig : public NoGradCheckerConfig {}; +template <> +struct CheckerConfig : public NoZeroCheckerConfig<0> {}; +template <> +struct CheckerConfig : public NoGradCheckerConfig {}; +template <> +struct CheckerConfig : public NoGradCheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(1.05, 5); + } +}; +template <> +struct CheckerConfig : public NoGradCheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(-0.95, 0.95); + } +}; +template <> +struct CheckerConfig : public NoGradCheckerConfig {}; +template <> +struct CheckerConfig : public NoGradCheckerConfig {}; +template <> +struct CheckerConfig : public NoGradCheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(-2.95, 2.95); + } +}; /* ======================= ternary config ======================= */ template <> struct CheckerConfig : public BinaryInputMinGap {}; template <> struct CheckerConfig : public BinaryInputMinGap {}; +struct CheckerConfig : public NoGradCheckerConfig {}; +template <> +struct CheckerConfig : public CheckerConfig { + template + static void do_update_checker(Checker& checker) { + auto icoord = [](const typename Checker::NumInpArray& inp) { + auto p0 = inp[0]->template ptr(), p1 = inp[1]->template ptr(), + p2 = inp[2]->template ptr(); + for (size_t i = 0, it = inp[0]->shape().total_nr_elems(); i < it; ++i) { + if (p1[i] > p2[i]) { + std::swap(p1[i], p2[i]); + } + if (p1[i] + 1 > p2[i]) { + p2[i] = p1[i] + 1; + } + if (std::abs(p1[i] - p0[i]) < 1) { + if (p1[i] < p0[i]) + p0[i] += 1; + else + p0[i] -= 1; + } + if (std::abs(p2[i] - p0[i]) < 1) { + if (p2[i] < p0[i]) + p0[i] += 1; + else + p0[i] -= 1; + } + } + }; + checker.set_input_coordinator(icoord); + } + + template + static void update_checker(Checker& checker) { + using ctype = typename Checker::ctype; + return do_update_checker(checker); + } + template + static void update_opt(Opt& opt) { + opt.numdiff_eps = 1e-3; + opt.numdiff_max_err = 0.1; + } +}; /* ======================= test runner ======================= */ namespace detail { template diff --git a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl index 0663986a..1ed742db 100644 --- a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl @@ -41,6 +41,7 @@ DEF_TRAIT(SWITCH_GT0, x > 0 ? y : 0) DEF_TRAIT(TANH_GRAD, (1 - x * x) * y) DEF_TRAIT(FUSE_ADD_RELU, std::max(x + y, 0)) +DEF_TRAIT(PRELU, (x > 0) ? x : (x* y)) #undef _ALLOW_INT #define _ALLOW_INT false @@ -57,6 +58,12 @@ DEF_TRAIT( SILU_GRAD, y*(1 + std::exp(-x) + x * std::exp(-x)) / (1 + std::exp(-x)) / (1 + std::exp(-x))) DEF_TRAIT(GELU_GRAD, do_gelu_grad(x, y)) +DEF_TRAIT(ASINH_GRAD, y / std::sqrt(x * x + 1)) +DEF_TRAIT(ACOSH_GRAD, y / std::sqrt(x * x - 1)) +DEF_TRAIT(ATANH_GRAD, y / (1 - x * x)) +DEF_TRAIT(SOFTPLUS_GRAD, y* std::exp(x) / (1.f + std::exp(x))) +DEF_TRAIT(RELU6_GRAD, x <= 0.f ? 0.f : (x >= 6.f ? 0.f : y)) +DEF_TRAIT(HSIGMOID_GRAD, x <= -3.f ? 0.f : (x >= 3.f ? 0.f : (y / 6.f))) #undef _ALLOW_INT #undef _ALLOW_FLOAT diff --git a/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl b/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl index e9e5cd8e..a9bf8af1 100644 --- a/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl @@ -15,6 +15,10 @@ DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0) DEF_TRAIT(COND_LT_MOV, x < y ? z : 0) DEF_TRAIT(FUSE_MUL_ADD3, x* y + z) +DEF_TRAIT(CLIP, x < y ? y : (x < z ? x : z)) +#undef _ALLOW_INT +#define _ALLOW_INT false +DEF_TRAIT(PRELU_GRAD, x > 0 ? y : (y * z)) #undef _ALLOW_INT #undef _ALLOW_FLOAT diff --git a/src/opr/test/basic_arith/elemwise_unary_trait_def.inl b/src/opr/test/basic_arith/elemwise_unary_trait_def.inl index a6e64ecd..edacc035 100644 --- a/src/opr/test/basic_arith/elemwise_unary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_unary_trait_def.inl @@ -22,6 +22,9 @@ DEF_TRAIT(NOT, !x) DEF_TRAIT(ABS, std::abs(x)) DEF_TRAIT(NEGATE, -x) DEF_TRAIT(RELU, std::max(x, 0)) +DEF_TRAIT(RELU6, std::min(std::max(x, 0), 6)) +DEF_TRAIT(SQUARE, x* x) +DEF_TRAIT(SIGN, x < 0 ? -1 : (x > 0 ? 1 : 0)) #undef _ALLOW_INT #define _ALLOW_INT false @@ -46,6 +49,16 @@ DEF_TRAIT(ERFCINV, do_erfcinv(x)) DEF_TRAIT(H_SWISH, do_h_swish(x)) DEF_TRAIT(SILU, x / (1 + std::exp(-x))) DEF_TRAIT(GELU, x*(0.5f * (1.f + std::erf(x / std::sqrt(2.f))))) +DEF_TRAIT(SINH, std::sinh(x)) +DEF_TRAIT(COSH, std::cosh(x)) +DEF_TRAIT(ASINH, std::asinh(x)) +DEF_TRAIT(ACOSH, std::acosh(x)) +DEF_TRAIT(ATANH, std::atanh(x)) +DEF_TRAIT(TAN, std::tan(x)) +DEF_TRAIT(SOFTPLUS, std::log1p(std::exp(-std::abs(x))) + std::max(x, 0)) +DEF_TRAIT(HSIGMOID, x <= -3.f ? 0.f : (x >= 3.f ? 1.f : ((x + 3.f) / 6.f))) +DEF_TRAIT(SQRT, std::sqrt(x)) +DEF_TRAIT(LOGSIGMOID, -std::log1p(std::exp(-std::abs(x))) - std::max(-x, 0)) #undef _ALLOW_INT #undef _ALLOW_FLOAT