Browse Source

feat(mgb/dnn): add modes that the output type is bool in elemwise

GitOrigin-RevId: fd0134fca2
HuaHua404-patch-4
Megvii Engine Team 3 years ago
parent
commit
247e2f59a4
59 changed files with 856 additions and 25 deletions
  1. +12
    -0
      dnn/scripts/gen_elemwise_multi_type_utils.py
  2. +9
    -0
      dnn/scripts/opr_param_defs.py
  3. +52
    -0
      dnn/src/common/elemwise/kern_defs.cuh
  4. +0
    -1
      dnn/src/common/elemwise_multi_type/kern_defs.cuh
  5. +33
    -1
      dnn/src/common/elemwise_multi_type/opr_impl.cpp
  6. +13
    -0
      dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp
  7. +18
    -0
      dnn/src/common/elemwise_multi_type/opr_impl_helper.h
  8. +27
    -0
      dnn/src/cuda/elemwise_multi_type/kern_impl_bool.inl
  9. +77
    -5
      dnn/src/cuda/elemwise_multi_type/kern_ops.cuh
  10. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_bfloat16_dt_bool.cu
  11. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_bool_dt_bool.cu
  12. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_float16_dt_bool.cu
  13. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_float32_dt_bool.cu
  14. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int16_dt_bool.cu
  15. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int32_dt_bool.cu
  16. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int8_dt_bool.cu
  17. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_uint8_dt_bool.cu
  18. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_bfloat16_dt_bool.cu
  19. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_float16_dt_bool.cu
  20. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_float32_dt_bool.cu
  21. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_bfloat16_dt_bool.cu
  22. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_float16_dt_bool.cu
  23. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_float32_dt_bool.cu
  24. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_bfloat16_dt_bool.cu
  25. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_bool_dt_bool.cu
  26. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_float16_dt_bool.cu
  27. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_float32_dt_bool.cu
  28. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int16_dt_bool.cu
  29. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int32_dt_bool.cu
  30. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int8_dt_bool.cu
  31. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_uint8_dt_bool.cu
  32. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_bfloat16_dt_bool.cu
  33. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_bool_dt_bool.cu
  34. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_float16_dt_bool.cu
  35. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_float32_dt_bool.cu
  36. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int16_dt_bool.cu
  37. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int32_dt_bool.cu
  38. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int8_dt_bool.cu
  39. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_uint8_dt_bool.cu
  40. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_bfloat16_dt_bool.cu
  41. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_bool_dt_bool.cu
  42. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_float16_dt_bool.cu
  43. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_float32_dt_bool.cu
  44. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int16_dt_bool.cu
  45. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int32_dt_bool.cu
  46. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int8_dt_bool.cu
  47. +6
    -0
      dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_uint8_dt_bool.cu
  48. +110
    -0
      dnn/src/cuda/elemwise_multi_type/opr_impl.cpp
  49. +8
    -0
      dnn/src/cuda/elemwise_multi_type/opr_impl.h
  50. +50
    -1
      dnn/src/naive/elemwise_multi_type/opr_impl.h
  51. +117
    -0
      dnn/src/naive/elemwise_multi_type/opr_impl_4.cpp
  52. +3
    -1
      dnn/test/common/checker.cpp
  53. +8
    -0
      dnn/test/common/elemwise.cpp
  54. +3
    -0
      dnn/test/common/rng.cpp
  55. +49
    -1
      dnn/test/cuda/elemwise_multi_type.cpp
  56. +32
    -8
      imperative/python/megengine/core/tensor/array_method.py
  57. +3
    -3
      imperative/python/megengine/functional/math.py
  58. +2
    -2
      src/jit/impl/ast_c.cpp
  59. +2
    -2
      src/opr/test/basic_arith/elemwise.cpp

+ 12
- 0
dnn/scripts/gen_elemwise_multi_type_utils.py View File

@@ -6,6 +6,10 @@ SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32'),
SUPPORT_DTYPES_Q4 = [('dt_qint4', 'dt_qint4'), ('dt_quint4', 'dt_quint4')] SUPPORT_DTYPES_Q4 = [('dt_qint4', 'dt_qint4'), ('dt_quint4', 'dt_quint4')]
SUPPORT_QINT32_DTYPES_Q4 = [('dt_qint32', 'dt_qint4'), ('dt_qint32', 'dt_quint4')] SUPPORT_QINT32_DTYPES_Q4 = [('dt_qint32', 'dt_qint4'), ('dt_qint32', 'dt_quint4')]


SUPPORT_ARRITY2_DTYPES = ['dt_int32', 'dt_uint8', 'dt_int8', 'dt_int16', 'dt_bool', 'dt_float32',
'dt_float16', 'dt_bfloat16']
SUPPORT_ARRITY1_DTYPES = ['dt_float32','dt_float16', 'dt_bfloat16']

MODES = { MODES = {
1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', 1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS',
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN',
@@ -34,3 +38,11 @@ QINT32_MODES = {
2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID', 2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID',
'FUSE_ADD_TANH', 'FUSE_ADD_H_SWISH'] 'FUSE_ADD_TANH', 'FUSE_ADD_H_SWISH']
} }

ARRITY1_BOOL_MODES = {
1: ['ISINF','ISNAN'],
}

ARRITY2_BOOL_MODES = {
2: ['EQ','LEQ','NEQ','LT'],
}

+ 9
- 0
dnn/scripts/opr_param_defs.py View File

@@ -421,6 +421,9 @@ pdef('Elemwise').add_enum(
Doc('GELU = 58', 'unary: x Phi(x)'), Doc('GELU = 58', 'unary: x Phi(x)'),
Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'), Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'),
Doc('COND_LT_MOV = 60', 'ternary: x < y ? z : 0'), Doc('COND_LT_MOV = 60', 'ternary: x < y ? z : 0'),
Doc('NEQ = 61', 'binary: x != y'),
Doc('ISNAN = 62', 'unary: isnan(x)'),
Doc('ISINF = 63', 'unary: isinf(x)'),
) )


pdef('ElemwiseMultiType').add_enum( pdef('ElemwiseMultiType').add_enum(
@@ -513,6 +516,12 @@ pdef('ElemwiseMultiType').add_enum(
'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and ' 'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and '
'``c`` float32, and the result is float32.'), '``c`` float32, and the result is float32.'),
Doc('QCOND_LT_MOV = 57', 'quantized cond_lt_mov'), Doc('QCOND_LT_MOV = 57', 'quantized cond_lt_mov'),
Doc('EQ = 58', 'eq'),
Doc('NEQ = 59', 'eq'),
Doc('LT = 60', 'lt'),
Doc('LEQ = 61', 'leq'),
Doc('ISNAN = 62', 'isnan'),
Doc('ISINF = 63', 'isinf')
) )


pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)


+ 52
- 0
dnn/src/common/elemwise/kern_defs.cuh View File

@@ -10,6 +10,7 @@


#include <cmath> #include <cmath>
#include <cstdlib> #include <cstdlib>
#include "math.h"


#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
#include <algorithm> #include <algorithm>
@@ -272,6 +273,57 @@ DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z);


#undef DEF_KERN_AD #undef DEF_KERN_AD
#undef DEF_KERN #undef DEF_KERN
#undef DEF_KERN_FLOAT
#undef DEF_KERN_INT
#undef DEF_KERN_ALL

/* ================== bool kernels ================== */
//! define kernel
template <megcorePlatform_t plat, uint32_t mode, typename stype, typename dtype>
struct ElemwiseBoolKern;

#define DEF_KERN(_ctype, _dtype, _mode, _imp) \
template <megcorePlatform_t plat> \
struct ElemwiseBoolKern< \
plat, param_enumv::Elemwise::Mode::_mode, _ctype, _dtype> { \
typedef _ctype ctype; \
static __host__ __device__ _dtype apply(KERN_SIG) { return _dtype(_imp); } \
}

//! define kernel for all float types
#define DEF_KERN_FLOAT(_mode, _imp) \
DEF_KERN(dt_float32, dt_bool, _mode, _imp); \
DNN_INC_FLOAT16(DEF_KERN(dt_float16, dt_bool, _mode, _imp);) \
DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, dt_bool, _mode, _imp);)

//! define kernel for all int types
#define DEF_KERN_INT(_mode, _imp) \
DEF_KERN(dt_int32, dt_bool, _mode, _imp); \
DEF_KERN(dt_int16, dt_bool, _mode, _imp); \
DEF_KERN(dt_int8, dt_bool, _mode, _imp); \
DEF_KERN(dt_uint8, dt_bool, _mode, _imp);

//! define kernel for all ctypes
#define DEF_KERN_ALL(_mode, _imp) \
DEF_KERN_INT(_mode, _imp); \
DEF_KERN_FLOAT(_mode, _imp); \
DEF_KERN(dt_bool, dt_bool, _mode, _imp);
#define KERN_SIG ctype x
DEF_KERN_FLOAT(ISNAN, isnan(float(x)));
DEF_KERN_FLOAT(ISINF, isinf(float(x)));
#undef KERN_SIG
#define KERN_SIG ctype x, ctype y
DEF_KERN_ALL(LT, x < y);
DEF_KERN_ALL(LEQ, x <= y);
DEF_KERN_ALL(EQ, x == y);
DEF_KERN_ALL(NEQ, x != y);
#undef KERN_SIG

#undef DEF_KERN_AD
#undef DEF_KERN
#undef DEF_KERN_FLOAT
#undef DEF_KERN_INT
#undef DEF_KERN_ALL


} // namespace megdnn } // namespace megdnn




+ 0
- 1
dnn/src/common/elemwise_multi_type/kern_defs.cuh View File

@@ -28,7 +28,6 @@ MEGDNN_HOST MEGDNN_DEVICE dtype round_shr_saturate(stype x, int k) {
} }
return static_cast<dtype>(result); return static_cast<dtype>(result);
} }

} // namespace elemwise_multi_type } // namespace elemwise_multi_type
} // namespace megdnn } // namespace megdnn




+ 33
- 1
dnn/src/common/elemwise_multi_type/opr_impl.cpp View File

@@ -31,6 +31,14 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
return func; return func;
}; };


auto make_not_check_dtype_func = []() {
auto func = [](DType dtype) {
megdnn_assert(
true, "This function is to not check the dtype %s", dtype.name());
};
return func;
};

auto make_check_category = [](DTypeCategory expected) { auto make_check_category = [](DTypeCategory expected) {
auto func = [expected](DType dtype) { auto func = [expected](DType dtype) {
megdnn_assert(expected == dtype.category()); megdnn_assert(expected == dtype.category());
@@ -126,6 +134,23 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
dst.need_specify_out_dtype = true; dst.need_specify_out_dtype = true;
}; };


auto init_bool_unary_op = [&](ModeTrait& dst, const char* name) {
dst.arity = 1;
dst.check_inp[0] = make_check_category(DTypeCategory::FLOAT);
dst.check_out = make_out_dtype_func(dtype::Bool());
dst.name = name;
dst.need_specify_out_dtype = true;
};

auto init_bool_binary_op = [&](ModeTrait& dst, const char* name) {
dst.arity = 2;
dst.check_inp[0] = make_not_check_dtype_func();
dst.check_inp[1] = make_not_check_dtype_func();
dst.check_out = make_out_dtype_func(dtype::Bool());
dst.name = name;
dst.need_specify_out_dtype = true;
};

auto init_quantized_binary_op = [&](ModeTrait& dst, const char* name) { auto init_quantized_binary_op = [&](ModeTrait& dst, const char* name) {
dst.arity = 2; dst.arity = 2;
dst.check_inp[0] = make_check_category(DTypeCategory::QUANTIZED); dst.check_inp[0] = make_check_category(DTypeCategory::QUANTIZED);
@@ -240,6 +265,13 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
SET(init_quantized_ternary_op, QFUSE_MUL_ADD3); SET(init_quantized_ternary_op, QFUSE_MUL_ADD3);
SET(init_quantized_ternary_op, QCOND_LEQ_MOV); SET(init_quantized_ternary_op, QCOND_LEQ_MOV);
SET(init_quantized_ternary_op, QCOND_LT_MOV); SET(init_quantized_ternary_op, QCOND_LT_MOV);

SET(init_bool_binary_op, LT);
SET(init_bool_binary_op, LEQ);
SET(init_bool_binary_op, EQ);
SET(init_bool_binary_op, NEQ);
SET(init_bool_unary_op, ISNAN);
SET(init_bool_unary_op, ISINF);
#undef SET #undef SET
} }


@@ -273,4 +305,4 @@ void ElemwiseMultiType::check_layout_and_broadcast(
megdnn_assert(dst.is_contiguous()); megdnn_assert(dst.is_contiguous());
} }


// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 13
- 0
dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp View File

@@ -9,6 +9,12 @@ using namespace megdnn;
make_elemwise_op_param<_n>(src, dst), dst, Elemwise::Mode::_MODE); \ make_elemwise_op_param<_n>(src, dst), dst, Elemwise::Mode::_MODE); \
break break


