GitOrigin-RevId: 33d5a62478
release-1.5
@@ -124,4 +124,5 @@ __ai void load_helper_x(T& weight, T2 ptr, int oc_offset, XT... args) { | |||||
} // namespace | } // namespace | ||||
} // namespace megdnn | } // namespace megdnn | ||||
#undef __ai | #undef __ai | ||||
// vim: syntax=cpp.doxygen | |||||
// vim: syntax=cpp.doxygen |
@@ -30,10 +30,12 @@ bool PoolingImpl::AlgoFp32ModexStridexNCHW44::usable( | |||||
bool avaible = param.src_type.enumv() == DTypeEnum::Float32 && | bool avaible = param.src_type.enumv() == DTypeEnum::Float32 && | ||||
param.format == Param::Format::NCHW44 && | param.format == Param::Format::NCHW44 && | ||||
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | ||||
fh == fw && sh == sw && | |||||
(fh == 2 || fh == 3 || fh == 4 || fh == 5) && | |||||
(sh == 1 || sh == 2); | |||||
return avaible; | |||||
fh == fw && sh == sw; | |||||
bool size_ok = ((fh == 2 || fh == 3 || fh == 4 || fh == 5) && | |||||
(sh == 1 || sh == 2)); | |||||
size_ok |= ((fh == 9 || fh == 13) && (sh == 1)); | |||||
return avaible && size_ok; | |||||
} | } | ||||
void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( | void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( | ||||
@@ -94,6 +96,15 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( | |||||
megdnn_assert(0, "invalid stride %d", sh); \ | megdnn_assert(0, "invalid stride %d", sh); \ | ||||
} | } | ||||
#define DISPATCH_STRIDE_1(filter) \ | |||||
switch (sh) { \ | |||||
case 1: \ | |||||
DISPATCH_MODE(filter, 1); \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0, "invalid stride %d", sh); \ | |||||
} | |||||
#define DISPATCH_FILTER() \ | #define DISPATCH_FILTER() \ | ||||
switch (fh) { \ | switch (fh) { \ | ||||
case 2: \ | case 2: \ | ||||
@@ -108,6 +119,12 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( | |||||
case 5: \ | case 5: \ | ||||
DISPATCH_STRIDE(5); \ | DISPATCH_STRIDE(5); \ | ||||
break; \ | break; \ | ||||
case 9: \ | |||||
DISPATCH_STRIDE_1(9); \ | |||||
break; \ | |||||
case 13: \ | |||||
DISPATCH_STRIDE_1(13); \ | |||||
break; \ | |||||
default: \ | default: \ | ||||
megdnn_assert(0, "invalid filter %d", fh); \ | megdnn_assert(0, "invalid filter %d", fh); \ | ||||
} | } | ||||
@@ -123,4 +140,4 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | |||||
// vim: syntax=cpp.doxygen |
@@ -64,6 +64,8 @@ INSTANCE_CAL(2) | |||||
INSTANCE_CAL(3) | INSTANCE_CAL(3) | ||||
INSTANCE_CAL(4) | INSTANCE_CAL(4) | ||||
INSTANCE_CAL(5) | INSTANCE_CAL(5) | ||||
INSTANCE_CAL(9) | |||||
INSTANCE_CAL(13) | |||||
#undef INSTANCE_CAL | #undef INSTANCE_CAL | ||||
#undef CALCULATE_AVG_CB | #undef CALCULATE_AVG_CB | ||||
@@ -305,4 +307,4 @@ static inline void pooling_fp32_nchw44(const float32_t* src, float32_t* dst, | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | |||||
// vim: syntax=cpp.doxygen |
@@ -116,6 +116,31 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_NCHW44_FP32) { | |||||
} | } | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W9_w13_NCHW44) | |||||
{ | |||||
UniformIntRNG rng{-10, 10}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_rng(0, &rng); | |||||
// clang-format off | |||||
for (size_t ih: {20, 15}) | |||||
for (size_t iw: {15, 20}) | |||||
for (size_t kernel: {9, 13}) | |||||
for (size_t pad: {4, 6}) | |||||
for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) | |||||
if (kernel > pad) | |||||
{ | |||||
param::Pooling param; | |||||
param.mode = mode; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = pad; | |||||
param.pad_w = pad; | |||||
param.stride_h = param.stride_w = 1; | |||||
param.window_h = param.window_w = kernel ; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 8, ih, iw, 4}, {}}); | |||||
} | |||||
// clang-format on | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) | TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | ||||
@@ -2,7 +2,7 @@ | |||||
set -e | set -e | ||||
ARCHS=("arm64-v8a" "armeabi-v7a") | ARCHS=("arm64-v8a" "armeabi-v7a") | ||||
BUILD_TYPE=Release | |||||
BUILD_TYPE=RelWithDebInfo | |||||
MGE_ARMV8_2_FEATURE_FP16=OFF | MGE_ARMV8_2_FEATURE_FP16=OFF | ||||
MGE_DISABLE_FLOAT16=OFF | MGE_DISABLE_FLOAT16=OFF | ||||
ARCH=arm64-v8a | ARCH=arm64-v8a | ||||