From bbafe69974a79b3a16cb46d70b903ae2dcad28b3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 20 Jun 2022 15:23:30 +0800 Subject: [PATCH] feat(dnn): add elemwise COND_LT_MOV GitOrigin-RevId: 444cd6825a775bed21562ebf5443b153b130745e --- dnn/scripts/gen_elemwise_multi_type_utils.py | 8 +++--- dnn/scripts/gen_elemwise_utils.py | 4 +-- dnn/scripts/opr_param_defs.py | 4 ++- dnn/src/common/elemwise/each_mode.inl | 4 ++- dnn/src/common/elemwise/kern_defs.cuh | 1 + dnn/src/common/elemwise/opr_impl.cpp | 1 + dnn/src/common/elemwise_multi_type/opr_impl.cpp | 1 + .../common/elemwise_multi_type/opr_impl_helper.cpp | 1 + .../cuda/elemwise/kimpl/COND_LT_MOV_dt_bfloat16.cu | 7 ++++++ .../cuda/elemwise/kimpl/COND_LT_MOV_dt_float16.cu | 7 ++++++ .../cuda/elemwise/kimpl/COND_LT_MOV_dt_float32.cu | 5 ++++ .../cuda/elemwise/kimpl/COND_LT_MOV_dt_int16.cu | 5 ++++ .../cuda/elemwise/kimpl/COND_LT_MOV_dt_int32.cu | 5 ++++ dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_int8.cu | 5 ++++ .../cuda/elemwise/kimpl/COND_LT_MOV_dt_uint8.cu | 5 ++++ .../kimpl/COND_LT_MOV_dt_qint4_dt_qint4.cu | 6 +++++ .../kimpl/COND_LT_MOV_dt_qint8_dt_qint8.cu | 6 +++++ .../kimpl/COND_LT_MOV_dt_quint4_dt_quint4.cu | 6 +++++ .../elemwise/kimpl/COND_LT_MOV_dt_bfloat16.cpp | 7 ++++++ .../elemwise/kimpl/COND_LT_MOV_dt_float16.cpp | 7 ++++++ .../elemwise/kimpl/COND_LT_MOV_dt_float32.cpp | 5 ++++ .../naive/elemwise/kimpl/COND_LT_MOV_dt_int16.cpp | 5 ++++ .../naive/elemwise/kimpl/COND_LT_MOV_dt_int32.cpp | 5 ++++ .../naive/elemwise/kimpl/COND_LT_MOV_dt_int8.cpp | 5 ++++ .../naive/elemwise/kimpl/COND_LT_MOV_dt_uint8.cpp | 5 ++++ dnn/src/naive/elemwise_multi_type/opr_impl_5.cpp | 1 + .../elemwise/kimpl/COND_LT_MOV_dt_bfloat16.cpp.hip | 7 ++++++ .../elemwise/kimpl/COND_LT_MOV_dt_float16.cpp.hip | 7 ++++++ .../elemwise/kimpl/COND_LT_MOV_dt_float32.cpp.hip | 5 ++++ .../elemwise/kimpl/COND_LT_MOV_dt_int16.cpp.hip | 5 ++++ .../elemwise/kimpl/COND_LT_MOV_dt_int32.cpp.hip | 5 ++++ .../elemwise/kimpl/COND_LT_MOV_dt_int8.cpp.hip | 5 ++++ .../elemwise/kimpl/COND_LT_MOV_dt_uint8.cpp.hip | 5 ++++ dnn/test/common/elemwise.cpp | 29 ++++++++++++++++++++++ dnn/test/common/elemwise.h | 2 ++ dnn/test/cuda/elemwise_multi_type.cpp | 4 +-- dnn/test/naive/elemwise_multi_type.cpp | 5 +++- 37 files changed, 189 insertions(+), 11 deletions(-) create mode 100644 dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_bfloat16.cu create mode 100644 dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_float16.cu create mode 100644 dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_float32.cu create mode 100644 dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_int16.cu create mode 100644 dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_int32.cu create mode 100644 dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_int8.cu create mode 100644 dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_uint8.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/COND_LT_MOV_dt_qint4_dt_qint4.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/COND_LT_MOV_dt_qint8_dt_qint8.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/COND_LT_MOV_dt_quint4_dt_quint4.cu create mode 100644 dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_bfloat16.cpp create mode 100644 dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_float16.cpp create mode 100644 dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_float32.cpp create mode 100644 dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_int16.cpp create mode 100644 dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_int32.cpp create mode 100644 dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_int8.cpp create mode 100644 dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_uint8.cpp create mode 100644 dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_bfloat16.cpp.hip create mode 100644 dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_float16.cpp.hip create mode 100644 dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_float32.cpp.hip create mode 100644 dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_int16.cpp.hip create mode 100644 dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_int32.cpp.hip create mode 100644 dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_int8.cpp.hip create mode 100644 dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_uint8.cpp.hip diff --git a/dnn/scripts/gen_elemwise_multi_type_utils.py b/dnn/scripts/gen_elemwise_multi_type_utils.py index f97505c6..9bbf65d6 100755 --- a/dnn/scripts/gen_elemwise_multi_type_utils.py +++ b/dnn/scripts/gen_elemwise_multi_type_utils.py @@ -17,16 +17,16 @@ MODES = { 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], - 3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], + 3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], } QINT4_MODES = { 1: ['RELU', 'ABS', 'NEGATE', 'CEIL', 'FLOOR', 'SIGMOID', 'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'], - 2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0', - 'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH', + 2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0', + 'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH', 'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'], - 3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], + 3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], } QINT32_MODES = { diff --git a/dnn/scripts/gen_elemwise_utils.py b/dnn/scripts/gen_elemwise_utils.py index 5d744e74..84bc541f 100755 --- a/dnn/scripts/gen_elemwise_utils.py +++ b/dnn/scripts/gen_elemwise_utils.py @@ -16,7 +16,7 @@ MODES = { (2, 'INT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH'], - (3, 'INT'): ['COND_LEQ_MOV'], + (3, 'INT'): ['COND_LEQ_MOV', 'COND_LT_MOV'], (1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', @@ -28,7 +28,7 @@ MODES = { 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], - (3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], + (3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], (1, 'BOOL'): ['NOT'], (2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], (3, 'BOOL'): [] diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index a49e65b9..22055d0a 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -420,6 +420,7 @@ pdef('Elemwise').add_enum( Doc('SILU_GRAD = 57', 'binary: grad(x / (1 + exp(-x))'), Doc('GELU = 58', 'unary: x Phi(x)'), Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'), + Doc('COND_LT_MOV = 60', 'ternary: x < y ? z : 0'), ) pdef('ElemwiseMultiType').add_enum( @@ -510,7 +511,8 @@ pdef('ElemwiseMultiType').add_enum( 'and the result is float32.'), Doc('FUSE_MUL_ADD3_UINT8xF32xF32xF32 = 56', 'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and ' - '``c`` float32, and the result is float32.') + '``c`` float32, and the result is float32.'), + Doc('QCOND_LT_MOV = 57', 'quantized cond_lt_mov'), ) pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) diff --git a/dnn/src/common/elemwise/each_mode.inl b/dnn/src/common/elemwise/each_mode.inl index 22892795..67a144ca 100644 --- a/dnn/src/common/elemwise/each_mode.inl +++ b/dnn/src/common/elemwise/each_mode.inl @@ -92,7 +92,9 @@ #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) + MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index c2d30106..35a785ca 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -265,6 +265,7 @@ DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); // int and float DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0)); +DEF_KERN_ALL(COND_LT_MOV, x < y ? z : ctype(0)); DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z); #undef KERN_SIG diff --git a/dnn/src/common/elemwise/opr_impl.cpp b/dnn/src/common/elemwise/opr_impl.cpp index 5d5d3f40..ed994866 100644 --- a/dnn/src/common/elemwise/opr_impl.cpp +++ b/dnn/src/common/elemwise/opr_impl.cpp @@ -219,6 +219,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { CB_MODE(Mode::SILU_GRAD); CB_MODE(Mode::GELU); CB_MODE(Mode::GELU_GRAD); + CB_MODE(Mode::COND_LT_MOV); default: megdnn_assert( 0, diff --git a/dnn/src/common/elemwise_multi_type/opr_impl.cpp b/dnn/src/common/elemwise_multi_type/opr_impl.cpp index 641f9abe..b4c35bd2 100644 --- a/dnn/src/common/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/common/elemwise_multi_type/opr_impl.cpp @@ -239,6 +239,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { SET(init_quantized_ternary_op, QFUSE_MUL_ADD3); SET(init_quantized_ternary_op, QCOND_LEQ_MOV); + SET(init_quantized_ternary_op, QCOND_LT_MOV); #undef SET } diff --git a/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp b/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp index 1f29fbff..179c51e3 100644 --- a/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp +++ b/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp @@ -95,6 +95,7 @@ void ElemwiseMultiTypeImplHelper::exec( ON_QUANTIZED_MODE(FUSE_MUL_ADD3, 3); ON_QUANTIZED_MODE(COND_LEQ_MOV, 3); + ON_QUANTIZED_MODE(COND_LT_MOV, 3); default: megdnn_throw("invalid mode"); } diff --git a/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_bfloat16.cu new file mode 100644 index 00000000..534d221a --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_float16.cu new file mode 100644 index 00000000..712c521c --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_float32.cu new file mode 100644 index 00000000..b9392e7e --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_int16.cu new file mode 100644 index 00000000..d6037b08 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_int16.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_int32.cu new file mode 100644 index 00000000..76621fee --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_int32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_int8.cu new file mode 100644 index 00000000..5e94a4b5 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_int8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_uint8.cu new file mode 100644 index 00000000..4653afac --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/COND_LT_MOV_dt_uint8.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LT_MOV_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LT_MOV_dt_qint4_dt_qint4.cu new file mode 100644 index 00000000..0f0488de --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LT_MOV_dt_qint4_dt_qint4.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 +#include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LT_MOV_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LT_MOV_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..570920aa --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LT_MOV_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LT_MOV_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LT_MOV_dt_quint4_dt_quint4.cu new file mode 100644 index 00000000..cd06ae4d --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LT_MOV_dt_quint4_dt_quint4.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 +#include "../kern_impl_q4.inl" diff --git a/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_bfloat16.cpp new file mode 100644 index 00000000..534d221a --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_float16.cpp new file mode 100644 index 00000000..712c521c --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_float32.cpp new file mode 100644 index 00000000..b9392e7e --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_int16.cpp b/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_int16.cpp new file mode 100644 index 00000000..d6037b08 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_int16.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_int32.cpp b/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_int32.cpp new file mode 100644 index 00000000..76621fee --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_int32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_int8.cpp b/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_int8.cpp new file mode 100644 index 00000000..5e94a4b5 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_int8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_uint8.cpp b/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_uint8.cpp new file mode 100644 index 00000000..4653afac --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/COND_LT_MOV_dt_uint8.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise_multi_type/opr_impl_5.cpp b/dnn/src/naive/elemwise_multi_type/opr_impl_5.cpp index eb193ca8..520b1973 100644 --- a/dnn/src/naive/elemwise_multi_type/opr_impl_5.cpp +++ b/dnn/src/naive/elemwise_multi_type/opr_impl_5.cpp @@ -25,6 +25,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( DISPATCH(FUSE_MUL_ADD3); DISPATCH(COND_LEQ_MOV); + DISPATCH(COND_LT_MOV); #undef DISPATCH default: megdnn_assert_internal(0); diff --git a/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..7c23e1bd --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_float16.cpp.hip new file mode 100644 index 00000000..3e9cbc63 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_float32.cpp.hip new file mode 100644 index 00000000..e0bb83ed --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_int16.cpp.hip new file mode 100644 index 00000000..0cb1f113 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_int32.cpp.hip new file mode 100644 index 00000000..96b9f4a0 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_int8.cpp.hip new file mode 100644 index 00000000..c825b2cd --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_uint8.cpp.hip new file mode 100644 index 00000000..f4905585 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COND_LT_MOV_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/test/common/elemwise.cpp b/dnn/test/common/elemwise.cpp index f705f17c..f7d3cb79 100644 --- a/dnn/test/common/elemwise.cpp +++ b/dnn/test/common/elemwise.cpp @@ -179,6 +179,35 @@ DEF_TEST(ternary_non_contig) { checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}}); } +DEF_TEST(ternary_lt) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle); + checker.set_param(Mode::COND_LT_MOV); + checker.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}}); + checker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}}); + checker.set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()) + .set_dtype(3, dtype::Float16()) + .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}}); + checker.execs({{2, 1, 1, 5}, {4, 5}, {3, 1, 1}, {2, 3, 4, 5}}); + checker.execs({{3, 1, 1}, {5}, {4, 1}, {3, 4, 5}}); + ASSERT_THROW(checker.execs({{2, 3, 4}, {4, 1}, {1}, {2, 3, 4}}), MegDNNError); + ASSERT_THROW(checker.execs({{2, 4, 4}, {4, 1}, {3, 1, 1}, {2, 3, 4}}), MegDNNError); +} + +DEF_TEST(ternary_lt_non_contig) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle); + checker.set_param(Mode::COND_LT_MOV); + TensorLayout ly{{2, 3}, dtype::Float32()}; + ly.stride[0] = 4; + checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}}); +} + DEF_TEST(fuse_mul_add3) { using Mode = ElemwiseForward::Param::Mode; Checker checker(handle); diff --git a/dnn/test/common/elemwise.h b/dnn/test/common/elemwise.h index e1649d7a..776bb62c 100644 --- a/dnn/test/common/elemwise.h +++ b/dnn/test/common/elemwise.h @@ -16,6 +16,8 @@ namespace elemwise { cb(binary_non_contig) \ cb(ternary) \ cb(ternary_non_contig) \ + cb(ternary_lt) \ + cb(ternary_lt_non_contig) \ cb(fuse_mul_add3) \ cb(fuse_mul_add3_non_contig) \ cb(fuse_mul_add4) \ diff --git a/dnn/test/cuda/elemwise_multi_type.cpp b/dnn/test/cuda/elemwise_multi_type.cpp index 23182bc7..36b985fe 100644 --- a/dnn/test/cuda/elemwise_multi_type.cpp +++ b/dnn/test/cuda/elemwise_multi_type.cpp @@ -207,7 +207,7 @@ TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_TENARY) { using Mode = ElemwiseMultiType::Param::Mode; Checker checker(handle_cuda()); - for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV}) { + for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV, Mode::QCOND_LT_MOV}) { UniformIntRNG rng_int8{-127, 127}; UniformIntRNG rng_uint8{0, 225}; checker.set_param({mode}) @@ -368,7 +368,7 @@ TEST_F(CUDA, BENCHMARK_ELEMWISE_QUANTIZED_MODE_TENARY) { CUBenchmarker bencher(handle_cuda()); UniformIntRNG rng{-128, 127}; - for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV}) { + for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV, Mode::QCOND_LT_MOV}) { printf("Benchmark mode: %d\n", (int)mode); bencher.set_param({mode}) .set_rng(0, &rng) diff --git a/dnn/test/naive/elemwise_multi_type.cpp b/dnn/test/naive/elemwise_multi_type.cpp index 0617aa4c..cf81591b 100644 --- a/dnn/test/naive/elemwise_multi_type.cpp +++ b/dnn/test/naive/elemwise_multi_type.cpp @@ -59,6 +59,7 @@ Elemwise::Mode get_elem_mode(ElemwiseMultiType::Mode mode) { MODE(FAST_TANH_GRAD); MODE(ATAN2); MODE(COND_LEQ_MOV); + MODE(COND_LT_MOV); MODE(H_SWISH_GRAD); MODE(FUSE_ADD_H_SWISH); @@ -231,7 +232,9 @@ TEST_F(NAIVE, ELEMWISE_QUANTIZED_MODE_TERNARY) { .set_dtype(1, dtype::QuantizedS8(0.2f)) .set_dtype(2, dtype::QuantizedS8(0.3f)); - for (auto mode : {Param::Mode::QFUSE_MUL_ADD3, Param::Mode::QCOND_LEQ_MOV}) { + for (auto mode : + {Param::Mode::QFUSE_MUL_ADD3, Param::Mode::QCOND_LEQ_MOV, + Param::Mode::QCOND_LT_MOV}) { Param param{mode}; checker.set_param(param);