Browse Source

feat(dnn): add bool dtype

GitOrigin-RevId: 98c8a092b4
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
e258812f12
42 changed files with 424 additions and 32 deletions
  1. +8
    -4
      dnn/include/megdnn/dtype.h
  2. +2
    -1
      dnn/include/megdnn/oprs/general.h
  3. +4
    -0
      dnn/scripts/gen_elemwise_utils.py
  4. +6
    -1
      dnn/scripts/opr_param_defs.py
  5. +1
    -0
      dnn/src/common/cond_take/predicate.cuh
  6. +8
    -0
      dnn/src/common/elemwise/each_mode.inl
  7. +4
    -0
      dnn/src/common/elemwise/kern_defs.cuh
  8. +17
    -1
      dnn/src/common/elemwise/opr_impl.cpp
  9. +18
    -1
      dnn/src/common/elemwise/opr_impl_body.inl
  10. +3
    -0
      dnn/src/common/elemwise/opr_impl_class_def.inl
  11. +4
    -2
      dnn/src/common/type_cvt.cpp
  12. +27
    -0
      dnn/src/cuda/cond_take/kimpl/dt_bool.cu
  13. +18
    -12
      dnn/src/cuda/elemwise/kern_wrapper.cuh
  14. +15
    -0
      dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu
  15. +15
    -0
      dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu
  16. +15
    -0
      dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu
  17. +15
    -0
      dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu
  18. +7
    -0
      dnn/src/cuda/elemwise_helper.cpp
  19. +4
    -0
      dnn/src/cuda/elemwise_helper.cuh
  20. +10
    -4
      dnn/src/cuda/type_cvt/kern.cu
  21. +8
    -3
      dnn/src/fallback/type_cvt/opr_impl.cpp
  22. +15
    -0
      dnn/src/naive/elemwise/kimpl/AND_dt_bool.cpp
  23. +15
    -0
      dnn/src/naive/elemwise/kimpl/NOT_dt_bool.cpp
  24. +15
    -0
      dnn/src/naive/elemwise/kimpl/OR_dt_bool.cpp
  25. +15
    -0
      dnn/src/naive/elemwise/kimpl/XOR_dt_bool.cpp
  26. +2
    -0
      dnn/src/naive/type_cvt/opr_impl.cpp
  27. +2
    -0
      dnn/test/common/elemwise.cpp
  28. +1
    -0
      python_module/src/cpp/megbrain_wrap.cpp
  29. +1
    -0
      src/core/include/megbrain/dtype.h
  30. +2
    -2
      src/jit/impl/ast_c.cpp
  31. +8
    -0
      src/jit/impl/halide/ast_hl.cpp
  32. +6
    -0
      src/opr/impl/basic_arith.cpp
  33. +2
    -0
      src/opr/impl/loop/forward.cpp
  34. +2
    -0
      src/opr/impl/loop/impl.cpp
  35. +4
    -0
      src/opr/include/megbrain/opr/basic_arith_wrapper.h
  36. +62
    -1
      src/opr/test/basic_arith/elemwise.cpp
  37. +12
    -0
      src/opr/test/basic_arith/elemwise_binary_trait_def.inl
  38. +2
    -0
      src/opr/test/basic_arith/elemwise_ternary_trait_def.inl
  39. +11
    -0
      src/opr/test/basic_arith/elemwise_unary_trait_def.inl
  40. +1
    -0
      src/serialization/impl/dtype.fbs
  41. +15
    -0
      test/src/helper.cpp
  42. +22
    -0
      test/src/include/megbrain/test/helper.h

+ 8
- 4
dnn/include/megdnn/dtype.h View File

@@ -52,6 +52,7 @@ namespace megdnn {
MEGDNN_INC_FLOAT16(cb(Float16)) \
MEGDNN_INC_FLOAT16(cb(BFloat16)) \
cb(UintB4) \
cb(Bool) \

/*!
* \brief iterate through each full byte dtype
@@ -65,6 +66,7 @@ namespace megdnn {
cb(Byte) \
MEGDNN_INC_FLOAT16(cb(Float16)) \
MEGDNN_INC_FLOAT16(cb(BFloat16)) \
cb(Bool) \

/*!
* \brief iterate through each fractional byte dtype
@@ -122,7 +124,7 @@ namespace megdnn {
*/
#define MEGDNN_FOREACH_COMPUTING_DTYPE(cb) \
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \
MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb)
MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) \

//! In order to avoid an unnecessary increase in binary size, we just
//! use QuantizedS16 dtype in winograd_filter_preprocess now. So I didn't add
@@ -348,6 +350,7 @@ typedef int32_t dt_int32;
typedef int16_t dt_int16;
typedef int8_t dt_int8;
typedef uint8_t dt_uint8;
typedef bool dt_bool;
MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;)
MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)

