From 2318ea3f150dc965245be2d118de8212a237509e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 19 Nov 2021 19:40:14 +0800 Subject: [PATCH] fix(dnn): fix naive average pooling overflow bug for int8 type GitOrigin-RevId: b60a7b6cf824d104efbc7c22020220898cdff4fe --- dnn/src/naive/pooling/opr_impl.cpp | 16 +++++++++++++--- dnn/test/naive/pooling.cpp | 31 +++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/dnn/src/naive/pooling/opr_impl.cpp b/dnn/src/naive/pooling/opr_impl.cpp index c655893f..4a4d8b14 100644 --- a/dnn/src/naive/pooling/opr_impl.cpp +++ b/dnn/src/naive/pooling/opr_impl.cpp @@ -47,6 +47,8 @@ struct MaxPooler { } }; +//! WARNING:for Integer, if sum ctype_ set incorrectly may cause overflow such as +//! (stype_=ctype_ =int8_t) template struct MeanIncludePoolerBase { using stype = stype_; @@ -65,6 +67,7 @@ struct MeanIncludePooler : public MeanIncludePoolerBase { ctype get_ans() { return this->sum / this->count; } }; +//! WARNING: the result is truncated template <> struct MeanIncludePooler : public MeanIncludePoolerBase { using MeanIncludePoolerBase::MeanIncludePoolerBase; @@ -74,7 +77,9 @@ struct MeanIncludePooler : public MeanIncludePoolerBase std::numeric_limits::max()); } }; - +/*! + * average pooling with zero point for quint8 +/*/ template <> struct MeanIncludePooler { int32_t sum; @@ -107,7 +112,7 @@ struct MeanIncludePooler { /*! * \brief Average pooling operation within a single window. - * Works on integers. Rounds toward +INF. + * Works on integers. Rounds toward nearest Integer * \tparam T input data type * \tparam U convert input data type to U before accumulating * \tparam ICType data type for intermediate result @@ -228,10 +233,11 @@ struct MeanExcludePooler { /*! * \brief Average pooling operation within a single window. - * Works on integers. Rounds toward +INF. + * Works on integers. Rounds toward nearest Integer * \tparam T input data type * \tparam U convert input data type to U before accumulating * \tparam ICType data type for intermediate result + * WARNING:for Integer, if type U or ICType set incorrectly may cause overflow */ template struct MeanExcludeRoundedPooler { @@ -257,6 +263,10 @@ struct MeanExcludeRoundedPooler { }; template <> +struct MeanExcludePooler : MeanExcludeRoundedPooler { + using MeanExcludeRoundedPooler::MeanExcludeRoundedPooler; +}; +template <> struct MeanExcludePooler : MeanExcludeRoundedPooler { using MeanExcludeRoundedPooler::MeanExcludeRoundedPooler; diff --git a/dnn/test/naive/pooling.cpp b/dnn/test/naive/pooling.cpp index c5a445c4..f0988fea 100644 --- a/dnn/test/naive/pooling.cpp +++ b/dnn/test/naive/pooling.cpp @@ -100,4 +100,35 @@ TEST_F(NAIVE, POOLING_QUANTIZED_Q4) { TensorValueLowbit4({1, 1, 2, 2}, u4_dt, u8_avg_exclu_dst_vec)}); } } + +TEST_F(NAIVE, POOLING_INT_AVERAGE) { + using Mode = Pooling::Param::Mode; + + Checker checker(handle(), /* check_dispatch */ false); + auto dt = dtype::Int8(); + Pooling::Param param = {Mode::AVERAGE, 0, 0, 1, 1, 2, 2}; + Testcase input_positive{ + TensorValue( + {1, 1, 3, 3}, dt, {127, 127, 127, 127, 127, 127, 127, 127, 127}), + {}}; + Testcase input_negative{ + TensorValue( + {1, 1, 3, 3}, dt, + {-127, -127, -127, -127, -127, -127, -127, -127, -127}), + {}}; + checker.set_param(param).exect( + input_positive, + Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {127, 127, 127, 127})}); + checker.set_param(param).exect( + input_negative, + Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {-127, -127, -127, -127})}); + param = {Mode::AVERAGE_COUNT_EXCLUDE_PADDING, 0, 0, 1, 1, 2, 2}; + checker.set_param(param).exect( + input_positive, + Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {127, 127, 127, 127})}); + checker.set_param(param).exect( + input_negative, + Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {-127, -127, -127, -127})}); +} + // vim: syntax=cpp.doxygen