GitOrigin-RevId: ed460adb7a
tags/v1.0.0-rc1
@@ -58,7 +58,8 @@ WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) { | |||
WorkspaceBundle get_bundle_nchw44( | |||
const PoolingImpl::PoolingKernSizeParam& param) { | |||
megdnn_assert((param.src_type.enumv() == DTypeEnum::QuantizedS8) && | |||
megdnn_assert((param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||
param.src_type.enumv() == DTypeEnum::Int8) && | |||
(param.format == param::Pooling::Format::NCHW44)); | |||
auto IH = param.isz[0]; | |||
auto IW = param.isz[1]; | |||
@@ -605,10 +606,15 @@ bool PoolingImpl::AlgoFilter3ModexStridexNCHW44::usable( | |||
auto FH = param.filter[0]; | |||
auto FW = param.filter[1]; | |||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||
param.src_type.enumv() == DTypeEnum::Int8) && | |||
param.format == Param::Format::NCHW44 && | |||
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | |||
FH == 3 && FW == 3 && SW == SH && (SH == 1 || SW == 2); | |||
//! Int8 not support average, because its round mode is different form | |||
//! quint8 | |||
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && | |||
param.mode == Mode::AVERAGE); | |||
return avaible; | |||
} | |||
@@ -693,10 +699,15 @@ bool PoolingImpl::AlgoFilter2ModexStridexNCHW44::usable( | |||
auto FH = param.filter[0]; | |||
auto FW = param.filter[1]; | |||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||
param.src_type.enumv() == DTypeEnum::Int8) && | |||
param.format == Param::Format::NCHW44 && | |||
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | |||
FH == 2 && FW == 2 && SH == SW && (SW == 1 || SW == 2); | |||
//! Int8 not support average, because its round mode is different form | |||
//! quint8 | |||
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && | |||
param.mode == Mode::AVERAGE); | |||
return avaible; | |||
} | |||
@@ -781,10 +792,16 @@ bool PoolingImpl::AlgoFilter4ModexStridexNCHW44::usable( | |||
auto FH = param.filter[0]; | |||
auto FW = param.filter[1]; | |||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||
param.src_type.enumv() == DTypeEnum::Int8) && | |||
param.format == Param::Format::NCHW44 && | |||
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | |||
FH == 4 && FW == 4 && SH == SW && (SW == 1 || SW == 2); | |||
//! Int8 not support average, because its round mode is different form | |||
//! quint8 | |||
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && | |||
param.mode == Mode::AVERAGE); | |||
return avaible; | |||
} | |||
@@ -869,10 +886,15 @@ bool PoolingImpl::AlgoFilter5ModexStridexNCHW44::usable( | |||
auto FH = param.filter[0]; | |||
auto FW = param.filter[1]; | |||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||
param.src_type.enumv() == DTypeEnum::Int8) && | |||
param.format == Param::Format::NCHW44 && | |||
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | |||
FH == 5 && FW == 5 && SH == SW && (SW == 1 || SW == 2); | |||
//! Int8 not support average, because its round mode is different form | |||
//! quint8 | |||
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && | |||
param.mode == Mode::AVERAGE); | |||
return avaible; | |||
} | |||
@@ -119,7 +119,8 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
arm_common_workspace = ws.total_size_in_bytes() * nr_threads; | |||
} | |||
if ((param.src_type.enumv() == DTypeEnum::QuantizedS8) && | |||
if ((param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||
param.src_type.enumv() == DTypeEnum::Int8) && | |||
(param.format == param::Pooling::Format::NCHW44)) { | |||
WorkspaceBundle ws = get_bundle_nchw44(param); | |||
arm_common_workspace = ws.total_size_in_bytes() * nr_threads; | |||
@@ -118,18 +118,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_NCHW44_FP32) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) | |||
{ | |||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||
Checker<Pooling> checker(handle()); | |||
checker.set_rng(0, &rng); | |||
// clang-format off | |||
for (size_t ih: {3, 5, 10}) | |||
for (size_t iw: {3, 5, 7, 9, 15, 20}) | |||
for (size_t ph: {0, 1, 2}) | |||
for (size_t pw: {0, 1, 2}) | |||
for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) | |||
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()}) | |||
if (ih+2*ph >= 3 && iw+2*pw >= 3) | |||
{ | |||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||
Checker<Pooling> checker(handle()); | |||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||
checker.set_rng(0,&rng); | |||
checker.set_dtype(0, data_type); | |||
param::Pooling param; | |||
param.mode = mode; | |||
@@ -149,18 +150,20 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) | |||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W2x2_NCHW44) | |||
{ | |||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||
Checker<Pooling> checker(handle()); | |||
checker.set_rng(0, &rng); | |||
// clang-format off | |||
for (size_t ih: {2, 5, 10, 17}) | |||
for (size_t iw: {2, 6, 8, 16, 26}) | |||
for (size_t ph: {0, 1}) | |||
for (size_t pw: {0, 1}) | |||
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) | |||
for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) | |||
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()}) | |||
if (ih+2*ph >= 2 && iw+2*pw >= 2) | |||
{ | |||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||
Checker<Pooling> checker(handle()); | |||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||
checker.set_rng(0,&rng); | |||
checker.set_dtype(0, data_type); | |||
checker.set_dtype(1, data_type); | |||
param::Pooling param; | |||
param.mode = mode; | |||
@@ -179,18 +182,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W2x2_NCHW44) | |||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W4x4_NCHW44) | |||
{ | |||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||
Checker<Pooling> checker(handle()); | |||
checker.set_rng(0, &rng); | |||
// clang-format off | |||
for (size_t ih: {4, 10, 18, 25, 30}) | |||
for (size_t iw: {4, 12, 17, 20, 25}) | |||
for (size_t ph: {0, 1, 2}) | |||
for (size_t pw: {0, 1, 2}) | |||
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()}) | |||
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) | |||
if (ih+2*ph >= 4 && iw+2*pw >= 4) | |||
{ | |||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||
Checker<Pooling> checker(handle()); | |||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||
checker.set_rng(0,&rng); | |||
checker.set_dtype(0, data_type); | |||
param::Pooling param; | |||
param.mode = mode; | |||
@@ -208,18 +212,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W4x4_NCHW44) | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W5x5_NCHW44) | |||
{ | |||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||
Checker<Pooling> checker(handle()); | |||
checker.set_rng(0, &rng); | |||
// clang-format off | |||
for (size_t ih: {5, 9, 19, 20, 39}) | |||
for (size_t iw: {5, 12, 23, 27, 39}) | |||
for (size_t ph: {0, 1, 2}) | |||
for (size_t pw: {0, 1, 2}) | |||
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()}) | |||
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) | |||
if (ih+2*ph >= 5 && iw+2*pw >= 5) | |||
{ | |||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||
Checker<Pooling> checker(handle()); | |||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||
checker.set_rng(0,&rng); | |||
checker.set_dtype(0, data_type); | |||
param::Pooling param; | |||
param.mode = mode; | |||