@@ -375,7 +378,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)
#if !MEGDNN_DISABLE_FLOAT16
BFloat16 = 11,
#endif
Bool = 12,
#define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE,
#define D(_name) _name,
MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(FST, D)
@@ -392,7 +395,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)
#if MEGDNN_CC_HOST
//! dtype numeric category fo
enum class DTypeCategory: int {
OTHER, FLOAT, INT, LOWBIT, QUANTIZED
OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL
};
//! dtype signedness
enum class DTypeSignedness: int {
@@ -401,7 +404,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)
#else
struct DTypeCategory {
enum Ev {
OTHER, FLOAT, INT, LOWBIT, QUANTIZED
OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL
};
int ev;
};
@@ -707,6 +710,7 @@ MEGDNN_DEF_DT(Int32, dt_int32, INT, SIGNED, INT32_MIN, INT32_MAX);
MEGDNN_DEF_DT(Int16, dt_int16, INT, SIGNED, INT16_MIN, INT16_MAX);
MEGDNN_DEF_DT(Int8, dt_int8, INT, SIGNED, INT8_MIN, INT8_MAX);
MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX);
MEGDNN_DEF_DT(Bool, dt_bool, BOOL, UNSIGNED, false, true);
MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED,
std::numeric_limits<dt_float16>::lowest(),
std::numeric_limits<dt_float16>::max()));


+ 2
- 1
dnn/include/megdnn/oprs/general.h View File

