GitOrigin-RevId: 3d97fedc8f
tags/v1.0.0-rc1
@@ -100,7 +100,6 @@ namespace { | |||||
MIDOUT_END(); \ | MIDOUT_END(); \ | ||||
break; \ | break; \ | ||||
default: \ | default: \ | ||||
megdnn_throw("no quantized unsupported biasmode"); \ | |||||
break; \ | break; \ | ||||
} | } | ||||
@@ -258,6 +257,66 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { | |||||
#undef FOR_NONLINEAR_NOBIAS | #undef FOR_NONLINEAR_NOBIAS | ||||
#undef FOR_NONLINEAR | #undef FOR_NONLINEAR | ||||
#undef FOR_BIAS | #undef FOR_BIAS | ||||
#define FOR_BINARY_BROADCAST(_op) \ | |||||
megdnn::arm_common:: \ | |||||
OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101>::run( \ | |||||
static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||||
dst_type, N, OC, OH* OW); | |||||
#define FOR_BINARY_BROADCAST_NCHW44(_op) \ | |||||
megdnn::arm_common::OpCallerBinary<_op<ctype>, \ | |||||
megdnn::arm_common::VEC_BCAST101x4>:: \ | |||||
run(static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||||
dst_type, N, OC, OH* OW, pack_oc_size); | |||||
#define FOR_BINARY(_op) \ | |||||
megdnn::arm_common:: \ | |||||
OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \ | |||||
static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||||
dst_type, N* OC* OH* OW* pack_oc_size); | |||||
#define FOR_BIAS(_bias_mode, OH, OW) \ | |||||
switch (_bias_mode) { \ | |||||
case megdnn::BiasMode::NO_BIAS: \ | |||||
break; \ | |||||
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
if (pack_oc_size == 1) { \ | |||||
FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \ | |||||
} else { \ | |||||
megdnn_assert(pack_oc_size == 4, \ | |||||
"Only support nchw44 in ARM"); \ | |||||
FOR_BINARY_BROADCAST_NCHW44(CONCAT_OP(AddOp)); \ | |||||
} \ | |||||
break; \ | |||||
case megdnn::BiasMode::BIAS: \ | |||||
FOR_BINARY(CONCAT_OP(AddOp)); \ | |||||
break; \ | |||||
default: \ | |||||
break; \ | |||||
} | |||||
template <typename ctype, typename dtype> | |||||
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { | |||||
static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | |||||
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | |||||
megdnn::DType bias_type, megdnn::DType dst_type, size_t N, | |||||
size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||||
megdnn_assert(nonlineMode == megdnn::NonlineMode::IDENTITY); | |||||
FOR_BIAS(bias_mode, OH, OW); | |||||
} | |||||
}; | |||||
#undef FOR_BINARY_BROADCAST | |||||
#undef FOR_BINARY_BROADCAST_NCHW44 | |||||
#undef FOR_BINARY | |||||
#undef FOR_BIAS | |||||
#undef CB | #undef CB | ||||
#undef CONCAT_OP | #undef CONCAT_OP | ||||
#undef CONCAT_NL | #undef CONCAT_NL | ||||
@@ -158,9 +158,11 @@ private: \ | |||||
uint32_t m_tile_size; | uint32_t m_tile_size; | ||||
enum class PostprocessMode : uint8_t { | enum class PostprocessMode : uint8_t { | ||||
FLOAT = 0, ///< support all biasmode and no_nonlinemode | |||||
NO_PROCESS, ///<support non bias and identity | |||||
QUANTIZED,///<support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish identify nonline mode | |||||
FLOAT = 0, ///< support all biasmode and no_nonlinemode | |||||
NO_PROCESS, ///< support non bias and identity | |||||
QUANTIZED, ///< support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish | |||||
///< identify nonline mode | |||||
ADD_BIAS, ///< only add bias | |||||
}; | }; | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -227,8 +227,7 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param, | |||||
param.dst_type.enumv() == DTypeEnum::QuantizedS16 || | param.dst_type.enumv() == DTypeEnum::QuantizedS16 || | ||||
param.dst_type.enumv() == DTypeEnum::Int32 || | param.dst_type.enumv() == DTypeEnum::Int32 || | ||||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | ||||
if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||||
if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -310,6 +310,19 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||||
} \ | } \ | ||||
} \ | } \ | ||||
MIDOUT_END() | MIDOUT_END() | ||||
#define cb3(_format, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||||
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, midout_iv(_midout_tag)) { \ | |||||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||||
conv1x1_gemv_worker = \ | |||||
Conv1x1GemvWorker<_src_ctype, _bias_ctype, _dst_ctype, \ | |||||
_bias_ctype, _dst_ctype, \ | |||||
_postprocess_mode, _format>::exec; \ | |||||
} \ | |||||
} \ | |||||
MIDOUT_END() | |||||
switch (param.filter_meta.format) { | switch (param.filter_meta.format) { | ||||
case param::ConvBias::Format::NCHW: | case param::ConvBias::Format::NCHW: | ||||
@@ -324,23 +337,23 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||||
PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash); | PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash); | ||||
#endif | #endif | ||||
#endif | #endif | ||||
cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, | |||||
dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, | |||||
dt_int8, dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||||
"NCHW::GEMV::INT8x8x32_INT32"_hash); | "NCHW::GEMV::INT8x8x32_INT32"_hash); | ||||
cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int16, dt_int16, | |||||
dt_int8, dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int16, dt_int16, | |||||
dt_int8, dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||||
"NCHW::GEMV::INT8x8x16_INT16"_hash); | "NCHW::GEMV::INT8x8x16_INT16"_hash); | ||||
cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | |||||
cb3(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | |||||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | ||||
dt_int32, PostprocessMode::NO_PROCESS, | |||||
dt_int32, PostprocessMode::ADD_BIAS, | |||||
"NCHW::GEMV::QINT8x8x32_QINT32"_hash); | "NCHW::GEMV::QINT8x8x32_QINT32"_hash); | ||||
cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | ||||
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | ||||
dt_int8, PostprocessMode::QUANTIZED, | dt_int8, PostprocessMode::QUANTIZED, | ||||
"NCHW::GEMV::QINT8x8x32_QINT8"_hash); | "NCHW::GEMV::QINT8x8x32_QINT8"_hash); | ||||
cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | |||||
cb3(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | |||||
dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, dt_int32, | ||||
dt_int32, PostprocessMode::NO_PROCESS, | |||||
dt_int32, PostprocessMode::ADD_BIAS, | |||||
"NCHW::GEMV::QUINT8x8x32_QINT32"_hash); | "NCHW::GEMV::QUINT8x8x32_QINT32"_hash); | ||||
cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | ||||
dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, dt_int32, | dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, dt_int32, | ||||
@@ -365,13 +378,13 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||||
break; | break; | ||||
case param::ConvBias::Format::NCHW44_DOT: | case param::ConvBias::Format::NCHW44_DOT: | ||||
cb2(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32, | |||||
cb3(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32, | |||||
dt_int32, dt_int8, dt_int32, dt_int32, | dt_int32, dt_int8, dt_int32, dt_int32, | ||||
PostprocessMode::NO_PROCESS, | |||||
PostprocessMode::ADD_BIAS, | |||||
"NCHW44_DOT::GEMV::INT8x8x32_INT32"_hash); | "NCHW44_DOT::GEMV::INT8x8x32_INT32"_hash); | ||||
cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | |||||
cb3(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | |||||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | ||||
dt_int32, PostprocessMode::NO_PROCESS, | |||||
dt_int32, PostprocessMode::ADD_BIAS, | |||||
"NCHW44_DOT::GEMV::QINT8x8x32_QINT32"_hash); | "NCHW44_DOT::GEMV::QINT8x8x32_QINT32"_hash); | ||||
cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | ||||
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | ||||
@@ -385,6 +398,7 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||||
} | } | ||||
#undef cb1 | #undef cb1 | ||||
#undef cb2 | #undef cb2 | ||||
#undef cb3 | |||||
megdnn_assert(conv1x1_gemv_worker, "No suitable gemv worker"); | megdnn_assert(conv1x1_gemv_worker, "No suitable gemv worker"); | ||||
@@ -448,8 +462,7 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(const NCBKernSizeParam& param, | |||||
if (param.dst_type.enumv() == DTypeEnum::Int16 || | if (param.dst_type.enumv() == DTypeEnum::Int16 || | ||||
param.dst_type.enumv() == DTypeEnum::Int32 || | param.dst_type.enumv() == DTypeEnum::Int32 || | ||||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | ||||
if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||||
if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -56,6 +56,19 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||||
} \ | } \ | ||||
} \ | } \ | ||||
MIDOUT_END() | MIDOUT_END() | ||||
#define cb3(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||||
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||||
midout_iv(_midout_tag)) { \ | |||||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||||
return std::make_unique<Conv1x1Strategy< \ | |||||
_src_ctype, _bias_ctype, _dst_ctype, _bias_ctype, \ | |||||
_dst_ctype, _postprocess_mode, _packmode>>(pack_c_size); \ | |||||
} \ | |||||
} \ | |||||
MIDOUT_END() | |||||
switch (pack_mode) { | switch (pack_mode) { | ||||
case MatrixMulImpl::AlgoBase::PackMode::DEFAULT: | case MatrixMulImpl::AlgoBase::PackMode::DEFAULT: | ||||
@@ -71,26 +84,26 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||||
"Default::FLOAT16_FLOAT16"_hash); | "Default::FLOAT16_FLOAT16"_hash); | ||||
#endif | #endif | ||||
#endif | #endif | ||||
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int32, | |||||
cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int32, | |||||
dt_int32, dt_int8, dt_int32, dt_int32, | dt_int32, dt_int8, dt_int32, dt_int32, | ||||
PostprocessMode::NO_PROCESS, "Default::INT8x8x32_INT32"_hash); | |||||
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int16, | |||||
PostprocessMode::ADD_BIAS, "Default::INT8x8x32_INT32"_hash); | |||||
cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int16, | |||||
dt_int16, dt_int8, dt_int16, dt_int16, | dt_int16, dt_int8, dt_int16, dt_int16, | ||||
PostprocessMode::NO_PROCESS, "Default::INT8x8x16_INT16"_hash); | |||||
PostprocessMode::ADD_BIAS, "Default::INT8x8x16_INT16"_hash); | |||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, | |||||
cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, | |||||
dtype::Quantized8Asymm, dtype::QuantizedS32, | dtype::Quantized8Asymm, dtype::QuantizedS32, | ||||
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | ||||
PostprocessMode::NO_PROCESS, | |||||
PostprocessMode::ADD_BIAS, | |||||
"Default::QUINT8x8x32_QINT32"_hash); | "Default::QUINT8x8x32_QINT32"_hash); | ||||
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, | cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, | ||||
dtype::Quantized8Asymm, dtype::QuantizedS32, | dtype::Quantized8Asymm, dtype::QuantizedS32, | ||||
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | ||||
PostprocessMode::QUANTIZED, "Default::QUINT8x8x32_QUINT8"_hash); | PostprocessMode::QUANTIZED, "Default::QUINT8x8x32_QUINT8"_hash); | ||||
#endif | #endif | ||||
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, | |||||
cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, | |||||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | ||||
dt_int32, PostprocessMode::NO_PROCESS, | |||||
dt_int32, PostprocessMode::ADD_BIAS, | |||||
"Default::QINT8x8x32_QINT32"_hash); | "Default::QINT8x8x32_QINT32"_hash); | ||||
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, | cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, | ||||
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | ||||
@@ -107,17 +120,17 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||||
cb1(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_float32, | cb1(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_float32, | ||||
dt_float32, PostprocessMode::FLOAT, "NoPack::FLOAT"_hash); | dt_float32, PostprocessMode::FLOAT, "NoPack::FLOAT"_hash); | ||||
cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int16, | |||||
cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int16, | |||||
dt_int16, dt_int8, dt_int16, dt_int16, | dt_int16, dt_int8, dt_int16, dt_int16, | ||||
PostprocessMode::NO_PROCESS, "NoPack::INT8x8x16_INT16"_hash); | |||||
PostprocessMode::ADD_BIAS, "NoPack::INT8x8x16_INT16"_hash); | |||||
cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int32, | |||||
cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int32, | |||||
dt_int32, dt_int8, dt_int32, dt_int32, | dt_int32, dt_int8, dt_int32, dt_int32, | ||||
PostprocessMode::NO_PROCESS, "NoPack::INT8x8x32_INT32"_hash); | |||||
PostprocessMode::ADD_BIAS, "NoPack::INT8x8x32_INT32"_hash); | |||||
cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dtype::QuantizedS8, | |||||
cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dtype::QuantizedS8, | |||||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | ||||
dt_int32, PostprocessMode::NO_PROCESS, | |||||
dt_int32, PostprocessMode::ADD_BIAS, | |||||
"NoPack::QINT8x8x32_QINT32"_hash); | "NoPack::QINT8x8x32_QINT32"_hash); | ||||
break; | break; | ||||
@@ -127,6 +140,7 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||||
} | } | ||||
#undef cb1 | #undef cb1 | ||||
#undef cb2 | #undef cb2 | ||||
#undef cb3 | |||||
megdnn_throw("Invalid Data Type"); | megdnn_throw("Invalid Data Type"); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
@@ -207,4 +221,4 @@ bool Conv1x1Factory::can_make_conv1x1_strategy( | |||||
} // namespace fallback | } // namespace fallback | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | |||||
// vim: syntax=cpp.doxygen |
@@ -746,8 +746,7 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
if (param.dst_type.enumv() == DTypeEnum::Int16 || | if (param.dst_type.enumv() == DTypeEnum::Int16 || | ||||
param.dst_type.enumv() == DTypeEnum::Int32 || | param.dst_type.enumv() == DTypeEnum::Int32 || | ||||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | ||||
if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||||
if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -213,6 +213,22 @@ public: | |||||
} \ | } \ | ||||
MIDOUT_END(); \ | MIDOUT_END(); \ | ||||
return {}; | return {}; | ||||
#define cb3(_format, _packmode, _i_src_type, _i_bias_type, _i_dst_type, \ | |||||
_src_ctype, _bias_ctype, _dst_ctype, _postprocess_mode, \ | |||||
_midout_tag) \ | |||||
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||||
midout_iv(_midout_tag)) { \ | |||||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||||
return std::make_unique< \ | |||||
Strategy<_src_ctype, _bias_ctype, _dst_ctype, _bias_ctype, \ | |||||
_dst_ctype, _postprocess_mode, \ | |||||
PackMode::_packmode, FormatMode::_format>>(); \ | |||||
} \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
return {}; | |||||
static std::unique_ptr<StrategyBase> make_default_strategy( | static std::unique_ptr<StrategyBase> make_default_strategy( | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
@@ -279,13 +295,13 @@ public: | |||||
#endif | #endif | ||||
case StrategyType::INT8x8x32: | case StrategyType::INT8x8x32: | ||||
if (format == param::ConvBias::Format::NCHW) { | if (format == param::ConvBias::Format::NCHW) { | ||||
cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
cb3(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||||
"DefaultStrategyType::INT8x8x32"_hash); | "DefaultStrategyType::INT8x8x32"_hash); | ||||
} else if (format == param::ConvBias::Format::NCHW44 || | } else if (format == param::ConvBias::Format::NCHW44 || | ||||
format == param::ConvBias::Format::NCHW44_DOT) { | format == param::ConvBias::Format::NCHW44_DOT) { | ||||
cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
cb3(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||||
"DefaultStrategyType::INT8x8x32"_hash); | "DefaultStrategyType::INT8x8x32"_hash); | ||||
} else { | } else { | ||||
megdnn_throw( | megdnn_throw( | ||||
@@ -299,12 +315,12 @@ public: | |||||
case StrategyType::INT8x8x16: | case StrategyType::INT8x8x16: | ||||
if (format == param::ConvBias::Format::NCHW) { | if (format == param::ConvBias::Format::NCHW) { | ||||
cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
cb3(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||||
"DefaultStrategyType::INT8x8x16"_hash); | "DefaultStrategyType::INT8x8x16"_hash); | ||||
} else if (format == param::ConvBias::Format::NCHW44) { | } else if (format == param::ConvBias::Format::NCHW44) { | ||||
cb2(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
cb3(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||||
"DefaultStrategyType::INT8x8x16"_hash); | "DefaultStrategyType::INT8x8x16"_hash); | ||||
} else { | } else { | ||||
megdnn_throw( | megdnn_throw( | ||||
@@ -316,9 +332,9 @@ public: | |||||
break; | break; | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
case StrategyType::QUINT8x8x32: | case StrategyType::QUINT8x8x32: | ||||
cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||||
cb3(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||||
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | ||||
PostprocessMode::NO_PROCESS, | |||||
PostprocessMode::ADD_BIAS, | |||||
"DefaultStrategyType::QUINT8x8x32"_hash); | "DefaultStrategyType::QUINT8x8x32"_hash); | ||||
break; | break; | ||||
@@ -331,15 +347,15 @@ public: | |||||
#endif | #endif | ||||
case StrategyType::QINT8x8x32: | case StrategyType::QINT8x8x32: | ||||
if (format == param::ConvBias::Format::NCHW) { | if (format == param::ConvBias::Format::NCHW) { | ||||
cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
cb3(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | ||||
PostprocessMode::NO_PROCESS, | |||||
PostprocessMode::ADD_BIAS, | |||||
"DefaultStrategyTypeNCHW::QINT8x8x32"_hash); | "DefaultStrategyTypeNCHW::QINT8x8x32"_hash); | ||||
} else if (format == param::ConvBias::Format::NCHW44 || | } else if (format == param::ConvBias::Format::NCHW44 || | ||||
format == param::ConvBias::Format::NCHW44_DOT) { | format == param::ConvBias::Format::NCHW44_DOT) { | ||||
cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | |||||
cb3(NCHW44, DEFAULT, dtype::QuantizedS8, | |||||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, | dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, | ||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||||
"DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); | "DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); | ||||
} else { | } else { | ||||
megdnn_throw( | megdnn_throw( | ||||
@@ -467,13 +483,13 @@ public: | |||||
#endif | #endif | ||||
#endif | #endif | ||||
case StrategyType::INT8x8x16: | case StrategyType::INT8x8x16: | ||||
cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
cb3(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||||
"NoPackStrategyType::INT8x8x16"_hash); | "NoPackStrategyType::INT8x8x16"_hash); | ||||
break; | break; | ||||
case StrategyType::INT8x8x32: | case StrategyType::INT8x8x32: | ||||
cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
cb3(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||||
"NoPackStrategyType::INT8x8x32"_hash); | "NoPackStrategyType::INT8x8x32"_hash); | ||||
break; | break; | ||||
default: | default: | ||||
@@ -509,6 +525,7 @@ public: | |||||
#undef cb1 | #undef cb1 | ||||
#undef cb2 | #undef cb2 | ||||
#undef cb3 | |||||
static std::unique_ptr<StrategyBase> make_strategy( | static std::unique_ptr<StrategyBase> make_strategy( | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
@@ -203,18 +203,16 @@ INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||||
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | //! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | ||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | ||||
megdnn::PostprocessMode::QUANTIZED) | megdnn::PostprocessMode::QUANTIZED) | ||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_int32, dt_int32, | |||||
megdnn::PostprocessMode::ADD_BIAS) | |||||
#endif | #endif | ||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | ||||
megdnn::PostprocessMode::QUANTIZED) | megdnn::PostprocessMode::QUANTIZED) | ||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | ||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
megdnn::PostprocessMode::ADD_BIAS) | |||||
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | ||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
megdnn::PostprocessMode::ADD_BIAS) | |||||
#undef INSTANTIAL_CLASS | #undef INSTANTIAL_CLASS | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -119,19 +119,16 @@ INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||||
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | //! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | ||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | ||||
megdnn::PostprocessMode::QUANTIZED) | megdnn::PostprocessMode::QUANTIZED) | ||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_int32, dt_int32, | |||||
megdnn::PostprocessMode::ADD_BIAS) | |||||
#endif | #endif | ||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | ||||
megdnn::PostprocessMode::QUANTIZED) | megdnn::PostprocessMode::QUANTIZED) | ||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | ||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
megdnn::PostprocessMode::ADD_BIAS) | |||||
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | ||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
megdnn::PostprocessMode::ADD_BIAS) | |||||
#undef INSTANTIAL_CLASS | #undef INSTANTIAL_CLASS | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -162,9 +162,9 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | ||||
megdnn::PostprocessMode::FLOAT) | megdnn::PostprocessMode::FLOAT) | ||||
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | ||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
megdnn::PostprocessMode::ADD_BIAS) | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | ||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
megdnn::PostprocessMode::ADD_BIAS) | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
#else | #else | ||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
@@ -294,6 +294,73 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> { | |||||
#undef FOR_BIAS | #undef FOR_BIAS | ||||
} | } | ||||
}; | }; | ||||
#undef CALL_BINARY | |||||
#undef CALL_BINARY_BROADCAST | |||||
#define CALL_BINARY(_op, _simd_type) \ | |||||
thin_function<void(const ctype*, const ctype*, dtype*, DType, DType, \ | |||||
DType, size_t)> \ | |||||
run = OpCallerBinary<_op<_simd_type, ctype, dtype>, _simd_type, \ | |||||
megdnn::x86::BcastType::VEC_VEC>::run; \ | |||||
run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \ | |||||
reinterpret_cast<dtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
N* OC* OH* OW); | |||||
#define CALL_BINARY_BROADCAST(_op, _simd_type) \ | |||||
thin_function<void(const ctype*, const ctype*, dtype*, DType, DType, \ | |||||
DType, size_t, size_t, size_t)> \ | |||||
run = OpCallerBinary<_op<_simd_type, ctype, dtype>, _simd_type, \ | |||||
megdnn::x86::BcastType::VEC_BCAST101>::run; \ | |||||
run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \ | |||||
reinterpret_cast<dtype*>(dst_ptr), bias_type, bias_type, dst_type, N, \ | |||||
OC, OH* OW); | |||||
#define FOR_SIMD(CALLER) \ | |||||
if (is_supported(SIMDType::AVX2)) { \ | |||||
CALLER(AddOp, SIMDType::AVX2) \ | |||||
} else if (is_supported(SIMDType::SSE4_2)) { \ | |||||
CALLER(AddOp, SIMDType::SSE4_2) \ | |||||
} else { \ | |||||
CALLER(AddOp, SIMDType::NONE) \ | |||||
} | |||||
#define FOR_BIAS(bias_mode) \ | |||||
switch (bias_mode) { \ | |||||
case BiasMode::BIAS: \ | |||||
FOR_SIMD(CALL_BINARY); \ | |||||
break; \ | |||||
case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
FOR_SIMD(CALL_BINARY_BROADCAST); \ | |||||
break; \ | |||||
default: \ | |||||
break; \ | |||||
} | |||||
template <typename ctype, typename dtype> | |||||
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { | |||||
static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | |||||
megdnn::ConvBiasForward::BiasMode bias_mode, | |||||
megdnn::param::ConvBiasV0::NonlineMode nonlineMode, | |||||
DType bias_type, DType dst_type, size_t N, size_t OC, | |||||
size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||||
megdnn_assert(pack_oc_size == 1, | |||||
"PostProcess only support nchw in x86"); | |||||
megdnn_assert( | |||||
nonlineMode == megdnn::param::ConvBiasV0::NonlineMode::IDENTITY, | |||||
"Add bias PostProcess only support IDENTITY"); | |||||
if (bias_mode == megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||||
return; | |||||
} | |||||
FOR_BIAS(bias_mode); | |||||
#undef CALL_BINARY | |||||
#undef CALL_BINARY_BROADCAST | |||||
#undef FOR_SIMD | |||||
#undef FOR_BIAS | |||||
} | |||||
}; | |||||
#undef cb_unary | #undef cb_unary | ||||
#undef cb_binary | #undef cb_binary | ||||
#undef BIAS_CASE | #undef BIAS_CASE | ||||
@@ -92,6 +92,8 @@ OP(dt_int8, SIMDType::AVX2, "avx2", __m256i, __m256ix2, __m256i, mm256, epi8, | |||||
using AddOpBase::operator(); \ | using AddOpBase::operator(); \ | ||||
}; | }; | ||||
OP(dt_int32, SIMDType::NONE); | |||||
OP(dt_int16, SIMDType::NONE); | |||||
OP(dt_float32, SIMDType::NONE); | OP(dt_float32, SIMDType::NONE); | ||||
#undef OP | #undef OP | ||||
} // namespace x86 | } // namespace x86 | ||||
@@ -1992,13 +1992,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) { | |||||
#define cb(name) \ | #define cb(name) \ | ||||
checker_conv_bias( \ | checker_conv_bias( \ | ||||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||||
true, false, true, false, false, true), \ | |||||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, \ | |||||
true, false, true, true, false, false), \ | |||||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | ||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \ | dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \ | ||||
checker_conv_bias( \ | checker_conv_bias( \ | ||||
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | ||||
false, false, true), \ | |||||
true, false, false), \ | |||||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | ||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); | dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); | ||||
@@ -2041,13 +2041,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32_MK4_DOT) { | |||||
#define cb(name) \ | #define cb(name) \ | ||||
checker_conv_bias( \ | checker_conv_bias( \ | ||||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||||
true, false, true, false, false, true), \ | |||||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, \ | |||||
true, false, true, true, false, false), \ | |||||
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | ||||
dtype::Int32(), {}, name); \ | dtype::Int32(), {}, name); \ | ||||
checker_conv_bias( \ | checker_conv_bias( \ | ||||
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | ||||
false, false, true), \ | |||||
true, false, false), \ | |||||
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | ||||
dtype::Int32(), {}, name); | dtype::Int32(), {}, name); | ||||
@@ -2118,7 +2118,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) { | |||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { | ||||
NormalRNG rng(128.f); | NormalRNG rng(128.f); | ||||
#define cb(name) \ | #define cb(name) \ | ||||
checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | ||||
false, true, true), \ | false, true, true), \ | ||||
@@ -2188,18 +2187,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { | ||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
float epsilon = 0.001; | float epsilon = 0.001; | ||||
#define cb(name) \ | |||||
checker_conv_bias( \ | |||||
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ | |||||
handle(), &rng, epsilon, \ | |||||
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||||
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||||
dtype::QuantizedS32(1.2 * 1.3), {}, name); \ | |||||
checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \ | |||||
&rng, epsilon, \ | |||||
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||||
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||||
dtype::QuantizedS32(1.2 * 1.3), {}, name); | |||||
#define cb(name) \ | |||||
checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||||
true, true, false), \ | |||||
handle(), &rng, epsilon, \ | |||||
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||||
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||||
dtype::QuantizedS32(1.2 * 1.3), {}, name); \ | |||||
checker_conv_bias( \ | |||||
get_conv_bias_args({1}, 2, false, false, true, true, false), \ | |||||
handle(), &rng, epsilon, \ | |||||
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||||
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||||
dtype::QuantizedS32(1.2 * 1.3), {}, name); | |||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
@@ -2252,18 +2252,18 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | |||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
float epsilon = 0.001; | float epsilon = 0.001; | ||||
std::vector<conv_bias::TestArg> args_nchw44 = | std::vector<conv_bias::TestArg> args_nchw44 = | ||||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, true, true, | |||||
false, false, false, false, true); | |||||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, false, true, | |||||
false, false, true, false, false); | |||||
std::vector<conv_bias::TestArg> args_nchw44_1x1s2 = | std::vector<conv_bias::TestArg> args_nchw44_1x1s2 = | ||||
get_nchw44_conv_bias_args({1}, 2, true, true, true, false, false, | |||||
false, false, true); | |||||
#define cb(name) \ | |||||
checker_conv_bias( \ | |||||
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ | |||||
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||||
dtype::Int16{}, dtype::Int16{}, name); \ | |||||
checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \ | |||||
&rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||||
get_nchw44_conv_bias_args({1}, 2, true, false, true, false, false, | |||||
true, false, false); | |||||
#define cb(name) \ | |||||
checker_conv_bias( \ | |||||
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, true), \ | |||||
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||||
dtype::Int16{}, dtype::Int16{}, name); \ | |||||
checker_conv_bias(get_conv_bias_args({1}, 2, false, false, true), \ | |||||
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||||
dtype::Int16{}, dtype::Int16{}, name); | dtype::Int16{}, dtype::Int16{}, name); | ||||
#define cb_nchw44(name) \ | #define cb_nchw44(name) \ | ||||
@@ -2314,14 +2314,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCES | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) { | ||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
float epsilon = 0.001; | float epsilon = 0.001; | ||||
#define cb(name) \ | |||||
check_conv_bias_preprocess( \ | |||||
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ | |||||
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||||
dtype::Int16{}, dtype::Int16{}, name); \ | |||||
check_conv_bias_preprocess(get_conv_bias_args({1}, 2, false, true, true), \ | |||||
handle(), &rng, epsilon, dtype::Int8{}, \ | |||||
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, \ | |||||
#define cb(name) \ | |||||
check_conv_bias_preprocess( \ | |||||
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, true), \ | |||||
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||||
dtype::Int16{}, dtype::Int16{}, name); \ | |||||
check_conv_bias_preprocess(get_conv_bias_args({1}, 2, false, false, true), \ | |||||
handle(), &rng, epsilon, dtype::Int8{}, \ | |||||
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, \ | |||||
name); | name); | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
@@ -2406,7 +2406,7 @@ void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args, | |||||
checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | ||||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | .set_dtype(1, dtype::QuantizedS8(2.5f)) | ||||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | .set_dtype(2, dtype::QuantizedS32(6.25f)) | ||||
.set_dtype(4, {}) | |||||
.set_dtype(4, dtype::QuantizedS32(6.25f)) | |||||
.set_rng(0, &rng) | .set_rng(0, &rng) | ||||
.set_rng(1, &rng) | .set_rng(1, &rng) | ||||
.set_rng(2, &rng) | .set_rng(2, &rng) | ||||
@@ -2436,7 +2436,7 @@ void checker_conv_bias_int8x8x32_preprocess(std::vector<conv_bias::TestArg> args | |||||
checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | ||||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | .set_dtype(1, dtype::QuantizedS8(2.5f)) | ||||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | .set_dtype(2, dtype::QuantizedS32(6.25f)) | ||||
.set_dtype(4, {}) | |||||
.set_dtype(4, dtype::QuantizedS32(6.25f)) | |||||
.set_rng(0, &rng) | .set_rng(0, &rng) | ||||
.set_rng(1, &rng) | .set_rng(1, &rng) | ||||
.set_rng(2, &rng) | .set_rng(2, &rng) | ||||
@@ -2450,7 +2450,7 @@ void checker_conv_bias_int8x8x32_preprocess(std::vector<conv_bias::TestArg> args | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = | std::vector<conv_bias::TestArg> args = | ||||
get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true); | |||||
get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true); | |||||
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
@@ -2464,7 +2464,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = | std::vector<conv_bias::TestArg> args = | ||||
get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true); | |||||
get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true); | |||||
#define cb(name) checker_conv_bias_int8x8x32_preprocess(args, handle(), name); | #define cb(name) checker_conv_bias_int8x8x32_preprocess(args, handle(), name); | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
@@ -2478,7 +2478,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPR | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = | std::vector<conv_bias::TestArg> args = | ||||
get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true); | |||||
get_nchw44_conv_bias_args({3, 4, 6}, 1, false, false, true); | |||||
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
@@ -3080,9 +3080,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32_PREPROCESS) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { | ||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
float epsilon = 0.001; | float epsilon = 0.001; | ||||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_conv_bias_1x1_args(false, true, false, false); | |||||
std::vector<conv_bias::TestArg> args_nchw44 = get_nchw44_conv_bias_args( | std::vector<conv_bias::TestArg> args_nchw44 = get_nchw44_conv_bias_args( | ||||
{1}, 1, true, true, true, false, false, false, false, true); | |||||
{1}, 1, true, true, true, false, false, true, false, false); | |||||
#define cb(name) \ | #define cb(name) \ | ||||
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \ | checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \ | ||||
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); | dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); | ||||
@@ -3140,7 +3141,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16_PREPROCESS) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_conv_bias_1x1_args(false, true, false, false); | |||||
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | ||||
@@ -834,6 +834,13 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X32) { | |||||
//! no bias | //! no bias | ||||
args.emplace_back(param, TensorShape{1, ic, h, w}, | args.emplace_back(param, TensorShape{1, ic, h, w}, | ||||
TensorShape{oc, ic, kernel, kernel}, TensorShape{}); | TensorShape{oc, ic, kernel, kernel}, TensorShape{}); | ||||
args.emplace_back(param, TensorShape{1, ic, h, w}, | |||||
TensorShape{oc, ic, kernel, kernel}, | |||||
TensorShape{1, oc, 1, 1}); | |||||
args.emplace_back(param, TensorShape{1, ic, h, w}, | |||||
TensorShape{oc, ic, kernel, kernel}, | |||||
TensorShape{1, oc, (h + 2 * p - kernel) + 1, | |||||
(h + 2 * p - kernel) + 1}); | |||||
}; | }; | ||||
for (size_t kernel : {2, 3, 4, 5, 6, 7}) | for (size_t kernel : {2, 3, 4, 5, 6, 7}) | ||||
@@ -1384,7 +1391,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
float epsilon = 0.001; | float epsilon = 0.001; | ||||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, true); | |||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
if (x86::is_supported(x86::SIMDType::VNNI)) { | if (x86::is_supported(x86::SIMDType::VNNI)) { | ||||
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | ||||
@@ -1422,7 +1429,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32_PREPROCESS) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
float epsilon = 0.001; | float epsilon = 0.001; | ||||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, true); | |||||
#if MEGDNN_X86_WITH_VNNI | #if MEGDNN_X86_WITH_VNNI | ||||
if (x86::is_supported(x86::SIMDType::VNNI)) { | if (x86::is_supported(x86::SIMDType::VNNI)) { | ||||
checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{}, | checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{}, | ||||