Browse Source

feat(mgb/arm_common): add 8x8x16 nchw44 max pooling

GitOrigin-RevId: ed460adb7a
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
c357db0134
3 changed files with 51 additions and 23 deletions
  1. +27
    -5
      dnn/src/arm_common/pooling/algo.cpp
  2. +2
    -1
      dnn/src/arm_common/pooling/opr_impl.cpp
  3. +22
    -17
      dnn/test/arm_common/pooling_multi_thread.cpp

+ 27
- 5
dnn/src/arm_common/pooling/algo.cpp View File

@@ -58,7 +58,8 @@ WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) {

WorkspaceBundle get_bundle_nchw44(
const PoolingImpl::PoolingKernSizeParam& param) {
megdnn_assert((param.src_type.enumv() == DTypeEnum::QuantizedS8) &&
megdnn_assert((param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
(param.format == param::Pooling::Format::NCHW44));
auto IH = param.isz[0];
auto IW = param.isz[1];
@@ -605,10 +606,15 @@ bool PoolingImpl::AlgoFilter3ModexStridexNCHW44::usable(
auto FH = param.filter[0];
auto FW = param.filter[1];

bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
param.format == Param::Format::NCHW44 &&
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
FH == 3 && FW == 3 && SW == SH && (SH == 1 || SW == 2);
//! Int8 not support average, because its round mode is different form
//! quint8
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 &&
param.mode == Mode::AVERAGE);
return avaible;
}

@@ -693,10 +699,15 @@ bool PoolingImpl::AlgoFilter2ModexStridexNCHW44::usable(
auto FH = param.filter[0];
auto FW = param.filter[1];

bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
param.format == Param::Format::NCHW44 &&
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
FH == 2 && FW == 2 && SH == SW && (SW == 1 || SW == 2);
//! Int8 not support average, because its round mode is different form
//! quint8
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 &&
param.mode == Mode::AVERAGE);
return avaible;
}

@@ -781,10 +792,16 @@ bool PoolingImpl::AlgoFilter4ModexStridexNCHW44::usable(
auto FH = param.filter[0];
auto FW = param.filter[1];

bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
param.format == Param::Format::NCHW44 &&
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
FH == 4 && FW == 4 && SH == SW && (SW == 1 || SW == 2);

//! Int8 not support average, because its round mode is different form
//! quint8
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 &&
param.mode == Mode::AVERAGE);
return avaible;
}

@@ -869,10 +886,15 @@ bool PoolingImpl::AlgoFilter5ModexStridexNCHW44::usable(
auto FH = param.filter[0];
auto FW = param.filter[1];

bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
param.format == Param::Format::NCHW44 &&
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
FH == 5 && FW == 5 && SH == SW && (SW == 1 || SW == 2);
//! Int8 not support average, because its round mode is different form
//! quint8
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 &&
param.mode == Mode::AVERAGE);
return avaible;
}



+ 2
- 1
dnn/src/arm_common/pooling/opr_impl.cpp View File

@@ -119,7 +119,8 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src,
arm_common_workspace = ws.total_size_in_bytes() * nr_threads;
}

if ((param.src_type.enumv() == DTypeEnum::QuantizedS8) &&
if ((param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
(param.format == param::Pooling::Format::NCHW44)) {
WorkspaceBundle ws = get_bundle_nchw44(param);
arm_common_workspace = ws.total_size_in_bytes() * nr_threads;


+ 22
- 17
dnn/test/arm_common/pooling_multi_thread.cpp View File

@@ -118,18 +118,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_NCHW44_FP32) {

TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_rng(0, &rng);
// clang-format off
for (size_t ih: {3, 5, 10})
for (size_t iw: {3, 5, 7, 9, 15, 20})
for (size_t ph: {0, 1, 2})
for (size_t pw: {0, 1, 2})
for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE})
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()})
if (ih+2*ph >= 3 && iw+2*pw >= 3)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
checker.set_dtype(0, data_type);

param::Pooling param;
param.mode = mode;
@@ -149,18 +150,20 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44)

TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W2x2_NCHW44)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_rng(0, &rng);
// clang-format off
for (size_t ih: {2, 5, 10, 17})
for (size_t iw: {2, 6, 8, 16, 26})
for (size_t ph: {0, 1})
for (size_t pw: {0, 1})
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE})
for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE})
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()})
if (ih+2*ph >= 2 && iw+2*pw >= 2)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
checker.set_dtype(0, data_type);
checker.set_dtype(1, data_type);

param::Pooling param;
param.mode = mode;
@@ -179,18 +182,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W2x2_NCHW44)

TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W4x4_NCHW44)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_rng(0, &rng);
// clang-format off
for (size_t ih: {4, 10, 18, 25, 30})
for (size_t iw: {4, 12, 17, 20, 25})
for (size_t ph: {0, 1, 2})
for (size_t pw: {0, 1, 2})
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()})
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE})
if (ih+2*ph >= 4 && iw+2*pw >= 4)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
checker.set_dtype(0, data_type);

param::Pooling param;
param.mode = mode;
@@ -208,18 +212,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W4x4_NCHW44)
}
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W5x5_NCHW44)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_rng(0, &rng);
// clang-format off
for (size_t ih: {5, 9, 19, 20, 39})
for (size_t iw: {5, 12, 23, 27, 39})
for (size_t ph: {0, 1, 2})
for (size_t pw: {0, 1, 2})
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()})
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE})
if (ih+2*ph >= 5 && iw+2*pw >= 5)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
checker.set_dtype(0, data_type);

param::Pooling param;
param.mode = mode;


Loading…
Cancel
Save