From c357db013477495a49cc4c08764e9497c01088ab Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 10 Aug 2020 15:23:55 +0800 Subject: [PATCH] feat(mgb/arm_common): add 8x8x16 nchw44 max pooling GitOrigin-RevId: ed460adb7a47930546f8c9e13b729d9e1306c55f --- dnn/src/arm_common/pooling/algo.cpp | 32 +++++++++++++++++++---- dnn/src/arm_common/pooling/opr_impl.cpp | 3 ++- dnn/test/arm_common/pooling_multi_thread.cpp | 39 ++++++++++++++++------------ 3 files changed, 51 insertions(+), 23 deletions(-) diff --git a/dnn/src/arm_common/pooling/algo.cpp b/dnn/src/arm_common/pooling/algo.cpp index 4962a00c..64c9b767 100644 --- a/dnn/src/arm_common/pooling/algo.cpp +++ b/dnn/src/arm_common/pooling/algo.cpp @@ -58,7 +58,8 @@ WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) { WorkspaceBundle get_bundle_nchw44( const PoolingImpl::PoolingKernSizeParam& param) { - megdnn_assert((param.src_type.enumv() == DTypeEnum::QuantizedS8) && + megdnn_assert((param.src_type.enumv() == DTypeEnum::QuantizedS8 || + param.src_type.enumv() == DTypeEnum::Int8) && (param.format == param::Pooling::Format::NCHW44)); auto IH = param.isz[0]; auto IW = param.isz[1]; @@ -605,10 +606,15 @@ bool PoolingImpl::AlgoFilter3ModexStridexNCHW44::usable( auto FH = param.filter[0]; auto FW = param.filter[1]; - bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && + bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 || + param.src_type.enumv() == DTypeEnum::Int8) && param.format == Param::Format::NCHW44 && (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && FH == 3 && FW == 3 && SW == SH && (SH == 1 || SW == 2); + //! Int8 not support average, because its round mode is different form + //! quint8 + avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && + param.mode == Mode::AVERAGE); return avaible; } @@ -693,10 +699,15 @@ bool PoolingImpl::AlgoFilter2ModexStridexNCHW44::usable( auto FH = param.filter[0]; auto FW = param.filter[1]; - bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && + bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 || + param.src_type.enumv() == DTypeEnum::Int8) && param.format == Param::Format::NCHW44 && (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && FH == 2 && FW == 2 && SH == SW && (SW == 1 || SW == 2); + //! Int8 not support average, because its round mode is different form + //! quint8 + avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && + param.mode == Mode::AVERAGE); return avaible; } @@ -781,10 +792,16 @@ bool PoolingImpl::AlgoFilter4ModexStridexNCHW44::usable( auto FH = param.filter[0]; auto FW = param.filter[1]; - bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && + bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 || + param.src_type.enumv() == DTypeEnum::Int8) && param.format == Param::Format::NCHW44 && (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && FH == 4 && FW == 4 && SH == SW && (SW == 1 || SW == 2); + + //! Int8 not support average, because its round mode is different form + //! quint8 + avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && + param.mode == Mode::AVERAGE); return avaible; } @@ -869,10 +886,15 @@ bool PoolingImpl::AlgoFilter5ModexStridexNCHW44::usable( auto FH = param.filter[0]; auto FW = param.filter[1]; - bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && + bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 || + param.src_type.enumv() == DTypeEnum::Int8) && param.format == Param::Format::NCHW44 && (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && FH == 5 && FW == 5 && SH == SW && (SW == 1 || SW == 2); + //! Int8 not support average, because its round mode is different form + //! quint8 + avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && + param.mode == Mode::AVERAGE); return avaible; } diff --git a/dnn/src/arm_common/pooling/opr_impl.cpp b/dnn/src/arm_common/pooling/opr_impl.cpp index 46da10c8..ef7c7a6b 100644 --- a/dnn/src/arm_common/pooling/opr_impl.cpp +++ b/dnn/src/arm_common/pooling/opr_impl.cpp @@ -119,7 +119,8 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, arm_common_workspace = ws.total_size_in_bytes() * nr_threads; } - if ((param.src_type.enumv() == DTypeEnum::QuantizedS8) && + if ((param.src_type.enumv() == DTypeEnum::QuantizedS8 || + param.src_type.enumv() == DTypeEnum::Int8) && (param.format == param::Pooling::Format::NCHW44)) { WorkspaceBundle ws = get_bundle_nchw44(param); arm_common_workspace = ws.total_size_in_bytes() * nr_threads; diff --git a/dnn/test/arm_common/pooling_multi_thread.cpp b/dnn/test/arm_common/pooling_multi_thread.cpp index aa73370c..5472e9cb 100644 --- a/dnn/test/arm_common/pooling_multi_thread.cpp +++ b/dnn/test/arm_common/pooling_multi_thread.cpp @@ -118,18 +118,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_NCHW44_FP32) { TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) { + UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; + Checker checker(handle()); + checker.set_rng(0, &rng); // clang-format off for (size_t ih: {3, 5, 10}) for (size_t iw: {3, 5, 7, 9, 15, 20}) for (size_t ph: {0, 1, 2}) for (size_t pw: {0, 1, 2}) for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) + for(auto data_type: SmallVector{dtype::QuantizedS8(1.1f), dtype::Int8()}) if (ih+2*ph >= 3 && iw+2*pw >= 3) { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); + checker.set_dtype(0, data_type); param::Pooling param; param.mode = mode; @@ -149,18 +150,20 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W2x2_NCHW44) { + UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; + Checker checker(handle()); + checker.set_rng(0, &rng); // clang-format off for (size_t ih: {2, 5, 10, 17}) for (size_t iw: {2, 6, 8, 16, 26}) for (size_t ph: {0, 1}) for (size_t pw: {0, 1}) - for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) + for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) + for(auto data_type: SmallVector{dtype::QuantizedS8(1.1f), dtype::Int8()}) if (ih+2*ph >= 2 && iw+2*pw >= 2) { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); + checker.set_dtype(0, data_type); + checker.set_dtype(1, data_type); param::Pooling param; param.mode = mode; @@ -179,18 +182,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W2x2_NCHW44) TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W4x4_NCHW44) { + UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; + Checker checker(handle()); + checker.set_rng(0, &rng); // clang-format off for (size_t ih: {4, 10, 18, 25, 30}) for (size_t iw: {4, 12, 17, 20, 25}) for (size_t ph: {0, 1, 2}) for (size_t pw: {0, 1, 2}) + for(auto data_type: SmallVector{dtype::QuantizedS8(1.1f), dtype::Int8()}) for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) if (ih+2*ph >= 4 && iw+2*pw >= 4) { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); + checker.set_dtype(0, data_type); param::Pooling param; param.mode = mode; @@ -208,18 +212,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W4x4_NCHW44) } TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W5x5_NCHW44) { + UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; + Checker checker(handle()); + checker.set_rng(0, &rng); // clang-format off for (size_t ih: {5, 9, 19, 20, 39}) for (size_t iw: {5, 12, 23, 27, 39}) for (size_t ph: {0, 1, 2}) for (size_t pw: {0, 1, 2}) + for(auto data_type: SmallVector{dtype::QuantizedS8(1.1f), dtype::Int8()}) for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) if (ih+2*ph >= 5 && iw+2*pw >= 5) { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); + checker.set_dtype(0, data_type); param::Pooling param; param.mode = mode;