GitOrigin-RevId: f84c32f922
tags/v0.5.0
@@ -13,6 +13,7 @@ | |||
#include "megdnn/opr_param_defs.h" | |||
#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h" | |||
#include "src/arm_common/pooling/do_max_pooling_4x4_nchw44.h" | |||
#include "src/arm_common/pooling/do_max_pooling_5x5_nchw44.h" | |||
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h" | |||
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h" | |||
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h" | |||
@@ -806,6 +807,75 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( | |||
#undef DISPATCH_FUNC | |||
} | |||
bool PoolingImpl::AlgoFilter5MaxStridexNCHW44::usable( | |||
const PoolingKernSizeParam& param) const { | |||
auto SH = param.stride[0]; | |||
auto SW = param.stride[1]; | |||
auto FH = param.filter[0]; | |||
auto FW = param.filter[1]; | |||
auto PH = param.padding[0]; | |||
auto PW = param.padding[1]; | |||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.format == Param::Format::NCHW44 && | |||
param.mode == Mode::MAX && FH == 5 && FW == 5 && SH == SW && | |||
(SW == 1 || SW == 2) && PH == 0 && PW == 0; | |||
return avaible; | |||
} | |||
void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( | |||
const PoolingKernParam& param) const { | |||
auto IH = param.isz[0], IW = param.isz[1]; | |||
auto OH = param.osz[0], OW = param.osz[1]; | |||
auto N = param.n, C = param.ic; | |||
auto PH = param.padding[0]; | |||
auto PW = param.padding[1]; | |||
auto SW = param.stride[0]; | |||
void* src_ptr = param.src_ptr; | |||
void* dst_ptr = param.dst_ptr; | |||
#define DISPATCH_FUNC(type, func, midout_type_id, i) \ | |||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ | |||
midout_iv(midout_type_id)) { \ | |||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ | |||
size_t index, size_t thread_id) { \ | |||
MEGDNN_MARK_USED_VAR(thread_id); \ | |||
size_t n = index / C; \ | |||
size_t c = index % C; \ | |||
do_max_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ | |||
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | |||
c * IH * IW * 4, \ | |||
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | |||
c * OH * OW * 4, \ | |||
IH, IW, OH, OW, PH, PW); \ | |||
}; \ | |||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ | |||
run); \ | |||
} \ | |||
MIDOUT_END(); | |||
#define DISPATCH_STRIDE(type, func, midout_type_id) \ | |||
switch (SW) { \ | |||
case 1: { \ | |||
DISPATCH_FUNC(type, func, midout_type_id, 1); \ | |||
break; \ | |||
} \ | |||
case 2: { \ | |||
DISPATCH_FUNC(type, func, midout_type_id, 2); \ | |||
break; \ | |||
} \ | |||
default: \ | |||
megdnn_assert(0, "unsupport stride size"); \ | |||
} | |||
DISPATCH_STRIDE(int8_t, int8, 12); | |||
#undef DISPATCH_STRIDE | |||
#undef DISPATCH_FUNC | |||
} | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||
@@ -115,6 +115,14 @@ public: | |||
void exec(const PoolingKernParam& param) const override; | |||
}; | |||
class PoolingImpl::AlgoFilter5MaxStridexNCHW44 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "ARM_POOLING_FILTER5_MAX_STRIDEX_NCHW44"; } | |||
bool usable(const PoolingKernSizeParam& param) const override; | |||
void exec(const PoolingKernParam& param) const override; | |||
}; | |||
WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); | |||
} // namespace arm_common | |||
@@ -0,0 +1,202 @@ | |||
/** | |||
* \file dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/pooling/do_max_pooling_5x5_nchw44.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/unroll_macro.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||
size_t IH, size_t IW, | |||
size_t OH, size_t OW, | |||
size_t PH, size_t PW) { | |||
size_t oh = 0; | |||
for (; oh < OH; ++oh) { | |||
size_t ih = oh; | |||
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; | |||
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; | |||
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; | |||
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; | |||
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4; | |||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||
size_t ow = 0; | |||
for (; ow + 3 < OW; ow += 4) { | |||
int8x16_t src00, src04, max_out, max_tmp0, max_tmp1, max_tmp2, | |||
max_tmp3, max_tmp4; | |||
int32x4_t src1234, src2345, src3456; | |||
#define CACULATE_ROW(i) \ | |||
src00 = vld1q_s8(sptr##i); \ | |||
src04 = vld1q_s8(sptr##i + 4 * 4); \ | |||
src1234 = vextq_s32(vreinterpretq_s32_s8(src00), \ | |||
vreinterpretq_s32_s8(src04), 1); \ | |||
src2345 = vextq_s32(vreinterpretq_s32_s8(src00), \ | |||
vreinterpretq_s32_s8(src04), 2); \ | |||
src3456 = vextq_s32(vreinterpretq_s32_s8(src00), \ | |||
vreinterpretq_s32_s8(src04), 3); \ | |||
max_tmp##i = vmaxq_s8(src00, vreinterpretq_s8_s32(src1234)); \ | |||
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2345)); \ | |||
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3456)); \ | |||
max_tmp##i = vmaxq_s8(max_tmp##i, src04); | |||
UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) | |||
max_out = vmaxq_s8(max_tmp0, max_tmp1); | |||
max_out = vmaxq_s8(max_out, max_tmp2); | |||
max_out = vmaxq_s8(max_out, max_tmp3); | |||
max_out = vmaxq_s8(max_out, max_tmp4); | |||
vst1q_s8(dptr, max_out); | |||
sptr0 += 16; | |||
sptr1 += 16; | |||
sptr2 += 16; | |||
sptr3 += 16; | |||
sptr4 += 16; | |||
dptr += 16; | |||
#undef CACULATE_ROW | |||
} | |||
for (; ow < OW; ++ow) { | |||
int8x8_t src01, src23, max_out; | |||
#define CACULATE_ROW(i) \ | |||
src01 = vld1_s8(sptr##i); \ | |||
src23 = vld1_s8(sptr##i + 8); \ | |||
int8x8_t max_tmp##i = vmax_s8(src01, src23); | |||
UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) | |||
max_out = vmax_s8(max_tmp0, max_tmp1); | |||
max_out = vmax_s8(max_out, max_tmp2); | |||
max_out = vmax_s8(max_out, max_tmp3); | |||
max_out = vmax_s8(max_out, max_tmp4); | |||
#define COMPARE_SRC45(i) int8x8_t src##i##_45 = vld1_s8(sptr##i + 4 * 4); | |||
UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45) | |||
int8x8_t max_45 = vmax_s8(src0_45, src1_45); | |||
max_45 = vmax_s8(max_45, src1_45); | |||
max_45 = vmax_s8(max_45, src2_45); | |||
max_45 = vmax_s8(max_45, src3_45); | |||
max_45 = vmax_s8(max_45, src4_45); | |||
#define store(i) \ | |||
*(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]); | |||
UNROLL_CALL_NOWRAPPER(4, store) | |||
#undef store | |||
#undef COMPARE_SRC45 | |||
#undef CACULATE_ROW | |||
sptr0 += 4; | |||
sptr1 += 4; | |||
sptr2 += 4; | |||
sptr3 += 4; | |||
sptr4 += 4; | |||
dptr += 4; | |||
} | |||
} | |||
} | |||
void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||
size_t IH, size_t IW, | |||
size_t OH, size_t OW, | |||
size_t PH, size_t PW) { | |||
size_t oh = 0; | |||
for (; oh < OH; ++oh) { | |||
size_t ih = oh << 1; | |||
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; | |||
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; | |||
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; | |||
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; | |||
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4; | |||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||
size_t ow = 0; | |||
for (; ow + 3 < OW; ow += 4) { | |||
int8x16_t src00, src04, src08, src09, src10, max_tmp0, max_tmp1, | |||
max_tmp2, max_tmp3, max_tmp4; | |||
int32x4_t src0246, src1357, src2468, src3579, src46810; | |||
int32x4x2_t src_tmp; | |||
#define CACULATE_ROW(i) \ | |||
src00 = vld1q_s8(sptr##i); \ | |||
src04 = vld1q_s8(sptr##i + 4 * 4); \ | |||
src08 = vld1q_s8(sptr##i + 4 * 8); \ | |||
src09 = vld1q_s8(sptr##i + 4 * 9); \ | |||
src10 = vld1q_s8(sptr##i + 4 * 10); \ | |||
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ | |||
vreinterpretq_s32_s8(src04)); \ | |||
src0246 = src_tmp.val[0]; \ | |||
src1357 = src_tmp.val[1]; \ | |||
src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src08), 1); \ | |||
src3579 = vextq_s32(src1357, vreinterpretq_s32_s8(src09), 1); \ | |||
src46810 = vextq_s32(src2468, vreinterpretq_s32_s8(src10), 1); \ | |||
max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \ | |||
vreinterpretq_s8_s32(src1357)); \ | |||
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ | |||
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); \ | |||
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src46810)); | |||
UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) | |||
int8x16_t max_out = vmaxq_s8(max_tmp0, max_tmp1); | |||
max_out = vmaxq_s8(max_out, max_tmp2); | |||
max_out = vmaxq_s8(max_out, max_tmp3); | |||
max_out = vmaxq_s8(max_out, max_tmp4); | |||
vst1q_s8(dptr, max_out); | |||
sptr0 += 32; | |||
sptr1 += 32; | |||
sptr2 += 32; | |||
sptr3 += 32; | |||
sptr4 += 32; | |||
dptr += 16; | |||
#undef CACULATE_ROW | |||
} | |||
for (; ow < OW; ++ow) { | |||
int8x8_t src01, src23, max_out; | |||
#define CACULATE_ROW(i) \ | |||
src01 = vld1_s8(sptr##i); \ | |||
src23 = vld1_s8(sptr##i + 8); \ | |||
int8x8_t max_tmp##i = vmax_s8(src01, src23); | |||
UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) | |||
max_out = vmax_s8(max_tmp0, max_tmp1); | |||
max_out = vmax_s8(max_out, max_tmp2); | |||
max_out = vmax_s8(max_out, max_tmp3); | |||
max_out = vmax_s8(max_out, max_tmp4); | |||
#define COMPARE_SRC45(i) int8x8_t src##i##_45 = vld1_s8(sptr##i + 4 * 4); | |||
UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45) | |||
int8x8_t max_45 = vmax_s8(src0_45, src1_45); | |||
max_45 = vmax_s8(max_45, src1_45); | |||
max_45 = vmax_s8(max_45, src2_45); | |||
max_45 = vmax_s8(max_45, src3_45); | |||
max_45 = vmax_s8(max_45, src4_45); | |||
#define store(i) \ | |||
*(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]); | |||
UNROLL_CALL_NOWRAPPER(4, store) | |||
#undef store | |||
#undef COMPARE_SRC45 | |||
#undef CACULATE_ROW | |||
sptr0 += 8; | |||
sptr1 += 8; | |||
sptr2 += 8; | |||
sptr3 += 8; | |||
sptr4 += 8; | |||
dptr += 4; | |||
} | |||
} | |||
} | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,30 @@ | |||
/** | |||
* \file dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
#define KERN(strdie) \ | |||
void do_max_pooling_5x5_##strdie##_int8_nchw44_NEON( \ | |||
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ | |||
size_t OW, size_t PH, size_t PW); | |||
KERN(stride1) | |||
KERN(stride2) | |||
#undef KERN | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -29,6 +29,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj { | |||
AlgoFilter3MaxStride2NCHW44 algo_filter3_max_stride2_nchw4; | |||
AlgoFilter3MaxStride1NCHW44 algo_filter3_max_stride1_nchw4; | |||
AlgoFilter4MaxStridexNCHW44 algo_filter4_max_stridex_nchw4; | |||
AlgoFilter5MaxStridexNCHW44 algo_filter5_max_stridex_nchw4; | |||
public: | |||
AlgoPack() { | |||
@@ -44,6 +45,7 @@ public: | |||
all_algos.emplace_back(&algo_filter3_max_stride1_nchw4); | |||
all_algos.emplace_back(&algo_filter2_max_stridex_nchw4); | |||
all_algos.emplace_back(&algo_filter4_max_stridex_nchw4); | |||
all_algos.emplace_back(&algo_filter5_max_stridex_nchw4); | |||
} | |||
SmallVector<AlgoBase*> all_algos; | |||
}; | |||
@@ -87,6 +87,7 @@ private: | |||
class AlgoFilter3MaxStride2NCHW44; | |||
class AlgoFilter3MaxStride1NCHW44; | |||
class AlgoFilter4MaxStridexNCHW44; | |||
class AlgoFilter5MaxStridexNCHW44; | |||
class AlgoPack; | |||
}; | |||
} // namespace arm_common | |||
@@ -254,6 +254,56 @@ TEST_F(ARM_COMMON, POOLING_MAX_W4x4_S2x2_NCHW44) | |||
} | |||
// clang-format on | |||
} | |||
TEST_F(ARM_COMMON, POOLING_MAX_W5x5_S1x1_NCHW44) | |||
{ | |||
// clang-format off | |||
for (size_t ih: {5, 9, 19, 20, 39}) | |||
for (size_t iw: {5, 12, 23, 27, 39}) | |||
for (size_t ph: {0}) | |||
for (size_t pw: {0}) | |||
if (ih+2*ph >= 5 && iw+2*pw >= 5) | |||
{ | |||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||
Checker<Pooling> checker(handle()); | |||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||
checker.set_rng(0,&rng); | |||
param::Pooling param; | |||
param.mode = param::Pooling::Mode::MAX; | |||
param.format = param::Pooling::Format::NCHW44; | |||
param.pad_h = ph; | |||
param.pad_w = pw; | |||
param.stride_h = param.stride_w = 1; | |||
param.window_h = param.window_w = 5; | |||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||
} | |||
// clang-format on | |||
} | |||
TEST_F(ARM_COMMON, POOLING_MAX_W5x5_S2x2_NCHW44) | |||
{ | |||
// clang-format off | |||
for (size_t ih: {5, 9, 19, 20, 39}) | |||
for (size_t iw: {5, 12, 23, 27, 39}) | |||
for (size_t ph: {0}) | |||
for (size_t pw: {0}) | |||
if (ih+2*ph >= 5 && iw+2*pw >= 5) | |||
{ | |||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||
Checker<Pooling> checker(handle()); | |||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||
checker.set_rng(0,&rng); | |||
param::Pooling param; | |||
param.mode = param::Pooling::Mode::MAX; | |||
param.format = param::Pooling::Format::NCHW44; | |||
param.pad_h = ph; | |||
param.pad_w = pw; | |||
param.stride_h = param.stride_w = 2; | |||
param.window_h = param.window_w = 5; | |||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||
} | |||
// clang-format on | |||
} | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
TEST_F(ARM_COMMON, POOLING_FP16) { | |||
@@ -204,6 +204,56 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_S2x2_NCHW44) | |||
} | |||
// clang-format on | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S1x1_NCHW44) | |||
{ | |||
// clang-format off | |||
for (size_t ih: {5, 9, 19, 20, 39}) | |||
for (size_t iw: {5, 12, 23, 27, 39}) | |||
for (size_t ph: {0}) | |||
for (size_t pw: {0}) | |||
if (ih+2*ph >= 5 && iw+2*pw >= 5) | |||
{ | |||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||
Checker<Pooling> checker(handle()); | |||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||
checker.set_rng(0,&rng); | |||
param::Pooling param; | |||
param.mode = param::Pooling::Mode::MAX; | |||
param.format = param::Pooling::Format::NCHW44; | |||
param.pad_h = ph; | |||
param.pad_w = pw; | |||
param.stride_h = param.stride_w = 1; | |||
param.window_h = param.window_w = 5; | |||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||
} | |||
// clang-format on | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S2x2_NCHW44) | |||
{ | |||
// clang-format off | |||
for (size_t ih: {5, 9, 19, 20, 39}) | |||
for (size_t iw: {5, 12, 23, 27, 39}) | |||
for (size_t ph: {0}) | |||
for (size_t pw: {0}) | |||
if (ih+2*ph >= 5 && iw+2*pw >= 5) | |||
{ | |||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||
Checker<Pooling> checker(handle()); | |||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||
checker.set_rng(0,&rng); | |||
param::Pooling param; | |||
param.mode = param::Pooling::Mode::MAX; | |||
param.format = param::Pooling::Format::NCHW44; | |||
param.pad_h = ph; | |||
param.pad_w = pw; | |||
param.stride_h = param.stride_w = 2; | |||
param.window_h = param.window_w = 5; | |||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||
} | |||
// clang-format on | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_INT8_W3x3_S2x2) | |||
{ | |||