GitOrigin-RevId: ed460adb7a
tags/v1.0.0-rc1
@@ -58,7 +58,8 @@ WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) { | |||||
WorkspaceBundle get_bundle_nchw44( | WorkspaceBundle get_bundle_nchw44( | ||||
const PoolingImpl::PoolingKernSizeParam& param) { | const PoolingImpl::PoolingKernSizeParam& param) { | ||||
megdnn_assert((param.src_type.enumv() == DTypeEnum::QuantizedS8) && | |||||
megdnn_assert((param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
param.src_type.enumv() == DTypeEnum::Int8) && | |||||
(param.format == param::Pooling::Format::NCHW44)); | (param.format == param::Pooling::Format::NCHW44)); | ||||
auto IH = param.isz[0]; | auto IH = param.isz[0]; | ||||
auto IW = param.isz[1]; | auto IW = param.isz[1]; | ||||
@@ -605,10 +606,15 @@ bool PoolingImpl::AlgoFilter3ModexStridexNCHW44::usable( | |||||
auto FH = param.filter[0]; | auto FH = param.filter[0]; | ||||
auto FW = param.filter[1]; | auto FW = param.filter[1]; | ||||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
param.src_type.enumv() == DTypeEnum::Int8) && | |||||
param.format == Param::Format::NCHW44 && | param.format == Param::Format::NCHW44 && | ||||
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | ||||
FH == 3 && FW == 3 && SW == SH && (SH == 1 || SW == 2); | FH == 3 && FW == 3 && SW == SH && (SH == 1 || SW == 2); | ||||
//! Int8 not support average, because its round mode is different form | |||||
//! quint8 | |||||
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && | |||||
param.mode == Mode::AVERAGE); | |||||
return avaible; | return avaible; | ||||
} | } | ||||
@@ -693,10 +699,15 @@ bool PoolingImpl::AlgoFilter2ModexStridexNCHW44::usable( | |||||
auto FH = param.filter[0]; | auto FH = param.filter[0]; | ||||
auto FW = param.filter[1]; | auto FW = param.filter[1]; | ||||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
param.src_type.enumv() == DTypeEnum::Int8) && | |||||
param.format == Param::Format::NCHW44 && | param.format == Param::Format::NCHW44 && | ||||
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | ||||
FH == 2 && FW == 2 && SH == SW && (SW == 1 || SW == 2); | FH == 2 && FW == 2 && SH == SW && (SW == 1 || SW == 2); | ||||
//! Int8 not support average, because its round mode is different form | |||||
//! quint8 | |||||
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && | |||||
param.mode == Mode::AVERAGE); | |||||
return avaible; | return avaible; | ||||
} | } | ||||
@@ -781,10 +792,16 @@ bool PoolingImpl::AlgoFilter4ModexStridexNCHW44::usable( | |||||
auto FH = param.filter[0]; | auto FH = param.filter[0]; | ||||
auto FW = param.filter[1]; | auto FW = param.filter[1]; | ||||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
param.src_type.enumv() == DTypeEnum::Int8) && | |||||
param.format == Param::Format::NCHW44 && | param.format == Param::Format::NCHW44 && | ||||
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | ||||
FH == 4 && FW == 4 && SH == SW && (SW == 1 || SW == 2); | FH == 4 && FW == 4 && SH == SW && (SW == 1 || SW == 2); | ||||
//! Int8 not support average, because its round mode is different form | |||||
//! quint8 | |||||
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && | |||||
param.mode == Mode::AVERAGE); | |||||
return avaible; | return avaible; | ||||
} | } | ||||
@@ -869,10 +886,15 @@ bool PoolingImpl::AlgoFilter5ModexStridexNCHW44::usable( | |||||
auto FH = param.filter[0]; | auto FH = param.filter[0]; | ||||
auto FW = param.filter[1]; | auto FW = param.filter[1]; | ||||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
param.src_type.enumv() == DTypeEnum::Int8) && | |||||
param.format == Param::Format::NCHW44 && | param.format == Param::Format::NCHW44 && | ||||
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | ||||
FH == 5 && FW == 5 && SH == SW && (SW == 1 || SW == 2); | FH == 5 && FW == 5 && SH == SW && (SW == 1 || SW == 2); | ||||
//! Int8 not support average, because its round mode is different form | |||||
//! quint8 | |||||
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && | |||||
param.mode == Mode::AVERAGE); | |||||
return avaible; | return avaible; | ||||
} | } | ||||
@@ -119,7 +119,8 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
arm_common_workspace = ws.total_size_in_bytes() * nr_threads; | arm_common_workspace = ws.total_size_in_bytes() * nr_threads; | ||||
} | } | ||||
if ((param.src_type.enumv() == DTypeEnum::QuantizedS8) && | |||||
if ((param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
param.src_type.enumv() == DTypeEnum::Int8) && | |||||
(param.format == param::Pooling::Format::NCHW44)) { | (param.format == param::Pooling::Format::NCHW44)) { | ||||
WorkspaceBundle ws = get_bundle_nchw44(param); | WorkspaceBundle ws = get_bundle_nchw44(param); | ||||
arm_common_workspace = ws.total_size_in_bytes() * nr_threads; | arm_common_workspace = ws.total_size_in_bytes() * nr_threads; | ||||
@@ -118,18 +118,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_NCHW44_FP32) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) | TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_rng(0, &rng); | |||||
// clang-format off | // clang-format off | ||||
for (size_t ih: {3, 5, 10}) | for (size_t ih: {3, 5, 10}) | ||||
for (size_t iw: {3, 5, 7, 9, 15, 20}) | for (size_t iw: {3, 5, 7, 9, 15, 20}) | ||||
for (size_t ph: {0, 1, 2}) | for (size_t ph: {0, 1, 2}) | ||||
for (size_t pw: {0, 1, 2}) | for (size_t pw: {0, 1, 2}) | ||||
for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) | for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) | ||||
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()}) | |||||
if (ih+2*ph >= 3 && iw+2*pw >= 3) | if (ih+2*ph >= 3 && iw+2*pw >= 3) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
checker.set_dtype(0, data_type); | |||||
param::Pooling param; | param::Pooling param; | ||||
param.mode = mode; | param.mode = mode; | ||||
@@ -149,18 +150,20 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W2x2_NCHW44) | TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W2x2_NCHW44) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_rng(0, &rng); | |||||
// clang-format off | // clang-format off | ||||
for (size_t ih: {2, 5, 10, 17}) | for (size_t ih: {2, 5, 10, 17}) | ||||
for (size_t iw: {2, 6, 8, 16, 26}) | for (size_t iw: {2, 6, 8, 16, 26}) | ||||
for (size_t ph: {0, 1}) | for (size_t ph: {0, 1}) | ||||
for (size_t pw: {0, 1}) | for (size_t pw: {0, 1}) | ||||
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) | |||||
for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) | |||||
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()}) | |||||
if (ih+2*ph >= 2 && iw+2*pw >= 2) | if (ih+2*ph >= 2 && iw+2*pw >= 2) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
checker.set_dtype(0, data_type); | |||||
checker.set_dtype(1, data_type); | |||||
param::Pooling param; | param::Pooling param; | ||||
param.mode = mode; | param.mode = mode; | ||||
@@ -179,18 +182,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W2x2_NCHW44) | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W4x4_NCHW44) | TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W4x4_NCHW44) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_rng(0, &rng); | |||||
// clang-format off | // clang-format off | ||||
for (size_t ih: {4, 10, 18, 25, 30}) | for (size_t ih: {4, 10, 18, 25, 30}) | ||||
for (size_t iw: {4, 12, 17, 20, 25}) | for (size_t iw: {4, 12, 17, 20, 25}) | ||||
for (size_t ph: {0, 1, 2}) | for (size_t ph: {0, 1, 2}) | ||||
for (size_t pw: {0, 1, 2}) | for (size_t pw: {0, 1, 2}) | ||||
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()}) | |||||
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) | for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) | ||||
if (ih+2*ph >= 4 && iw+2*pw >= 4) | if (ih+2*ph >= 4 && iw+2*pw >= 4) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
checker.set_dtype(0, data_type); | |||||
param::Pooling param; | param::Pooling param; | ||||
param.mode = mode; | param.mode = mode; | ||||
@@ -208,18 +212,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W4x4_NCHW44) | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W5x5_NCHW44) | TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W5x5_NCHW44) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_rng(0, &rng); | |||||
// clang-format off | // clang-format off | ||||
for (size_t ih: {5, 9, 19, 20, 39}) | for (size_t ih: {5, 9, 19, 20, 39}) | ||||
for (size_t iw: {5, 12, 23, 27, 39}) | for (size_t iw: {5, 12, 23, 27, 39}) | ||||
for (size_t ph: {0, 1, 2}) | for (size_t ph: {0, 1, 2}) | ||||
for (size_t pw: {0, 1, 2}) | for (size_t pw: {0, 1, 2}) | ||||
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()}) | |||||
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) | for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) | ||||
if (ih+2*ph >= 5 && iw+2*pw >= 5) | if (ih+2*ph >= 5 && iw+2*pw >= 5) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
checker.set_dtype(0, data_type); | |||||
param::Pooling param; | param::Pooling param; | ||||
param.mode = mode; | param.mode = mode; | ||||