GitOrigin-RevId: b60a7b6cf8
release-1.7
@@ -47,6 +47,8 @@ struct MaxPooler { | |||
} | |||
}; | |||
//! WARNING:for Integer, if sum ctype_ set incorrectly may cause overflow such as | |||
//! (stype_=ctype_ =int8_t) | |||
template <typename stype_, typename ctype_> | |||
struct MeanIncludePoolerBase { | |||
using stype = stype_; | |||
@@ -65,6 +67,7 @@ struct MeanIncludePooler : public MeanIncludePoolerBase<T, T> { | |||
ctype get_ans() { return this->sum / this->count; } | |||
}; | |||
//! WARNING: the result is truncated | |||
template <> | |||
struct MeanIncludePooler<int8_t> : public MeanIncludePoolerBase<int8_t, int32_t> { | |||
using MeanIncludePoolerBase::MeanIncludePoolerBase; | |||
@@ -74,7 +77,9 @@ struct MeanIncludePooler<int8_t> : public MeanIncludePoolerBase<int8_t, int32_t> | |||
std::numeric_limits<int8_t>::max()); | |||
} | |||
}; | |||
/*! | |||
* average pooling with zero point for quint8 | |||
/*/ | |||
template <> | |||
struct MeanIncludePooler<dt_quint8> { | |||
int32_t sum; | |||
@@ -107,7 +112,7 @@ struct MeanIncludePooler<dt_quint8> { | |||
/*! | |||
* \brief Average pooling operation within a single window. | |||
* Works on integers. Rounds toward +INF. | |||
* Works on integers. Rounds toward nearest Integer | |||
* \tparam T input data type | |||
* \tparam U convert input data type to U before accumulating | |||
* \tparam ICType data type for intermediate result | |||
@@ -228,10 +233,11 @@ struct MeanExcludePooler { | |||
/*! | |||
* \brief Average pooling operation within a single window. | |||
* Works on integers. Rounds toward +INF. | |||
* Works on integers. Rounds toward nearest Integer | |||
* \tparam T input data type | |||
* \tparam U convert input data type to U before accumulating | |||
* \tparam ICType data type for intermediate result | |||
* WARNING:for Integer, if type U or ICType set incorrectly may cause overflow | |||
*/ | |||
template <typename T, typename U, typename ICType = U> | |||
struct MeanExcludeRoundedPooler { | |||
@@ -257,6 +263,10 @@ struct MeanExcludeRoundedPooler { | |||
}; | |||
template <> | |||
struct MeanExcludePooler<int8_t> : MeanExcludeRoundedPooler<int8_t, int8_t, int32_t> { | |||
using MeanExcludeRoundedPooler::MeanExcludeRoundedPooler; | |||
}; | |||
template <> | |||
struct MeanExcludePooler<dt_quint8> | |||
: MeanExcludeRoundedPooler<dt_quint8, uint8_t, uint32_t> { | |||
using MeanExcludeRoundedPooler::MeanExcludeRoundedPooler; | |||
@@ -100,4 +100,35 @@ TEST_F(NAIVE, POOLING_QUANTIZED_Q4) { | |||
TensorValueLowbit4({1, 1, 2, 2}, u4_dt, u8_avg_exclu_dst_vec)}); | |||
} | |||
} | |||
TEST_F(NAIVE, POOLING_INT_AVERAGE) { | |||
using Mode = Pooling::Param::Mode; | |||
Checker<Pooling> checker(handle(), /* check_dispatch */ false); | |||
auto dt = dtype::Int8(); | |||
Pooling::Param param = {Mode::AVERAGE, 0, 0, 1, 1, 2, 2}; | |||
Testcase input_positive{ | |||
TensorValue( | |||
{1, 1, 3, 3}, dt, {127, 127, 127, 127, 127, 127, 127, 127, 127}), | |||
{}}; | |||
Testcase input_negative{ | |||
TensorValue( | |||
{1, 1, 3, 3}, dt, | |||
{-127, -127, -127, -127, -127, -127, -127, -127, -127}), | |||
{}}; | |||
checker.set_param(param).exect( | |||
input_positive, | |||
Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {127, 127, 127, 127})}); | |||
checker.set_param(param).exect( | |||
input_negative, | |||
Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {-127, -127, -127, -127})}); | |||
param = {Mode::AVERAGE_COUNT_EXCLUDE_PADDING, 0, 0, 1, 1, 2, 2}; | |||
checker.set_param(param).exect( | |||
input_positive, | |||
Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {127, 127, 127, 127})}); | |||
checker.set_param(param).exect( | |||
input_negative, | |||
Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {-127, -127, -127, -127})}); | |||
} | |||
// vim: syntax=cpp.doxygen |