GitOrigin-RevId: 25eab33e14
tags/v0.5.0
@@ -73,14 +73,15 @@ WorkspaceBundle get_bundle_nchw44( | |||||
const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW, | const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW, | ||||
size_t& IH2, size_t& IW2, size_t PH, size_t PW, | size_t& IH2, size_t& IW2, size_t PH, size_t PW, | ||||
const WorkspaceBundle& ws) { | |||||
const WorkspaceBundle& ws, bool is_max_mode) { | |||||
int8_t* sptr_base = nullptr; | int8_t* sptr_base = nullptr; | ||||
int8_t padding_value = is_max_mode ? INT8_MIN : 0; | |||||
bool need_pad = ((PH != 0) || (PW != 0)) ? true : false; | bool need_pad = ((PH != 0) || (PW != 0)) ? true : false; | ||||
if (need_pad) { | if (need_pad) { | ||||
IH2 = IH + 2 * PH; | IH2 = IH + 2 * PH; | ||||
IW2 = IW + 2 * PW; | IW2 = IW + 2 * PW; | ||||
sptr_base = static_cast<int8_t*>(ws.get(0)); | sptr_base = static_cast<int8_t*>(ws.get(0)); | ||||
memset(sptr_base, -128, sizeof(int8_t) * IH2 * IW2 * 4); | |||||
memset(sptr_base, padding_value, sizeof(int8_t) * IH2 * IW2 * 4); | |||||
rep(ih, IH) { | rep(ih, IH) { | ||||
std::memcpy(sptr_base + (ih + PH) * IW2 * 4 + PW * 4, | std::memcpy(sptr_base + (ih + PH) * IW2 * 4 + PW * 4, | ||||
src + ih * IW * 4, sizeof(int8_t) * IW * 4); | src + ih * IW * 4, sizeof(int8_t) * IW * 4); | ||||
@@ -597,7 +598,7 @@ void PoolingImpl::AlgoInt8Filter3MaxStride2::exec( | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
} | } | ||||
bool PoolingImpl::AlgoFilter3MaxStridexNCHW44::usable( | |||||
bool PoolingImpl::AlgoFilter3ModexStridexNCHW44::usable( | |||||
const PoolingKernSizeParam& param) const { | const PoolingKernSizeParam& param) const { | ||||
auto SH = param.stride[0]; | auto SH = param.stride[0]; | ||||
auto SW = param.stride[1]; | auto SW = param.stride[1]; | ||||
@@ -606,12 +607,12 @@ bool PoolingImpl::AlgoFilter3MaxStridexNCHW44::usable( | |||||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
param.format == Param::Format::NCHW44 && | param.format == Param::Format::NCHW44 && | ||||
param.mode == Mode::MAX && FH == 3 && FW == 3 && SW == SH && | |||||
(SH == 1 || SW == 2); | |||||
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | |||||
FH == 3 && FW == 3 && SW == SH && (SH == 1 || SW == 2); | |||||
return avaible; | return avaible; | ||||
} | } | ||||
void PoolingImpl::AlgoFilter3MaxStridexNCHW44::exec( | |||||
void PoolingImpl::AlgoFilter3ModexStridexNCHW44::exec( | |||||
const PoolingKernParam& param) const { | const PoolingKernParam& param) const { | ||||
auto IH = param.isz[0], IW = param.isz[1]; | auto IH = param.isz[0], IW = param.isz[1]; | ||||
auto OH = param.osz[0], OW = param.osz[1]; | auto OH = param.osz[0], OW = param.osz[1]; | ||||
@@ -623,8 +624,8 @@ void PoolingImpl::AlgoFilter3MaxStridexNCHW44::exec( | |||||
void* src_ptr = param.src_ptr; | void* src_ptr = param.src_ptr; | ||||
void* dst_ptr = param.dst_ptr; | void* dst_ptr = param.dst_ptr; | ||||
#define DISPATCH_FUNC(type, func, i) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ | |||||
#define DISPATCH_FUNC(type, func, i, mode) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(8), \ | |||||
midout_iv(#type #i##_hash)) { \ | midout_iv(#type #i##_hash)) { \ | ||||
WorkspaceBundle wbundle = get_bundle_nchw44(param); \ | WorkspaceBundle wbundle = get_bundle_nchw44(param); \ | ||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | ||||
@@ -635,7 +636,7 @@ void PoolingImpl::AlgoFilter3MaxStridexNCHW44::exec( | |||||
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ | ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ | ||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_max_pooling_3x3_stride##i##_##func##_nchw44_NEON( \ | |||||
do_##mode##_pooling_3x3_stride##i##_##func##_nchw44_NEON( \ | |||||
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | ||||
c * IH * IW * 4, \ | c * IH * IW * 4, \ | ||||
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | ||||
@@ -648,27 +649,43 @@ void PoolingImpl::AlgoFilter3MaxStridexNCHW44::exec( | |||||
} \ | } \ | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
#define DISPATCH_STRIDE(type, func) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_FUNC(type, func, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_FUNC(type, func, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert(0, "unsupport stride size"); \ | |||||
#define DISPATCH_MODE(type, func, stride) \ | |||||
switch (param.mode) { \ | |||||
case Mode::MAX: { \ | |||||
DISPATCH_FUNC(type, func, stride, max); \ | |||||
break; \ | |||||
} \ | |||||
case Mode::AVERAGE: { \ | |||||
DISPATCH_FUNC(type, func, stride, avg); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_throw(ssprintf("Unsupport pooling mode %d", param.mode) \ | |||||
.c_str()); \ | |||||
} | |||||
#define DISPATCH_STRIDE(type, func) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_MODE(type, func, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_MODE(type, func, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \ | |||||
} | } | ||||
DISPATCH_STRIDE(int8_t, int8); | DISPATCH_STRIDE(int8_t, int8); | ||||
#undef DISPATCH_STRIDE | #undef DISPATCH_STRIDE | ||||
#undef DISPATCH_MODE | |||||
#undef DISPATCH_FUNC | #undef DISPATCH_FUNC | ||||
} | } | ||||
bool PoolingImpl::AlgoFilter2MaxStridexNCHW44::usable( | |||||
bool PoolingImpl::AlgoFilter2ModexStridexNCHW44::usable( | |||||
const PoolingKernSizeParam& param) const { | const PoolingKernSizeParam& param) const { | ||||
auto SH = param.stride[0]; | auto SH = param.stride[0]; | ||||
auto SW = param.stride[1]; | auto SW = param.stride[1]; | ||||
@@ -677,12 +694,12 @@ bool PoolingImpl::AlgoFilter2MaxStridexNCHW44::usable( | |||||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
param.format == Param::Format::NCHW44 && | param.format == Param::Format::NCHW44 && | ||||
param.mode == Mode::MAX && FH == 2 && FW == 2 && SH == SW && | |||||
(SW == 1 || SW == 2); | |||||
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | |||||
FH == 2 && FW == 2 && SH == SW && (SW == 1 || SW == 2); | |||||
return avaible; | return avaible; | ||||
} | } | ||||
void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( | |||||
void PoolingImpl::AlgoFilter2ModexStridexNCHW44::exec( | |||||
const PoolingKernParam& param) const { | const PoolingKernParam& param) const { | ||||
auto IH = param.isz[0], IW = param.isz[1]; | auto IH = param.isz[0], IW = param.isz[1]; | ||||
auto OH = param.osz[0], OW = param.osz[1]; | auto OH = param.osz[0], OW = param.osz[1]; | ||||
@@ -694,8 +711,8 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( | |||||
void* src_ptr = param.src_ptr; | void* src_ptr = param.src_ptr; | ||||
void* dst_ptr = param.dst_ptr; | void* dst_ptr = param.dst_ptr; | ||||
#define DISPATCH_FUNC(type, func, i) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ | |||||
#define DISPATCH_FUNC(type, func, i, mode) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(9), \ | |||||
midout_iv(#func #i##_hash)) { \ | midout_iv(#func #i##_hash)) { \ | ||||
WorkspaceBundle wbundle = get_bundle_nchw44(param); \ | WorkspaceBundle wbundle = get_bundle_nchw44(param); \ | ||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | ||||
@@ -706,7 +723,7 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( | |||||
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ | ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ | ||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_max_pooling_2x2_stride##i##_##func##_nchw44_NEON( \ | |||||
do_##mode##_pooling_2x2_stride##i##_##func##_nchw44_NEON( \ | |||||
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | ||||
c * IH * IW * 4, \ | c * IH * IW * 4, \ | ||||
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | ||||
@@ -719,27 +736,43 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( | |||||
} \ | } \ | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
#define DISPATCH_STRIDE(type, func) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_FUNC(type, func, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_FUNC(type, func, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert(0, "unsupport stride size"); \ | |||||
#define DISPATCH_MODE(type, func, stride) \ | |||||
switch (param.mode) { \ | |||||
case Mode::MAX: { \ | |||||
DISPATCH_FUNC(type, func, stride, max); \ | |||||
break; \ | |||||
} \ | |||||
case Mode::AVERAGE: { \ | |||||
DISPATCH_FUNC(type, func, stride, avg); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_throw(ssprintf("Unsupport pooling mode %d", param.mode) \ | |||||
.c_str()); \ | |||||
} | |||||
#define DISPATCH_STRIDE(type, func) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_MODE(type, func, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_MODE(type, func, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \ | |||||
} | } | ||||
DISPATCH_STRIDE(int8_t, int8); | DISPATCH_STRIDE(int8_t, int8); | ||||
#undef DISPATCH_STRIDE | #undef DISPATCH_STRIDE | ||||
#undef DISPATCH_MODE | |||||
#undef DISPATCH_FUNC | #undef DISPATCH_FUNC | ||||
} | } | ||||
bool PoolingImpl::AlgoFilter4MaxStridexNCHW44::usable( | |||||
bool PoolingImpl::AlgoFilter4ModexStridexNCHW44::usable( | |||||
const PoolingKernSizeParam& param) const { | const PoolingKernSizeParam& param) const { | ||||
auto SH = param.stride[0]; | auto SH = param.stride[0]; | ||||
auto SW = param.stride[1]; | auto SW = param.stride[1]; | ||||
@@ -748,12 +781,12 @@ bool PoolingImpl::AlgoFilter4MaxStridexNCHW44::usable( | |||||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
param.format == Param::Format::NCHW44 && | param.format == Param::Format::NCHW44 && | ||||
param.mode == Mode::MAX && FH == 4 && FW == 4 && SH == SW && | |||||
(SW == 1 || SW == 2); | |||||
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | |||||
FH == 4 && FW == 4 && SH == SW && (SW == 1 || SW == 2); | |||||
return avaible; | return avaible; | ||||
} | } | ||||
void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( | |||||
void PoolingImpl::AlgoFilter4ModexStridexNCHW44::exec( | |||||
const PoolingKernParam& param) const { | const PoolingKernParam& param) const { | ||||
auto IH = param.isz[0], IW = param.isz[1]; | auto IH = param.isz[0], IW = param.isz[1]; | ||||
auto OH = param.osz[0], OW = param.osz[1]; | auto OH = param.osz[0], OW = param.osz[1]; | ||||
@@ -765,8 +798,8 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( | |||||
void* src_ptr = param.src_ptr; | void* src_ptr = param.src_ptr; | ||||
void* dst_ptr = param.dst_ptr; | void* dst_ptr = param.dst_ptr; | ||||
#define DISPATCH_FUNC(type, func, i) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ | |||||
#define DISPATCH_FUNC(type, func, i, mode) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(10), \ | |||||
midout_iv(#func #i##_hash)) { \ | midout_iv(#func #i##_hash)) { \ | ||||
WorkspaceBundle wbundle = get_bundle_nchw44(param); \ | WorkspaceBundle wbundle = get_bundle_nchw44(param); \ | ||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | ||||
@@ -777,7 +810,7 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( | |||||
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ | ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ | ||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_max_pooling_4x4_stride##i##_##func##_nchw44_NEON( \ | |||||
do_##mode##_pooling_4x4_stride##i##_##func##_nchw44_NEON( \ | |||||
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | ||||
c * IH * IW * 4, \ | c * IH * IW * 4, \ | ||||
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | ||||
@@ -790,27 +823,43 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( | |||||
} \ | } \ | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
#define DISPATCH_STRIDE(type, func) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_FUNC(type, func, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_FUNC(type, func, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert(0, "unsupport stride size"); \ | |||||
#define DISPATCH_MODE(type, func, stride) \ | |||||
switch (param.mode) { \ | |||||
case Mode::MAX: { \ | |||||
DISPATCH_FUNC(type, func, stride, max); \ | |||||
break; \ | |||||
} \ | |||||
case Mode::AVERAGE: { \ | |||||
DISPATCH_FUNC(type, func, stride, avg); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_throw(ssprintf("Unsupport pooling mode %d", param.mode) \ | |||||
.c_str()); \ | |||||
} | |||||
#define DISPATCH_STRIDE(type, func) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_MODE(type, func, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_MODE(type, func, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \ | |||||
} | } | ||||
DISPATCH_STRIDE(int8_t, int8); | DISPATCH_STRIDE(int8_t, int8); | ||||
#undef DISPATCH_STRIDE | #undef DISPATCH_STRIDE | ||||
#undef DISPATCH_MODE | |||||
#undef DISPATCH_FUNC | #undef DISPATCH_FUNC | ||||
} | } | ||||
bool PoolingImpl::AlgoFilter5MaxStridexNCHW44::usable( | |||||
bool PoolingImpl::AlgoFilter5ModexStridexNCHW44::usable( | |||||
const PoolingKernSizeParam& param) const { | const PoolingKernSizeParam& param) const { | ||||
auto SH = param.stride[0]; | auto SH = param.stride[0]; | ||||
auto SW = param.stride[1]; | auto SW = param.stride[1]; | ||||
@@ -819,12 +868,12 @@ bool PoolingImpl::AlgoFilter5MaxStridexNCHW44::usable( | |||||
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
param.format == Param::Format::NCHW44 && | param.format == Param::Format::NCHW44 && | ||||
param.mode == Mode::MAX && FH == 5 && FW == 5 && SH == SW && | |||||
(SW == 1 || SW == 2); | |||||
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | |||||
FH == 5 && FW == 5 && SH == SW && (SW == 1 || SW == 2); | |||||
return avaible; | return avaible; | ||||
} | } | ||||
void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( | |||||
void PoolingImpl::AlgoFilter5ModexStridexNCHW44::exec( | |||||
const PoolingKernParam& param) const { | const PoolingKernParam& param) const { | ||||
auto IH = param.isz[0], IW = param.isz[1]; | auto IH = param.isz[0], IW = param.isz[1]; | ||||
auto OH = param.osz[0], OW = param.osz[1]; | auto OH = param.osz[0], OW = param.osz[1]; | ||||
@@ -836,8 +885,8 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( | |||||
void* src_ptr = param.src_ptr; | void* src_ptr = param.src_ptr; | ||||
void* dst_ptr = param.dst_ptr; | void* dst_ptr = param.dst_ptr; | ||||
#define DISPATCH_FUNC(type, func, i) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ | |||||
#define DISPATCH_FUNC(type, func, i, mode) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(11), \ | |||||
midout_iv(#func #i##_hash)) { \ | midout_iv(#func #i##_hash)) { \ | ||||
WorkspaceBundle wbundle = get_bundle_nchw44(param); \ | WorkspaceBundle wbundle = get_bundle_nchw44(param); \ | ||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | ||||
@@ -848,7 +897,7 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( | |||||
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ | ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ | ||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_max_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ | |||||
do_##mode##_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ | |||||
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | ||||
c * IH * IW * 4, \ | c * IH * IW * 4, \ | ||||
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | ||||
@@ -861,23 +910,39 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( | |||||
} \ | } \ | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
#define DISPATCH_STRIDE(type, func) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_FUNC(type, func, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_FUNC(type, func, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert(0, "unsupport stride size"); \ | |||||
#define DISPATCH_MODE(type, func, stride) \ | |||||
switch (param.mode) { \ | |||||
case Mode::MAX: { \ | |||||
DISPATCH_FUNC(type, func, stride, max); \ | |||||
break; \ | |||||
} \ | |||||
case Mode::AVERAGE: { \ | |||||
DISPATCH_FUNC(type, func, stride, avg); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_throw(ssprintf("Unsupport pooling mode %d", param.mode) \ | |||||
.c_str()); \ | |||||
} | |||||
#define DISPATCH_STRIDE(type, func) \ | |||||
switch (SW) { \ | |||||
case 1: { \ | |||||
DISPATCH_MODE(type, func, 1); \ | |||||
break; \ | |||||
} \ | |||||
case 2: { \ | |||||
DISPATCH_MODE(type, func, 2); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \ | |||||
} | } | ||||
DISPATCH_STRIDE(int8_t, int8); | DISPATCH_STRIDE(int8_t, int8); | ||||
#undef DISPATCH_STRIDE | #undef DISPATCH_STRIDE | ||||
#undef DISPATCH_MODE | |||||
#undef DISPATCH_FUNC | #undef DISPATCH_FUNC | ||||
} | } | ||||
@@ -83,34 +83,34 @@ public: | |||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
}; | }; | ||||
class PoolingImpl::AlgoFilter3MaxStridexNCHW44 final : public AlgoBase { | |||||
class PoolingImpl::AlgoFilter3ModexStridexNCHW44 final : public AlgoBase { | |||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return "ARM_POOLING_FILTER3_MAX_STRIDEX_NCHW44"; } | |||||
const char* name() const override { return "ARM_POOLING_FILTER3_MODEX_STRIDEX_NCHW44"; } | |||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
}; | }; | ||||
class PoolingImpl::AlgoFilter2MaxStridexNCHW44 final : public AlgoBase { | |||||
class PoolingImpl::AlgoFilter2ModexStridexNCHW44 final : public AlgoBase { | |||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return "ARM_POOLING_FILTER2_MAX_STRIDEX_NCHW44"; } | |||||
const char* name() const override { return "ARM_POOLING_FILTER2_MODEX_STRIDEX_NCHW44"; } | |||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
}; | }; | ||||
class PoolingImpl::AlgoFilter4MaxStridexNCHW44 final : public AlgoBase { | |||||
class PoolingImpl::AlgoFilter4ModexStridexNCHW44 final : public AlgoBase { | |||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return "ARM_POOLING_FILTER4_MAX_STRIDEX_NCHW44"; } | |||||
const char* name() const override { return "ARM_POOLING_FILTER4_MODEX_STRIDEX_NCHW44"; } | |||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
}; | }; | ||||
class PoolingImpl::AlgoFilter5MaxStridexNCHW44 final : public AlgoBase { | |||||
class PoolingImpl::AlgoFilter5ModexStridexNCHW44 final : public AlgoBase { | |||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return "ARM_POOLING_FILTER5_MAX_STRIDEX_NCHW44"; } | |||||
const char* name() const override { return "ARM_POOLING_FILTER5_MODEX_STRIDEX_NCHW44"; } | |||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
}; | }; | ||||
@@ -122,7 +122,7 @@ WorkspaceBundle get_bundle_nchw44( | |||||
const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW, | const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW, | ||||
size_t& IH2, size_t& IW2, size_t PH, size_t PW, | size_t& IH2, size_t& IW2, size_t PH, size_t PW, | ||||
const WorkspaceBundle& ws); | |||||
const WorkspaceBundle& ws, bool is_max_mode); | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.cpp | |||||
* \file dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -24,7 +24,7 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
const WorkspaceBundle& ws) { | const WorkspaceBundle& ws) { | ||||
const int8_t* sptr = nullptr; | const int8_t* sptr = nullptr; | ||||
size_t IH2, IW2; | size_t IH2, IW2; | ||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); | |||||
size_t oh = 0; | size_t oh = 0; | ||||
for (; oh < OH; ++oh) { | for (; oh < OH; ++oh) { | ||||
size_t ih = oh; | size_t ih = oh; | ||||
@@ -70,7 +70,7 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
const WorkspaceBundle& ws) { | const WorkspaceBundle& ws) { | ||||
const int8_t* sptr = nullptr; | const int8_t* sptr = nullptr; | ||||
size_t IH2, IW2; | size_t IH2, IW2; | ||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); | |||||
size_t oh = 0; | size_t oh = 0; | ||||
for (; oh < OH; ++oh) { | for (; oh < OH; ++oh) { | ||||
size_t ih = oh << 1; | size_t ih = oh << 1; | ||||
@@ -120,6 +120,206 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
} | } | ||||
} | } | ||||
void do_avg_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
int16_t filter_size = 4; | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); | |||||
size_t oh = 0; | |||||
for (; oh < OH; ++oh) { | |||||
size_t ih = oh; | |||||
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||||
size_t ow = 0; | |||||
for (; ow + 3 < OW; ow += 4) { | |||||
int8x16_t src0123, src1234; | |||||
int16x8_t src01, src23, src12, src34; | |||||
int16x8_t sum01 = vdupq_n_s16(0); | |||||
int16x8_t sum23 = vdupq_n_s16(0); | |||||
#define CACULATE_ROW(i) \ | |||||
src0123 = vld1q_s8(sptr##i); \ | |||||
src1234 = vld1q_s8(sptr##i + 4); \ | |||||
src01 = vmovl_s8(vget_low_s8(src0123)); \ | |||||
src23 = vmovl_s8(vget_high_s8(src0123)); \ | |||||
src12 = vmovl_s8(vget_low_s8(src1234)); \ | |||||
src34 = vmovl_s8(vget_high_s8(src1234)); \ | |||||
sum01 = vaddq_s16(sum01, src01); \ | |||||
sum01 = vaddq_s16(sum01, src12); \ | |||||
sum23 = vaddq_s16(sum23, src23); \ | |||||
sum23 = vaddq_s16(sum23, src34); | |||||
UNROLL_CALL_NOWRAPPER(2, CACULATE_ROW) | |||||
#define sum_define(i) int16_t sum##i; | |||||
UNROLL_CALL_NOWRAPPER(8, sum_define) | |||||
#define sum01_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum01, i) > 0 \ | |||||
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define sum23_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum23, i) > 0 \ | |||||
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(8, sum01_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum01) | |||||
UNROLL_CALL_NOWRAPPER(8, sum23_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum23) | |||||
sptr0 += 16; | |||||
sptr1 += 16; | |||||
dptr += 16; | |||||
#undef store_sum01 | |||||
#undef store_sum23 | |||||
#undef sum01_avg | |||||
#undef sum23_avg | |||||
#undef sum_define | |||||
#undef CACULATE_ROW | |||||
} | |||||
for (; ow < OW; ++ow) { | |||||
int8x8_t src001 = vld1_s8(sptr0); | |||||
int8x8_t src101 = vld1_s8(sptr1); | |||||
int16x8_t src00 = vmovl_s8(src001); | |||||
int16x8_t src10 = vmovl_s8(src101); | |||||
int16x8_t max_tmp = vaddq_s16(src00, src10); | |||||
#define do_acc(i) \ | |||||
int16_t sum##i = \ | |||||
vgetq_lane_s16(max_tmp, i) + vgetq_lane_s16(max_tmp, i + 4); | |||||
#define do_avg(i) \ | |||||
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ | |||||
: (sum##i - filter_size / 2) / filter_size; | |||||
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(4, do_acc) | |||||
UNROLL_CALL_NOWRAPPER(4, do_avg) | |||||
UNROLL_CALL_NOWRAPPER(4, store) | |||||
#undef store | |||||
#undef do_avg | |||||
#undef do_acc | |||||
sptr0 += 4; | |||||
sptr1 += 4; | |||||
dptr += 4; | |||||
} | |||||
} | |||||
} | |||||
void do_avg_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
int16_t filter_size = 4; | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); | |||||
size_t oh = 0; | |||||
for (; oh < OH; ++oh) { | |||||
size_t ih = oh << 1; | |||||
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||||
size_t ow = 0; | |||||
for (; ow + 3 < OW; ow += 4) { | |||||
int32x4x2_t src_tmp; | |||||
int8x16_t src00, src04; | |||||
int32x4_t src0246, src1357; | |||||
int16x8_t src02, src13, src46, src57; | |||||
int16x8_t sum01 = vdupq_n_s16(0); | |||||
int16x8_t sum23 = vdupq_n_s16(0); | |||||
#define CACULATE_ROW(i) \ | |||||
src00 = vld1q_s8(sptr##i); \ | |||||
src04 = vld1q_s8(sptr##i + 4 * 4); \ | |||||
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ | |||||
vreinterpretq_s32_s8(src04)); \ | |||||
src0246 = src_tmp.val[0]; \ | |||||
src1357 = src_tmp.val[1]; \ | |||||
src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ | |||||
src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ | |||||
src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ | |||||
src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ | |||||
sum01 = vaddq_s16(sum01, src02); \ | |||||
sum01 = vaddq_s16(sum01, src13); \ | |||||
sum23 = vaddq_s16(sum23, src46); \ | |||||
sum23 = vaddq_s16(sum23, src57); | |||||
UNROLL_CALL_NOWRAPPER(2, CACULATE_ROW) | |||||
#define sum_define(i) int16_t sum##i; | |||||
UNROLL_CALL_NOWRAPPER(8, sum_define) | |||||
#define sum01_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum01, i) > 0 \ | |||||
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define sum23_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum23, i) > 0 \ | |||||
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(8, sum01_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum01) | |||||
UNROLL_CALL_NOWRAPPER(8, sum23_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum23) | |||||
sptr0 += 32; | |||||
sptr1 += 32; | |||||
dptr += 16; | |||||
#undef store_sum01 | |||||
#undef store_sum23 | |||||
#undef sum01_avg | |||||
#undef sum23_avg | |||||
#undef sum_define | |||||
#undef CACULATE_ROW | |||||
} | |||||
for (; ow < OW; ++ow) { | |||||
int8x8_t src001 = vld1_s8(sptr0); | |||||
int8x8_t src101 = vld1_s8(sptr1); | |||||
int16x8_t src00 = vmovl_s8(src001); | |||||
int16x8_t src10 = vmovl_s8(src101); | |||||
int16x8_t max_tmp = vaddq_s16(src00, src10); | |||||
#define do_acc(i) \ | |||||
int16_t sum##i = \ | |||||
vgetq_lane_s16(max_tmp, i) + vgetq_lane_s16(max_tmp, i + 4); | |||||
#define do_avg(i) \ | |||||
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ | |||||
: (sum##i - filter_size / 2) / filter_size; | |||||
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(4, do_acc) | |||||
UNROLL_CALL_NOWRAPPER(4, do_avg) | |||||
UNROLL_CALL_NOWRAPPER(4, store) | |||||
#undef do_avg | |||||
#undef do_acc | |||||
#undef store | |||||
sptr0 += 8; | |||||
sptr1 += 8; | |||||
dptr += 4; | |||||
} | |||||
} | |||||
} | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.h | |||||
* \file dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -15,16 +15,15 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | namespace arm_common { | ||||
void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws); | |||||
void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws); | |||||
#define KERN(mode, stride, ctype) \ | |||||
void do_##mode##_pooling_2x2_##stride##_##ctype##_nchw44_NEON( \ | |||||
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ | |||||
size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws); | |||||
KERN(max, stride1, int8) | |||||
KERN(max, stride2, int8) | |||||
KERN(avg, stride1, int8) | |||||
KERN(avg, stride2, int8) | |||||
#undef KERN | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp | |||||
* \file dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -24,7 +24,7 @@ void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
const WorkspaceBundle& ws) { | const WorkspaceBundle& ws) { | ||||
const int8_t* sptr = nullptr; | const int8_t* sptr = nullptr; | ||||
size_t IH2, IW2; | size_t IH2, IW2; | ||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); | |||||
size_t oh = 0; | size_t oh = 0; | ||||
for (; oh < OH; ++oh) { | for (; oh < OH; ++oh) { | ||||
size_t ih = oh; | size_t ih = oh; | ||||
@@ -99,7 +99,7 @@ void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
const WorkspaceBundle& ws) { | const WorkspaceBundle& ws) { | ||||
const int8_t* sptr = nullptr; | const int8_t* sptr = nullptr; | ||||
size_t IH2, IW2; | size_t IH2, IW2; | ||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); | |||||
size_t oh = 0; | size_t oh = 0; | ||||
for (; oh < OH; ++oh) { | for (; oh < OH; ++oh) { | ||||
size_t ih = oh << 1; | size_t ih = oh << 1; | ||||
@@ -190,6 +190,241 @@ void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
} | } | ||||
} | } | ||||
void do_avg_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
int16_t filter_size = 9; | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); | |||||
size_t oh = 0; | |||||
for (; oh < OH; ++oh) { | |||||
size_t ih = oh; | |||||
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||||
size_t ow = 0; | |||||
for (; ow + 3 < OW; ow += 4) { | |||||
int8x16_t src0123, src1234, src2345; | |||||
int16x8_t src01, src23, src12, src34, src45; | |||||
int16x8_t sum01 = vdupq_n_s16(0); | |||||
int16x8_t sum23 = vdupq_n_s16(0); | |||||
#define CACULATE_ROW(i) \ | |||||
src0123 = vld1q_s8(sptr##i); \ | |||||
src1234 = vld1q_s8(sptr##i + 4); \ | |||||
src2345 = vld1q_s8(sptr##i + 8); \ | |||||
src01 = vmovl_s8(vget_low_s8(src0123)); \ | |||||
src23 = vmovl_s8(vget_high_s8(src0123)); \ | |||||
src12 = vmovl_s8(vget_low_s8(src1234)); \ | |||||
src34 = vmovl_s8(vget_high_s8(src1234)); \ | |||||
src45 = vmovl_s8(vget_high_s8(src2345)); \ | |||||
sum01 = vaddq_s16(sum01, src01); \ | |||||
sum01 = vaddq_s16(sum01, src12); \ | |||||
sum01 = vaddq_s16(sum01, src23); \ | |||||
sum23 = vaddq_s16(sum23, src23); \ | |||||
sum23 = vaddq_s16(sum23, src34); \ | |||||
sum23 = vaddq_s16(sum23, src45); | |||||
UNROLL_CALL_NOWRAPPER(3, CACULATE_ROW) | |||||
#define sum_define(i) int16_t sum##i; | |||||
UNROLL_CALL_NOWRAPPER(8, sum_define) | |||||
#define sum01_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum01, i) > 0 \ | |||||
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define sum23_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum23, i) > 0 \ | |||||
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(8, sum01_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum01) | |||||
UNROLL_CALL_NOWRAPPER(8, sum23_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum23) | |||||
sptr0 += 16; | |||||
sptr1 += 16; | |||||
sptr2 += 16; | |||||
dptr += 16; | |||||
#undef store_sum01 | |||||
#undef store_sum23 | |||||
#undef sum01_avg | |||||
#undef sum23_avg | |||||
#undef sum_define | |||||
#undef CACULATE_ROW | |||||
} | |||||
for (; ow < OW; ++ow) { | |||||
int8x8_t src001, src012; | |||||
int16x8_t src01, src12, sum01, sum02; | |||||
sum01 = vdupq_n_s16(0); | |||||
sum02 = vdupq_n_s16(0); | |||||
#define CACULATE_ROW(i) \ | |||||
src001 = vld1_s8(sptr##i); \ | |||||
src012 = vld1_s8(sptr##i + 4); \ | |||||
src01 = vmovl_s8(src001); \ | |||||
src12 = vmovl_s8(src012); \ | |||||
sum01 = vaddq_s16(sum01, src01); \ | |||||
sum02 = vaddq_s16(sum02, src12); | |||||
UNROLL_CALL_NOWRAPPER(3, CACULATE_ROW) | |||||
#define do_acc(i) \ | |||||
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \ | |||||
vgetq_lane_s16(sum02, i + 4); | |||||
#define do_avg(i) \ | |||||
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ | |||||
: (sum##i - filter_size / 2) / filter_size; | |||||
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(4, do_acc) | |||||
UNROLL_CALL_NOWRAPPER(4, do_avg) | |||||
UNROLL_CALL_NOWRAPPER(4, store) | |||||
#undef store | |||||
#undef do_avg | |||||
#undef do_acc | |||||
#undef CACULATE_ROW | |||||
sptr0 += 4; | |||||
sptr1 += 4; | |||||
sptr2 += 4; | |||||
dptr += 4; | |||||
} | |||||
} | |||||
} | |||||
void do_avg_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
int16_t filter_size = 9; | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); | |||||
size_t oh = 0; | |||||
for (; oh < OH; ++oh) { | |||||
size_t ih = oh << 1; | |||||
const int8_t* sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
const int8_t* sptr2 = sptr + (ih + 2) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||||
size_t ow = 0; | |||||
for (; ow + 3 < OW; ow += 4) { | |||||
int32x4x2_t src_tmp; | |||||
int8x16_t src00, src04; | |||||
int32x4_t src0246, src1357, src2468, src08; | |||||
int16x8_t src02, src46, src13, src57, src24, src68; | |||||
int16x8_t sum01 = vdupq_n_s16(0); | |||||
int16x8_t sum23 = vdupq_n_s16(0); | |||||
#define CACULATE_ROW(i) \ | |||||
src00 = vld1q_s8(sptr##i); \ | |||||
src04 = vld1q_s8(sptr##i + 4 * 4); \ | |||||
src08 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 8)); \ | |||||
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ | |||||
vreinterpretq_s32_s8(src04)); \ | |||||
src0246 = src_tmp.val[0]; \ | |||||
src1357 = src_tmp.val[1]; \ | |||||
src2468 = vextq_s32(src0246, src08, 1); \ | |||||
src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ | |||||
src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ | |||||
src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ | |||||
src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ | |||||
src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \ | |||||
src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \ | |||||
sum01 = vaddq_s16(sum01, src02); \ | |||||
sum01 = vaddq_s16(sum01, src13); \ | |||||
sum01 = vaddq_s16(sum01, src24); \ | |||||
sum23 = vaddq_s16(sum23, src46); \ | |||||
sum23 = vaddq_s16(sum23, src57); \ | |||||
sum23 = vaddq_s16(sum23, src68); | |||||
UNROLL_CALL_NOWRAPPER(3, CACULATE_ROW) | |||||
#define sum_define(i) int16_t sum##i; | |||||
UNROLL_CALL_NOWRAPPER(8, sum_define) | |||||
#define sum01_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum01, i) > 0 \ | |||||
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define sum23_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum23, i) > 0 \ | |||||
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(8, sum01_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum01) | |||||
UNROLL_CALL_NOWRAPPER(8, sum23_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum23) | |||||
sptr0 += 32; | |||||
sptr1 += 32; | |||||
sptr2 += 32; | |||||
dptr += 16; | |||||
#undef store_sum01 | |||||
#undef store_sum23 | |||||
#undef sum01_avg | |||||
#undef sum23_avg | |||||
#undef sum_define | |||||
#undef CACULATE_ROW | |||||
} | |||||
for (; ow < OW; ++ow) { | |||||
int8x8_t src001, src012; | |||||
int16x8_t src01, src12, sum01, sum02; | |||||
sum01 = vdupq_n_s16(0); | |||||
sum02 = vdupq_n_s16(0); | |||||
#define CACULATE_ROW(i) \ | |||||
src001 = vld1_s8(sptr##i); \ | |||||
src012 = vld1_s8(sptr##i + 4); \ | |||||
src01 = vmovl_s8(src001); \ | |||||
src12 = vmovl_s8(src012); \ | |||||
sum01 = vaddq_s16(sum01, src01); \ | |||||
sum02 = vaddq_s16(sum02, src12); | |||||
UNROLL_CALL_NOWRAPPER(3, CACULATE_ROW) | |||||
#define do_acc(i) \ | |||||
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \ | |||||
vgetq_lane_s16(sum02, i + 4); | |||||
#define do_avg(i) \ | |||||
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ | |||||
: (sum##i - filter_size / 2) / filter_size; | |||||
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(4, do_acc) | |||||
UNROLL_CALL_NOWRAPPER(4, do_avg) | |||||
UNROLL_CALL_NOWRAPPER(4, store) | |||||
#undef store | |||||
#undef do_avg | |||||
#undef do_acc | |||||
sptr0 += 8; | |||||
sptr1 += 8; | |||||
sptr2 += 8; | |||||
dptr += 4; | |||||
} | |||||
} | |||||
} | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h | |||||
* \file dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -15,16 +15,15 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | namespace arm_common { | ||||
void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws); | |||||
void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, size_t OH, | |||||
size_t OW, size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws); | |||||
#define KERN(mode, stride, ctype) \ | |||||
void do_##mode##_pooling_3x3_##stride##_##ctype##_nchw44_NEON( \ | |||||
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ | |||||
size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws); | |||||
KERN(max, stride1, int8) | |||||
KERN(max, stride2, int8) | |||||
KERN(avg, stride1, int8) | |||||
KERN(avg, stride2, int8) | |||||
#undef KERN | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -24,7 +24,7 @@ void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
const WorkspaceBundle& ws) { | const WorkspaceBundle& ws) { | ||||
const int8_t* sptr = nullptr; | const int8_t* sptr = nullptr; | ||||
size_t IH2, IW2; | size_t IH2, IW2; | ||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); | |||||
size_t oh = 0; | size_t oh = 0; | ||||
for (; oh < OH; ++oh) { | for (; oh < OH; ++oh) { | ||||
size_t ih = oh; | size_t ih = oh; | ||||
@@ -99,7 +99,7 @@ void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
const WorkspaceBundle& ws) { | const WorkspaceBundle& ws) { | ||||
const int8_t* sptr = nullptr; | const int8_t* sptr = nullptr; | ||||
size_t IH2, IW2; | size_t IH2, IW2; | ||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); | |||||
size_t oh = 0; | size_t oh = 0; | ||||
for (; oh < OH; ++oh) { | for (; oh < OH; ++oh) { | ||||
size_t ih = oh << 1; | size_t ih = oh << 1; | ||||
@@ -171,6 +171,252 @@ void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
} | } | ||||
} | } | ||||
void do_avg_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
int16_t filter_size = 16; | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); | |||||
size_t oh = 0; | |||||
for (; oh < OH; ++oh) { | |||||
size_t ih = oh; | |||||
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4; | |||||
const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||||
size_t ow = 0; | |||||
for (; ow + 3 < OW; ow += 4) { | |||||
int16x8_t src01, src23, src12, src34, src45, src56; | |||||
int16x8_t sum01 = vdupq_n_s16(0); | |||||
int16x8_t sum23 = vdupq_n_s16(0); | |||||
#define CACULATE_ROW(i) \ | |||||
src01 = vmovl_s8(vld1_s8(sptr##i)); \ | |||||
src23 = vmovl_s8(vld1_s8(sptr##i + 8)); \ | |||||
src12 = vmovl_s8(vld1_s8(sptr##i + 4)); \ | |||||
src34 = vmovl_s8(vld1_s8(sptr##i + 12)); \ | |||||
src45 = vmovl_s8(vld1_s8(sptr##i + 16)); \ | |||||
src56 = vmovl_s8(vld1_s8(sptr##i + 20)); \ | |||||
sum01 = vaddq_s16(sum01, src01); \ | |||||
sum01 = vaddq_s16(sum01, src12); \ | |||||
sum01 = vaddq_s16(sum01, src23); \ | |||||
sum01 = vaddq_s16(sum01, src34); \ | |||||
sum23 = vaddq_s16(sum23, src23); \ | |||||
sum23 = vaddq_s16(sum23, src34); \ | |||||
sum23 = vaddq_s16(sum23, src45); \ | |||||
sum23 = vaddq_s16(sum23, src56); | |||||
UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW) | |||||
#define sum_define(i) int16_t sum##i; | |||||
UNROLL_CALL_NOWRAPPER(8, sum_define) | |||||
#define sum01_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum01, i) > 0 \ | |||||
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define sum23_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum23, i) > 0 \ | |||||
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(8, sum01_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum01) | |||||
UNROLL_CALL_NOWRAPPER(8, sum23_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum23) | |||||
sptr0 += 16; | |||||
sptr1 += 16; | |||||
sptr2 += 16; | |||||
sptr3 += 16; | |||||
dptr += 16; | |||||
#undef store_sum01 | |||||
#undef store_sum23 | |||||
#undef sum01_avg | |||||
#undef sum23_avg | |||||
#undef sum_define | |||||
#undef CACULATE_ROW | |||||
} | |||||
for (; ow < OW; ++ow) { | |||||
int16x8_t src01, src23, sum01; | |||||
sum01 = vdupq_n_s16(0); | |||||
#define CACULATE_ROW(i) \ | |||||
src01 = vmovl_s8(vld1_s8(sptr##i)); \ | |||||
src23 = vmovl_s8(vld1_s8(sptr##i + 8)); \ | |||||
sum01 = vaddq_s16(sum01, src01); \ | |||||
sum01 = vaddq_s16(sum01, src23); | |||||
UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW) | |||||
#define do_acc(i) \ | |||||
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4); | |||||
#define do_avg(i) \ | |||||
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ | |||||
: (sum##i - filter_size / 2) / filter_size; | |||||
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(4, do_acc) | |||||
UNROLL_CALL_NOWRAPPER(4, do_avg) | |||||
UNROLL_CALL_NOWRAPPER(4, store) | |||||
#undef store | |||||
#undef do_avg | |||||
#undef do_acc | |||||
#undef CACULATE_ROW | |||||
sptr0 += 4; | |||||
sptr1 += 4; | |||||
sptr2 += 4; | |||||
sptr3 += 4; | |||||
dptr += 4; | |||||
} | |||||
} | |||||
} | |||||
void do_avg_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
int16_t filter_size = 16; | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); | |||||
size_t oh = 0; | |||||
for (; oh < OH; ++oh) { | |||||
size_t ih = oh << 1; | |||||
const int8_t* sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
const int8_t* sptr2 = sptr + (ih + 2) * IW2 * 4; | |||||
const int8_t* sptr3 = sptr + (ih + 3) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||||
size_t ow = 0; | |||||
for (; ow + 3 < OW; ow += 4) { | |||||
int32x4x2_t src_tmp; | |||||
int8x16_t src00, src04; | |||||
int16x8_t src02, src13, src57, src24, src68, src35, src79, src46; | |||||
int32x4_t src08, src09, src0246, src1357, src2468, src3579; | |||||
int16x8_t sum01 = vdupq_n_s16(0); | |||||
int16x8_t sum23 = vdupq_n_s16(0); | |||||
#define CACULATE_ROW(i) \ | |||||
src00 = vld1q_s8(sptr##i); \ | |||||
src04 = vld1q_s8(sptr##i + 4 * 4); \ | |||||
src08 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 8)); \ | |||||
src09 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 9)); \ | |||||
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ | |||||
vreinterpretq_s32_s8(src04)); \ | |||||
src0246 = src_tmp.val[0]; \ | |||||
src1357 = src_tmp.val[1]; \ | |||||
src2468 = vextq_s32(src0246, src08, 1); \ | |||||
src3579 = vextq_s32(src1357, src09, 1); \ | |||||
src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ | |||||
src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ | |||||
src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ | |||||
src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ | |||||
src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \ | |||||
src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \ | |||||
src35 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src3579))); \ | |||||
src79 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src3579))); \ | |||||
sum01 = vaddq_s16(sum01, src02); \ | |||||
sum01 = vaddq_s16(sum01, src13); \ | |||||
sum01 = vaddq_s16(sum01, src24); \ | |||||
sum01 = vaddq_s16(sum01, src35); \ | |||||
sum23 = vaddq_s16(sum23, src46); \ | |||||
sum23 = vaddq_s16(sum23, src57); \ | |||||
sum23 = vaddq_s16(sum23, src68); \ | |||||
sum23 = vaddq_s16(sum23, src79); | |||||
UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW) | |||||
#define sum_define(i) int16_t sum##i; | |||||
UNROLL_CALL_NOWRAPPER(8, sum_define) | |||||
#define sum01_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum01, i) > 0 \ | |||||
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define sum23_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum23, i) > 0 \ | |||||
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(8, sum01_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum01) | |||||
UNROLL_CALL_NOWRAPPER(8, sum23_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum23) | |||||
sptr0 += 32; | |||||
sptr1 += 32; | |||||
sptr2 += 32; | |||||
sptr3 += 32; | |||||
dptr += 16; | |||||
#undef store_sum01 | |||||
#undef store_sum23 | |||||
#undef sum01_avg | |||||
#undef sum23_avg | |||||
#undef sum_define | |||||
#undef CACULATE_ROW | |||||
} | |||||
for (; ow < OW; ++ow) { | |||||
int8x8_t src001, src023; | |||||
int16x8_t src01, src23, sum01; | |||||
sum01 = vdupq_n_s16(0); | |||||
#define CACULATE_ROW(i) \ | |||||
src001 = vld1_s8(sptr##i); \ | |||||
src023 = vld1_s8(sptr##i + 8); \ | |||||
src01 = vmovl_s8(src001); \ | |||||
src23 = vmovl_s8(src023); \ | |||||
sum01 = vaddq_s16(sum01, src01); \ | |||||
sum01 = vaddq_s16(sum01, src23); | |||||
UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW) | |||||
#define do_acc(i) \ | |||||
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4); | |||||
#define do_avg(i) \ | |||||
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ | |||||
: (sum##i - filter_size / 2) / filter_size; | |||||
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(4, do_acc) | |||||
UNROLL_CALL_NOWRAPPER(4, do_avg) | |||||
UNROLL_CALL_NOWRAPPER(4, store) | |||||
#undef store | |||||
#undef do_avg | |||||
#undef do_acc | |||||
#undef CACULATE_ROW | |||||
sptr0 += 8; | |||||
sptr1 += 8; | |||||
sptr2 += 8; | |||||
sptr3 += 8; | |||||
dptr += 4; | |||||
} | |||||
} | |||||
} | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.h | |||||
* \file dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -15,15 +15,16 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | namespace arm_common { | ||||
#define KERN(strdie) \ | |||||
void do_max_pooling_4x4_##strdie##_int8_nchw44_NEON( \ | |||||
#define KERN(mode, stride, ctype) \ | |||||
void do_##mode##_pooling_4x4_##stride##_##ctype##_nchw44_NEON( \ | |||||
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ | const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ | ||||
size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws); | size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws); | ||||
KERN(stride1) | |||||
KERN(stride2) | |||||
KERN(max, stride1, int8) | |||||
KERN(max, stride2, int8) | |||||
KERN(avg, stride1, int8) | |||||
KERN(avg, stride2, int8) | |||||
#undef KERN | #undef KERN | ||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.cpp | |||||
* \file dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -24,7 +24,7 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
const WorkspaceBundle& ws) { | const WorkspaceBundle& ws) { | ||||
const int8_t* sptr = nullptr; | const int8_t* sptr = nullptr; | ||||
size_t IH2, IW2; | size_t IH2, IW2; | ||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); | |||||
size_t oh = 0; | size_t oh = 0; | ||||
for (; oh < OH; ++oh) { | for (; oh < OH; ++oh) { | ||||
size_t ih = oh; | size_t ih = oh; | ||||
@@ -118,7 +118,7 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
const WorkspaceBundle& ws) { | const WorkspaceBundle& ws) { | ||||
const int8_t* sptr = nullptr; | const int8_t* sptr = nullptr; | ||||
size_t IH2, IW2; | size_t IH2, IW2; | ||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); | |||||
size_t oh = 0; | size_t oh = 0; | ||||
for (; oh < OH; ++oh) { | for (; oh < OH; ++oh) { | ||||
size_t ih = oh << 1; | size_t ih = oh << 1; | ||||
@@ -213,6 +213,284 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
} | } | ||||
} | } | ||||
void do_avg_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
int16_t filter_size = 25; | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); | |||||
size_t oh = 0; | |||||
for (; oh < OH; ++oh) { | |||||
size_t ih = oh; | |||||
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4; | |||||
const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW2 * 4; | |||||
const int8_t* __restrict sptr4 = sptr + (ih + 4) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||||
size_t ow = 0; | |||||
for (; ow + 3 < OW; ow += 4) { | |||||
int16x8_t src01, src23, src12, src34, src45, src56, src67; | |||||
int16x8_t sum01 = vdupq_n_s16(0); | |||||
int16x8_t sum23 = vdupq_n_s16(0); | |||||
#define CACULATE_ROW(i) \ | |||||
src01 = vmovl_s8(vld1_s8(sptr##i)); \ | |||||
src23 = vmovl_s8(vld1_s8(sptr##i + 8)); \ | |||||
src12 = vmovl_s8(vld1_s8(sptr##i + 4)); \ | |||||
src34 = vmovl_s8(vld1_s8(sptr##i + 12)); \ | |||||
src45 = vmovl_s8(vld1_s8(sptr##i + 16)); \ | |||||
src56 = vmovl_s8(vld1_s8(sptr##i + 20)); \ | |||||
src67 = vmovl_s8(vld1_s8(sptr##i + 24)); \ | |||||
sum01 = vaddq_s16(sum01, src01); \ | |||||
sum01 = vaddq_s16(sum01, src12); \ | |||||
sum01 = vaddq_s16(sum01, src23); \ | |||||
sum01 = vaddq_s16(sum01, src34); \ | |||||
sum01 = vaddq_s16(sum01, src45); \ | |||||
sum23 = vaddq_s16(sum23, src23); \ | |||||
sum23 = vaddq_s16(sum23, src34); \ | |||||
sum23 = vaddq_s16(sum23, src45); \ | |||||
sum23 = vaddq_s16(sum23, src56); \ | |||||
sum23 = vaddq_s16(sum23, src67); | |||||
UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) | |||||
#define sum_define(i) int16_t sum##i; | |||||
UNROLL_CALL_NOWRAPPER(8, sum_define) | |||||
#define sum01_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum01, i) > 0 \ | |||||
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define sum23_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum23, i) > 0 \ | |||||
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(8, sum01_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum01) | |||||
UNROLL_CALL_NOWRAPPER(8, sum23_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum23) | |||||
sptr0 += 16; | |||||
sptr1 += 16; | |||||
sptr2 += 16; | |||||
sptr3 += 16; | |||||
sptr4 += 16; | |||||
dptr += 16; | |||||
#undef store_sum01 | |||||
#undef store_sum23 | |||||
#undef sum01_avg | |||||
#undef sum23_avg | |||||
#undef sum_define | |||||
#undef CACULATE_ROW | |||||
} | |||||
for (; ow < OW; ++ow) { | |||||
int32x2_t src004; | |||||
int8x8_t src001, src023; | |||||
int16x8_t src01, src23, src04, sum01, sum02; | |||||
sum01 = vdupq_n_s16(0); | |||||
sum02 = vdupq_n_s16(0); | |||||
#define CACULATE_ROW(i) \ | |||||
src001 = vld1_s8(sptr##i); \ | |||||
src023 = vld1_s8(sptr##i + 8); \ | |||||
src004 = vld1_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 4)); \ | |||||
src01 = vmovl_s8(src001); \ | |||||
src23 = vmovl_s8(src023); \ | |||||
src04 = vmovl_s8(vreinterpret_s8_s32(src004)); \ | |||||
sum01 = vaddq_s16(sum01, src01); \ | |||||
sum01 = vaddq_s16(sum01, src23); \ | |||||
sum02 = vaddq_s16(sum02, src04); | |||||
UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) | |||||
#define do_acc(i) \ | |||||
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \ | |||||
vgetq_lane_s16(sum02, i + 4); | |||||
#define do_avg(i) \ | |||||
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ | |||||
: (sum##i - filter_size / 2) / filter_size; | |||||
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(4, do_acc) | |||||
UNROLL_CALL_NOWRAPPER(4, do_avg) | |||||
UNROLL_CALL_NOWRAPPER(4, store) | |||||
#undef store | |||||
#undef do_avg | |||||
#undef do_acc | |||||
#undef CACULATE_ROW | |||||
sptr0 += 4; | |||||
sptr1 += 4; | |||||
sptr2 += 4; | |||||
sptr3 += 4; | |||||
sptr4 += 4; | |||||
dptr += 4; | |||||
} | |||||
} | |||||
} | |||||
void do_avg_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, | |||||
size_t IH, size_t IW, | |||||
size_t OH, size_t OW, | |||||
size_t PH, size_t PW, | |||||
const WorkspaceBundle& ws) { | |||||
int16_t filter_size = 25; | |||||
const int8_t* sptr = nullptr; | |||||
size_t IH2, IW2; | |||||
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); | |||||
size_t oh = 0; | |||||
for (; oh < OH; ++oh) { | |||||
size_t ih = oh << 1; | |||||
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; | |||||
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; | |||||
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4; | |||||
const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW2 * 4; | |||||
const int8_t* __restrict sptr4 = sptr + (ih + 4) * IW2 * 4; | |||||
int8_t* __restrict dptr = dst + oh * OW * 4; | |||||
size_t ow = 0; | |||||
for (; ow + 3 < OW; ow += 4) { | |||||
int32x4x2_t src_tmp; | |||||
int8x16_t src00, src04; | |||||
int16x8_t src02, src13, src57, src24, src68, src35, src79, src46, | |||||
src810; | |||||
int32x4_t src08, src09, src10, src0246, src1357, src2468, src3579, | |||||
src46810; | |||||
int16x8_t sum01 = vdupq_n_s16(0); | |||||
int16x8_t sum23 = vdupq_n_s16(0); | |||||
#define CACULATE_ROW(i) \ | |||||
src00 = vld1q_s8(sptr##i); \ | |||||
src04 = vld1q_s8(sptr##i + 4 * 4); \ | |||||
src08 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 8)); \ | |||||
src09 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 9)); \ | |||||
src10 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 10)); \ | |||||
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ | |||||
vreinterpretq_s32_s8(src04)); \ | |||||
src0246 = src_tmp.val[0]; \ | |||||
src1357 = src_tmp.val[1]; \ | |||||
src2468 = vextq_s32(src0246, src08, 1); \ | |||||
src3579 = vextq_s32(src1357, src09, 1); \ | |||||
src46810 = vextq_s32(src2468, src10, 1); \ | |||||
src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ | |||||
src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ | |||||
src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ | |||||
src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ | |||||
src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \ | |||||
src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \ | |||||
src35 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src3579))); \ | |||||
src79 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src3579))); \ | |||||
src46 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src46810))); \ | |||||
src810 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src46810))); \ | |||||
sum01 = vaddq_s16(sum01, src02); \ | |||||
sum01 = vaddq_s16(sum01, src13); \ | |||||
sum01 = vaddq_s16(sum01, src24); \ | |||||
sum01 = vaddq_s16(sum01, src35); \ | |||||
sum01 = vaddq_s16(sum01, src46); \ | |||||
sum23 = vaddq_s16(sum23, src46); \ | |||||
sum23 = vaddq_s16(sum23, src57); \ | |||||
sum23 = vaddq_s16(sum23, src68); \ | |||||
sum23 = vaddq_s16(sum23, src79); \ | |||||
sum23 = vaddq_s16(sum23, src810); | |||||
UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) | |||||
#define sum_define(i) int16_t sum##i; | |||||
UNROLL_CALL_NOWRAPPER(8, sum_define) | |||||
#define sum01_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum01, i) > 0 \ | |||||
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define sum23_avg(i) \ | |||||
sum##i = vgetq_lane_s16(sum23, i) > 0 \ | |||||
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ | |||||
filter_size \ | |||||
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ | |||||
filter_size; | |||||
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(8, sum01_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum01) | |||||
UNROLL_CALL_NOWRAPPER(8, sum23_avg) | |||||
UNROLL_CALL_NOWRAPPER(8, store_sum23) | |||||
sptr0 += 32; | |||||
sptr1 += 32; | |||||
sptr2 += 32; | |||||
sptr3 += 32; | |||||
sptr4 += 32; | |||||
dptr += 16; | |||||
#undef store_sum01 | |||||
#undef store_sum23 | |||||
#undef sum01_avg | |||||
#undef sum23_avg | |||||
#undef sum_define | |||||
#undef CACULATE_ROW | |||||
} | |||||
for (; ow < OW; ++ow) { | |||||
int32x2_t src004; | |||||
int8x8_t src001, src023; | |||||
int16x8_t src01, src23, src04, sum01, sum02; | |||||
sum01 = vdupq_n_s16(0); | |||||
sum02 = vdupq_n_s16(0); | |||||
#define CACULATE_ROW(i) \ | |||||
src001 = vld1_s8(sptr##i); \ | |||||
src023 = vld1_s8(sptr##i + 8); \ | |||||
src004 = vld1_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 4)); \ | |||||
src01 = vmovl_s8(src001); \ | |||||
src23 = vmovl_s8(src023); \ | |||||
src04 = vmovl_s8(vreinterpret_s8_s32(src004)); \ | |||||
sum01 = vaddq_s16(sum01, src01); \ | |||||
sum01 = vaddq_s16(sum01, src23); \ | |||||
sum02 = vaddq_s16(sum02, src04); | |||||
UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) | |||||
#define do_acc(i) \ | |||||
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \ | |||||
vgetq_lane_s16(sum02, i + 4); | |||||
#define do_avg(i) \ | |||||
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ | |||||
: (sum##i - filter_size / 2) / filter_size; | |||||
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i); | |||||
UNROLL_CALL_NOWRAPPER(4, do_acc) | |||||
UNROLL_CALL_NOWRAPPER(4, do_avg) | |||||
UNROLL_CALL_NOWRAPPER(4, store) | |||||
#undef store | |||||
#undef do_avg | |||||
#undef do_acc | |||||
#undef CACULATE_ROW | |||||
sptr0 += 8; | |||||
sptr1 += 8; | |||||
sptr2 += 8; | |||||
sptr3 += 8; | |||||
sptr4 += 8; | |||||
dptr += 4; | |||||
} | |||||
} | |||||
} | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.h | |||||
* \file dnn/src/arm_common/pooling/do__pooling_5x5_nchw44.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -15,15 +15,16 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | namespace arm_common { | ||||
#define KERN(strdie) \ | |||||
void do_max_pooling_5x5_##strdie##_int8_nchw44_NEON( \ | |||||
#define KERN(mode, stride, ctype) \ | |||||
void do_##mode##_pooling_5x5_##stride##_##ctype##_nchw44_NEON( \ | |||||
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ | const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ | ||||
size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws); | size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws); | ||||
KERN(stride1) | |||||
KERN(stride2) | |||||
KERN(max, stride1, int8) | |||||
KERN(max, stride2, int8) | |||||
KERN(avg, stride1, int8) | |||||
KERN(avg, stride2, int8) | |||||
#undef KERN | #undef KERN | ||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -25,10 +25,10 @@ class PoolingImpl::AlgoPack : NonCopyableObj { | |||||
AlgoFilter5MaxStride2 algo_filter5_max_stride2; | AlgoFilter5MaxStride2 algo_filter5_max_stride2; | ||||
AlgoInt8Filter2MaxStride2 algo_int8_filter2_max_stride2; | AlgoInt8Filter2MaxStride2 algo_int8_filter2_max_stride2; | ||||
AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2; | AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2; | ||||
AlgoFilter2MaxStridexNCHW44 algo_filter2_max_stridex_nchw4; | |||||
AlgoFilter3MaxStridexNCHW44 algo_filter3_max_stridex_nchw4; | |||||
AlgoFilter4MaxStridexNCHW44 algo_filter4_max_stridex_nchw4; | |||||
AlgoFilter5MaxStridexNCHW44 algo_filter5_max_stridex_nchw4; | |||||
AlgoFilter2ModexStridexNCHW44 algo_filter2_modex_stridex_nchw4; | |||||
AlgoFilter3ModexStridexNCHW44 algo_filter3_modex_stridex_nchw4; | |||||
AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4; | |||||
AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4; | |||||
public: | public: | ||||
AlgoPack() { | AlgoPack() { | ||||
@@ -40,10 +40,10 @@ public: | |||||
all_algos.emplace_back(&algo_filter5_max_stride2); | all_algos.emplace_back(&algo_filter5_max_stride2); | ||||
all_algos.emplace_back(&algo_int8_filter2_max_stride2); | all_algos.emplace_back(&algo_int8_filter2_max_stride2); | ||||
all_algos.emplace_back(&algo_int8_filter3_max_stride2); | all_algos.emplace_back(&algo_int8_filter3_max_stride2); | ||||
all_algos.emplace_back(&algo_filter3_max_stridex_nchw4); | |||||
all_algos.emplace_back(&algo_filter2_max_stridex_nchw4); | |||||
all_algos.emplace_back(&algo_filter4_max_stridex_nchw4); | |||||
all_algos.emplace_back(&algo_filter5_max_stridex_nchw4); | |||||
all_algos.emplace_back(&algo_filter3_modex_stridex_nchw4); | |||||
all_algos.emplace_back(&algo_filter2_modex_stridex_nchw4); | |||||
all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4); | |||||
all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4); | |||||
} | } | ||||
SmallVector<AlgoBase*> all_algos; | SmallVector<AlgoBase*> all_algos; | ||||
}; | }; | ||||
@@ -83,10 +83,10 @@ private: | |||||
class AlgoFilter5MaxStride2; | class AlgoFilter5MaxStride2; | ||||
class AlgoInt8Filter2MaxStride2; | class AlgoInt8Filter2MaxStride2; | ||||
class AlgoInt8Filter3MaxStride2; | class AlgoInt8Filter3MaxStride2; | ||||
class AlgoFilter2MaxStridexNCHW44; | |||||
class AlgoFilter3MaxStridexNCHW44; | |||||
class AlgoFilter4MaxStridexNCHW44; | |||||
class AlgoFilter5MaxStridexNCHW44; | |||||
class AlgoFilter2ModexStridexNCHW44; | |||||
class AlgoFilter3ModexStridexNCHW44; | |||||
class AlgoFilter4ModexStridexNCHW44; | |||||
class AlgoFilter5ModexStridexNCHW44; | |||||
class AlgoPack; | class AlgoPack; | ||||
}; | }; | ||||
} // namespace arm_common | } // namespace arm_common | ||||
@@ -10,6 +10,7 @@ | |||||
*/ | */ | ||||
#include <vector> | #include <vector> | ||||
#include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
#include "megdnn/opr_param_defs.h" | |||||
#include "test/arm_common/fixture.h" | #include "test/arm_common/fixture.h" | ||||
#include "test/common/pooling.h" | #include "test/common/pooling.h" | ||||
@@ -56,13 +57,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING) { | |||||
} | } | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_NCHW44) | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) | |||||
{ | { | ||||
// clang-format off | // clang-format off | ||||
for (size_t ih: {3, 5, 10}) | for (size_t ih: {3, 5, 10}) | ||||
for (size_t iw: {3, 5, 7, 9, 15, 20}) | for (size_t iw: {3, 5, 7, 9, 15, 20}) | ||||
for (size_t ph: {0, 1, 2}) | for (size_t ph: {0, 1, 2}) | ||||
for (size_t pw: {0, 1, 2}) | for (size_t pw: {0, 1, 2}) | ||||
for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) | |||||
if (ih+2*ph >= 3 && iw+2*pw >= 3) | if (ih+2*ph >= 3 && iw+2*pw >= 3) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | ||||
@@ -71,7 +73,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_NCHW44) | |||||
checker.set_rng(0,&rng); | checker.set_rng(0,&rng); | ||||
param::Pooling param; | param::Pooling param; | ||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.mode = mode; | |||||
param.format = param::Pooling::Format::NCHW44; | param.format = param::Pooling::Format::NCHW44; | ||||
param.pad_h = ph; | param.pad_h = ph; | ||||
param.pad_w = pw; | param.pad_w = pw; | ||||
@@ -86,13 +88,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_NCHW44) | |||||
// clang-format on | // clang-format on | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_NCHW44) | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W2x2_NCHW44) | |||||
{ | { | ||||
// clang-format off | // clang-format off | ||||
for (size_t ih: {2, 5, 10, 17}) | for (size_t ih: {2, 5, 10, 17}) | ||||
for (size_t iw: {2, 6, 8, 16, 26}) | for (size_t iw: {2, 6, 8, 16, 26}) | ||||
for (size_t ph: {0, 1}) | for (size_t ph: {0, 1}) | ||||
for (size_t pw: {0, 1}) | for (size_t pw: {0, 1}) | ||||
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) | |||||
if (ih+2*ph >= 2 && iw+2*pw >= 2) | if (ih+2*ph >= 2 && iw+2*pw >= 2) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | ||||
@@ -101,7 +104,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_NCHW44) | |||||
checker.set_rng(0,&rng); | checker.set_rng(0,&rng); | ||||
param::Pooling param; | param::Pooling param; | ||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.mode = mode; | |||||
param.format = param::Pooling::Format::NCHW44; | param.format = param::Pooling::Format::NCHW44; | ||||
param.pad_h = ph; | param.pad_h = ph; | ||||
param.pad_w = pw; | param.pad_w = pw; | ||||
@@ -115,13 +118,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_NCHW44) | |||||
// clang-format on | // clang-format on | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_NCHW44) | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W4x4_NCHW44) | |||||
{ | { | ||||
// clang-format off | // clang-format off | ||||
for (size_t ih: {4, 10, 18, 25, 30}) | for (size_t ih: {4, 10, 18, 25, 30}) | ||||
for (size_t iw: {4, 12, 17, 20, 25}) | for (size_t iw: {4, 12, 17, 20, 25}) | ||||
for (size_t ph: {0, 1, 2}) | for (size_t ph: {0, 1, 2}) | ||||
for (size_t pw: {0, 1, 2}) | for (size_t pw: {0, 1, 2}) | ||||
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) | |||||
if (ih+2*ph >= 4 && iw+2*pw >= 4) | if (ih+2*ph >= 4 && iw+2*pw >= 4) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | ||||
@@ -130,7 +134,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_NCHW44) | |||||
checker.set_rng(0,&rng); | checker.set_rng(0,&rng); | ||||
param::Pooling param; | param::Pooling param; | ||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.mode = mode; | |||||
param.format = param::Pooling::Format::NCHW44; | param.format = param::Pooling::Format::NCHW44; | ||||
param.pad_h = ph; | param.pad_h = ph; | ||||
param.pad_w = pw; | param.pad_w = pw; | ||||
@@ -143,13 +147,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_NCHW44) | |||||
} | } | ||||
// clang-format on | // clang-format on | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_NCHW44) | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W5x5_NCHW44) | |||||
{ | { | ||||
// clang-format off | // clang-format off | ||||
for (size_t ih: {5, 9, 19, 20, 39}) | for (size_t ih: {5, 9, 19, 20, 39}) | ||||
for (size_t iw: {5, 12, 23, 27, 39}) | for (size_t iw: {5, 12, 23, 27, 39}) | ||||
for (size_t ph: {0, 1, 2}) | for (size_t ph: {0, 1, 2}) | ||||
for (size_t pw: {0, 1, 2}) | for (size_t pw: {0, 1, 2}) | ||||
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) | |||||
if (ih+2*ph >= 5 && iw+2*pw >= 5) | if (ih+2*ph >= 5 && iw+2*pw >= 5) | ||||
{ | { | ||||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | ||||
@@ -158,7 +163,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_NCHW44) | |||||
checker.set_rng(0,&rng); | checker.set_rng(0,&rng); | ||||
param::Pooling param; | param::Pooling param; | ||||
param.mode = param::Pooling::Mode::MAX; | |||||
param.mode = mode; | |||||
param.format = param::Pooling::Format::NCHW44; | param.format = param::Pooling::Format::NCHW44; | ||||
param.pad_h = ph; | param.pad_h = ph; | ||||
param.pad_w = pw; | param.pad_w = pw; | ||||
@@ -477,31 +482,37 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_POOLING_NCHW44) { | |||||
std::vector<SmallVector<TensorShape>> shapes; | std::vector<SmallVector<TensorShape>> shapes; | ||||
std::vector<std::vector<size_t>> filter_and_stride = { | std::vector<std::vector<size_t>> filter_and_stride = { | ||||
{2, 1}, {2, 2}, {3, 1}, {3, 2}, {4, 1}, {4, 2}, {5, 1}, {5, 2}}; | {2, 1}, {2, 2}, {3, 1}, {3, 2}, {4, 1}, {4, 2}, {5, 1}, {5, 2}}; | ||||
for(auto filter:filter_and_stride){ | |||||
shapes.push_back({{1, 32 * 4, 215, 215}, {}}); | |||||
shapes.push_back({{1, 32 * 4, 128, 128}, {}}); | |||||
shapes.push_back({{1, 16 * 4, 56, 56}, {}}); | |||||
param.window_h = param.window_w = filter[0]; | |||||
param.stride_h = param.stride_w = filter[1]; | |||||
param.format = Param::Format::NCHW; | |||||
printf("NCHW Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h, | |||||
param.window_h, param.stride_h, static_cast<int>(param.mode)); | |||||
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
dtype::QuantizedS8(1.1f)); | |||||
shapes.clear(); | |||||
shapes.push_back({{1, 32, 215, 215,4}, {}}); | |||||
shapes.push_back({{1, 32, 128, 128,4}, {}}); | |||||
shapes.push_back({{1, 16, 56, 56, 4}, {}}); | |||||
param.format = Param::Format::NCHW44; | |||||
printf("NCHW44 Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h, | |||||
param.window_w, param.stride_h, static_cast<int>(param.mode)); | |||||
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
dtype::QuantizedS8(1.1f)); | |||||
shapes.clear(); | |||||
} | |||||
for (auto mode : | |||||
{param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) { | |||||
for (auto filter : filter_and_stride) { | |||||
shapes.push_back({{1, 32 * 4, 215, 215}, {}}); | |||||
shapes.push_back({{1, 32 * 4, 128, 128}, {}}); | |||||
shapes.push_back({{1, 16 * 4, 56, 56}, {}}); | |||||
param.mode = mode; | |||||
param.window_h = param.window_w = filter[0]; | |||||
param.stride_h = param.stride_w = filter[1]; | |||||
param.format = Param::Format::NCHW; | |||||
printf("NCHW Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", | |||||
param.window_h, param.window_h, param.stride_h, | |||||
static_cast<int>(param.mode)); | |||||
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}}, | |||||
{1, {4}}, dtype::QuantizedS8(1.1f)); | |||||
shapes.clear(); | |||||
shapes.push_back({{1, 32, 215, 215, 4}, {}}); | |||||
shapes.push_back({{1, 32, 128, 128, 4}, {}}); | |||||
shapes.push_back({{1, 16, 56, 56, 4}, {}}); | |||||
param.format = Param::Format::NCHW44; | |||||
printf("NCHW44 Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", | |||||
param.window_h, param.window_w, param.stride_h, | |||||
static_cast<int>(param.mode)); | |||||
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}}, | |||||
{1, {4}}, dtype::QuantizedS8(1.1f)); | |||||
shapes.clear(); | |||||
} | |||||
} | |||||
} | } | ||||
#endif | #endif | ||||