#define ON_BOOL_MODE(_MODE, _n) \
case Mode::_MODE: \
dest_type_bool_mode( \
make_elemwise_op_param<_n>(src, dst), dst, Elemwise::Mode::_MODE); \
break

void ElemwiseMultiTypeImplHelper::exec( void ElemwiseMultiTypeImplHelper::exec(
_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) { _megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) {
switch (m_param.mode) { switch (m_param.mode) {
@@ -96,6 +102,13 @@ void ElemwiseMultiTypeImplHelper::exec(
ON_QUANTIZED_MODE(FUSE_MUL_ADD3, 3); ON_QUANTIZED_MODE(FUSE_MUL_ADD3, 3);
ON_QUANTIZED_MODE(COND_LEQ_MOV, 3); ON_QUANTIZED_MODE(COND_LEQ_MOV, 3);
ON_QUANTIZED_MODE(COND_LT_MOV, 3); ON_QUANTIZED_MODE(COND_LT_MOV, 3);

ON_BOOL_MODE(LT, 2);
ON_BOOL_MODE(LEQ, 2);
ON_BOOL_MODE(EQ, 2);
ON_BOOL_MODE(NEQ, 2);
ON_BOOL_MODE(ISNAN, 1);
ON_BOOL_MODE(ISINF, 1);
default: default:
megdnn_throw("invalid mode"); megdnn_throw("invalid mode");
} }


+ 18
- 0
dnn/src/common/elemwise_multi_type/opr_impl_helper.h View File

@@ -73,6 +73,24 @@ protected:
const ElemwiseOpParamN<2>& param, const TensorND& dst, const ElemwiseOpParamN<2>& param, const TensorND& dst,
Elemwise::Mode mode) = 0; Elemwise::Mode mode) = 0;


virtual void dest_type_bool_mode(
const ElemwiseOpParamN<1>& param, const TensorND& dst,
Elemwise::Mode mode) {
MEGDNN_MARK_USED_VAR(param);
MEGDNN_MARK_USED_VAR(dst);
MEGDNN_MARK_USED_VAR(mode);
megdnn_throw("Unrealized except arm_common");
}

virtual void dest_type_bool_mode(
const ElemwiseOpParamN<2>& param, const TensorND& dst,
Elemwise::Mode mode) {
MEGDNN_MARK_USED_VAR(param);
MEGDNN_MARK_USED_VAR(dst);
MEGDNN_MARK_USED_VAR(mode);
megdnn_throw("Unrealized except arm_common");
}

virtual void on_quantized_mode( virtual void on_quantized_mode(
const ElemwiseOpParamN<3>& param, const TensorND& dst, const ElemwiseOpParamN<3>& param, const TensorND& dst,
Elemwise::Mode mode) { Elemwise::Mode mode) {


+ 27
- 0
dnn/src/cuda/elemwise_multi_type/kern_impl_bool.inl View File

@@ -0,0 +1,27 @@
#pragma once

#ifndef KERN_IMPL_MODE
#error "KERN_IMPL_MODE, KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE must be defined"
#endif

#include "src/cuda/elemwise_multi_type/kern_ops.cuh"

namespace megdnn {
namespace cuda {

#define cb(_m) \
typedef ElemwiseBoolKern< \
megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, KERN_IMPL_STYPE, \
KERN_IMPL_DTYPE> \
KernImpl; \
typedef kern_ops_quantized::QuantizedMultiTypeOp< \
KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE, KernImpl> \
Op; \
INST_RUN_ELEMWISE(Op, KERN_IMPL_STYPE, KERN_IMPL_ARITY);

KERN_IMPL_MODE(cb)

} // namespace cuda
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 77
- 5
dnn/src/cuda/elemwise_multi_type/kern_ops.cuh View File

@@ -4,7 +4,6 @@
#include "src/cuda/elemwise_multi_type/kern.cuh" #include "src/cuda/elemwise_multi_type/kern.cuh"
#include "src/cuda/integer_subbyte_utils.cuh" #include "src/cuda/integer_subbyte_utils.cuh"
#include "src/cuda/utils.cuh" #include "src/cuda/utils.cuh"

namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {
using namespace elemwise_intl; using namespace elemwise_intl;
@@ -122,6 +121,7 @@ struct QuantizedMultiTypeOp<
(std::is_same<ctype_src, dt_qint8>::value || (std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_qint32>::value || std::is_same<ctype_src, dt_qint32>::value ||
std::is_same<ctype_src, dt_quint8>::value) && std::is_same<ctype_src, dt_quint8>::value) &&
!std::is_same<ctype_dst, dt_bool>::value &&
IsNotTypeQ4<ctype_dst>::value>::type> { IsNotTypeQ4<ctype_dst>::value>::type> {
ctype_dst* dst; ctype_dst* dst;
CudaDTypeParam<ctype_dst> dst_param; CudaDTypeParam<ctype_dst> dst_param;
@@ -160,11 +160,43 @@ struct QuantizedMultiTypeOp<


template <typename ctype_src, typename ctype_dst, typename KernImpl> template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp< struct QuantizedMultiTypeOp<
1, ctype_src, ctype_dst, KernImpl,
typename std::enable_if<std::is_same<ctype_dst, dt_bool>::value>::type> {
ctype_dst* dst;
typedef typename elemwise_intl::VectTypeTrait<ctype_src>::vect_type src_vect_type;
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type dst_vect_type;

#if !MEGDNN_CC_CUDA
QuantizedMultiTypeOp(ctype_dst* m_dst) : dst{m_dst} {}
#endif

#if MEGDNN_CC_CUDA
__device__ __forceinline__ ctype_dst apply(ctype_src v1) {
ctype_dst rv = KernImpl::apply(v1);
return rv;
}

__device__ __forceinline__ void operator()(uint32_t idx, ctype_src a) {
dst[idx] = KernImpl::apply(a);
}

__device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a) {
ctype_src a_x(a.x), a_y(a.y), a_z(a.z), a_w(a.w);
ctype_dst x = apply(a_x), y = apply(a_y), z = apply(a_z), w = apply(a_w);
*(dst_vect_type*)(&dst[idx]) =
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y, z, w);
}
#endif
};

template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp<
2, ctype_src, ctype_dst, KernImpl, 2, ctype_src, ctype_dst, KernImpl,
typename std::enable_if< typename std::enable_if<
(std::is_same<ctype_src, dt_qint8>::value || (std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_qint32>::value || std::is_same<ctype_src, dt_qint32>::value ||
std::is_same<ctype_src, dt_quint8>::value) && std::is_same<ctype_src, dt_quint8>::value) &&
!std::is_same<ctype_dst, dt_bool>::value &&
IsNotTypeQ4<ctype_dst>::value>::type> { IsNotTypeQ4<ctype_dst>::value>::type> {
ctype_dst* dst; ctype_dst* dst;
CudaDTypeParam<ctype_dst> dst_param; CudaDTypeParam<ctype_dst> dst_param;
@@ -208,6 +240,40 @@ struct QuantizedMultiTypeOp<


template <typename ctype_src, typename ctype_dst, typename KernImpl> template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp< struct QuantizedMultiTypeOp<
2, ctype_src, ctype_dst, KernImpl,
typename std::enable_if<(std::is_same<ctype_dst, dt_bool>::value)>::type> {
ctype_dst* dst;
typedef typename elemwise_intl::VectTypeTrait<ctype_src>::vect_type src_vect_type;
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type dst_vect_type;

#if !MEGDNN_CC_CUDA
QuantizedMultiTypeOp(ctype_dst* m_dst) : dst{m_dst} {}
#endif

#if MEGDNN_CC_CUDA
__device__ __forceinline__ ctype_dst apply(ctype_src v1, ctype_src v2) {
ctype_dst rv = KernImpl::apply(v1, v2);
return rv;
}

__device__ __forceinline__ void operator()(uint32_t idx, ctype_src a, ctype_src b) {
dst[idx] = KernImpl::apply(a, b);
}

__device__ __forceinline__ void operator()(
uint32_t idx, src_vect_type a, src_vect_type b) {
ctype_src a_x(a.x), a_y(a.y), a_z(a.z), a_w(a.w), b_x(b.x), b_y(b.y), b_z(b.z),
b_w(b.w);
ctype_dst x = apply(a_x, b_x), y = apply(a_y, b_y), z = apply(a_z, b_z),
w = apply(a_w, b_w);
*(dst_vect_type*)(&dst[idx]) =
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y, z, w);
}
#endif
};

template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp<
3, ctype_src, ctype_dst, KernImpl, 3, ctype_src, ctype_dst, KernImpl,
typename std::enable_if< typename std::enable_if<
(std::is_same<ctype_src, dt_qint8>::value || (std::is_same<ctype_src, dt_qint8>::value ||
@@ -262,7 +328,8 @@ template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp< struct QuantizedMultiTypeOp<
1, ctype_src, ctype_dst, KernImpl, 1, ctype_src, ctype_dst, KernImpl,
typename std::enable_if< typename std::enable_if<
IsTypeQ4<ctype_src>::value && IsNotTypeQ4<ctype_dst>::value>::type> {
IsTypeQ4<ctype_src>::value && IsNotTypeQ4<ctype_dst>::value &&
!std::is_same<ctype_dst, dt_bool>::value>::type> {
ctype_dst* dst; ctype_dst* dst;
CudaDTypeParam<ctype_dst> dst_param; CudaDTypeParam<ctype_dst> dst_param;
CudaDTypeParam<ctype_src> param_a; CudaDTypeParam<ctype_src> param_a;
@@ -293,7 +360,8 @@ template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp< struct QuantizedMultiTypeOp<
2, ctype_src, ctype_dst, KernImpl, 2, ctype_src, ctype_dst, KernImpl,
typename std::enable_if< typename std::enable_if<
IsTypeQ4<ctype_src>::value && IsNotTypeQ4<ctype_dst>::value>::type> {
IsTypeQ4<ctype_src>::value && IsNotTypeQ4<ctype_dst>::value &&
!std::is_same<ctype_dst, dt_bool>::value>::type> {
ctype_dst* dst; ctype_dst* dst;
CudaDTypeParam<ctype_dst> dst_param; CudaDTypeParam<ctype_dst> dst_param;
CudaDTypeParam<ctype_src> param_a, param_b; CudaDTypeParam<ctype_src> param_a, param_b;
@@ -326,7 +394,8 @@ template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp< struct QuantizedMultiTypeOp<
1, ctype_src, ctype_dst, KernImpl, 1, ctype_src, ctype_dst, KernImpl,
typename std::enable_if< typename std::enable_if<
IsTypeQ4<ctype_src>::value && IsTypeQ4<ctype_dst>::value>::type> {
IsTypeQ4<ctype_src>::value && IsTypeQ4<ctype_dst>::value &&
!std::is_same<ctype_dst, dt_bool>::value>::type> {
using src_storage = typename elemwise_intl::VectTypeTrait<ctype_src>::Storage; using src_storage = typename elemwise_intl::VectTypeTrait<ctype_src>::Storage;
using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage;
dst_storage* dst; dst_storage* dst;
@@ -371,6 +440,7 @@ struct QuantizedMultiTypeOp<
(std::is_same<ctype_src, dt_qint8>::value || (std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_qint32>::value || std::is_same<ctype_src, dt_qint32>::value ||
std::is_same<ctype_src, dt_quint8>::value) && std::is_same<ctype_src, dt_quint8>::value) &&
!std::is_same<ctype_dst, dt_bool>::value &&
IsTypeQ4<ctype_dst>::value>::type> { IsTypeQ4<ctype_dst>::value>::type> {
using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage;
dst_storage* dst; dst_storage* dst;
@@ -407,7 +477,8 @@ template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp< struct QuantizedMultiTypeOp<
2, ctype_src, ctype_dst, KernImpl, 2, ctype_src, ctype_dst, KernImpl,
typename std::enable_if< typename std::enable_if<
IsTypeQ4<ctype_src>::value && IsTypeQ4<ctype_dst>::value>::type> {
IsTypeQ4<ctype_src>::value && IsTypeQ4<ctype_dst>::value &&
!std::is_same<ctype_dst, dt_bool>::value>::type> {
using src_storage = typename elemwise_intl::VectTypeTrait<ctype_src>::Storage; using src_storage = typename elemwise_intl::VectTypeTrait<ctype_src>::Storage;
using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage;
dst_storage* dst; dst_storage* dst;
@@ -460,6 +531,7 @@ struct QuantizedMultiTypeOp<
(std::is_same<ctype_src, dt_qint8>::value || (std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_qint32>::value || std::is_same<ctype_src, dt_qint32>::value ||
std::is_same<ctype_src, dt_quint8>::value) && std::is_same<ctype_src, dt_quint8>::value) &&
!std::is_same<ctype_dst, dt_bool>::value &&
IsTypeQ4<ctype_dst>::value>::type> { IsTypeQ4<ctype_dst>::value>::type> {
using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage;
dst_storage* dst; dst_storage* dst;


+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_bfloat16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_bfloat16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_bool_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_bool
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_float16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_float16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_float32_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_float32
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_int16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int32_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_int32
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int8_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_int8
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_uint8_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_uint8
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_bfloat16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISINF, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_STYPE dt_bfloat16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_float16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISINF, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_STYPE dt_float16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_float32_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISINF, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_STYPE dt_float32
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_bfloat16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISNAN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_STYPE dt_bfloat16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_float16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISNAN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_STYPE dt_float16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_float32_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISNAN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_STYPE dt_float32
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_bfloat16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_bfloat16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_bool_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_bool
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_float16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_float16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_float32_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_float32
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_int16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int32_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_int32
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int8_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_int8
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_uint8_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_uint8
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_bfloat16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_bfloat16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_bool_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_bool
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_float16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_float16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_float32_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_float32
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_int16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int32_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_int32
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int8_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_int8
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_uint8_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_uint8
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_bfloat16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_bfloat16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_bool_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_bool
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_float16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_float16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_float32_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_float32
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int16_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_int16
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int32_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_int32
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int8_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_int8
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 6
- 0
dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_uint8_dt_bool.cu View File

@@ -0,0 +1,6 @@
// generated by gen_elemwise_multi_type_kern_impls_bool.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_STYPE dt_uint8
#define KERN_IMPL_DTYPE dt_bool
#include "../kern_impl_bool.inl"

+ 110
- 0
dnn/src/cuda/elemwise_multi_type/opr_impl.cpp View File

@@ -295,6 +295,60 @@ IMPL_MODE_DISPATCHER(2, dt_qint32, dt_quint4);
#undef _cb_dispatch_mode #undef _cb_dispatch_mode
#undef IMPL_MODE_DISPATCHER #undef IMPL_MODE_DISPATCHER


#define _cb_dispatch_mode(_m) \
case param::Elemwise::Mode::_m: \
do { \
using KernImpl = ElemwiseBoolKern< \
megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, src_ctype, \
dt_bool>; \
using Op = kern_ops_quantized::QuantizedMultiTypeOp< \
arity, src_ctype, bool, KernImpl>; \
dst_ctype* dst = dst_tensor.ptr<dst_ctype>(); \
Op op(dst); \
return run_elemwise<Op, src_ctype, arity>(src, stream, op); \
} while (0);

#define IMPL_MODE_DISPATCHER_BOOL(_arity, _src_ctype, _dst_ctype) \
template <> \
struct ModeDispatcher<_arity, _src_ctype, _dst_ctype> { \
static constexpr int arity = _arity; \
using src_ctype = _src_ctype; \
using dst_ctype = _dst_ctype; \
static void run( \
const ElemwiseOpParamN<_arity>& src, const TensorND& dst_tensor, \
param::Elemwise::Mode mode, cudaStream_t stream) { \
switch (mode) { \
FOREACH(_cb_dispatch_mode) \
default: \
megdnn_throw("bad mode"); \
} \
} \
}

#define FOREACH(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb)
IMPL_MODE_DISPATCHER_BOOL(2, dt_int8, dt_bool);
IMPL_MODE_DISPATCHER_BOOL(2, dt_float32, dt_bool);
IMPL_MODE_DISPATCHER_BOOL(2, dt_bfloat16, dt_bool);
IMPL_MODE_DISPATCHER_BOOL(2, dt_float16, dt_bool);
IMPL_MODE_DISPATCHER_BOOL(2, dt_int16, dt_bool);
IMPL_MODE_DISPATCHER_BOOL(2, dt_int32, dt_bool);
IMPL_MODE_DISPATCHER_BOOL(2, dt_bool, dt_bool);
IMPL_MODE_DISPATCHER_BOOL(2, dt_uint8, dt_bool);
#undef FOREACH
#define FOREACH(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ISNAN, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ISINF, cb)
IMPL_MODE_DISPATCHER_BOOL(1, dt_float16, dt_bool);
IMPL_MODE_DISPATCHER_BOOL(1, dt_float32, dt_bool);
IMPL_MODE_DISPATCHER_BOOL(1, dt_bfloat16, dt_bool);
#undef FOREACH
#undef _cb_dispatch_mode
#undef IMPL_MODE_DISPATCHER_BOOL

template <typename ctype_src> template <typename ctype_src>
void dispatch_src_ctype( void dispatch_src_ctype(
const ElemwiseOpParamN<1>&, const TensorND& dst_tensor, Elemwise::Mode, const ElemwiseOpParamN<1>&, const TensorND& dst_tensor, Elemwise::Mode,
@@ -578,6 +632,62 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
#undef DISPATCH #undef DISPATCH
} }


void ElemwiseMultiTypeImpl::dest_type_bool_mode(
const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor,
Elemwise::Mode mode) {
auto stream = cuda_stream(this->handle());
switch (param[0].layout.dtype.enumv()) {
#define DISPATCH(_dt) \
case DTypeTrait<_dt>::enumv: { \
ModeDispatcher<1, typename DTypeTrait<_dt>::ctype, bool>::run( \
param, dst_tensor, mode, stream); \
break; \
}

DISPATCH(dtype::Float32);
DISPATCH(dtype::Float16);
DISPATCH(dtype::BFloat16);

default:
megdnn_throw(ssprintf(
"Unsupported input dtype %s for ElemwiseMultiType",
param[0].layout.dtype.name()));
}

#undef DISPATCH
}

void ElemwiseMultiTypeImpl::dest_type_bool_mode(
const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor,
Elemwise::Mode mode) {
megdnn_assert(param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv());
auto stream = cuda_stream(this->handle());
switch (param[0].layout.dtype.enumv()) {
#define DISPATCH(_dt) \
case DTypeTrait<_dt>::enumv: { \
ModeDispatcher<2, typename DTypeTrait<_dt>::ctype, bool>::run( \
param, dst_tensor, mode, stream); \
break; \
}

DISPATCH(dtype::Int8);
DISPATCH(dtype::Float32);
DISPATCH(dtype::BFloat16);
DISPATCH(dtype::Bool);
DISPATCH(dtype::Float16);
DISPATCH(dtype::Int16);
DISPATCH(dtype::Int32);
DISPATCH(dtype::Uint8);

default:
megdnn_throw(ssprintf(
"Unsupported input dtype %s for ElemwiseMultiType",
param[0].layout.dtype.name()));
}

#undef DISPATCH
}

void ElemwiseMultiTypeImpl::on_quantized_mode( void ElemwiseMultiTypeImpl::on_quantized_mode(
const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor, const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor,
Elemwise::Mode mode) { Elemwise::Mode mode) {


+ 8
- 0
dnn/src/cuda/elemwise_multi_type/opr_impl.h View File

@@ -36,6 +36,14 @@ class ElemwiseMultiTypeImpl final : public ElemwiseMultiTypeImplHelper {
const ElemwiseOpParamN<3>& param, const TensorND& dst, const ElemwiseOpParamN<3>& param, const TensorND& dst,
Elemwise::Mode mode) override; Elemwise::Mode mode) override;


void dest_type_bool_mode(
const ElemwiseOpParamN<1>& param, const TensorND& dst,
Elemwise::Mode mode) override;

void dest_type_bool_mode(
const ElemwiseOpParamN<2>& param, const TensorND& dst,
Elemwise::Mode mode) override;

public: public:
using ElemwiseMultiTypeImplHelper::ElemwiseMultiTypeImplHelper; using ElemwiseMultiTypeImplHelper::ElemwiseMultiTypeImplHelper;
}; };


+ 50
- 1
dnn/src/naive/elemwise_multi_type/opr_impl.h View File

@@ -3,7 +3,6 @@
#include "megdnn/tensor_iter.h" #include "megdnn/tensor_iter.h"
#include "src/common/elemwise_multi_type/opr_impl_helper.h" #include "src/common/elemwise_multi_type/opr_impl_helper.h"
#include "src/naive/handle.h" #include "src/naive/handle.h"

namespace megdnn { namespace megdnn {
namespace naive { namespace naive {


@@ -68,6 +67,25 @@ class ElemwiseMultiTypeImpl : public ElemwiseMultiTypeImplHelper {
} }


template <typename KernImpl, typename src_ctype, typename dst_ctype> template <typename KernImpl, typename src_ctype, typename dst_ctype>
void dispatch_dst_bool_op(
const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor) {
auto size = param.size;
auto src0 = param[0];
auto work = [src0, size, dst_tensor]() {
// This is needed as these iterators are captured as const value.
auto iA = tensor_iter_valonly<src_ctype>(src0).begin();
auto pD = tensor_iter_valonly<dst_ctype>(dst_tensor).begin();
for (size_t i = 0; i < size; i++) {
src_ctype a = *iA;
*pD = KernImpl::apply(a);
++iA;
++pD;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}

template <typename KernImpl, typename src_ctype, typename dst_ctype>
void dispatch_add_qint_op( void dispatch_add_qint_op(
const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor) { const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor) {
auto size = param.size; auto size = param.size;
@@ -98,6 +116,29 @@ class ElemwiseMultiTypeImpl : public ElemwiseMultiTypeImplHelper {
} }


template <typename KernImpl, typename src_ctype, typename dst_ctype> template <typename KernImpl, typename src_ctype, typename dst_ctype>
void dispatch_dst_bool_op(
const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto work = [src0, src1, size, dst_tensor]() {
// This is needed as these iterators are captured as const value.
auto iA = tensor_iter_valonly<src_ctype>(src0).begin();
auto iB = tensor_iter_valonly<src_ctype>(src1).begin();
auto pD = tensor_iter_valonly<dst_ctype>(dst_tensor).begin();
for (size_t i = 0; i < size; i++) {
src_ctype a = *iA;
src_ctype b = *iB;
*pD = KernImpl::apply(a, b);
++iA;
++iB;
++pD;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}

template <typename KernImpl, typename src_ctype, typename dst_ctype>
void dispatch_add_qint_op( void dispatch_add_qint_op(
const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor) { const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor) {
auto size = param.size; auto size = param.size;
@@ -178,6 +219,14 @@ protected:
const ElemwiseOpParamN<3>& param, const TensorND& dst, const ElemwiseOpParamN<3>& param, const TensorND& dst,
Elemwise::Mode mode) override; Elemwise::Mode mode) override;


void dest_type_bool_mode(
const ElemwiseOpParamN<1>& param, const TensorND& dst,
Elemwise::Mode mode) override;

void dest_type_bool_mode(
const ElemwiseOpParamN<2>& param, const TensorND& dst,
Elemwise::Mode mode) override;

public: public:
using ElemwiseMultiTypeImplHelper::ElemwiseMultiTypeImplHelper; using ElemwiseMultiTypeImplHelper::ElemwiseMultiTypeImplHelper;
}; };


+ 117
- 0
dnn/src/naive/elemwise_multi_type/opr_impl_4.cpp View File

@@ -54,4 +54,121 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
} }
} }


void ElemwiseMultiTypeImpl::dest_type_bool_mode(
const ElemwiseOpParamN<1>& param, const TensorND& dst, Elemwise::Mode mode) {
switch (mode) {
case Elemwise::Mode::ISINF: {
switch (param[0].layout.dtype.enumv()) {
#define DISPATCH(_dt, _mode) \
case DTypeTrait<_dt>::enumv: { \
typedef ElemwiseBoolKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \
typename DTypeTrait<_dt>::ctype, dt_bool> \
KernImpl##_mode; \
dispatch_dst_bool_op< \
KernImpl##_mode, typename DTypeTrait<_dt>::ctype, dt_bool>( \
param, dst); \
break; \
}
#define DISPATCH_MODE(_mode) \
DISPATCH(megdnn::dtype::Float32, _mode); \
DNN_INC_FLOAT16(DISPATCH(megdnn::dtype::Float16, _mode);) \
DNN_INC_FLOAT16(DISPATCH(megdnn::dtype::BFloat16, _mode);)
DISPATCH_MODE(ISINF);
default:
megdnn_throw(ssprintf(
"Unsupported input dtype %s for ElemwiseMultiType",
param[0].layout.dtype.name()));
};
break;
};
case Elemwise::Mode::ISNAN: {
switch (param[0].layout.dtype.enumv()) {
DISPATCH_MODE(ISNAN);
default:
megdnn_throw(ssprintf(
"Unsupported input dtype %s for ElemwiseMultiType",
param[0].layout.dtype.name()));
};
break;
};
default:
megdnn_assert_internal(0);
}
#undef DISPATCH_MODE
#undef DISPATCH
}

void ElemwiseMultiTypeImpl::dest_type_bool_mode(
const ElemwiseOpParamN<2>& param, const TensorND& dst, Elemwise::Mode mode) {
megdnn_assert(param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv());
switch (mode) {
case Elemwise::Mode::EQ: {
switch (param[0].layout.dtype.enumv()) {
#define DISPATCH(_dt, _mode) \
case DTypeTrait<_dt>::enumv: { \
typedef ElemwiseBoolKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \
typename DTypeTrait<_dt>::ctype, dt_bool> \
KernImpl##_mode; \
dispatch_dst_bool_op< \
KernImpl##_mode, typename DTypeTrait<_dt>::ctype, dt_bool>( \
param, dst); \
break; \
};
#define DISPATCH_MODE(_mode) \
DISPATCH(megdnn::dtype::Float32, _mode); \
DNN_INC_FLOAT16(DISPATCH(megdnn::dtype::Float16, _mode);) \
DNN_INC_FLOAT16(DISPATCH(megdnn::dtype::BFloat16, _mode);) \
DISPATCH(megdnn::dtype::Int32, _mode); \
DISPATCH(megdnn::dtype::Int16, _mode); \
DISPATCH(megdnn::dtype::Int8, _mode); \
DISPATCH(megdnn::dtype::Uint8, _mode); \
DISPATCH(megdnn::dtype::Bool, _mode);
DISPATCH_MODE(EQ);
break;
default:
megdnn_throw(ssprintf(
"Unsupported input dtype %s for ElemwiseMultiType",
param[0].layout.dtype.name()));
};
break;
};
case Elemwise::Mode::NEQ: {
switch (param[0].layout.dtype.enumv()) {
DISPATCH_MODE(NEQ);
default:
megdnn_throw(ssprintf(
"Unsupported input dtype %s for ElemwiseMultiType",
param[0].layout.dtype.name()));
};
break;
};
case Elemwise::Mode::LT: {
switch (param[0].layout.dtype.enumv()) {
DISPATCH_MODE(LT);
default:
megdnn_throw(ssprintf(
"Unsupported input dtype %s for ElemwiseMultiType",
param[0].layout.dtype.name()));
};
break;
};
case Elemwise::Mode::LEQ: {
switch (param[0].layout.dtype.enumv()) {
DISPATCH_MODE(LEQ);
default:
megdnn_throw(ssprintf(
"Unsupported input dtype %s for ElemwiseMultiType",
param[0].layout.dtype.name()));
};
break;
};
default:
megdnn_assert_internal(0);
}
#undef DISPATCH_MODE
#undef DISPATCH
}

// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 3
- 1
dnn/test/common/checker.cpp View File

@@ -149,8 +149,9 @@ void copy_tensors(
//! use QuantizedS16 dtype in winograd_filter_preprocess now. //! use QuantizedS16 dtype in winograd_filter_preprocess now.
cb(::megdnn::dtype::QuantizedS16) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) cb(::megdnn::dtype::QuantizedS16) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
cb(::megdnn::dtype::Uint16) cb(::megdnn::dtype::QuantizedS1) cb(::megdnn::dtype::Uint16) cb(::megdnn::dtype::QuantizedS1)
cb(::megdnn::dtype::Bool)
#undef cb #undef cb
default : megdnn_trap();
default : megdnn_trap();
} }
} }


@@ -325,6 +326,7 @@ void CheckerHelper::do_exec(
m_output_canonizer(tensors_cur_host); m_output_canonizer(tensors_cur_host);
m_output_canonizer(tensors_naive); m_output_canonizer(tensors_naive);
} }

check_tensors(tensors_naive, tensors_cur_host); check_tensors(tensors_naive, tensors_cur_host);
if (m_extra_opr_impl) { if (m_extra_opr_impl) {
check_tensors(tensors_naive, *tensors_extra_opr_impl); check_tensors(tensors_naive, *tensors_extra_opr_impl);


+ 8
- 0
dnn/test/common/elemwise.cpp View File

@@ -756,6 +756,14 @@ DEF_TEST(all_modes) {


auto should_ignore = [handle](Mode mode) { auto should_ignore = [handle](Mode mode) {
MEGDNN_MARK_USED_VAR(mode); MEGDNN_MARK_USED_VAR(mode);
switch (mode) {
case Mode::NEQ:
case Mode::ISNAN:
case Mode::ISINF:
return true;
default:
break;
}
return false; return false;
}; };




+ 3
- 0
dnn/test/common/rng.cpp View File

@@ -195,6 +195,9 @@ void IIDRNG::gen(const TensorND& tensor) {
if (tensor.layout.dtype.enumv() == DTypeEnum::Uint16) { if (tensor.layout.dtype.enumv() == DTypeEnum::Uint16) {
return; return;
} }
if (tensor.layout.dtype.enumv() == DTypeEnum::Bool) {
return;
}
megdnn_assert( megdnn_assert(
0, "IIDRNG does not know how to generate value for DType %s", 0, "IIDRNG does not know how to generate value for DType %s",
tensor.layout.dtype.name()); tensor.layout.dtype.name());


+ 49
- 1
dnn/test/cuda/elemwise_multi_type.cpp View File

@@ -4,7 +4,6 @@
#include "test/cuda/benchmark.h" #include "test/cuda/benchmark.h"
#include "test/cuda/fixture.h" #include "test/cuda/fixture.h"
#include "test/cuda/utils.h" #include "test/cuda/utils.h"

#undef cuda_check #undef cuda_check
#include "src/cuda/utils.h" #include "src/cuda/utils.h"


@@ -143,6 +142,43 @@ static void run_test_q4(int arity, Checker<ElemwiseMultiType>& checker, Mode mod
} }
} }


static void run_test_bool(int arity, Checker<ElemwiseMultiType>& checker, Mode mode) {
for (DType type :
std::vector<DType>{{dtype::Int8()}, {dtype::Float32()}, {dtype::Float16()}}) {
if ((mode == Mode::ISNAN || mode == Mode::ISINF) && type == dtype::Int8()) {
continue;
}
checker.set_param(mode);
UniformIntRNG rng_int8{1, 1};
NormalRNG rng_normal{0, 1};

auto set_inp_rng = [&](DType dtype, size_t i) {
if (dtype.enumv() == DTypeEnum::Int8) {
checker.set_rng(i, &rng_int8);
} else if (
dtype.enumv() == DTypeEnum::Float32 ||
dtype.enumv() == DTypeEnum::Float16) {
checker.set_rng(i, &rng_normal);
} else {
megdnn_assert(0);
}
checker.set_dtype(i, dtype);
};
auto src_type = type;
for (int i = 0; i < arity; i++) {
set_inp_rng(src_type, i);
}

if (arity == 1) {
checker.execs({{3, 4, 5, 6}, {}});
} else if (arity == 2) {
checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {}});
} else {
megdnn_assert(0);
}
}
}

TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_UNARY) { TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_UNARY) {
Checker<ElemwiseMultiType> checker(handle_cuda()); Checker<ElemwiseMultiType> checker(handle_cuda());
for (auto mode : for (auto mode :
@@ -203,6 +239,18 @@ TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_BINARY) {
} }
} }


TEST_F(CUDA, ELEMWISE_BOOL_MODE_BINARY) {
using Mode = ElemwiseMultiType::Param::Mode;

Checker<ElemwiseMultiType> checker(handle_cuda());
for (auto mode : {Mode::EQ, Mode::NEQ, Mode::LT, Mode::LEQ}) {
run_test_bool(2, checker, mode);
}
for (auto mode : {Mode::ISNAN, Mode::ISINF}) {
run_test_bool(1, checker, mode);
}
}

TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_TENARY) { TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_TENARY) {
using Mode = ElemwiseMultiType::Param::Mode; using Mode = ElemwiseMultiType::Param::Mode;
Checker<ElemwiseMultiType> checker(handle_cuda()); Checker<ElemwiseMultiType> checker(handle_cuda());


+ 32
- 8
imperative/python/megengine/core/tensor/array_method.py View File

@@ -23,11 +23,25 @@ from .._imperative_rt.core2 import (
) )
from ..ops import builtin from ..ops import builtin
from . import amp from . import amp
from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph
from .utils import (
_normalize_axis,
astensor1d,
cast_tensors,
convert_inputs,
make_shape_tuple,
subgraph,
)


_ElwMod = builtin.Elemwise.Mode _ElwMod = builtin.Elemwise.Mode




def _elemwise_multi_type(*args, mode, **kwargs):
op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
args = convert_inputs(*args)
(result,) = apply(op, *args)
return result


def _elwise_apply(args, mode): def _elwise_apply(args, mode):
op = builtin.Elemwise(mode) op = builtin.Elemwise(mode)
(result,) = apply(op, *args) (result,) = apply(op, *args)
@@ -234,13 +248,23 @@ class ArrayMethodMixin(abc.ABC):


__hash__ = None # due to __eq__ diviates from python convention __hash__ = None # due to __eq__ diviates from python convention


__lt__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LT).astype("bool")
__le__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LEQ).astype("bool")
__gt__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LT).astype("bool")
__ge__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LEQ).astype("bool")
__eq__ = lambda self, value: _elwise(self, value, mode=_ElwMod.EQ).astype("bool")
__ne__ = lambda self, value: _elwise(
_elwise(self, value, mode=_ElwMod.EQ).astype("bool"), mode=_ElwMod.NOT,
__lt__ = lambda self, value: _elemwise_multi_type(
self, value, mode="lt", dtype="Bool"
)
__le__ = lambda self, value: _elemwise_multi_type(
self, value, mode="leq", dtype="Bool"
)
__gt__ = lambda self, value: _elemwise_multi_type(
value, self, mode="lt", dtype="Bool"
)
__ge__ = lambda self, value: _elemwise_multi_type(
value, self, mode="leq", dtype="Bool"
)
__eq__ = lambda self, value: _elemwise_multi_type(
self, value, mode="eq", dtype="Bool"
)
__ne__ = lambda self, value: _elemwise_multi_type(
self, value, mode="neq", dtype="Bool"
) )


__neg__ = _unary_elwise(_ElwMod.NEGATE) __neg__ = _unary_elwise(_ElwMod.NEGATE)


+ 3
- 3
imperative/python/megengine/functional/math.py View File

@@ -10,7 +10,7 @@ from ..core.tensor.array_method import _matmul
from ..core.tensor.utils import _normalize_axis from ..core.tensor.utils import _normalize_axis
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.deprecation import deprecated_kwargs_default from ..utils.deprecation import deprecated_kwargs_default
from .elemwise import clip
from .elemwise import _elemwise_multi_type, clip
from .tensor import expand_dims, squeeze from .tensor import expand_dims, squeeze


__all__ = [ __all__ = [
@@ -52,7 +52,7 @@ def isnan(inp: Tensor) -> Tensor:
>>> F.isnan(x).numpy() >>> F.isnan(x).numpy()
array([False, True, False]) array([False, True, False])
""" """
return inp != inp
return _elemwise_multi_type(inp, mode="isnan", dtype="Bool")




def isinf(inp: Tensor) -> Tensor: def isinf(inp: Tensor) -> Tensor:
@@ -69,7 +69,7 @@ def isinf(inp: Tensor) -> Tensor:
>>> F.isinf(x).numpy() >>> F.isinf(x).numpy()
array([False, True, False]) array([False, True, False])
""" """
return abs(inp).astype("float32") == float("inf")
return _elemwise_multi_type(inp, mode="isinf", dtype="Bool")




def sign(inp: Tensor): def sign(inp: Tensor):


+ 2
- 2
src/jit/impl/ast_c.cpp View File

@@ -133,9 +133,9 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() {
0.f}) / 0.f}) /
6.f), 6.f),
}; };
mgb_assert(map.size() + 16 == opr::Elemwise::Param::MODE_NR_MEMBER);
mgb_assert(map.size() + 19 == opr::Elemwise::Param::MODE_NR_MEMBER);
// unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH,
// ERFINV, ERFCINV, NOT, AND, OR, XOR
// ERFINV, ERFCINV, NOT, AND, OR, XOR, NEQ, ISNAN, ISINF
return map; return map;
#undef ADD_OPR #undef ADD_OPR
} }


+ 2
- 2
src/opr/test/basic_arith/elemwise.cpp View File

@@ -756,8 +756,8 @@ TYPED_TEST(TestOprBasicArithTernaryElemwise, Float32) {


TEST(TestOprBasicArithElemwise, CheckAllModeTested) { TEST(TestOprBasicArithElemwise, CheckAllModeTested) {
size_t nr_member = opr::Elemwise::Param::MODE_NR_MEMBER; size_t nr_member = opr::Elemwise::Param::MODE_NR_MEMBER;
ASSERT_EQ(nr_member, tested_mode.size() + 4);
// Not using TestRunner: NOT, AND, OR, XOR
ASSERT_EQ(nr_member, tested_mode.size() + 7);
// Not using TestRunner: NOT, AND, OR, XOR, NEQ, ISNAN, ISINF
} }
#define TEST_OPR_BASIC_ARITH_UNARY_BOOL(_mode, _op) \ #define TEST_OPR_BASIC_ARITH_UNARY_BOOL(_mode, _op) \
TEST(TestOprBasicArithElemwise, _mode) { \ TEST(TestOprBasicArithElemwise, _mode) { \


Loading…
Cancel
Save