From a1e67207560a8cbe271e2beee6d31923ff73ce0b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 20 Aug 2020 17:35:20 +0800 Subject: [PATCH] feat(dnn): enable bool comparison GitOrigin-RevId: 735693b81e46189db15be0b7f98fa64973c3035e --- dnn/scripts/gen_elemwise_utils.py | 2 +- dnn/src/common/elemwise/each_mode.inl | 3 +++ dnn/src/common/elemwise/kern_defs.cuh | 3 +++ dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu | 15 +++++++++++++++ dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu | 15 +++++++++++++++ dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu | 15 +++++++++++++++ dnn/src/naive/elemwise/kimpl/EQ_dt_bool.cpp | 15 +++++++++++++++ dnn/src/naive/elemwise/kimpl/LEQ_dt_bool.cpp | 15 +++++++++++++++ dnn/src/naive/elemwise/kimpl/LT_dt_bool.cpp | 15 +++++++++++++++ src/opr/test/basic_arith/elemwise.cpp | 3 +++ src/opr/test/basic_arith/elemwise_binary_trait_def.inl | 11 +++++++---- 11 files changed, 107 insertions(+), 5 deletions(-) create mode 100644 dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu create mode 100644 dnn/src/naive/elemwise/kimpl/EQ_dt_bool.cpp create mode 100644 dnn/src/naive/elemwise/kimpl/LEQ_dt_bool.cpp create mode 100644 dnn/src/naive/elemwise/kimpl/LT_dt_bool.cpp diff --git a/dnn/scripts/gen_elemwise_utils.py b/dnn/scripts/gen_elemwise_utils.py index 2fd6ca8f..5b48a7d0 100755 --- a/dnn/scripts/gen_elemwise_utils.py +++ b/dnn/scripts/gen_elemwise_utils.py @@ -30,6 +30,6 @@ MODES = { 'FUSE_ADD_H_SWISH'], (3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], (1, 'BOOL'): ['NOT'], - (2, 'BOOL'): ['AND', 'OR', 'XOR'], + (2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], (3, 'BOOL'): [] } diff --git a/dnn/src/common/elemwise/each_mode.inl b/dnn/src/common/elemwise/each_mode.inl index fca47da8..a9501afa 100644 --- a/dnn/src/common/elemwise/each_mode.inl +++ b/dnn/src/common/elemwise/each_mode.inl @@ -45,6 +45,9 @@ MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) \ #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index 5a05384b..e5906411 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -173,6 +173,9 @@ namespace megdnn { DEF_KERN_ALL(LT, x < y); DEF_KERN_ALL(LEQ, x <= y); DEF_KERN_ALL(EQ, x == y); + DEF_KERN(dt_bool, LT, x < y); + DEF_KERN(dt_bool, LEQ, x <= y); + DEF_KERN(dt_bool, EQ, x == y); DEF_KERN_INT(FLOOR_DIV, x / y); DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y)); diff --git a/dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu b/dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu new file mode 100644 index 00000000..e52069cc --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu @@ -0,0 +1,15 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu new file mode 100644 index 00000000..985b7a97 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu @@ -0,0 +1,15 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu b/dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu new file mode 100644 index 00000000..f3873934 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu @@ -0,0 +1,15 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/EQ_dt_bool.cpp b/dnn/src/naive/elemwise/kimpl/EQ_dt_bool.cpp new file mode 100644 index 00000000..1d3f3fab --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/EQ_dt_bool.cpp @@ -0,0 +1,15 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/EQ_dt_bool.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/LEQ_dt_bool.cpp b/dnn/src/naive/elemwise/kimpl/LEQ_dt_bool.cpp new file mode 100644 index 00000000..6a7848b1 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/LEQ_dt_bool.cpp @@ -0,0 +1,15 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/LEQ_dt_bool.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/LT_dt_bool.cpp b/dnn/src/naive/elemwise/kimpl/LT_dt_bool.cpp new file mode 100644 index 00000000..fad8b121 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/LT_dt_bool.cpp @@ -0,0 +1,15 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/LT_dt_bool.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool +#include "../kern_impl.inl" diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index 09b32566..6acd970e 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -812,6 +812,9 @@ TEST_OPR_BASIC_ARITH_UNARY_BOOL(NOT, !) TEST_OPR_BASIC_ARITH_BINARY_BOOL(AND, &&) TEST_OPR_BASIC_ARITH_BINARY_BOOL(OR, ||) TEST_OPR_BASIC_ARITH_BINARY_BOOL(XOR, ^) +TEST_OPR_BASIC_ARITH_BINARY_BOOL(LT, <) +TEST_OPR_BASIC_ARITH_BINARY_BOOL(LEQ, <=) +TEST_OPR_BASIC_ARITH_BINARY_BOOL(EQ, ==) TEST(TestOprBasicArithElemwise, FuseMulAdd3Shapes) { using Checker = AutoOprChecker<3, 1>; diff --git a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl index ba32cbf3..0a70f1b9 100644 --- a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl @@ -27,6 +27,13 @@ DEF_TRAIT(OR, x || y) DEF_TRAIT(XOR, x ^ y) #undef _ALLOW_INT #undef _ALLOW_FLOAT + +#define _ALLOW_INT true +#define _ALLOW_FLOAT true +DEF_TRAIT(EQ, x == y) +DEF_TRAIT(LEQ, x <= y) +DEF_TRAIT(LT, x < y) + #undef _ALLOW_BOOL #define _ALLOW_BOOL false @@ -44,10 +51,6 @@ DEF_TRAIT(SUB, x - y) DEF_TRAIT(SWITCH_GT0, x > 0 ? y : 0) DEF_TRAIT(TANH_GRAD, (1 - x * x) * y) -DEF_TRAIT(EQ, x == y) -DEF_TRAIT(LEQ, x <= y) -DEF_TRAIT(LT, x < y) - DEF_TRAIT(FUSE_ADD_RELU, std::max(x + y, 0)) #undef _ALLOW_INT