Browse Source

fix(arm_common/pooling): check mode in pooling algo to avoid wrong use AVERAGE_COUNT_EXCLUDE_PADDING

GitOrigin-RevId: 7a2d243db7
release-1.5
Megvii Engine Team 4 years ago
parent
commit
5868d1fe4f
2 changed files with 17 additions and 2 deletions
  1. +5
    -2
      dnn/src/arm_common/pooling/algo.cpp
  2. +12
    -0
      dnn/test/arm_common/pooling.cpp

+ 5
- 2
dnn/src/arm_common/pooling/algo.cpp View File

@@ -93,6 +93,7 @@ const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW,
}
return need_pad ? sptr_base : src;
}

bool PoolingImpl::AlgoFilterxModexStride1::usable(
const PoolingKernSizeParam& param) const {
auto SH = param.stride[0];
@@ -104,7 +105,8 @@ bool PoolingImpl::AlgoFilterxModexStride1::usable(
param.src_type.category() == DTypeCategory::QUANTIZED) &&
param.format == Param::Format::NCHW && SH == 1 && SW == 1 &&
FH == FW && (FH == 2 || FH == 3);
return avaible;
bool is_mode_ok = (param.mode == Mode::MAX || param.mode == Mode::AVERAGE);
return avaible && is_mode_ok;
}

void PoolingImpl::AlgoFilterxModexStride1::exec(
@@ -202,7 +204,8 @@ bool PoolingImpl::AlgoFilter2ModexStride2::usable(
param.src_type.category() == DTypeCategory::QUANTIZED) &&
param.format == Param::Format::NCHW && FH == FW &&
SH == SW && FH == 2 && SH == 2;
return avaible;
bool is_mode_ok = (param.mode == Mode::MAX || param.mode == Mode::AVERAGE);
return avaible && is_mode_ok;
}

void PoolingImpl::AlgoFilter2ModexStride2::exec(


+ 12
- 0
dnn/test/arm_common/pooling.cpp View File

@@ -53,6 +53,18 @@ TEST_F(ARM_COMMON, POOLING)
if (ih + p * 2 >= 5 && iw + p * 2 >= 5)
checker.set_param(param).exec({{2, 3, ih, iw}, {}});
}
for (size_t ih: {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30})
for (size_t iw: {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30})
for (size_t p: {1, 2})
{
Param param;
param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
param.window_h = param.window_w = 3;
param.stride_h = param.stride_w = 1;
param.pad_h = param.pad_w = p;
Checker<Pooling> checker(handle());
checker.set_param(param).exec({{2, 3, ih, iw}, {}});
}
// clang-format on
}



Loading…
Cancel
Save