GitOrigin-RevId: 42d144a813
tags/v0.5.0
@@ -11,9 +11,10 @@ | |||||
*/ | */ | ||||
#include "src/arm_common/pooling/algo.h" | #include "src/arm_common/pooling/algo.h" | ||||
#include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h" | |||||
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h" | |||||
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h" | #include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h" | ||||
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h" | #include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h" | ||||
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h" | |||||
#include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h" | #include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h" | ||||
#include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h" | #include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h" | ||||
@@ -666,6 +667,75 @@ void PoolingImpl::AlgoFilter3MaxStride1NCHW44::exec( | |||||
#undef DISPATCH_FUNC | #undef DISPATCH_FUNC | ||||
} | } | ||||
bool PoolingImpl::AlgoFilter2MaxStridexNCHW44::usable( | |||||
const PoolingKernSizeParam& param) const { | |||||
auto SH = param.stride[0]; | |||||
auto SW = param.stride[1]; | |||||
auto FH = param.filter[0]; | |||||
auto FW = param.filter[1]; | |||||
auto PH = param.padding[0]; | |||||
auto PW = param.padding[1]; | |||||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.format == Param::Format::NCHW44 && | |||||
param.mode == Mode::MAX && FH == 2 && FW == 2 && SH == SW && | |||||
(SW == 1 || SW == 2) && PH == 0 && PW == 0; | |||||
return avaible; | |||||
} | |||||
void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( | |||||
const PoolingKernParam& param) const { | |||||
auto IH = param.isz[0], IW = param.isz[1]; | |||||
auto OH = param.osz[0], OW = param.osz[1]; | |||||
auto N = param.n, C = param.ic; | |||||
auto PH = param.padding[0]; | |||||
auto PW = param.padding[1]; | |||||
auto SW = param.stride[0]; | |||||
void* src_ptr = param.src_ptr; | |||||
void* dst_ptr = param.dst_ptr; | |||||
#define DISPATCH_FUNC(type, func, midout_type_id, i) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ | |||||
midout_iv(midout_type_id)) { \ | |||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ | |||||
size_t index, size_t thread_id) { \ | |||||
MEGDNN_MARK_USED_VAR(thread_id); \ | |||||
size_t n = index / C; \ | |||||
size_t c = index % C; \ | |||||
do_max_pooling_2x2_stride##i##_##func##_nchw44_NEON( \ | |||||
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | |||||
c * IH * IW * 4, \ | |||||
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | |||||
c * OH * OW * 4, \ | |||||
IH, IW, OH, OW, PH, PW); \ | |||||
}; \ | |||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||||
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ | |||||
run); \ | |||||
} \ | |||||
MIDOUT_END(); | |||||
#define DISPATCH_STRIDE(type, func, midout_type_id) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_FUNC(type, func, midout_type_id, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_FUNC(type, func, midout_type_id, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert(0, "unsupport stride size"); \ | |||||
} | |||||
DISPATCH_STRIDE(int8_t, int8, 10); | |||||
#undef DISPATCH_STRIDE | |||||
#undef DISPATCH_FUNC | |||||
} | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
@@ -99,6 +99,14 @@ public: | |||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
}; | }; | ||||
class PoolingImpl::AlgoFilter2MaxStridexNCHW44 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "ARM_POOLING_FILTER2_MAX_STRIDEX_NCHW44"; } | |||||
bool usable(const PoolingKernSizeParam& param) const override; | |||||
void exec(const PoolingKernParam& param) const override; | |||||
}; | |||||
WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); | WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); | ||||
} // namespace arm_common | } // namespace arm_common | ||||
@@ -0,0 +1,126 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/unroll_macro.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW) { | |||||
size_t oh = 0; | |||||
for (; oh < OH; ++oh) { | |||||
size_t ih = oh; | |||||
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; | |||||
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||||
size_t ow = 0; | |||||
for (; ow + 3 < OW; ow += 4) { | |||||
int8x16_t src0123 = vld1q_s8(sptr0); | |||||
int8x16_t src1234 = vld1q_s8(sptr0 + 4); | |||||
int8x16_t max0 = vmaxq_s8(src0123, src1234); | |||||
src0123 = vld1q_s8(sptr1); | |||||
src1234 = vld1q_s8(sptr1 + 4); | |||||
int8x16_t max1 = vmaxq_s8(src0123, src1234); | |||||
int8x16_t max_out = vmaxq_s8(max0, max1); | |||||
vst1q_s8(dptr, max_out); | |||||
sptr0 += 16; | |||||
sptr1 += 16; | |||||
dptr += 16; | |||||
} | |||||
for (; ow < OW; ++ow) { | |||||
int8x8_t src001 = vld1_s8(sptr0); | |||||
int8x8_t src012 = vld1_s8(sptr0 + 4); | |||||
int8x8_t src101 = vld1_s8(sptr1); | |||||
int8x8_t src112 = vld1_s8(sptr1 + 4); | |||||
int8x8_t max01_tmp = vmax_s8(src001, src101); | |||||
int8x8_t max12_tmp = vmax_s8(src012, src112); | |||||
int8x8_t mat_out = vmax_s8(max01_tmp, max12_tmp); | |||||
#define store(i) *(dptr + i) = mat_out[i]; | |||||
UNROLL_CALL_NOWRAPPER(4, store) | |||||
#undef store | |||||
sptr0 += 4; | |||||
sptr1 += 4; | |||||
dptr += 4; | |||||
} | |||||
} | |||||
} | |||||
void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW) { | |||||
size_t oh = 0; | |||||
for (; oh < OH; ++oh) { | |||||
size_t ih = oh << 1; | |||||
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; | |||||
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||||
size_t ow = 0; | |||||
for (; ow + 3 < OW; ow += 4) { | |||||
int8x16_t src00 = vld1q_s8(sptr0); | |||||
int8x16_t src04 = vld1q_s8(sptr0 + 4 * 4); | |||||
int32x4x2_t src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), | |||||
vreinterpretq_s32_s8(src04)); | |||||
int32x4_t src0246 = src_tmp.val[0]; | |||||
int32x4_t src1357 = src_tmp.val[1]; | |||||
int8x16_t max0 = vmaxq_s8(vreinterpretq_s8_s32(src0246), | |||||
vreinterpretq_s8_s32(src1357)); | |||||
src00 = vld1q_s8(sptr1); | |||||
src04 = vld1q_s8(sptr1 + 4 * 4); | |||||
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), | |||||
vreinterpretq_s32_s8(src04)); | |||||
src0246 = src_tmp.val[0]; | |||||
src1357 = src_tmp.val[1]; | |||||
int8x16_t max1 = vmaxq_s8(vreinterpretq_s8_s32(src0246), | |||||
vreinterpretq_s8_s32(src1357)); | |||||
int8x16_t max_out = vmaxq_s8(max0, max1); | |||||
vst1q_s8(dptr, max_out); | |||||
sptr0 += 32; | |||||
sptr1 += 32; | |||||
dptr += 16; | |||||
} | |||||
for (; ow < OW; ++ow) { | |||||
int8x8_t src001 = vld1_s8(sptr0); | |||||
int8x8_t src012 = vld1_s8(sptr0 + 4); | |||||
int8x8_t src101 = vld1_s8(sptr1); | |||||
int8x8_t src112 = vld1_s8(sptr1 + 4); | |||||
int8x8_t max01_tmp = vmax_s8(src001, src101); | |||||
int8x8_t max12_tmp = vmax_s8(src012, src112); | |||||
int8x8_t mat_out = vmax_s8(max01_tmp, max12_tmp); | |||||
#define store(i) *(dptr + i) = mat_out[i]; | |||||
UNROLL_CALL_NOWRAPPER(4, store) | |||||
#undef store | |||||
sptr0 += 8; | |||||
sptr1 += 8; | |||||
dptr += 4; | |||||
} | |||||
} | |||||
} | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,30 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW); | |||||
void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW); | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -27,6 +27,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj { | |||||
AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2; | AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2; | ||||
AlgoFilter3MaxStride2NCHW44 algo_filter3_max_stride2_nchw4; | AlgoFilter3MaxStride2NCHW44 algo_filter3_max_stride2_nchw4; | ||||
AlgoFilter3MaxStride1NCHW44 algo_filter3_max_stride1_nchw4; | AlgoFilter3MaxStride1NCHW44 algo_filter3_max_stride1_nchw4; | ||||
AlgoFilter2MaxStridexNCHW44 algo_filter2_max_stridex_nchw4; | |||||
public: | public: | ||||
AlgoPack() { | AlgoPack() { | ||||
@@ -40,6 +41,7 @@ public: | |||||
all_algos.emplace_back(&algo_int8_filter3_max_stride2); | all_algos.emplace_back(&algo_int8_filter3_max_stride2); | ||||
all_algos.emplace_back(&algo_filter3_max_stride2_nchw4); | all_algos.emplace_back(&algo_filter3_max_stride2_nchw4); | ||||
all_algos.emplace_back(&algo_filter3_max_stride1_nchw4); | all_algos.emplace_back(&algo_filter3_max_stride1_nchw4); | ||||
all_algos.emplace_back(&algo_filter2_max_stridex_nchw4); | |||||
} | } | ||||
SmallVector<AlgoBase*> all_algos; | SmallVector<AlgoBase*> all_algos; | ||||
}; | }; | ||||
@@ -85,6 +85,7 @@ private: | |||||
class AlgoInt8Filter3MaxStride2; | class AlgoInt8Filter3MaxStride2; | ||||
class AlgoFilter3MaxStride2NCHW44; | class AlgoFilter3MaxStride2NCHW44; | ||||
class AlgoFilter3MaxStride1NCHW44; | class AlgoFilter3MaxStride1NCHW44; | ||||
class AlgoFilter2MaxStridexNCHW44; | |||||
class AlgoPack; | class AlgoPack; | ||||
}; | }; | ||||
} // namespace arm_common | } // namespace arm_common | ||||
@@ -154,6 +154,57 @@ TEST_F(ARM_COMMON, POOLING_MAX_W3x3_S1x1_NCHW44) | |||||
// clang-format on | // clang-format on | ||||
} | } | ||||
TEST_F(ARM_COMMON, POOLING_MAX_W2x2_S1x1_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {2, 5, 10, 17}) | |||||
for (size_t iw: {2, 6, 8, 16, 26}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 2 && iw+2*pw >= 2) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 1; | |||||
param.window_h = param.window_w = 2; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||||
} | |||||
// clang-format on | |||||
} | |||||
TEST_F(ARM_COMMON, POOLING_MAX_W2x2_S2x2_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {2, 5, 10, 17}) | |||||
for (size_t iw: {2, 6, 8, 16, 26}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 2 && iw+2*pw >= 2) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 2; | |||||
param.window_h = param.window_w = 2; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||||
} | |||||
// clang-format on | |||||
} | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
TEST_F(ARM_COMMON, POOLING_FP16) { | TEST_F(ARM_COMMON, POOLING_FP16) { | ||||
Checker<Pooling> checker(handle()); | Checker<Pooling> checker(handle()); | ||||
@@ -104,6 +104,57 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S1x1_NCHW44) | |||||
// clang-format on | // clang-format on | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S1x1_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {2, 5, 10, 17}) | |||||
for (size_t iw: {2, 6, 8, 16, 26}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 3 && iw+2*pw >= 3) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 1; | |||||
param.window_h = param.window_w = 2; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||||
} | |||||
// clang-format on | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S2x2_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {2, 5, 10, 17}) | |||||
for (size_t iw: {2, 6, 8, 16, 26}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 3 && iw+2*pw >= 3) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 2; | |||||
param.window_h = param.window_w = 2; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||||
} | |||||
// clang-format on | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_INT8_W3x3_S2x2) | TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_INT8_W3x3_S2x2) | ||||
{ | { | ||||
for (size_t ih: {2, 3, 7, 13, 52, 53, 54, 55}) | for (size_t ih: {2, 3, 7, 13, 52, 53, 54, 55}) | ||||