@@ -39,11 +39,12 @@ class ElemwiseForward: public OperatorBase {
bool commutable; //!< whether arity == 2 and inputs commutable
bool allow_int; //!< whether int inputs allowed
bool allow_float; //!< whether float inputs allowed
bool allow_bool; //!< whether bool inputs allowed
const char* name; //!< name of the mode


ModeTrait():
arity(0), commutable(0), allow_int(0), allow_float(0),
arity(0), commutable(0), allow_int(0), allow_float(0), allow_bool(0),
name(NULL)
{}



+ 4
- 0
dnn/scripts/gen_elemwise_utils.py View File

@@ -5,6 +5,7 @@ DTYPES = {'dt_int32': ('Int32', 'INT'),
'dt_uint8': ('Uint8', 'INT'),
'dt_int8': ('Int8', 'INT'),
'dt_int16': ('Int16', 'INT'),
'dt_bool': ('Bool', 'BOOL'),
'dt_float32': ('Float32', 'FLOAT'),
'dt_float16': ('Float16', 'FLOAT'),
'dt_bfloat16': ('BFloat16', 'FLOAT')
@@ -28,4 +29,7 @@ MODES = {
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD',
'FUSE_ADD_H_SWISH'],
(3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'],
(1, 'BOOL'): ['NOT'],
(2, 'BOOL'): ['AND', 'OR', 'XOR'],
(3, 'BOOL'): []
}

+ 6
- 1
dnn/scripts/opr_param_defs.py View File

@@ -314,7 +314,12 @@ pdef('Elemwise').add_enum(
Doc('ERFCINV', 'unary: inverse function of erfc(x)'),
Doc('H_SWISH', 'unary: x * clip(x + 3, 0, 6) / 6'),
Doc('H_SWISH_GRAD', 'binary: x < -3 ? 0 : (x > 3 ? y : (2 * x + 3) / 6 * y)'),
Doc('FUSE_ADD_H_SWISH', 'binary: hswish(x+y)')
Doc('FUSE_ADD_H_SWISH', 'binary: hswish(x+y)'),

Doc('NOT', 'unary: !x'),
Doc('AND', 'binary: x && y'),
Doc('OR', 'binary: x || y'),
Doc('XOR', 'binary: x ^ y')
)

pdef('ElemwiseMultiType').add_enum(


+ 1
- 0
dnn/src/common/cond_take/predicate.cuh View File

@@ -68,6 +68,7 @@ namespace cond_take {
#define inst_eq_i(_dt) do_inst_eq_i(DTypeTrait<_dt>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(inst_eq_f)
MEGDNN_FOREACH_COMPUTING_DTYPE_INT(inst_eq_i)
inst_eq_i(::megdnn::dtype::Bool)
#undef inst_eq_f
#undef inst_eq_i



+ 8
- 0
dnn/src/common/elemwise/each_mode.inl View File

@@ -9,6 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_each_mode.py
#define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(NOT, cb) \

#define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \
@@ -38,6 +41,11 @@
MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) \

#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) \

#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) \


+ 4
- 0
dnn/src/common/elemwise/kern_defs.cuh View File

@@ -139,6 +139,7 @@ namespace megdnn {
DEF_KERN_FLOAT(H_SWISH, x * min(max(x + 3, 0.f), 6.f) * (1.f / 6.f));

// int only
DEF_KERN(dt_bool, NOT, x ^ 1);

#undef KERN_SIG

@@ -156,6 +157,9 @@ namespace megdnn {
DEF_KERN_ALL(MAX, x > y ? x : y);
DEF_KERN_ALL(MIN, x < y ? x : y);
DEF_KERN_ALL(MUL, x* y);
DEF_KERN(dt_bool, AND, x && y);
DEF_KERN(dt_bool, OR, x || y);
DEF_KERN(dt_bool, XOR, x ^ y);
DEF_KERN_INT(RMULH, round_mulh_saturate(x, y));
DEF_KERN_ALL(SIGMOID_GRAD, x*(ctype(1) - x) * y);
DEF_KERN_ALL(SUB, x - y);


+ 17
- 1
dnn/src/common/elemwise/opr_impl.cpp View File

@@ -74,6 +74,15 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {

#define cb(_m) \
MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) { \
get(Mode::_m).allow_bool = true; \
} \
MIDOUT_END();
MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb);
MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb);
#undef cb

#define cb(_m) \
MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) { \
auto&& t = get(Mode::_m); \
t.arity = _a; \
t.name = megdnn_mangle(#_m); \
@@ -82,10 +91,12 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
#define _a 1
MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb);
MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb);
MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb);
#undef _a
#define _a 2
MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb);
MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb);
MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb);
#undef _a
#define _a 3
MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb);
@@ -98,6 +109,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
auto&& t = get(Mode::_m); \
t.allow_int = true; \
t.allow_float = true; \
t.allow_bool = true; \
t.arity = _arity; \
t.name = megdnn_mangle(#_m); \
} \
@@ -129,7 +141,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {

#if MEGDNN_ELEMWISE_MODE_ENABLE_ALL
for (auto&& i : traits) {
megdnn_assert(i.arity && (i.allow_int || i.allow_float) &&
megdnn_assert(i.arity && (i.allow_int || i.allow_float || i.allow_bool) &&
(!i.commutable || i.arity == 2));
}
#else
@@ -282,6 +294,10 @@ void ElemwiseForward::check_dtype(DType dtype) {
megdnn_assert(trait.allow_int, "unsupport mode %s for int\n",
trait.name);
break;
case DTypeCategory::BOOL:
megdnn_assert(trait.allow_bool, "unsupport mode %s for bool\n",
trait.name);
break;
default:
megdnn_throw("bad dtype");
}


+ 18
- 1
dnn/src/common/elemwise/opr_impl_body.inl View File

@@ -18,6 +18,15 @@ void ElemwiseForwardImpl::on_arity_dispatched() {
auto src = make_elemwise_op_param<arity>();
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(on_arity_dispatched_cb_dtype)
MEGDNN_FOREACH_COMPUTING_DTYPE_INT(on_arity_dispatched_cb_dtype)
on_arity_dispatched_cb_dtype(::megdnn::dtype::Bool)
megdnn_throw("bad dtype");
}

template<int arity>
void ElemwiseForwardImpl::on_arity_dispatched_no_bool() {
auto src = make_elemwise_op_param<arity>();
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(on_arity_dispatched_cb_dtype)
MEGDNN_FOREACH_COMPUTING_DTYPE_INT(on_arity_dispatched_cb_dtype)
megdnn_throw("bad dtype");
}

@@ -45,6 +54,14 @@ IMPL_MODE_DISPATCHER(2, DTypeCategory::FLOAT);
IMPL_MODE_DISPATCHER(3, DTypeCategory::FLOAT);
#undef FOREACH

#define FOREACH MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL
IMPL_MODE_DISPATCHER(1, DTypeCategory::BOOL);
#undef FOREACH

#define FOREACH MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL
IMPL_MODE_DISPATCHER(2, DTypeCategory::BOOL);
#undef FOREACH

void ElemwiseForwardImpl::exec(
const TensorNDArray &src,
_megdnn_tensor_out dst) {
@@ -97,8 +114,8 @@ void ElemwiseForwardImpl::exec(
#define D(_n) case _n: return on_arity_dispatched<_n>()
D(1);
D(2);
D(3);
#undef D
case 3: return on_arity_dispatched_no_bool<3>();
default:
megdnn_throw("bad size of input tensors");
}


+ 3
- 0
dnn/src/common/elemwise/opr_impl_class_def.inl View File

@@ -13,6 +13,9 @@
template<int arity>
void on_arity_dispatched();

template<int arity>
void on_arity_dispatched_no_bool();

template<int arity, DTypeCategory dtype_cat, typename ctype>
struct ModeDispatcher;



+ 4
- 2
dnn/src/common/type_cvt.cpp View File

@@ -19,10 +19,12 @@ void TypeCvt::check_exec(const TensorLayout &src, const TensorLayout &dst) {
megdnn_assert_eq_shape(src, dst);
auto cat = src.dtype.category();
megdnn_assert(cat == DTypeCategory::FLOAT || cat == DTypeCategory::INT ||
cat == DTypeCategory::QUANTIZED);
cat == DTypeCategory::QUANTIZED ||
cat == DTypeCategory::BOOL);
cat = dst.dtype.category();
megdnn_assert(cat == DTypeCategory::FLOAT || cat == DTypeCategory::INT ||
cat == DTypeCategory::QUANTIZED);
cat == DTypeCategory::QUANTIZED ||
cat == DTypeCategory::BOOL);
}

} // namespace megdnn


+ 27
- 0
dnn/src/cuda/cond_take/kimpl/dt_bool.cu View File

