Browse Source

feat(dnn): enable bool comparison

GitOrigin-RevId: 735693b81e
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
a1e6720756
11 changed files with 107 additions and 5 deletions
  1. +1
    -1
      dnn/scripts/gen_elemwise_utils.py
  2. +3
    -0
      dnn/src/common/elemwise/each_mode.inl
  3. +3
    -0
      dnn/src/common/elemwise/kern_defs.cuh
  4. +15
    -0
      dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu
  5. +15
    -0
      dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu
  6. +15
    -0
      dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu
  7. +15
    -0
      dnn/src/naive/elemwise/kimpl/EQ_dt_bool.cpp
  8. +15
    -0
      dnn/src/naive/elemwise/kimpl/LEQ_dt_bool.cpp
  9. +15
    -0
      dnn/src/naive/elemwise/kimpl/LT_dt_bool.cpp
  10. +3
    -0
      src/opr/test/basic_arith/elemwise.cpp
  11. +7
    -4
      src/opr/test/basic_arith/elemwise_binary_trait_def.inl

+ 1
- 1
dnn/scripts/gen_elemwise_utils.py View File

@@ -30,6 +30,6 @@ MODES = {
'FUSE_ADD_H_SWISH'], 'FUSE_ADD_H_SWISH'],
(3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], (3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'],
(1, 'BOOL'): ['NOT'], (1, 'BOOL'): ['NOT'],
(2, 'BOOL'): ['AND', 'OR', 'XOR'],
(2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'],
(3, 'BOOL'): [] (3, 'BOOL'): []
} }

+ 3
- 0
dnn/src/common/elemwise/each_mode.inl View File

@@ -45,6 +45,9 @@
MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) \


#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb) \ #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \


+ 3
- 0
dnn/src/common/elemwise/kern_defs.cuh View File

@@ -173,6 +173,9 @@ namespace megdnn {
DEF_KERN_ALL(LT, x < y); DEF_KERN_ALL(LT, x < y);
DEF_KERN_ALL(LEQ, x <= y); DEF_KERN_ALL(LEQ, x <= y);
DEF_KERN_ALL(EQ, x == y); DEF_KERN_ALL(EQ, x == y);
DEF_KERN(dt_bool, LT, x < y);
DEF_KERN(dt_bool, LEQ, x <= y);
DEF_KERN(dt_bool, EQ, x == y);


DEF_KERN_INT(FLOOR_DIV, x / y); DEF_KERN_INT(FLOOR_DIV, x / y);
DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y)); DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y));


+ 15
- 0
dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu View File

@@ -0,0 +1,15 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"

+ 15
- 0
dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu View File

@@ -0,0 +1,15 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"

+ 15
- 0
dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu View File

@@ -0,0 +1,15 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"

+ 15
- 0
dnn/src/naive/elemwise/kimpl/EQ_dt_bool.cpp View File

@@ -0,0 +1,15 @@
/**
* \file dnn/src/naive/elemwise/kimpl/EQ_dt_bool.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"

+ 15
- 0
dnn/src/naive/elemwise/kimpl/LEQ_dt_bool.cpp View File

@@ -0,0 +1,15 @@
/**
* \file dnn/src/naive/elemwise/kimpl/LEQ_dt_bool.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"

+ 15
- 0
dnn/src/naive/elemwise/kimpl/LT_dt_bool.cpp View File

@@ -0,0 +1,15 @@
/**
* \file dnn/src/naive/elemwise/kimpl/LT_dt_bool.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"

+ 3
- 0
src/opr/test/basic_arith/elemwise.cpp View File

@@ -812,6 +812,9 @@ TEST_OPR_BASIC_ARITH_UNARY_BOOL(NOT, !)
TEST_OPR_BASIC_ARITH_BINARY_BOOL(AND, &&) TEST_OPR_BASIC_ARITH_BINARY_BOOL(AND, &&)
TEST_OPR_BASIC_ARITH_BINARY_BOOL(OR, ||) TEST_OPR_BASIC_ARITH_BINARY_BOOL(OR, ||)
TEST_OPR_BASIC_ARITH_BINARY_BOOL(XOR, ^) TEST_OPR_BASIC_ARITH_BINARY_BOOL(XOR, ^)
TEST_OPR_BASIC_ARITH_BINARY_BOOL(LT, <)
TEST_OPR_BASIC_ARITH_BINARY_BOOL(LEQ, <=)
TEST_OPR_BASIC_ARITH_BINARY_BOOL(EQ, ==)


TEST(TestOprBasicArithElemwise, FuseMulAdd3Shapes) { TEST(TestOprBasicArithElemwise, FuseMulAdd3Shapes) {
using Checker = AutoOprChecker<3, 1>; using Checker = AutoOprChecker<3, 1>;


+ 7
- 4
src/opr/test/basic_arith/elemwise_binary_trait_def.inl View File

@@ -27,6 +27,13 @@ DEF_TRAIT(OR, x || y)
DEF_TRAIT(XOR, x ^ y) DEF_TRAIT(XOR, x ^ y)
#undef _ALLOW_INT #undef _ALLOW_INT
#undef _ALLOW_FLOAT #undef _ALLOW_FLOAT

#define _ALLOW_INT true
#define _ALLOW_FLOAT true
DEF_TRAIT(EQ, x == y)
DEF_TRAIT(LEQ, x <= y)
DEF_TRAIT(LT, x < y)

#undef _ALLOW_BOOL #undef _ALLOW_BOOL


#define _ALLOW_BOOL false #define _ALLOW_BOOL false
@@ -44,10 +51,6 @@ DEF_TRAIT(SUB, x - y)
DEF_TRAIT(SWITCH_GT0, x > 0 ? y : 0) DEF_TRAIT(SWITCH_GT0, x > 0 ? y : 0)
DEF_TRAIT(TANH_GRAD, (1 - x * x) * y) DEF_TRAIT(TANH_GRAD, (1 - x * x) * y)


DEF_TRAIT(EQ, x == y)
DEF_TRAIT(LEQ, x <= y)
DEF_TRAIT(LT, x < y)

DEF_TRAIT(FUSE_ADD_RELU, std::max<ctype>(x + y, 0)) DEF_TRAIT(FUSE_ADD_RELU, std::max<ctype>(x + y, 0))
#undef _ALLOW_INT #undef _ALLOW_INT




Loading…
Cancel
Save