From e258812f1260766458e798ec92837b31d9b329c8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 22 Jul 2020 19:01:26 +0800 Subject: [PATCH] feat(dnn): add bool dtype GitOrigin-RevId: 98c8a092b4872f4f0ef6068c58e47871d9f042be --- dnn/include/megdnn/dtype.h | 12 +++-- dnn/include/megdnn/oprs/general.h | 3 +- dnn/scripts/gen_elemwise_utils.py | 4 ++ dnn/scripts/opr_param_defs.py | 7 ++- dnn/src/common/cond_take/predicate.cuh | 1 + dnn/src/common/elemwise/each_mode.inl | 8 +++ dnn/src/common/elemwise/kern_defs.cuh | 4 ++ dnn/src/common/elemwise/opr_impl.cpp | 18 ++++++- dnn/src/common/elemwise/opr_impl_body.inl | 19 ++++++- dnn/src/common/elemwise/opr_impl_class_def.inl | 3 ++ dnn/src/common/type_cvt.cpp | 6 ++- dnn/src/cuda/cond_take/kimpl/dt_bool.cu | 27 ++++++++++ dnn/src/cuda/elemwise/kern_wrapper.cuh | 30 ++++++----- dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu | 15 ++++++ dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu | 15 ++++++ dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu | 15 ++++++ dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu | 15 ++++++ dnn/src/cuda/elemwise_helper.cpp | 7 +++ dnn/src/cuda/elemwise_helper.cuh | 4 ++ dnn/src/cuda/type_cvt/kern.cu | 14 +++-- dnn/src/fallback/type_cvt/opr_impl.cpp | 11 ++-- dnn/src/naive/elemwise/kimpl/AND_dt_bool.cpp | 15 ++++++ dnn/src/naive/elemwise/kimpl/NOT_dt_bool.cpp | 15 ++++++ dnn/src/naive/elemwise/kimpl/OR_dt_bool.cpp | 15 ++++++ dnn/src/naive/elemwise/kimpl/XOR_dt_bool.cpp | 15 ++++++ dnn/src/naive/type_cvt/opr_impl.cpp | 2 + dnn/test/common/elemwise.cpp | 2 + python_module/src/cpp/megbrain_wrap.cpp | 1 + src/core/include/megbrain/dtype.h | 1 + src/jit/impl/ast_c.cpp | 4 +- src/jit/impl/halide/ast_hl.cpp | 8 +++ src/opr/impl/basic_arith.cpp | 6 +++ src/opr/impl/loop/forward.cpp | 2 + src/opr/impl/loop/impl.cpp | 2 + src/opr/include/megbrain/opr/basic_arith_wrapper.h | 4 ++ src/opr/test/basic_arith/elemwise.cpp | 63 +++++++++++++++++++++- .../test/basic_arith/elemwise_binary_trait_def.inl | 12 +++++ .../basic_arith/elemwise_ternary_trait_def.inl | 2 + .../test/basic_arith/elemwise_unary_trait_def.inl | 11 ++++ src/serialization/impl/dtype.fbs | 1 + test/src/helper.cpp | 15 ++++++ test/src/include/megbrain/test/helper.h | 22 ++++++++ 42 files changed, 424 insertions(+), 32 deletions(-) create mode 100644 dnn/src/cuda/cond_take/kimpl/dt_bool.cu create mode 100644 dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu create mode 100644 dnn/src/naive/elemwise/kimpl/AND_dt_bool.cpp create mode 100644 dnn/src/naive/elemwise/kimpl/NOT_dt_bool.cpp create mode 100644 dnn/src/naive/elemwise/kimpl/OR_dt_bool.cpp create mode 100644 dnn/src/naive/elemwise/kimpl/XOR_dt_bool.cpp diff --git a/dnn/include/megdnn/dtype.h b/dnn/include/megdnn/dtype.h index bd11994e..bdfff413 100644 --- a/dnn/include/megdnn/dtype.h +++ b/dnn/include/megdnn/dtype.h @@ -52,6 +52,7 @@ namespace megdnn { MEGDNN_INC_FLOAT16(cb(Float16)) \ MEGDNN_INC_FLOAT16(cb(BFloat16)) \ cb(UintB4) \ + cb(Bool) \ /*! * \brief iterate through each full byte dtype @@ -65,6 +66,7 @@ namespace megdnn { cb(Byte) \ MEGDNN_INC_FLOAT16(cb(Float16)) \ MEGDNN_INC_FLOAT16(cb(BFloat16)) \ + cb(Bool) \ /*! * \brief iterate through each fractional byte dtype @@ -122,7 +124,7 @@ namespace megdnn { */ #define MEGDNN_FOREACH_COMPUTING_DTYPE(cb) \ MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ - MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) + MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) \ //! In order to avoid an unnecessary increase in binary size, we just //! use QuantizedS16 dtype in winograd_filter_preprocess now. So I didn't add @@ -348,6 +350,7 @@ typedef int32_t dt_int32; typedef int16_t dt_int16; typedef int8_t dt_int8; typedef uint8_t dt_uint8; +typedef bool dt_bool; MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) @@ -375,7 +378,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) #if !MEGDNN_DISABLE_FLOAT16 BFloat16 = 11, #endif - + Bool = 12, #define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, #define D(_name) _name, MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(FST, D) @@ -392,7 +395,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) #if MEGDNN_CC_HOST //! dtype numeric category fo enum class DTypeCategory: int { - OTHER, FLOAT, INT, LOWBIT, QUANTIZED + OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL }; //! dtype signedness enum class DTypeSignedness: int { @@ -401,7 +404,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) #else struct DTypeCategory { enum Ev { - OTHER, FLOAT, INT, LOWBIT, QUANTIZED + OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL }; int ev; }; @@ -707,6 +710,7 @@ MEGDNN_DEF_DT(Int32, dt_int32, INT, SIGNED, INT32_MIN, INT32_MAX); MEGDNN_DEF_DT(Int16, dt_int16, INT, SIGNED, INT16_MIN, INT16_MAX); MEGDNN_DEF_DT(Int8, dt_int8, INT, SIGNED, INT8_MIN, INT8_MAX); MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX); +MEGDNN_DEF_DT(Bool, dt_bool, BOOL, UNSIGNED, false, true); MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED, std::numeric_limits::lowest(), std::numeric_limits::max())); diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index acc4a75b..490c8e9f 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -39,11 +39,12 @@ class ElemwiseForward: public OperatorBase { bool commutable; //!< whether arity == 2 and inputs commutable bool allow_int; //!< whether int inputs allowed bool allow_float; //!< whether float inputs allowed + bool allow_bool; //!< whether bool inputs allowed const char* name; //!< name of the mode ModeTrait(): - arity(0), commutable(0), allow_int(0), allow_float(0), + arity(0), commutable(0), allow_int(0), allow_float(0), allow_bool(0), name(NULL) {} diff --git a/dnn/scripts/gen_elemwise_utils.py b/dnn/scripts/gen_elemwise_utils.py index f6968f0e..2fd6ca8f 100755 --- a/dnn/scripts/gen_elemwise_utils.py +++ b/dnn/scripts/gen_elemwise_utils.py @@ -5,6 +5,7 @@ DTYPES = {'dt_int32': ('Int32', 'INT'), 'dt_uint8': ('Uint8', 'INT'), 'dt_int8': ('Int8', 'INT'), 'dt_int16': ('Int16', 'INT'), + 'dt_bool': ('Bool', 'BOOL'), 'dt_float32': ('Float32', 'FLOAT'), 'dt_float16': ('Float16', 'FLOAT'), 'dt_bfloat16': ('BFloat16', 'FLOAT') @@ -28,4 +29,7 @@ MODES = { 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', 'FUSE_ADD_H_SWISH'], (3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], + (1, 'BOOL'): ['NOT'], + (2, 'BOOL'): ['AND', 'OR', 'XOR'], + (3, 'BOOL'): [] } diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index b3ef8155..3d16e178 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -314,7 +314,12 @@ pdef('Elemwise').add_enum( Doc('ERFCINV', 'unary: inverse function of erfc(x)'), Doc('H_SWISH', 'unary: x * clip(x + 3, 0, 6) / 6'), Doc('H_SWISH_GRAD', 'binary: x < -3 ? 0 : (x > 3 ? y : (2 * x + 3) / 6 * y)'), - Doc('FUSE_ADD_H_SWISH', 'binary: hswish(x+y)') + Doc('FUSE_ADD_H_SWISH', 'binary: hswish(x+y)'), + + Doc('NOT', 'unary: !x'), + Doc('AND', 'binary: x && y'), + Doc('OR', 'binary: x || y'), + Doc('XOR', 'binary: x ^ y') ) pdef('ElemwiseMultiType').add_enum( diff --git a/dnn/src/common/cond_take/predicate.cuh b/dnn/src/common/cond_take/predicate.cuh index 75359a6f..a00aa476 100644 --- a/dnn/src/common/cond_take/predicate.cuh +++ b/dnn/src/common/cond_take/predicate.cuh @@ -68,6 +68,7 @@ namespace cond_take { #define inst_eq_i(_dt) do_inst_eq_i(DTypeTrait<_dt>::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(inst_eq_f) MEGDNN_FOREACH_COMPUTING_DTYPE_INT(inst_eq_i) + inst_eq_i(::megdnn::dtype::Bool) #undef inst_eq_f #undef inst_eq_i diff --git a/dnn/src/common/elemwise/each_mode.inl b/dnn/src/common/elemwise/each_mode.inl index 52fa48b9..fca47da8 100644 --- a/dnn/src/common/elemwise/each_mode.inl +++ b/dnn/src/common/elemwise/each_mode.inl @@ -9,6 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ // generated by gen_elemwise_each_mode.py +#define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(NOT, cb) \ + #define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \ @@ -38,6 +41,11 @@ MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) \ +#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) \ + #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) \ diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index 02f2eabf..5a05384b 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -139,6 +139,7 @@ namespace megdnn { DEF_KERN_FLOAT(H_SWISH, x * min(max(x + 3, 0.f), 6.f) * (1.f / 6.f)); // int only + DEF_KERN(dt_bool, NOT, x ^ 1); #undef KERN_SIG @@ -156,6 +157,9 @@ namespace megdnn { DEF_KERN_ALL(MAX, x > y ? x : y); DEF_KERN_ALL(MIN, x < y ? x : y); DEF_KERN_ALL(MUL, x* y); + DEF_KERN(dt_bool, AND, x && y); + DEF_KERN(dt_bool, OR, x || y); + DEF_KERN(dt_bool, XOR, x ^ y); DEF_KERN_INT(RMULH, round_mulh_saturate(x, y)); DEF_KERN_ALL(SIGMOID_GRAD, x*(ctype(1) - x) * y); DEF_KERN_ALL(SUB, x - y); diff --git a/dnn/src/common/elemwise/opr_impl.cpp b/dnn/src/common/elemwise/opr_impl.cpp index 5ba020b7..0881934b 100644 --- a/dnn/src/common/elemwise/opr_impl.cpp +++ b/dnn/src/common/elemwise/opr_impl.cpp @@ -74,6 +74,15 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { #define cb(_m) \ MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) { \ + get(Mode::_m).allow_bool = true; \ + } \ + MIDOUT_END(); + MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb); + MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb); +#undef cb + +#define cb(_m) \ + MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) { \ auto&& t = get(Mode::_m); \ t.arity = _a; \ t.name = megdnn_mangle(#_m); \ @@ -82,10 +91,12 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { #define _a 1 MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb); + MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb); #undef _a #define _a 2 MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb); MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb); + MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb); #undef _a #define _a 3 MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb); @@ -98,6 +109,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { auto&& t = get(Mode::_m); \ t.allow_int = true; \ t.allow_float = true; \ + t.allow_bool = true; \ t.arity = _arity; \ t.name = megdnn_mangle(#_m); \ } \ @@ -129,7 +141,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { #if MEGDNN_ELEMWISE_MODE_ENABLE_ALL for (auto&& i : traits) { - megdnn_assert(i.arity && (i.allow_int || i.allow_float) && + megdnn_assert(i.arity && (i.allow_int || i.allow_float || i.allow_bool) && (!i.commutable || i.arity == 2)); } #else @@ -282,6 +294,10 @@ void ElemwiseForward::check_dtype(DType dtype) { megdnn_assert(trait.allow_int, "unsupport mode %s for int\n", trait.name); break; + case DTypeCategory::BOOL: + megdnn_assert(trait.allow_bool, "unsupport mode %s for bool\n", + trait.name); + break; default: megdnn_throw("bad dtype"); } diff --git a/dnn/src/common/elemwise/opr_impl_body.inl b/dnn/src/common/elemwise/opr_impl_body.inl index 7cbcaaa6..4ea99efe 100644 --- a/dnn/src/common/elemwise/opr_impl_body.inl +++ b/dnn/src/common/elemwise/opr_impl_body.inl @@ -18,6 +18,15 @@ void ElemwiseForwardImpl::on_arity_dispatched() { auto src = make_elemwise_op_param(); MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(on_arity_dispatched_cb_dtype) MEGDNN_FOREACH_COMPUTING_DTYPE_INT(on_arity_dispatched_cb_dtype) + on_arity_dispatched_cb_dtype(::megdnn::dtype::Bool) + megdnn_throw("bad dtype"); +} + +template +void ElemwiseForwardImpl::on_arity_dispatched_no_bool() { + auto src = make_elemwise_op_param(); + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(on_arity_dispatched_cb_dtype) + MEGDNN_FOREACH_COMPUTING_DTYPE_INT(on_arity_dispatched_cb_dtype) megdnn_throw("bad dtype"); } @@ -45,6 +54,14 @@ IMPL_MODE_DISPATCHER(2, DTypeCategory::FLOAT); IMPL_MODE_DISPATCHER(3, DTypeCategory::FLOAT); #undef FOREACH +#define FOREACH MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL +IMPL_MODE_DISPATCHER(1, DTypeCategory::BOOL); +#undef FOREACH + +#define FOREACH MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL +IMPL_MODE_DISPATCHER(2, DTypeCategory::BOOL); +#undef FOREACH + void ElemwiseForwardImpl::exec( const TensorNDArray &src, _megdnn_tensor_out dst) { @@ -97,8 +114,8 @@ void ElemwiseForwardImpl::exec( #define D(_n) case _n: return on_arity_dispatched<_n>() D(1); D(2); - D(3); #undef D + case 3: return on_arity_dispatched_no_bool<3>(); default: megdnn_throw("bad size of input tensors"); } diff --git a/dnn/src/common/elemwise/opr_impl_class_def.inl b/dnn/src/common/elemwise/opr_impl_class_def.inl index cab89521..c7d60ee2 100644 --- a/dnn/src/common/elemwise/opr_impl_class_def.inl +++ b/dnn/src/common/elemwise/opr_impl_class_def.inl @@ -13,6 +13,9 @@ template void on_arity_dispatched(); + template + void on_arity_dispatched_no_bool(); + template struct ModeDispatcher; diff --git a/dnn/src/common/type_cvt.cpp b/dnn/src/common/type_cvt.cpp index 885a81b5..5a9cf995 100644 --- a/dnn/src/common/type_cvt.cpp +++ b/dnn/src/common/type_cvt.cpp @@ -19,10 +19,12 @@ void TypeCvt::check_exec(const TensorLayout &src, const TensorLayout &dst) { megdnn_assert_eq_shape(src, dst); auto cat = src.dtype.category(); megdnn_assert(cat == DTypeCategory::FLOAT || cat == DTypeCategory::INT || - cat == DTypeCategory::QUANTIZED); + cat == DTypeCategory::QUANTIZED || + cat == DTypeCategory::BOOL); cat = dst.dtype.category(); megdnn_assert(cat == DTypeCategory::FLOAT || cat == DTypeCategory::INT || - cat == DTypeCategory::QUANTIZED); + cat == DTypeCategory::QUANTIZED || + cat == DTypeCategory::BOOL); } } // namespace megdnn diff --git a/dnn/src/cuda/cond_take/kimpl/dt_bool.cu b/dnn/src/cuda/cond_take/kimpl/dt_bool.cu new file mode 100644 index 00000000..524ccc77 --- /dev/null +++ b/dnn/src/cuda/cond_take/kimpl/dt_bool.cu @@ -0,0 +1,27 @@ +/** + * \file dnn/src/cuda/cond_take/kimpl/dt_bool.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_cond_take_kern_impls.py +#include "../kern.inl" + +namespace megdnn { +namespace cuda { +namespace cond_take { + +inst_genidx(::megdnn::dtype::Bool) +#undef inst_genidx + +inst_copy(::megdnn::dtype::Bool) +#undef inst_copy +#undef inst_copy_ + +} // cond_take +} // cuda +} // megdnn diff --git a/dnn/src/cuda/elemwise/kern_wrapper.cuh b/dnn/src/cuda/elemwise/kern_wrapper.cuh index 5f666ffc..441cc163 100644 --- a/dnn/src/cuda/elemwise/kern_wrapper.cuh +++ b/dnn/src/cuda/elemwise/kern_wrapper.cuh @@ -25,8 +25,9 @@ namespace cuda { 1, KernImpl, typename std::enable_if< !std::is_same::value && - !std::is_same::value>::type> { + !std::is_same::value && + !std::is_same::value>::type> { typedef typename KernImpl::ctype ctype; ctype* dst; @@ -41,8 +42,9 @@ namespace cuda { 2, KernImpl, typename std::enable_if< !std::is_same::value && - !std::is_same::value>::type> { + !std::is_same::value && + !std::is_same::value>::type> { typedef typename KernImpl::ctype ctype; ctype* dst; @@ -57,8 +59,9 @@ namespace cuda { 3, KernImpl, typename std::enable_if< !std::is_same::value && - !std::is_same::value>::type> { + !std::is_same::value && + !std::is_same::value>::type> { typedef typename KernImpl::ctype ctype; ctype* dst; @@ -74,8 +77,9 @@ namespace cuda { 1, KernImpl, typename std::enable_if< std::is_same::value || - std::is_same::value>::type> { + std::is_same::value || + std::is_same::value>::type> { typedef typename KernImpl::ctype ctype; using VectTypeTrait = elemwise_intl::VectTypeTrait; typedef typename VectTypeTrait::vect_type vect_type; @@ -99,8 +103,9 @@ namespace cuda { 2, KernImpl, typename std::enable_if< std::is_same::value || - std::is_same::value>::type> { + std::is_same::value || + std::is_same::value>::type> { typedef typename KernImpl::ctype ctype; using VectTypeTrait = elemwise_intl::VectTypeTrait; typedef typename VectTypeTrait::vect_type vect_type; @@ -126,8 +131,9 @@ namespace cuda { 3, KernImpl, typename std::enable_if< std::is_same::value || - std::is_same::value>::type> { + std::is_same::value || + std::is_same::value>::type> { typedef typename KernImpl::ctype ctype; using VectTypeTrait = elemwise_intl::VectTypeTrait; typedef typename VectTypeTrait::vect_type vect_type; diff --git a/dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu b/dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu new file mode 100644 index 00000000..0f952baf --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu @@ -0,0 +1,15 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu b/dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu new file mode 100644 index 00000000..6774e133 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu @@ -0,0 +1,15 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NOT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bool +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu b/dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu new file mode 100644 index 00000000..cab52df0 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu @@ -0,0 +1,15 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu b/dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu new file mode 100644 index 00000000..0e7a2192 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu @@ -0,0 +1,15 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_helper.cpp b/dnn/src/cuda/elemwise_helper.cpp index b2295fba..bcec4f0e 100644 --- a/dnn/src/cuda/elemwise_helper.cpp +++ b/dnn/src/cuda/elemwise_helper.cpp @@ -169,6 +169,9 @@ INST_FOR_CTYPE #define ct dt_qint32 INST_FOR_CTYPE #undef ct +#define ct dt_bool +INST_FOR_CTYPE +#undef ct #undef INST_FOR_CTYPE #undef INST @@ -216,6 +219,9 @@ INST_FOR_CTYPE #define ct dt_qint32 INST_FOR_CTYPE #undef ct +#define ct dt_bool +INST_FOR_CTYPE +#undef ct #undef ndim_cb @@ -225,6 +231,7 @@ INST_FOR_CTYPE #define INST(dt_ibyte) template class ParamVectVisitor<4, dt_ibyte, BCAST_1010> INST(dt_int8); INST(dt_uint8); +INST(dt_bool); INST(dt_qint8); INST(dt_quint8); #undef dt_ibyte diff --git a/dnn/src/cuda/elemwise_helper.cuh b/dnn/src/cuda/elemwise_helper.cuh index 471591d2..94e184f6 100644 --- a/dnn/src/cuda/elemwise_helper.cuh +++ b/dnn/src/cuda/elemwise_helper.cuh @@ -102,6 +102,7 @@ INST(dt_float16, half4); INST(dt_bfloat16, bhalf4); INST(dt_int32, int4); INST(dt_int16, short4); +INST(dt_bool, uchar4); #undef as_raw #define as_raw(x) x.as_int8() INST(dt_qint8, char4); @@ -454,6 +455,7 @@ INST_DT_IBYTE(dt_int8); INST_DT_IBYTE(dt_uint8); INST_DT_IBYTE(dt_qint8); INST_DT_IBYTE(dt_quint8); +INST_DT_IBYTE(dt_bool); #undef INST_DT_IBYTE #undef DEVICE_WRAPPER #undef INST_PARAM_VECT_VISITOR @@ -913,6 +915,7 @@ INST_DT_IBYTE(dt_int8); INST_DT_IBYTE(dt_uint8); INST_DT_IBYTE(dt_qint8); INST_DT_IBYTE(dt_quint8); +INST_DT_IBYTE(dt_bool); #undef INST_DT_IBYTE //! implement general case by UserOpInvokerToSameNdim @@ -1259,6 +1262,7 @@ INST_DT_IBYTE(dt_int8); INST_DT_IBYTE(dt_uint8); INST_DT_IBYTE(dt_qint8); INST_DT_IBYTE(dt_quint8); +INST_DT_IBYTE(dt_bool); #undef INST_DT_IBYTE #endif diff --git a/dnn/src/cuda/type_cvt/kern.cu b/dnn/src/cuda/type_cvt/kern.cu index dd52b8c7..1d9131eb 100644 --- a/dnn/src/cuda/type_cvt/kern.cu +++ b/dnn/src/cuda/type_cvt/kern.cu @@ -62,7 +62,8 @@ template struct TypeCvtOp::value || - std::is_same::value>::type> { + std::is_same::value || + std::is_same::value>::type> { ctype_dest* dest; using src_vect_type = typename VectTypeTrait::vect_type; using dst_vect_type = typename VectTypeTrait::vect_type; @@ -85,7 +86,8 @@ struct TypeCvtOpToQuantized< ctype_dest, ctype_src, typename std::enable_if< std::is_same::value || - std::is_same::value>::type> { + std::is_same::value || + std::is_same::value>::type> { ctype_dest* dest; CudaDTypeParam param; using src_vect_type = typename VectTypeTrait::vect_type; @@ -109,7 +111,8 @@ struct TypeCvtOpFromQuantized< ctype_dest, ctype_src, typename std::enable_if< std::is_same::value || - std::is_same::value>::type> { + std::is_same::value || + std::is_same::value>::type> { ctype_dest* dest; CudaDTypeParam param; using src_vect_type = typename VectTypeTrait::vect_type; @@ -137,7 +140,8 @@ struct TypeCvtOpBetweenQuantized< ctype_dest, ctype_src, typename std::enable_if< std::is_same::value || - std::is_same::value>::type> { + std::is_same::value || + std::is_same::value>::type> { ctype_dest* dest; CudaDTypeParam src_param; CudaDTypeParam dst_param; @@ -243,6 +247,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cb(dtype_src, dt_float32) \ cb(dtype_src, dt_float16) \ cb(dtype_src, dt_bfloat16) \ + cb(dtype_src, dt_bool) \ #define MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \ cb(dtype_src, dt_quint8) \ @@ -265,6 +270,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cb(dt_float32) \ cb(dt_float16) \ cb(dt_bfloat16) \ + cb(dt_bool) \ #define MEGDNN_FOREACH_QUANTIZED_CTYPE(cb) \ cb(dt_quint8) \ diff --git a/dnn/src/fallback/type_cvt/opr_impl.cpp b/dnn/src/fallback/type_cvt/opr_impl.cpp index 4c952130..04020f9d 100644 --- a/dnn/src/fallback/type_cvt/opr_impl.cpp +++ b/dnn/src/fallback/type_cvt/opr_impl.cpp @@ -138,7 +138,8 @@ void do_cvt_s8_normal(_megdnn_tensor_in src, _megdnn_tensor_out dst) { dctype* __restrict dptr = dst.ptr(); float scale = src.layout.dtype.param().scale; for (size_t i = 0; i < n; ++i) { - dptr[i] = static_cast(sptr[i] * scale); + auto val = sptr[i] * scale; + dptr[i] = static_cast(val); } } @@ -150,7 +151,8 @@ void do_cvt_s32_normal(_megdnn_tensor_in src, _megdnn_tensor_out dst) { dctype* __restrict dptr = dst.ptr(); float scale = src.layout.dtype.param().scale; for (size_t i = 0; i < n; ++i) { - dptr[i] = static_cast(sptr[i] * scale); + auto val = sptr[i] * scale; + dptr[i] = static_cast(val); } } @@ -163,7 +165,8 @@ void do_cvt_asymm8_normal(_megdnn_tensor_in src, _megdnn_tensor_out dst) { float scale = src.layout.dtype.param().scale; uint8_t zp = src.layout.dtype.param().zero_point; for (size_t i = 0; i < n; ++i) { - dptr[i] = static_cast((sptr[i] - zp) * scale); + auto val = (sptr[i] - zp) * scale; + dptr[i] = static_cast(val); } } @@ -310,6 +313,7 @@ void on_dest_ctype(_megdnn_tensor_in src, _megdnn_tensor_out dst) { break; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) case DTypeEnum::QuantizedS8: MIDOUT_BEGIN(megdnn_fb_typecvt_src_dtype, midout_iv(DTypeEnum::QuantizedS8)) { @@ -467,6 +471,7 @@ void run_contiguous(_megdnn_tensor_in src, _megdnn_tensor_out dst) { } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) case DTypeEnum::QuantizedS8: MIDOUT_BEGIN(megdnn_fb_typecvt_dst_dtype, midout_iv(DTypeEnum::QuantizedS8)) { diff --git a/dnn/src/naive/elemwise/kimpl/AND_dt_bool.cpp b/dnn/src/naive/elemwise/kimpl/AND_dt_bool.cpp new file mode 100644 index 00000000..605eb17d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/AND_dt_bool.cpp @@ -0,0 +1,15 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/AND_dt_bool.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/NOT_dt_bool.cpp b/dnn/src/naive/elemwise/kimpl/NOT_dt_bool.cpp new file mode 100644 index 00000000..33408885 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/NOT_dt_bool.cpp @@ -0,0 +1,15 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/NOT_dt_bool.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NOT, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bool +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/OR_dt_bool.cpp b/dnn/src/naive/elemwise/kimpl/OR_dt_bool.cpp new file mode 100644 index 00000000..f47ff571 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/OR_dt_bool.cpp @@ -0,0 +1,15 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/OR_dt_bool.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/kimpl/XOR_dt_bool.cpp b/dnn/src/naive/elemwise/kimpl/XOR_dt_bool.cpp new file mode 100644 index 00000000..b101c06d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/XOR_dt_bool.cpp @@ -0,0 +1,15 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/XOR_dt_bool.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool +#include "../kern_impl.inl" diff --git a/dnn/src/naive/type_cvt/opr_impl.cpp b/dnn/src/naive/type_cvt/opr_impl.cpp index 6c0a06b8..87f6133a 100644 --- a/dnn/src/naive/type_cvt/opr_impl.cpp +++ b/dnn/src/naive/type_cvt/opr_impl.cpp @@ -82,6 +82,7 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest, MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb default: megdnn_throw("bad dtype"); @@ -103,6 +104,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb default: megdnn_throw("bad dtype"); diff --git a/dnn/test/common/elemwise.cpp b/dnn/test/common/elemwise.cpp index ee0d7eb6..523ff206 100644 --- a/dnn/test/common/elemwise.cpp +++ b/dnn/test/common/elemwise.cpp @@ -942,6 +942,8 @@ TEST(TEST_ELEMWISE, MODE_TRAIT) { ASSERT_TRUE(T::from_mode(M::RMULH).commutable); ASSERT_FALSE(T::from_mode(M::RMULH).allow_float); + + ASSERT_TRUE(T::from_mode(M::XOR).allow_bool); } } // namespace elemwise diff --git a/python_module/src/cpp/megbrain_wrap.cpp b/python_module/src/cpp/megbrain_wrap.cpp index 9f9becf0..466a9604 100644 --- a/python_module/src/cpp/megbrain_wrap.cpp +++ b/python_module/src/cpp/megbrain_wrap.cpp @@ -916,6 +916,7 @@ SymbolVar fill_retain_dtype(SymbolVar var, PyObject *value) { case DTypeEnum::QuantizedS4: case DTypeEnum::Byte: case DTypeEnum::QuantizedS16: + case DTypeEnum::Bool: break; #define cb(low_bit, size) \ case DTypeEnum::low_bit##size: \ diff --git a/src/core/include/megbrain/dtype.h b/src/core/include/megbrain/dtype.h index 57d2955c..52fd71db 100644 --- a/src/core/include/megbrain/dtype.h +++ b/src/core/include/megbrain/dtype.h @@ -27,6 +27,7 @@ using ::megdnn::dt_int32; using ::megdnn::dt_quint8; using ::megdnn::dt_qint8; using ::megdnn::dt_qint32; +using ::megdnn::dt_bool; using ::megdnn::DType; using ::megdnn::DTypeEnum; using ::megdnn::DTypeTrait; diff --git a/src/jit/impl/ast_c.cpp b/src/jit/impl/ast_c.cpp index 4ed30928..95143db3 100644 --- a/src/jit/impl/ast_c.cpp +++ b/src/jit/impl/ast_c.cpp @@ -145,9 +145,9 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { 0.f}) / 6.f), }; - mgb_assert(map.size() + 8 == opr::Elemwise::Param::MODE_NR_MEMBER); + mgb_assert(map.size() + 12 == opr::Elemwise::Param::MODE_NR_MEMBER); // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, - // ERFINV, ERFCINV + // ERFINV, ERFCINV, NOT, AND, OR, XOR return map; #undef ADD_OPR } diff --git a/src/jit/impl/halide/ast_hl.cpp b/src/jit/impl/halide/ast_hl.cpp index cf0249c1..f67d2210 100644 --- a/src/jit/impl/halide/ast_hl.cpp +++ b/src/jit/impl/halide/ast_hl.cpp @@ -193,6 +193,14 @@ Halide::Expr dispatch_elemwise_mode( return Halide::round(inp(0)); case Mode::RMULH: return (inp(0) * inp(1)) >> Halide::popcount(inp(0)); + case Mode::NOT: + return cv(1) - cv(inp(0) != cv(0)); + case Mode::AND: + return cv(inp(0) != cv(0)) * cv(inp(1) != cv(0)); + case Mode::OR: + return cv(cv(inp(0) != cv(0)) + cv(inp(1) != cv(0)) > cv(0)); + case Mode::XOR: + return cv(cv(inp(0) != cv(0)) + cv(inp(1) != cv(0)) == cv(1)); default: mgb_throw(InternalError, "unsupported Elemwise mode(%d)", static_cast(mode)); diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index d3a00c29..d9195544 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -631,6 +631,8 @@ MGB_IMPL_OPR_GRAD(Elemwise) { RET(EL2(H_SWISH_GRAD, i0, og)); case Mode::FUSE_ADD_H_SWISH: RET(EL2(H_SWISH_GRAD, (i0 + i1), og)); + case Mode::NOT: + return nullptr; // binary case Mode::ABS_GRAD: @@ -693,6 +695,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) { return nullptr; case Mode::EQ: RET_INVALID(); + case Mode::OR: + case Mode::XOR: + case Mode::AND: + return nullptr; // ternary case Mode::COND_LEQ_MOV: diff --git a/src/opr/impl/loop/forward.cpp b/src/opr/impl/loop/forward.cpp index b52344ec..0c7eee01 100644 --- a/src/opr/impl/loop/forward.cpp +++ b/src/opr/impl/loop/forward.cpp @@ -408,6 +408,8 @@ cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const { break; case DTypeEnum::UintB4: break; + case DTypeEnum::Bool: + break; #define cb(x) case DTypeEnum::x: break; MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) diff --git a/src/opr/impl/loop/impl.cpp b/src/opr/impl/loop/impl.cpp index 45b36a19..305ca817 100644 --- a/src/opr/impl/loop/impl.cpp +++ b/src/opr/impl/loop/impl.cpp @@ -247,6 +247,8 @@ MGB_DEFINE_OPR_CLASS(LoopImpl::DescImplBase::LoopCondManager::GetCondOpr, break; case DTypeEnum::UintB4: break; + case DTypeEnum::Bool: + break; #define cb(_dt) \ case DTypeEnum::_dt: \ break; diff --git a/src/opr/include/megbrain/opr/basic_arith_wrapper.h b/src/opr/include/megbrain/opr/basic_arith_wrapper.h index 5f24b2ba..1199176f 100644 --- a/src/opr/include/megbrain/opr/basic_arith_wrapper.h +++ b/src/opr/include/megbrain/opr/basic_arith_wrapper.h @@ -32,6 +32,7 @@ namespace opr { EL1(exp, EXP) EL1(log, LOG) EL1(abs, ABS) + EL1(not_, NOT) #undef EL1 @@ -53,6 +54,9 @@ namespace opr { EL2(min, MIN) EL2(switch_gt0, SWITCH_GT0) EL2(eq, EQ) + EL2(and_, AND) + EL2(or_, OR) + EL2(xor_, XOR) #undef EL2 diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index 08cf6ad4..09b32566 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -206,6 +206,7 @@ namespace { static constexpr Mode MODE = Mode::_mode; \ static constexpr bool ALLOW_INT = _ALLOW_INT; \ static constexpr bool ALLOW_FLOAT = _ALLOW_FLOAT; \ + static constexpr bool ALLOW_BOOL = _ALLOW_BOOL; \ static constexpr const char* NAME = #_mode; \ template \ static inline ctype apply( \ @@ -588,6 +589,14 @@ namespace { struct enable_for_dtype_impl { static constexpr bool value = false; }; + template + struct enable_for_dtype_impl { + static constexpr bool value = Trait::ALLOW_BOOL; + }; + template<> + struct enable_for_dtype_impl { + static constexpr bool value = false; + }; } //! whether to enable test for specific dtype and Trait @@ -749,8 +758,60 @@ TYPED_TEST(TestOprBasicArithTernaryElemwise, Float32) { TEST(TestOprBasicArithElemwise, CheckAllModeTested) { size_t nr_member = opr::Elemwise::Param::MODE_NR_MEMBER; - ASSERT_EQ(nr_member, tested_mode.size()); + ASSERT_EQ(nr_member, tested_mode.size() + 4); + // Not using TestRunner: NOT, AND, OR, XOR } +#define TEST_OPR_BASIC_ARITH_UNARY_BOOL(_mode, _op) \ + TEST(TestOprBasicArithElemwise, _mode) { \ + HostTensorGenerator gen; \ + auto host_x = gen({2, 1}); \ + auto ptr = host_x->ptr(); \ + for (size_t i = 0; i < 2; ++i) { \ + ptr[i] = (i & 1); \ + } \ + auto graph = ComputingGraph::make(); \ + using Mode = opr::Elemwise::Mode; \ + auto x = opr::Host2DeviceCopy::make(*graph, host_x), \ + y = opr::Elemwise::make({x}, Mode::_mode); \ + HostTensorND host_y; \ + auto func = graph->compile({make_callback_copy(y, host_y)}); \ + func->execute(); \ + ASSERT_EQ(TensorShape({2, 1}), host_y.shape()); \ + auto ptry = host_y.ptr(); \ + for (int i = 0;i < 2;i ++) { \ + ASSERT_EQ(_op ptr[i], ptry[i]); \ + } \ + } \ + +TEST_OPR_BASIC_ARITH_UNARY_BOOL(NOT, !) + +#define TEST_OPR_BASIC_ARITH_BINARY_BOOL(_mode, _op) \ + TEST(TestOprBasicArithElemwise, _mode) { \ + HostTensorGenerator gen; \ + auto host_x1 = gen({2, 2}), host_x2 = gen({2, 2}); \ + auto ptr1 = host_x1->ptr(), ptr2 = host_x2->ptr(); \ + for (size_t i = 0; i < 4; ++i) { \ + ptr1[i] = (i < 2); \ + ptr2[i] = (i & 1); \ + } \ + auto graph = ComputingGraph::make(); \ + using Mode = opr::Elemwise::Mode; \ + auto x1 = opr::Host2DeviceCopy::make(*graph, host_x1), \ + x2 = opr::Host2DeviceCopy::make(*graph, host_x2), \ + y = opr::Elemwise::make({x1, x2}, Mode::_mode); \ + HostTensorND host_y; \ + auto func = graph->compile({make_callback_copy(y, host_y)}); \ + func->execute(); \ + ASSERT_EQ(TensorShape({2, 2}), host_y.shape()); \ + auto ptry = host_y.ptr(); \ + for (int i = 0;i < 4;i ++) { \ + ASSERT_EQ(ptr1[i] _op ptr2[i], ptry[i]); \ + } \ + } \ + +TEST_OPR_BASIC_ARITH_BINARY_BOOL(AND, &&) +TEST_OPR_BASIC_ARITH_BINARY_BOOL(OR, ||) +TEST_OPR_BASIC_ARITH_BINARY_BOOL(XOR, ^) TEST(TestOprBasicArithElemwise, FuseMulAdd3Shapes) { using Checker = AutoOprChecker<3, 1>; diff --git a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl index 7267a4ed..ba32cbf3 100644 --- a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl @@ -19,6 +19,17 @@ ctype x = inp[0][idx]; \ ctype y = inp[1][idx] +#define _ALLOW_BOOL true +#define _ALLOW_FLOAT false +#define _ALLOW_INT false +DEF_TRAIT(AND, x && y) +DEF_TRAIT(OR, x || y) +DEF_TRAIT(XOR, x ^ y) +#undef _ALLOW_INT +#undef _ALLOW_FLOAT +#undef _ALLOW_BOOL + +#define _ALLOW_BOOL false #define _ALLOW_FLOAT true #define _ALLOW_INT true DEF_TRAIT(ABS_GRAD, x > 0 ? y : -y) @@ -60,6 +71,7 @@ DEF_TRAIT(SHR, do_shr(x, y)) DEF_TRAIT(RMULH, do_round_mulh_saturate(x, y)) #undef _ALLOW_INT #undef _ALLOW_FLOAT +#undef _ALLOW_BOOL #undef _CUR_ARITY #undef _EXPAND_PARAMS diff --git a/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl b/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl index d0103d0c..4c4a5fda 100644 --- a/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl @@ -20,6 +20,7 @@ ctype y = inp[1][idx]; \ ctype z = inp[2][idx] +#define _ALLOW_BOOL false #define _ALLOW_FLOAT true #define _ALLOW_INT true DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0) @@ -46,5 +47,6 @@ DEF_TRAIT(FUSE_MUL_ADD4, i0 * i1 + i2 * i3) #undef _CUR_ARITY #undef _EXPAND_PARAMS +#undef _ALLOW_BOOL // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/test/basic_arith/elemwise_unary_trait_def.inl b/src/opr/test/basic_arith/elemwise_unary_trait_def.inl index 6e90beab..e0e9d41b 100644 --- a/src/opr/test/basic_arith/elemwise_unary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_unary_trait_def.inl @@ -18,6 +18,15 @@ #define _EXPAND_PARAMS \ ctype x = inp[0][idx] +#define _ALLOW_BOOL true +#define _ALLOW_FLOAT false +#define _ALLOW_INT false +DEF_TRAIT(NOT, !x) +#undef _ALLOW_INT +#undef _ALLOW_FLOAT +#undef _ALLOW_BOOL + +#define _ALLOW_BOOL false #define _ALLOW_FLOAT true @@ -51,6 +60,8 @@ DEF_TRAIT(H_SWISH, do_h_swish(x)) #undef _ALLOW_FLOAT +#undef _ALLOW_BOOL + #undef _CUR_ARITY #undef _EXPAND_PARAMS diff --git a/src/serialization/impl/dtype.fbs b/src/serialization/impl/dtype.fbs index debc2d98..6fe3ec23 100644 --- a/src/serialization/impl/dtype.fbs +++ b/src/serialization/impl/dtype.fbs @@ -21,6 +21,7 @@ enum DTypeEnum : byte { QuantizedS4, QuantizedS16, BFloat16, + Bool, } table LinearQuantizationParam { diff --git a/test/src/helper.cpp b/test/src/helper.cpp index df4fe649..8a56bd5c 100644 --- a/test/src/helper.cpp +++ b/test/src/helper.cpp @@ -141,6 +141,21 @@ namespace mgb { template class HostTensorGenerator< dtype::Int32, RandomDistribution::CONSTANT>; std::shared_ptr + HostTensorGenerator:: + operator()(const TensorShape& shape, CompNode cn) { + if (!cn.valid()) + cn = CompNode::load("xpu0"); + auto dtype = dtype::Bool(); + std::shared_ptr ret = + std::make_shared(cn, shape, dtype); + auto ptr = ret->ptr(); + for (size_t i = 0, it = shape.total_nr_elems(); i < it; ++i) { + ptr[i] = (i % 2 == 1); + } + return ret; + } + + std::shared_ptr HostTensorGenerator:: operator()(const TensorShape& shape, CompNode cn) { if (!cn.valid()) diff --git a/test/src/include/megbrain/test/helper.h b/test/src/include/megbrain/test/helper.h index 065fade4..01e3940e 100644 --- a/test/src/include/megbrain/test/helper.h +++ b/test/src/include/megbrain/test/helper.h @@ -202,6 +202,10 @@ struct RandomDistributionDTypeDefault { static constexpr auto dist = RandomDistribution::UNIFORM; }; template<> +struct RandomDistributionDTypeDefault { + static constexpr auto dist = RandomDistribution::UNIFORM; +}; +template<> struct RandomDistributionDTypeDefault { static constexpr auto dist = RandomDistribution::UNIFORM; }; @@ -251,6 +255,10 @@ struct UniformRNGDefaultRange { static constexpr dt_uint8 LO = 0, HI = 255; }; template<> +struct UniformRNGDefaultRange { + static constexpr dt_bool LO = false, HI = true; +}; +template<> struct UniformRNGDefaultRange { static constexpr dt_int16 LO = -32767, HI = 32767; }; @@ -341,6 +349,20 @@ class HostTensorGenerator final: private: ctype m_default_val; }; +template <> +class HostTensorGenerator final + : public HostTensorGeneratorBase { +public: + using ctype = typename DTypeTrait::ctype; + + HostTensorGenerator(uint64_t seed = next_rand_seed()) + : HostTensorGeneratorBase{seed} {} + + std::shared_ptr operator()(const TensorShape& shape, + CompNode cn = {}) override; + using HostTensorGeneratorBase::operator(); + +}; template <> class HostTensorGenerator final