@@ -17,16 +17,16 @@ MODES = { | |||||
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | ||||
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | ||||
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | ||||
3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | |||||
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||||
} | } | ||||
QINT4_MODES = { | QINT4_MODES = { | ||||
1: ['RELU', 'ABS', 'NEGATE', 'CEIL', 'FLOOR', 'SIGMOID', | 1: ['RELU', 'ABS', 'NEGATE', 'CEIL', 'FLOOR', 'SIGMOID', | ||||
'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'], | 'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'], | ||||
2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0', | |||||
'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH', | |||||
2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0', | |||||
'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH', | |||||
'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'], | 'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'], | ||||
3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | |||||
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||||
} | } | ||||
QINT32_MODES = { | QINT32_MODES = { | ||||
@@ -16,7 +16,7 @@ MODES = { | |||||
(2, 'INT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | (2, 'INT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | ||||
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ', | 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ', | ||||
'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH'], | 'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH'], | ||||
(3, 'INT'): ['COND_LEQ_MOV'], | |||||
(3, 'INT'): ['COND_LEQ_MOV', 'COND_LT_MOV'], | |||||
(1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | (1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | ||||
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | ||||
@@ -28,7 +28,7 @@ MODES = { | |||||
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | ||||
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | ||||
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | ||||
(3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | |||||
(3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||||
(1, 'BOOL'): ['NOT'], | (1, 'BOOL'): ['NOT'], | ||||
(2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], | (2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], | ||||
(3, 'BOOL'): [] | (3, 'BOOL'): [] | ||||
@@ -420,6 +420,7 @@ pdef('Elemwise').add_enum( | |||||
Doc('SILU_GRAD = 57', 'binary: grad(x / (1 + exp(-x))'), | Doc('SILU_GRAD = 57', 'binary: grad(x / (1 + exp(-x))'), | ||||
Doc('GELU = 58', 'unary: x Phi(x)'), | Doc('GELU = 58', 'unary: x Phi(x)'), | ||||
Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'), | Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'), | ||||
Doc('COND_LT_MOV = 60', 'ternary: x < y ? z : 0'), | |||||
) | ) | ||||
pdef('ElemwiseMultiType').add_enum( | pdef('ElemwiseMultiType').add_enum( | ||||
@@ -510,7 +511,8 @@ pdef('ElemwiseMultiType').add_enum( | |||||
'and the result is float32.'), | 'and the result is float32.'), | ||||
Doc('FUSE_MUL_ADD3_UINT8xF32xF32xF32 = 56', | Doc('FUSE_MUL_ADD3_UINT8xF32xF32xF32 = 56', | ||||
'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and ' | 'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and ' | ||||
'``c`` float32, and the result is float32.') | |||||
'``c`` float32, and the result is float32.'), | |||||
Doc('QCOND_LT_MOV = 57', 'quantized cond_lt_mov'), | |||||
) | ) | ||||
pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | ||||
@@ -92,7 +92,9 @@ | |||||
#define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \ | #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \ | ||||
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ | ||||
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \ | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) | MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) | ||||
#define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb) \ | #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb) \ | ||||
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) |
@@ -265,6 +265,7 @@ DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); | |||||
// int and float | // int and float | ||||
DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0)); | DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0)); | ||||
DEF_KERN_ALL(COND_LT_MOV, x < y ? z : ctype(0)); | |||||
DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z); | DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z); | ||||
#undef KERN_SIG | #undef KERN_SIG | ||||
@@ -219,6 +219,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||||
CB_MODE(Mode::SILU_GRAD); | CB_MODE(Mode::SILU_GRAD); | ||||
CB_MODE(Mode::GELU); | CB_MODE(Mode::GELU); | ||||
CB_MODE(Mode::GELU_GRAD); | CB_MODE(Mode::GELU_GRAD); | ||||
CB_MODE(Mode::COND_LT_MOV); | |||||
default: | default: | ||||
megdnn_assert( | megdnn_assert( | ||||
0, | 0, | ||||
@@ -239,6 +239,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||||
SET(init_quantized_ternary_op, QFUSE_MUL_ADD3); | SET(init_quantized_ternary_op, QFUSE_MUL_ADD3); | ||||
SET(init_quantized_ternary_op, QCOND_LEQ_MOV); | SET(init_quantized_ternary_op, QCOND_LEQ_MOV); | ||||
SET(init_quantized_ternary_op, QCOND_LT_MOV); | |||||
#undef SET | #undef SET | ||||
} | } | ||||
@@ -95,6 +95,7 @@ void ElemwiseMultiTypeImplHelper::exec( | |||||
ON_QUANTIZED_MODE(FUSE_MUL_ADD3, 3); | ON_QUANTIZED_MODE(FUSE_MUL_ADD3, 3); | ||||
ON_QUANTIZED_MODE(COND_LEQ_MOV, 3); | ON_QUANTIZED_MODE(COND_LEQ_MOV, 3); | ||||
ON_QUANTIZED_MODE(COND_LT_MOV, 3); | |||||
default: | default: | ||||
megdnn_throw("invalid mode"); | megdnn_throw("invalid mode"); | ||||
} | } | ||||
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_float16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_float32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_int16 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_int32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_int8 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_uint8 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_STYPE dt_qint8 | |||||
#define KERN_IMPL_DTYPE dt_qint8 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_float16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_float32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_int16 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_int32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_int8 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_uint8 | |||||
#include "../kern_impl.inl" |
@@ -25,6 +25,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( | |||||
DISPATCH(FUSE_MUL_ADD3); | DISPATCH(FUSE_MUL_ADD3); | ||||
DISPATCH(COND_LEQ_MOV); | DISPATCH(COND_LEQ_MOV); | ||||
DISPATCH(COND_LT_MOV); | |||||
#undef DISPATCH | #undef DISPATCH | ||||
default: | default: | ||||
megdnn_assert_internal(0); | megdnn_assert_internal(0); | ||||
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,7 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_float16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_float32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_int16 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_int32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_int8 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,5 @@ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_uint8 | |||||
#include "../kern_impl.inl" |
@@ -179,6 +179,35 @@ DEF_TEST(ternary_non_contig) { | |||||
checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}}); | checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}}); | ||||
} | } | ||||
DEF_TEST(ternary_lt) { | |||||
using Mode = ElemwiseForward::Param::Mode; | |||||
Checker<ElemwiseForward> checker(handle); | |||||
checker.set_param(Mode::COND_LT_MOV); | |||||
checker.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}}); | |||||
checker.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dtype::Float32()) | |||||
.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}}); | |||||
checker.set_dtype(0, dtype::Float16()) | |||||
.set_dtype(1, dtype::Float16()) | |||||
.set_dtype(2, dtype::Float16()) | |||||
.set_dtype(3, dtype::Float16()) | |||||
.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}}); | |||||
checker.execs({{2, 1, 1, 5}, {4, 5}, {3, 1, 1}, {2, 3, 4, 5}}); | |||||
checker.execs({{3, 1, 1}, {5}, {4, 1}, {3, 4, 5}}); | |||||
ASSERT_THROW(checker.execs({{2, 3, 4}, {4, 1}, {1}, {2, 3, 4}}), MegDNNError); | |||||
ASSERT_THROW(checker.execs({{2, 4, 4}, {4, 1}, {3, 1, 1}, {2, 3, 4}}), MegDNNError); | |||||
} | |||||
DEF_TEST(ternary_lt_non_contig) { | |||||
using Mode = ElemwiseForward::Param::Mode; | |||||
Checker<ElemwiseForward> checker(handle); | |||||
checker.set_param(Mode::COND_LT_MOV); | |||||
TensorLayout ly{{2, 3}, dtype::Float32()}; | |||||
ly.stride[0] = 4; | |||||
checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}}); | |||||
} | |||||
DEF_TEST(fuse_mul_add3) { | DEF_TEST(fuse_mul_add3) { | ||||
using Mode = ElemwiseForward::Param::Mode; | using Mode = ElemwiseForward::Param::Mode; | ||||
Checker<ElemwiseForward> checker(handle); | Checker<ElemwiseForward> checker(handle); | ||||
@@ -16,6 +16,8 @@ namespace elemwise { | |||||
cb(binary_non_contig) \ | cb(binary_non_contig) \ | ||||
cb(ternary) \ | cb(ternary) \ | ||||
cb(ternary_non_contig) \ | cb(ternary_non_contig) \ | ||||
cb(ternary_lt) \ | |||||
cb(ternary_lt_non_contig) \ | |||||
cb(fuse_mul_add3) \ | cb(fuse_mul_add3) \ | ||||
cb(fuse_mul_add3_non_contig) \ | cb(fuse_mul_add3_non_contig) \ | ||||
cb(fuse_mul_add4) \ | cb(fuse_mul_add4) \ | ||||
@@ -207,7 +207,7 @@ TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_TENARY) { | |||||
using Mode = ElemwiseMultiType::Param::Mode; | using Mode = ElemwiseMultiType::Param::Mode; | ||||
Checker<ElemwiseMultiType> checker(handle_cuda()); | Checker<ElemwiseMultiType> checker(handle_cuda()); | ||||
for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV}) { | |||||
for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV, Mode::QCOND_LT_MOV}) { | |||||
UniformIntRNG rng_int8{-127, 127}; | UniformIntRNG rng_int8{-127, 127}; | ||||
UniformIntRNG rng_uint8{0, 225}; | UniformIntRNG rng_uint8{0, 225}; | ||||
checker.set_param({mode}) | checker.set_param({mode}) | ||||
@@ -368,7 +368,7 @@ TEST_F(CUDA, BENCHMARK_ELEMWISE_QUANTIZED_MODE_TENARY) { | |||||
CUBenchmarker<ElemwiseMultiType> bencher(handle_cuda()); | CUBenchmarker<ElemwiseMultiType> bencher(handle_cuda()); | ||||
UniformIntRNG rng{-128, 127}; | UniformIntRNG rng{-128, 127}; | ||||
for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV}) { | |||||
for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV, Mode::QCOND_LT_MOV}) { | |||||
printf("Benchmark mode: %d\n", (int)mode); | printf("Benchmark mode: %d\n", (int)mode); | ||||
bencher.set_param({mode}) | bencher.set_param({mode}) | ||||
.set_rng(0, &rng) | .set_rng(0, &rng) | ||||
@@ -59,6 +59,7 @@ Elemwise::Mode get_elem_mode(ElemwiseMultiType::Mode mode) { | |||||
MODE(FAST_TANH_GRAD); | MODE(FAST_TANH_GRAD); | ||||
MODE(ATAN2); | MODE(ATAN2); | ||||
MODE(COND_LEQ_MOV); | MODE(COND_LEQ_MOV); | ||||
MODE(COND_LT_MOV); | |||||
MODE(H_SWISH_GRAD); | MODE(H_SWISH_GRAD); | ||||
MODE(FUSE_ADD_H_SWISH); | MODE(FUSE_ADD_H_SWISH); | ||||
@@ -231,7 +232,9 @@ TEST_F(NAIVE, ELEMWISE_QUANTIZED_MODE_TERNARY) { | |||||
.set_dtype(1, dtype::QuantizedS8(0.2f)) | .set_dtype(1, dtype::QuantizedS8(0.2f)) | ||||
.set_dtype(2, dtype::QuantizedS8(0.3f)); | .set_dtype(2, dtype::QuantizedS8(0.3f)); | ||||
for (auto mode : {Param::Mode::QFUSE_MUL_ADD3, Param::Mode::QCOND_LEQ_MOV}) { | |||||
for (auto mode : | |||||
{Param::Mode::QFUSE_MUL_ADD3, Param::Mode::QCOND_LEQ_MOV, | |||||
Param::Mode::QCOND_LT_MOV}) { | |||||
Param param{mode}; | Param param{mode}; | ||||
checker.set_param(param); | checker.set_param(param); | ||||