From b778d22523420ba51027ea404bef350cf638b151 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 7 Aug 2020 16:53:28 +0800 Subject: [PATCH] feat(mgb/fallback): add conv1x1_gemv, conv1x1 and im2col 8x8x16/8x8x32 support bias GitOrigin-RevId: 3d97fedc8f33d0b41f94680d6710c56bc32b62e7 --- dnn/src/arm_common/conv_bias/postprocess_helper.h | 61 +++++++++++++- dnn/src/fallback/conv_bias/common.h | 8 +- dnn/src/fallback/conv_bias/conv1x1/algos.cpp | 3 +- .../conv_bias/conv1x1/algos_conv1x1_gemv.cpp | 41 ++++++---- .../conv_bias/conv1x1/conv1x1_strategy.cpp | 44 ++++++---- dnn/src/fallback/conv_bias/im2col/algos.cpp | 3 +- dnn/src/fallback/conv_bias/im2col/factory.h | 53 +++++++----- .../fallback/conv_bias/im2col/strategy_default.cpp | 10 +-- .../conv_bias/im2col/strategy_default_nchw44.cpp | 11 +-- .../fallback/conv_bias/im2col/strategy_nopack.cpp | 4 +- dnn/src/x86/conv_bias/postprocess_helper.h | 67 +++++++++++++++ dnn/src/x86/elemwise_helper/kimpl/add.h | 2 + dnn/test/arm_common/conv_bias_multi_thread.cpp | 94 +++++++++++----------- dnn/test/x86/conv_bias.cpp | 11 ++- 14 files changed, 294 insertions(+), 118 deletions(-) diff --git a/dnn/src/arm_common/conv_bias/postprocess_helper.h b/dnn/src/arm_common/conv_bias/postprocess_helper.h index d665bfec..bcfa718c 100644 --- a/dnn/src/arm_common/conv_bias/postprocess_helper.h +++ b/dnn/src/arm_common/conv_bias/postprocess_helper.h @@ -100,7 +100,6 @@ namespace { MIDOUT_END(); \ break; \ default: \ - megdnn_throw("no quantized unsupported biasmode"); \ break; \ } @@ -258,6 +257,66 @@ struct PostProcess { #undef FOR_NONLINEAR_NOBIAS #undef FOR_NONLINEAR #undef FOR_BIAS + +#define FOR_BINARY_BROADCAST(_op) \ + megdnn::arm_common:: \ + OpCallerBinary<_op, megdnn::arm_common::VEC_BCAST101>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, \ + dst_type, N, OC, OH* OW); + +#define FOR_BINARY_BROADCAST_NCHW44(_op) \ + megdnn::arm_common::OpCallerBinary<_op, \ + megdnn::arm_common::VEC_BCAST101x4>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, \ + dst_type, N, OC, OH* OW, pack_oc_size); + +#define FOR_BINARY(_op) \ + megdnn::arm_common:: \ + OpCallerBinary<_op, megdnn::arm_common::VEC_VEC>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, \ + dst_type, N* OC* OH* OW* pack_oc_size); + +#define FOR_BIAS(_bias_mode, OH, OW) \ + switch (_bias_mode) { \ + case megdnn::BiasMode::NO_BIAS: \ + break; \ + case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ + if (pack_oc_size == 1) { \ + FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \ + } else { \ + megdnn_assert(pack_oc_size == 4, \ + "Only support nchw44 in ARM"); \ + FOR_BINARY_BROADCAST_NCHW44(CONCAT_OP(AddOp)); \ + } \ + break; \ + case megdnn::BiasMode::BIAS: \ + FOR_BINARY(CONCAT_OP(AddOp)); \ + break; \ + default: \ + break; \ + } + +template +struct PostProcess { + static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, + size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { + megdnn_assert(nonlineMode == megdnn::NonlineMode::IDENTITY); + FOR_BIAS(bias_mode, OH, OW); + } +}; + +#undef FOR_BINARY_BROADCAST +#undef FOR_BINARY_BROADCAST_NCHW44 +#undef FOR_BINARY +#undef FOR_BIAS #undef CB #undef CONCAT_OP #undef CONCAT_NL diff --git a/dnn/src/fallback/conv_bias/common.h b/dnn/src/fallback/conv_bias/common.h index abe313dd..60f18a6d 100644 --- a/dnn/src/fallback/conv_bias/common.h +++ b/dnn/src/fallback/conv_bias/common.h @@ -158,9 +158,11 @@ private: \ uint32_t m_tile_size; enum class PostprocessMode : uint8_t { - FLOAT = 0, ///< support all biasmode and no_nonlinemode - NO_PROCESS, ///::enumv && \ + param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ + conv1x1_gemv_worker = \ + Conv1x1GemvWorker<_src_ctype, _bias_ctype, _dst_ctype, \ + _bias_ctype, _dst_ctype, \ + _postprocess_mode, _format>::exec; \ + } \ + } \ + MIDOUT_END() switch (param.filter_meta.format) { case param::ConvBias::Format::NCHW: @@ -324,23 +337,23 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash); #endif #endif - cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, - dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, + cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, + dt_int8, dt_int32, dt_int32, PostprocessMode::ADD_BIAS, "NCHW::GEMV::INT8x8x32_INT32"_hash); - cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int16, dt_int16, - dt_int8, dt_int16, dt_int16, PostprocessMode::NO_PROCESS, + cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int16, dt_int16, + dt_int8, dt_int16, dt_int16, PostprocessMode::ADD_BIAS, "NCHW::GEMV::INT8x8x16_INT16"_hash); - cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8, + cb3(param::ConvBias::Format::NCHW, dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, - dt_int32, PostprocessMode::NO_PROCESS, + dt_int32, PostprocessMode::ADD_BIAS, "NCHW::GEMV::QINT8x8x32_QINT32"_hash); cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, PostprocessMode::QUANTIZED, "NCHW::GEMV::QINT8x8x32_QINT8"_hash); - cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, + cb3(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, dt_int32, - dt_int32, PostprocessMode::NO_PROCESS, + dt_int32, PostprocessMode::ADD_BIAS, "NCHW::GEMV::QUINT8x8x32_QINT32"_hash); cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, dt_int32, @@ -365,13 +378,13 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( break; case param::ConvBias::Format::NCHW44_DOT: - cb2(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32, + cb3(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32, - PostprocessMode::NO_PROCESS, + PostprocessMode::ADD_BIAS, "NCHW44_DOT::GEMV::INT8x8x32_INT32"_hash); - cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, + cb3(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, - dt_int32, PostprocessMode::NO_PROCESS, + dt_int32, PostprocessMode::ADD_BIAS, "NCHW44_DOT::GEMV::QINT8x8x32_QINT32"_hash); cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, @@ -385,6 +398,7 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( } #undef cb1 #undef cb2 +#undef cb3 megdnn_assert(conv1x1_gemv_worker, "No suitable gemv worker"); @@ -448,8 +462,7 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(const NCBKernSizeParam& param, if (param.dst_type.enumv() == DTypeEnum::Int16 || param.dst_type.enumv() == DTypeEnum::Int32 || param.dst_type.enumv() == DTypeEnum::QuantizedS32) { - if (param.bias_mode != megdnn::BiasMode::NO_BIAS || - param.nonlineMode != megdnn::NonlineMode::IDENTITY) { + if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) { return false; } } diff --git a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp index c5f9151d..ee927eaf 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp @@ -56,6 +56,19 @@ std::unique_ptr create_conv1x1_strategy( } \ } \ MIDOUT_END() +#define cb3(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ + _bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ + MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ + midout_iv(_midout_tag)) { \ + if (param.filter_type.enumv() == param.src_type.enumv() && \ + param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ + param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ + return std::make_unique>(pack_c_size); \ + } \ + } \ + MIDOUT_END() switch (pack_mode) { case MatrixMulImpl::AlgoBase::PackMode::DEFAULT: @@ -71,26 +84,26 @@ std::unique_ptr create_conv1x1_strategy( "Default::FLOAT16_FLOAT16"_hash); #endif #endif - cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int32, + cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32, - PostprocessMode::NO_PROCESS, "Default::INT8x8x32_INT32"_hash); - cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int16, + PostprocessMode::ADD_BIAS, "Default::INT8x8x32_INT32"_hash); + cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, dt_int16, - PostprocessMode::NO_PROCESS, "Default::INT8x8x16_INT16"_hash); + PostprocessMode::ADD_BIAS, "Default::INT8x8x16_INT16"_hash); #if MEGDNN_AARCH64 || MEGDNN_ARMV7 - cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, + cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, - PostprocessMode::NO_PROCESS, + PostprocessMode::ADD_BIAS, "Default::QUINT8x8x32_QINT32"_hash); cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, PostprocessMode::QUANTIZED, "Default::QUINT8x8x32_QUINT8"_hash); #endif - cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, + cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, - dt_int32, PostprocessMode::NO_PROCESS, + dt_int32, PostprocessMode::ADD_BIAS, "Default::QINT8x8x32_QINT32"_hash); cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, @@ -107,17 +120,17 @@ std::unique_ptr create_conv1x1_strategy( cb1(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_float32, dt_float32, PostprocessMode::FLOAT, "NoPack::FLOAT"_hash); - cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int16, + cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, dt_int16, - PostprocessMode::NO_PROCESS, "NoPack::INT8x8x16_INT16"_hash); + PostprocessMode::ADD_BIAS, "NoPack::INT8x8x16_INT16"_hash); - cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int32, + cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32, - PostprocessMode::NO_PROCESS, "NoPack::INT8x8x32_INT32"_hash); + PostprocessMode::ADD_BIAS, "NoPack::INT8x8x32_INT32"_hash); - cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dtype::QuantizedS8, + cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, - dt_int32, PostprocessMode::NO_PROCESS, + dt_int32, PostprocessMode::ADD_BIAS, "NoPack::QINT8x8x32_QINT32"_hash); break; @@ -127,6 +140,7 @@ std::unique_ptr create_conv1x1_strategy( } #undef cb1 #undef cb2 +#undef cb3 megdnn_throw("Invalid Data Type"); return nullptr; } @@ -207,4 +221,4 @@ bool Conv1x1Factory::can_make_conv1x1_strategy( } // namespace fallback } // namespace megdnn -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/im2col/algos.cpp b/dnn/src/fallback/conv_bias/im2col/algos.cpp index 666c099e..0068ef4c 100644 --- a/dnn/src/fallback/conv_bias/im2col/algos.cpp +++ b/dnn/src/fallback/conv_bias/im2col/algos.cpp @@ -746,8 +746,7 @@ bool ConvBiasImpl::AlgoIm2col::usable( if (param.dst_type.enumv() == DTypeEnum::Int16 || param.dst_type.enumv() == DTypeEnum::Int32 || param.dst_type.enumv() == DTypeEnum::QuantizedS32) { - if (param.bias_mode != megdnn::BiasMode::NO_BIAS || - param.nonlineMode != megdnn::NonlineMode::IDENTITY) { + if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) { return false; } } diff --git a/dnn/src/fallback/conv_bias/im2col/factory.h b/dnn/src/fallback/conv_bias/im2col/factory.h index f4fbf529..b48d4f0d 100644 --- a/dnn/src/fallback/conv_bias/im2col/factory.h +++ b/dnn/src/fallback/conv_bias/im2col/factory.h @@ -213,6 +213,22 @@ public: } \ MIDOUT_END(); \ return {}; +#define cb3(_format, _packmode, _i_src_type, _i_bias_type, _i_dst_type, \ + _src_ctype, _bias_ctype, _dst_ctype, _postprocess_mode, \ + _midout_tag) \ + MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ + midout_iv(_midout_tag)) { \ + if (param.filter_type.enumv() == param.src_type.enumv() && \ + param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ + param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ + return std::make_unique< \ + Strategy<_src_ctype, _bias_ctype, _dst_ctype, _bias_ctype, \ + _dst_ctype, _postprocess_mode, \ + PackMode::_packmode, FormatMode::_format>>(); \ + } \ + } \ + MIDOUT_END(); \ + return {}; static std::unique_ptr make_default_strategy( fallback::MatrixMulImpl::AlgoBase* matmul_algo, @@ -279,13 +295,13 @@ public: #endif case StrategyType::INT8x8x32: if (format == param::ConvBias::Format::NCHW) { - cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, - dt_int32, dt_int32, PostprocessMode::NO_PROCESS, + cb3(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, + dt_int32, dt_int32, PostprocessMode::ADD_BIAS, "DefaultStrategyType::INT8x8x32"_hash); } else if (format == param::ConvBias::Format::NCHW44 || format == param::ConvBias::Format::NCHW44_DOT) { - cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, - dt_int32, dt_int32, PostprocessMode::NO_PROCESS, + cb3(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, + dt_int32, dt_int32, PostprocessMode::ADD_BIAS, "DefaultStrategyType::INT8x8x32"_hash); } else { megdnn_throw( @@ -299,12 +315,12 @@ public: case StrategyType::INT8x8x16: if (format == param::ConvBias::Format::NCHW) { - cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, - dt_int16, dt_int16, PostprocessMode::NO_PROCESS, + cb3(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, + dt_int16, dt_int16, PostprocessMode::ADD_BIAS, "DefaultStrategyType::INT8x8x16"_hash); } else if (format == param::ConvBias::Format::NCHW44) { - cb2(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, - dt_int16, dt_int16, PostprocessMode::NO_PROCESS, + cb3(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, + dt_int16, dt_int16, PostprocessMode::ADD_BIAS, "DefaultStrategyType::INT8x8x16"_hash); } else { megdnn_throw( @@ -316,9 +332,9 @@ public: break; #if MEGDNN_AARCH64 || MEGDNN_ARMV7 case StrategyType::QUINT8x8x32: - cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, + cb3(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, - PostprocessMode::NO_PROCESS, + PostprocessMode::ADD_BIAS, "DefaultStrategyType::QUINT8x8x32"_hash); break; @@ -331,15 +347,15 @@ public: #endif case StrategyType::QINT8x8x32: if (format == param::ConvBias::Format::NCHW) { - cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, + cb3(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, - PostprocessMode::NO_PROCESS, + PostprocessMode::ADD_BIAS, "DefaultStrategyTypeNCHW::QINT8x8x32"_hash); } else if (format == param::ConvBias::Format::NCHW44 || format == param::ConvBias::Format::NCHW44_DOT) { - cb2(NCHW44, DEFAULT, dtype::QuantizedS8, + cb3(NCHW44, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, - dt_int32, dt_int32, PostprocessMode::NO_PROCESS, + dt_int32, dt_int32, PostprocessMode::ADD_BIAS, "DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); } else { megdnn_throw( @@ -467,13 +483,13 @@ public: #endif #endif case StrategyType::INT8x8x16: - cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, - dt_int16, dt_int16, PostprocessMode::NO_PROCESS, + cb3(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, + dt_int16, dt_int16, PostprocessMode::ADD_BIAS, "NoPackStrategyType::INT8x8x16"_hash); break; case StrategyType::INT8x8x32: - cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, - dt_int32, dt_int32, PostprocessMode::NO_PROCESS, + cb3(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, + dt_int32, dt_int32, PostprocessMode::ADD_BIAS, "NoPackStrategyType::INT8x8x32"_hash); break; default: @@ -509,6 +525,7 @@ public: #undef cb1 #undef cb2 +#undef cb3 static std::unique_ptr make_strategy( fallback::MatrixMulImpl::AlgoBase* matmul_algo, diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_default.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_default.cpp index 25d5e6ec..4b5fb720 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_default.cpp +++ b/dnn/src/fallback/conv_bias/im2col/strategy_default.cpp @@ -203,18 +203,16 @@ INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, //! x86 do not have uint8 matmul so only armv7 armv8 support uint8 INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, megdnn::PostprocessMode::QUANTIZED) -INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, - megdnn::PostprocessMode::NO_PROCESS) +INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_int32, dt_int32, + megdnn::PostprocessMode::ADD_BIAS) #endif INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, megdnn::PostprocessMode::QUANTIZED) INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, - megdnn::PostprocessMode::NO_PROCESS) + megdnn::PostprocessMode::ADD_BIAS) INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, - megdnn::PostprocessMode::NO_PROCESS) -INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, - megdnn::PostprocessMode::NO_PROCESS) + megdnn::PostprocessMode::ADD_BIAS) #undef INSTANTIAL_CLASS } // namespace megdnn diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp index e2b77721..213a0193 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp +++ b/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp @@ -119,19 +119,16 @@ INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, //! x86 do not have uint8 matmul so only armv7 armv8 support uint8 INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, megdnn::PostprocessMode::QUANTIZED) -INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, - megdnn::PostprocessMode::NO_PROCESS) +INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_int32, dt_int32, + megdnn::PostprocessMode::ADD_BIAS) #endif INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, megdnn::PostprocessMode::QUANTIZED) INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, - megdnn::PostprocessMode::NO_PROCESS) + megdnn::PostprocessMode::ADD_BIAS) INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, - megdnn::PostprocessMode::NO_PROCESS) -INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, - megdnn::PostprocessMode::NO_PROCESS) - + megdnn::PostprocessMode::ADD_BIAS) #undef INSTANTIAL_CLASS } // namespace megdnn diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp index 1ab41b71..c3a05d20 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp +++ b/dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp @@ -162,9 +162,9 @@ void Strategy { #undef FOR_BIAS } }; + +#undef CALL_BINARY +#undef CALL_BINARY_BROADCAST + +#define CALL_BINARY(_op, _simd_type) \ + thin_function \ + run = OpCallerBinary<_op<_simd_type, ctype, dtype>, _simd_type, \ + megdnn::x86::BcastType::VEC_VEC>::run; \ + run(static_cast(conv_dst_ptr), static_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ + N* OC* OH* OW); + +#define CALL_BINARY_BROADCAST(_op, _simd_type) \ + thin_function \ + run = OpCallerBinary<_op<_simd_type, ctype, dtype>, _simd_type, \ + megdnn::x86::BcastType::VEC_BCAST101>::run; \ + run(static_cast(conv_dst_ptr), static_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, \ + OC, OH* OW); + +#define FOR_SIMD(CALLER) \ + if (is_supported(SIMDType::AVX2)) { \ + CALLER(AddOp, SIMDType::AVX2) \ + } else if (is_supported(SIMDType::SSE4_2)) { \ + CALLER(AddOp, SIMDType::SSE4_2) \ + } else { \ + CALLER(AddOp, SIMDType::NONE) \ + } + +#define FOR_BIAS(bias_mode) \ + switch (bias_mode) { \ + case BiasMode::BIAS: \ + FOR_SIMD(CALL_BINARY); \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + FOR_SIMD(CALL_BINARY_BROADCAST); \ + break; \ + default: \ + break; \ + } + +template +struct PostProcess { + static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, + megdnn::ConvBiasForward::BiasMode bias_mode, + megdnn::param::ConvBiasV0::NonlineMode nonlineMode, + DType bias_type, DType dst_type, size_t N, size_t OC, + size_t OH, size_t OW, size_t pack_oc_size = 1) { + MEGDNN_MARK_USED_VAR(pack_oc_size); + megdnn_assert(pack_oc_size == 1, + "PostProcess only support nchw in x86"); + megdnn_assert( + nonlineMode == megdnn::param::ConvBiasV0::NonlineMode::IDENTITY, + "Add bias PostProcess only support IDENTITY"); + if (bias_mode == megdnn::ConvBiasForward::BiasMode::NO_BIAS) { + return; + } + FOR_BIAS(bias_mode); +#undef CALL_BINARY +#undef CALL_BINARY_BROADCAST +#undef FOR_SIMD +#undef FOR_BIAS + } +}; + #undef cb_unary #undef cb_binary #undef BIAS_CASE diff --git a/dnn/src/x86/elemwise_helper/kimpl/add.h b/dnn/src/x86/elemwise_helper/kimpl/add.h index b76149b5..1792484e 100644 --- a/dnn/src/x86/elemwise_helper/kimpl/add.h +++ b/dnn/src/x86/elemwise_helper/kimpl/add.h @@ -92,6 +92,8 @@ OP(dt_int8, SIMDType::AVX2, "avx2", __m256i, __m256ix2, __m256i, mm256, epi8, using AddOpBase::operator(); \ }; +OP(dt_int32, SIMDType::NONE); +OP(dt_int16, SIMDType::NONE); OP(dt_float32, SIMDType::NONE); #undef OP } // namespace x86 diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 91336e86..fae40b97 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -1992,13 +1992,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) { #define cb(name) \ checker_conv_bias( \ - get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ - true, false, true, false, false, true), \ + get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, \ + true, false, true, true, false, false), \ handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \ checker_conv_bias( \ get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ - false, false, true), \ + true, false, false), \ handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); @@ -2041,13 +2041,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32_MK4_DOT) { #define cb(name) \ checker_conv_bias( \ - get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ - true, false, true, false, false, true), \ + get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, \ + true, false, true, true, false, false), \ handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ dtype::Int32(), {}, name); \ checker_conv_bias( \ get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ - false, false, true), \ + true, false, false), \ handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ dtype::Int32(), {}, name); @@ -2118,7 +2118,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) { #if MEGDNN_AARCH64 || MEGDNN_ARMV7 TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { NormalRNG rng(128.f); - #define cb(name) \ checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ false, true, true), \ @@ -2188,18 +2187,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { UniformIntRNG rng{-50, 50}; float epsilon = 0.001; -#define cb(name) \ - checker_conv_bias( \ - get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ - handle(), &rng, epsilon, \ - dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ - dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ - dtype::QuantizedS32(1.2 * 1.3), {}, name); \ - checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \ - &rng, epsilon, \ - dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ - dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ - dtype::QuantizedS32(1.2 * 1.3), {}, name); +#define cb(name) \ + checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ + true, true, false), \ + handle(), &rng, epsilon, \ + dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ + dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ + dtype::QuantizedS32(1.2 * 1.3), {}, name); \ + checker_conv_bias( \ + get_conv_bias_args({1}, 2, false, false, true, true, false), \ + handle(), &rng, epsilon, \ + dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ + dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ + dtype::QuantizedS32(1.2 * 1.3), {}, name); #if MEGDNN_AARCH64 #if __ARM_FEATURE_DOTPROD @@ -2252,18 +2252,18 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { UniformIntRNG rng{-50, 50}; float epsilon = 0.001; std::vector args_nchw44 = - get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, true, true, - false, false, false, false, true); + get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, false, true, + false, false, true, false, false); std::vector args_nchw44_1x1s2 = - get_nchw44_conv_bias_args({1}, 2, true, true, true, false, false, - false, false, true); -#define cb(name) \ - checker_conv_bias( \ - get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ - handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ - dtype::Int16{}, dtype::Int16{}, name); \ - checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \ - &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ + get_nchw44_conv_bias_args({1}, 2, true, false, true, false, false, + true, false, false); +#define cb(name) \ + checker_conv_bias( \ + get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, true), \ + handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ + dtype::Int16{}, dtype::Int16{}, name); \ + checker_conv_bias(get_conv_bias_args({1}, 2, false, false, true), \ + handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ dtype::Int16{}, dtype::Int16{}, name); #define cb_nchw44(name) \ @@ -2314,14 +2314,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCES TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) { UniformIntRNG rng{-50, 50}; float epsilon = 0.001; -#define cb(name) \ - check_conv_bias_preprocess( \ - get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ - handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ - dtype::Int16{}, dtype::Int16{}, name); \ - check_conv_bias_preprocess(get_conv_bias_args({1}, 2, false, true, true), \ - handle(), &rng, epsilon, dtype::Int8{}, \ - dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, \ +#define cb(name) \ + check_conv_bias_preprocess( \ + get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, true), \ + handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ + dtype::Int16{}, dtype::Int16{}, name); \ + check_conv_bias_preprocess(get_conv_bias_args({1}, 2, false, false, true), \ + handle(), &rng, epsilon, dtype::Int8{}, \ + dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, \ name); #if MEGDNN_AARCH64 @@ -2406,7 +2406,7 @@ void checker_conv_bias_mul_int8x8x32(std::vector args, checker.set_dtype(0, dtype::QuantizedS8(2.5f)) .set_dtype(1, dtype::QuantizedS8(2.5f)) .set_dtype(2, dtype::QuantizedS32(6.25f)) - .set_dtype(4, {}) + .set_dtype(4, dtype::QuantizedS32(6.25f)) .set_rng(0, &rng) .set_rng(1, &rng) .set_rng(2, &rng) @@ -2436,7 +2436,7 @@ void checker_conv_bias_int8x8x32_preprocess(std::vector args checker.set_dtype(0, dtype::QuantizedS8(2.5f)) .set_dtype(1, dtype::QuantizedS8(2.5f)) .set_dtype(2, dtype::QuantizedS32(6.25f)) - .set_dtype(4, {}) + .set_dtype(4, dtype::QuantizedS32(6.25f)) .set_rng(0, &rng) .set_rng(1, &rng) .set_rng(2, &rng) @@ -2450,7 +2450,7 @@ void checker_conv_bias_int8x8x32_preprocess(std::vector args TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { using namespace conv_bias; std::vector args = - get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true); + get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true); #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); #if MEGDNN_AARCH64 @@ -2464,7 +2464,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) { using namespace conv_bias; std::vector args = - get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true); + get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true); #define cb(name) checker_conv_bias_int8x8x32_preprocess(args, handle(), name); #if MEGDNN_AARCH64 @@ -2478,7 +2478,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPR TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) { using namespace conv_bias; std::vector args = - get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true); + get_nchw44_conv_bias_args({3, 4, 6}, 1, false, false, true); #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); #if MEGDNN_AARCH64 @@ -3080,9 +3080,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32_PREPROCESS) { TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { UniformIntRNG rng{-50, 50}; float epsilon = 0.001; - std::vector args = get_conv_bias_1x1_args(true, true); + std::vector args = + get_conv_bias_1x1_args(false, true, false, false); std::vector args_nchw44 = get_nchw44_conv_bias_args( - {1}, 1, true, true, true, false, false, false, false, true); + {1}, 1, true, true, true, false, false, true, false, false); #define cb(name) \ checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \ dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); @@ -3140,7 +3141,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16_PREPROCESS) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) { using namespace conv_bias; - std::vector args = get_conv_bias_1x1_args(true, true); + std::vector args = + get_conv_bias_1x1_args(false, true, false, false); #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); diff --git a/dnn/test/x86/conv_bias.cpp b/dnn/test/x86/conv_bias.cpp index 40a4efd3..df8c5404 100644 --- a/dnn/test/x86/conv_bias.cpp +++ b/dnn/test/x86/conv_bias.cpp @@ -834,6 +834,13 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X32) { //! no bias args.emplace_back(param, TensorShape{1, ic, h, w}, TensorShape{oc, ic, kernel, kernel}, TensorShape{}); + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, 1, 1}); + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, (h + 2 * p - kernel) + 1, + (h + 2 * p - kernel) + 1}); }; for (size_t kernel : {2, 3, 4, 5, 6, 7}) @@ -1384,7 +1391,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { using namespace conv_bias; UniformIntRNG rng{-50, 50}; float epsilon = 0.001; - std::vector args = get_conv_bias_1x1_args(true, true); + std::vector args = get_conv_bias_1x1_args(false, true); #if MEGDNN_X86_WITH_MKL_DNN if (x86::is_supported(x86::SIMDType::VNNI)) { checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, @@ -1422,7 +1429,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32_PREPROCESS) { using namespace conv_bias; UniformIntRNG rng{-50, 50}; float epsilon = 0.001; - std::vector args = get_conv_bias_1x1_args(true, true); + std::vector args = get_conv_bias_1x1_args(false, true); #if MEGDNN_X86_WITH_VNNI if (x86::is_supported(x86::SIMDType::VNNI)) { checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{},