GitOrigin-RevId: fd0134fca2
HuaHua404-patch-4
@@ -6,6 +6,10 @@ SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32'), | |||||
SUPPORT_DTYPES_Q4 = [('dt_qint4', 'dt_qint4'), ('dt_quint4', 'dt_quint4')] | SUPPORT_DTYPES_Q4 = [('dt_qint4', 'dt_qint4'), ('dt_quint4', 'dt_quint4')] | ||||
SUPPORT_QINT32_DTYPES_Q4 = [('dt_qint32', 'dt_qint4'), ('dt_qint32', 'dt_quint4')] | SUPPORT_QINT32_DTYPES_Q4 = [('dt_qint32', 'dt_qint4'), ('dt_qint32', 'dt_quint4')] | ||||
SUPPORT_ARRITY2_DTYPES = ['dt_int32', 'dt_uint8', 'dt_int8', 'dt_int16', 'dt_bool', 'dt_float32', | |||||
'dt_float16', 'dt_bfloat16'] | |||||
SUPPORT_ARRITY1_DTYPES = ['dt_float32','dt_float16', 'dt_bfloat16'] | |||||
MODES = { | MODES = { | ||||
1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | 1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | ||||
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | ||||
@@ -34,3 +38,11 @@ QINT32_MODES = { | |||||
2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID', | 2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID', | ||||
'FUSE_ADD_TANH', 'FUSE_ADD_H_SWISH'] | 'FUSE_ADD_TANH', 'FUSE_ADD_H_SWISH'] | ||||
} | } | ||||
ARRITY1_BOOL_MODES = { | |||||
1: ['ISINF','ISNAN'], | |||||
} | |||||
ARRITY2_BOOL_MODES = { | |||||
2: ['EQ','LEQ','NEQ','LT'], | |||||
} |
@@ -421,6 +421,9 @@ pdef('Elemwise').add_enum( | |||||
Doc('GELU = 58', 'unary: x Phi(x)'), | Doc('GELU = 58', 'unary: x Phi(x)'), | ||||
Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'), | Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'), | ||||
Doc('COND_LT_MOV = 60', 'ternary: x < y ? z : 0'), | Doc('COND_LT_MOV = 60', 'ternary: x < y ? z : 0'), | ||||
Doc('NEQ = 61', 'binary: x != y'), | |||||
Doc('ISNAN = 62', 'unary: isnan(x)'), | |||||
Doc('ISINF = 63', 'unary: isinf(x)'), | |||||
) | ) | ||||
pdef('ElemwiseMultiType').add_enum( | pdef('ElemwiseMultiType').add_enum( | ||||
@@ -513,6 +516,12 @@ pdef('ElemwiseMultiType').add_enum( | |||||
'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and ' | 'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and ' | ||||
'``c`` float32, and the result is float32.'), | '``c`` float32, and the result is float32.'), | ||||
Doc('QCOND_LT_MOV = 57', 'quantized cond_lt_mov'), | Doc('QCOND_LT_MOV = 57', 'quantized cond_lt_mov'), | ||||
Doc('EQ = 58', 'eq'), | |||||
Doc('NEQ = 59', 'eq'), | |||||
Doc('LT = 60', 'lt'), | |||||
Doc('LEQ = 61', 'leq'), | |||||
Doc('ISNAN = 62', 'isnan'), | |||||
Doc('ISINF = 63', 'isinf') | |||||
) | ) | ||||
pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | ||||
@@ -10,6 +10,7 @@ | |||||
#include <cmath> | #include <cmath> | ||||
#include <cstdlib> | #include <cstdlib> | ||||
#include "math.h" | |||||
#if MEGDNN_CC_HOST | #if MEGDNN_CC_HOST | ||||
#include <algorithm> | #include <algorithm> | ||||
@@ -272,6 +273,57 @@ DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z); | |||||
#undef DEF_KERN_AD | #undef DEF_KERN_AD | ||||
#undef DEF_KERN | #undef DEF_KERN | ||||
#undef DEF_KERN_FLOAT | |||||
#undef DEF_KERN_INT | |||||
#undef DEF_KERN_ALL | |||||
/* ================== bool kernels ================== */ | |||||
//! define kernel | |||||
template <megcorePlatform_t plat, uint32_t mode, typename stype, typename dtype> | |||||
struct ElemwiseBoolKern; | |||||
#define DEF_KERN(_ctype, _dtype, _mode, _imp) \ | |||||
template <megcorePlatform_t plat> \ | |||||
struct ElemwiseBoolKern< \ | |||||
plat, param_enumv::Elemwise::Mode::_mode, _ctype, _dtype> { \ | |||||
typedef _ctype ctype; \ | |||||
static __host__ __device__ _dtype apply(KERN_SIG) { return _dtype(_imp); } \ | |||||
} | |||||
//! define kernel for all float types | |||||
#define DEF_KERN_FLOAT(_mode, _imp) \ | |||||
DEF_KERN(dt_float32, dt_bool, _mode, _imp); \ | |||||
DNN_INC_FLOAT16(DEF_KERN(dt_float16, dt_bool, _mode, _imp);) \ | |||||
DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, dt_bool, _mode, _imp);) | |||||
//! define kernel for all int types | |||||
#define DEF_KERN_INT(_mode, _imp) \ | |||||
DEF_KERN(dt_int32, dt_bool, _mode, _imp); \ | |||||
DEF_KERN(dt_int16, dt_bool, _mode, _imp); \ | |||||
DEF_KERN(dt_int8, dt_bool, _mode, _imp); \ | |||||
DEF_KERN(dt_uint8, dt_bool, _mode, _imp); | |||||
//! define kernel for all ctypes | |||||
#define DEF_KERN_ALL(_mode, _imp) \ | |||||
DEF_KERN_INT(_mode, _imp); \ | |||||
DEF_KERN_FLOAT(_mode, _imp); \ | |||||
DEF_KERN(dt_bool, dt_bool, _mode, _imp); | |||||
#define KERN_SIG ctype x | |||||
DEF_KERN_FLOAT(ISNAN, isnan(float(x))); | |||||
DEF_KERN_FLOAT(ISINF, isinf(float(x))); | |||||
#undef KERN_SIG | |||||
#define KERN_SIG ctype x, ctype y | |||||
DEF_KERN_ALL(LT, x < y); | |||||
DEF_KERN_ALL(LEQ, x <= y); | |||||
DEF_KERN_ALL(EQ, x == y); | |||||
DEF_KERN_ALL(NEQ, x != y); | |||||
#undef KERN_SIG | |||||
#undef DEF_KERN_AD | |||||
#undef DEF_KERN | |||||
#undef DEF_KERN_FLOAT | |||||
#undef DEF_KERN_INT | |||||
#undef DEF_KERN_ALL | |||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -28,7 +28,6 @@ MEGDNN_HOST MEGDNN_DEVICE dtype round_shr_saturate(stype x, int k) { | |||||
} | } | ||||
return static_cast<dtype>(result); | return static_cast<dtype>(result); | ||||
} | } | ||||
} // namespace elemwise_multi_type | } // namespace elemwise_multi_type | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -31,6 +31,14 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||||
return func; | return func; | ||||
}; | }; | ||||
auto make_not_check_dtype_func = []() { | |||||
auto func = [](DType dtype) { | |||||
megdnn_assert( | |||||
true, "This function is to not check the dtype %s", dtype.name()); | |||||
}; | |||||
return func; | |||||
}; | |||||
auto make_check_category = [](DTypeCategory expected) { | auto make_check_category = [](DTypeCategory expected) { | ||||
auto func = [expected](DType dtype) { | auto func = [expected](DType dtype) { | ||||
megdnn_assert(expected == dtype.category()); | megdnn_assert(expected == dtype.category()); | ||||
@@ -126,6 +134,23 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||||
dst.need_specify_out_dtype = true; | dst.need_specify_out_dtype = true; | ||||
}; | }; | ||||
auto init_bool_unary_op = [&](ModeTrait& dst, const char* name) { | |||||
dst.arity = 1; | |||||
dst.check_inp[0] = make_check_category(DTypeCategory::FLOAT); | |||||
dst.check_out = make_out_dtype_func(dtype::Bool()); | |||||
dst.name = name; | |||||
dst.need_specify_out_dtype = true; | |||||
}; | |||||
auto init_bool_binary_op = [&](ModeTrait& dst, const char* name) { | |||||
dst.arity = 2; | |||||
dst.check_inp[0] = make_not_check_dtype_func(); | |||||
dst.check_inp[1] = make_not_check_dtype_func(); | |||||
dst.check_out = make_out_dtype_func(dtype::Bool()); | |||||
dst.name = name; | |||||
dst.need_specify_out_dtype = true; | |||||
}; | |||||
auto init_quantized_binary_op = [&](ModeTrait& dst, const char* name) { | auto init_quantized_binary_op = [&](ModeTrait& dst, const char* name) { | ||||
dst.arity = 2; | dst.arity = 2; | ||||
dst.check_inp[0] = make_check_category(DTypeCategory::QUANTIZED); | dst.check_inp[0] = make_check_category(DTypeCategory::QUANTIZED); | ||||
@@ -240,6 +265,13 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||||
SET(init_quantized_ternary_op, QFUSE_MUL_ADD3); | SET(init_quantized_ternary_op, QFUSE_MUL_ADD3); | ||||
SET(init_quantized_ternary_op, QCOND_LEQ_MOV); | SET(init_quantized_ternary_op, QCOND_LEQ_MOV); | ||||
SET(init_quantized_ternary_op, QCOND_LT_MOV); | SET(init_quantized_ternary_op, QCOND_LT_MOV); | ||||
SET(init_bool_binary_op, LT); | |||||
SET(init_bool_binary_op, LEQ); | |||||
SET(init_bool_binary_op, EQ); | |||||
SET(init_bool_binary_op, NEQ); | |||||
SET(init_bool_unary_op, ISNAN); | |||||
SET(init_bool_unary_op, ISINF); | |||||
#undef SET | #undef SET | ||||
} | } | ||||
@@ -273,4 +305,4 @@ void ElemwiseMultiType::check_layout_and_broadcast( | |||||
megdnn_assert(dst.is_contiguous()); | megdnn_assert(dst.is_contiguous()); | ||||
} | } | ||||
// vim: syntax=cpp.doxygen | |||||
// vim: syntax=cpp.doxygen |
@@ -9,6 +9,12 @@ using namespace megdnn; | |||||
make_elemwise_op_param<_n>(src, dst), dst, Elemwise::Mode::_MODE); \ | make_elemwise_op_param<_n>(src, dst), dst, Elemwise::Mode::_MODE); \ | ||||
break | break | ||||
#define ON_BOOL_MODE(_MODE, _n) \ | |||||
case Mode::_MODE: \ | |||||
dest_type_bool_mode( \ | |||||
make_elemwise_op_param<_n>(src, dst), dst, Elemwise::Mode::_MODE); \ | |||||
break | |||||
void ElemwiseMultiTypeImplHelper::exec( | void ElemwiseMultiTypeImplHelper::exec( | ||||
_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) { | _megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) { | ||||
switch (m_param.mode) { | switch (m_param.mode) { | ||||
@@ -96,6 +102,13 @@ void ElemwiseMultiTypeImplHelper::exec( | |||||
ON_QUANTIZED_MODE(FUSE_MUL_ADD3, 3); | ON_QUANTIZED_MODE(FUSE_MUL_ADD3, 3); | ||||
ON_QUANTIZED_MODE(COND_LEQ_MOV, 3); | ON_QUANTIZED_MODE(COND_LEQ_MOV, 3); | ||||
ON_QUANTIZED_MODE(COND_LT_MOV, 3); | ON_QUANTIZED_MODE(COND_LT_MOV, 3); | ||||
ON_BOOL_MODE(LT, 2); | |||||
ON_BOOL_MODE(LEQ, 2); | |||||
ON_BOOL_MODE(EQ, 2); | |||||
ON_BOOL_MODE(NEQ, 2); | |||||
ON_BOOL_MODE(ISNAN, 1); | |||||
ON_BOOL_MODE(ISINF, 1); | |||||
default: | default: | ||||
megdnn_throw("invalid mode"); | megdnn_throw("invalid mode"); | ||||
} | } | ||||
@@ -73,6 +73,24 @@ protected: | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst, | const ElemwiseOpParamN<2>& param, const TensorND& dst, | ||||
Elemwise::Mode mode) = 0; | Elemwise::Mode mode) = 0; | ||||
virtual void dest_type_bool_mode( | |||||
const ElemwiseOpParamN<1>& param, const TensorND& dst, | |||||
Elemwise::Mode mode) { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MEGDNN_MARK_USED_VAR(dst); | |||||
MEGDNN_MARK_USED_VAR(mode); | |||||
megdnn_throw("Unrealized except arm_common"); | |||||
} | |||||
virtual void dest_type_bool_mode( | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst, | |||||
Elemwise::Mode mode) { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MEGDNN_MARK_USED_VAR(dst); | |||||
MEGDNN_MARK_USED_VAR(mode); | |||||
megdnn_throw("Unrealized except arm_common"); | |||||
} | |||||
virtual void on_quantized_mode( | virtual void on_quantized_mode( | ||||
const ElemwiseOpParamN<3>& param, const TensorND& dst, | const ElemwiseOpParamN<3>& param, const TensorND& dst, | ||||
Elemwise::Mode mode) { | Elemwise::Mode mode) { | ||||
@@ -0,0 +1,27 @@ | |||||
#pragma once | |||||
#ifndef KERN_IMPL_MODE | |||||
#error "KERN_IMPL_MODE, KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE must be defined" | |||||
#endif | |||||
#include "src/cuda/elemwise_multi_type/kern_ops.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
#define cb(_m) \ | |||||
typedef ElemwiseBoolKern< \ | |||||
megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, KERN_IMPL_STYPE, \ | |||||
KERN_IMPL_DTYPE> \ | |||||
KernImpl; \ | |||||
typedef kern_ops_quantized::QuantizedMultiTypeOp< \ | |||||
KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE, KernImpl> \ | |||||
Op; \ | |||||
INST_RUN_ELEMWISE(Op, KERN_IMPL_STYPE, KERN_IMPL_ARITY); | |||||
KERN_IMPL_MODE(cb) | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -4,7 +4,6 @@ | |||||
#include "src/cuda/elemwise_multi_type/kern.cuh" | #include "src/cuda/elemwise_multi_type/kern.cuh" | ||||
#include "src/cuda/integer_subbyte_utils.cuh" | #include "src/cuda/integer_subbyte_utils.cuh" | ||||
#include "src/cuda/utils.cuh" | #include "src/cuda/utils.cuh" | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace cuda { | namespace cuda { | ||||
using namespace elemwise_intl; | using namespace elemwise_intl; | ||||
@@ -122,6 +121,7 @@ struct QuantizedMultiTypeOp< | |||||
(std::is_same<ctype_src, dt_qint8>::value || | (std::is_same<ctype_src, dt_qint8>::value || | ||||
std::is_same<ctype_src, dt_qint32>::value || | std::is_same<ctype_src, dt_qint32>::value || | ||||
std::is_same<ctype_src, dt_quint8>::value) && | std::is_same<ctype_src, dt_quint8>::value) && | ||||
!std::is_same<ctype_dst, dt_bool>::value && | |||||
IsNotTypeQ4<ctype_dst>::value>::type> { | IsNotTypeQ4<ctype_dst>::value>::type> { | ||||
ctype_dst* dst; | ctype_dst* dst; | ||||
CudaDTypeParam<ctype_dst> dst_param; | CudaDTypeParam<ctype_dst> dst_param; | ||||
@@ -160,11 +160,43 @@ struct QuantizedMultiTypeOp< | |||||
template <typename ctype_src, typename ctype_dst, typename KernImpl> | template <typename ctype_src, typename ctype_dst, typename KernImpl> | ||||
struct QuantizedMultiTypeOp< | struct QuantizedMultiTypeOp< | ||||
1, ctype_src, ctype_dst, KernImpl, | |||||
typename std::enable_if<std::is_same<ctype_dst, dt_bool>::value>::type> { | |||||
ctype_dst* dst; | |||||
typedef typename elemwise_intl::VectTypeTrait<ctype_src>::vect_type src_vect_type; | |||||
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type dst_vect_type; | |||||
#if !MEGDNN_CC_CUDA | |||||
QuantizedMultiTypeOp(ctype_dst* m_dst) : dst{m_dst} {} | |||||
#endif | |||||
#if MEGDNN_CC_CUDA | |||||
__device__ __forceinline__ ctype_dst apply(ctype_src v1) { | |||||
ctype_dst rv = KernImpl::apply(v1); | |||||
return rv; | |||||
} | |||||
__device__ __forceinline__ void operator()(uint32_t idx, ctype_src a) { | |||||
dst[idx] = KernImpl::apply(a); | |||||
} | |||||
__device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a) { | |||||
ctype_src a_x(a.x), a_y(a.y), a_z(a.z), a_w(a.w); | |||||
ctype_dst x = apply(a_x), y = apply(a_y), z = apply(a_z), w = apply(a_w); | |||||
*(dst_vect_type*)(&dst[idx]) = | |||||
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y, z, w); | |||||
} | |||||
#endif | |||||
}; | |||||
template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
struct QuantizedMultiTypeOp< | |||||
2, ctype_src, ctype_dst, KernImpl, | 2, ctype_src, ctype_dst, KernImpl, | ||||
typename std::enable_if< | typename std::enable_if< | ||||
(std::is_same<ctype_src, dt_qint8>::value || | (std::is_same<ctype_src, dt_qint8>::value || | ||||
std::is_same<ctype_src, dt_qint32>::value || | std::is_same<ctype_src, dt_qint32>::value || | ||||
std::is_same<ctype_src, dt_quint8>::value) && | std::is_same<ctype_src, dt_quint8>::value) && | ||||
!std::is_same<ctype_dst, dt_bool>::value && | |||||
IsNotTypeQ4<ctype_dst>::value>::type> { | IsNotTypeQ4<ctype_dst>::value>::type> { | ||||
ctype_dst* dst; | ctype_dst* dst; | ||||
CudaDTypeParam<ctype_dst> dst_param; | CudaDTypeParam<ctype_dst> dst_param; | ||||
@@ -208,6 +240,40 @@ struct QuantizedMultiTypeOp< | |||||
template <typename ctype_src, typename ctype_dst, typename KernImpl> | template <typename ctype_src, typename ctype_dst, typename KernImpl> | ||||
struct QuantizedMultiTypeOp< | struct QuantizedMultiTypeOp< | ||||
2, ctype_src, ctype_dst, KernImpl, | |||||
typename std::enable_if<(std::is_same<ctype_dst, dt_bool>::value)>::type> { | |||||
ctype_dst* dst; | |||||
typedef typename elemwise_intl::VectTypeTrait<ctype_src>::vect_type src_vect_type; | |||||
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type dst_vect_type; | |||||
#if !MEGDNN_CC_CUDA | |||||
QuantizedMultiTypeOp(ctype_dst* m_dst) : dst{m_dst} {} | |||||
#endif | |||||
#if MEGDNN_CC_CUDA | |||||
__device__ __forceinline__ ctype_dst apply(ctype_src v1, ctype_src v2) { | |||||
ctype_dst rv = KernImpl::apply(v1, v2); | |||||
return rv; | |||||
} | |||||
__device__ __forceinline__ void operator()(uint32_t idx, ctype_src a, ctype_src b) { | |||||
dst[idx] = KernImpl::apply(a, b); | |||||
} | |||||
__device__ __forceinline__ void operator()( | |||||
uint32_t idx, src_vect_type a, src_vect_type b) { | |||||
ctype_src a_x(a.x), a_y(a.y), a_z(a.z), a_w(a.w), b_x(b.x), b_y(b.y), b_z(b.z), | |||||
b_w(b.w); | |||||
ctype_dst x = apply(a_x, b_x), y = apply(a_y, b_y), z = apply(a_z, b_z), | |||||
w = apply(a_w, b_w); | |||||
*(dst_vect_type*)(&dst[idx]) = | |||||
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y, z, w); | |||||
} | |||||
#endif | |||||
}; | |||||
template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
struct QuantizedMultiTypeOp< | |||||
3, ctype_src, ctype_dst, KernImpl, | 3, ctype_src, ctype_dst, KernImpl, | ||||
typename std::enable_if< | typename std::enable_if< | ||||
(std::is_same<ctype_src, dt_qint8>::value || | (std::is_same<ctype_src, dt_qint8>::value || | ||||
@@ -262,7 +328,8 @@ template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
struct QuantizedMultiTypeOp< | struct QuantizedMultiTypeOp< | ||||
1, ctype_src, ctype_dst, KernImpl, | 1, ctype_src, ctype_dst, KernImpl, | ||||
typename std::enable_if< | typename std::enable_if< | ||||
IsTypeQ4<ctype_src>::value && IsNotTypeQ4<ctype_dst>::value>::type> { | |||||
IsTypeQ4<ctype_src>::value && IsNotTypeQ4<ctype_dst>::value && | |||||
!std::is_same<ctype_dst, dt_bool>::value>::type> { | |||||
ctype_dst* dst; | ctype_dst* dst; | ||||
CudaDTypeParam<ctype_dst> dst_param; | CudaDTypeParam<ctype_dst> dst_param; | ||||
CudaDTypeParam<ctype_src> param_a; | CudaDTypeParam<ctype_src> param_a; | ||||
@@ -293,7 +360,8 @@ template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
struct QuantizedMultiTypeOp< | struct QuantizedMultiTypeOp< | ||||
2, ctype_src, ctype_dst, KernImpl, | 2, ctype_src, ctype_dst, KernImpl, | ||||
typename std::enable_if< | typename std::enable_if< | ||||
IsTypeQ4<ctype_src>::value && IsNotTypeQ4<ctype_dst>::value>::type> { | |||||
IsTypeQ4<ctype_src>::value && IsNotTypeQ4<ctype_dst>::value && | |||||
!std::is_same<ctype_dst, dt_bool>::value>::type> { | |||||
ctype_dst* dst; | ctype_dst* dst; | ||||
CudaDTypeParam<ctype_dst> dst_param; | CudaDTypeParam<ctype_dst> dst_param; | ||||
CudaDTypeParam<ctype_src> param_a, param_b; | CudaDTypeParam<ctype_src> param_a, param_b; | ||||
@@ -326,7 +394,8 @@ template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
struct QuantizedMultiTypeOp< | struct QuantizedMultiTypeOp< | ||||
1, ctype_src, ctype_dst, KernImpl, | 1, ctype_src, ctype_dst, KernImpl, | ||||
typename std::enable_if< | typename std::enable_if< | ||||
IsTypeQ4<ctype_src>::value && IsTypeQ4<ctype_dst>::value>::type> { | |||||
IsTypeQ4<ctype_src>::value && IsTypeQ4<ctype_dst>::value && | |||||
!std::is_same<ctype_dst, dt_bool>::value>::type> { | |||||
using src_storage = typename elemwise_intl::VectTypeTrait<ctype_src>::Storage; | using src_storage = typename elemwise_intl::VectTypeTrait<ctype_src>::Storage; | ||||
using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | ||||
dst_storage* dst; | dst_storage* dst; | ||||
@@ -371,6 +440,7 @@ struct QuantizedMultiTypeOp< | |||||
(std::is_same<ctype_src, dt_qint8>::value || | (std::is_same<ctype_src, dt_qint8>::value || | ||||
std::is_same<ctype_src, dt_qint32>::value || | std::is_same<ctype_src, dt_qint32>::value || | ||||
std::is_same<ctype_src, dt_quint8>::value) && | std::is_same<ctype_src, dt_quint8>::value) && | ||||
!std::is_same<ctype_dst, dt_bool>::value && | |||||
IsTypeQ4<ctype_dst>::value>::type> { | IsTypeQ4<ctype_dst>::value>::type> { | ||||
using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | ||||
dst_storage* dst; | dst_storage* dst; | ||||
@@ -407,7 +477,8 @@ template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
struct QuantizedMultiTypeOp< | struct QuantizedMultiTypeOp< | ||||
2, ctype_src, ctype_dst, KernImpl, | 2, ctype_src, ctype_dst, KernImpl, | ||||
typename std::enable_if< | typename std::enable_if< | ||||
IsTypeQ4<ctype_src>::value && IsTypeQ4<ctype_dst>::value>::type> { | |||||
IsTypeQ4<ctype_src>::value && IsTypeQ4<ctype_dst>::value && | |||||
!std::is_same<ctype_dst, dt_bool>::value>::type> { | |||||
using src_storage = typename elemwise_intl::VectTypeTrait<ctype_src>::Storage; | using src_storage = typename elemwise_intl::VectTypeTrait<ctype_src>::Storage; | ||||
using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | ||||
dst_storage* dst; | dst_storage* dst; | ||||
@@ -460,6 +531,7 @@ struct QuantizedMultiTypeOp< | |||||
(std::is_same<ctype_src, dt_qint8>::value || | (std::is_same<ctype_src, dt_qint8>::value || | ||||
std::is_same<ctype_src, dt_qint32>::value || | std::is_same<ctype_src, dt_qint32>::value || | ||||
std::is_same<ctype_src, dt_quint8>::value) && | std::is_same<ctype_src, dt_quint8>::value) && | ||||
!std::is_same<ctype_dst, dt_bool>::value && | |||||
IsTypeQ4<ctype_dst>::value>::type> { | IsTypeQ4<ctype_dst>::value>::type> { | ||||
using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | using dst_storage = typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | ||||
dst_storage* dst; | dst_storage* dst; | ||||
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_bfloat16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_bool | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_float16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_float32 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_int16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_int32 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_int8 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_uint8 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISINF, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_bfloat16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISINF, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_float16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISINF, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_float32 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISNAN, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_bfloat16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISNAN, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_float16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISNAN, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_float32 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_bfloat16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_bool | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_float16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_float32 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_int16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_int32 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_int8 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_uint8 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_bfloat16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_bool | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_float16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_float32 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_int16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_int32 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_int8 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_uint8 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_bfloat16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_bool | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_float16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_float32 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_int16 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_int32 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_int8 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls_bool.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_uint8 | |||||
#define KERN_IMPL_DTYPE dt_bool | |||||
#include "../kern_impl_bool.inl" |
@@ -295,6 +295,60 @@ IMPL_MODE_DISPATCHER(2, dt_qint32, dt_quint4); | |||||
#undef _cb_dispatch_mode | #undef _cb_dispatch_mode | ||||
#undef IMPL_MODE_DISPATCHER | #undef IMPL_MODE_DISPATCHER | ||||
#define _cb_dispatch_mode(_m) \ | |||||
case param::Elemwise::Mode::_m: \ | |||||
do { \ | |||||
using KernImpl = ElemwiseBoolKern< \ | |||||
megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, src_ctype, \ | |||||
dt_bool>; \ | |||||
using Op = kern_ops_quantized::QuantizedMultiTypeOp< \ | |||||
arity, src_ctype, bool, KernImpl>; \ | |||||
dst_ctype* dst = dst_tensor.ptr<dst_ctype>(); \ | |||||
Op op(dst); \ | |||||
return run_elemwise<Op, src_ctype, arity>(src, stream, op); \ | |||||
} while (0); | |||||
#define IMPL_MODE_DISPATCHER_BOOL(_arity, _src_ctype, _dst_ctype) \ | |||||
template <> \ | |||||
struct ModeDispatcher<_arity, _src_ctype, _dst_ctype> { \ | |||||
static constexpr int arity = _arity; \ | |||||
using src_ctype = _src_ctype; \ | |||||
using dst_ctype = _dst_ctype; \ | |||||
static void run( \ | |||||
const ElemwiseOpParamN<_arity>& src, const TensorND& dst_tensor, \ | |||||
param::Elemwise::Mode mode, cudaStream_t stream) { \ | |||||
switch (mode) { \ | |||||
FOREACH(_cb_dispatch_mode) \ | |||||
default: \ | |||||
megdnn_throw("bad mode"); \ | |||||
} \ | |||||
} \ | |||||
} | |||||
#define FOREACH(cb) \ | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) \ | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) \ | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) \ | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) | |||||
IMPL_MODE_DISPATCHER_BOOL(2, dt_int8, dt_bool); | |||||
IMPL_MODE_DISPATCHER_BOOL(2, dt_float32, dt_bool); | |||||
IMPL_MODE_DISPATCHER_BOOL(2, dt_bfloat16, dt_bool); | |||||
IMPL_MODE_DISPATCHER_BOOL(2, dt_float16, dt_bool); | |||||
IMPL_MODE_DISPATCHER_BOOL(2, dt_int16, dt_bool); | |||||
IMPL_MODE_DISPATCHER_BOOL(2, dt_int32, dt_bool); | |||||
IMPL_MODE_DISPATCHER_BOOL(2, dt_bool, dt_bool); | |||||
IMPL_MODE_DISPATCHER_BOOL(2, dt_uint8, dt_bool); | |||||
#undef FOREACH | |||||
#define FOREACH(cb) \ | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(ISNAN, cb) \ | |||||
MEGDNN_ELEMWISE_MODE_ENABLE(ISINF, cb) | |||||
IMPL_MODE_DISPATCHER_BOOL(1, dt_float16, dt_bool); | |||||
IMPL_MODE_DISPATCHER_BOOL(1, dt_float32, dt_bool); | |||||
IMPL_MODE_DISPATCHER_BOOL(1, dt_bfloat16, dt_bool); | |||||
#undef FOREACH | |||||
#undef _cb_dispatch_mode | |||||
#undef IMPL_MODE_DISPATCHER_BOOL | |||||
template <typename ctype_src> | template <typename ctype_src> | ||||
void dispatch_src_ctype( | void dispatch_src_ctype( | ||||
const ElemwiseOpParamN<1>&, const TensorND& dst_tensor, Elemwise::Mode, | const ElemwiseOpParamN<1>&, const TensorND& dst_tensor, Elemwise::Mode, | ||||
@@ -578,6 +632,62 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( | |||||
#undef DISPATCH | #undef DISPATCH | ||||
} | } | ||||
void ElemwiseMultiTypeImpl::dest_type_bool_mode( | |||||
const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor, | |||||
Elemwise::Mode mode) { | |||||
auto stream = cuda_stream(this->handle()); | |||||
switch (param[0].layout.dtype.enumv()) { | |||||
#define DISPATCH(_dt) \ | |||||
case DTypeTrait<_dt>::enumv: { \ | |||||
ModeDispatcher<1, typename DTypeTrait<_dt>::ctype, bool>::run( \ | |||||
param, dst_tensor, mode, stream); \ | |||||
break; \ | |||||
} | |||||
DISPATCH(dtype::Float32); | |||||
DISPATCH(dtype::Float16); | |||||
DISPATCH(dtype::BFloat16); | |||||
default: | |||||
megdnn_throw(ssprintf( | |||||
"Unsupported input dtype %s for ElemwiseMultiType", | |||||
param[0].layout.dtype.name())); | |||||
} | |||||
#undef DISPATCH | |||||
} | |||||
void ElemwiseMultiTypeImpl::dest_type_bool_mode( | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor, | |||||
Elemwise::Mode mode) { | |||||
megdnn_assert(param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv()); | |||||
auto stream = cuda_stream(this->handle()); | |||||
switch (param[0].layout.dtype.enumv()) { | |||||
#define DISPATCH(_dt) \ | |||||
case DTypeTrait<_dt>::enumv: { \ | |||||
ModeDispatcher<2, typename DTypeTrait<_dt>::ctype, bool>::run( \ | |||||
param, dst_tensor, mode, stream); \ | |||||
break; \ | |||||
} | |||||
DISPATCH(dtype::Int8); | |||||
DISPATCH(dtype::Float32); | |||||
DISPATCH(dtype::BFloat16); | |||||
DISPATCH(dtype::Bool); | |||||
DISPATCH(dtype::Float16); | |||||
DISPATCH(dtype::Int16); | |||||
DISPATCH(dtype::Int32); | |||||
DISPATCH(dtype::Uint8); | |||||
default: | |||||
megdnn_throw(ssprintf( | |||||
"Unsupported input dtype %s for ElemwiseMultiType", | |||||
param[0].layout.dtype.name())); | |||||
} | |||||
#undef DISPATCH | |||||
} | |||||
void ElemwiseMultiTypeImpl::on_quantized_mode( | void ElemwiseMultiTypeImpl::on_quantized_mode( | ||||
const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor, | const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor, | ||||
Elemwise::Mode mode) { | Elemwise::Mode mode) { | ||||
@@ -36,6 +36,14 @@ class ElemwiseMultiTypeImpl final : public ElemwiseMultiTypeImplHelper { | |||||
const ElemwiseOpParamN<3>& param, const TensorND& dst, | const ElemwiseOpParamN<3>& param, const TensorND& dst, | ||||
Elemwise::Mode mode) override; | Elemwise::Mode mode) override; | ||||
void dest_type_bool_mode( | |||||
const ElemwiseOpParamN<1>& param, const TensorND& dst, | |||||
Elemwise::Mode mode) override; | |||||
void dest_type_bool_mode( | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst, | |||||
Elemwise::Mode mode) override; | |||||
public: | public: | ||||
using ElemwiseMultiTypeImplHelper::ElemwiseMultiTypeImplHelper; | using ElemwiseMultiTypeImplHelper::ElemwiseMultiTypeImplHelper; | ||||
}; | }; | ||||
@@ -3,7 +3,6 @@ | |||||
#include "megdnn/tensor_iter.h" | #include "megdnn/tensor_iter.h" | ||||
#include "src/common/elemwise_multi_type/opr_impl_helper.h" | #include "src/common/elemwise_multi_type/opr_impl_helper.h" | ||||
#include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace naive { | namespace naive { | ||||
@@ -68,6 +67,25 @@ class ElemwiseMultiTypeImpl : public ElemwiseMultiTypeImplHelper { | |||||
} | } | ||||
template <typename KernImpl, typename src_ctype, typename dst_ctype> | template <typename KernImpl, typename src_ctype, typename dst_ctype> | ||||
void dispatch_dst_bool_op( | |||||
const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor) { | |||||
auto size = param.size; | |||||
auto src0 = param[0]; | |||||
auto work = [src0, size, dst_tensor]() { | |||||
// This is needed as these iterators are captured as const value. | |||||
auto iA = tensor_iter_valonly<src_ctype>(src0).begin(); | |||||
auto pD = tensor_iter_valonly<dst_ctype>(dst_tensor).begin(); | |||||
for (size_t i = 0; i < size; i++) { | |||||
src_ctype a = *iA; | |||||
*pD = KernImpl::apply(a); | |||||
++iA; | |||||
++pD; | |||||
} | |||||
}; | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||||
} | |||||
template <typename KernImpl, typename src_ctype, typename dst_ctype> | |||||
void dispatch_add_qint_op( | void dispatch_add_qint_op( | ||||
const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor) { | const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor) { | ||||
auto size = param.size; | auto size = param.size; | ||||
@@ -98,6 +116,29 @@ class ElemwiseMultiTypeImpl : public ElemwiseMultiTypeImplHelper { | |||||
} | } | ||||
template <typename KernImpl, typename src_ctype, typename dst_ctype> | template <typename KernImpl, typename src_ctype, typename dst_ctype> | ||||
void dispatch_dst_bool_op( | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor) { | |||||
auto size = param.size; | |||||
auto src0 = param[0]; | |||||
auto src1 = param[1]; | |||||
auto work = [src0, src1, size, dst_tensor]() { | |||||
// This is needed as these iterators are captured as const value. | |||||
auto iA = tensor_iter_valonly<src_ctype>(src0).begin(); | |||||
auto iB = tensor_iter_valonly<src_ctype>(src1).begin(); | |||||
auto pD = tensor_iter_valonly<dst_ctype>(dst_tensor).begin(); | |||||
for (size_t i = 0; i < size; i++) { | |||||
src_ctype a = *iA; | |||||
src_ctype b = *iB; | |||||
*pD = KernImpl::apply(a, b); | |||||
++iA; | |||||
++iB; | |||||
++pD; | |||||
} | |||||
}; | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||||
} | |||||
template <typename KernImpl, typename src_ctype, typename dst_ctype> | |||||
void dispatch_add_qint_op( | void dispatch_add_qint_op( | ||||
const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor) { | const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor) { | ||||
auto size = param.size; | auto size = param.size; | ||||
@@ -178,6 +219,14 @@ protected: | |||||
const ElemwiseOpParamN<3>& param, const TensorND& dst, | const ElemwiseOpParamN<3>& param, const TensorND& dst, | ||||
Elemwise::Mode mode) override; | Elemwise::Mode mode) override; | ||||
void dest_type_bool_mode( | |||||
const ElemwiseOpParamN<1>& param, const TensorND& dst, | |||||
Elemwise::Mode mode) override; | |||||
void dest_type_bool_mode( | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst, | |||||
Elemwise::Mode mode) override; | |||||
public: | public: | ||||
using ElemwiseMultiTypeImplHelper::ElemwiseMultiTypeImplHelper; | using ElemwiseMultiTypeImplHelper::ElemwiseMultiTypeImplHelper; | ||||
}; | }; | ||||
@@ -54,4 +54,121 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( | |||||
} | } | ||||
} | } | ||||
void ElemwiseMultiTypeImpl::dest_type_bool_mode( | |||||
const ElemwiseOpParamN<1>& param, const TensorND& dst, Elemwise::Mode mode) { | |||||
switch (mode) { | |||||
case Elemwise::Mode::ISINF: { | |||||
switch (param[0].layout.dtype.enumv()) { | |||||
#define DISPATCH(_dt, _mode) \ | |||||
case DTypeTrait<_dt>::enumv: { \ | |||||
typedef ElemwiseBoolKern< \ | |||||
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \ | |||||
typename DTypeTrait<_dt>::ctype, dt_bool> \ | |||||
KernImpl##_mode; \ | |||||
dispatch_dst_bool_op< \ | |||||
KernImpl##_mode, typename DTypeTrait<_dt>::ctype, dt_bool>( \ | |||||
param, dst); \ | |||||
break; \ | |||||
} | |||||
#define DISPATCH_MODE(_mode) \ | |||||
DISPATCH(megdnn::dtype::Float32, _mode); \ | |||||
DNN_INC_FLOAT16(DISPATCH(megdnn::dtype::Float16, _mode);) \ | |||||
DNN_INC_FLOAT16(DISPATCH(megdnn::dtype::BFloat16, _mode);) | |||||
DISPATCH_MODE(ISINF); | |||||
default: | |||||
megdnn_throw(ssprintf( | |||||
"Unsupported input dtype %s for ElemwiseMultiType", | |||||
param[0].layout.dtype.name())); | |||||
}; | |||||
break; | |||||
}; | |||||
case Elemwise::Mode::ISNAN: { | |||||
switch (param[0].layout.dtype.enumv()) { | |||||
DISPATCH_MODE(ISNAN); | |||||
default: | |||||
megdnn_throw(ssprintf( | |||||
"Unsupported input dtype %s for ElemwiseMultiType", | |||||
param[0].layout.dtype.name())); | |||||
}; | |||||
break; | |||||
}; | |||||
default: | |||||
megdnn_assert_internal(0); | |||||
} | |||||
#undef DISPATCH_MODE | |||||
#undef DISPATCH | |||||
} | |||||
void ElemwiseMultiTypeImpl::dest_type_bool_mode( | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst, Elemwise::Mode mode) { | |||||
megdnn_assert(param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv()); | |||||
switch (mode) { | |||||
case Elemwise::Mode::EQ: { | |||||
switch (param[0].layout.dtype.enumv()) { | |||||
#define DISPATCH(_dt, _mode) \ | |||||
case DTypeTrait<_dt>::enumv: { \ | |||||
typedef ElemwiseBoolKern< \ | |||||
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \ | |||||
typename DTypeTrait<_dt>::ctype, dt_bool> \ | |||||
KernImpl##_mode; \ | |||||
dispatch_dst_bool_op< \ | |||||
KernImpl##_mode, typename DTypeTrait<_dt>::ctype, dt_bool>( \ | |||||
param, dst); \ | |||||
break; \ | |||||
}; | |||||
#define DISPATCH_MODE(_mode) \ | |||||
DISPATCH(megdnn::dtype::Float32, _mode); \ | |||||
DNN_INC_FLOAT16(DISPATCH(megdnn::dtype::Float16, _mode);) \ | |||||
DNN_INC_FLOAT16(DISPATCH(megdnn::dtype::BFloat16, _mode);) \ | |||||
DISPATCH(megdnn::dtype::Int32, _mode); \ | |||||
DISPATCH(megdnn::dtype::Int16, _mode); \ | |||||
DISPATCH(megdnn::dtype::Int8, _mode); \ | |||||
DISPATCH(megdnn::dtype::Uint8, _mode); \ | |||||
DISPATCH(megdnn::dtype::Bool, _mode); | |||||
DISPATCH_MODE(EQ); | |||||
break; | |||||
default: | |||||
megdnn_throw(ssprintf( | |||||
"Unsupported input dtype %s for ElemwiseMultiType", | |||||
param[0].layout.dtype.name())); | |||||
}; | |||||
break; | |||||
}; | |||||
case Elemwise::Mode::NEQ: { | |||||
switch (param[0].layout.dtype.enumv()) { | |||||
DISPATCH_MODE(NEQ); | |||||
default: | |||||
megdnn_throw(ssprintf( | |||||
"Unsupported input dtype %s for ElemwiseMultiType", | |||||
param[0].layout.dtype.name())); | |||||
}; | |||||
break; | |||||
}; | |||||
case Elemwise::Mode::LT: { | |||||
switch (param[0].layout.dtype.enumv()) { | |||||
DISPATCH_MODE(LT); | |||||
default: | |||||
megdnn_throw(ssprintf( | |||||
"Unsupported input dtype %s for ElemwiseMultiType", | |||||
param[0].layout.dtype.name())); | |||||
}; | |||||
break; | |||||
}; | |||||
case Elemwise::Mode::LEQ: { | |||||
switch (param[0].layout.dtype.enumv()) { | |||||
DISPATCH_MODE(LEQ); | |||||
default: | |||||
megdnn_throw(ssprintf( | |||||
"Unsupported input dtype %s for ElemwiseMultiType", | |||||
param[0].layout.dtype.name())); | |||||
}; | |||||
break; | |||||
}; | |||||
default: | |||||
megdnn_assert_internal(0); | |||||
} | |||||
#undef DISPATCH_MODE | |||||
#undef DISPATCH | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -149,8 +149,9 @@ void copy_tensors( | |||||
//! use QuantizedS16 dtype in winograd_filter_preprocess now. | //! use QuantizedS16 dtype in winograd_filter_preprocess now. | ||||
cb(::megdnn::dtype::QuantizedS16) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | cb(::megdnn::dtype::QuantizedS16) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | ||||
cb(::megdnn::dtype::Uint16) cb(::megdnn::dtype::QuantizedS1) | cb(::megdnn::dtype::Uint16) cb(::megdnn::dtype::QuantizedS1) | ||||
cb(::megdnn::dtype::Bool) | |||||
#undef cb | #undef cb | ||||
default : megdnn_trap(); | |||||
default : megdnn_trap(); | |||||
} | } | ||||
} | } | ||||
@@ -325,6 +326,7 @@ void CheckerHelper::do_exec( | |||||
m_output_canonizer(tensors_cur_host); | m_output_canonizer(tensors_cur_host); | ||||
m_output_canonizer(tensors_naive); | m_output_canonizer(tensors_naive); | ||||
} | } | ||||
check_tensors(tensors_naive, tensors_cur_host); | check_tensors(tensors_naive, tensors_cur_host); | ||||
if (m_extra_opr_impl) { | if (m_extra_opr_impl) { | ||||
check_tensors(tensors_naive, *tensors_extra_opr_impl); | check_tensors(tensors_naive, *tensors_extra_opr_impl); | ||||
@@ -756,6 +756,14 @@ DEF_TEST(all_modes) { | |||||
auto should_ignore = [handle](Mode mode) { | auto should_ignore = [handle](Mode mode) { | ||||
MEGDNN_MARK_USED_VAR(mode); | MEGDNN_MARK_USED_VAR(mode); | ||||
switch (mode) { | |||||
case Mode::NEQ: | |||||
case Mode::ISNAN: | |||||
case Mode::ISINF: | |||||
return true; | |||||
default: | |||||
break; | |||||
} | |||||
return false; | return false; | ||||
}; | }; | ||||
@@ -195,6 +195,9 @@ void IIDRNG::gen(const TensorND& tensor) { | |||||
if (tensor.layout.dtype.enumv() == DTypeEnum::Uint16) { | if (tensor.layout.dtype.enumv() == DTypeEnum::Uint16) { | ||||
return; | return; | ||||
} | } | ||||
if (tensor.layout.dtype.enumv() == DTypeEnum::Bool) { | |||||
return; | |||||
} | |||||
megdnn_assert( | megdnn_assert( | ||||
0, "IIDRNG does not know how to generate value for DType %s", | 0, "IIDRNG does not know how to generate value for DType %s", | ||||
tensor.layout.dtype.name()); | tensor.layout.dtype.name()); | ||||
@@ -4,7 +4,6 @@ | |||||
#include "test/cuda/benchmark.h" | #include "test/cuda/benchmark.h" | ||||
#include "test/cuda/fixture.h" | #include "test/cuda/fixture.h" | ||||
#include "test/cuda/utils.h" | #include "test/cuda/utils.h" | ||||
#undef cuda_check | #undef cuda_check | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
@@ -143,6 +142,43 @@ static void run_test_q4(int arity, Checker<ElemwiseMultiType>& checker, Mode mod | |||||
} | } | ||||
} | } | ||||
static void run_test_bool(int arity, Checker<ElemwiseMultiType>& checker, Mode mode) { | |||||
for (DType type : | |||||
std::vector<DType>{{dtype::Int8()}, {dtype::Float32()}, {dtype::Float16()}}) { | |||||
if ((mode == Mode::ISNAN || mode == Mode::ISINF) && type == dtype::Int8()) { | |||||
continue; | |||||
} | |||||
checker.set_param(mode); | |||||
UniformIntRNG rng_int8{1, 1}; | |||||
NormalRNG rng_normal{0, 1}; | |||||
auto set_inp_rng = [&](DType dtype, size_t i) { | |||||
if (dtype.enumv() == DTypeEnum::Int8) { | |||||
checker.set_rng(i, &rng_int8); | |||||
} else if ( | |||||
dtype.enumv() == DTypeEnum::Float32 || | |||||
dtype.enumv() == DTypeEnum::Float16) { | |||||
checker.set_rng(i, &rng_normal); | |||||
} else { | |||||
megdnn_assert(0); | |||||
} | |||||
checker.set_dtype(i, dtype); | |||||
}; | |||||
auto src_type = type; | |||||
for (int i = 0; i < arity; i++) { | |||||
set_inp_rng(src_type, i); | |||||
} | |||||
if (arity == 1) { | |||||
checker.execs({{3, 4, 5, 6}, {}}); | |||||
} else if (arity == 2) { | |||||
checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {}}); | |||||
} else { | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
} | |||||
TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_UNARY) { | TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_UNARY) { | ||||
Checker<ElemwiseMultiType> checker(handle_cuda()); | Checker<ElemwiseMultiType> checker(handle_cuda()); | ||||
for (auto mode : | for (auto mode : | ||||
@@ -203,6 +239,18 @@ TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_BINARY) { | |||||
} | } | ||||
} | } | ||||
TEST_F(CUDA, ELEMWISE_BOOL_MODE_BINARY) { | |||||
using Mode = ElemwiseMultiType::Param::Mode; | |||||
Checker<ElemwiseMultiType> checker(handle_cuda()); | |||||
for (auto mode : {Mode::EQ, Mode::NEQ, Mode::LT, Mode::LEQ}) { | |||||
run_test_bool(2, checker, mode); | |||||
} | |||||
for (auto mode : {Mode::ISNAN, Mode::ISINF}) { | |||||
run_test_bool(1, checker, mode); | |||||
} | |||||
} | |||||
TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_TENARY) { | TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_TENARY) { | ||||
using Mode = ElemwiseMultiType::Param::Mode; | using Mode = ElemwiseMultiType::Param::Mode; | ||||
Checker<ElemwiseMultiType> checker(handle_cuda()); | Checker<ElemwiseMultiType> checker(handle_cuda()); | ||||
@@ -23,11 +23,25 @@ from .._imperative_rt.core2 import ( | |||||
) | ) | ||||
from ..ops import builtin | from ..ops import builtin | ||||
from . import amp | from . import amp | ||||
from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph | |||||
from .utils import ( | |||||
_normalize_axis, | |||||
astensor1d, | |||||
cast_tensors, | |||||
convert_inputs, | |||||
make_shape_tuple, | |||||
subgraph, | |||||
) | |||||
_ElwMod = builtin.Elemwise.Mode | _ElwMod = builtin.Elemwise.Mode | ||||
def _elemwise_multi_type(*args, mode, **kwargs): | |||||
op = builtin.ElemwiseMultiType(mode=mode, **kwargs) | |||||
args = convert_inputs(*args) | |||||
(result,) = apply(op, *args) | |||||
return result | |||||
def _elwise_apply(args, mode): | def _elwise_apply(args, mode): | ||||
op = builtin.Elemwise(mode) | op = builtin.Elemwise(mode) | ||||
(result,) = apply(op, *args) | (result,) = apply(op, *args) | ||||
@@ -234,13 +248,23 @@ class ArrayMethodMixin(abc.ABC): | |||||
__hash__ = None # due to __eq__ diviates from python convention | __hash__ = None # due to __eq__ diviates from python convention | ||||
__lt__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LT).astype("bool") | |||||
__le__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LEQ).astype("bool") | |||||
__gt__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LT).astype("bool") | |||||
__ge__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LEQ).astype("bool") | |||||
__eq__ = lambda self, value: _elwise(self, value, mode=_ElwMod.EQ).astype("bool") | |||||
__ne__ = lambda self, value: _elwise( | |||||
_elwise(self, value, mode=_ElwMod.EQ).astype("bool"), mode=_ElwMod.NOT, | |||||
__lt__ = lambda self, value: _elemwise_multi_type( | |||||
self, value, mode="lt", dtype="Bool" | |||||
) | |||||
__le__ = lambda self, value: _elemwise_multi_type( | |||||
self, value, mode="leq", dtype="Bool" | |||||
) | |||||
__gt__ = lambda self, value: _elemwise_multi_type( | |||||
value, self, mode="lt", dtype="Bool" | |||||
) | |||||
__ge__ = lambda self, value: _elemwise_multi_type( | |||||
value, self, mode="leq", dtype="Bool" | |||||
) | |||||
__eq__ = lambda self, value: _elemwise_multi_type( | |||||
self, value, mode="eq", dtype="Bool" | |||||
) | |||||
__ne__ = lambda self, value: _elemwise_multi_type( | |||||
self, value, mode="neq", dtype="Bool" | |||||
) | ) | ||||
__neg__ = _unary_elwise(_ElwMod.NEGATE) | __neg__ = _unary_elwise(_ElwMod.NEGATE) | ||||
@@ -10,7 +10,7 @@ from ..core.tensor.array_method import _matmul | |||||
from ..core.tensor.utils import _normalize_axis | from ..core.tensor.utils import _normalize_axis | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from ..utils.deprecation import deprecated_kwargs_default | from ..utils.deprecation import deprecated_kwargs_default | ||||
from .elemwise import clip | |||||
from .elemwise import _elemwise_multi_type, clip | |||||
from .tensor import expand_dims, squeeze | from .tensor import expand_dims, squeeze | ||||
__all__ = [ | __all__ = [ | ||||
@@ -52,7 +52,7 @@ def isnan(inp: Tensor) -> Tensor: | |||||
>>> F.isnan(x).numpy() | >>> F.isnan(x).numpy() | ||||
array([False, True, False]) | array([False, True, False]) | ||||
""" | """ | ||||
return inp != inp | |||||
return _elemwise_multi_type(inp, mode="isnan", dtype="Bool") | |||||
def isinf(inp: Tensor) -> Tensor: | def isinf(inp: Tensor) -> Tensor: | ||||
@@ -69,7 +69,7 @@ def isinf(inp: Tensor) -> Tensor: | |||||
>>> F.isinf(x).numpy() | >>> F.isinf(x).numpy() | ||||
array([False, True, False]) | array([False, True, False]) | ||||
""" | """ | ||||
return abs(inp).astype("float32") == float("inf") | |||||
return _elemwise_multi_type(inp, mode="isinf", dtype="Bool") | |||||
def sign(inp: Tensor): | def sign(inp: Tensor): | ||||
@@ -133,9 +133,9 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { | |||||
0.f}) / | 0.f}) / | ||||
6.f), | 6.f), | ||||
}; | }; | ||||
mgb_assert(map.size() + 16 == opr::Elemwise::Param::MODE_NR_MEMBER); | |||||
mgb_assert(map.size() + 19 == opr::Elemwise::Param::MODE_NR_MEMBER); | |||||
// unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, | // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, | ||||
// ERFINV, ERFCINV, NOT, AND, OR, XOR | |||||
// ERFINV, ERFCINV, NOT, AND, OR, XOR, NEQ, ISNAN, ISINF | |||||
return map; | return map; | ||||
#undef ADD_OPR | #undef ADD_OPR | ||||
} | } | ||||
@@ -756,8 +756,8 @@ TYPED_TEST(TestOprBasicArithTernaryElemwise, Float32) { | |||||
TEST(TestOprBasicArithElemwise, CheckAllModeTested) { | TEST(TestOprBasicArithElemwise, CheckAllModeTested) { | ||||
size_t nr_member = opr::Elemwise::Param::MODE_NR_MEMBER; | size_t nr_member = opr::Elemwise::Param::MODE_NR_MEMBER; | ||||
ASSERT_EQ(nr_member, tested_mode.size() + 4); | |||||
// Not using TestRunner: NOT, AND, OR, XOR | |||||
ASSERT_EQ(nr_member, tested_mode.size() + 7); | |||||
// Not using TestRunner: NOT, AND, OR, XOR, NEQ, ISNAN, ISINF | |||||
} | } | ||||
#define TEST_OPR_BASIC_ARITH_UNARY_BOOL(_mode, _op) \ | #define TEST_OPR_BASIC_ARITH_UNARY_BOOL(_mode, _op) \ | ||||
TEST(TestOprBasicArithElemwise, _mode) { \ | TEST(TestOprBasicArithElemwise, _mode) { \ | ||||