diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index 9b556bc3..b0301aa6 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -313,7 +313,6 @@ protected: size_t workspace_in_bytes); }; using Cumsum = CumsumForward; - // mxx can be max or min class ArgmxxBase : public OperatorBase { DEF_OPR_IMPL_CTOR(ArgmxxBase, OperatorBase); diff --git a/dnn/scripts/gen_elemwise_multi_type_utils.py b/dnn/scripts/gen_elemwise_multi_type_utils.py index 004738c0..837787d0 100755 --- a/dnn/scripts/gen_elemwise_multi_type_utils.py +++ b/dnn/scripts/gen_elemwise_multi_type_utils.py @@ -48,6 +48,19 @@ MODES = { "H_SWISH", "SILU", "GELU", + "SINH", + "COSH", + "ASINH", + "ACOSH", + "ATANH", + "TAN", + "SOFTPLUS", + "RELU6", + "HSIGMOID", + "LOGSIGMOID", + "SQRT", + "SQUARE", + "SIGN", ], 2: [ "ABS_GRAD", @@ -76,8 +89,15 @@ MODES = { "FUSE_ADD_H_SWISH", "SILU_GRAD", "GELU_GRAD", + "PRELU", + "ASINH_GRAD", + "ACOSH_GRAD", + "ATANH_GRAD", + "SOFTPLUS_GRAD", + "RELU6_GRAD", + "HSIGMOID_GRAD", ], - 3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], + 3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3", "CLIP", "PRELU_GRAD"], } QINT4_MODES = { @@ -107,8 +127,9 @@ QINT4_MODES = { "FUSE_ADD_TANH", "FUSE_ADD_SIGMOID", "FUSE_ADD_H_SWISH", + "PRELU", ], - 3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], + 3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3", "CLIP"], } QINT32_MODES = { diff --git a/dnn/scripts/gen_elemwise_utils.py b/dnn/scripts/gen_elemwise_utils.py index 3b947a92..d5aa1a3b 100755 --- a/dnn/scripts/gen_elemwise_utils.py +++ b/dnn/scripts/gen_elemwise_utils.py @@ -12,7 +12,7 @@ DTYPES = { } MODES = { - (1, "INT"): ["RELU", "ABS", "NEGATE"], + (1, "INT"): ["RELU", "ABS", "NEGATE", "RELU6", "SQUARE", "SIGN"], (2, "INT"): [ "ABS_GRAD", "ADD", @@ -32,8 +32,9 @@ MODES = { "SHL", "SHR", "RMULH", + "PRELU", ], - (3, "INT"): ["COND_LEQ_MOV", "COND_LT_MOV"], + (3, "INT"): ["COND_LEQ_MOV", "COND_LT_MOV", "CLIP"], (1, "FLOAT"): [ "RELU", "ABS", @@ -59,6 +60,19 @@ MODES = { "H_SWISH", "SILU", "GELU", + "SINH", + "COSH", + "ASINH", + "ACOSH", + "ATANH", + "TAN", + "SOFTPLUS", + "RELU6", + "HSIGMOID", + "LOGSIGMOID", + "SQRT", + "SQUARE", + "SIGN", ], (2, "FLOAT"): [ "ABS_GRAD", @@ -87,8 +101,21 @@ MODES = { "FUSE_ADD_H_SWISH", "SILU_GRAD", "GELU_GRAD", + "PRELU", + "ASINH_GRAD", + "ACOSH_GRAD", + "ATANH_GRAD", + "SOFTPLUS_GRAD", + "RELU6_GRAD", + "HSIGMOID_GRAD", + ], + (3, "FLOAT"): [ + "COND_LEQ_MOV", + "COND_LT_MOV", + "FUSE_MUL_ADD3", + "CLIP", + "PRELU_GRAD", ], - (3, "FLOAT"): ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], (1, "BOOL"): ["NOT"], (2, "BOOL"): ["AND", "OR", "XOR", "LT", "LEQ", "EQ"], (3, "BOOL"): [], diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 3ab046e4..d09dacde 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -424,6 +424,28 @@ pdef('Elemwise').add_enum( Doc('NEQ = 61', 'binary: x != y'), Doc('ISNAN = 62', 'unary: isnan(x)'), Doc('ISINF = 63', 'unary: isinf(x)'), + Doc('SINH = 64', 'unary: sinh(x)'), + Doc('COSH = 65', 'unary: cosh(x)'), + Doc('ASINH = 66', 'unary: asinh(x)'), + Doc('ACOSH = 67', 'unary: acosh(x)'), + Doc('ATANH = 68', 'unary: atanh(x)'), + Doc('TAN = 69', 'unary: tan(x)'), + Doc('ASINH_GRAD = 70', 'binary: y / sqrt(x^2 + 1)'), + Doc('ACOSH_GRAD = 71', 'binary: y / sqrt(x^2 - 1) (x > 1)'), + Doc('ATANH_GRAD = 72', 'binary: y / (1 - x^2) (|x| < 1)'), + Doc('PRELU = 73', 'binary: x > 0 ? x : x * y'), + Doc('CLIP = 74', 'ternary: x <= y ? y : (x <= z ? x : z)'), + Doc('PRELU_GRAD = 75', 'ternary: x > 0 ? y : y * z'), + Doc('SOFTPLUS = 76', 'unary: log(1 + e^x)'), + Doc('SOFTPLUS_GRAD = 77', 'binary: y * e^x / (1 + e^x)'), + Doc('RELU6 = 78', 'unary: min(max(0, x), 6)'), + Doc('RELU6_GRAD = 79', 'binary: x < 0 ? 0 : (x > 6 ? 0 : y)'), + Doc('HSIGMOID = 80', 'unary: relu6(x + 3) / 6'), + Doc('HSIGMOID_GRAD = 81', 'binary: x < -3 ? 0 : (x > 3 ? 0 : y / 6)'), + Doc('LOGSIGMOID = 82', 'unary: -log(1 + e^(-x))'), + Doc('SQRT = 83', 'unary: x^(1/2)'), + Doc('SQUARE = 84', 'unary: x^2'), + Doc('SIGN = 85', 'unary: sgn(x)'), ) pdef('ElemwiseMultiType').add_enum( diff --git a/dnn/src/common/elemwise/each_mode.inl b/dnn/src/common/elemwise/each_mode.inl index 67a144ca..48ca51dc 100644 --- a/dnn/src/common/elemwise/each_mode.inl +++ b/dnn/src/common/elemwise/each_mode.inl @@ -25,12 +25,28 @@ MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) + MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) + MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \ @@ -66,7 +82,14 @@ MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) + MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ @@ -86,15 +109,19 @@ MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) + MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_BOOL(cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) + MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) + MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index 53953f22..a7057e7f 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -154,11 +154,18 @@ struct ElemwiseKern; // int and float DEF_KERN_ALL(NEGATE, -x); +DEF_KERN_ALL(SQUARE, x* x); #if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__) DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x); +DEF_KERN_INT(RELU6, x <= ctype(0) ? ctype(0) : (x <= ctype(6) ? x : ctype(6))); +DEF_KERN_INT(SIGN, x < ctype(0) ? ctype(-1) : (x > ctype(0) ? ctype(1) : ctype(0))); DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x); +DEF_KERN_FLOAT(RELU6, x <= 6.f ? ctype(0) : (x <= 6.f ? x : ctype(6))); +DEF_KERN_FLOAT(SIGN, x < 0.f ? -1.f : (x > 0.f ? 1.f : 0.f)); #else DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x); +DEF_KERN_ALL(RELU6, x <= ctype(0) ? ctype(0) : (x <= ctype(6) ? x : ctype(6))); +DEF_KERN_ALL(SIGN, x < ctype(0) ? ctype(-1) : (x > ctype(0) ? ctype(1) : ctype(0))); #endif DEF_KERN_INT(ABS, abs(int(x))); // DEF_KERN_INT(ABS, x > ctype(0) ? x : -x); @@ -186,6 +193,18 @@ DEF_KERN_FLOAT(ERFCINV, erfcinvf(x)); DEF_KERN_FLOAT(H_SWISH, x* min(max(x + 3, 0.f), 6.f) * (1.f / 6.f)); DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f)); DEF_KERN_FLOAT(GELU, x* normcdf(x)); +DEF_KERN_FLOAT(SINH, sinhf(x)); +DEF_KERN_FLOAT(COSH, coshf(x)); +DEF_KERN_FLOAT(ASINH, asinhf(x)); +DEF_KERN_FLOAT(ACOSH, acoshf(x)); +DEF_KERN_FLOAT(ATANH, atanhf(x)); +DEF_KERN_FLOAT(TAN, tanf(x)); +DEF_KERN_FLOAT(SOFTPLUS, log1pf(expf(-fabsf(x))) + (x <= ctype(0) ? ctype(0) : x)); +DEF_KERN_FLOAT( + HSIGMOID, + x <= ctype(-3) ? ctype(0) : (x >= ctype(3) ? ctype(1) : ((x + 3.f) / 6.f))); +DEF_KERN_FLOAT(SQRT, sqrtf(x)); +DEF_KERN_FLOAT(LOGSIGMOID, -log1pf(expf(-fabsf(x))) + (x >= ctype(0) ? ctype(0) : x)); // int only DEF_KERN(dt_bool, NOT, x ^ 1); @@ -240,6 +259,12 @@ DEF_KERN_FLOAT(FUSE_ADD_RELU, (x + y) <= 0.f ? ctype(0) : (x + y)); #else DEF_KERN_ALL(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y)); #endif +#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__) +DEF_KERN_INT(PRELU, x > ctype(0) ? x : (x * y)); +DEF_KERN_FLOAT(PRELU, x > 0.f ? x : (x * y)); +#else +DEF_KERN_ALL(PRELU, x > ctype(0) ? x : (x * y)); +#endif // float only DEF_KERN_FLOAT(TRUE_DIV, x / y); @@ -259,6 +284,14 @@ DEF_KERN_FLOAT( DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y)); DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y)); DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); +DEF_KERN_FLOAT(ASINH_GRAD, y / sqrt(x * x + 1.f)); +DEF_KERN_FLOAT(ACOSH_GRAD, y / sqrt(x * x - 1.f)); +DEF_KERN_FLOAT(ATANH_GRAD, y / (1.f - x * x)); +DEF_KERN_FLOAT(SOFTPLUS_GRAD, y* expf(x) / (1.f + expf(x))); +DEF_KERN_FLOAT(RELU6_GRAD, x <= ctype(0) ? ctype(0) : (x >= ctype(6) ? ctype(0) : y)); +DEF_KERN_FLOAT( + HSIGMOID_GRAD, + x <= ctype(-3) ? ctype(0) : (x >= ctype(3) ? ctype(0) : (y / 6.f))); #undef KERN_SIG /* ================== ternary kernels ================== */ @@ -268,6 +301,8 @@ DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0)); DEF_KERN_ALL(COND_LT_MOV, x < y ? z : ctype(0)); DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z); +DEF_KERN_ALL(CLIP, x <= y ? y : (x <= z ? x : z)); +DEF_KERN_FLOAT(PRELU_GRAD, x >= 0.f ? y : (y * z)); #undef KERN_SIG diff --git a/dnn/src/common/elemwise/opr_impl.cpp b/dnn/src/common/elemwise/opr_impl.cpp index b5981b8c..752eb3e5 100644 --- a/dnn/src/common/elemwise/opr_impl.cpp +++ b/dnn/src/common/elemwise/opr_impl.cpp @@ -62,6 +62,9 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb); + cb(NEQ); + cb(ISNAN); + cb(ISINF); #undef cb #define cb(_m) \ @@ -84,11 +87,14 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb); + cb(ISNAN); + cb(ISINF); #undef _a #define _a 2 MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb); + cb(NEQ); #undef _a #define _a 3 MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb); @@ -223,6 +229,28 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { CB_MODE(Mode::GELU); CB_MODE(Mode::GELU_GRAD); CB_MODE(Mode::COND_LT_MOV); + CB_MODE(Mode::SINH); + CB_MODE(Mode::COSH); + CB_MODE(Mode::ASINH); + CB_MODE(Mode::ACOSH); + CB_MODE(Mode::ATANH); + CB_MODE(Mode::TAN); + CB_MODE(Mode::ASINH_GRAD); + CB_MODE(Mode::ACOSH_GRAD); + CB_MODE(Mode::ATANH_GRAD); + CB_MODE(Mode::PRELU); + CB_MODE(Mode::PRELU_GRAD); + CB_MODE(Mode::CLIP); + CB_MODE(Mode::SOFTPLUS); + CB_MODE(Mode::SOFTPLUS_GRAD); + CB_MODE(Mode::RELU6); + CB_MODE(Mode::RELU6_GRAD); + CB_MODE(Mode::HSIGMOID); + CB_MODE(Mode::HSIGMOID_GRAD); + CB_MODE(Mode::LOGSIGMOID); + CB_MODE(Mode::SQRT); + CB_MODE(Mode::SQUARE); + CB_MODE(Mode::SIGN); default: megdnn_assert( 0, diff --git a/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..fc3af7a7 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float16.cu new file mode 100644 index 00000000..1c4f89c8 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float32.cu new file mode 100644 index 00000000..7674459b --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ACOSH_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_bfloat16.cu new file mode 100644 index 00000000..80411b71 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float16.cu new file mode 100644 index 00000000..aa417709 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float32.cu new file mode 100644 index 00000000..cfcf7ad3 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ACOSH_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..393728a4 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float16.cu new file mode 100644 index 00000000..807bedca --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float32.cu new file mode 100644 index 00000000..7a3d5a3e --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ASINH_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ASINH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ASINH_dt_bfloat16.cu new file mode 100644 index 00000000..7a912381 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ASINH_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ASINH_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ASINH_dt_float16.cu new file mode 100644 index 00000000..29c9307d --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ASINH_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ASINH_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ASINH_dt_float32.cu new file mode 100644 index 00000000..b6fc8ab6 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ASINH_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..cf7afe1b --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float16.cu new file mode 100644 index 00000000..333b8f60 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float32.cu new file mode 100644 index 00000000..be794c32 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ATANH_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ATANH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ATANH_dt_bfloat16.cu new file mode 100644 index 00000000..fa6683df --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ATANH_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ATANH_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ATANH_dt_float16.cu new file mode 100644 index 00000000..804e5361 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ATANH_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ATANH_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ATANH_dt_float32.cu new file mode 100644 index 00000000..5fd222b3 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ATANH_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/CLIP_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_bfloat16.cu new file mode 100644 index 00000000..de73cd15 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/CLIP_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_float16.cu new file mode 100644 index 00000000..25e351c1 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/CLIP_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_float32.cu new file mode 100644 index 00000000..cf8dc776 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int16.cu new file mode 100644 index 00000000..f60b5c4c --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int16.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int32.cu new file mode 100644 index 00000000..c003f595 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int8.cu new file mode 100644 index 00000000..cb0ec046 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_int8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/CLIP_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_uint8.cu new file mode 100644 index 00000000..b0198d93 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/CLIP_dt_uint8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/COSH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/COSH_dt_bfloat16.cu new file mode 100644 index 00000000..7cb17527 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/COSH_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/COSH_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/COSH_dt_float16.cu new file mode 100644 index 00000000..5f42f235 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/COSH_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/COSH_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/COSH_dt_float32.cu new file mode 100644 index 00000000..94ea1870 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/COSH_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..a8115bff --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cu new file mode 100644 index 00000000..a1fb7ee3 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cu new file mode 100644 index 00000000..9c0a4aeb --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_bfloat16.cu new file mode 100644 index 00000000..28a83976 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float16.cu new file mode 100644 index 00000000..cdb77455 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float32.cu new file mode 100644 index 00000000..528f944c --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/HSIGMOID_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cu new file mode 100644 index 00000000..06322df6 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float16.cu new file mode 100644 index 00000000..d0b6c026 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float32.cu new file mode 100644 index 00000000..ea1bcf1a --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/LOGSIGMOID_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..dc331500 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float16.cu new file mode 100644 index 00000000..34411818 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float32.cu new file mode 100644 index 00000000..1fedc850 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_bfloat16.cu new file mode 100644 index 00000000..78c18f9d --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_float16.cu new file mode 100644 index 00000000..33e6ce73 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_float32.cu new file mode 100644 index 00000000..46f2d367 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int16.cu new file mode 100644 index 00000000..d1dfa9ac --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int16.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int32.cu new file mode 100644 index 00000000..d6d7332f --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int8.cu new file mode 100644 index 00000000..621a7dd3 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_int8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/PRELU_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_uint8.cu new file mode 100644 index 00000000..86ff475d --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/PRELU_dt_uint8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..90699665 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float16.cu new file mode 100644 index 00000000..efb61fa6 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float32.cu new file mode 100644 index 00000000..6088f41d --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_bfloat16.cu new file mode 100644 index 00000000..cf79b7e9 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_float16.cu new file mode 100644 index 00000000..0646045d --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_float32.cu new file mode 100644 index 00000000..2fe7746f --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int16.cu new file mode 100644 index 00000000..32c2dab3 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int16.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int32.cu new file mode 100644 index 00000000..e59877c3 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int8.cu new file mode 100644 index 00000000..6f6f7741 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_int8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU6_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_uint8.cu new file mode 100644 index 00000000..60812b55 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU6_dt_uint8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGN_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_bfloat16.cu new file mode 100644 index 00000000..34316156 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SIGN_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_float16.cu new file mode 100644 index 00000000..04ac0b86 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SIGN_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_float32.cu new file mode 100644 index 00000000..0402184f --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int16.cu new file mode 100644 index 00000000..0a854c23 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int16.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int32.cu new file mode 100644 index 00000000..5f3aa927 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int8.cu new file mode 100644 index 00000000..c0d44608 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_int8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGN_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_uint8.cu new file mode 100644 index 00000000..37f4b4b2 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGN_dt_uint8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SINH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SINH_dt_bfloat16.cu new file mode 100644 index 00000000..19b3b24d --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SINH_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SINH_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SINH_dt_float16.cu new file mode 100644 index 00000000..0298140e --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SINH_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SINH_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SINH_dt_float32.cu new file mode 100644 index 00000000..d781a287 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SINH_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cu new file mode 100644 index 00000000..9769ef87 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cu new file mode 100644 index 00000000..694fea1b --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cu new file mode 100644 index 00000000..05710880 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cu new file mode 100644 index 00000000..7df279c7 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float16.cu new file mode 100644 index 00000000..98d84dad --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float32.cu new file mode 100644 index 00000000..898996df --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SOFTPLUS_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SQRT_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SQRT_dt_bfloat16.cu new file mode 100644 index 00000000..b6483dfe --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQRT_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SQRT_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SQRT_dt_float16.cu new file mode 100644 index 00000000..262e68d4 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQRT_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SQRT_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SQRT_dt_float32.cu new file mode 100644 index 00000000..1c6aa2af --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQRT_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_bfloat16.cu new file mode 100644 index 00000000..59f5383e --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float16.cu new file mode 100644 index 00000000..c53551bb --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float32.cu new file mode 100644 index 00000000..4282b479 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int16.cu new file mode 100644 index 00000000..5d7bec08 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int16.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int32.cu new file mode 100644 index 00000000..c1237c37 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int8.cu new file mode 100644 index 00000000..8c01483e --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_int8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_uint8.cu new file mode 100644 index 00000000..094c4191 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SQUARE_dt_uint8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/TAN_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/TAN_dt_bfloat16.cu new file mode 100644 index 00000000..415ed3fc --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/TAN_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/TAN_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/TAN_dt_float16.cu new file mode 100644 index 00000000..dd5339a3 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/TAN_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/TAN_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/TAN_dt_float32.cu new file mode 100644 index 00000000..796c4e65 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/TAN_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..74e69bb6 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..6734fd67 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ACOSH_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..6af7f406 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..9b080b73 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ASINH_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..1d477ec4 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..934699b3 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ATANH_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_qint4_dt_qint4.cu new file mode 100644 index 00000000..accf56fa --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_qint4_dt_qint4.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 +#include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..78de5e8a --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_quint4_dt_quint4.cu new file mode 100644 index 00000000..754de3e6 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/CLIP_dt_quint4_dt_quint4.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 +#include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/COSH_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/COSH_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..2bcc45ab --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/COSH_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/HSIGMOID_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/HSIGMOID_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..a50f44c9 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/HSIGMOID_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/HSIGMOID_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/HSIGMOID_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..5cac0b94 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/HSIGMOID_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LOGSIGMOID_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LOGSIGMOID_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..bb2abf6a --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LOGSIGMOID_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/PRELU_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/PRELU_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..1518830e --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/PRELU_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/PRELU_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/PRELU_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..60fa38fa --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/PRELU_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU6_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU6_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..49ef8a94 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU6_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU6_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU6_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..7cf2b512 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU6_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGN_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGN_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..279bb2bb --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGN_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SINH_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SINH_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..68f6f9a7 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SINH_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SOFTPLUS_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SOFTPLUS_GRAD_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..2c8aa8b6 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SOFTPLUS_GRAD_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SOFTPLUS_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SOFTPLUS_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..7c250442 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SOFTPLUS_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SQRT_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SQRT_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..48290ef9 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SQRT_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SQUARE_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SQUARE_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..a24e0457 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SQUARE_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/TAN_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/TAN_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..ecf0bcb6 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/TAN_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp b/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp index 64a90258..f381d7b5 100644 --- a/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp @@ -267,7 +267,11 @@ IMPL_MODE_DISPATCHER(2, dt_qint4, dt_qint4); IMPL_MODE_DISPATCHER(2, dt_quint4, dt_quint4); #undef FOREACH -#define FOREACH MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT +#define FOREACH(cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) IMPL_MODE_DISPATCHER(3, dt_qint4, dt_qint4); IMPL_MODE_DISPATCHER(3, dt_quint4, dt_quint4); #undef FOREACH diff --git a/dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp b/dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp index 4197a2da..f0cf6c13 100644 --- a/dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp +++ b/dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp @@ -228,6 +228,7 @@ INST(Mode::SHL); INST(Mode::SHR); INST(Mode::FUSE_ADD_RELU); INST(Mode::RMULH); +INST(Mode::PRELU); #undef INST #define INST(mode) \ @@ -258,6 +259,13 @@ INST(Mode::H_SWISH_GRAD); INST(Mode::FUSE_ADD_H_SWISH); INST(Mode::SILU_GRAD); INST(Mode::GELU_GRAD); +INST(Mode::PRELU); +INST(Mode::ASINH_GRAD); +INST(Mode::ACOSH_GRAD); +INST(Mode::ATANH_GRAD); +INST(Mode::SOFTPLUS_GRAD); +INST(Mode::RELU6_GRAD); +INST(Mode::HSIGMOID_GRAD); #undef INST } // namespace fallback } // namespace megdnn diff --git a/dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp b/dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp index b36e7dae..bf7acf09 100644 --- a/dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp +++ b/dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp @@ -77,6 +77,9 @@ using Mode = param_enumv::Elemwise::Mode; INST(Mode::RELU); INST(Mode::ABS); INST(Mode::NEGATE); +INST(Mode::RELU6); +INST(Mode::SQUARE); +INST(Mode::SIGN); #undef INST #define INST(mode) \ @@ -105,6 +108,19 @@ INST(Mode::ERFCINV); INST(Mode::H_SWISH); INST(Mode::SILU); INST(Mode::GELU); +INST(Mode::SINH); +INST(Mode::COSH); +INST(Mode::ASINH); +INST(Mode::ACOSH); +INST(Mode::ATANH); +INST(Mode::TAN); +INST(Mode::SOFTPLUS); +INST(Mode::RELU6); +INST(Mode::HSIGMOID); +INST(Mode::LOGSIGMOID); +INST(Mode::SQRT); +INST(Mode::SQUARE); +INST(Mode::SIGN); #undef INST } // namespace fallback } // namespace megdnn diff --git a/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..fc3af7a7 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_float16.cpp new file mode 100644 index 00000000..1c4f89c8 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_float32.cpp new file mode 100644 index 00000000..7674459b --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ACOSH_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/ACOSH_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ACOSH_dt_bfloat16.cpp new file mode 100644 index 00000000..80411b71 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ACOSH_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ACOSH_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/ACOSH_dt_float16.cpp new file mode 100644 index 00000000..aa417709 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ACOSH_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ACOSH_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/ACOSH_dt_float32.cpp new file mode 100644 index 00000000..cfcf7ad3 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ACOSH_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..393728a4 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_float16.cpp new file mode 100644 index 00000000..807bedca --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_float32.cpp new file mode 100644 index 00000000..7a3d5a3e --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ASINH_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/ASINH_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ASINH_dt_bfloat16.cpp new file mode 100644 index 00000000..7a912381 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ASINH_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ASINH_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/ASINH_dt_float16.cpp new file mode 100644 index 00000000..29c9307d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ASINH_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ASINH_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/ASINH_dt_float32.cpp new file mode 100644 index 00000000..b6fc8ab6 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ASINH_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..cf7afe1b --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_float16.cpp new file mode 100644 index 00000000..333b8f60 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_float32.cpp new file mode 100644 index 00000000..be794c32 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ATANH_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/ATANH_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ATANH_dt_bfloat16.cpp new file mode 100644 index 00000000..fa6683df --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ATANH_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ATANH_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/ATANH_dt_float16.cpp new file mode 100644 index 00000000..804e5361 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ATANH_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ATANH_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/ATANH_dt_float32.cpp new file mode 100644 index 00000000..5fd222b3 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ATANH_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/CLIP_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/CLIP_dt_bfloat16.cpp new file mode 100644 index 00000000..de73cd15 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/CLIP_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/CLIP_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/CLIP_dt_float16.cpp new file mode 100644 index 00000000..25e351c1 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/CLIP_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/CLIP_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/CLIP_dt_float32.cpp new file mode 100644 index 00000000..cf8dc776 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/CLIP_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/CLIP_dt_int16.cpp b/dnn/src/naive/elemwise/kimpl/CLIP_dt_int16.cpp new file mode 100644 index 00000000..f60b5c4c --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/CLIP_dt_int16.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/CLIP_dt_int32.cpp b/dnn/src/naive/elemwise/kimpl/CLIP_dt_int32.cpp new file mode 100644 index 00000000..c003f595 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/CLIP_dt_int32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/CLIP_dt_int8.cpp b/dnn/src/naive/elemwise/kimpl/CLIP_dt_int8.cpp new file mode 100644 index 00000000..cb0ec046 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/CLIP_dt_int8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/CLIP_dt_uint8.cpp b/dnn/src/naive/elemwise/kimpl/CLIP_dt_uint8.cpp new file mode 100644 index 00000000..b0198d93 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/CLIP_dt_uint8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/COSH_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/COSH_dt_bfloat16.cpp new file mode 100644 index 00000000..7cb17527 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/COSH_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/COSH_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/COSH_dt_float16.cpp new file mode 100644 index 00000000..5f42f235 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/COSH_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/COSH_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/COSH_dt_float32.cpp new file mode 100644 index 00000000..94ea1870 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/COSH_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..a8115bff --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cpp new file mode 100644 index 00000000..a1fb7ee3 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cpp new file mode 100644 index 00000000..9c0a4aeb --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_bfloat16.cpp new file mode 100644 index 00000000..28a83976 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_float16.cpp new file mode 100644 index 00000000..cdb77455 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_float32.cpp new file mode 100644 index 00000000..528f944c --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/HSIGMOID_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cpp new file mode 100644 index 00000000..06322df6 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_float16.cpp new file mode 100644 index 00000000..d0b6c026 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_float32.cpp new file mode 100644 index 00000000..ea1bcf1a --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/LOGSIGMOID_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..dc331500 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_float16.cpp new file mode 100644 index 00000000..34411818 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_float32.cpp new file mode 100644 index 00000000..1fedc850 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_dt_bfloat16.cpp new file mode 100644 index 00000000..78c18f9d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_dt_float16.cpp new file mode 100644 index 00000000..33e6ce73 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_dt_float32.cpp new file mode 100644 index 00000000..46f2d367 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_dt_int16.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_dt_int16.cpp new file mode 100644 index 00000000..d1dfa9ac --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_dt_int16.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_dt_int32.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_dt_int32.cpp new file mode 100644 index 00000000..d6d7332f --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_dt_int32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_dt_int8.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_dt_int8.cpp new file mode 100644 index 00000000..621a7dd3 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_dt_int8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/PRELU_dt_uint8.cpp b/dnn/src/naive/elemwise/kimpl/PRELU_dt_uint8.cpp new file mode 100644 index 00000000..86ff475d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/PRELU_dt_uint8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..90699665 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_float16.cpp new file mode 100644 index 00000000..efb61fa6 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_float32.cpp new file mode 100644 index 00000000..6088f41d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_dt_bfloat16.cpp new file mode 100644 index 00000000..cf79b7e9 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_dt_float16.cpp new file mode 100644 index 00000000..0646045d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_dt_float32.cpp new file mode 100644 index 00000000..2fe7746f --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_dt_int16.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_dt_int16.cpp new file mode 100644 index 00000000..32c2dab3 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_dt_int16.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_dt_int32.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_dt_int32.cpp new file mode 100644 index 00000000..e59877c3 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_dt_int32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_dt_int8.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_dt_int8.cpp new file mode 100644 index 00000000..6f6f7741 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_dt_int8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/RELU6_dt_uint8.cpp b/dnn/src/naive/elemwise/kimpl/RELU6_dt_uint8.cpp new file mode 100644 index 00000000..60812b55 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU6_dt_uint8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SIGN_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SIGN_dt_bfloat16.cpp new file mode 100644 index 00000000..34316156 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGN_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SIGN_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/SIGN_dt_float16.cpp new file mode 100644 index 00000000..04ac0b86 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGN_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SIGN_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/SIGN_dt_float32.cpp new file mode 100644 index 00000000..0402184f --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGN_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SIGN_dt_int16.cpp b/dnn/src/naive/elemwise/kimpl/SIGN_dt_int16.cpp new file mode 100644 index 00000000..0a854c23 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGN_dt_int16.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SIGN_dt_int32.cpp b/dnn/src/naive/elemwise/kimpl/SIGN_dt_int32.cpp new file mode 100644 index 00000000..5f3aa927 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGN_dt_int32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SIGN_dt_int8.cpp b/dnn/src/naive/elemwise/kimpl/SIGN_dt_int8.cpp new file mode 100644 index 00000000..c0d44608 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGN_dt_int8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SIGN_dt_uint8.cpp b/dnn/src/naive/elemwise/kimpl/SIGN_dt_uint8.cpp new file mode 100644 index 00000000..37f4b4b2 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGN_dt_uint8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SINH_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SINH_dt_bfloat16.cpp new file mode 100644 index 00000000..19b3b24d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SINH_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SINH_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/SINH_dt_float16.cpp new file mode 100644 index 00000000..0298140e --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SINH_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SINH_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/SINH_dt_float32.cpp new file mode 100644 index 00000000..d781a287 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SINH_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cpp new file mode 100644 index 00000000..9769ef87 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cpp new file mode 100644 index 00000000..694fea1b --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cpp new file mode 100644 index 00000000..05710880 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cpp new file mode 100644 index 00000000..7df279c7 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_float16.cpp new file mode 100644 index 00000000..98d84dad --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_float32.cpp new file mode 100644 index 00000000..898996df --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SOFTPLUS_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SQRT_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SQRT_dt_bfloat16.cpp new file mode 100644 index 00000000..b6483dfe --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQRT_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SQRT_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/SQRT_dt_float16.cpp new file mode 100644 index 00000000..262e68d4 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQRT_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SQRT_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/SQRT_dt_float32.cpp new file mode 100644 index 00000000..1c6aa2af --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQRT_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SQUARE_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_bfloat16.cpp new file mode 100644 index 00000000..59f5383e --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SQUARE_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_float16.cpp new file mode 100644 index 00000000..c53551bb --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SQUARE_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_float32.cpp new file mode 100644 index 00000000..4282b479 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int16.cpp b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int16.cpp new file mode 100644 index 00000000..5d7bec08 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int16.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int32.cpp b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int32.cpp new file mode 100644 index 00000000..c1237c37 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int8.cpp b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int8.cpp new file mode 100644 index 00000000..8c01483e --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_int8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/SQUARE_dt_uint8.cpp b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_uint8.cpp new file mode 100644 index 00000000..094c4191 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SQUARE_dt_uint8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/TAN_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/TAN_dt_bfloat16.cpp new file mode 100644 index 00000000..415ed3fc --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/TAN_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/TAN_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/TAN_dt_float16.cpp new file mode 100644 index 00000000..dd5339a3 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/TAN_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/TAN_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/TAN_dt_float32.cpp new file mode 100644 index 00000000..796c4e65 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/TAN_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..a231d0a3 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..ca103ee3 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..11f768c9 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ACOSH_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..28c3f173 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_float16.cpp.hip new file mode 100644 index 00000000..06d2c12e --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_float32.cpp.hip new file mode 100644 index 00000000..9e969396 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ACOSH_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..d0f61ca2 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..0840a54c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..c239ddb2 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ASINH_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ASINH_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ASINH_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..ea0f11ec --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ASINH_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ASINH_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ASINH_dt_float16.cpp.hip new file mode 100644 index 00000000..cc5d7302 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ASINH_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ASINH_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ASINH_dt_float32.cpp.hip new file mode 100644 index 00000000..96076302 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ASINH_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..3b0496f4 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..1a0c841a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..ce400c03 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ATANH_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ATANH_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ATANH_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..9ac3b63a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ATANH_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ATANH_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ATANH_dt_float16.cpp.hip new file mode 100644 index 00000000..56902e43 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ATANH_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ATANH_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ATANH_dt_float32.cpp.hip new file mode 100644 index 00000000..048cab3a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ATANH_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/CLIP_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..91c018ae --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/CLIP_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_float16.cpp.hip new file mode 100644 index 00000000..1d06d8d1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/CLIP_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_float32.cpp.hip new file mode 100644 index 00000000..346efcbd --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int16.cpp.hip new file mode 100644 index 00000000..4394848f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int32.cpp.hip new file mode 100644 index 00000000..ed51ebd8 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int8.cpp.hip new file mode 100644 index 00000000..dea82a24 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/CLIP_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_uint8.cpp.hip new file mode 100644 index 00000000..0d798b15 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CLIP_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/COSH_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COSH_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..25c504a3 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COSH_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/COSH_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COSH_dt_float16.cpp.hip new file mode 100644 index 00000000..a5e92a13 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COSH_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/COSH_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COSH_dt_float32.cpp.hip new file mode 100644 index 00000000..2d4c2784 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COSH_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..54b03d90 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..eaa54cca --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..a000bb10 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..fdb642b6 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_float16.cpp.hip new file mode 100644 index 00000000..94e88cd6 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_float32.cpp.hip new file mode 100644 index 00000000..8e13dd53 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/HSIGMOID_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..fa85681c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_float16.cpp.hip new file mode 100644 index 00000000..4992d81d --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_float32.cpp.hip new file mode 100644 index 00000000..72e6e3a9 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LOGSIGMOID_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..d24ceb4a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..4665a277 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..023f6fe1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..2ae7e683 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_float16.cpp.hip new file mode 100644 index 00000000..1e1253ca --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_float32.cpp.hip new file mode 100644 index 00000000..9d3b9676 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int16.cpp.hip new file mode 100644 index 00000000..03846baf --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int32.cpp.hip new file mode 100644 index 00000000..41f41670 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int8.cpp.hip new file mode 100644 index 00000000..88c2bbb6 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/PRELU_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_uint8.cpp.hip new file mode 100644 index 00000000..a9febead --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/PRELU_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..f2099c6b --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..b46f1c52 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..37b7d1ba --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..bb39a1a8 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_float16.cpp.hip new file mode 100644 index 00000000..f84eac71 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_float32.cpp.hip new file mode 100644 index 00000000..fd3fe60f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int16.cpp.hip new file mode 100644 index 00000000..66a41225 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int32.cpp.hip new file mode 100644 index 00000000..1a5eed82 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int8.cpp.hip new file mode 100644 index 00000000..f2ecc40a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU6_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_uint8.cpp.hip new file mode 100644 index 00000000..b13aad0d --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU6_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGN_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..ee5373bf --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SIGN_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_float16.cpp.hip new file mode 100644 index 00000000..69277cd5 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SIGN_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_float32.cpp.hip new file mode 100644 index 00000000..718709a1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int16.cpp.hip new file mode 100644 index 00000000..d1cdfbd5 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int32.cpp.hip new file mode 100644 index 00000000..2955aef0 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int8.cpp.hip new file mode 100644 index 00000000..e6cc0849 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGN_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_uint8.cpp.hip new file mode 100644 index 00000000..b39727a4 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGN_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SINH_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SINH_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..7b0d4740 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SINH_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SINH_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SINH_dt_float16.cpp.hip new file mode 100644 index 00000000..070496fb --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SINH_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SINH_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SINH_dt_float32.cpp.hip new file mode 100644 index 00000000..d295a78f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SINH_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..2de73538 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..da7dabfd --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..f123cd7e --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..f86b264d --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_float16.cpp.hip new file mode 100644 index 00000000..991ba230 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_float32.cpp.hip new file mode 100644 index 00000000..55716f22 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SOFTPLUS_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SQRT_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQRT_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..9c4b9211 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQRT_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SQRT_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQRT_dt_float16.cpp.hip new file mode 100644 index 00000000..ad05f946 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQRT_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SQRT_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQRT_dt_float32.cpp.hip new file mode 100644 index 00000000..448f9a0c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQRT_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..aa6d5912 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_float16.cpp.hip new file mode 100644 index 00000000..1db3690d --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_float32.cpp.hip new file mode 100644 index 00000000..da4072ae --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int16.cpp.hip new file mode 100644 index 00000000..487b9824 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int32.cpp.hip new file mode 100644 index 00000000..ae617e29 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int8.cpp.hip new file mode 100644 index 00000000..9ebe2f85 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_uint8.cpp.hip new file mode 100644 index 00000000..b01411f1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SQUARE_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/TAN_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TAN_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..9a274f50 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TAN_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/TAN_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TAN_dt_float16.cpp.hip new file mode 100644 index 00000000..522e02f8 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TAN_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/TAN_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TAN_dt_float32.cpp.hip new file mode 100644 index 00000000..42a0c3b1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TAN_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/test/common/elemwise.cpp b/dnn/test/common/elemwise.cpp index 602954c6..e8134bbd 100644 --- a/dnn/test/common/elemwise.cpp +++ b/dnn/test/common/elemwise.cpp @@ -744,8 +744,8 @@ DEF_TEST(all_modes) { TensorShapeArray shapes; UniformFloatRNG default_rng_f32{-100.f, 100.f}, pos_rng_f32{.1f, 1000.f}, small_pos_rng_f32{.1f, .10f}, small_rng_f32{-3.f, 3.f}, - abslt1_rng_f32{-1.f, 1.f}, uniform_0_2_rng{0.f, 2.f}, - tanh_rng_f32{-5.f, 5.f}; + abslt1_rng_f32{-0.95f, 0.95f}, uniform_0_2_rng{0.f, 2.f}, + tanh_rng_f32{-5.f, 5.f}, lt1_rng_f32{1.f, 10.f}; UniformFloatNonZeroRNG nonzero_rng_f32{.1f, 1000.f}, big_nonzero_rng_f32{100.f, 1000.f}; UniformIntRNG default_rng_i32{-100, 100}, small_rng_i32{-2, 2}, @@ -786,12 +786,14 @@ DEF_TEST(all_modes) { shapes[shapes.size() - 1] = {}; auto do_run = [&](DType dtype, float eps = 1e-3) { // limit value ranges for some modes - if (mode == Mode::LOG || mode == Mode::LOG1P) { + if (mode == Mode::LOG || mode == Mode::LOG1P || mode == Mode::SQRT) { checker.set_rng(0, &pos_rng_f32); - } else if (mode == Mode::POW) { + } else if (mode == Mode::POW || mode == Mode::SOFTPLUS_GRAD) { checker.set_rng(0, &small_pos_rng_f32); checker.set_rng(1, &small_rng_f32); - } else if (mode == Mode::EXP || mode == Mode::EXPM1) { + } else if ( + mode == Mode::EXP || mode == Mode::EXPM1 || mode == Mode::SINH || + mode == Mode::COSH) { checker.set_rng(0, &small_rng_f32); } else if (mode == Mode::FAST_TANH) { checker.set_rng(0, &tanh_rng_f32); @@ -807,6 +809,10 @@ DEF_TEST(all_modes) { checker.set_rng(1, &default_rng_f32); } else if (mode == Mode::ERFCINV) { checker.set_rng(0, &uniform_0_2_rng); + } else if (mode == Mode::ACOSH_GRAD || mode == Mode::ACOSH) { + checker.set_rng(0, <1_rng_f32); + } else if (mode == Mode::ATANH_GRAD || mode == Mode::ATANH) { + checker.set_rng(0, &abslt1_rng_f32); } else if ( mode == Mode::MOD || mode == Mode::TRUE_DIV || mode == Mode::FLOOR_DIV) { diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 39c8606c..b6889956 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -773,7 +773,7 @@ def sqrt(x: Tensor) -> Tensor: Tensor([1. 2. 3. 4.], device=xpux:0) """ - return pow(x, 0.5) + return _elwise(x, mode=Elemwise.Mode.SQRT) def square(x: Tensor) -> Tensor: @@ -790,16 +790,16 @@ def square(x: Tensor) -> Tensor: Examples: >>> F.square(2) - Tensor(4.0, device=xpux:0) + Tensor(4, dtype=int32, device=xpux:0) Element-wise square: >>> x = Tensor([1, -2, -3, 4]) >>> F.square(x) - Tensor([ 1. 4. 9. 16.], device=xpux:0) + Tensor([ 1 4 9 16], dtype=int32, device=xpux:0) """ - return pow(x, 2) + return _elwise(x, mode=Elemwise.Mode.SQUARE) def logaddexp(x: Tensor, y: Tensor) -> Tensor: @@ -1053,7 +1053,7 @@ def clip(x: Tensor, lower=None, upper=None) -> Tensor: ), "At least one of 'lower' or 'upper' must not be None" if lower is not None: if upper is not None: - return minimum(maximum(x, lower), upper) + return _elwise(x, lower, upper, mode=Elemwise.Mode.CLIP) else: return maximum(x, lower) else: @@ -1157,7 +1157,7 @@ def tan(x): Tensor([0. 1. 0.], device=xpux:0) """ - return sin(x) / cos(x) + return _elwise(x, mode=Elemwise.Mode.TAN) def acos(x): @@ -1373,7 +1373,7 @@ def cosh(x): Tensor([1. 1.5431 1.5431], device=xpux:0) """ - return 0.5 * (exp(x) + exp(-x)) + return _elwise(x, mode=Elemwise.Mode.COSH) def sinh(x): @@ -1416,8 +1416,7 @@ def sinh(x): Tensor([ 0. 1.1752 -1.1752], device=xpux:0) """ - u = expm1(x) - return 0.5 * u / (u + 1) * (u + 2) + return _elwise(x, mode=Elemwise.Mode.SINH) def tanh(x): @@ -1498,7 +1497,7 @@ def asinh(x): Tensor([ 0. 0.8814 -0.8814], device=xpux:0) """ - return log(x + (x ** 2 + 1) ** 0.5) + return _elwise(x, mode=Elemwise.Mode.ASINH) def acosh(x): @@ -1534,7 +1533,7 @@ def acosh(x): >>> F.acosh(x) Tensor([0. 1.317 1.7627], device=xpux:0) """ - return log(x + (x ** 2 - 1) ** 0.5) + return _elwise(x, mode=Elemwise.Mode.ACOSH) def atanh(x): @@ -1572,7 +1571,7 @@ def atanh(x): Tensor([ 0. 0.5493 -0.5493], device=xpux:0) """ - return log1p(2 * x / (1 - x)) / 2 + return _elwise(x, mode=Elemwise.Mode.ATANH) # bit-twiddling functions diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 35a49e25..4ae82c1e 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -8,7 +8,7 @@ import numpy as np from ..core._imperative_rt.core2 import Const, apply from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core.ops import builtin -from ..core.tensor.array_method import _matmul +from ..core.tensor.array_method import _elwise, _matmul from ..core.tensor.utils import _normalize_axis from ..tensor import Tensor from ..utils.deprecation import deprecated_kwargs_default @@ -133,10 +133,7 @@ def sign(x: Tensor): >>> F.sign(x) Tensor([ 1 -1 0], dtype=int32, device=xpux:0) """ - return (x > 0).astype(x.dtype) - (x < 0).astype(x.dtype) - - -# statistical functions + return _elwise(x, mode=builtin.Elemwise.Mode.SIGN) def sum( diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index f5742d8b..0c3fbad0 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -835,37 +835,9 @@ def sigmoid(x): return _elwise(x, mode=Elemwise.Mode.SIGMOID) -@lru_cache(maxsize=None) -def _get_hsigmoid_op(dtype=None, device=None): - @subgraph_fn( - "Hsigmoid", - dtype=dtype, - device=device, - nr_inputs=1, - jit_fusion=True, - custom_grad=True, - ) - def hsigmoid(inputs, f, c): - (inp,) = inputs[0:1] - inp = f("+", inp, c(3)) - max_0 = f("max", inp, c(0)) - min_6 = f("min", max_0, c(6)) - oup = f("/", min_6, c(6)) - (oup_grad,) = yield (oup,) - inp_grad = f("/", oup_grad, c(6)) - inp_grad = f("cond_leq_mov", max_0, c(6), inp_grad) - inp_grad = f("cond_leq_mov", c(0), inp, inp_grad) - yield (inp_grad,) - - return hsigmoid - - def hsigmoid(x): r"""Element-wise `relu6(x + 3) / 6`.""" - hsigmoid = _get_hsigmoid_op(x.dtype, x.device) - (x,) = hsigmoid(x) - return x - # return relu6(x + 3) / 6 + return _elwise(x, mode=Elemwise.Mode.HSIGMOID) def relu(x): @@ -873,95 +845,14 @@ def relu(x): return _elwise(x, mode=Elemwise.Mode.RELU) -@lru_cache(maxsize=None) -def _get_relu6_op(dtype=None, device=None): - @subgraph_fn( - "ReLU6", - dtype=dtype, - device=device, - nr_inputs=1, - jit_fusion=True, - custom_grad=True, - ) - def relu6(inputs, f, c): - (inp,) = inputs[0:1] - max_0 = f("max", inp, c(0)) - min_6 = f("min", max_0, c(6)) - oup = min_6 - (oup_grad,) = yield (oup,) - inp_grad = f("cond_leq_mov", max_0, c(6), oup_grad) - inp_grad = f("cond_leq_mov", c(0), inp, inp_grad) - yield (inp_grad,) - - return relu6 - - def relu6(x): r"""Element-wise `min(max(x, 0), 6)`.""" - relu6 = _get_relu6_op(x.dtype, x.device) - (x,) = relu6(x) - return x - - -@lru_cache(maxsize=None) -def _get_prelu_op(dtype=None, device=None): - @subgraph_fn( - "PReLU", - dtype=dtype, - device=device, - nr_inputs=2, - jit_fusion=True, - custom_grad=True, - ) - def prelu(inputs, f, c): - (inp, weight) = inputs[0:2] - max_0 = f("max", inp, c(0)) - min_0 = f("min", inp, c(0)) - oup = f("fma3", min_0, weight, max_0) - (oup_grad,) = yield (oup,) - inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad) - inp_grad_1 = f("*", oup_grad, weight) - inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1) - inp_grad = f("+", inp_grad_0, inp_grad_1) - weight_grad = f("*", oup_grad, min_0) - yield (inp_grad, weight_grad) - - return prelu - - -def prelu(inp: Tensor, weight: Tensor) -> Tensor: - r"""Element-wise PReLU function. - - Refer to :class:`~.PReLU` for more information. - """ - prelu = _get_prelu_op(dtype=inp.dtype, device=inp.device) - (oup,) = prelu(inp, broadcast_to(weight, inp.shape)) - return oup + return _elwise(x, mode=Elemwise.Mode.RELU6) -@lru_cache(maxsize=None) -def _get_leaky_relu_op(negative_slope, *, dtype=None, device=None): - @subgraph_fn( - "LeakyReLU", - dtype=dtype, - device=device, - nr_inputs=1, - jit_fusion=True, - custom_grad=True, - ) - def leakyReLU(inputs, f, c): - (inp,) = inputs[0:1] - max_0 = f("max", inp, c(0)) - min_0 = f("min", inp, c(0)) - oup = f("+", max_0, f("*", min_0, c(negative_slope))) - (oup_grad,) = yield (oup,) - inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad) - inp_grad_1 = f("*", oup_grad, c(negative_slope)) - inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1) - inp_grad = f("+", inp_grad_0, inp_grad_1) - yield (inp_grad,) - - return leakyReLU +def prelu(x, y): + r"""Element-wise `max(x, 0) + y * min(x, 0)`.""" + return _elwise(x, y, mode=Elemwise.Mode.PRELU) def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: @@ -969,9 +860,7 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: Refer to :class:`~.LeakyReLU` for more information. """ - leakyReLU = _get_leaky_relu_op(negative_slope, dtype=inp.dtype, device=inp.device) - (oup,) = leakyReLU(inp) - return oup + return _elwise(inp, negative_slope, mode=Elemwise.Mode.PRELU) def silu(x): @@ -990,36 +879,6 @@ def gelu(x): return _elwise(x, mode=Elemwise.Mode.GELU) -@lru_cache(maxsize=None) -def _get_softplus_op(dtype=None, device=None): - @subgraph_fn( - "Softplus", - dtype=dtype, - device=device, - nr_inputs=1, - jit_fusion=True, - custom_grad=True, - ) - def softplus(inputs, f, c): - (inp,) = inputs[0:1] - neg_abs = f("-", f("abs", inp)) - exp = f("exp", neg_abs) - oup0 = f("log1p", exp) - oup1 = f("relu", inp) - oup = f("+", oup0, oup1) - (oup_grad,) = yield (oup,) - inp_grad_0 = f("switch_gt0", oup1, oup_grad) - inp_grad_1 = oup_grad - inp_grad_1 = f("/", oup_grad, f("+", exp, c(1))) - inp_grad_1 = f("*", inp_grad_1, exp) - inp_grad_1 = f("-", inp_grad_1) - inp_grad_1 = f("abs_grad", inp, inp_grad_1) - inp_grad = f("+", inp_grad_0, inp_grad_1) - yield (inp_grad,) - - return softplus - - def softplus(inp: Tensor) -> Tensor: r"""Applies the element-wise function: @@ -1042,9 +901,7 @@ def softplus(inp: Tensor) -> Tensor: >>> y.numpy().round(decimals=4) array([0.0486, 0.1269, 0.3133, 0.6931, 1.3133, 2.1269], dtype=float32) """ - softplus = _get_softplus_op(inp.dtype, inp.device) - (oup,) = softplus(inp) - return oup + return _elwise(inp, mode=Elemwise.Mode.SOFTPLUS) def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: @@ -1073,39 +930,6 @@ def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: return inp - logsumexp(inp, axis, keepdims=True) -@lru_cache(maxsize=None) -def _get_logsigmoid_op(dtype=None, device=None): - @subgraph_fn( - "LogSigmoid", - dtype=dtype, - device=device, - nr_inputs=1, - jit_fusion=True, - custom_grad=True, - ) - def logsigmoid(inputs, f, c): - (inp,) = inputs[0:1] - neg_abs = f("-", f("abs", inp)) - exp = f("exp", neg_abs) - oup0 = f("log1p", exp) - oup1 = f("relu", f("-", inp)) - oup = f("+", oup0, oup1) - oup = f("-", oup) - (oup_grad,) = yield (oup,) - oup_grad = f("-", oup_grad) - inp_grad_0 = f("switch_gt0", oup1, oup_grad) - inp_grad_0 = f("-", inp_grad_0) - inp_grad_1 = oup_grad - inp_grad_1 = f("/", inp_grad_1, f("+", exp, c(1))) - inp_grad_1 = f("*", inp_grad_1, exp) - inp_grad_1 = f("-", inp_grad_1) - inp_grad_1 = f("abs_grad", inp, inp_grad_1) - inp_grad = f("+", inp_grad_0, inp_grad_1) - yield (inp_grad,) - - return logsigmoid - - def logsigmoid(inp: Tensor) -> Tensor: r"""Applies the element-wise function: @@ -1123,9 +947,7 @@ def logsigmoid(inp: Tensor) -> Tensor: array([-5.0067, -4.0182, -3.0486, -2.1269, -1.3133, -0.6931, -0.3133, -0.1269, -0.0486, -0.0181], dtype=float32) """ - logsigmoid = _get_logsigmoid_op(inp.dtype, inp.device) - (oup,) = logsigmoid(inp) - return oup + return _elwise(inp, mode=Elemwise.Mode.LOGSIGMOID) def logsumexp( diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 57100bd5..effd496c 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -1203,7 +1203,6 @@ def cumsum(inp: Tensor, axis: int): [ 4 9 15]], dtype=int32, device=xpux:0) """ - assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor" op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False) return apply(op, inp)[0] diff --git a/imperative/python/megengine/traced_module/compat.py b/imperative/python/megengine/traced_module/compat.py index 3941b470..757b0d4e 100644 --- a/imperative/python/megengine/traced_module/compat.py +++ b/imperative/python/megengine/traced_module/compat.py @@ -207,3 +207,20 @@ def quantized_convbn2d_module_loader(expr): module = expr.inputs[0].owner if not hasattr(module, "padding_mode"): module.padding_mode = "zeros" + + +@register_functional_loader(("megengine.functional.elemwise", "square")) +def square_func_loader(expr): + import pkg_resources as pkg + + if not hasattr(expr, "version") or pkg.parse_version( + expr.version + ) <= pkg.parse_version("1.11.1"): + if expr.inputs[0].dtype != np.float32: + orig_oup = expr.outputs[0] + oup = TensorNode(expr, shape=orig_oup.shape, dtype=expr.inputs[0].dtype,) + expr.return_val = (oup,) + astype_expr = CallMethod(oup, "astype") + astype_expr.set_args_kwargs(oup, "float32") + orig_oup.expr = astype_expr + astype_expr.return_val = (orig_oup,) diff --git a/imperative/python/test/unit/traced_module/test_serialization.py b/imperative/python/test/unit/traced_module/test_serialization.py index 8b7363f3..23e21153 100644 --- a/imperative/python/test/unit/traced_module/test_serialization.py +++ b/imperative/python/test/unit/traced_module/test_serialization.py @@ -1,11 +1,13 @@ import pickle from collections import defaultdict +from functools import wraps from tempfile import TemporaryFile import numpy as np import megengine.functional as F import megengine.module as M +import megengine.traced_module.expr as Expr import megengine.traced_module.serialization as S from megengine import Tensor from megengine.core._imperative_rt.core2 import apply @@ -357,3 +359,35 @@ def test_opdef_serialization(): load_x = pickle.load(f) assert x.strategy == load_x.strategy assert x == load_x + + +def test_square_function_compat(): + @wraps(F.elemwise.square) + def origin_square(x): + return F.pow(x, 2) + + new_square = F.elemwise.square + F.elemwise.square = origin_square + current_version = Expr.__version__ + Expr.__version__ = "1.11.1" + + class old_square(M.Module): + def forward(self, x): + x = F.relu(x) + x = F.elemwise.square(x) + return x * 2 + + m = trace_module(old_square(), Tensor([1, 2, 4, 6])) + float_m = trace_module(old_square(), Tensor([1.0, 2.0, 4.0, 6.0])) + + # dump old version square + obj = pickle.dumps(m) + f_obj = pickle.dumps(float_m) + + # load in new version + F.elemwise.square = new_square + Expr.__version__ = current_version + new_m = pickle.loads(obj) + new_float_m = pickle.loads(f_obj) + assert len(new_m.graph._exprs) == 4 and len(new_float_m.graph._exprs) == 3 + assert new_m(Tensor([1, 2, 4, 6])).dtype == np.float32 diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index b049c432..4c4e6851 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -663,7 +663,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { OP_TRAIT_REG(Cumsum, Cumsum).apply_on_var_node(apply_on_var_node).fallback(); } // namespace cumsum } // namespace - namespace lrn { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); diff --git a/imperative/src/impl/transformations/dtype_promote.cpp b/imperative/src/impl/transformations/dtype_promote.cpp index 58de880a..5e7743d7 100644 --- a/imperative/src/impl/transformations/dtype_promote.cpp +++ b/imperative/src/impl/transformations/dtype_promote.cpp @@ -116,12 +116,17 @@ ValueRefList elemwise_rule(const OpDef& op, Span inputs) { } static std::unordered_set cast_case1 = { - Elemwise::Mode::TRUE_DIV, Elemwise::Mode::EXP, - Elemwise::Mode::POW, Elemwise::Mode::LOG, - Elemwise::Mode::EXPM1, Elemwise::Mode::LOG1P, - Elemwise::Mode::ACOS, Elemwise::Mode::ASIN, - Elemwise::Mode::ATAN2, Elemwise::Mode::COS, - Elemwise::Mode::SIN, Elemwise::Mode::LOG_SUM_EXP, + Elemwise::Mode::TRUE_DIV, Elemwise::Mode::EXP, + Elemwise::Mode::POW, Elemwise::Mode::LOG, + Elemwise::Mode::EXPM1, Elemwise::Mode::LOG1P, + Elemwise::Mode::ACOS, Elemwise::Mode::ASIN, + Elemwise::Mode::ATAN2, Elemwise::Mode::COS, + Elemwise::Mode::SIN, Elemwise::Mode::LOG_SUM_EXP, + Elemwise::Mode::TAN, Elemwise::Mode::ASINH, + Elemwise::Mode::ACOSH, Elemwise::Mode::ATANH, + Elemwise::Mode::SINH, Elemwise::Mode::COSH, + Elemwise::Mode::SOFTPLUS, Elemwise::Mode::HSIGMOID, + Elemwise::Mode::LOGSIGMOID, Elemwise::Mode::SQRT, }; static std::unordered_set cast_case2 = { diff --git a/imperative/tablegen/generated/hash.txt b/imperative/tablegen/generated/hash.txt index 104ec78f..511a1e43 100644 --- a/imperative/tablegen/generated/hash.txt +++ b/imperative/tablegen/generated/hash.txt @@ -1,7 +1,7 @@ -e38b68be4e2aaf3de2f22e3dddbeaac4 ../../dnn/scripts/opr_param_defs.py +8dd504f360fd3d3bfb560c970b568153 ../../dnn/scripts/opr_param_defs.py cf864561de125ab559c0035158656682 ../../src/core/include/megbrain/ir/ops.td -9248d42a9b3e770693306992156f6015 generated/opdef.h.inl -5c7e7ac49d1338d70ac84ba309e6732b generated/opdef.cpp.inl -30b669eec36876a65717e0c68dd76c83 generated/opdef.py.inl -4312de292a3d71f34a084bf43ea2ecec generated/opdef.cpy.inl +f27cdbb7926e0be9f5dabb8651d2e4da generated/opdef.h.inl +96817f709ee92c8e1eb7cb4168f28565 generated/opdef.cpp.inl +672668fa3ed11c27781f0fa380e6c8aa generated/opdef.py.inl +47511e3e7fed8c64a1c4fea48d79b3d1 generated/opdef.cpy.inl 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h diff --git a/imperative/tablegen/generated/opdef.cpp.inl b/imperative/tablegen/generated/opdef.cpp.inl index b9c5878d..6ec6e918 100644 --- a/imperative/tablegen/generated/opdef.cpp.inl +++ b/imperative/tablegen/generated/opdef.cpp.inl @@ -3076,6 +3076,72 @@ std::vector> Elemwise_props_impl(const OpDef case Elemwise::Mode::ISINF: props_.emplace_back("mode", "ISINF"); break; + case Elemwise::Mode::SINH: + props_.emplace_back("mode", "SINH"); + break; + case Elemwise::Mode::COSH: + props_.emplace_back("mode", "COSH"); + break; + case Elemwise::Mode::ASINH: + props_.emplace_back("mode", "ASINH"); + break; + case Elemwise::Mode::ACOSH: + props_.emplace_back("mode", "ACOSH"); + break; + case Elemwise::Mode::ATANH: + props_.emplace_back("mode", "ATANH"); + break; + case Elemwise::Mode::TAN: + props_.emplace_back("mode", "TAN"); + break; + case Elemwise::Mode::ASINH_GRAD: + props_.emplace_back("mode", "ASINH_GRAD"); + break; + case Elemwise::Mode::ACOSH_GRAD: + props_.emplace_back("mode", "ACOSH_GRAD"); + break; + case Elemwise::Mode::ATANH_GRAD: + props_.emplace_back("mode", "ATANH_GRAD"); + break; + case Elemwise::Mode::PRELU: + props_.emplace_back("mode", "PRELU"); + break; + case Elemwise::Mode::CLIP: + props_.emplace_back("mode", "CLIP"); + break; + case Elemwise::Mode::PRELU_GRAD: + props_.emplace_back("mode", "PRELU_GRAD"); + break; + case Elemwise::Mode::SOFTPLUS: + props_.emplace_back("mode", "SOFTPLUS"); + break; + case Elemwise::Mode::SOFTPLUS_GRAD: + props_.emplace_back("mode", "SOFTPLUS_GRAD"); + break; + case Elemwise::Mode::RELU6: + props_.emplace_back("mode", "RELU6"); + break; + case Elemwise::Mode::RELU6_GRAD: + props_.emplace_back("mode", "RELU6_GRAD"); + break; + case Elemwise::Mode::HSIGMOID: + props_.emplace_back("mode", "HSIGMOID"); + break; + case Elemwise::Mode::HSIGMOID_GRAD: + props_.emplace_back("mode", "HSIGMOID_GRAD"); + break; + case Elemwise::Mode::LOGSIGMOID: + props_.emplace_back("mode", "LOGSIGMOID"); + break; + case Elemwise::Mode::SQRT: + props_.emplace_back("mode", "SQRT"); + break; + case Elemwise::Mode::SQUARE: + props_.emplace_back("mode", "SQUARE"); + break; + case Elemwise::Mode::SIGN: + props_.emplace_back("mode", "SIGN"); + break; default: props_.emplace_back("mode", "INVALID"); break; diff --git a/imperative/tablegen/generated/opdef.cpy.inl b/imperative/tablegen/generated/opdef.cpy.inl index 3ed1bc7d..42b83709 100644 --- a/imperative/tablegen/generated/opdef.cpy.inl +++ b/imperative/tablegen/generated/opdef.cpy.inl @@ -8770,16 +8770,16 @@ void _init_py_Dropout(py::module m) { template<> struct EnumTrait { static constexpr const char *name = "Elemwise.Mode"; - static constexpr std::underlying_type_t max = 64 - 1; + static constexpr std::underlying_type_t max = 86 - 1; }; template<> PyTypeObject* EnumWrapper::type = nullptr; template<> const char* -EnumWrapper::members[] = {"RELU", "ABS", "ACOS", "ASIN", "CEIL", "COS", "EXP", "EXPM1", "FLOOR", "LOG", "LOG1P", "NEGATE", "SIGMOID", "SIN", "TANH", "ABS_GRAD", "ADD", "FLOOR_DIV", "MAX", "MIN", "MOD", "MUL", "POW", "SIGMOID_GRAD", "SUB", "SWITCH_GT0", "TANH_GRAD", "TRUE_DIV", "LOG_SUM_EXP", "LT", "LEQ", "EQ", "SHL", "SHR", "COND_LEQ_MOV", "FUSE_MUL_ADD3", "FUSE_MUL_ADD4", "FUSE_ADD_RELU", "FUSE_ADD_SIGMOID", "FUSE_ADD_TANH", "FAST_TANH", "FAST_TANH_GRAD", "ROUND", "RMULH", "ATAN2", "ERF", "ERFINV", "ERFC", "ERFCINV", "H_SWISH", "H_SWISH_GRAD", "FUSE_ADD_H_SWISH", "NOT", "AND", "OR", "XOR", "SILU", "SILU_GRAD", "GELU", "GELU_GRAD", "COND_LT_MOV", "NEQ", "ISNAN", "ISINF"}; +EnumWrapper::members[] = {"RELU", "ABS", "ACOS", "ASIN", "CEIL", "COS", "EXP", "EXPM1", "FLOOR", "LOG", "LOG1P", "NEGATE", "SIGMOID", "SIN", "TANH", "ABS_GRAD", "ADD", "FLOOR_DIV", "MAX", "MIN", "MOD", "MUL", "POW", "SIGMOID_GRAD", "SUB", "SWITCH_GT0", "TANH_GRAD", "TRUE_DIV", "LOG_SUM_EXP", "LT", "LEQ", "EQ", "SHL", "SHR", "COND_LEQ_MOV", "FUSE_MUL_ADD3", "FUSE_MUL_ADD4", "FUSE_ADD_RELU", "FUSE_ADD_SIGMOID", "FUSE_ADD_TANH", "FAST_TANH", "FAST_TANH_GRAD", "ROUND", "RMULH", "ATAN2", "ERF", "ERFINV", "ERFC", "ERFCINV", "H_SWISH", "H_SWISH_GRAD", "FUSE_ADD_H_SWISH", "NOT", "AND", "OR", "XOR", "SILU", "SILU_GRAD", "GELU", "GELU_GRAD", "COND_LT_MOV", "NEQ", "ISNAN", "ISINF", "SINH", "COSH", "ASINH", "ACOSH", "ATANH", "TAN", "ASINH_GRAD", "ACOSH_GRAD", "ATANH_GRAD", "PRELU", "CLIP", "PRELU_GRAD", "SOFTPLUS", "SOFTPLUS_GRAD", "RELU6", "RELU6_GRAD", "HSIGMOID", "HSIGMOID_GRAD", "LOGSIGMOID", "SQRT", "SQUARE", "SIGN"}; template<> std::unordered_map -EnumWrapper::mem2value = {{normalize_enum("RELU"), Elemwise::Mode::RELU}, {normalize_enum("ABS"), Elemwise::Mode::ABS}, {normalize_enum("ACOS"), Elemwise::Mode::ACOS}, {normalize_enum("ASIN"), Elemwise::Mode::ASIN}, {normalize_enum("CEIL"), Elemwise::Mode::CEIL}, {normalize_enum("COS"), Elemwise::Mode::COS}, {normalize_enum("EXP"), Elemwise::Mode::EXP}, {normalize_enum("EXPM1"), Elemwise::Mode::EXPM1}, {normalize_enum("FLOOR"), Elemwise::Mode::FLOOR}, {normalize_enum("LOG"), Elemwise::Mode::LOG}, {normalize_enum("LOG1P"), Elemwise::Mode::LOG1P}, {normalize_enum("NEGATE"), Elemwise::Mode::NEGATE}, {normalize_enum("SIGMOID"), Elemwise::Mode::SIGMOID}, {normalize_enum("SIN"), Elemwise::Mode::SIN}, {normalize_enum("TANH"), Elemwise::Mode::TANH}, {normalize_enum("ABS_GRAD"), Elemwise::Mode::ABS_GRAD}, {normalize_enum("ADD"), Elemwise::Mode::ADD}, {normalize_enum("FLOOR_DIV"), Elemwise::Mode::FLOOR_DIV}, {normalize_enum("MAX"), Elemwise::Mode::MAX}, {normalize_enum("MIN"), Elemwise::Mode::MIN}, {normalize_enum("MOD"), Elemwise::Mode::MOD}, {normalize_enum("MUL"), Elemwise::Mode::MUL}, {normalize_enum("POW"), Elemwise::Mode::POW}, {normalize_enum("SIGMOID_GRAD"), Elemwise::Mode::SIGMOID_GRAD}, {normalize_enum("SUB"), Elemwise::Mode::SUB}, {normalize_enum("SWITCH_GT0"), Elemwise::Mode::SWITCH_GT0}, {normalize_enum("TANH_GRAD"), Elemwise::Mode::TANH_GRAD}, {normalize_enum("TRUE_DIV"), Elemwise::Mode::TRUE_DIV}, {normalize_enum("LOG_SUM_EXP"), Elemwise::Mode::LOG_SUM_EXP}, {normalize_enum("LT"), Elemwise::Mode::LT}, {normalize_enum("LEQ"), Elemwise::Mode::LEQ}, {normalize_enum("EQ"), Elemwise::Mode::EQ}, {normalize_enum("SHL"), Elemwise::Mode::SHL}, {normalize_enum("SHR"), Elemwise::Mode::SHR}, {normalize_enum("COND_LEQ_MOV"), Elemwise::Mode::COND_LEQ_MOV}, {normalize_enum("FUSE_MUL_ADD3"), Elemwise::Mode::FUSE_MUL_ADD3}, {normalize_enum("FUSE_MUL_ADD4"), Elemwise::Mode::FUSE_MUL_ADD4}, {normalize_enum("FUSE_ADD_RELU"), Elemwise::Mode::FUSE_ADD_RELU}, {normalize_enum("FUSE_ADD_SIGMOID"), Elemwise::Mode::FUSE_ADD_SIGMOID}, {normalize_enum("FUSE_ADD_TANH"), Elemwise::Mode::FUSE_ADD_TANH}, {normalize_enum("FAST_TANH"), Elemwise::Mode::FAST_TANH}, {normalize_enum("FAST_TANH_GRAD"), Elemwise::Mode::FAST_TANH_GRAD}, {normalize_enum("ROUND"), Elemwise::Mode::ROUND}, {normalize_enum("RMULH"), Elemwise::Mode::RMULH}, {normalize_enum("ATAN2"), Elemwise::Mode::ATAN2}, {normalize_enum("ERF"), Elemwise::Mode::ERF}, {normalize_enum("ERFINV"), Elemwise::Mode::ERFINV}, {normalize_enum("ERFC"), Elemwise::Mode::ERFC}, {normalize_enum("ERFCINV"), Elemwise::Mode::ERFCINV}, {normalize_enum("H_SWISH"), Elemwise::Mode::H_SWISH}, {normalize_enum("H_SWISH_GRAD"), Elemwise::Mode::H_SWISH_GRAD}, {normalize_enum("FUSE_ADD_H_SWISH"), Elemwise::Mode::FUSE_ADD_H_SWISH}, {normalize_enum("NOT"), Elemwise::Mode::NOT}, {normalize_enum("AND"), Elemwise::Mode::AND}, {normalize_enum("OR"), Elemwise::Mode::OR}, {normalize_enum("XOR"), Elemwise::Mode::XOR}, {normalize_enum("SILU"), Elemwise::Mode::SILU}, {normalize_enum("SILU_GRAD"), Elemwise::Mode::SILU_GRAD}, {normalize_enum("GELU"), Elemwise::Mode::GELU}, {normalize_enum("GELU_GRAD"), Elemwise::Mode::GELU_GRAD}, {normalize_enum("COND_LT_MOV"), Elemwise::Mode::COND_LT_MOV}, {normalize_enum("NEQ"), Elemwise::Mode::NEQ}, {normalize_enum("ISNAN"), Elemwise::Mode::ISNAN}, {normalize_enum("ISINF"), Elemwise::Mode::ISINF}}; -template<> PyObject* EnumWrapper::pyobj_insts[64] = {nullptr}; +EnumWrapper::mem2value = {{normalize_enum("RELU"), Elemwise::Mode::RELU}, {normalize_enum("ABS"), Elemwise::Mode::ABS}, {normalize_enum("ACOS"), Elemwise::Mode::ACOS}, {normalize_enum("ASIN"), Elemwise::Mode::ASIN}, {normalize_enum("CEIL"), Elemwise::Mode::CEIL}, {normalize_enum("COS"), Elemwise::Mode::COS}, {normalize_enum("EXP"), Elemwise::Mode::EXP}, {normalize_enum("EXPM1"), Elemwise::Mode::EXPM1}, {normalize_enum("FLOOR"), Elemwise::Mode::FLOOR}, {normalize_enum("LOG"), Elemwise::Mode::LOG}, {normalize_enum("LOG1P"), Elemwise::Mode::LOG1P}, {normalize_enum("NEGATE"), Elemwise::Mode::NEGATE}, {normalize_enum("SIGMOID"), Elemwise::Mode::SIGMOID}, {normalize_enum("SIN"), Elemwise::Mode::SIN}, {normalize_enum("TANH"), Elemwise::Mode::TANH}, {normalize_enum("ABS_GRAD"), Elemwise::Mode::ABS_GRAD}, {normalize_enum("ADD"), Elemwise::Mode::ADD}, {normalize_enum("FLOOR_DIV"), Elemwise::Mode::FLOOR_DIV}, {normalize_enum("MAX"), Elemwise::Mode::MAX}, {normalize_enum("MIN"), Elemwise::Mode::MIN}, {normalize_enum("MOD"), Elemwise::Mode::MOD}, {normalize_enum("MUL"), Elemwise::Mode::MUL}, {normalize_enum("POW"), Elemwise::Mode::POW}, {normalize_enum("SIGMOID_GRAD"), Elemwise::Mode::SIGMOID_GRAD}, {normalize_enum("SUB"), Elemwise::Mode::SUB}, {normalize_enum("SWITCH_GT0"), Elemwise::Mode::SWITCH_GT0}, {normalize_enum("TANH_GRAD"), Elemwise::Mode::TANH_GRAD}, {normalize_enum("TRUE_DIV"), Elemwise::Mode::TRUE_DIV}, {normalize_enum("LOG_SUM_EXP"), Elemwise::Mode::LOG_SUM_EXP}, {normalize_enum("LT"), Elemwise::Mode::LT}, {normalize_enum("LEQ"), Elemwise::Mode::LEQ}, {normalize_enum("EQ"), Elemwise::Mode::EQ}, {normalize_enum("SHL"), Elemwise::Mode::SHL}, {normalize_enum("SHR"), Elemwise::Mode::SHR}, {normalize_enum("COND_LEQ_MOV"), Elemwise::Mode::COND_LEQ_MOV}, {normalize_enum("FUSE_MUL_ADD3"), Elemwise::Mode::FUSE_MUL_ADD3}, {normalize_enum("FUSE_MUL_ADD4"), Elemwise::Mode::FUSE_MUL_ADD4}, {normalize_enum("FUSE_ADD_RELU"), Elemwise::Mode::FUSE_ADD_RELU}, {normalize_enum("FUSE_ADD_SIGMOID"), Elemwise::Mode::FUSE_ADD_SIGMOID}, {normalize_enum("FUSE_ADD_TANH"), Elemwise::Mode::FUSE_ADD_TANH}, {normalize_enum("FAST_TANH"), Elemwise::Mode::FAST_TANH}, {normalize_enum("FAST_TANH_GRAD"), Elemwise::Mode::FAST_TANH_GRAD}, {normalize_enum("ROUND"), Elemwise::Mode::ROUND}, {normalize_enum("RMULH"), Elemwise::Mode::RMULH}, {normalize_enum("ATAN2"), Elemwise::Mode::ATAN2}, {normalize_enum("ERF"), Elemwise::Mode::ERF}, {normalize_enum("ERFINV"), Elemwise::Mode::ERFINV}, {normalize_enum("ERFC"), Elemwise::Mode::ERFC}, {normalize_enum("ERFCINV"), Elemwise::Mode::ERFCINV}, {normalize_enum("H_SWISH"), Elemwise::Mode::H_SWISH}, {normalize_enum("H_SWISH_GRAD"), Elemwise::Mode::H_SWISH_GRAD}, {normalize_enum("FUSE_ADD_H_SWISH"), Elemwise::Mode::FUSE_ADD_H_SWISH}, {normalize_enum("NOT"), Elemwise::Mode::NOT}, {normalize_enum("AND"), Elemwise::Mode::AND}, {normalize_enum("OR"), Elemwise::Mode::OR}, {normalize_enum("XOR"), Elemwise::Mode::XOR}, {normalize_enum("SILU"), Elemwise::Mode::SILU}, {normalize_enum("SILU_GRAD"), Elemwise::Mode::SILU_GRAD}, {normalize_enum("GELU"), Elemwise::Mode::GELU}, {normalize_enum("GELU_GRAD"), Elemwise::Mode::GELU_GRAD}, {normalize_enum("COND_LT_MOV"), Elemwise::Mode::COND_LT_MOV}, {normalize_enum("NEQ"), Elemwise::Mode::NEQ}, {normalize_enum("ISNAN"), Elemwise::Mode::ISNAN}, {normalize_enum("ISINF"), Elemwise::Mode::ISINF}, {normalize_enum("SINH"), Elemwise::Mode::SINH}, {normalize_enum("COSH"), Elemwise::Mode::COSH}, {normalize_enum("ASINH"), Elemwise::Mode::ASINH}, {normalize_enum("ACOSH"), Elemwise::Mode::ACOSH}, {normalize_enum("ATANH"), Elemwise::Mode::ATANH}, {normalize_enum("TAN"), Elemwise::Mode::TAN}, {normalize_enum("ASINH_GRAD"), Elemwise::Mode::ASINH_GRAD}, {normalize_enum("ACOSH_GRAD"), Elemwise::Mode::ACOSH_GRAD}, {normalize_enum("ATANH_GRAD"), Elemwise::Mode::ATANH_GRAD}, {normalize_enum("PRELU"), Elemwise::Mode::PRELU}, {normalize_enum("CLIP"), Elemwise::Mode::CLIP}, {normalize_enum("PRELU_GRAD"), Elemwise::Mode::PRELU_GRAD}, {normalize_enum("SOFTPLUS"), Elemwise::Mode::SOFTPLUS}, {normalize_enum("SOFTPLUS_GRAD"), Elemwise::Mode::SOFTPLUS_GRAD}, {normalize_enum("RELU6"), Elemwise::Mode::RELU6}, {normalize_enum("RELU6_GRAD"), Elemwise::Mode::RELU6_GRAD}, {normalize_enum("HSIGMOID"), Elemwise::Mode::HSIGMOID}, {normalize_enum("HSIGMOID_GRAD"), Elemwise::Mode::HSIGMOID_GRAD}, {normalize_enum("LOGSIGMOID"), Elemwise::Mode::LOGSIGMOID}, {normalize_enum("SQRT"), Elemwise::Mode::SQRT}, {normalize_enum("SQUARE"), Elemwise::Mode::SQUARE}, {normalize_enum("SIGN"), Elemwise::Mode::SIGN}}; +template<> PyObject* EnumWrapper::pyobj_insts[86] = {nullptr}; void _init_py_Elemwise_Mode(PyTypeObject& py_type) { auto& e_type = EnumWrapper::type; @@ -9147,6 +9147,116 @@ void _init_py_Elemwise_Mode(PyTypeObject& py_type) { reinterpret_cast*>(inst)->value = Elemwise::Mode::ISINF; mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ISINF", inst) >= 0); EnumWrapper::pyobj_insts[63] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::SINH; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SINH", inst) >= 0); + EnumWrapper::pyobj_insts[64] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::COSH; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "COSH", inst) >= 0); + EnumWrapper::pyobj_insts[65] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::ASINH; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ASINH", inst) >= 0); + EnumWrapper::pyobj_insts[66] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::ACOSH; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ACOSH", inst) >= 0); + EnumWrapper::pyobj_insts[67] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::ATANH; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ATANH", inst) >= 0); + EnumWrapper::pyobj_insts[68] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::TAN; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "TAN", inst) >= 0); + EnumWrapper::pyobj_insts[69] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::ASINH_GRAD; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ASINH_GRAD", inst) >= 0); + EnumWrapper::pyobj_insts[70] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::ACOSH_GRAD; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ACOSH_GRAD", inst) >= 0); + EnumWrapper::pyobj_insts[71] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::ATANH_GRAD; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ATANH_GRAD", inst) >= 0); + EnumWrapper::pyobj_insts[72] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::PRELU; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "PRELU", inst) >= 0); + EnumWrapper::pyobj_insts[73] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::CLIP; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "CLIP", inst) >= 0); + EnumWrapper::pyobj_insts[74] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::PRELU_GRAD; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "PRELU_GRAD", inst) >= 0); + EnumWrapper::pyobj_insts[75] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::SOFTPLUS; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SOFTPLUS", inst) >= 0); + EnumWrapper::pyobj_insts[76] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::SOFTPLUS_GRAD; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SOFTPLUS_GRAD", inst) >= 0); + EnumWrapper::pyobj_insts[77] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::RELU6; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "RELU6", inst) >= 0); + EnumWrapper::pyobj_insts[78] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::RELU6_GRAD; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "RELU6_GRAD", inst) >= 0); + EnumWrapper::pyobj_insts[79] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::HSIGMOID; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "HSIGMOID", inst) >= 0); + EnumWrapper::pyobj_insts[80] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::HSIGMOID_GRAD; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "HSIGMOID_GRAD", inst) >= 0); + EnumWrapper::pyobj_insts[81] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::LOGSIGMOID; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "LOGSIGMOID", inst) >= 0); + EnumWrapper::pyobj_insts[82] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::SQRT; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SQRT", inst) >= 0); + EnumWrapper::pyobj_insts[83] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::SQUARE; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SQUARE", inst) >= 0); + EnumWrapper::pyobj_insts[84] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::SIGN; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SIGN", inst) >= 0); + EnumWrapper::pyobj_insts[85] = inst; } Py_INCREF(e_type); mgb_assert(PyDict_SetItemString( diff --git a/imperative/tablegen/generated/opdef.h.inl b/imperative/tablegen/generated/opdef.h.inl index b674b127..88fe7c61 100644 --- a/imperative/tablegen/generated/opdef.h.inl +++ b/imperative/tablegen/generated/opdef.h.inl @@ -783,6 +783,28 @@ case Elemwise::Mode::COND_LT_MOV: return "COND_LT_MOV"; case Elemwise::Mode::NEQ: return "NEQ"; case Elemwise::Mode::ISNAN: return "ISNAN"; case Elemwise::Mode::ISINF: return "ISINF"; +case Elemwise::Mode::SINH: return "SINH"; +case Elemwise::Mode::COSH: return "COSH"; +case Elemwise::Mode::ASINH: return "ASINH"; +case Elemwise::Mode::ACOSH: return "ACOSH"; +case Elemwise::Mode::ATANH: return "ATANH"; +case Elemwise::Mode::TAN: return "TAN"; +case Elemwise::Mode::ASINH_GRAD: return "ASINH_GRAD"; +case Elemwise::Mode::ACOSH_GRAD: return "ACOSH_GRAD"; +case Elemwise::Mode::ATANH_GRAD: return "ATANH_GRAD"; +case Elemwise::Mode::PRELU: return "PRELU"; +case Elemwise::Mode::CLIP: return "CLIP"; +case Elemwise::Mode::PRELU_GRAD: return "PRELU_GRAD"; +case Elemwise::Mode::SOFTPLUS: return "SOFTPLUS"; +case Elemwise::Mode::SOFTPLUS_GRAD: return "SOFTPLUS_GRAD"; +case Elemwise::Mode::RELU6: return "RELU6"; +case Elemwise::Mode::RELU6_GRAD: return "RELU6_GRAD"; +case Elemwise::Mode::HSIGMOID: return "HSIGMOID"; +case Elemwise::Mode::HSIGMOID_GRAD: return "HSIGMOID_GRAD"; +case Elemwise::Mode::LOGSIGMOID: return "LOGSIGMOID"; +case Elemwise::Mode::SQRT: return "SQRT"; +case Elemwise::Mode::SQUARE: return "SQUARE"; +case Elemwise::Mode::SIGN: return "SIGN"; default: return "Elemwise::Mode::Unknown"; } diff --git a/imperative/tablegen/generated/opdef.py.inl b/imperative/tablegen/generated/opdef.py.inl index 1788ff15..9857bf36 100644 --- a/imperative/tablegen/generated/opdef.py.inl +++ b/imperative/tablegen/generated/opdef.py.inl @@ -896,6 +896,28 @@ py::enum_(ElemwiseInst, "Mode") .value("NEQ", Elemwise::Mode::NEQ) .value("ISNAN", Elemwise::Mode::ISNAN) .value("ISINF", Elemwise::Mode::ISINF) + .value("SINH", Elemwise::Mode::SINH) + .value("COSH", Elemwise::Mode::COSH) + .value("ASINH", Elemwise::Mode::ASINH) + .value("ACOSH", Elemwise::Mode::ACOSH) + .value("ATANH", Elemwise::Mode::ATANH) + .value("TAN", Elemwise::Mode::TAN) + .value("ASINH_GRAD", Elemwise::Mode::ASINH_GRAD) + .value("ACOSH_GRAD", Elemwise::Mode::ACOSH_GRAD) + .value("ATANH_GRAD", Elemwise::Mode::ATANH_GRAD) + .value("PRELU", Elemwise::Mode::PRELU) + .value("CLIP", Elemwise::Mode::CLIP) + .value("PRELU_GRAD", Elemwise::Mode::PRELU_GRAD) + .value("SOFTPLUS", Elemwise::Mode::SOFTPLUS) + .value("SOFTPLUS_GRAD", Elemwise::Mode::SOFTPLUS_GRAD) + .value("RELU6", Elemwise::Mode::RELU6) + .value("RELU6_GRAD", Elemwise::Mode::RELU6_GRAD) + .value("HSIGMOID", Elemwise::Mode::HSIGMOID) + .value("HSIGMOID_GRAD", Elemwise::Mode::HSIGMOID_GRAD) + .value("LOGSIGMOID", Elemwise::Mode::LOGSIGMOID) + .value("SQRT", Elemwise::Mode::SQRT) + .value("SQUARE", Elemwise::Mode::SQUARE) + .value("SIGN", Elemwise::Mode::SIGN) .def(py::init([](const std::string& in) { auto&& str = normalize_enum(in); if (str == "RELU") return Elemwise::Mode::RELU; @@ -962,6 +984,28 @@ py::enum_(ElemwiseInst, "Mode") if (str == "NEQ") return Elemwise::Mode::NEQ; if (str == "ISNAN") return Elemwise::Mode::ISNAN; if (str == "ISINF") return Elemwise::Mode::ISINF; + if (str == "SINH") return Elemwise::Mode::SINH; + if (str == "COSH") return Elemwise::Mode::COSH; + if (str == "ASINH") return Elemwise::Mode::ASINH; + if (str == "ACOSH") return Elemwise::Mode::ACOSH; + if (str == "ATANH") return Elemwise::Mode::ATANH; + if (str == "TAN") return Elemwise::Mode::TAN; + if (str == "ASINH_GRAD") return Elemwise::Mode::ASINH_GRAD; + if (str == "ACOSH_GRAD") return Elemwise::Mode::ACOSH_GRAD; + if (str == "ATANH_GRAD") return Elemwise::Mode::ATANH_GRAD; + if (str == "PRELU") return Elemwise::Mode::PRELU; + if (str == "CLIP") return Elemwise::Mode::CLIP; + if (str == "PRELU_GRAD") return Elemwise::Mode::PRELU_GRAD; + if (str == "SOFTPLUS") return Elemwise::Mode::SOFTPLUS; + if (str == "SOFTPLUS_GRAD") return Elemwise::Mode::SOFTPLUS_GRAD; + if (str == "RELU6") return Elemwise::Mode::RELU6; + if (str == "RELU6_GRAD") return Elemwise::Mode::RELU6_GRAD; + if (str == "HSIGMOID") return Elemwise::Mode::HSIGMOID; + if (str == "HSIGMOID_GRAD") return Elemwise::Mode::HSIGMOID_GRAD; + if (str == "LOGSIGMOID") return Elemwise::Mode::LOGSIGMOID; + if (str == "SQRT") return Elemwise::Mode::SQRT; + if (str == "SQUARE") return Elemwise::Mode::SQUARE; + if (str == "SIGN") return Elemwise::Mode::SIGN; throw py::cast_error("invalid enum value " + in); })); py::implicitly_convertible(); diff --git a/src/gopt/test/basic_arith.cpp b/src/gopt/test/basic_arith.cpp index c61ce9eb..483cc7af 100644 --- a/src/gopt/test/basic_arith.cpp +++ b/src/gopt/test/basic_arith.cpp @@ -107,6 +107,12 @@ TEST(TestGoptBasicArithInplace, Absorbing) { ASSERT_EQ(y.as_immutable_scalar()->get_cast(), 0.f); } +auto gen_postive = [](HostTensorND& dest) { + HostTensorGenerator mask_generator{ + 2.f, 4.f}; + dest = *mask_generator(dest.shape(), dest.comp_node()); +}; + TEST(TestGoptBasicArithInplace, LogExpExpand) { // test log(exp(a) * (exp(b) / (exp(c) * d**2))) -> a + b - c - log(d**2) @@ -144,9 +150,13 @@ TEST(TestGoptBasicArithInplace, LogExpExpand) { opt.numdiff_eps_single_inp[3] = 1e-3; opt.numdiff_max_err_single_inp[3] = 1e-2; Checker{make_graph, fwd} - .run(ms({2, 3}, {2, 3}), opt) - .run(ms({1, 3}, {2, 3}), opt) - .run(ms({3, 2}, {1}), opt); + .set_input_generator(0, gen_postive) + .set_input_generator(1, gen_postive) + .set_input_generator(2, gen_postive) + .set_input_generator(3, gen_postive) + .run(ms({32, 1}, {32, 1}), opt) + .run(ms({2, 32}, {2, 32}), opt) + .run(ms({1, 32}, {1, 32}), opt); } TEST(TestGoptBasicArithInplace, LogSumExp) { diff --git a/src/jit/impl/ast_c.cpp b/src/jit/impl/ast_c.cpp index edce0007..eb1dc794 100644 --- a/src/jit/impl/ast_c.cpp +++ b/src/jit/impl/ast_c.cpp @@ -133,7 +133,7 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { 0.f}) / 6.f), }; - mgb_assert(map.size() + 19 == opr::Elemwise::Param::MODE_NR_MEMBER); + mgb_assert(map.size() + 41 == opr::Elemwise::Param::MODE_NR_MEMBER); // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, // ERFINV, ERFCINV, NOT, AND, OR, XOR, NEQ, ISNAN, ISINF return map; diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 8111b4e0..9924f747 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -543,6 +543,34 @@ MGB_IMPL_OPR_GRAD(Elemwise) { RET(EL2(SILU_GRAD, i0, og)); case Mode::GELU: RET(EL2(GELU_GRAD, i0, og)); + case Mode::SINH: + RET(EL1(COSH, i0) * og); + case Mode::COSH: + RET(EL1(SINH, i0) * og); + case Mode::ASINH: + RET(EL2(ASINH_GRAD, i0, og)); + case Mode::ACOSH: + RET(EL2(ACOSH_GRAD, i0, og)); + case Mode::ATANH: + RET(EL2(ATANH_GRAD, i0, og)); + case Mode::TAN: { + auto two = i0.make_scalar_dt(2); + RET(og / (EL2(POW, EL1(COS, i0), two))); + } + case Mode::RELU6: + RET(EL2(RELU6_GRAD, i0, og)); + case Mode::SOFTPLUS: + RET(EL2(SOFTPLUS_GRAD, i0, og)); + case Mode::HSIGMOID: + RET(EL2(HSIGMOID_GRAD, i0, og)); + case Mode::LOGSIGMOID: + RET(EL2(SOFTPLUS_GRAD, EL1(NEGATE, i0), og)); + case Mode::SQRT: + RET(og / EL1(SQRT, i0) / 2); + case Mode::SQUARE: + RET(og * 2 * i0); + case Mode::SIGN: + RET(i0.make_scalar_dt(0).broadcast(i0.symshape())); // binary case Mode::ABS_GRAD: @@ -617,6 +645,11 @@ MGB_IMPL_OPR_GRAD(Elemwise) { case Mode::XOR: case Mode::AND: return nullptr; + case Mode::PRELU: + if (wrt_idx == 0) { + RET(EL3(PRELU_GRAD, i0, og, i1)); + } + RET(EL2(SWITCH_GT0, -i0, og * i0)); // ternary case Mode::COND_LEQ_MOV: @@ -627,6 +660,15 @@ MGB_IMPL_OPR_GRAD(Elemwise) { if (wrt_idx <= 1) return nullptr; RET(EL3(COND_LT_MOV, i0, i1, og)); + case Mode::CLIP: + if (wrt_idx == 0) { + RET(EL3(COND_LEQ_MOV, i1, i0, EL3(COND_LEQ_MOV, i0, i2, og))); + } + if (wrt_idx == 1) { + RET(EL3(COND_LEQ_MOV, i0, i1, og)); + } + RET(EL3(COND_LEQ_MOV, i2, i0, og)); + // fuse oprs case Mode::FUSE_MUL_ADD3: if (wrt_idx < 2) { diff --git a/src/opr/impl/basic_arith.sereg.h b/src/opr/impl/basic_arith.sereg.h index 8b760d1e..47d46917 100644 --- a/src/opr/impl/basic_arith.sereg.h +++ b/src/opr/impl/basic_arith.sereg.h @@ -1,5 +1,6 @@ #include "megbrain/opr/basic_arith.h" #include "megbrain/opr/internal/param_tag_defs.h" +#include "megbrain/opr/io.h" #include "megbrain/serialization/helper.h" #include "megbrain/serialization/sereg.h" @@ -140,7 +141,254 @@ struct ParamConverter { template <> struct OprMaker : public OprMakerVariadic {}; +template <> +struct OprLoadDumpImplV2 { + using Opr = opr::Elemwise; + using PersisParam = opr::Elemwise::Param; + using PersisElemwseiParam = opr::Elemwise::Param; + static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { + ctx.write_param(opr.cast_final_safe().param()); + } + + static cg::OperatorNodeBase* replace_opr( + cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { + auto mode = opr->cast_final_safe().param().mode; + auto astype = [](VarNode* inp, VarNode* ref) { + return opr::TypeCvt::make(inp, ref->dtype()).node(); + }; + auto make_const = [](DTypeScalar val, VarNode* ref) { + return opr::ImmutableTensor::make( + *ref->owner_graph(), val, ref->comp_node()) + .node(); + }; + auto float_half = DTypeScalar(static_cast(0.5)); + auto float_one = DTypeScalar(static_cast(1.0)); + auto float_two = DTypeScalar(static_cast(2.0)); + auto float_zero = DTypeScalar(static_cast(0.0)); + auto float_six = DTypeScalar(static_cast(6.0)); + auto float_three = DTypeScalar(static_cast(3.0)); + if (PersisParam::Mode::SQRT == mode) { + auto elemwise_mode = PersisParam::Mode::POW; + auto half_var = make_const(float_half, inputs[0]); + if (inputs[0]->dtype() != half_var->dtype()) { + half_var = astype(half_var, inputs[0]); + } + return opr::Elemwise::make({inputs[0], half_var}, elemwise_mode) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::SQUARE == mode) { + auto elemwise_mode = PersisParam::Mode::POW; + auto two_var = make_const(float_two, inputs[0]); + if (inputs[0]->dtype() != two_var->dtype()) { + two_var = astype(two_var, inputs[0]); + } + return opr::Elemwise::make({inputs[0], two_var}, elemwise_mode) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::TAN == mode) { + auto sin = opr::Elemwise::make({inputs[0]}, PersisParam::Mode::SIN).node(); + auto cos = opr::Elemwise::make({inputs[0]}, PersisParam::Mode::COS).node(); + return opr::Elemwise::make({sin, cos}, PersisParam::Mode::TRUE_DIV) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::COSH == mode) { + auto half_var = make_const(float_half, inputs[0]); + if (inputs[0]->dtype() != half_var->dtype()) + half_var = astype(half_var, inputs[0]); + auto expx = opr::Elemwise::make({inputs[0]}, PersisParam::Mode::EXP).node(); + auto negatex = + opr::Elemwise::make({inputs[0]}, PersisParam::Mode::NEGATE).node(); + auto expnegatex = + opr::Elemwise::make({negatex}, PersisParam::Mode::EXP).node(); + return opr::Elemwise::make( + {half_var, + opr::Elemwise::make( + {expx, expnegatex}, PersisParam::Mode::ADD) + .node()}, + PersisParam::Mode::MUL) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::SINH == mode) { + auto inp = inputs[0]; + auto two_var = make_const(float_two, inputs[0]); + auto half_var = make_const(float_half, inputs[0]); + auto one_var = make_const(float_one, inputs[0]); + if (inp->dtype() != two_var->dtype()) { + two_var = astype(two_var, inputs[0]); + half_var = astype(half_var, inputs[0]); + one_var = astype(one_var, inputs[0]); + } + auto u = opr::Elemwise::make({inp}, PersisParam::Mode::EXPM1).node(); + auto tmp1 = + opr::Elemwise::make({u, half_var}, PersisParam::Mode::MUL).node(); + auto uadd1 = + opr::Elemwise::make({u, one_var}, PersisParam::Mode::ADD).node(); + auto uadd2 = + opr::Elemwise::make({u, two_var}, PersisParam::Mode::ADD).node(); + auto tmp2 = opr::Elemwise::make({tmp1, uadd1}, PersisParam::Mode::TRUE_DIV) + .node(); + return opr::Elemwise::make({tmp2, uadd2}, PersisParam::Mode::MUL) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::ASINH == mode) { + auto inp = inputs[0]; + auto two_var = make_const(float_two, inp); + auto half_var = make_const(float_half, inp); + auto one_var = make_const(float_one, inp); + if (inp->dtype() != two_var->dtype()) { + two_var = astype(two_var, inputs[0]); + half_var = astype(half_var, inputs[0]); + one_var = astype(one_var, inputs[0]); + } + auto inp2 = + opr::Elemwise::make({inp, two_var}, PersisParam::Mode::POW).node(); + auto inp2add1 = + opr::Elemwise::make({inp2, one_var}, PersisParam::Mode::ADD).node(); + auto inp2add1sqrt = + opr::Elemwise::make({inp2add1, half_var}, PersisParam::Mode::POW) + .node(); + auto tmp = opr::Elemwise::make({inp, inp2add1sqrt}, PersisParam::Mode::ADD) + .node(); + return opr::Elemwise::make({tmp}, PersisElemwseiParam::Mode::LOG) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::ACOSH == mode) { + auto inp = inputs[0]; + auto two_var = make_const(float_two, inp); + auto half_var = make_const(float_half, inp); + auto one_var = make_const(float_one, inp); + if (inp->dtype() != two_var->dtype()) { + two_var = astype(two_var, inputs[0]); + half_var = astype(half_var, inputs[0]); + one_var = astype(one_var, inputs[0]); + } + auto inp2 = + opr::Elemwise::make({inp, two_var}, PersisParam::Mode::POW).node(); + auto inp2sub1 = + opr::Elemwise::make({inp2, one_var}, PersisParam::Mode::SUB).node(); + auto inp2sub1sqrt = + opr::Elemwise::make({inp2sub1, half_var}, PersisParam::Mode::POW) + .node(); + auto tmp = opr::Elemwise::make({inp, inp2sub1sqrt}, PersisParam::Mode::ADD) + .node(); + return opr::Elemwise::make({tmp}, PersisElemwseiParam::Mode::LOG) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::ATANH == mode) { + auto inp = inputs[0]; + auto two_var = make_const(float_two, inp); + auto one_var = make_const(float_one, inp); + if (inp->dtype() != two_var->dtype()) { + two_var = astype(two_var, inputs[0]); + one_var = astype(one_var, inputs[0]); + } + auto tmp1 = + opr::Elemwise::make({two_var, inp}, PersisParam::Mode::MUL).node(); + auto tmp2 = + opr::Elemwise::make({one_var, inp}, PersisParam::Mode::SUB).node(); + auto tmp3 = opr::Elemwise::make({tmp1, tmp2}, PersisParam::Mode::TRUE_DIV) + .node(); + auto log1p = opr::Elemwise::make({tmp3}, PersisParam::Mode::LOG1P).node(); + return opr::Elemwise::make({log1p, two_var}, PersisParam::Mode::TRUE_DIV) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::CLIP == mode) { + auto tmp = + opr::Elemwise::make({inputs[0], inputs[1]}, PersisParam::Mode::MAX) + .node(); + return opr::Elemwise::make({tmp, inputs[1]}, PersisParam::Mode::MIN) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::SIGN == mode) { + auto zero_var = make_const(float_zero, inputs[0]); + zero_var = astype(zero_var, inputs[0]); + auto tmp1 = + opr::Elemwise::make({inputs[0], zero_var}, PersisParam::Mode::LT) + .node(); + auto tmp2 = + opr::Elemwise::make({zero_var, inputs[0]}, PersisParam::Mode::LT) + .node(); + return opr::Elemwise::make({tmp1, tmp2}, PersisParam::Mode::SUB) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::HSIGMOID == mode) { + auto six_var = make_const(float_six, inputs[0]); + auto zero_var = make_const(float_zero, inputs[0]); + auto three_var = make_const(float_three, inputs[0]); + if (inputs[0]->dtype() != six_var->dtype()) { + six_var = astype(six_var, inputs[0]); + zero_var = astype(zero_var, inputs[0]); + three_var = astype(three_var, inputs[0]); + } + auto tmp1 = + opr::Elemwise::make({inputs[0], three_var}, PersisParam::Mode::ADD) + .node(); + auto tmp2 = opr::Elemwise::make({tmp1, zero_var}, PersisParam::Mode::MAX) + .node(); + auto tmp3 = + opr::Elemwise::make({tmp2, six_var}, PersisParam::Mode::MIN).node(); + return opr::Elemwise::make({tmp3, six_var}, PersisParam::Mode::TRUE_DIV) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::RELU6 == mode) { + auto six_var = make_const(float_six, inputs[0]); + auto zero_var = make_const(float_zero, inputs[0]); + six_var = astype(six_var, inputs[0]); + zero_var = astype(zero_var, inputs[0]); + auto max_0 = + opr::Elemwise::make({inputs[0], zero_var}, PersisParam::Mode::MAX) + .node(); + return opr::Elemwise::make({max_0, six_var}, PersisParam::Mode::MIN) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::PRELU == mode) { + auto zero_var = make_const(float_zero, inputs[0]); + auto inp = inputs[0]; + auto weight = inputs[1]; + if (inp->dtype() != zero_var->dtype()) { + zero_var = astype(zero_var, inp); + } + auto min_0 = + opr::Elemwise::make({inp, zero_var}, PersisParam::Mode::MIN).node(); + auto max_0 = + opr::Elemwise::make({inp, zero_var}, PersisParam::Mode::MAX).node(); + return opr::Elemwise::make( + {min_0, weight, max_0}, PersisParam::Mode::FUSE_MUL_ADD3) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::SOFTPLUS == mode) { + auto inp = inputs[0]; + auto abs = opr::Elemwise::make({inp}, PersisParam::Mode::ABS).node(); + auto neg_abs = opr::Elemwise::make({abs}, PersisParam::Mode::NEGATE).node(); + auto exp = opr::Elemwise::make({neg_abs}, PersisParam::Mode::EXP).node(); + auto oup0 = opr::Elemwise::make({exp}, PersisParam::Mode::LOG1P).node(); + auto oup1 = opr::Elemwise::make({inp}, PersisParam::Mode::RELU).node(); + return opr::Elemwise::make({oup0, oup1}, PersisParam::Mode::ADD) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::LOGSIGMOID == mode) { + auto inp = inputs[0]; + auto abs = opr::Elemwise::make({inp}, PersisParam::Mode::ABS).node(); + auto neg_abs = opr::Elemwise::make({abs}, PersisParam::Mode::NEGATE).node(); + auto exp = opr::Elemwise::make({neg_abs}, PersisParam::Mode::EXP).node(); + auto oup0 = opr::Elemwise::make({exp}, PersisParam::Mode::LOG1P).node(); + auto neg_inp = opr::Elemwise::make({inp}, PersisParam::Mode::NEGATE).node(); + auto oup1 = opr::Elemwise::make({neg_inp}, PersisParam::Mode::RELU).node(); + auto oup = opr::Elemwise::make({oup0, oup1}, PersisParam::Mode::ADD).node(); + return opr::Elemwise::make({oup}, PersisParam::Mode::NEGATE) + .node() + ->owner_opr(); + } + return opr; + } + static cg::OperatorNodeBase* load( + OprLoadContext& ctx, const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + return OprMaker::make( + ctx.read_param(), inputs, ctx.graph(), config); + } +}; template <> struct OprMaker { using Opr = opr::Reduce; @@ -177,7 +425,11 @@ cg::OperatorNodeBase* opr_shallow_copy_add_update( ->owner_opr(); } -MGB_SEREG_OPR(Elemwise, 0); +MGB_SEREG_OPR_CONDITION(Elemwise, 0, false); +MGB_SEREG_OPR_V2_HASH_WITHOUT_TAIL_0( + Elemwise, 0, + (mgb::serialization::OprLoadDumpImplV2::replace_opr), + VERSION_1, VERSION_1); MGB_SEREG_OPR(PowC, 1); MGB_SEREG_OPR(AddUpdate, 2); MGB_REG_OPR_SHALLOW_COPY(AddUpdate, opr_shallow_copy_add_update); diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index 3b45f822..436fc2b8 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -119,6 +119,8 @@ MGB_IMPL_OPR_GRAD(ArgsortForward) { MGB_DYN_TYPE_OBJ_FINAL_IMPL(ArgsortBackward); MEGDNN_OPR_INIT3(ArgsortBackward, "argsort_bwd", 2, false) +/* ================= Cumprod ================= */ + /* ================= Cumsum ================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cumsum); diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index e9342d3d..7b66b0b5 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -349,6 +349,99 @@ struct CheckerConfig : public CheckerConfig {}; template <> struct CheckerConfig : public NoGradCheckerConfig {}; +template <> +struct CheckerConfig : public NoGradCheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(-1.2, 1.2); + } +}; +template <> +struct CheckerConfig : public CheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(-5, 5); + } + template + static void update_opt(Opt& opt) { + opt.numdiff_eps = 1e-2; + opt.numdiff_max_err = 0.1; + } +}; +template <> +struct CheckerConfig : public CheckerConfig {}; +template <> +struct CheckerConfig : public CheckerConfig { + template + static void update_opt(Opt& opt) { + opt.numdiff_eps = 1e-2; + opt.numdiff_max_err = 0.1; + } +}; +template <> +struct CheckerConfig : public CheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(1.05, 5); + } +}; +template <> +struct CheckerConfig : public CheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(-0.95, 0.95); + } +}; +template <> +struct CheckerConfig : public CheckerConfig {}; +template <> +struct CheckerConfig : public CheckerConfig {}; +template <> +struct CheckerConfig : public CheckerConfig {}; +template <> +struct CheckerConfig : public CheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(0.05, 5); + } + template + static void update_opt(Opt& opt) { + opt.numdiff_eps = 1e-2; + opt.numdiff_max_err = 0.1; + } +}; +template <> +struct CheckerConfig : public CheckerConfig { + template + static void do_update_checker(Checker& checker) { + auto icoord = [](const typename Checker::NumInpArray& inp) { + auto p0 = inp[0]->template ptr(); + for (size_t i = 0, it = inp[0]->shape().total_nr_elems(); i < it; ++i) { + if (std::abs(p0[i]) < 1) { + p0[i] += 2; + } else if (std::abs(p0[i] - 6) < 1) { + p0[i] += 2; + } + } + }; + checker.set_input_coordinator(icoord); + } + template + static void update_checker(Checker& checker) { + using ctype = typename Checker::ctype; + return do_update_checker(checker); + } +}; +template <> +struct CheckerConfig : public CheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(-2.95, 2.95); + } +}; +template <> +struct CheckerConfig : public NoZeroCheckerConfig<0> {}; + /* ======================= binary config ======================= */ template struct BinaryInputMinGap : public CheckerConfig { @@ -567,13 +660,89 @@ template <> struct CheckerConfig : public NoGradCheckerConfig {}; template <> struct CheckerConfig : public NoGradCheckerConfig {}; +template <> +struct CheckerConfig : public NoZeroCheckerConfig<0> {}; +template <> +struct CheckerConfig : public NoGradCheckerConfig {}; +template <> +struct CheckerConfig : public NoGradCheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(1.05, 5); + } +}; +template <> +struct CheckerConfig : public NoGradCheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(-0.95, 0.95); + } +}; +template <> +struct CheckerConfig : public NoGradCheckerConfig {}; +template <> +struct CheckerConfig : public NoGradCheckerConfig {}; +template <> +struct CheckerConfig : public NoGradCheckerConfig { + template + static InputGenerator get_inp_gen(size_t) { + return get_inp_gen_f32_range(-2.95, 2.95); + } +}; /* ======================= ternary config ======================= */ template <> struct CheckerConfig : public BinaryInputMinGap {}; + template <> struct CheckerConfig : public BinaryInputMinGap {}; +template <> +struct CheckerConfig : public NoGradCheckerConfig {}; + +template <> +struct CheckerConfig : public CheckerConfig { + template + static void do_update_checker(Checker& checker) { + auto icoord = [](const typename Checker::NumInpArray& inp) { + auto p0 = inp[0]->template ptr(), p1 = inp[1]->template ptr(), + p2 = inp[2]->template ptr(); + for (size_t i = 0, it = inp[0]->shape().total_nr_elems(); i < it; ++i) { + if (p1[i] > p2[i]) { + std::swap(p1[i], p2[i]); + } + if (p1[i] + 1 > p2[i]) { + p2[i] = p1[i] + 1; + } + if (std::abs(p1[i] - p0[i]) < 1) { + if (p1[i] < p0[i]) + p0[i] += 1; + else + p0[i] -= 1; + } + if (std::abs(p2[i] - p0[i]) < 1) { + if (p2[i] < p0[i]) + p0[i] += 1; + else + p0[i] -= 1; + } + } + }; + checker.set_input_coordinator(icoord); + } + + template + static void update_checker(Checker& checker) { + using ctype = typename Checker::ctype; + return do_update_checker(checker); + } + + template + static void update_opt(Opt& opt) { + opt.numdiff_eps = 1e-3; + opt.numdiff_max_err = 0.1; + } +}; /* ======================= test runner ======================= */ namespace detail { template @@ -721,6 +890,10 @@ void TestRunner::run() { } TensorShape shapes[] = {{1}, {23, 3}, {666}}; + if (Trait::ARITY == 4) { + checker.disable_graph_opt(); + shapes[0] = {32}; + } typename Checker::RunOptions opt; Config::update_opt(opt); Config::update_checker(checker); @@ -869,13 +1042,13 @@ TEST(TestOprBasicArithElemwise, FuseMulAdd4Shapes) { }; Checker checker{make_graph, fwd}; - checker.run({TensorShape{1, 2}, {2, 1}, {1, 2}, {2, 1}}) - .run({TensorShape{1, 2, 1, 2, 1, 2}, - {2, 1, 2, 1, 2, 1}, - {2, 1, 2, 1, 2, 1}, - {1, 2, 1, 2, 1, 2}}); + checker.run({TensorShape{1, 32}, {1, 32}, {1, 32}, {1, 32}}) + .run({TensorShape{1, 1, 1, 1, 1, 32}, + {1, 1, 1, 1, 1, 32}, + {1, 1, 1, 1, 1, 32}, + {1, 1, 1, 1, 1, 32}}); ASSERT_FALSE(opr->fuse_badlayout_warn_printed()); - checker.run({TensorShape{1, 2}, {2, 1}, {2, 2}, {2, 2}}); + checker.run({TensorShape{1, 32}, {32, 1}, {32, 32}, {32, 32}}); ASSERT_TRUE(opr->fuse_badlayout_warn_printed()); } diff --git a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl index 0663986a..1ed742db 100644 --- a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl @@ -41,6 +41,7 @@ DEF_TRAIT(SWITCH_GT0, x > 0 ? y : 0) DEF_TRAIT(TANH_GRAD, (1 - x * x) * y) DEF_TRAIT(FUSE_ADD_RELU, std::max(x + y, 0)) +DEF_TRAIT(PRELU, (x > 0) ? x : (x* y)) #undef _ALLOW_INT #define _ALLOW_INT false @@ -57,6 +58,12 @@ DEF_TRAIT( SILU_GRAD, y*(1 + std::exp(-x) + x * std::exp(-x)) / (1 + std::exp(-x)) / (1 + std::exp(-x))) DEF_TRAIT(GELU_GRAD, do_gelu_grad(x, y)) +DEF_TRAIT(ASINH_GRAD, y / std::sqrt(x * x + 1)) +DEF_TRAIT(ACOSH_GRAD, y / std::sqrt(x * x - 1)) +DEF_TRAIT(ATANH_GRAD, y / (1 - x * x)) +DEF_TRAIT(SOFTPLUS_GRAD, y* std::exp(x) / (1.f + std::exp(x))) +DEF_TRAIT(RELU6_GRAD, x <= 0.f ? 0.f : (x >= 6.f ? 0.f : y)) +DEF_TRAIT(HSIGMOID_GRAD, x <= -3.f ? 0.f : (x >= 3.f ? 0.f : (y / 6.f))) #undef _ALLOW_INT #undef _ALLOW_FLOAT diff --git a/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl b/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl index e9e5cd8e..a9bf8af1 100644 --- a/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl @@ -15,6 +15,10 @@ DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0) DEF_TRAIT(COND_LT_MOV, x < y ? z : 0) DEF_TRAIT(FUSE_MUL_ADD3, x* y + z) +DEF_TRAIT(CLIP, x < y ? y : (x < z ? x : z)) +#undef _ALLOW_INT +#define _ALLOW_INT false +DEF_TRAIT(PRELU_GRAD, x > 0 ? y : (y * z)) #undef _ALLOW_INT #undef _ALLOW_FLOAT diff --git a/src/opr/test/basic_arith/elemwise_unary_trait_def.inl b/src/opr/test/basic_arith/elemwise_unary_trait_def.inl index a6e64ecd..edacc035 100644 --- a/src/opr/test/basic_arith/elemwise_unary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_unary_trait_def.inl @@ -22,6 +22,9 @@ DEF_TRAIT(NOT, !x) DEF_TRAIT(ABS, std::abs(x)) DEF_TRAIT(NEGATE, -x) DEF_TRAIT(RELU, std::max(x, 0)) +DEF_TRAIT(RELU6, std::min(std::max(x, 0), 6)) +DEF_TRAIT(SQUARE, x* x) +DEF_TRAIT(SIGN, x < 0 ? -1 : (x > 0 ? 1 : 0)) #undef _ALLOW_INT #define _ALLOW_INT false @@ -46,6 +49,16 @@ DEF_TRAIT(ERFCINV, do_erfcinv(x)) DEF_TRAIT(H_SWISH, do_h_swish(x)) DEF_TRAIT(SILU, x / (1 + std::exp(-x))) DEF_TRAIT(GELU, x*(0.5f * (1.f + std::erf(x / std::sqrt(2.f))))) +DEF_TRAIT(SINH, std::sinh(x)) +DEF_TRAIT(COSH, std::cosh(x)) +DEF_TRAIT(ASINH, std::asinh(x)) +DEF_TRAIT(ACOSH, std::acosh(x)) +DEF_TRAIT(ATANH, std::atanh(x)) +DEF_TRAIT(TAN, std::tan(x)) +DEF_TRAIT(SOFTPLUS, std::log1p(std::exp(-std::abs(x))) + std::max(x, 0)) +DEF_TRAIT(HSIGMOID, x <= -3.f ? 0.f : (x >= 3.f ? 1.f : ((x + 3.f) / 6.f))) +DEF_TRAIT(SQRT, std::sqrt(x)) +DEF_TRAIT(LOGSIGMOID, -std::log1p(std::exp(-std::abs(x))) - std::max(-x, 0)) #undef _ALLOW_INT #undef _ALLOW_FLOAT diff --git a/src/serialization/impl/serializer_oss_v2.cpp b/src/serialization/impl/serializer_oss_v2.cpp index 246e11c5..8b9dce64 100644 --- a/src/serialization/impl/serializer_oss_v2.cpp +++ b/src/serialization/impl/serializer_oss_v2.cpp @@ -54,9 +54,13 @@ public: auto new_opr = (it->second)(opr, new_inp); auto &&origin_out = opr->output(), &&cur_out = new_opr->output(); - for (size_t i = 0; i < std::min(origin_out.size(), cur_out.size()); - i++) { - rewriter.replace_var(origin_out[i], cur_out[i], nullptr); + if (opr == new_opr) { + rewriter.auto_replace_outputs(opr); + } else { + for (size_t i = 0; i < std::min(origin_out.size(), cur_out.size()); + i++) { + rewriter.replace_var(origin_out[i], cur_out[i], nullptr); + } } } else { rewriter.auto_replace_outputs(opr); diff --git a/test/src/autocheck.cpp b/test/src/autocheck.cpp index 1de0194c..e8465a18 100644 --- a/test/src/autocheck.cpp +++ b/test/src/autocheck.cpp @@ -224,6 +224,7 @@ DEF_IMPL(void)::do_run(const ShapeInpArray& shapes, const RunOptions& opt) { m_inputs_generator[i](*m_inputs[i]); mgb_assert(m_inputs[i]->shape().eq_shape(shapes[i])); } + if (MGB_GETENV("MGB_AUTOCHECK_DUMP_INPUT")) { static size_t run_id; auto fname = output_file(ssprintf("autocheck-inp-%zu.bin", run_id++));