Browse Source

perf(arm_common): optimize arm common pooling 9x9 and 13x13

GitOrigin-RevId: 33d5a62478
release-1.5
Megvii Engine Team 3 years ago
parent
commit
3afa3893d7
5 changed files with 53 additions and 8 deletions
  1. +2
    -1
      dnn/src/arm_common/intrinsic_helper.h
  2. +22
    -5
      dnn/src/arm_common/pooling/algo_fp32_pooling_nchw44.cpp
  3. +3
    -1
      dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h
  4. +25
    -0
      dnn/test/arm_common/pooling_multi_thread.cpp
  5. +1
    -1
      scripts/cmake-build/cross_build_android_arm_inference.sh

+ 2
- 1
dnn/src/arm_common/intrinsic_helper.h View File

@@ -124,4 +124,5 @@ __ai void load_helper_x(T& weight, T2 ptr, int oc_offset, XT... args) {
} // namespace } // namespace
} // namespace megdnn } // namespace megdnn
#undef __ai #undef __ai
// vim: syntax=cpp.doxygen

// vim: syntax=cpp.doxygen

+ 22
- 5
dnn/src/arm_common/pooling/algo_fp32_pooling_nchw44.cpp View File

@@ -30,10 +30,12 @@ bool PoolingImpl::AlgoFp32ModexStridexNCHW44::usable(
bool avaible = param.src_type.enumv() == DTypeEnum::Float32 && bool avaible = param.src_type.enumv() == DTypeEnum::Float32 &&
param.format == Param::Format::NCHW44 && param.format == Param::Format::NCHW44 &&
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
fh == fw && sh == sw &&
(fh == 2 || fh == 3 || fh == 4 || fh == 5) &&
(sh == 1 || sh == 2);
return avaible;
fh == fw && sh == sw;
bool size_ok = ((fh == 2 || fh == 3 || fh == 4 || fh == 5) &&
(sh == 1 || sh == 2));
size_ok |= ((fh == 9 || fh == 13) && (sh == 1));

return avaible && size_ok;
} }


void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec(
@@ -94,6 +96,15 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec(
megdnn_assert(0, "invalid stride %d", sh); \ megdnn_assert(0, "invalid stride %d", sh); \
} }


#define DISPATCH_STRIDE_1(filter) \
switch (sh) { \
case 1: \
DISPATCH_MODE(filter, 1); \
break; \
default: \
megdnn_assert(0, "invalid stride %d", sh); \
}

#define DISPATCH_FILTER() \ #define DISPATCH_FILTER() \
switch (fh) { \ switch (fh) { \
case 2: \ case 2: \
@@ -108,6 +119,12 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec(
case 5: \ case 5: \
DISPATCH_STRIDE(5); \ DISPATCH_STRIDE(5); \
break; \ break; \
case 9: \
DISPATCH_STRIDE_1(9); \
break; \
case 13: \
DISPATCH_STRIDE_1(13); \
break; \
default: \ default: \
megdnn_assert(0, "invalid filter %d", fh); \ megdnn_assert(0, "invalid filter %d", fh); \
} }
@@ -123,4 +140,4 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec(
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn


// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 3
- 1
dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h View File

@@ -64,6 +64,8 @@ INSTANCE_CAL(2)
INSTANCE_CAL(3) INSTANCE_CAL(3)
INSTANCE_CAL(4) INSTANCE_CAL(4)
INSTANCE_CAL(5) INSTANCE_CAL(5)
INSTANCE_CAL(9)
INSTANCE_CAL(13)


#undef INSTANCE_CAL #undef INSTANCE_CAL
#undef CALCULATE_AVG_CB #undef CALCULATE_AVG_CB
@@ -305,4 +307,4 @@ static inline void pooling_fp32_nchw44(const float32_t* src, float32_t* dst,
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn


// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 25
- 0
dnn/test/arm_common/pooling_multi_thread.cpp View File

@@ -116,6 +116,31 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_NCHW44_FP32) {
} }
} }


TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W9_w13_NCHW44)
{
UniformIntRNG rng{-10, 10};
Checker<Pooling> checker(handle());
checker.set_rng(0, &rng);
// clang-format off
for (size_t ih: {20, 15})
for (size_t iw: {15, 20})
for (size_t kernel: {9, 13})
for (size_t pad: {4, 6})
for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE})
if (kernel > pad)
{
param::Pooling param;
param.mode = mode;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = pad;
param.pad_w = pad;
param.stride_h = param.stride_w = 1;
param.window_h = param.window_w = kernel ;
checker.set_param(param).exec(TensorShapeArray{{2, 8, ih, iw, 4}, {}});
}
// clang-format on
}

TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44)
{ {
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};


+ 1
- 1
scripts/cmake-build/cross_build_android_arm_inference.sh View File

@@ -2,7 +2,7 @@
set -e set -e


ARCHS=("arm64-v8a" "armeabi-v7a") ARCHS=("arm64-v8a" "armeabi-v7a")
BUILD_TYPE=Release
BUILD_TYPE=RelWithDebInfo
MGE_ARMV8_2_FEATURE_FP16=OFF MGE_ARMV8_2_FEATURE_FP16=OFF
MGE_DISABLE_FLOAT16=OFF MGE_DISABLE_FLOAT16=OFF
ARCH=arm64-v8a ARCH=arm64-v8a


Loading…
Cancel
Save