GitOrigin-RevId: 3d97fedc8f
tags/v1.0.0-rc1
@@ -100,7 +100,6 @@ namespace { | |||
MIDOUT_END(); \ | |||
break; \ | |||
default: \ | |||
megdnn_throw("no quantized unsupported biasmode"); \ | |||
break; \ | |||
} | |||
@@ -258,6 +257,66 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { | |||
#undef FOR_NONLINEAR_NOBIAS | |||
#undef FOR_NONLINEAR | |||
#undef FOR_BIAS | |||
#define FOR_BINARY_BROADCAST(_op) \ | |||
megdnn::arm_common:: \ | |||
OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101>::run( \ | |||
static_cast<ctype*>(conv_dst_ptr), \ | |||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||
dst_type, N, OC, OH* OW); | |||
#define FOR_BINARY_BROADCAST_NCHW44(_op) \ | |||
megdnn::arm_common::OpCallerBinary<_op<ctype>, \ | |||
megdnn::arm_common::VEC_BCAST101x4>:: \ | |||
run(static_cast<ctype*>(conv_dst_ptr), \ | |||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||
dst_type, N, OC, OH* OW, pack_oc_size); | |||
#define FOR_BINARY(_op) \ | |||
megdnn::arm_common:: \ | |||
OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \ | |||
static_cast<ctype*>(conv_dst_ptr), \ | |||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||
dst_type, N* OC* OH* OW* pack_oc_size); | |||
#define FOR_BIAS(_bias_mode, OH, OW) \ | |||
switch (_bias_mode) { \ | |||
case megdnn::BiasMode::NO_BIAS: \ | |||
break; \ | |||
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
if (pack_oc_size == 1) { \ | |||
FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \ | |||
} else { \ | |||
megdnn_assert(pack_oc_size == 4, \ | |||
"Only support nchw44 in ARM"); \ | |||
FOR_BINARY_BROADCAST_NCHW44(CONCAT_OP(AddOp)); \ | |||
} \ | |||
break; \ | |||
case megdnn::BiasMode::BIAS: \ | |||
FOR_BINARY(CONCAT_OP(AddOp)); \ | |||
break; \ | |||
default: \ | |||
break; \ | |||
} | |||
template <typename ctype, typename dtype> | |||
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { | |||
static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | |||
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | |||
megdnn::DType bias_type, megdnn::DType dst_type, size_t N, | |||
size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
megdnn_assert(nonlineMode == megdnn::NonlineMode::IDENTITY); | |||
FOR_BIAS(bias_mode, OH, OW); | |||
} | |||
}; | |||
#undef FOR_BINARY_BROADCAST | |||
#undef FOR_BINARY_BROADCAST_NCHW44 | |||
#undef FOR_BINARY | |||
#undef FOR_BIAS | |||
#undef CB | |||
#undef CONCAT_OP | |||
#undef CONCAT_NL | |||
@@ -158,9 +158,11 @@ private: \ | |||
uint32_t m_tile_size; | |||
enum class PostprocessMode : uint8_t { | |||
FLOAT = 0, ///< support all biasmode and no_nonlinemode | |||
NO_PROCESS, ///<support non bias and identity | |||
QUANTIZED,///<support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish identify nonline mode | |||
FLOAT = 0, ///< support all biasmode and no_nonlinemode | |||
NO_PROCESS, ///< support non bias and identity | |||
QUANTIZED, ///< support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish | |||
///< identify nonline mode | |||
ADD_BIAS, ///< only add bias | |||
}; | |||
} // namespace megdnn | |||
@@ -227,8 +227,7 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param, | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS16 || | |||
param.dst_type.enumv() == DTypeEnum::Int32 || | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | |||
if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
return false; | |||
} | |||
} | |||
@@ -310,6 +310,19 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||
} \ | |||
} \ | |||
MIDOUT_END() | |||
#define cb3(_format, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
conv1x1_gemv_worker = \ | |||
Conv1x1GemvWorker<_src_ctype, _bias_ctype, _dst_ctype, \ | |||
_bias_ctype, _dst_ctype, \ | |||
_postprocess_mode, _format>::exec; \ | |||
} \ | |||
} \ | |||
MIDOUT_END() | |||
switch (param.filter_meta.format) { | |||
case param::ConvBias::Format::NCHW: | |||
@@ -324,23 +337,23 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||
PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash); | |||
#endif | |||
#endif | |||
cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, | |||
dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, | |||
dt_int8, dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||
"NCHW::GEMV::INT8x8x32_INT32"_hash); | |||
cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int16, dt_int16, | |||
dt_int8, dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int16, dt_int16, | |||
dt_int8, dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||
"NCHW::GEMV::INT8x8x16_INT16"_hash); | |||
cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | |||
cb3(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | |||
dt_int32, PostprocessMode::NO_PROCESS, | |||
dt_int32, PostprocessMode::ADD_BIAS, | |||
"NCHW::GEMV::QINT8x8x32_QINT32"_hash); | |||
cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | |||
dt_int8, PostprocessMode::QUANTIZED, | |||
"NCHW::GEMV::QINT8x8x32_QINT8"_hash); | |||
cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | |||
cb3(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | |||
dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, dt_int32, | |||
dt_int32, PostprocessMode::NO_PROCESS, | |||
dt_int32, PostprocessMode::ADD_BIAS, | |||
"NCHW::GEMV::QUINT8x8x32_QINT32"_hash); | |||
cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | |||
dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, dt_int32, | |||
@@ -365,13 +378,13 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||
break; | |||
case param::ConvBias::Format::NCHW44_DOT: | |||
cb2(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32, | |||
cb3(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32, | |||
dt_int32, dt_int8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
PostprocessMode::ADD_BIAS, | |||
"NCHW44_DOT::GEMV::INT8x8x32_INT32"_hash); | |||
cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | |||
cb3(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | |||
dt_int32, PostprocessMode::NO_PROCESS, | |||
dt_int32, PostprocessMode::ADD_BIAS, | |||
"NCHW44_DOT::GEMV::QINT8x8x32_QINT32"_hash); | |||
cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | |||
@@ -385,6 +398,7 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||
} | |||
#undef cb1 | |||
#undef cb2 | |||
#undef cb3 | |||
megdnn_assert(conv1x1_gemv_worker, "No suitable gemv worker"); | |||
@@ -448,8 +462,7 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(const NCBKernSizeParam& param, | |||
if (param.dst_type.enumv() == DTypeEnum::Int16 || | |||
param.dst_type.enumv() == DTypeEnum::Int32 || | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | |||
if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
return false; | |||
} | |||
} | |||
@@ -56,6 +56,19 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||
} \ | |||
} \ | |||
MIDOUT_END() | |||
#define cb3(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
return std::make_unique<Conv1x1Strategy< \ | |||
_src_ctype, _bias_ctype, _dst_ctype, _bias_ctype, \ | |||
_dst_ctype, _postprocess_mode, _packmode>>(pack_c_size); \ | |||
} \ | |||
} \ | |||
MIDOUT_END() | |||
switch (pack_mode) { | |||
case MatrixMulImpl::AlgoBase::PackMode::DEFAULT: | |||
@@ -71,26 +84,26 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||
"Default::FLOAT16_FLOAT16"_hash); | |||
#endif | |||
#endif | |||
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int32, | |||
cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int32, | |||
dt_int32, dt_int8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, "Default::INT8x8x32_INT32"_hash); | |||
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int16, | |||
PostprocessMode::ADD_BIAS, "Default::INT8x8x32_INT32"_hash); | |||
cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int16, | |||
dt_int16, dt_int8, dt_int16, dt_int16, | |||
PostprocessMode::NO_PROCESS, "Default::INT8x8x16_INT16"_hash); | |||
PostprocessMode::ADD_BIAS, "Default::INT8x8x16_INT16"_hash); | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, | |||
cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, | |||
dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
PostprocessMode::ADD_BIAS, | |||
"Default::QUINT8x8x32_QINT32"_hash); | |||
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, | |||
dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | |||
PostprocessMode::QUANTIZED, "Default::QUINT8x8x32_QUINT8"_hash); | |||
#endif | |||
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, | |||
cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | |||
dt_int32, PostprocessMode::NO_PROCESS, | |||
dt_int32, PostprocessMode::ADD_BIAS, | |||
"Default::QINT8x8x32_QINT32"_hash); | |||
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | |||
@@ -107,17 +120,17 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||
cb1(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_float32, | |||
dt_float32, PostprocessMode::FLOAT, "NoPack::FLOAT"_hash); | |||
cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int16, | |||
cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int16, | |||
dt_int16, dt_int8, dt_int16, dt_int16, | |||
PostprocessMode::NO_PROCESS, "NoPack::INT8x8x16_INT16"_hash); | |||
PostprocessMode::ADD_BIAS, "NoPack::INT8x8x16_INT16"_hash); | |||
cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int32, | |||
cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int32, | |||
dt_int32, dt_int8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, "NoPack::INT8x8x32_INT32"_hash); | |||
PostprocessMode::ADD_BIAS, "NoPack::INT8x8x32_INT32"_hash); | |||
cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dtype::QuantizedS8, | |||
cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | |||
dt_int32, PostprocessMode::NO_PROCESS, | |||
dt_int32, PostprocessMode::ADD_BIAS, | |||
"NoPack::QINT8x8x32_QINT32"_hash); | |||
break; | |||
@@ -127,6 +140,7 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||
} | |||
#undef cb1 | |||
#undef cb2 | |||
#undef cb3 | |||
megdnn_throw("Invalid Data Type"); | |||
return nullptr; | |||
} | |||
@@ -207,4 +221,4 @@ bool Conv1x1Factory::can_make_conv1x1_strategy( | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||
// vim: syntax=cpp.doxygen |
@@ -746,8 +746,7 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||
if (param.dst_type.enumv() == DTypeEnum::Int16 || | |||
param.dst_type.enumv() == DTypeEnum::Int32 || | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | |||
if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
return false; | |||
} | |||
} | |||
@@ -213,6 +213,22 @@ public: | |||
} \ | |||
MIDOUT_END(); \ | |||
return {}; | |||
#define cb3(_format, _packmode, _i_src_type, _i_bias_type, _i_dst_type, \ | |||
_src_ctype, _bias_ctype, _dst_ctype, _postprocess_mode, \ | |||
_midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||
midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
return std::make_unique< \ | |||
Strategy<_src_ctype, _bias_ctype, _dst_ctype, _bias_ctype, \ | |||
_dst_ctype, _postprocess_mode, \ | |||
PackMode::_packmode, FormatMode::_format>>(); \ | |||
} \ | |||
} \ | |||
MIDOUT_END(); \ | |||
return {}; | |||
static std::unique_ptr<StrategyBase> make_default_strategy( | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
@@ -279,13 +295,13 @@ public: | |||
#endif | |||
case StrategyType::INT8x8x32: | |||
if (format == param::ConvBias::Format::NCHW) { | |||
cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
cb3(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||
"DefaultStrategyType::INT8x8x32"_hash); | |||
} else if (format == param::ConvBias::Format::NCHW44 || | |||
format == param::ConvBias::Format::NCHW44_DOT) { | |||
cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
cb3(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||
"DefaultStrategyType::INT8x8x32"_hash); | |||
} else { | |||
megdnn_throw( | |||
@@ -299,12 +315,12 @@ public: | |||
case StrategyType::INT8x8x16: | |||
if (format == param::ConvBias::Format::NCHW) { | |||
cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
cb3(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||
dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||
"DefaultStrategyType::INT8x8x16"_hash); | |||
} else if (format == param::ConvBias::Format::NCHW44) { | |||
cb2(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
cb3(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||
dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||
"DefaultStrategyType::INT8x8x16"_hash); | |||
} else { | |||
megdnn_throw( | |||
@@ -316,9 +332,9 @@ public: | |||
break; | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
case StrategyType::QUINT8x8x32: | |||
cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
cb3(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
PostprocessMode::ADD_BIAS, | |||
"DefaultStrategyType::QUINT8x8x32"_hash); | |||
break; | |||
@@ -331,15 +347,15 @@ public: | |||
#endif | |||
case StrategyType::QINT8x8x32: | |||
if (format == param::ConvBias::Format::NCHW) { | |||
cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||
cb3(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
PostprocessMode::ADD_BIAS, | |||
"DefaultStrategyTypeNCHW::QINT8x8x32"_hash); | |||
} else if (format == param::ConvBias::Format::NCHW44 || | |||
format == param::ConvBias::Format::NCHW44_DOT) { | |||
cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | |||
cb3(NCHW44, DEFAULT, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||
"DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); | |||
} else { | |||
megdnn_throw( | |||
@@ -467,13 +483,13 @@ public: | |||
#endif | |||
#endif | |||
case StrategyType::INT8x8x16: | |||
cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | |||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
cb3(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | |||
dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||
"NoPackStrategyType::INT8x8x16"_hash); | |||
break; | |||
case StrategyType::INT8x8x32: | |||
cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
cb3(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||
"NoPackStrategyType::INT8x8x32"_hash); | |||
break; | |||
default: | |||
@@ -509,6 +525,7 @@ public: | |||
#undef cb1 | |||
#undef cb2 | |||
#undef cb3 | |||
static std::unique_ptr<StrategyBase> make_strategy( | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
@@ -203,18 +203,16 @@ INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||
megdnn::PostprocessMode::QUANTIZED) | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
megdnn::PostprocessMode::ADD_BIAS) | |||
#endif | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||
megdnn::PostprocessMode::QUANTIZED) | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
megdnn::PostprocessMode::ADD_BIAS) | |||
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
megdnn::PostprocessMode::ADD_BIAS) | |||
#undef INSTANTIAL_CLASS | |||
} // namespace megdnn | |||
@@ -119,19 +119,16 @@ INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||
megdnn::PostprocessMode::QUANTIZED) | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
megdnn::PostprocessMode::ADD_BIAS) | |||
#endif | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||
megdnn::PostprocessMode::QUANTIZED) | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
megdnn::PostprocessMode::ADD_BIAS) | |||
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
megdnn::PostprocessMode::ADD_BIAS) | |||
#undef INSTANTIAL_CLASS | |||
} // namespace megdnn | |||
@@ -162,9 +162,9 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | |||
megdnn::PostprocessMode::FLOAT) | |||
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
megdnn::PostprocessMode::ADD_BIAS) | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
megdnn::PostprocessMode::ADD_BIAS) | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
#else | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
@@ -294,6 +294,73 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> { | |||
#undef FOR_BIAS | |||
} | |||
}; | |||
#undef CALL_BINARY | |||
#undef CALL_BINARY_BROADCAST | |||
#define CALL_BINARY(_op, _simd_type) \ | |||
thin_function<void(const ctype*, const ctype*, dtype*, DType, DType, \ | |||
DType, size_t)> \ | |||
run = OpCallerBinary<_op<_simd_type, ctype, dtype>, _simd_type, \ | |||
megdnn::x86::BcastType::VEC_VEC>::run; \ | |||
run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \ | |||
reinterpret_cast<dtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
N* OC* OH* OW); | |||
#define CALL_BINARY_BROADCAST(_op, _simd_type) \ | |||
thin_function<void(const ctype*, const ctype*, dtype*, DType, DType, \ | |||
DType, size_t, size_t, size_t)> \ | |||
run = OpCallerBinary<_op<_simd_type, ctype, dtype>, _simd_type, \ | |||
megdnn::x86::BcastType::VEC_BCAST101>::run; \ | |||
run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \ | |||
reinterpret_cast<dtype*>(dst_ptr), bias_type, bias_type, dst_type, N, \ | |||
OC, OH* OW); | |||
#define FOR_SIMD(CALLER) \ | |||
if (is_supported(SIMDType::AVX2)) { \ | |||
CALLER(AddOp, SIMDType::AVX2) \ | |||
} else if (is_supported(SIMDType::SSE4_2)) { \ | |||
CALLER(AddOp, SIMDType::SSE4_2) \ | |||
} else { \ | |||
CALLER(AddOp, SIMDType::NONE) \ | |||
} | |||
#define FOR_BIAS(bias_mode) \ | |||
switch (bias_mode) { \ | |||
case BiasMode::BIAS: \ | |||
FOR_SIMD(CALL_BINARY); \ | |||
break; \ | |||
case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
FOR_SIMD(CALL_BINARY_BROADCAST); \ | |||
break; \ | |||
default: \ | |||
break; \ | |||
} | |||
template <typename ctype, typename dtype> | |||
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { | |||
static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | |||
megdnn::ConvBiasForward::BiasMode bias_mode, | |||
megdnn::param::ConvBiasV0::NonlineMode nonlineMode, | |||
DType bias_type, DType dst_type, size_t N, size_t OC, | |||
size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
megdnn_assert(pack_oc_size == 1, | |||
"PostProcess only support nchw in x86"); | |||
megdnn_assert( | |||
nonlineMode == megdnn::param::ConvBiasV0::NonlineMode::IDENTITY, | |||
"Add bias PostProcess only support IDENTITY"); | |||
if (bias_mode == megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||
return; | |||
} | |||
FOR_BIAS(bias_mode); | |||
#undef CALL_BINARY | |||
#undef CALL_BINARY_BROADCAST | |||
#undef FOR_SIMD | |||
#undef FOR_BIAS | |||
} | |||
}; | |||
#undef cb_unary | |||
#undef cb_binary | |||
#undef BIAS_CASE | |||
@@ -92,6 +92,8 @@ OP(dt_int8, SIMDType::AVX2, "avx2", __m256i, __m256ix2, __m256i, mm256, epi8, | |||
using AddOpBase::operator(); \ | |||
}; | |||
OP(dt_int32, SIMDType::NONE); | |||
OP(dt_int16, SIMDType::NONE); | |||
OP(dt_float32, SIMDType::NONE); | |||
#undef OP | |||
} // namespace x86 | |||
@@ -1992,13 +1992,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) { | |||
#define cb(name) \ | |||
checker_conv_bias( \ | |||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||
true, false, true, false, false, true), \ | |||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, \ | |||
true, false, true, true, false, false), \ | |||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \ | |||
checker_conv_bias( \ | |||
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | |||
false, false, true), \ | |||
true, false, false), \ | |||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); | |||
@@ -2041,13 +2041,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32_MK4_DOT) { | |||
#define cb(name) \ | |||
checker_conv_bias( \ | |||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||
true, false, true, false, false, true), \ | |||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, \ | |||
true, false, true, true, false, false), \ | |||
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | |||
dtype::Int32(), {}, name); \ | |||
checker_conv_bias( \ | |||
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | |||
false, false, true), \ | |||
true, false, false), \ | |||
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | |||
dtype::Int32(), {}, name); | |||
@@ -2118,7 +2118,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) { | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { | |||
NormalRNG rng(128.f); | |||
#define cb(name) \ | |||
checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||
false, true, true), \ | |||
@@ -2188,18 +2187,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { | |||
UniformIntRNG rng{-50, 50}; | |||
float epsilon = 0.001; | |||
#define cb(name) \ | |||
checker_conv_bias( \ | |||
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ | |||
handle(), &rng, epsilon, \ | |||
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||
dtype::QuantizedS32(1.2 * 1.3), {}, name); \ | |||
checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \ | |||
&rng, epsilon, \ | |||
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||
dtype::QuantizedS32(1.2 * 1.3), {}, name); | |||
#define cb(name) \ | |||
checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||
true, true, false), \ | |||
handle(), &rng, epsilon, \ | |||
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||
dtype::QuantizedS32(1.2 * 1.3), {}, name); \ | |||
checker_conv_bias( \ | |||
get_conv_bias_args({1}, 2, false, false, true, true, false), \ | |||
handle(), &rng, epsilon, \ | |||
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||
dtype::QuantizedS32(1.2 * 1.3), {}, name); | |||
#if MEGDNN_AARCH64 | |||
#if __ARM_FEATURE_DOTPROD | |||
@@ -2252,18 +2252,18 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | |||
UniformIntRNG rng{-50, 50}; | |||
float epsilon = 0.001; | |||
std::vector<conv_bias::TestArg> args_nchw44 = | |||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, true, true, | |||
false, false, false, false, true); | |||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, false, true, | |||
false, false, true, false, false); | |||
std::vector<conv_bias::TestArg> args_nchw44_1x1s2 = | |||
get_nchw44_conv_bias_args({1}, 2, true, true, true, false, false, | |||
false, false, true); | |||
#define cb(name) \ | |||
checker_conv_bias( \ | |||
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ | |||
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||
dtype::Int16{}, dtype::Int16{}, name); \ | |||
checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \ | |||
&rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||
get_nchw44_conv_bias_args({1}, 2, true, false, true, false, false, | |||
true, false, false); | |||
#define cb(name) \ | |||
checker_conv_bias( \ | |||
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, true), \ | |||
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||
dtype::Int16{}, dtype::Int16{}, name); \ | |||
checker_conv_bias(get_conv_bias_args({1}, 2, false, false, true), \ | |||
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||
dtype::Int16{}, dtype::Int16{}, name); | |||
#define cb_nchw44(name) \ | |||
@@ -2314,14 +2314,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCES | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) { | |||
UniformIntRNG rng{-50, 50}; | |||
float epsilon = 0.001; | |||
#define cb(name) \ | |||
check_conv_bias_preprocess( \ | |||
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ | |||
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||
dtype::Int16{}, dtype::Int16{}, name); \ | |||
check_conv_bias_preprocess(get_conv_bias_args({1}, 2, false, true, true), \ | |||
handle(), &rng, epsilon, dtype::Int8{}, \ | |||
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, \ | |||
#define cb(name) \ | |||
check_conv_bias_preprocess( \ | |||
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, true), \ | |||
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||
dtype::Int16{}, dtype::Int16{}, name); \ | |||
check_conv_bias_preprocess(get_conv_bias_args({1}, 2, false, false, true), \ | |||
handle(), &rng, epsilon, dtype::Int8{}, \ | |||
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, \ | |||
name); | |||
#if MEGDNN_AARCH64 | |||
@@ -2406,7 +2406,7 @@ void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args, | |||
checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
.set_dtype(4, {}) | |||
.set_dtype(4, dtype::QuantizedS32(6.25f)) | |||
.set_rng(0, &rng) | |||
.set_rng(1, &rng) | |||
.set_rng(2, &rng) | |||
@@ -2436,7 +2436,7 @@ void checker_conv_bias_int8x8x32_preprocess(std::vector<conv_bias::TestArg> args | |||
checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
.set_dtype(4, {}) | |||
.set_dtype(4, dtype::QuantizedS32(6.25f)) | |||
.set_rng(0, &rng) | |||
.set_rng(1, &rng) | |||
.set_rng(2, &rng) | |||
@@ -2450,7 +2450,7 @@ void checker_conv_bias_int8x8x32_preprocess(std::vector<conv_bias::TestArg> args | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true); | |||
get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true); | |||
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | |||
#if MEGDNN_AARCH64 | |||
@@ -2464,7 +2464,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true); | |||
get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true); | |||
#define cb(name) checker_conv_bias_int8x8x32_preprocess(args, handle(), name); | |||
#if MEGDNN_AARCH64 | |||
@@ -2478,7 +2478,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPR | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true); | |||
get_nchw44_conv_bias_args({3, 4, 6}, 1, false, false, true); | |||
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | |||
#if MEGDNN_AARCH64 | |||
@@ -3080,9 +3080,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32_PREPROCESS) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { | |||
UniformIntRNG rng{-50, 50}; | |||
float epsilon = 0.001; | |||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||
std::vector<conv_bias::TestArg> args = | |||
get_conv_bias_1x1_args(false, true, false, false); | |||
std::vector<conv_bias::TestArg> args_nchw44 = get_nchw44_conv_bias_args( | |||
{1}, 1, true, true, true, false, false, false, false, true); | |||
{1}, 1, true, true, true, false, false, true, false, false); | |||
#define cb(name) \ | |||
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \ | |||
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); | |||
@@ -3140,7 +3141,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16_PREPROCESS) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||
std::vector<conv_bias::TestArg> args = | |||
get_conv_bias_1x1_args(false, true, false, false); | |||
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | |||
@@ -834,6 +834,13 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X32) { | |||
//! no bias | |||
args.emplace_back(param, TensorShape{1, ic, h, w}, | |||
TensorShape{oc, ic, kernel, kernel}, TensorShape{}); | |||
args.emplace_back(param, TensorShape{1, ic, h, w}, | |||
TensorShape{oc, ic, kernel, kernel}, | |||
TensorShape{1, oc, 1, 1}); | |||
args.emplace_back(param, TensorShape{1, ic, h, w}, | |||
TensorShape{oc, ic, kernel, kernel}, | |||
TensorShape{1, oc, (h + 2 * p - kernel) + 1, | |||
(h + 2 * p - kernel) + 1}); | |||
}; | |||
for (size_t kernel : {2, 3, 4, 5, 6, 7}) | |||
@@ -1384,7 +1391,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { | |||
using namespace conv_bias; | |||
UniformIntRNG rng{-50, 50}; | |||
float epsilon = 0.001; | |||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, true); | |||
#if MEGDNN_X86_WITH_MKL_DNN | |||
if (x86::is_supported(x86::SIMDType::VNNI)) { | |||
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | |||
@@ -1422,7 +1429,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32_PREPROCESS) { | |||
using namespace conv_bias; | |||
UniformIntRNG rng{-50, 50}; | |||
float epsilon = 0.001; | |||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, true); | |||
#if MEGDNN_X86_WITH_VNNI | |||
if (x86::is_supported(x86::SIMDType::VNNI)) { | |||
checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{}, | |||