GitOrigin-RevId: f125004e1f
tags/v0.5.0
@@ -11,14 +11,13 @@ | |||||
*/ | */ | ||||
#include "src/arm_common/pooling/algo.h" | #include "src/arm_common/pooling/algo.h" | ||||
#include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h" | |||||
#include "src/arm_common/pooling/do_max_pooling_4x4_nchw44.h" | |||||
#include "src/arm_common/pooling/do_max_pooling_5x5_nchw44.h" | |||||
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h" | |||||
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h" | #include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h" | ||||
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h" | |||||
#include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h" | #include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h" | ||||
#include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h" | #include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h" | ||||
#include "src/arm_common/pooling/do_pooling_2x2_nchw44.h" | |||||
#include "src/arm_common/pooling/do_pooling_3x3_nchw44.h" | |||||
#include "src/arm_common/pooling/do_pooling_4x4_nchw44.h" | |||||
#include "src/arm_common/pooling/do_pooling_5x5_nchw44.h" | |||||
#include "midout.h" | #include "midout.h" | ||||
@@ -57,6 +56,41 @@ WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) { | |||||
return ws; | return ws; | ||||
} | } | ||||
WorkspaceBundle get_bundle_nchw44( | |||||
const PoolingImpl::PoolingKernSizeParam& param) { | |||||
megdnn_assert((param.src_type.enumv() == DTypeEnum::QuantizedS8) && | |||||
(param.format == param::Pooling::Format::NCHW44)); | |||||
auto IH = param.isz[0]; | |||||
auto IW = param.isz[1]; | |||||
auto PH = param.padding[0]; | |||||
auto PW = param.padding[1]; | |||||
size_t padding_size = 0; | |||||
if ((PH != 0) || (PW != 0)) { | |||||
padding_size = (IW + 2 * PW) * (IH + 2 * PH) * 4 * sizeof(int8_t); | |||||
} | |||||
return WorkspaceBundle(nullptr, {padding_size}); | |||||
} | |||||
const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW, | |||||
size_t& IH2, size_t& IW2, size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
int8_t* sptr_base = nullptr; | |||||
bool need_pad = ((PH != 0) || (PW != 0)) ? true : false; | |||||
if (need_pad) { | |||||
IH2 = IH + 2 * PH; | |||||
IW2 = IW + 2 * PW; | |||||
sptr_base = static_cast<int8_t*>(ws.get(0)); | |||||
memset(sptr_base, -128, sizeof(int8_t) * IH2 * IW2 * 4); | |||||
rep(ih, IH) { | |||||
std::memcpy(sptr_base + (ih + PH) * IW2 * 4 + PW * 4, | |||||
src + ih * IW * 4, sizeof(int8_t) * IW * 4); | |||||
} | |||||
} else { | |||||
IH2 = IH; | |||||
IW2 = IW; | |||||
} | |||||
return need_pad ? sptr_base : src; | |||||
} | |||||
bool PoolingImpl::AlgoFilterxModexStride1::usable( | bool PoolingImpl::AlgoFilterxModexStride1::usable( | ||||
const PoolingKernSizeParam& param) const { | const PoolingKernSizeParam& param) const { | ||||
auto SH = param.stride[0]; | auto SH = param.stride[0]; | ||||
@@ -563,47 +597,50 @@ void PoolingImpl::AlgoInt8Filter3MaxStride2::exec( | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
} | } | ||||
bool PoolingImpl::AlgoFilter3MaxStride2NCHW44::usable( | |||||
bool PoolingImpl::AlgoFilter3MaxStridexNCHW44::usable( | |||||
const PoolingKernSizeParam& param) const { | const PoolingKernSizeParam& param) const { | ||||
auto SH = param.stride[0]; | auto SH = param.stride[0]; | ||||
auto SW = param.stride[1]; | auto SW = param.stride[1]; | ||||
auto FH = param.filter[0]; | auto FH = param.filter[0]; | ||||
auto FW = param.filter[1]; | auto FW = param.filter[1]; | ||||
auto PH = param.padding[0]; | |||||
auto PW = param.padding[1]; | |||||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
param.format == Param::Format::NCHW44 && | param.format == Param::Format::NCHW44 && | ||||
param.mode == Mode::MAX && FH == 3 && FW == 3 && SH == 2 && | |||||
SW == 2 && PH == 0 && PW == 0; | |||||
param.mode == Mode::MAX && FH == 3 && FW == 3 && SW == SH && | |||||
(SH == 1 || SW == 2); | |||||
return avaible; | return avaible; | ||||
} | } | ||||
void PoolingImpl::AlgoFilter3MaxStride2NCHW44::exec( | |||||
void PoolingImpl::AlgoFilter3MaxStridexNCHW44::exec( | |||||
const PoolingKernParam& param) const { | const PoolingKernParam& param) const { | ||||
auto IH = param.isz[0], IW = param.isz[1]; | auto IH = param.isz[0], IW = param.isz[1]; | ||||
auto OH = param.osz[0], OW = param.osz[1]; | auto OH = param.osz[0], OW = param.osz[1]; | ||||
auto N = param.n, C = param.ic; | auto N = param.n, C = param.ic; | ||||
auto PH = param.padding[0]; | auto PH = param.padding[0]; | ||||
auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
auto SW = param.stride[0]; | |||||
void* src_ptr = param.src_ptr; | void* src_ptr = param.src_ptr; | ||||
void* dst_ptr = param.dst_ptr; | void* dst_ptr = param.dst_ptr; | ||||
#define DISPATCH_FUNC(type, func, midout_type_id) \ | |||||
#define DISPATCH_FUNC(type, func, i) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ | MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ | ||||
midout_iv(midout_type_id)) { \ | |||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ | |||||
midout_iv(#type #i##_hash)) { \ | |||||
WorkspaceBundle wbundle = get_bundle_nchw44(param); \ | |||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | |||||
wbundle = wbundle, \ | |||||
workspace_ptr = param.workspace<dt_byte>()]( \ | |||||
size_t index, size_t thread_id) { \ | size_t index, size_t thread_id) { \ | ||||
MEGDNN_MARK_USED_VAR(thread_id); \ | |||||
auto ws = wbundle; \ | |||||
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ | |||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_max_pooling_3x3_s2x2_##func##_nchw44_NEON( \ | |||||
do_max_pooling_3x3_stride##i##_##func##_nchw44_NEON( \ | |||||
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | ||||
c * IH * IW * 4, \ | c * IH * IW * 4, \ | ||||
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | ||||
c * OH * OW * 4, \ | c * OH * OW * 4, \ | ||||
IH, IW, OH, OW, PH, PW); \ | |||||
IH, IW, OH, OW, PH, PW, ws); \ | |||||
}; \ | }; \ | ||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ | static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ | ||||
@@ -611,61 +648,23 @@ void PoolingImpl::AlgoFilter3MaxStride2NCHW44::exec( | |||||
} \ | } \ | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
DISPATCH_FUNC(int8_t, int8, 9); | |||||
#undef DISPATCH_FUNC | |||||
} | |||||
bool PoolingImpl::AlgoFilter3MaxStride1NCHW44::usable( | |||||
const PoolingKernSizeParam& param) const { | |||||
auto SH = param.stride[0]; | |||||
auto SW = param.stride[1]; | |||||
auto FH = param.filter[0]; | |||||
auto FW = param.filter[1]; | |||||
auto PH = param.padding[0]; | |||||
auto PW = param.padding[1]; | |||||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.format == Param::Format::NCHW44 && | |||||
param.mode == Mode::MAX && FH == 3 && FW == 3 && SH == 1 && | |||||
SW == 1 && PH == 0 && PW == 0; | |||||
return avaible; | |||||
} | |||||
void PoolingImpl::AlgoFilter3MaxStride1NCHW44::exec( | |||||
const PoolingKernParam& param) const { | |||||
auto IH = param.isz[0], IW = param.isz[1]; | |||||
auto OH = param.osz[0], OW = param.osz[1]; | |||||
auto N = param.n, C = param.ic; | |||||
auto PH = param.padding[0]; | |||||
auto PW = param.padding[1]; | |||||
void* src_ptr = param.src_ptr; | |||||
void* dst_ptr = param.dst_ptr; | |||||
#define DISPATCH_FUNC(type, func, midout_type_id) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ | |||||
midout_iv(midout_type_id)) { \ | |||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ | |||||
size_t index, size_t thread_id) { \ | |||||
MEGDNN_MARK_USED_VAR(thread_id); \ | |||||
size_t n = index / C; \ | |||||
size_t c = index % C; \ | |||||
do_max_pooling_3x3_s1x1_##func##_nchw44_NEON( \ | |||||
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | |||||
c * IH * IW * 4, \ | |||||
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | |||||
c * OH * OW * 4, \ | |||||
IH, IW, OH, OW, PH, PW); \ | |||||
}; \ | |||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||||
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ | |||||
run); \ | |||||
} \ | |||||
MIDOUT_END(); | |||||
#define DISPATCH_STRIDE(type, func) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_FUNC(type, func, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_FUNC(type, func, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert(0, "unsupport stride size"); \ | |||||
} | |||||
DISPATCH_FUNC(int8_t, int8, 10); | |||||
DISPATCH_STRIDE(int8_t, int8); | |||||
#undef DISPATCH_STRIDE | |||||
#undef DISPATCH_FUNC | #undef DISPATCH_FUNC | ||||
} | } | ||||
@@ -675,13 +674,11 @@ bool PoolingImpl::AlgoFilter2MaxStridexNCHW44::usable( | |||||
auto SW = param.stride[1]; | auto SW = param.stride[1]; | ||||
auto FH = param.filter[0]; | auto FH = param.filter[0]; | ||||
auto FW = param.filter[1]; | auto FW = param.filter[1]; | ||||
auto PH = param.padding[0]; | |||||
auto PW = param.padding[1]; | |||||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
param.format == Param::Format::NCHW44 && | param.format == Param::Format::NCHW44 && | ||||
param.mode == Mode::MAX && FH == 2 && FW == 2 && SH == SW && | param.mode == Mode::MAX && FH == 2 && FW == 2 && SH == SW && | ||||
(SW == 1 || SW == 2) && PH == 0 && PW == 0; | |||||
(SW == 1 || SW == 2); | |||||
return avaible; | return avaible; | ||||
} | } | ||||
@@ -697,12 +694,16 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( | |||||
void* src_ptr = param.src_ptr; | void* src_ptr = param.src_ptr; | ||||
void* dst_ptr = param.dst_ptr; | void* dst_ptr = param.dst_ptr; | ||||
#define DISPATCH_FUNC(type, func, midout_type_id, i) \ | |||||
#define DISPATCH_FUNC(type, func, i) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ | MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ | ||||
midout_iv(midout_type_id)) { \ | |||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ | |||||
midout_iv(#func #i##_hash)) { \ | |||||
WorkspaceBundle wbundle = get_bundle_nchw44(param); \ | |||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | |||||
wbundle = wbundle, \ | |||||
workspace_ptr = param.workspace<dt_byte>()]( \ | |||||
size_t index, size_t thread_id) { \ | size_t index, size_t thread_id) { \ | ||||
MEGDNN_MARK_USED_VAR(thread_id); \ | |||||
auto ws = wbundle; \ | |||||
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ | |||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_max_pooling_2x2_stride##i##_##func##_nchw44_NEON( \ | do_max_pooling_2x2_stride##i##_##func##_nchw44_NEON( \ | ||||
@@ -710,7 +711,7 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( | |||||
c * IH * IW * 4, \ | c * IH * IW * 4, \ | ||||
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | ||||
c * OH * OW * 4, \ | c * OH * OW * 4, \ | ||||
IH, IW, OH, OW, PH, PW); \ | |||||
IH, IW, OH, OW, PH, PW, ws); \ | |||||
}; \ | }; \ | ||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ | static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ | ||||
@@ -718,21 +719,21 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( | |||||
} \ | } \ | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
#define DISPATCH_STRIDE(type, func, midout_type_id) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_FUNC(type, func, midout_type_id, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_FUNC(type, func, midout_type_id, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert(0, "unsupport stride size"); \ | |||||
#define DISPATCH_STRIDE(type, func) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_FUNC(type, func, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_FUNC(type, func, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert(0, "unsupport stride size"); \ | |||||
} | } | ||||
DISPATCH_STRIDE(int8_t, int8, 10); | |||||
DISPATCH_STRIDE(int8_t, int8); | |||||
#undef DISPATCH_STRIDE | #undef DISPATCH_STRIDE | ||||
#undef DISPATCH_FUNC | #undef DISPATCH_FUNC | ||||
@@ -744,13 +745,11 @@ bool PoolingImpl::AlgoFilter4MaxStridexNCHW44::usable( | |||||
auto SW = param.stride[1]; | auto SW = param.stride[1]; | ||||
auto FH = param.filter[0]; | auto FH = param.filter[0]; | ||||
auto FW = param.filter[1]; | auto FW = param.filter[1]; | ||||
auto PH = param.padding[0]; | |||||
auto PW = param.padding[1]; | |||||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
param.format == Param::Format::NCHW44 && | param.format == Param::Format::NCHW44 && | ||||
param.mode == Mode::MAX && FH == 4 && FW == 4 && SH == SW && | param.mode == Mode::MAX && FH == 4 && FW == 4 && SH == SW && | ||||
(SW == 1 || SW == 2) && PH == 0 && PW == 0; | |||||
(SW == 1 || SW == 2); | |||||
return avaible; | return avaible; | ||||
} | } | ||||
@@ -766,12 +765,16 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( | |||||
void* src_ptr = param.src_ptr; | void* src_ptr = param.src_ptr; | ||||
void* dst_ptr = param.dst_ptr; | void* dst_ptr = param.dst_ptr; | ||||
#define DISPATCH_FUNC(type, func, midout_type_id, i) \ | |||||
#define DISPATCH_FUNC(type, func, i) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ | MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ | ||||
midout_iv(midout_type_id)) { \ | |||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ | |||||
midout_iv(#func #i##_hash)) { \ | |||||
WorkspaceBundle wbundle = get_bundle_nchw44(param); \ | |||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | |||||
wbundle = wbundle, \ | |||||
workspace_ptr = param.workspace<dt_byte>()]( \ | |||||
size_t index, size_t thread_id) { \ | size_t index, size_t thread_id) { \ | ||||
MEGDNN_MARK_USED_VAR(thread_id); \ | |||||
auto ws = wbundle; \ | |||||
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ | |||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_max_pooling_4x4_stride##i##_##func##_nchw44_NEON( \ | do_max_pooling_4x4_stride##i##_##func##_nchw44_NEON( \ | ||||
@@ -779,7 +782,7 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( | |||||
c * IH * IW * 4, \ | c * IH * IW * 4, \ | ||||
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | ||||
c * OH * OW * 4, \ | c * OH * OW * 4, \ | ||||
IH, IW, OH, OW, PH, PW); \ | |||||
IH, IW, OH, OW, PH, PW, ws); \ | |||||
}; \ | }; \ | ||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ | static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ | ||||
@@ -787,21 +790,21 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( | |||||
} \ | } \ | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
#define DISPATCH_STRIDE(type, func, midout_type_id) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_FUNC(type, func, midout_type_id, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_FUNC(type, func, midout_type_id, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert(0, "unsupport stride size"); \ | |||||
#define DISPATCH_STRIDE(type, func) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_FUNC(type, func, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_FUNC(type, func, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert(0, "unsupport stride size"); \ | |||||
} | } | ||||
DISPATCH_STRIDE(int8_t, int8, 11); | |||||
DISPATCH_STRIDE(int8_t, int8); | |||||
#undef DISPATCH_STRIDE | #undef DISPATCH_STRIDE | ||||
#undef DISPATCH_FUNC | #undef DISPATCH_FUNC | ||||
@@ -813,13 +816,11 @@ bool PoolingImpl::AlgoFilter5MaxStridexNCHW44::usable( | |||||
auto SW = param.stride[1]; | auto SW = param.stride[1]; | ||||
auto FH = param.filter[0]; | auto FH = param.filter[0]; | ||||
auto FW = param.filter[1]; | auto FW = param.filter[1]; | ||||
auto PH = param.padding[0]; | |||||
auto PW = param.padding[1]; | |||||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
param.format == Param::Format::NCHW44 && | param.format == Param::Format::NCHW44 && | ||||
param.mode == Mode::MAX && FH == 5 && FW == 5 && SH == SW && | param.mode == Mode::MAX && FH == 5 && FW == 5 && SH == SW && | ||||
(SW == 1 || SW == 2) && PH == 0 && PW == 0; | |||||
(SW == 1 || SW == 2); | |||||
return avaible; | return avaible; | ||||
} | } | ||||
@@ -835,12 +836,16 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( | |||||
void* src_ptr = param.src_ptr; | void* src_ptr = param.src_ptr; | ||||
void* dst_ptr = param.dst_ptr; | void* dst_ptr = param.dst_ptr; | ||||
#define DISPATCH_FUNC(type, func, midout_type_id, i) \ | |||||
#define DISPATCH_FUNC(type, func, i) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ | MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ | ||||
midout_iv(midout_type_id)) { \ | |||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ | |||||
midout_iv(#func #i##_hash)) { \ | |||||
WorkspaceBundle wbundle = get_bundle_nchw44(param); \ | |||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | |||||
wbundle = wbundle, \ | |||||
workspace_ptr = param.workspace<dt_byte>()]( \ | |||||
size_t index, size_t thread_id) { \ | size_t index, size_t thread_id) { \ | ||||
MEGDNN_MARK_USED_VAR(thread_id); \ | |||||
auto ws = wbundle; \ | |||||
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ | |||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_max_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ | do_max_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ | ||||
@@ -848,7 +853,7 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( | |||||
c * IH * IW * 4, \ | c * IH * IW * 4, \ | ||||
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | ||||
c * OH * OW * 4, \ | c * OH * OW * 4, \ | ||||
IH, IW, OH, OW, PH, PW); \ | |||||
IH, IW, OH, OW, PH, PW, ws); \ | |||||
}; \ | }; \ | ||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ | static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ | ||||
@@ -856,21 +861,21 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( | |||||
} \ | } \ | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
#define DISPATCH_STRIDE(type, func, midout_type_id) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_FUNC(type, func, midout_type_id, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_FUNC(type, func, midout_type_id, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert(0, "unsupport stride size"); \ | |||||
#define DISPATCH_STRIDE(type, func) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_FUNC(type, func, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_FUNC(type, func, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert(0, "unsupport stride size"); \ | |||||
} | } | ||||
DISPATCH_STRIDE(int8_t, int8, 12); | |||||
DISPATCH_STRIDE(int8_t, int8); | |||||
#undef DISPATCH_STRIDE | #undef DISPATCH_STRIDE | ||||
#undef DISPATCH_FUNC | #undef DISPATCH_FUNC | ||||
@@ -83,18 +83,10 @@ public: | |||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
}; | }; | ||||
class PoolingImpl::AlgoFilter3MaxStride2NCHW44 final : public AlgoBase { | |||||
class PoolingImpl::AlgoFilter3MaxStridexNCHW44 final : public AlgoBase { | |||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return "ARM_POOLING_FILTER3_MAX_STRIDE2_NCHW44"; } | |||||
bool usable(const PoolingKernSizeParam& param) const override; | |||||
void exec(const PoolingKernParam& param) const override; | |||||
}; | |||||
class PoolingImpl::AlgoFilter3MaxStride1NCHW44 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "ARM_POOLING_FILTER3_MAX_STRIDE1_NCHW44"; } | |||||
const char* name() const override { return "ARM_POOLING_FILTER3_MAX_STRIDEX_NCHW44"; } | |||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
}; | }; | ||||
@@ -125,6 +117,12 @@ public: | |||||
WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); | WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); | ||||
WorkspaceBundle get_bundle_nchw44( | |||||
const PoolingImpl::PoolingKernSizeParam& param); | |||||
const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW, | |||||
size_t& IH2, size_t& IW2, size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws); | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -1,91 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/unroll_macro.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
void do_max_pooling_3x3_s1x1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, size_t OH, | |||||
size_t OW, size_t PH, size_t PW) { | |||||
size_t oh = 0; | |||||
for (; oh < OH; ++oh) { | |||||
size_t ih = oh; | |||||
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; | |||||
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; | |||||
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||||
size_t ow = 0; | |||||
for (; ow + 3 < OW; ow += 4) { | |||||
int8x16_t src0123 = vld1q_s8(sptr0); | |||||
int8x16_t src1234 = vld1q_s8(sptr0 + 4); | |||||
int8x16_t src2345 = vld1q_s8(sptr0 + 8); | |||||
int8x16_t max0 = vmaxq_s8(src0123, src1234); | |||||
max0 = vmaxq_s8(max0, src2345); | |||||
src0123 = vld1q_s8(sptr1); | |||||
src1234 = vld1q_s8(sptr1 + 4); | |||||
src2345 = vld1q_s8(sptr1 + 8); | |||||
int8x16_t max1 = vmaxq_s8(src0123, src1234); | |||||
max1 = vmaxq_s8(max1, src2345); | |||||
src0123 = vld1q_s8(sptr2); | |||||
src1234 = vld1q_s8(sptr2 + 4); | |||||
src2345 = vld1q_s8(sptr2 + 8); | |||||
int8x16_t max2 = vmaxq_s8(src0123, src1234); | |||||
max2 = vmaxq_s8(max2, src2345); | |||||
int8x16_t max_out = vmaxq_s8(max0, max1); | |||||
max_out = vmaxq_s8(max_out, max2); | |||||
vst1q_s8(dptr, max_out); | |||||
sptr0 += 16; | |||||
sptr1 += 16; | |||||
sptr2 += 16; | |||||
dptr += 16; | |||||
} | |||||
for (; ow < OW; ++ow) { | |||||
int8x8_t src001 = vld1_s8(sptr0); | |||||
int8x8_t src012 = vld1_s8(sptr0 + 4); | |||||
int8x8_t src101 = vld1_s8(sptr1); | |||||
int8x8_t src112 = vld1_s8(sptr1 + 4); | |||||
int8x8_t src201 = vld1_s8(sptr2); | |||||
int8x8_t src212 = vld1_s8(sptr2 + 4); | |||||
int8x8_t max01_tmp = vmax_s8(src001, src101); | |||||
max01_tmp = vmax_s8(max01_tmp, src201); | |||||
int8x8_t max12_tmp = vmax_s8(src012, src112); | |||||
max12_tmp = vmax_s8(max12_tmp, src212); | |||||
#define cb(i) \ | |||||
int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \ | |||||
max12_tmp[i + 4]); | |||||
#define store(i) *(dptr + i) = dst##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb) | |||||
UNROLL_CALL_NOWRAPPER(4, store) | |||||
#undef store | |||||
#undef cb | |||||
sptr0 += 4; | |||||
sptr1 += 4; | |||||
sptr2 += 4; | |||||
dptr += 4; | |||||
} | |||||
} | |||||
} | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -1,25 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
void do_max_pooling_3x3_s1x1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, size_t OH, | |||||
size_t OW, size_t PH, size_t PW); | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -1,112 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/unroll_macro.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
void do_max_pooling_3x3_s2x2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, size_t OH, | |||||
size_t OW, size_t PH, size_t PW) { | |||||
size_t oh = 0; | |||||
for (; oh < OH; ++oh) { | |||||
size_t ih = oh << 1; | |||||
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; | |||||
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; | |||||
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||||
size_t ow = 0; | |||||
for (; ow + 3 < OW; ow += 4) { | |||||
int8x16_t src00 = vld1q_s8(sptr0); | |||||
int8x16_t src04 = vld1q_s8(sptr0 + 4 * 4); | |||||
int8x16_t src08 = vld1q_s8(sptr0 + 4 * 8); | |||||
int32x4x2_t src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), | |||||
vreinterpretq_s32_s8(src04)); | |||||
int32x4_t src0246 = src_tmp.val[0]; | |||||
int32x4_t src1357 = src_tmp.val[1]; | |||||
int32x4_t src2468 = | |||||
vextq_s32(src0246, vreinterpretq_s32_s8(src08), 1); | |||||
int8x16_t max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), | |||||
vreinterpretq_s8_s32(src1357)); | |||||
int8x16_t max0 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); | |||||
int8x16_t src10 = vld1q_s8(sptr1); | |||||
int8x16_t src14 = vld1q_s8(sptr1 + 4 * 4); | |||||
int8x16_t src18 = vld1q_s8(sptr1 + 4 * 8); | |||||
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src10), | |||||
vreinterpretq_s32_s8(src14)); | |||||
src0246 = src_tmp.val[0]; | |||||
src1357 = src_tmp.val[1]; | |||||
src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src18), 1); | |||||
max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), | |||||
vreinterpretq_s8_s32(src1357)); | |||||
int8x16_t max1 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); | |||||
int8x16_t src20 = vld1q_s8(sptr2); | |||||
int8x16_t src24 = vld1q_s8(sptr2 + 4 * 4); | |||||
int8x16_t src28 = vld1q_s8(sptr2 + 4 * 8); | |||||
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src20), | |||||
vreinterpretq_s32_s8(src24)); | |||||
src0246 = src_tmp.val[0]; | |||||
src1357 = src_tmp.val[1]; | |||||
src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src28), 1); | |||||
max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), | |||||
vreinterpretq_s8_s32(src1357)); | |||||
int8x16_t max2 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); | |||||
max_tmp = vmaxq_s8(max0, max1); | |||||
int8x16_t max_out = vmaxq_s8(max_tmp, max2); | |||||
vst1q_s8(dptr, max_out); | |||||
sptr0 += 32; | |||||
sptr1 += 32; | |||||
sptr2 += 32; | |||||
dptr += 16; | |||||
} | |||||
for (; ow < OW; ++ow) { | |||||
int8x8_t src001 = vld1_s8(sptr0); | |||||
int8x8_t src012 = vld1_s8(sptr0 + 4); | |||||
int8x8_t src101 = vld1_s8(sptr1); | |||||
int8x8_t src112 = vld1_s8(sptr1 + 4); | |||||
int8x8_t src201 = vld1_s8(sptr2); | |||||
int8x8_t src212 = vld1_s8(sptr2 + 4); | |||||
int8x8_t max01_tmp = vmax_s8(src001, src101); | |||||
max01_tmp = vmax_s8(max01_tmp, src201); | |||||
int8x8_t max12_tmp = vmax_s8(src012, src112); | |||||
max12_tmp = vmax_s8(max12_tmp, src212); | |||||
#define cb(i) \ | |||||
int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \ | |||||
max12_tmp[i + 4]); | |||||
#define store(i) *(dptr + i) = dst##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb) | |||||
UNROLL_CALL_NOWRAPPER(4, store) | |||||
#undef store | |||||
#undef cb | |||||
sptr0 += 8; | |||||
sptr1 += 8; | |||||
sptr2 += 8; | |||||
dptr += 4; | |||||
} | |||||
} | |||||
} | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -9,7 +9,8 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h" | |||||
#include "src/arm_common/pooling/do_pooling_2x2_nchw44.h" | |||||
#include "src/arm_common/pooling/algo.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
@@ -19,12 +20,16 @@ namespace arm_common { | |||||
void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | ||||
size_t IH, size_t IW, | size_t IH, size_t IW, | ||||
size_t OH, size_t OW, | size_t OH, size_t OW, | ||||
size_t PH, size_t PW) { | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
size_t oh = 0; | size_t oh = 0; | ||||
for (; oh < OH; ++oh) { | for (; oh < OH; ++oh) { | ||||
size_t ih = oh; | size_t ih = oh; | ||||
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; | |||||
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; | |||||
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | int8_t* __restrict dptr = dst + oh * OW * 4; | ||||
size_t ow = 0; | size_t ow = 0; | ||||
for (; ow + 3 < OW; ow += 4) { | for (; ow + 3 < OW; ow += 4) { | ||||
@@ -46,15 +51,10 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
} | } | ||||
for (; ow < OW; ++ow) { | for (; ow < OW; ++ow) { | ||||
int8x8_t src001 = vld1_s8(sptr0); | int8x8_t src001 = vld1_s8(sptr0); | ||||
int8x8_t src012 = vld1_s8(sptr0 + 4); | |||||
int8x8_t src101 = vld1_s8(sptr1); | int8x8_t src101 = vld1_s8(sptr1); | ||||
int8x8_t src112 = vld1_s8(sptr1 + 4); | |||||
int8x8_t max01_tmp = vmax_s8(src001, src101); | |||||
int8x8_t max12_tmp = vmax_s8(src012, src112); | |||||
int8x8_t mat_out = vmax_s8(max01_tmp, max12_tmp); | |||||
#define store(i) *(dptr + i) = mat_out[i]; | |||||
int8x8_t max_out = vmax_s8(src001, src101); | |||||
#define store(i) *(dptr + i) = std::max(max_out[i], max_out[i + 4]); | |||||
UNROLL_CALL_NOWRAPPER(4, store) | UNROLL_CALL_NOWRAPPER(4, store) | ||||
#undef store | #undef store | ||||
sptr0 += 4; | sptr0 += 4; | ||||
@@ -66,12 +66,16 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | ||||
size_t IH, size_t IW, | size_t IH, size_t IW, | ||||
size_t OH, size_t OW, | size_t OH, size_t OW, | ||||
size_t PH, size_t PW) { | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
size_t oh = 0; | size_t oh = 0; | ||||
for (; oh < OH; ++oh) { | for (; oh < OH; ++oh) { | ||||
size_t ih = oh << 1; | size_t ih = oh << 1; | ||||
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; | |||||
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; | |||||
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | int8_t* __restrict dptr = dst + oh * OW * 4; | ||||
size_t ow = 0; | size_t ow = 0; | ||||
for (; ow + 3 < OW; ow += 4) { | for (; ow + 3 < OW; ow += 4) { | ||||
@@ -103,15 +107,10 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
} | } | ||||
for (; ow < OW; ++ow) { | for (; ow < OW; ++ow) { | ||||
int8x8_t src001 = vld1_s8(sptr0); | int8x8_t src001 = vld1_s8(sptr0); | ||||
int8x8_t src012 = vld1_s8(sptr0 + 4); | |||||
int8x8_t src101 = vld1_s8(sptr1); | int8x8_t src101 = vld1_s8(sptr1); | ||||
int8x8_t src112 = vld1_s8(sptr1 + 4); | |||||
int8x8_t max01_tmp = vmax_s8(src001, src101); | |||||
int8x8_t max12_tmp = vmax_s8(src012, src112); | |||||
int8x8_t mat_out = vmax_s8(max01_tmp, max12_tmp); | |||||
#define store(i) *(dptr + i) = mat_out[i]; | |||||
int8x8_t max_out = vmax_s8(src001, src101); | |||||
#define store(i) *(dptr + i) = std::max(max_out[i], max_out[i + 4]); | |||||
UNROLL_CALL_NOWRAPPER(4, store) | UNROLL_CALL_NOWRAPPER(4, store) | ||||
#undef store | #undef store | ||||
sptr0 += 8; | sptr0 += 8; |
@@ -18,11 +18,13 @@ namespace arm_common { | |||||
void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | ||||
size_t IH, size_t IW, | size_t IH, size_t IW, | ||||
size_t OH, size_t OW, | size_t OH, size_t OW, | ||||
size_t PH, size_t PW); | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws); | |||||
void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | ||||
size_t IH, size_t IW, | size_t IH, size_t IW, | ||||
size_t OH, size_t OW, | size_t OH, size_t OW, | ||||
size_t PH, size_t PW); | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws); | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn |
@@ -0,0 +1,195 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/arm_common/pooling/do_pooling_3x3_nchw44.h" | |||||
#include "src/arm_common/pooling/algo.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/unroll_macro.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
size_t oh = 0; | |||||
for (; oh < OH; ++oh) { | |||||
size_t ih = oh; | |||||
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||||
size_t ow = 0; | |||||
for (; ow + 3 < OW; ow += 4) { | |||||
int8x16_t src0123 = vld1q_s8(sptr0); | |||||
int8x16_t src1234 = vld1q_s8(sptr0 + 4); | |||||
int8x16_t src2345 = vld1q_s8(sptr0 + 8); | |||||
int8x16_t max0 = vmaxq_s8(src0123, src1234); | |||||
max0 = vmaxq_s8(max0, src2345); | |||||
src0123 = vld1q_s8(sptr1); | |||||
src1234 = vld1q_s8(sptr1 + 4); | |||||
src2345 = vld1q_s8(sptr1 + 8); | |||||
int8x16_t max1 = vmaxq_s8(src0123, src1234); | |||||
max1 = vmaxq_s8(max1, src2345); | |||||
src0123 = vld1q_s8(sptr2); | |||||
src1234 = vld1q_s8(sptr2 + 4); | |||||
src2345 = vld1q_s8(sptr2 + 8); | |||||
int8x16_t max2 = vmaxq_s8(src0123, src1234); | |||||
max2 = vmaxq_s8(max2, src2345); | |||||
int8x16_t max_out = vmaxq_s8(max0, max1); | |||||
max_out = vmaxq_s8(max_out, max2); | |||||
vst1q_s8(dptr, max_out); | |||||
sptr0 += 16; | |||||
sptr1 += 16; | |||||
sptr2 += 16; | |||||
dptr += 16; | |||||
} | |||||
for (; ow < OW; ++ow) { | |||||
int8x8_t src001 = vld1_s8(sptr0); | |||||
int8x8_t src012 = vld1_s8(sptr0 + 4); | |||||
int8x8_t src101 = vld1_s8(sptr1); | |||||
int8x8_t src112 = vld1_s8(sptr1 + 4); | |||||
int8x8_t src201 = vld1_s8(sptr2); | |||||
int8x8_t src212 = vld1_s8(sptr2 + 4); | |||||
int8x8_t max01_tmp = vmax_s8(src001, src101); | |||||
max01_tmp = vmax_s8(max01_tmp, src201); | |||||
int8x8_t max12_tmp = vmax_s8(src012, src112); | |||||
max12_tmp = vmax_s8(max12_tmp, src212); | |||||
#define cb(i) \ | |||||
int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \ | |||||
max12_tmp[i + 4]); | |||||
#define store(i) *(dptr + i) = dst##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb) | |||||
UNROLL_CALL_NOWRAPPER(4, store) | |||||
#undef store | |||||
#undef cb | |||||
sptr0 += 4; | |||||
sptr1 += 4; | |||||
sptr2 += 4; | |||||
dptr += 4; | |||||
} | |||||
} | |||||
} | |||||
void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
size_t oh = 0; | |||||
for (; oh < OH; ++oh) { | |||||
size_t ih = oh << 1; | |||||
const int8_t* sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
const int8_t* sptr2 = sptr + (ih + 2) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||||
size_t ow = 0; | |||||
for (; ow + 3 < OW; ow += 4) { | |||||
int8x16_t src00 = vld1q_s8(sptr0); | |||||
int8x16_t src04 = vld1q_s8(sptr0 + 4 * 4); | |||||
int32x4_t src08 = vld1q_dup_s32( | |||||
reinterpret_cast<const int32_t*>(sptr0 + 4 * 8)); | |||||
int32x4x2_t src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), | |||||
vreinterpretq_s32_s8(src04)); | |||||
int32x4_t src0246 = src_tmp.val[0]; | |||||
int32x4_t src1357 = src_tmp.val[1]; | |||||
int32x4_t src2468 = vextq_s32(src0246, src08, 1); | |||||
int8x16_t max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), | |||||
vreinterpretq_s8_s32(src1357)); | |||||
int8x16_t max0 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); | |||||
int8x16_t src10 = vld1q_s8(sptr1); | |||||
int8x16_t src14 = vld1q_s8(sptr1 + 4 * 4); | |||||
int32x4_t src18 = vld1q_dup_s32( | |||||
reinterpret_cast<const int32_t*>(sptr1 + 4 * 8)); | |||||
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src10), | |||||
vreinterpretq_s32_s8(src14)); | |||||
src0246 = src_tmp.val[0]; | |||||
src1357 = src_tmp.val[1]; | |||||
src2468 = vextq_s32(src0246, src18, 1); | |||||
max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), | |||||
vreinterpretq_s8_s32(src1357)); | |||||
int8x16_t max1 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); | |||||
int8x16_t src20 = vld1q_s8(sptr2); | |||||
int8x16_t src24 = vld1q_s8(sptr2 + 4 * 4); | |||||
int32x4_t src28 = vld1q_dup_s32( | |||||
reinterpret_cast<const int32_t*>(sptr2 + 4 * 8)); | |||||
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src20), | |||||
vreinterpretq_s32_s8(src24)); | |||||
src0246 = src_tmp.val[0]; | |||||
src1357 = src_tmp.val[1]; | |||||
src2468 = vextq_s32(src0246, src28, 1); | |||||
max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), | |||||
vreinterpretq_s8_s32(src1357)); | |||||
int8x16_t max2 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); | |||||
max_tmp = vmaxq_s8(max0, max1); | |||||
int8x16_t max_out = vmaxq_s8(max_tmp, max2); | |||||
vst1q_s8(dptr, max_out); | |||||
sptr0 += 32; | |||||
sptr1 += 32; | |||||
sptr2 += 32; | |||||
dptr += 16; | |||||
} | |||||
for (; ow < OW; ++ow) { | |||||
int8x8_t src001 = vld1_s8(sptr0); | |||||
int8x8_t src012 = vld1_s8(sptr0 + 4); | |||||
int8x8_t src101 = vld1_s8(sptr1); | |||||
int8x8_t src112 = vld1_s8(sptr1 + 4); | |||||
int8x8_t src201 = vld1_s8(sptr2); | |||||
int8x8_t src212 = vld1_s8(sptr2 + 4); | |||||
int8x8_t max01_tmp = vmax_s8(src001, src101); | |||||
max01_tmp = vmax_s8(max01_tmp, src201); | |||||
int8x8_t max12_tmp = vmax_s8(src012, src112); | |||||
max12_tmp = vmax_s8(max12_tmp, src212); | |||||
#define cb(i) \ | |||||
int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \ | |||||
max12_tmp[i + 4]); | |||||
#define store(i) *(dptr + i) = dst##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb) | |||||
UNROLL_CALL_NOWRAPPER(4, store) | |||||
#undef store | |||||
#undef cb | |||||
sptr0 += 8; | |||||
sptr1 += 8; | |||||
sptr2 += 8; | |||||
dptr += 4; | |||||
} | |||||
} | |||||
} | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -15,9 +15,16 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | namespace arm_common { | ||||
void do_max_pooling_3x3_s2x2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws); | |||||
void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, size_t OH, | size_t IH, size_t IW, size_t OH, | ||||
size_t OW, size_t PH, size_t PW); | |||||
size_t OW, size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws); | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn |
@@ -9,7 +9,8 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/pooling/do_max_pooling_4x4_nchw44.h" | |||||
#include "src/arm_common/pooling/do_pooling_4x4_nchw44.h" | |||||
#include "src/arm_common/pooling/algo.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
@@ -19,14 +20,18 @@ namespace arm_common { | |||||
void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | ||||
size_t IH, size_t IW, | size_t IH, size_t IW, | ||||
size_t OH, size_t OW, | size_t OH, size_t OW, | ||||
size_t PH, size_t PW) { | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
size_t oh = 0; | size_t oh = 0; | ||||
for (; oh < OH; ++oh) { | for (; oh < OH; ++oh) { | ||||
size_t ih = oh; | size_t ih = oh; | ||||
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; | |||||
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; | |||||
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; | |||||
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; | |||||
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4; | |||||
const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | int8_t* __restrict dptr = dst + oh * OW * 4; | ||||
size_t ow = 0; | size_t ow = 0; | ||||
for (; ow + 3 < OW; ow += 4) { | for (; ow + 3 < OW; ow += 4) { | ||||
@@ -90,35 +95,38 @@ void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | ||||
size_t IH, size_t IW, | size_t IH, size_t IW, | ||||
size_t OH, size_t OW, | size_t OH, size_t OW, | ||||
size_t PH, size_t PW) { | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
size_t oh = 0; | size_t oh = 0; | ||||
for (; oh < OH; ++oh) { | for (; oh < OH; ++oh) { | ||||
size_t ih = oh << 1; | size_t ih = oh << 1; | ||||
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; | |||||
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; | |||||
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; | |||||
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; | |||||
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4; | |||||
const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | int8_t* __restrict dptr = dst + oh * OW * 4; | ||||
size_t ow = 0; | size_t ow = 0; | ||||
for (; ow + 3 < OW; ow += 4) { | for (; ow + 3 < OW; ow += 4) { | ||||
int8x16_t src00, src04, src08, src09, max_tmp0, max_tmp1, max_tmp2, | |||||
max_tmp3; | |||||
int32x4_t src0246, src1357, src2468, src3579; | |||||
int8x16_t src00, src04, max_tmp0, max_tmp1, max_tmp2, max_tmp3; | |||||
int32x4_t src0246, src1357, src2468, src3579, src08, src09; | |||||
int32x4x2_t src_tmp; | int32x4x2_t src_tmp; | ||||
#define CACULATE_ROW(i) \ | |||||
src00 = vld1q_s8(sptr##i); \ | |||||
src04 = vld1q_s8(sptr##i + 4 * 4); \ | |||||
src08 = vld1q_s8(sptr##i + 4 * 8); \ | |||||
src09 = vld1q_s8(sptr##i + 4 * 9); \ | |||||
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ | |||||
vreinterpretq_s32_s8(src04)); \ | |||||
src0246 = src_tmp.val[0]; \ | |||||
src1357 = src_tmp.val[1]; \ | |||||
src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src08), 1); \ | |||||
src3579 = vextq_s32(src1357, vreinterpretq_s32_s8(src09), 1); \ | |||||
max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \ | |||||
vreinterpretq_s8_s32(src1357)); \ | |||||
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ | |||||
#define CACULATE_ROW(i) \ | |||||
src00 = vld1q_s8(sptr##i); \ | |||||
src04 = vld1q_s8(sptr##i + 4 * 4); \ | |||||
src08 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 8)); \ | |||||
src09 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 9)); \ | |||||
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ | |||||
vreinterpretq_s32_s8(src04)); \ | |||||
src0246 = src_tmp.val[0]; \ | |||||
src1357 = src_tmp.val[1]; \ | |||||
src2468 = vextq_s32(src0246, src08, 1); \ | |||||
src3579 = vextq_s32(src1357, src09, 1); \ | |||||
max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \ | |||||
vreinterpretq_s8_s32(src1357)); \ | |||||
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ | |||||
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); | max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); | ||||
UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW) | UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW) |
@@ -18,7 +18,7 @@ namespace arm_common { | |||||
#define KERN(strdie) \ | #define KERN(strdie) \ | ||||
void do_max_pooling_4x4_##strdie##_int8_nchw44_NEON( \ | void do_max_pooling_4x4_##strdie##_int8_nchw44_NEON( \ | ||||
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ | const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ | ||||
size_t OW, size_t PH, size_t PW); | |||||
size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws); | |||||
KERN(stride1) | KERN(stride1) | ||||
KERN(stride2) | KERN(stride2) |
@@ -9,7 +9,8 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/pooling/do_max_pooling_5x5_nchw44.h" | |||||
#include "src/arm_common/pooling/do_pooling_5x5_nchw44.h" | |||||
#include "src/arm_common/pooling/algo.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
@@ -19,15 +20,19 @@ namespace arm_common { | |||||
void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | ||||
size_t IH, size_t IW, | size_t IH, size_t IW, | ||||
size_t OH, size_t OW, | size_t OH, size_t OW, | ||||
size_t PH, size_t PW) { | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
size_t oh = 0; | size_t oh = 0; | ||||
for (; oh < OH; ++oh) { | for (; oh < OH; ++oh) { | ||||
size_t ih = oh; | size_t ih = oh; | ||||
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; | |||||
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; | |||||
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; | |||||
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; | |||||
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4; | |||||
const int8_t* sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
const int8_t* sptr2 = sptr + (ih + 2) * IW2 * 4; | |||||
const int8_t* sptr3 = sptr + (ih + 3) * IW2 * 4; | |||||
const int8_t* sptr4 = sptr + (ih + 4) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | int8_t* __restrict dptr = dst + oh * OW * 4; | ||||
size_t ow = 0; | size_t ow = 0; | ||||
for (; ow + 3 < OW; ow += 4) { | for (; ow + 3 < OW; ow += 4) { | ||||
@@ -80,13 +85,16 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
max_out = vmax_s8(max_out, max_tmp3); | max_out = vmax_s8(max_out, max_tmp3); | ||||
max_out = vmax_s8(max_out, max_tmp4); | max_out = vmax_s8(max_out, max_tmp4); | ||||
#define COMPARE_SRC45(i) int8x8_t src##i##_45 = vld1_s8(sptr##i + 4 * 4); | |||||
#define COMPARE_SRC45(i) \ | |||||
int32x2_t src##i##_45 = \ | |||||
vld1_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 4)); | |||||
UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45) | UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45) | ||||
int8x8_t max_45 = vmax_s8(src0_45, src1_45); | |||||
max_45 = vmax_s8(max_45, src1_45); | |||||
max_45 = vmax_s8(max_45, src2_45); | |||||
max_45 = vmax_s8(max_45, src3_45); | |||||
max_45 = vmax_s8(max_45, src4_45); | |||||
int8x8_t max_45 = vmax_s8(vreinterpret_s8_s32(src0_45), | |||||
vreinterpret_s8_s32(src1_45)); | |||||
max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src1_45)); | |||||
max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src2_45)); | |||||
max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src3_45)); | |||||
max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src4_45)); | |||||
#define store(i) \ | #define store(i) \ | ||||
*(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]); | *(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]); | ||||
@@ -106,39 +114,44 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | ||||
size_t IH, size_t IW, | size_t IH, size_t IW, | ||||
size_t OH, size_t OW, | size_t OH, size_t OW, | ||||
size_t PH, size_t PW) { | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
size_t oh = 0; | size_t oh = 0; | ||||
for (; oh < OH; ++oh) { | for (; oh < OH; ++oh) { | ||||
size_t ih = oh << 1; | size_t ih = oh << 1; | ||||
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; | |||||
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; | |||||
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; | |||||
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; | |||||
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4; | |||||
const int8_t* sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
const int8_t* sptr2 = sptr + (ih + 2) * IW2 * 4; | |||||
const int8_t* sptr3 = sptr + (ih + 3) * IW2 * 4; | |||||
const int8_t* sptr4 = sptr + (ih + 4) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | int8_t* __restrict dptr = dst + oh * OW * 4; | ||||
size_t ow = 0; | size_t ow = 0; | ||||
for (; ow + 3 < OW; ow += 4) { | for (; ow + 3 < OW; ow += 4) { | ||||
int8x16_t src00, src04, src08, src09, src10, max_tmp0, max_tmp1, | |||||
max_tmp2, max_tmp3, max_tmp4; | |||||
int32x4_t src0246, src1357, src2468, src3579, src46810; | |||||
int8x16_t src00, src04, max_tmp0, max_tmp1, max_tmp2, max_tmp3, | |||||
max_tmp4; | |||||
int32x4_t src0246, src1357, src2468, src3579, src46810, src10, | |||||
src09, src08; | |||||
int32x4x2_t src_tmp; | int32x4x2_t src_tmp; | ||||
#define CACULATE_ROW(i) \ | |||||
src00 = vld1q_s8(sptr##i); \ | |||||
src04 = vld1q_s8(sptr##i + 4 * 4); \ | |||||
src08 = vld1q_s8(sptr##i + 4 * 8); \ | |||||
src09 = vld1q_s8(sptr##i + 4 * 9); \ | |||||
src10 = vld1q_s8(sptr##i + 4 * 10); \ | |||||
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ | |||||
vreinterpretq_s32_s8(src04)); \ | |||||
src0246 = src_tmp.val[0]; \ | |||||
src1357 = src_tmp.val[1]; \ | |||||
src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src08), 1); \ | |||||
src3579 = vextq_s32(src1357, vreinterpretq_s32_s8(src09), 1); \ | |||||
src46810 = vextq_s32(src2468, vreinterpretq_s32_s8(src10), 1); \ | |||||
max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \ | |||||
vreinterpretq_s8_s32(src1357)); \ | |||||
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ | |||||
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); \ | |||||
#define CACULATE_ROW(i) \ | |||||
src00 = vld1q_s8(sptr##i); \ | |||||
src04 = vld1q_s8(sptr##i + 4 * 4); \ | |||||
src08 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 8)); \ | |||||
src09 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 9)); \ | |||||
src10 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 10)); \ | |||||
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ | |||||
vreinterpretq_s32_s8(src04)); \ | |||||
src0246 = src_tmp.val[0]; \ | |||||
src1357 = src_tmp.val[1]; \ | |||||
src2468 = vextq_s32(src0246, src08, 1); \ | |||||
src3579 = vextq_s32(src1357, src09, 1); \ | |||||
src46810 = vextq_s32(src2468, src10, 1); \ | |||||
max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \ | |||||
vreinterpretq_s8_s32(src1357)); \ | |||||
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ | |||||
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); \ | |||||
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src46810)); | max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src46810)); | ||||
UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) | UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) | ||||
@@ -173,13 +186,16 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
max_out = vmax_s8(max_out, max_tmp3); | max_out = vmax_s8(max_out, max_tmp3); | ||||
max_out = vmax_s8(max_out, max_tmp4); | max_out = vmax_s8(max_out, max_tmp4); | ||||
#define COMPARE_SRC45(i) int8x8_t src##i##_45 = vld1_s8(sptr##i + 4 * 4); | |||||
#define COMPARE_SRC45(i) \ | |||||
int32x2_t src##i##_45 = \ | |||||
vld1_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 4)); | |||||
UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45) | UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45) | ||||
int8x8_t max_45 = vmax_s8(src0_45, src1_45); | |||||
max_45 = vmax_s8(max_45, src1_45); | |||||
max_45 = vmax_s8(max_45, src2_45); | |||||
max_45 = vmax_s8(max_45, src3_45); | |||||
max_45 = vmax_s8(max_45, src4_45); | |||||
int8x8_t max_45 = vmax_s8(vreinterpret_s8_s32(src0_45), | |||||
vreinterpret_s8_s32(src1_45)); | |||||
max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src1_45)); | |||||
max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src2_45)); | |||||
max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src3_45)); | |||||
max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src4_45)); | |||||
#define store(i) \ | #define store(i) \ | ||||
*(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]); | *(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]); |
@@ -18,7 +18,7 @@ namespace arm_common { | |||||
#define KERN(strdie) \ | #define KERN(strdie) \ | ||||
void do_max_pooling_5x5_##strdie##_int8_nchw44_NEON( \ | void do_max_pooling_5x5_##strdie##_int8_nchw44_NEON( \ | ||||
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ | const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ | ||||
size_t OW, size_t PH, size_t PW); | |||||
size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws); | |||||
KERN(stride1) | KERN(stride1) | ||||
KERN(stride2) | KERN(stride2) |
@@ -26,8 +26,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj { | |||||
AlgoInt8Filter2MaxStride2 algo_int8_filter2_max_stride2; | AlgoInt8Filter2MaxStride2 algo_int8_filter2_max_stride2; | ||||
AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2; | AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2; | ||||
AlgoFilter2MaxStridexNCHW44 algo_filter2_max_stridex_nchw4; | AlgoFilter2MaxStridexNCHW44 algo_filter2_max_stridex_nchw4; | ||||
AlgoFilter3MaxStride2NCHW44 algo_filter3_max_stride2_nchw4; | |||||
AlgoFilter3MaxStride1NCHW44 algo_filter3_max_stride1_nchw4; | |||||
AlgoFilter3MaxStridexNCHW44 algo_filter3_max_stridex_nchw4; | |||||
AlgoFilter4MaxStridexNCHW44 algo_filter4_max_stridex_nchw4; | AlgoFilter4MaxStridexNCHW44 algo_filter4_max_stridex_nchw4; | ||||
AlgoFilter5MaxStridexNCHW44 algo_filter5_max_stridex_nchw4; | AlgoFilter5MaxStridexNCHW44 algo_filter5_max_stridex_nchw4; | ||||
@@ -41,8 +40,7 @@ public: | |||||
all_algos.emplace_back(&algo_filter5_max_stride2); | all_algos.emplace_back(&algo_filter5_max_stride2); | ||||
all_algos.emplace_back(&algo_int8_filter2_max_stride2); | all_algos.emplace_back(&algo_int8_filter2_max_stride2); | ||||
all_algos.emplace_back(&algo_int8_filter3_max_stride2); | all_algos.emplace_back(&algo_int8_filter3_max_stride2); | ||||
all_algos.emplace_back(&algo_filter3_max_stride2_nchw4); | |||||
all_algos.emplace_back(&algo_filter3_max_stride1_nchw4); | |||||
all_algos.emplace_back(&algo_filter3_max_stridex_nchw4); | |||||
all_algos.emplace_back(&algo_filter2_max_stridex_nchw4); | all_algos.emplace_back(&algo_filter2_max_stridex_nchw4); | ||||
all_algos.emplace_back(&algo_filter4_max_stridex_nchw4); | all_algos.emplace_back(&algo_filter4_max_stridex_nchw4); | ||||
all_algos.emplace_back(&algo_filter5_max_stridex_nchw4); | all_algos.emplace_back(&algo_filter5_max_stridex_nchw4); | ||||
@@ -119,6 +117,12 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
arm_common_workspace = ws.total_size_in_bytes() * nr_threads; | arm_common_workspace = ws.total_size_in_bytes() * nr_threads; | ||||
} | } | ||||
if ((param.src_type.enumv() == DTypeEnum::QuantizedS8) && | |||||
(param.format == param::Pooling::Format::NCHW44)) { | |||||
WorkspaceBundle ws = get_bundle_nchw44(param); | |||||
arm_common_workspace = ws.total_size_in_bytes() * nr_threads; | |||||
} | |||||
if (find_algo) { | if (find_algo) { | ||||
return arm_common_workspace; | return arm_common_workspace; | ||||
} else { | } else { | ||||
@@ -84,8 +84,7 @@ private: | |||||
class AlgoInt8Filter2MaxStride2; | class AlgoInt8Filter2MaxStride2; | ||||
class AlgoInt8Filter3MaxStride2; | class AlgoInt8Filter3MaxStride2; | ||||
class AlgoFilter2MaxStridexNCHW44; | class AlgoFilter2MaxStridexNCHW44; | ||||
class AlgoFilter3MaxStride2NCHW44; | |||||
class AlgoFilter3MaxStride1NCHW44; | |||||
class AlgoFilter3MaxStridexNCHW44; | |||||
class AlgoFilter4MaxStridexNCHW44; | class AlgoFilter4MaxStridexNCHW44; | ||||
class AlgoFilter5MaxStridexNCHW44; | class AlgoFilter5MaxStridexNCHW44; | ||||
class AlgoPack; | class AlgoPack; | ||||
@@ -8,8 +8,6 @@ | |||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "megdnn/dtype.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
#include "test/arm_common/fixture.h" | #include "test/arm_common/fixture.h" | ||||
#include "test/common/pooling.h" | #include "test/common/pooling.h" | ||||
@@ -102,209 +100,6 @@ TEST_F(ARM_COMMON, POOLING_INT8_W3x3_S2x2) | |||||
// clang-format on | // clang-format on | ||||
} | } | ||||
TEST_F(ARM_COMMON, POOLING_MAX_W3x3_S2x2_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {3, 5, 10}) | |||||
for (size_t iw: {3, 5, 7, 9, 15, 20}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 3 && iw+2*pw >= 3) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 2; | |||||
param.window_h = param.window_w = 3; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||||
} | |||||
// clang-format on | |||||
} | |||||
TEST_F(ARM_COMMON, POOLING_MAX_W3x3_S1x1_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {3, 5, 10}) | |||||
for (size_t iw: {3, 5, 7, 9, 15, 20}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 3 && iw+2*pw >= 3) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 1; | |||||
param.window_h = param.window_w = 3; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||||
} | |||||
// clang-format on | |||||
} | |||||
TEST_F(ARM_COMMON, POOLING_MAX_W2x2_S1x1_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {2, 5, 10, 17}) | |||||
for (size_t iw: {2, 6, 8, 16, 26}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 2 && iw+2*pw >= 2) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 1; | |||||
param.window_h = param.window_w = 2; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||||
} | |||||
// clang-format on | |||||
} | |||||
TEST_F(ARM_COMMON, POOLING_MAX_W2x2_S2x2_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {2, 5, 10, 17}) | |||||
for (size_t iw: {2, 6, 8, 16, 26}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 2 && iw+2*pw >= 2) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 2; | |||||
param.window_h = param.window_w = 2; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||||
} | |||||
// clang-format on | |||||
} | |||||
TEST_F(ARM_COMMON, POOLING_MAX_W4x4_S1x1_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {4, 7, 10, 17, 20}) | |||||
for (size_t iw: {4, 8, 10, 21, 32}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 2 && iw+2*pw >= 2) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 1; | |||||
param.window_h = param.window_w = 4; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||||
} | |||||
// clang-format on | |||||
} | |||||
TEST_F(ARM_COMMON, POOLING_MAX_W4x4_S2x2_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {4, 10, 18, 25, 30}) | |||||
for (size_t iw: {4, 12, 17, 20, 25}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 2 && iw+2*pw >= 2) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 2; | |||||
param.window_h = param.window_w = 4; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||||
} | |||||
// clang-format on | |||||
} | |||||
TEST_F(ARM_COMMON, POOLING_MAX_W5x5_S1x1_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {5, 9, 19, 20, 39}) | |||||
for (size_t iw: {5, 12, 23, 27, 39}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 5 && iw+2*pw >= 5) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 1; | |||||
param.window_h = param.window_w = 5; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||||
} | |||||
// clang-format on | |||||
} | |||||
TEST_F(ARM_COMMON, POOLING_MAX_W5x5_S2x2_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {5, 9, 19, 20, 39}) | |||||
for (size_t iw: {5, 12, 23, 27, 39}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 5 && iw+2*pw >= 5) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 2; | |||||
param.window_h = param.window_w = 5; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||||
} | |||||
// clang-format on | |||||
} | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
TEST_F(ARM_COMMON, POOLING_FP16) { | TEST_F(ARM_COMMON, POOLING_FP16) { | ||||
Checker<Pooling> checker(handle()); | Checker<Pooling> checker(handle()); | ||||
@@ -8,6 +8,8 @@ | |||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include <vector> | |||||
#include "megdnn/dtype.h" | |||||
#include "test/arm_common/fixture.h" | #include "test/arm_common/fixture.h" | ||||
#include "test/common/pooling.h" | #include "test/common/pooling.h" | ||||
@@ -53,38 +55,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING) { | |||||
checker.set_param(param).exec({{2, 3, ih, iw}, {}}); | checker.set_param(param).exec({{2, 3, ih, iw}, {}}); | ||||
} | } | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S2x2_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {3, 5, 10}) | |||||
for (size_t iw: {3, 5, 7, 9, 15, 20}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 3 && iw+2*pw >= 3) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 2; | |||||
param.window_h = param.window_w = 3; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||||
} | |||||
// clang-format on | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S1x1_NCHW44) | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_NCHW44) | |||||
{ | { | ||||
// clang-format off | // clang-format off | ||||
for (size_t ih: {3, 5, 10}) | for (size_t ih: {3, 5, 10}) | ||||
for (size_t iw: {3, 5, 7, 9, 15, 20}) | for (size_t iw: {3, 5, 7, 9, 15, 20}) | ||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
for (size_t ph: {0, 1, 2}) | |||||
for (size_t pw: {0, 1, 2}) | |||||
if (ih+2*ph >= 3 && iw+2*pw >= 3) | if (ih+2*ph >= 3 && iw+2*pw >= 3) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | ||||
@@ -100,18 +78,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S1x1_NCHW44) | |||||
param.stride_h = param.stride_w = 1; | param.stride_h = param.stride_w = 1; | ||||
param.window_h = param.window_w = 3; | param.window_h = param.window_w = 3; | ||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | ||||
param.stride_h = param.stride_w = 2; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | |||||
} | } | ||||
// clang-format on | // clang-format on | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S1x1_NCHW44) | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_NCHW44) | |||||
{ | { | ||||
// clang-format off | // clang-format off | ||||
for (size_t ih: {2, 5, 10, 17}) | |||||
for (size_t iw: {2, 6, 8, 16, 26}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 3 && iw+2*pw >= 3) | |||||
for (size_t ih: {2, 5, 10, 17}) | |||||
for (size_t iw: {2, 6, 8, 16, 26}) | |||||
for (size_t ph: {0, 1}) | |||||
for (size_t pw: {0, 1}) | |||||
if (ih+2*ph >= 2 && iw+2*pw >= 2) | |||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | ||||
Checker<Pooling> checker(handle()); | Checker<Pooling> checker(handle()); | ||||
@@ -126,41 +108,20 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S1x1_NCHW44) | |||||
param.stride_h = param.stride_w = 1; | param.stride_h = param.stride_w = 1; | ||||
param.window_h = param.window_w = 2; | param.window_h = param.window_w = 2; | ||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | ||||
} | |||||
// clang-format on | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S2x2_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {2, 5, 10, 17}) | |||||
for (size_t iw: {2, 6, 8, 16, 26}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 3 && iw+2*pw >= 3) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 2; | param.stride_h = param.stride_w = 2; | ||||
param.window_h = param.window_w = 2; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | ||||
} | } | ||||
// clang-format on | // clang-format on | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_S1x1_NCHW44) | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_NCHW44) | |||||
{ | { | ||||
// clang-format off | // clang-format off | ||||
for (size_t ih: {4, 7, 10, 17, 20}) | |||||
for (size_t iw: {4, 8, 10, 21, 32}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
for (size_t ih: {4, 10, 18, 25, 30}) | |||||
for (size_t iw: {4, 12, 17, 20, 25}) | |||||
for (size_t ph: {0, 1, 2}) | |||||
for (size_t pw: {0, 1, 2}) | |||||
if (ih+2*ph >= 4 && iw+2*pw >= 4) | if (ih+2*ph >= 4 && iw+2*pw >= 4) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | ||||
@@ -176,41 +137,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_S1x1_NCHW44) | |||||
param.stride_h = param.stride_w = 1; | param.stride_h = param.stride_w = 1; | ||||
param.window_h = param.window_w = 4; | param.window_h = param.window_w = 4; | ||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | ||||
} | |||||
// clang-format on | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_S2x2_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {4, 10, 18, 25, 30}) | |||||
for (size_t iw: {4, 12, 17, 20, 25}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 4 && iw+2*pw >= 4) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 2; | param.stride_h = param.stride_w = 2; | ||||
param.window_h = param.window_w = 4; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | ||||
} | } | ||||
// clang-format on | // clang-format on | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S1x1_NCHW44) | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_NCHW44) | |||||
{ | { | ||||
// clang-format off | // clang-format off | ||||
for (size_t ih: {5, 9, 19, 20, 39}) | for (size_t ih: {5, 9, 19, 20, 39}) | ||||
for (size_t iw: {5, 12, 23, 27, 39}) | for (size_t iw: {5, 12, 23, 27, 39}) | ||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
for (size_t ph: {0, 1, 2}) | |||||
for (size_t pw: {0, 1, 2}) | |||||
if (ih+2*ph >= 5 && iw+2*pw >= 5) | if (ih+2*ph >= 5 && iw+2*pw >= 5) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | ||||
@@ -226,31 +165,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S1x1_NCHW44) | |||||
param.stride_h = param.stride_w = 1; | param.stride_h = param.stride_w = 1; | ||||
param.window_h = param.window_w = 5; | param.window_h = param.window_w = 5; | ||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | ||||
} | |||||
// clang-format on | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S2x2_NCHW44) | |||||
{ | |||||
// clang-format off | |||||
for (size_t ih: {5, 9, 19, 20, 39}) | |||||
for (size_t iw: {5, 12, 23, 27, 39}) | |||||
for (size_t ph: {0}) | |||||
for (size_t pw: {0}) | |||||
if (ih+2*ph >= 5 && iw+2*pw >= 5) | |||||
{ | |||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||||
Checker<Pooling> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.1f)); | |||||
checker.set_rng(0,&rng); | |||||
param::Pooling param; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.format = param::Pooling::Format::NCHW44; | |||||
param.pad_h = ph; | |||||
param.pad_w = pw; | |||||
param.stride_h = param.stride_w = 2; | param.stride_h = param.stride_w = 2; | ||||
param.window_h = param.window_w = 5; | |||||
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); | ||||
} | } | ||||
// clang-format on | // clang-format on | ||||
} | } | ||||
@@ -473,13 +391,15 @@ template <typename Opr> | |||||
void benchmark_impl(const typename Opr::Param& param, | void benchmark_impl(const typename Opr::Param& param, | ||||
std::vector<SmallVector<TensorShape>> shapes, size_t RUNS, | std::vector<SmallVector<TensorShape>> shapes, size_t RUNS, | ||||
TaskExecutorConfig&& multi_thread_config, | TaskExecutorConfig&& multi_thread_config, | ||||
TaskExecutorConfig&& single_thread_config) { | |||||
TaskExecutorConfig&& single_thread_config, | |||||
DType data_type) { | |||||
std::vector<float> multi_thread_times, single_thread_times; | std::vector<float> multi_thread_times, single_thread_times; | ||||
{ | { | ||||
auto multi_thread_hanle = | auto multi_thread_hanle = | ||||
create_cpu_handle(0, true, &multi_thread_config); | create_cpu_handle(0, true, &multi_thread_config); | ||||
auto benchmarker = Benchmarker<Opr>(multi_thread_hanle.get()); | auto benchmarker = Benchmarker<Opr>(multi_thread_hanle.get()); | ||||
benchmarker.set_times(RUNS).set_display(false).set_param(param); | benchmarker.set_times(RUNS).set_display(false).set_param(param); | ||||
benchmarker.set_dtype(0, data_type); | |||||
for (auto shape : shapes) { | for (auto shape : shapes) { | ||||
multi_thread_times.push_back(benchmarker.exec(shape) / RUNS); | multi_thread_times.push_back(benchmarker.exec(shape) / RUNS); | ||||
} | } | ||||
@@ -489,6 +409,7 @@ void benchmark_impl(const typename Opr::Param& param, | |||||
create_cpu_handle(0, true, &single_thread_config); | create_cpu_handle(0, true, &single_thread_config); | ||||
auto benchmarker = Benchmarker<Opr>(single_thread_handle.get()); | auto benchmarker = Benchmarker<Opr>(single_thread_handle.get()); | ||||
benchmarker.set_times(RUNS).set_display(false).set_param(param); | benchmarker.set_times(RUNS).set_display(false).set_param(param); | ||||
benchmarker.set_dtype(0, data_type); | |||||
for (auto shape : shapes) { | for (auto shape : shapes) { | ||||
single_thread_times.push_back(benchmarker.exec(shape) / RUNS); | single_thread_times.push_back(benchmarker.exec(shape) / RUNS); | ||||
} | } | ||||
@@ -540,10 +461,47 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_POOLING) { | |||||
param.stride_h = param.stride_w = 2; | param.stride_h = param.stride_w = 2; | ||||
param.pad_h = param.pad_w = 1; | param.pad_h = param.pad_w = 1; | ||||
printf("Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h, | printf("Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h, | ||||
param.stride_h, param.pad_h, static_cast<int>(param.mode)); | |||||
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {0, 1, 2, 3}}, {1, {0}}); | |||||
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}); | |||||
benchmark_impl<Pooling>(param, shapes, RUNS, {2, {0, 1}}, {1, {0}}); | |||||
param.window_w, param.stride_h, static_cast<int>(param.mode)); | |||||
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {0, 1, 2, 3}}, {1, {0}}, dtype::Float32()); | |||||
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, dtype::Float32()); | |||||
benchmark_impl<Pooling>(param, shapes, RUNS, {2, {0, 1}}, {1, {0}}, dtype::Float32()); | |||||
} | |||||
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_POOLING_NCHW44) { | |||||
constexpr size_t RUNS = 50; | |||||
using Param = param::Pooling; | |||||
Param param; | |||||
param.pad_h = param.pad_w = 0; | |||||
param.mode = Param::Mode::MAX; | |||||
std::vector<SmallVector<TensorShape>> shapes; | |||||
std::vector<std::vector<size_t>> filter_and_stride = { | |||||
{2, 1}, {2, 2}, {3, 1}, {3, 2}, {4, 1}, {4, 2}, {5, 1}, {5, 2}}; | |||||
for(auto filter:filter_and_stride){ | |||||
shapes.push_back({{1, 32 * 4, 215, 215}, {}}); | |||||
shapes.push_back({{1, 32 * 4, 128, 128}, {}}); | |||||
shapes.push_back({{1, 16 * 4, 56, 56}, {}}); | |||||
param.window_h = param.window_w = filter[0]; | |||||
param.stride_h = param.stride_w = filter[1]; | |||||
param.format = Param::Format::NCHW; | |||||
printf("NCHW Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h, | |||||
param.window_h, param.stride_h, static_cast<int>(param.mode)); | |||||
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
dtype::QuantizedS8(1.1f)); | |||||
shapes.clear(); | |||||
shapes.push_back({{1, 32, 215, 215,4}, {}}); | |||||
shapes.push_back({{1, 32, 128, 128,4}, {}}); | |||||
shapes.push_back({{1, 16, 56, 56, 4}, {}}); | |||||
param.format = Param::Format::NCHW44; | |||||
printf("NCHW44 Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h, | |||||
param.window_w, param.stride_h, static_cast<int>(param.mode)); | |||||
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
dtype::QuantizedS8(1.1f)); | |||||
shapes.clear(); | |||||
} | |||||
} | } | ||||
#endif | #endif | ||||