@@ -30,6 +30,6 @@ MODES = { | |||||
'FUSE_ADD_H_SWISH'], | 'FUSE_ADD_H_SWISH'], | ||||
(3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | (3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | ||||
(1, 'BOOL'): ['NOT'], | (1, 'BOOL'): ['NOT'], | ||||
(2, 'BOOL'): ['AND', 'OR', 'XOR'], | |||||
(2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], | |||||
(3, 'BOOL'): [] | (3, 'BOOL'): [] | ||||
} | } |
@@ -45,6 +45,9 @@ | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \ | ||||
MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) \ | ||||
MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) \ | ||||
MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) \ | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) \ | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) \ | |||||
#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb) \ | #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb) \ | ||||
MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ | ||||
@@ -173,6 +173,9 @@ namespace megdnn { | |||||
DEF_KERN_ALL(LT, x < y); | DEF_KERN_ALL(LT, x < y); | ||||
DEF_KERN_ALL(LEQ, x <= y); | DEF_KERN_ALL(LEQ, x <= y); | ||||
DEF_KERN_ALL(EQ, x == y); | DEF_KERN_ALL(EQ, x == y); | ||||
DEF_KERN(dt_bool, LT, x < y); | |||||
DEF_KERN(dt_bool, LEQ, x <= y); | |||||
DEF_KERN(dt_bool, EQ, x == y); | |||||
DEF_KERN_INT(FLOOR_DIV, x / y); | DEF_KERN_INT(FLOOR_DIV, x / y); | ||||
DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y)); | DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y)); | ||||
@@ -0,0 +1,15 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bool | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,15 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bool | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,15 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bool | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,15 @@ | |||||
/** | |||||
* \file dnn/src/naive/elemwise/kimpl/EQ_dt_bool.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bool | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,15 @@ | |||||
/** | |||||
* \file dnn/src/naive/elemwise/kimpl/LEQ_dt_bool.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bool | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,15 @@ | |||||
/** | |||||
* \file dnn/src/naive/elemwise/kimpl/LT_dt_bool.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bool | |||||
#include "../kern_impl.inl" |
@@ -812,6 +812,9 @@ TEST_OPR_BASIC_ARITH_UNARY_BOOL(NOT, !) | |||||
TEST_OPR_BASIC_ARITH_BINARY_BOOL(AND, &&) | TEST_OPR_BASIC_ARITH_BINARY_BOOL(AND, &&) | ||||
TEST_OPR_BASIC_ARITH_BINARY_BOOL(OR, ||) | TEST_OPR_BASIC_ARITH_BINARY_BOOL(OR, ||) | ||||
TEST_OPR_BASIC_ARITH_BINARY_BOOL(XOR, ^) | TEST_OPR_BASIC_ARITH_BINARY_BOOL(XOR, ^) | ||||
TEST_OPR_BASIC_ARITH_BINARY_BOOL(LT, <) | |||||
TEST_OPR_BASIC_ARITH_BINARY_BOOL(LEQ, <=) | |||||
TEST_OPR_BASIC_ARITH_BINARY_BOOL(EQ, ==) | |||||
TEST(TestOprBasicArithElemwise, FuseMulAdd3Shapes) { | TEST(TestOprBasicArithElemwise, FuseMulAdd3Shapes) { | ||||
using Checker = AutoOprChecker<3, 1>; | using Checker = AutoOprChecker<3, 1>; | ||||
@@ -27,6 +27,13 @@ DEF_TRAIT(OR, x || y) | |||||
DEF_TRAIT(XOR, x ^ y) | DEF_TRAIT(XOR, x ^ y) | ||||
#undef _ALLOW_INT | #undef _ALLOW_INT | ||||
#undef _ALLOW_FLOAT | #undef _ALLOW_FLOAT | ||||
#define _ALLOW_INT true | |||||
#define _ALLOW_FLOAT true | |||||
DEF_TRAIT(EQ, x == y) | |||||
DEF_TRAIT(LEQ, x <= y) | |||||
DEF_TRAIT(LT, x < y) | |||||
#undef _ALLOW_BOOL | #undef _ALLOW_BOOL | ||||
#define _ALLOW_BOOL false | #define _ALLOW_BOOL false | ||||
@@ -44,10 +51,6 @@ DEF_TRAIT(SUB, x - y) | |||||
DEF_TRAIT(SWITCH_GT0, x > 0 ? y : 0) | DEF_TRAIT(SWITCH_GT0, x > 0 ? y : 0) | ||||
DEF_TRAIT(TANH_GRAD, (1 - x * x) * y) | DEF_TRAIT(TANH_GRAD, (1 - x * x) * y) | ||||
DEF_TRAIT(EQ, x == y) | |||||
DEF_TRAIT(LEQ, x <= y) | |||||
DEF_TRAIT(LT, x < y) | |||||
DEF_TRAIT(FUSE_ADD_RELU, std::max<ctype>(x + y, 0)) | DEF_TRAIT(FUSE_ADD_RELU, std::max<ctype>(x + y, 0)) | ||||
#undef _ALLOW_INT | #undef _ALLOW_INT | ||||