@@ -0,0 +1,27 @@
/**
* \file dnn/src/cuda/cond_take/kimpl/dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_cond_take_kern_impls.py
#include "../kern.inl"

namespace megdnn {
namespace cuda {
namespace cond_take {

inst_genidx(::megdnn::dtype::Bool)
#undef inst_genidx

inst_copy(::megdnn::dtype::Bool)
#undef inst_copy
#undef inst_copy_

} // cond_take
} // cuda
} // megdnn

+ 18
- 12
dnn/src/cuda/elemwise/kern_wrapper.cuh View File

@@ -25,8 +25,9 @@ namespace cuda {
1, KernImpl,
typename std::enable_if<
!std::is_same<typename KernImpl::ctype, dt_int8>::value &&
!std::is_same<typename KernImpl::ctype,
dt_uint8>::value>::type> {
!std::is_same<typename KernImpl::ctype, dt_uint8>::value &&
!std::is_same<typename KernImpl::ctype,
dt_bool>::value>::type> {
typedef typename KernImpl::ctype ctype;
ctype* dst;

@@ -41,8 +42,9 @@ namespace cuda {
2, KernImpl,
typename std::enable_if<
!std::is_same<typename KernImpl::ctype, dt_int8>::value &&
!std::is_same<typename KernImpl::ctype,
dt_uint8>::value>::type> {
!std::is_same<typename KernImpl::ctype, dt_uint8>::value &&
!std::is_same<typename KernImpl::ctype,
dt_bool>::value>::type> {
typedef typename KernImpl::ctype ctype;
ctype* dst;

@@ -57,8 +59,9 @@ namespace cuda {
3, KernImpl,
typename std::enable_if<
!std::is_same<typename KernImpl::ctype, dt_int8>::value &&
!std::is_same<typename KernImpl::ctype,
dt_uint8>::value>::type> {
!std::is_same<typename KernImpl::ctype, dt_uint8>::value &&
!std::is_same<typename KernImpl::ctype,
dt_bool>::value>::type> {
typedef typename KernImpl::ctype ctype;
ctype* dst;

@@ -74,8 +77,9 @@ namespace cuda {
1, KernImpl,
typename std::enable_if<
std::is_same<typename KernImpl::ctype, dt_int8>::value ||
std::is_same<typename KernImpl::ctype,
dt_uint8>::value>::type> {
std::is_same<typename KernImpl::ctype, dt_uint8>::value ||
std::is_same<typename KernImpl::ctype,
dt_bool>::value>::type> {
typedef typename KernImpl::ctype ctype;
using VectTypeTrait = elemwise_intl::VectTypeTrait<ctype>;
typedef typename VectTypeTrait::vect_type vect_type;
@@ -99,8 +103,9 @@ namespace cuda {
2, KernImpl,
typename std::enable_if<
std::is_same<typename KernImpl::ctype, dt_int8>::value ||
std::is_same<typename KernImpl::ctype,
dt_uint8>::value>::type> {
std::is_same<typename KernImpl::ctype, dt_uint8>::value ||
std::is_same<typename KernImpl::ctype,
dt_bool>::value>::type> {
typedef typename KernImpl::ctype ctype;
using VectTypeTrait = elemwise_intl::VectTypeTrait<ctype>;
typedef typename VectTypeTrait::vect_type vect_type;
@@ -126,8 +131,9 @@ namespace cuda {
3, KernImpl,
typename std::enable_if<
std::is_same<typename KernImpl::ctype, dt_int8>::value ||
std::is_same<typename KernImpl::ctype,
dt_uint8>::value>::type> {
std::is_same<typename KernImpl::ctype, dt_uint8>::value ||
std::is_same<typename KernImpl::ctype,
dt_bool>::value>::type> {
typedef typename KernImpl::ctype ctype;
using VectTypeTrait = elemwise_intl::VectTypeTrait<ctype>;
typedef typename VectTypeTrait::vect_type vect_type;


+ 15
- 0
dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu View File

@@ -0,0 +1,15 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"

+ 15
- 0
dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu View File

@@ -0,0 +1,15 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NOT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"

+ 15
- 0
dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu View File

@@ -0,0 +1,15 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"

+ 15
- 0
dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu View File

@@ -0,0 +1,15 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"

+ 7
- 0
dnn/src/cuda/elemwise_helper.cpp View File

@@ -169,6 +169,9 @@ INST_FOR_CTYPE
#define ct dt_qint32
INST_FOR_CTYPE
#undef ct
#define ct dt_bool
INST_FOR_CTYPE
#undef ct

#undef INST_FOR_CTYPE
#undef INST
@@ -216,6 +219,9 @@ INST_FOR_CTYPE
#define ct dt_qint32
INST_FOR_CTYPE
#undef ct
#define ct dt_bool
INST_FOR_CTYPE
#undef ct

#undef ndim_cb

@@ -225,6 +231,7 @@ INST_FOR_CTYPE
#define INST(dt_ibyte) template class ParamVectVisitor<4, dt_ibyte, BCAST_1010>
INST(dt_int8);
INST(dt_uint8);
INST(dt_bool);
INST(dt_qint8);
INST(dt_quint8);
#undef dt_ibyte


+ 4
- 0
dnn/src/cuda/elemwise_helper.cuh View File

@@ -102,6 +102,7 @@ INST(dt_float16, half4);
INST(dt_bfloat16, bhalf4);
INST(dt_int32, int4);
INST(dt_int16, short4);
INST(dt_bool, uchar4);
#undef as_raw
#define as_raw(x) x.as_int8()
INST(dt_qint8, char4);
@@ -454,6 +455,7 @@ INST_DT_IBYTE(dt_int8);
INST_DT_IBYTE(dt_uint8);
INST_DT_IBYTE(dt_qint8);
INST_DT_IBYTE(dt_quint8);
INST_DT_IBYTE(dt_bool);
#undef INST_DT_IBYTE
#undef DEVICE_WRAPPER
#undef INST_PARAM_VECT_VISITOR
@@ -913,6 +915,7 @@ INST_DT_IBYTE(dt_int8);
INST_DT_IBYTE(dt_uint8);
INST_DT_IBYTE(dt_qint8);
INST_DT_IBYTE(dt_quint8);
INST_DT_IBYTE(dt_bool);
#undef INST_DT_IBYTE

//! implement general case by UserOpInvokerToSameNdim
@@ -1259,6 +1262,7 @@ INST_DT_IBYTE(dt_int8);
INST_DT_IBYTE(dt_uint8);
INST_DT_IBYTE(dt_qint8);
INST_DT_IBYTE(dt_quint8);
INST_DT_IBYTE(dt_bool);
#undef INST_DT_IBYTE
#endif



+ 10
- 4
dnn/src/cuda/type_cvt/kern.cu View File

@@ -62,7 +62,8 @@ template <typename ctype_dest, typename ctype_src>
struct TypeCvtOp<ctype_dest, ctype_src,
typename std::enable_if<
std::is_same<ctype_src, dt_int8>::value ||
std::is_same<ctype_src, dt_uint8>::value>::type> {
std::is_same<ctype_src, dt_uint8>::value ||
std::is_same<ctype_src, dt_bool>::value>::type> {
ctype_dest* dest;
using src_vect_type = typename VectTypeTrait<ctype_src>::vect_type;
using dst_vect_type = typename VectTypeTrait<ctype_dest>::vect_type;
@@ -85,7 +86,8 @@ struct TypeCvtOpToQuantized<
ctype_dest, ctype_src,
typename std::enable_if<
std::is_same<ctype_src, dt_int8>::value ||
std::is_same<ctype_src, dt_uint8>::value>::type> {
std::is_same<ctype_src, dt_uint8>::value ||
std::is_same<ctype_src, dt_bool>::value>::type> {
ctype_dest* dest;
CudaDTypeParam<ctype_dest> param;
using src_vect_type = typename VectTypeTrait<ctype_src>::vect_type;
@@ -109,7 +111,8 @@ struct TypeCvtOpFromQuantized<
ctype_dest, ctype_src,
typename std::enable_if<
std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_quint8>::value>::type> {
std::is_same<ctype_src, dt_quint8>::value ||
std::is_same<ctype_src, dt_bool>::value>::type> {
ctype_dest* dest;
CudaDTypeParam<ctype_src> param;
using src_vect_type = typename VectTypeTrait<ctype_src>::vect_type;
@@ -137,7 +140,8 @@ struct TypeCvtOpBetweenQuantized<
ctype_dest, ctype_src,
typename std::enable_if<
std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_quint8>::value>::type> {
std::is_same<ctype_src, dt_quint8>::value ||
std::is_same<ctype_src, dt_bool>::value>::type> {
ctype_dest* dest;
CudaDTypeParam<ctype_src> src_param;
CudaDTypeParam<ctype_dest> dst_param;
@@ -243,6 +247,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src,
cb(dtype_src, dt_float32) \
cb(dtype_src, dt_float16) \
cb(dtype_src, dt_bfloat16) \
cb(dtype_src, dt_bool) \

#define MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \
cb(dtype_src, dt_quint8) \
@@ -265,6 +270,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src,
cb(dt_float32) \
cb(dt_float16) \
cb(dt_bfloat16) \
cb(dt_bool) \

#define MEGDNN_FOREACH_QUANTIZED_CTYPE(cb) \
cb(dt_quint8) \


+ 8
- 3
dnn/src/fallback/type_cvt/opr_impl.cpp View File

@@ -138,7 +138,8 @@ void do_cvt_s8_normal(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
dctype* __restrict dptr = dst.ptr<dctype>();
float scale = src.layout.dtype.param<dtype::QuantizedS8>().scale;
for (size_t i = 0; i < n; ++i) {
dptr[i] = static_cast<dctype>(sptr[i] * scale);
auto val = sptr[i] * scale;
dptr[i] = static_cast<dctype>(val);
}
}

@@ -150,7 +151,8 @@ void do_cvt_s32_normal(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
dctype* __restrict dptr = dst.ptr<dctype>();
float scale = src.layout.dtype.param<dtype::QuantizedS32>().scale;
for (size_t i = 0; i < n; ++i) {
dptr[i] = static_cast<dctype>(sptr[i] * scale);
auto val = sptr[i] * scale;
dptr[i] = static_cast<dctype>(val);
}
}

@@ -163,7 +165,8 @@ void do_cvt_asymm8_normal(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
float scale = src.layout.dtype.param<dtype::Quantized8Asymm>().scale;
uint8_t zp = src.layout.dtype.param<dtype::Quantized8Asymm>().zero_point;
for (size_t i = 0; i < n; ++i) {
dptr[i] = static_cast<dctype>((sptr[i] - zp) * scale);
auto val = (sptr[i] - zp) * scale;
dptr[i] = static_cast<dctype>(val);
}
}

@@ -310,6 +313,7 @@ void on_dest_ctype(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
break; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
case DTypeEnum::QuantizedS8:
MIDOUT_BEGIN(megdnn_fb_typecvt_src_dtype,
midout_iv(DTypeEnum::QuantizedS8)) {
@@ -467,6 +471,7 @@ void run_contiguous(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
}

MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
case DTypeEnum::QuantizedS8:
MIDOUT_BEGIN(megdnn_fb_typecvt_dst_dtype,
midout_iv(DTypeEnum::QuantizedS8)) {


+ 15
- 0
dnn/src/naive/elemwise/kimpl/AND_dt_bool.cpp View File

@@ -0,0 +1,15 @@
/**
* \file dnn/src/naive/elemwise/kimpl/AND_dt_bool.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"

+ 15
- 0
dnn/src/naive/elemwise/kimpl/NOT_dt_bool.cpp View File

@@ -0,0 +1,15 @@
/**
* \file dnn/src/naive/elemwise/kimpl/NOT_dt_bool.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NOT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"

+ 15
- 0
dnn/src/naive/elemwise/kimpl/OR_dt_bool.cpp View File

@@ -0,0 +1,15 @@
/**
* \file dnn/src/naive/elemwise/kimpl/OR_dt_bool.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"

+ 15
- 0
dnn/src/naive/elemwise/kimpl/XOR_dt_bool.cpp View File

@@ -0,0 +1,15 @@
/**
* \file dnn/src/naive/elemwise/kimpl/XOR_dt_bool.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"

+ 2
- 0
dnn/src/naive/type_cvt/opr_impl.cpp View File

@@ -82,6 +82,7 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest,
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
default:
megdnn_throw("bad dtype");
@@ -103,6 +104,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
default:
megdnn_throw("bad dtype");


+ 2
- 0
dnn/test/common/elemwise.cpp View File

@@ -942,6 +942,8 @@ TEST(TEST_ELEMWISE, MODE_TRAIT) {

ASSERT_TRUE(T::from_mode(M::RMULH).commutable);
ASSERT_FALSE(T::from_mode(M::RMULH).allow_float);

ASSERT_TRUE(T::from_mode(M::XOR).allow_bool);
}

} // namespace elemwise


+ 1
- 0
python_module/src/cpp/megbrain_wrap.cpp View File

@@ -916,6 +916,7 @@ SymbolVar fill_retain_dtype(SymbolVar var, PyObject *value) {
case DTypeEnum::QuantizedS4:
case DTypeEnum::Byte:
case DTypeEnum::QuantizedS16:
case DTypeEnum::Bool:
break;
#define cb(low_bit, size) \
case DTypeEnum::low_bit##size: \


+ 1
- 0
src/core/include/megbrain/dtype.h View File

@@ -27,6 +27,7 @@ using ::megdnn::dt_int32;
using ::megdnn::dt_quint8;
using ::megdnn::dt_qint8;
using ::megdnn::dt_qint32;
using ::megdnn::dt_bool;
using ::megdnn::DType;
using ::megdnn::DTypeEnum;
using ::megdnn::DTypeTrait;


+ 2
- 2
src/jit/impl/ast_c.cpp View File

@@ -145,9 +145,9 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() {
0.f}) /
6.f),
};
mgb_assert(map.size() + 8 == opr::Elemwise::Param::MODE_NR_MEMBER);
mgb_assert(map.size() + 12 == opr::Elemwise::Param::MODE_NR_MEMBER);
// unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH,
// ERFINV, ERFCINV
// ERFINV, ERFCINV, NOT, AND, OR, XOR
return map;
#undef ADD_OPR
}


+ 8
- 0
src/jit/impl/halide/ast_hl.cpp View File

@@ -193,6 +193,14 @@ Halide::Expr dispatch_elemwise_mode(
return Halide::round(inp(0));
case Mode::RMULH:
return (inp(0) * inp(1)) >> Halide::popcount(inp(0));
case Mode::NOT:
return cv(1) - cv(inp(0) != cv(0));
case Mode::AND:
return cv(inp(0) != cv(0)) * cv(inp(1) != cv(0));
case Mode::OR:
return cv(cv(inp(0) != cv(0)) + cv(inp(1) != cv(0)) > cv(0));
case Mode::XOR:
return cv(cv(inp(0) != cv(0)) + cv(inp(1) != cv(0)) == cv(1));
default:
mgb_throw(InternalError, "unsupported Elemwise mode(%d)",
static_cast<int>(mode));


+ 6
- 0
src/opr/impl/basic_arith.cpp View File

@@ -631,6 +631,8 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
RET(EL2(H_SWISH_GRAD, i0, og));
case Mode::FUSE_ADD_H_SWISH:
RET(EL2(H_SWISH_GRAD, (i0 + i1), og));
case Mode::NOT:
return nullptr;

// binary
case Mode::ABS_GRAD:
@@ -693,6 +695,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
return nullptr;
case Mode::EQ:
RET_INVALID();
case Mode::OR:
case Mode::XOR:
case Mode::AND:
return nullptr;

// ternary
case Mode::COND_LEQ_MOV:


+ 2
- 0
src/opr/impl/loop/forward.cpp View File

@@ -408,6 +408,8 @@ cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const {
break;
case DTypeEnum::UintB4:
break;
case DTypeEnum::Bool:
break;

#define cb(x) case DTypeEnum::x: break;
MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb)


+ 2
- 0
src/opr/impl/loop/impl.cpp View File

@@ -247,6 +247,8 @@ MGB_DEFINE_OPR_CLASS(LoopImpl::DescImplBase::LoopCondManager::GetCondOpr,
break;
case DTypeEnum::UintB4:
break;
case DTypeEnum::Bool:
break;
#define cb(_dt) \
case DTypeEnum::_dt: \
break;


+ 4
- 0
src/opr/include/megbrain/opr/basic_arith_wrapper.h View File

@@ -32,6 +32,7 @@ namespace opr {
EL1(exp, EXP)
EL1(log, LOG)
EL1(abs, ABS)
EL1(not_, NOT)

#undef EL1

@@ -53,6 +54,9 @@ namespace opr {
EL2(min, MIN)
EL2(switch_gt0, SWITCH_GT0)
EL2(eq, EQ)
EL2(and_, AND)
EL2(or_, OR)
EL2(xor_, XOR)

#undef EL2



+ 62
- 1
src/opr/test/basic_arith/elemwise.cpp View File

@@ -206,6 +206,7 @@ namespace {
static constexpr Mode MODE = Mode::_mode; \
static constexpr bool ALLOW_INT = _ALLOW_INT; \
static constexpr bool ALLOW_FLOAT = _ALLOW_FLOAT; \
static constexpr bool ALLOW_BOOL = _ALLOW_BOOL; \
static constexpr const char* NAME = #_mode; \
template<typename ctype> \
static inline ctype apply( \
@@ -588,6 +589,14 @@ namespace {
struct enable_for_dtype_impl<dtype::Int32, void> {
static constexpr bool value = false;
};
template<class Trait>
struct enable_for_dtype_impl<dtype::Bool, Trait> {
static constexpr bool value = Trait::ALLOW_BOOL;
};
template<>
struct enable_for_dtype_impl<dtype::Bool, void> {
static constexpr bool value = false;
};
}

//! whether to enable test for specific dtype and Trait
@@ -749,8 +758,60 @@ TYPED_TEST(TestOprBasicArithTernaryElemwise, Float32) {

TEST(TestOprBasicArithElemwise, CheckAllModeTested) {
size_t nr_member = opr::Elemwise::Param::MODE_NR_MEMBER;
ASSERT_EQ(nr_member, tested_mode.size());
ASSERT_EQ(nr_member, tested_mode.size() + 4);
// Not using TestRunner: NOT, AND, OR, XOR
}
#define TEST_OPR_BASIC_ARITH_UNARY_BOOL(_mode, _op) \
TEST(TestOprBasicArithElemwise, _mode) { \
HostTensorGenerator<dtype::Bool> gen; \
auto host_x = gen({2, 1}); \
auto ptr = host_x->ptr<dt_bool>(); \
for (size_t i = 0; i < 2; ++i) { \
ptr[i] = (i & 1); \
} \
auto graph = ComputingGraph::make(); \
using Mode = opr::Elemwise::Mode; \
auto x = opr::Host2DeviceCopy::make(*graph, host_x), \
y = opr::Elemwise::make({x}, Mode::_mode); \
HostTensorND host_y; \
auto func = graph->compile({make_callback_copy(y, host_y)}); \
func->execute(); \
ASSERT_EQ(TensorShape({2, 1}), host_y.shape()); \
auto ptry = host_y.ptr<dt_bool>(); \
for (int i = 0;i < 2;i ++) { \
ASSERT_EQ(_op ptr[i], ptry[i]); \
} \
} \

TEST_OPR_BASIC_ARITH_UNARY_BOOL(NOT, !)

#define TEST_OPR_BASIC_ARITH_BINARY_BOOL(_mode, _op) \
TEST(TestOprBasicArithElemwise, _mode) { \
HostTensorGenerator<dtype::Bool> gen; \
auto host_x1 = gen({2, 2}), host_x2 = gen({2, 2}); \
auto ptr1 = host_x1->ptr<dt_bool>(), ptr2 = host_x2->ptr<dt_bool>(); \
for (size_t i = 0; i < 4; ++i) { \
ptr1[i] = (i < 2); \
ptr2[i] = (i & 1); \
} \
auto graph = ComputingGraph::make(); \
using Mode = opr::Elemwise::Mode; \
auto x1 = opr::Host2DeviceCopy::make(*graph, host_x1), \
x2 = opr::Host2DeviceCopy::make(*graph, host_x2), \
y = opr::Elemwise::make({x1, x2}, Mode::_mode); \
HostTensorND host_y; \
auto func = graph->compile({make_callback_copy(y, host_y)}); \
func->execute(); \
ASSERT_EQ(TensorShape({2, 2}), host_y.shape()); \
auto ptry = host_y.ptr<dt_bool>(); \
for (int i = 0;i < 4;i ++) { \
ASSERT_EQ(ptr1[i] _op ptr2[i], ptry[i]); \
} \
} \

TEST_OPR_BASIC_ARITH_BINARY_BOOL(AND, &&)
TEST_OPR_BASIC_ARITH_BINARY_BOOL(OR, ||)
TEST_OPR_BASIC_ARITH_BINARY_BOOL(XOR, ^)

TEST(TestOprBasicArithElemwise, FuseMulAdd3Shapes) {
using Checker = AutoOprChecker<3, 1>;


+ 12
- 0
src/opr/test/basic_arith/elemwise_binary_trait_def.inl View File

@@ -19,6 +19,17 @@
ctype x = inp[0][idx]; \
ctype y = inp[1][idx]

#define _ALLOW_BOOL true
#define _ALLOW_FLOAT false
#define _ALLOW_INT false
DEF_TRAIT(AND, x && y)
DEF_TRAIT(OR, x || y)
DEF_TRAIT(XOR, x ^ y)
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
#undef _ALLOW_BOOL

#define _ALLOW_BOOL false
#define _ALLOW_FLOAT true
#define _ALLOW_INT true
DEF_TRAIT(ABS_GRAD, x > 0 ? y : -y)
@@ -60,6 +71,7 @@ DEF_TRAIT(SHR, do_shr(x, y))
DEF_TRAIT(RMULH, do_round_mulh_saturate(x, y))
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
#undef _ALLOW_BOOL

#undef _CUR_ARITY
#undef _EXPAND_PARAMS


+ 2
- 0
src/opr/test/basic_arith/elemwise_ternary_trait_def.inl View File

@@ -20,6 +20,7 @@
ctype y = inp[1][idx]; \
ctype z = inp[2][idx]

#define _ALLOW_BOOL false
#define _ALLOW_FLOAT true
#define _ALLOW_INT true
DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0)
@@ -46,5 +47,6 @@ DEF_TRAIT(FUSE_MUL_ADD4, i0 * i1 + i2 * i3)

#undef _CUR_ARITY
#undef _EXPAND_PARAMS
#undef _ALLOW_BOOL

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 11
- 0
src/opr/test/basic_arith/elemwise_unary_trait_def.inl View File

@@ -18,6 +18,15 @@
#define _EXPAND_PARAMS \
ctype x = inp[0][idx]

#define _ALLOW_BOOL true
#define _ALLOW_FLOAT false
#define _ALLOW_INT false
DEF_TRAIT(NOT, !x)
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
#undef _ALLOW_BOOL

#define _ALLOW_BOOL false

#define _ALLOW_FLOAT true

@@ -51,6 +60,8 @@ DEF_TRAIT(H_SWISH, do_h_swish(x))

#undef _ALLOW_FLOAT

#undef _ALLOW_BOOL

#undef _CUR_ARITY
#undef _EXPAND_PARAMS



+ 1
- 0
src/serialization/impl/dtype.fbs View File

@@ -21,6 +21,7 @@ enum DTypeEnum : byte {
QuantizedS4,
QuantizedS16,
BFloat16,
Bool,
}

table LinearQuantizationParam {


+ 15
- 0
test/src/helper.cpp View File

@@ -141,6 +141,21 @@ namespace mgb {
template class HostTensorGenerator<
dtype::Int32, RandomDistribution::CONSTANT>;
std::shared_ptr<HostTensorND>
HostTensorGenerator<dtype::Bool, RandomDistribution::UNIFORM>::
operator()(const TensorShape& shape, CompNode cn) {
if (!cn.valid())
cn = CompNode::load("xpu0");
auto dtype = dtype::Bool();
std::shared_ptr<HostTensorND> ret =
std::make_shared<HostTensorND>(cn, shape, dtype);
auto ptr = ret->ptr<dt_bool>();
for (size_t i = 0, it = shape.total_nr_elems(); i < it; ++i) {
ptr[i] = (i % 2 == 1);
}
return ret;
}

std::shared_ptr<HostTensorND>
HostTensorGenerator<dtype::QuantizedS8, RandomDistribution::UNIFORM>::
operator()(const TensorShape& shape, CompNode cn) {
if (!cn.valid())


+ 22
- 0
test/src/include/megbrain/test/helper.h View File

@@ -202,6 +202,10 @@ struct RandomDistributionDTypeDefault<dtype::Int32> {
static constexpr auto dist = RandomDistribution::UNIFORM;
};
template<>
struct RandomDistributionDTypeDefault<dtype::Bool> {
static constexpr auto dist = RandomDistribution::UNIFORM;
};
template<>
struct RandomDistributionDTypeDefault<dtype::QuantizedS8> {
static constexpr auto dist = RandomDistribution::UNIFORM;
};
@@ -251,6 +255,10 @@ struct UniformRNGDefaultRange<dtype::Uint8> {
static constexpr dt_uint8 LO = 0, HI = 255;
};
template<>
struct UniformRNGDefaultRange<dtype::Bool> {
static constexpr dt_bool LO = false, HI = true;
};
template<>
struct UniformRNGDefaultRange<dtype::Int16> {
static constexpr dt_int16 LO = -32767, HI = 32767;
};
@@ -341,6 +349,20 @@ class HostTensorGenerator<dtype, RandomDistribution::CONSTANT> final:
private:
ctype m_default_val;
};
template <>
class HostTensorGenerator<dtype::Bool, RandomDistribution::UNIFORM> final
: public HostTensorGeneratorBase {
public:
using ctype = typename DTypeTrait<dtype::Bool>::ctype;

HostTensorGenerator(uint64_t seed = next_rand_seed())
: HostTensorGeneratorBase{seed} {}

std::shared_ptr<HostTensorND> operator()(const TensorShape& shape,
CompNode cn = {}) override;
using HostTensorGeneratorBase::operator();

};

template <>
class HostTensorGenerator<dtype::QuantizedS8, RandomDistribution::UNIFORM> final


Loading…
Cancel
Save