@@ -38,23 +38,6 @@ public: | |||
const NCBKernSizeParam& param) const override; | |||
}; | |||
class ConvBiasImpl::AlgoS8DirectStride1NCHW44 final : public AlgoBase { | |||
public: | |||
AlgoS8DirectStride1NCHW44() {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "S8_NCHW44_DIRECT_STRD1"; } | |||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
bool is_preferred(megdnn::fallback::ConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
}; | |||
class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | |||
bool m_large_group; | |||
@@ -74,11 +57,11 @@ public: | |||
const NCBKernSizeParam& param) const override; | |||
}; | |||
class ConvBiasImpl::AlgoS8DirectStride2NCHW44 final : public AlgoBase { | |||
class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { | |||
public: | |||
AlgoS8DirectStride2NCHW44() {} | |||
AlgoS8DirectNCHW44() {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "S8_NCHW44_DIRECT_STRD2"; } | |||
const char* name() const override { return "S8_NCHW44_DIRECT"; } | |||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl*, | |||
@@ -245,8 +228,8 @@ private: | |||
//=======================input int8 compute fp32 output int8============ | |||
class ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44 final : public AlgoBase { | |||
public: | |||
AlgoS8CF32WinogradF23_4x4_NCHW44(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
uint32_t tile_size) | |||
AlgoS8CF32WinogradF23_4x4_NCHW44( | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) | |||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
@@ -277,7 +260,7 @@ private: | |||
class ConvBiasImpl::AlgoS8WinogradF23_8x8_NCHW44 final : public AlgoBase { | |||
public: | |||
AlgoS8WinogradF23_8x8_NCHW44(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
uint32_t tile_size) | |||
uint32_t tile_size) | |||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
@@ -36,26 +36,6 @@ KERN(stride2, 7, nchw) | |||
#undef KERN | |||
#define KERN(stride, i, layout) \ | |||
template <BiasMode bias_mode, typename Op, int remain_w> \ | |||
void conv_direct_##stride##_##i##x##i##_int8_##layout( \ | |||
const int8_t* src, const int8_t* filter, const int32_t* bias, \ | |||
int32_t* temp, int8_t* dst, const size_t OC, const size_t IC, \ | |||
const size_t IH, const size_t IW, const size_t OH, \ | |||
const size_t OW, const Op& op); | |||
KERN(stride1, 2, nchw44) | |||
KERN(stride1, 3, nchw44) | |||
KERN(stride1, 5, nchw44) | |||
KERN(stride1, 7, nchw44) | |||
KERN(stride2, 2, nchw44) | |||
KERN(stride2, 3, nchw44) | |||
KERN(stride2, 5, nchw44) | |||
KERN(stride2, 7, nchw44) | |||
#undef KERN | |||
void nchw44_pack_filter(const int8_t* src, int8_t* dst, int filter); | |||
void nchw44_pack_src(const int8_t* src, int8_t* dst, int length); | |||
} // namespace conv_bias | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_algo.cpp | |||
* \file dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
@@ -13,6 +13,7 @@ | |||
#include "megdnn/oprs.h" | |||
#include "src/arm_common/conv_bias/int8/algos.h" | |||
#include "src/arm_common/conv_bias/int8/direct.h" | |||
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | |||
#include "src/arm_common/conv_bias/int8/strategy.h" | |||
#include "src/arm_common/elemwise_op.h" | |||
#include "src/common/opr_delegate.h" | |||
@@ -25,28 +26,19 @@ using conv_fun = std::function<void( | |||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_stride2) | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44) | |||
static void get_rectified_size( | |||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { | |||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, | |||
int& iw2) { | |||
auto&& fm = param.filter_meta; | |||
size_t SW = fm.stride[1]; | |||
size_t IH = param.isz[0]; | |||
size_t IW = param.isz[1]; | |||
size_t OH = param.osz[0]; | |||
size_t OW = param.osz[1]; | |||
size_t FH = fm.spatial[0]; | |||
size_t FW = fm.spatial[1]; | |||
int ih = param.isz[0]; | |||
int iw = param.isz[1]; | |||
int ph = fm.padding[0]; | |||
int pw = fm.padding[1]; | |||
OH2 = OH; | |||
OW2 = (OW + 7) & ~7; | |||
IH2 = SW * OH + FH - SW; | |||
IW2 = SW * OW2 + FW - SW; | |||
// Because stride is 2, sometimes IW == IW2+1. Do a max update to | |||
// handle this case. | |||
IH2 = std::max(IH2, IH); | |||
IW2 = std::max(IW2, IW); | |||
ih2 = ih + ph * 2; | |||
iw2 = iw + pw * 2; | |||
} | |||
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||
constexpr size_t src_expand = 4; | |||
@@ -57,8 +49,8 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||
size_t OC = fm.ocpg; | |||
size_t FH = fm.spatial[0]; | |||
size_t FW = fm.spatial[1]; | |||
size_t IH2, IW2, OH2, OW2; | |||
get_rectified_size(param, IH2, IW2, OH2, OW2); | |||
int IH2, IW2; | |||
get_rectified_size(param, IH2, IW2); | |||
if (group == 1) { | |||
size_t src_size = | |||
batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; | |||
@@ -76,16 +68,16 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||
const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids) { | |||
size_t IH = kern_param.isz[0]; | |||
size_t IW = kern_param.isz[1]; | |||
size_t IC = kern_param.filter_meta.icpg; | |||
size_t PH = kern_param.filter_meta.padding[0]; | |||
size_t PW = kern_param.filter_meta.padding[1]; | |||
size_t GROUP = kern_param.filter_meta.group; | |||
int IH = kern_param.isz[0]; | |||
int IW = kern_param.isz[1]; | |||
int IC = kern_param.filter_meta.icpg; | |||
int PH = kern_param.filter_meta.padding[0]; | |||
int PW = kern_param.filter_meta.padding[1]; | |||
int GROUP = kern_param.filter_meta.group; | |||
size_t IH2, IW2, OH2, OW2; | |||
get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
int IH2, IW2; | |||
get_rectified_size(kern_param, IH2, IW2); | |||
int padding_group_size = IH2 * IW2 * IC; | |||
bundle.set(kern_param.workspace_ptr); | |||
//! Used for get the workspace offset | |||
constexpr int pack_ic = 4; | |||
@@ -100,16 +92,10 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||
size_t group_id = ncb_index.ndrange_id[1]; | |||
size_t group_pack_size = 1; | |||
int nr_pad_h = PH * IW2 * pack_ic * expend_element; | |||
int nr_pad_w = PW * pack_ic * expend_element; | |||
int over_pad = std::max(0_z, IW2 - IW - 2 * PW) * pack_ic * expend_element; | |||
int row_last_pad = ((int)IW2 - (int)IW - 2 * (int)PW) >= 0 | |||
? nr_pad_w + over_pad | |||
: (IW2 - IW - PW) * pack_ic * expend_element; | |||
int col_last_pad = | |||
((int)IH2 - (int)IH - 2 * (int)PH) >= 0 | |||
? nr_pad_h | |||
: (IH2 - IH - PH) * IW2 * pack_ic * expend_element; | |||
int nr_pad_h = PH * IW2 * pack_ic * expend_element; | |||
int row_last_pad = (IW2 - IW - PW) * pack_ic * expend_element; | |||
int col_last_pad = (IH2 - IH - PH) * IW2 * pack_ic * expend_element; | |||
const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>( | |||
batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); | |||
@@ -129,7 +115,7 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||
rep(ih_idx, IH) { | |||
std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); | |||
sptr_base += nr_pad_w; | |||
conv_bias::nchw44_pack_src(sptr, sptr_base, IW); | |||
nchw44_pack_src(sptr, sptr_base, IW); | |||
sptr_base += IW * pack_ic * expend_element; | |||
sptr += IW * pack_ic; | |||
std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); | |||
@@ -140,7 +126,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||
} | |||
} | |||
template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain> | |||
template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain, | |||
typename DstType, int stride> | |||
static void do_conv_kern(WorkspaceBundle bundle, | |||
const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
@@ -153,12 +140,12 @@ static void do_conv_kern(WorkspaceBundle bundle, | |||
size_t IC = kern_param.filter_meta.icpg; | |||
size_t OC = kern_param.filter_meta.ocpg; | |||
size_t GROUP = kern_param.filter_meta.group; | |||
size_t IH2, IW2, OH2, OW2; | |||
get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||
int IH2, IW2; | |||
get_rectified_size(kern_param, IH2, IW2); | |||
bool need_post_process = | |||
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) | |||
Op op = Op(1.0f, 4.0f); | |||
Op op(1.f, 4.f); | |||
if (need_post_process) { | |||
float scale_bias = | |||
kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||
@@ -191,49 +178,43 @@ static void do_conv_kern(WorkspaceBundle bundle, | |||
const int8_t* fptr = | |||
kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC; | |||
void* dst = reinterpret_cast<void*>( | |||
reinterpret_cast<ptrdiff_t>( | |||
kern_param.dst<void>(batch_id, group_id)) + | |||
oc_idx * OH * OW); | |||
DstType* dst = reinterpret_cast<DstType*>( | |||
kern_param.dst<void>(batch_id, group_id, oc_idx)); | |||
const int32_t* bptr = | |||
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx; | |||
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | |||
group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; | |||
conv_bias::nchw44_pack_filter(fptr, packed_weight, | |||
oc_block / 4 * IC / 4 * FH * FW); | |||
#define KERN1_NCHW44_CONV(filter) \ | |||
conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw44< \ | |||
bias_mode, Op, ow_remain>(sptr, packed_weight, bptr, nullptr, \ | |||
static_cast<int8_t*>(dst), oc_block, IC, \ | |||
IH2, IW2, OH, OW, op) | |||
DISPATCH_FILTER(filter, KERN1_NCHW44_CONV) | |||
#undef KERN1_NCHW44_CONV | |||
nchw44_pack_filter(fptr, packed_weight, oc_block / 4 * IC / 4 * FH * FW); | |||
conv_direct_int8_nchw44<bias_mode, Op, ow_remain, filter, DstType, stride>( | |||
sptr, packed_weight, bptr, nullptr, static_cast<DstType*>(dst), | |||
oc_block, IC, IH2, IW2, OH, OW, op); | |||
} | |||
/* ===================== stride2 algo ===================== */ | |||
bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::usable( | |||
bool ConvBiasImpl::AlgoS8DirectNCHW44::usable( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
MEGDNN_MARK_USED_VAR(algo_selection_strategy); | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
auto OC = fm.ocpg; | |||
auto IC = fm.icpg; | |||
bool avaible = //! src and filter are qint8, dst is qint8 or qint32 | |||
const int fh = fm.spatial[0]; | |||
const int fw = fm.spatial[1]; | |||
const int oc = fm.ocpg; | |||
const int ic = fm.icpg; | |||
const bool avaible = //! src and filter are qint8, dst is qint8 or qint32 | |||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
(param.dst_type.enumv() == DTypeEnum::QuantizedS8 || | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS32))) && | |||
(fm.format == param::Convolution::Format::NCHW44) && | |||
(OC % 4 == 0 && IC % 4 == 0 && OC >= 4) && !fm.should_flip && | |||
(oc % 4 == 0 && ic % 4 == 0 && oc >= 4) && !fm.should_flip && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | |||
fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && | |||
(fm.stride[0] == 2 || fm.stride[0] == 1) && fh == fw && | |||
(fh == 2 || fh == 3 || fh == 5 || fh == 7) && | |||
param.bias_mode != BiasMode::BIAS; | |||
return avaible; | |||
} | |||
bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::is_preferred( | |||
bool ConvBiasImpl::AlgoS8DirectNCHW44::is_preferred( | |||
megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, | |||
const NCBKernSizeParam& param) const { | |||
// TODO: benchmark and fix | |||
@@ -242,13 +223,13 @@ bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::is_preferred( | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoS8DirectStride2NCHW44::get_workspace( | |||
size_t ConvBiasImpl::AlgoS8DirectNCHW44::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoS8DirectStride2NCHW44::dispatch_kerns( | |||
ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
auto fm = param.filter_meta; | |||
size_t N = param.n; | |||
@@ -261,97 +242,129 @@ ConvBiasImpl::AlgoS8DirectStride2NCHW44::dispatch_kerns( | |||
WorkspaceBundle wbundle = get_bundle(param); | |||
conv_fun do_conv_fun = nullptr; | |||
int ow_remain = OW % 8; | |||
bool need_post_process = param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||
// NOTE: remain_w is not used to gen hash of midout for compatible with changing | |||
// shape runtime | |||
#define DO_CONV_KERN_FUN(filter, bias_mode, remain_w, op) \ | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride2, \ | |||
midout_iv(#filter #bias_mode #op##_hash)) { \ | |||
do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w>; \ | |||
} \ | |||
#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, remain_w, op) \ | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, \ | |||
midout_iv(#stride #dst_type #filter #bias_mode #op##_hash)) { \ | |||
do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w, dst_type, \ | |||
stride>; \ | |||
} \ | |||
MIDOUT_END(); | |||
#define GET_OP_PARAM(filter, bias_mode, remain_w) \ | |||
switch (param.nonlineMode) { \ | |||
case param::ConvBias::NonlineMode::IDENTITY: \ | |||
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
break; \ | |||
case param::ConvBias::NonlineMode::RELU: \ | |||
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
break; \ | |||
case param::ConvBias::NonlineMode::H_SWISH: \ | |||
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
break; \ | |||
#define GET_OP_PARAM(stride, filter, bias_mode, remain_w) \ | |||
if (need_post_process) { \ | |||
switch (param.nonlineMode) { \ | |||
case param::ConvBias::NonlineMode::IDENTITY: \ | |||
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | |||
remain_w, \ | |||
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
break; \ | |||
case param::ConvBias::NonlineMode::RELU: \ | |||
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | |||
remain_w, \ | |||
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
break; \ | |||
case param::ConvBias::NonlineMode::H_SWISH: \ | |||
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | |||
remain_w, \ | |||
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0, "no supported noline mode"); \ | |||
break; \ | |||
} \ | |||
} else { \ | |||
switch (param.nonlineMode) { \ | |||
case param::ConvBias::NonlineMode::IDENTITY: \ | |||
DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, \ | |||
remain_w, NoneOp<dt_int32>) \ | |||
break; \ | |||
default: \ | |||
megdnn_assert( \ | |||
0, \ | |||
"only support IDENTITY mode when dst is not qint8"); \ | |||
break; \ | |||
} \ | |||
} | |||
#define GET_REMAIN_W_PARAM(filter, bias_mode) \ | |||
switch (ow_remain) { \ | |||
case 0: \ | |||
GET_OP_PARAM(filter, bias_mode, 0); \ | |||
break; \ | |||
case 1: \ | |||
GET_OP_PARAM(filter, bias_mode, 1); \ | |||
break; \ | |||
case 2: \ | |||
GET_OP_PARAM(filter, bias_mode, 2); \ | |||
break; \ | |||
case 3: \ | |||
GET_OP_PARAM(filter, bias_mode, 3); \ | |||
break; \ | |||
case 4: \ | |||
GET_OP_PARAM(filter, bias_mode, 4); \ | |||
break; \ | |||
case 5: \ | |||
GET_OP_PARAM(filter, bias_mode, 5); \ | |||
break; \ | |||
case 6: \ | |||
GET_OP_PARAM(filter, bias_mode, 6); \ | |||
break; \ | |||
case 7: \ | |||
GET_OP_PARAM(filter, bias_mode, 7); \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
#define GET_REMAIN_W_PARAM(stride, filter, bias_mode) \ | |||
switch (ow_remain) { \ | |||
case 0: \ | |||
GET_OP_PARAM(stride, filter, bias_mode, 0); \ | |||
break; \ | |||
case 1: \ | |||
GET_OP_PARAM(stride, filter, bias_mode, 1); \ | |||
break; \ | |||
case 2: \ | |||
GET_OP_PARAM(stride, filter, bias_mode, 2); \ | |||
break; \ | |||
case 3: \ | |||
GET_OP_PARAM(stride, filter, bias_mode, 3); \ | |||
break; \ | |||
case 4: \ | |||
GET_OP_PARAM(stride, filter, bias_mode, 4); \ | |||
break; \ | |||
case 5: \ | |||
GET_OP_PARAM(stride, filter, bias_mode, 5); \ | |||
break; \ | |||
case 6: \ | |||
GET_OP_PARAM(stride, filter, bias_mode, 6); \ | |||
break; \ | |||
case 7: \ | |||
GET_OP_PARAM(stride, filter, bias_mode, 7); \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
} | |||
#define GET_BIAS_MODE_PARAM(filter) \ | |||
switch (param.bias_mode) { \ | |||
case BiasMode::NO_BIAS: \ | |||
GET_REMAIN_W_PARAM(filter, BiasMode::NO_BIAS) \ | |||
break; \ | |||
case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
GET_REMAIN_W_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
break; \ | |||
#define GET_BIAS_MODE_PARAM(stride, filter) \ | |||
switch (param.bias_mode) { \ | |||
case BiasMode::NO_BIAS: \ | |||
GET_REMAIN_W_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||
break; \ | |||
case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
GET_REMAIN_W_PARAM(stride, filter, \ | |||
BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
break; \ | |||
} | |||
#define DISPATCH_CONV_KERN() \ | |||
#define DISPATCH_CONV_KERN(stride) \ | |||
switch (param.filter_meta.spatial[0]) { \ | |||
case 2: \ | |||
GET_BIAS_MODE_PARAM(2) \ | |||
GET_BIAS_MODE_PARAM(stride, 2) \ | |||
break; \ | |||
case 3: \ | |||
GET_BIAS_MODE_PARAM(3) \ | |||
GET_BIAS_MODE_PARAM(stride, 3) \ | |||
break; \ | |||
case 5: \ | |||
GET_BIAS_MODE_PARAM(5) \ | |||
GET_BIAS_MODE_PARAM(stride, 5) \ | |||
break; \ | |||
case 7: \ | |||
GET_BIAS_MODE_PARAM(7) \ | |||
GET_BIAS_MODE_PARAM(stride, 7) \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
break; \ | |||
} | |||
DISPATCH_CONV_KERN(); | |||
switch (param.filter_meta.stride[0]) { | |||
case 1: | |||
DISPATCH_CONV_KERN(1); | |||
break; | |||
case 2: | |||
DISPATCH_CONV_KERN(2); | |||
break; | |||
default: | |||
megdnn_throw(ssprintf("Unsupport stride size %u for the first conv", | |||
param.filter_meta.stride[0]) | |||
.c_str()); | |||
break; | |||
} | |||
#undef DO_CONV_KERN_FUN | |||
#undef GET_REMAIN_W_PARAM |
@@ -1,393 +0,0 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/arm_common/conv_bias/int8/algos.h" | |||
#include "src/arm_common/conv_bias/int8/direct.h" | |||
#include "src/arm_common/conv_bias/int8/strategy.h" | |||
#include "src/arm_common/elemwise_op.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "midout.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using conv_fun = std::function<void( | |||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_stride1) | |||
static void get_rectified_size( | |||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { | |||
auto&& fm = param.filter_meta; | |||
auto SW = fm.stride[1]; | |||
auto OH = param.osz[0]; | |||
auto OW = param.osz[1]; | |||
auto FH = fm.spatial[0]; | |||
auto FW = fm.spatial[1]; | |||
OH2 = OH; | |||
OW2 = (OW + 7) & ~7; | |||
IH2 = SW * OH + FH - SW; | |||
IW2 = SW * OW2 + FW - SW; | |||
} | |||
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||
constexpr size_t src_expand = 4; | |||
auto&& fm = param.filter_meta; | |||
size_t group = fm.group; | |||
size_t batch = param.n; | |||
size_t IC = fm.icpg; | |||
size_t OC = fm.ocpg; | |||
size_t FH = fm.spatial[0]; | |||
size_t FW = fm.spatial[1]; | |||
size_t IH2, IW2, OH2, OW2; | |||
get_rectified_size(param, IH2, IW2, OH2, OW2); | |||
if (group == 1) { | |||
size_t src_size = | |||
batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; | |||
size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); | |||
return {nullptr, {src_size, weight_size}}; | |||
} else { | |||
size_t src_size = | |||
param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; | |||
size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); | |||
return {nullptr, {src_size, weight_size}}; | |||
} | |||
}; | |||
static void copy_padding_kern(WorkspaceBundle bundle, | |||
const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids) { | |||
size_t IH = kern_param.isz[0]; | |||
size_t IW = kern_param.isz[1]; | |||
size_t IC = kern_param.filter_meta.icpg; | |||
size_t PH = kern_param.filter_meta.padding[0]; | |||
size_t PW = kern_param.filter_meta.padding[1]; | |||
size_t GROUP = kern_param.filter_meta.group; | |||
size_t IH2, IW2, OH2, OW2; | |||
get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
bundle.set(kern_param.workspace_ptr); | |||
//! Used for get the workspace offset | |||
constexpr int pack_ic = 4; | |||
constexpr int expend_element = 4; | |||
// TODO: block dim is better to get from arg | |||
size_t workspace_ic_block = 4; | |||
size_t workspace_batch_id = workspace_ids[0]; | |||
size_t workspace_group_id = workspace_ids[1]; | |||
size_t workspace_ic_id = workspace_ids[2]; | |||
size_t workspace_ic = workspace_ic_id * workspace_ic_block; | |||
size_t batch_id = ncb_index.ndrange_id[0]; | |||
size_t group_id = ncb_index.ndrange_id[1]; | |||
size_t group_pack_size = 1; | |||
int nr_pad_h = PH * IW2 * pack_ic * expend_element; | |||
int nr_pad_w = PW * pack_ic * expend_element; | |||
int over_pad = std::max(0_z, IW2 - IW - 2 * PW) * pack_ic * expend_element; | |||
//! copy to sptr_base to eliminate padding effect | |||
const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>( | |||
batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); | |||
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) + | |||
(workspace_batch_id * GROUP * padding_group_size + | |||
workspace_group_id * padding_group_size + | |||
workspace_ic * IH2 * IW2) * | |||
expend_element; | |||
size_t nr_ic = workspace_ic_block; | |||
if (GROUP > 1) { | |||
nr_ic = IC; | |||
} | |||
rep_step(ic_idx, nr_ic, pack_ic) { | |||
std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); | |||
sptr_base += nr_pad_h; | |||
rep(ih_idx, IH) { | |||
std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); | |||
sptr_base += nr_pad_w; | |||
conv_bias::nchw44_pack_src(sptr, sptr_base, IW); | |||
sptr_base += IW * pack_ic * expend_element; | |||
sptr += IW * pack_ic; | |||
std::memset(sptr_base, 0, (nr_pad_w + over_pad) * sizeof(int8_t)); | |||
sptr_base += nr_pad_w + over_pad; | |||
} | |||
std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); | |||
sptr_base += nr_pad_h; | |||
} | |||
} | |||
template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain> | |||
static void do_conv_kern(WorkspaceBundle bundle, | |||
const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids, | |||
const CpuNDRange& ncb_range) { | |||
size_t OH = kern_param.osz[0]; | |||
size_t OW = kern_param.osz[1]; | |||
size_t FH = kern_param.filter_meta.spatial[0]; | |||
size_t FW = kern_param.filter_meta.spatial[1]; | |||
size_t IC = kern_param.filter_meta.icpg; | |||
size_t OC = kern_param.filter_meta.ocpg; | |||
size_t GROUP = kern_param.filter_meta.group; | |||
size_t IH2, IW2, OH2, OW2; | |||
get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||
bool need_post_process = | |||
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) | |||
Op op = Op(1.0f, 4.0f); | |||
if (need_post_process) { | |||
float scale_bias = | |||
kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | |||
op = Op(scale_bias, scale_dst); | |||
} | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
bundle.set(kern_param.workspace_ptr); | |||
constexpr size_t pack_c = 4; | |||
constexpr size_t src_expand_size = 4; | |||
const size_t workspace_batch_id = workspace_ids[0]; | |||
const size_t workspace_group_id = workspace_ids[1]; | |||
const size_t batch_id = ncb_index.ndrange_id[0]; | |||
const size_t group_id = ncb_index.ndrange_id[1]; | |||
const size_t oc_id = ncb_index.ndrange_id[2]; | |||
const size_t oc_block_num = ncb_range[2]; | |||
size_t nr_pack_per_step = div_ceil(div_ceil(OC, pack_c), oc_block_num); | |||
size_t oc_block = nr_pack_per_step * pack_c; | |||
const size_t oc_idx = oc_id * oc_block; | |||
if (oc_id == (oc_block_num - 1)) { | |||
oc_block = OC - oc_id * nr_pack_per_step * pack_c; | |||
} | |||
megdnn_assert(oc_block % pack_c == 0, | |||
"oc must be devisible by 4, but oc = %zu", oc_block); | |||
const int8_t* sptr = | |||
static_cast<int8_t*>(bundle.get(0)) + | |||
workspace_batch_id * GROUP * padding_group_size * src_expand_size + | |||
workspace_group_id * padding_group_size * src_expand_size; | |||
const int8_t* fptr = | |||
kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC; | |||
void* dst = reinterpret_cast<void*>( | |||
reinterpret_cast<ptrdiff_t>( | |||
kern_param.dst<void>(batch_id, group_id)) + | |||
oc_idx * OH * OW); | |||
const int32_t* bptr = | |||
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx; | |||
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | |||
group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; | |||
conv_bias::nchw44_pack_filter(fptr, packed_weight, | |||
oc_block / 4 * IC / 4 * FH * FW); | |||
#define KERN1_NCHW44_CONV(filter) \ | |||
conv_bias::conv_direct_stride1_##filter##x##filter##_int8_nchw44< \ | |||
bias_mode, Op, ow_remain>(sptr, packed_weight, bptr, nullptr, \ | |||
static_cast<int8_t*>(dst), oc_block, IC, \ | |||
IH2, IW2, OH, OW, op) | |||
DISPATCH_FILTER(filter, KERN1_NCHW44_CONV) | |||
#undef KERN1_NCHW44_CONV | |||
} | |||
/* ===================== stride1 algo ===================== */ | |||
bool ConvBiasImpl::AlgoS8DirectStride1NCHW44::usable( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
MEGDNN_MARK_USED_VAR(algo_selection_strategy); | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
auto OC = fm.ocpg; | |||
auto IC = fm.icpg; | |||
bool avaible = //! src and filter are qint8, dst is qint8 or qint32 | |||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
(param.dst_type.enumv() == DTypeEnum::QuantizedS8 || | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS32))) && | |||
(fm.format == param::Convolution::Format::NCHW44) && | |||
(OC % 4 == 0 && IC % 4 == 0 && OC >= 4) && !fm.should_flip && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && | |||
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | |||
param.bias_mode != BiasMode::BIAS; | |||
return avaible; | |||
} | |||
bool ConvBiasImpl::AlgoS8DirectStride1NCHW44::is_preferred( | |||
megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, | |||
const NCBKernSizeParam& param) const { | |||
// TODO: benchmark and fix | |||
MEGDNN_MARK_USED_VAR(conv_bias_impl_ptr); | |||
MEGDNN_MARK_USED_VAR(param); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoS8DirectStride1NCHW44::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoS8DirectStride1NCHW44::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
auto fm = param.filter_meta; | |||
size_t N = param.n; | |||
size_t IC = fm.icpg; | |||
size_t OC = fm.ocpg; | |||
size_t OW = param.osz[1]; | |||
size_t group = fm.group; | |||
size_t fh = fm.spatial[0]; | |||
size_t fw = fm.spatial[1]; | |||
WorkspaceBundle wbundle = get_bundle(param); | |||
conv_fun do_conv_fun = nullptr; | |||
int ow_remain = OW % 8; | |||
// NOTE: remain_w is not used to gen hash of midout for compatible with changing | |||
// shape runtime | |||
#define DO_CONV_KERN_FUN(filter, bias_mode, remain_w, op) \ | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride1, \ | |||
midout_iv(#filter #bias_mode #op##_hash)) { \ | |||
do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w>; \ | |||
} \ | |||
MIDOUT_END(); | |||
#define GET_OP_PARAM(filter, bias_mode, remain_w) \ | |||
switch (param.nonlineMode) { \ | |||
case param::ConvBias::NonlineMode::IDENTITY: \ | |||
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
break; \ | |||
case param::ConvBias::NonlineMode::RELU: \ | |||
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
break; \ | |||
case param::ConvBias::NonlineMode::H_SWISH: \ | |||
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
break; \ | |||
} | |||
#define GET_REMAIN_W_PARAM(filter, bias_mode) \ | |||
switch (ow_remain) { \ | |||
case 0: \ | |||
GET_OP_PARAM(filter, bias_mode, 0); \ | |||
break; \ | |||
case 1: \ | |||
GET_OP_PARAM(filter, bias_mode, 1); \ | |||
break; \ | |||
case 2: \ | |||
GET_OP_PARAM(filter, bias_mode, 2); \ | |||
break; \ | |||
case 3: \ | |||
GET_OP_PARAM(filter, bias_mode, 3); \ | |||
break; \ | |||
case 4: \ | |||
GET_OP_PARAM(filter, bias_mode, 4); \ | |||
break; \ | |||
case 5: \ | |||
GET_OP_PARAM(filter, bias_mode, 5); \ | |||
break; \ | |||
case 6: \ | |||
GET_OP_PARAM(filter, bias_mode, 6); \ | |||
break; \ | |||
case 7: \ | |||
GET_OP_PARAM(filter, bias_mode, 7); \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
} | |||
#define GET_BIAS_MODE_PARAM(filter) \ | |||
switch (param.bias_mode) { \ | |||
case BiasMode::NO_BIAS: \ | |||
GET_REMAIN_W_PARAM(filter, BiasMode::NO_BIAS) \ | |||
break; \ | |||
case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
GET_REMAIN_W_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
break; \ | |||
} | |||
#define DISPATCH_CONV_KERN() \ | |||
switch (param.filter_meta.spatial[0]) { \ | |||
case 2: \ | |||
GET_BIAS_MODE_PARAM(2) \ | |||
break; \ | |||
case 3: \ | |||
GET_BIAS_MODE_PARAM(3) \ | |||
break; \ | |||
case 5: \ | |||
GET_BIAS_MODE_PARAM(5) \ | |||
break; \ | |||
case 7: \ | |||
GET_BIAS_MODE_PARAM(7) \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
break; \ | |||
} | |||
DISPATCH_CONV_KERN(); | |||
#undef DO_CONV_KERN_FUN | |||
#undef GET_REMAIN_W_PARAM | |||
#undef GET_OP_PARAM | |||
#undef GET_BIAS_MODE_PARAM | |||
#undef DISPATCH_CONV_KERN | |||
megdnn_assert(do_conv_fun); | |||
SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||
WorkspaceBundle bundle = wbundle; | |||
constexpr size_t pack_oc = 4; | |||
size_t oc_step = pack_oc; | |||
if (fh == 2 && fw == 2 && OC >= 8) { | |||
oc_step = 8; | |||
} | |||
if (group == 1) { | |||
CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; | |||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
copy_padding_kern(bundle, kern_param, ncb_index, | |||
ncb_index.ndrange_id); | |||
}; | |||
constexpr size_t pack_ic = 4; | |||
ret_kerns.push_back({copy_padding, {N, group, div_ceil(IC, pack_ic)}}); | |||
auto do_conv = [bundle, do_conv_fun, ncb_range]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, | |||
ncb_range); | |||
}; | |||
ret_kerns.push_back({do_conv, ncb_range}); | |||
} else { | |||
CpuNDRange ncb_range = {N, group, 1}; | |||
auto do_conv = [bundle, do_conv_fun, ncb_range]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
copy_padding_kern(bundle, kern_param, ncb_index, | |||
{0, ncb_index.thread_id, 0}); | |||
do_conv_fun(bundle, kern_param, ncb_index, | |||
{0, ncb_index.thread_id, 0}, ncb_range); | |||
}; | |||
ret_kerns.push_back({do_conv, ncb_range}); | |||
} | |||
return ret_kerns; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -1,791 +0,0 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/int8/direct.h" | |||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
#include "src/arm_common/elemwise_op.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
namespace { | |||
/** | |||
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc> | |||
example: (format like weight<oc, ic>) | |||
packed weight | |||
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> | |||
--------------------------------------------------------------------- | |||
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> | |||
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> | |||
**/ | |||
// TODO: can try oh = 2 impl, oc = 8 impl | |||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||
static void ker_neon_dirctconv_3x3s1_oc4_ow8(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, | |||
int8_t* dst_ptr, int ic, int ih, | |||
int iw, const Op& op) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
constexpr int pack_iw_len = 4; | |||
const int ic_stride = ih * iw * pack_iw_len; | |||
int32x4_t c[2 * 4]; | |||
int8x16_t weight[3]; | |||
int8x16_t src[8 + 2]; | |||
int16x8_t temp_c[2]; | |||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * pack_iw_len; | |||
src[0] = vld1q_s8(src_ic_0_3); | |||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||
// oc == 0 | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0] = vld1q_s8(read_weight_ptr); | |||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||
} | |||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||
static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, | |||
int8_t* dst_ptr, int ic, int ih, | |||
int iw, int ld_dst_oc, | |||
const Op& op) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int oc_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
constexpr int pack_iw_len = 4; | |||
const int ic_stride = ih * iw * pack_iw_len; | |||
const int ld_weight_oc4 = oc_step * fh * fw * ic; | |||
int32x4_t c[2][8]; | |||
int8x16_t weight[2][2]; | |||
int8x16_t src[8 + 1]; | |||
int16x8_t temp_c[4]; | |||
init_oc8_ow8<bias_mode>(c, bias_ptr, oc_step); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * pack_iw_len; | |||
src[0] = vld1q_s8(src_ic_0_3); | |||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||
// oc == 0 | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0][0] = vld1q_s8(read_weight_ptr); | |||
weight[0][1] = vld1q_s8(read_weight_ptr + 16); | |||
weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); | |||
weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); | |||
c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); | |||
c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); | |||
c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[2]); | |||
c[1][1] = vdotq_s32_h(weight[1][0], src[1], c[1][1], temp_c[3]); | |||
c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); | |||
c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); | |||
c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[2]); | |||
c[1][1] = vdotq_s32_h(weight[1][1], src[2], c[1][1], temp_c[3]); | |||
c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]); | |||
c[1][2] = vdotq_s32_h(weight[1][0], src[2], c[1][2], temp_c[1]); | |||
c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[2]); | |||
c[1][3] = vdotq_s32_h(weight[1][0], src[3], c[1][3], temp_c[3]); | |||
c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]); | |||
c[1][2] = vdotq_s32_h(weight[1][1], src[3], c[1][2], temp_c[1]); | |||
c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[2]); | |||
c[1][3] = vdotq_s32_h(weight[1][1], src[4], c[1][3], temp_c[3]); | |||
c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]); | |||
c[1][4] = vdotq_s32_h(weight[1][0], src[4], c[1][4], temp_c[1]); | |||
c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[2]); | |||
c[1][5] = vdotq_s32_h(weight[1][0], src[5], c[1][5], temp_c[3]); | |||
c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]); | |||
c[1][4] = vdotq_s32_h(weight[1][1], src[5], c[1][4], temp_c[1]); | |||
c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[2]); | |||
c[1][5] = vdotq_s32_h(weight[1][1], src[6], c[1][5], temp_c[3]); | |||
c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]); | |||
c[1][6] = vdotq_s32_h(weight[1][0], src[6], c[1][6], temp_c[1]); | |||
c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[2]); | |||
c[1][7] = vdotq_s32_h(weight[1][0], src[7], c[1][7], temp_c[3]); | |||
c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]); | |||
c[1][6] = vdotq_s32_h(weight[1][1], src[7], c[1][6], temp_c[1]); | |||
c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[2]); | |||
c[1][7] = vdotq_s32_h(weight[1][1], src[8], c[1][7], temp_c[3]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
store_oc8_ow8_remain_static<remain_w>(c, op, dst_ptr, ld_dst_oc); | |||
} | |||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||
static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, | |||
int8_t* dst_ptr, int ic, int ih, | |||
int iw, const Op& op) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
constexpr int pack_iw_len = 4; | |||
const int ic_stride = ih * iw * pack_iw_len; | |||
int32x4_t c[2 * 4]; | |||
int8x16_t weight[2]; | |||
int8x16_t src[8 + 1]; | |||
int16x8_t temp_c[2]; | |||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * pack_iw_len; | |||
src[0] = vld1q_s8(src_ic_0_3); | |||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||
// oc == 0 | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0] = vld1q_s8(read_weight_ptr); | |||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||
} | |||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||
static void ker_neon_dirctconv_5x5s1_oc4_ow8(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, | |||
int8_t* dst_ptr, int ic, int ih, | |||
int iw, const Op& op) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
constexpr int pack_iw_len = 4; | |||
const int ic_stride = ih * iw * pack_iw_len; | |||
int32x4_t c[2 * 4]; | |||
int8x16_t weight[5]; | |||
int8x16_t src[8 + 2]; | |||
int16x8_t temp_c[2]; | |||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * pack_iw_len; | |||
src[0] = vld1q_s8(src_ic_0_3); | |||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||
// oc == 0 | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0] = vld1q_s8(read_weight_ptr); | |||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[3], src[4], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[4], src[5], c[1], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[3], src[5], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[3], src[6], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[4], src[6], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[4], src[7], c[3], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[3], src[7], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[3], src[8], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[4], src[8], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[4], src[9], c[5], temp_c[1]); | |||
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||
c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[3], src[9], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[3], src[0], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[4], src[0], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[4], src[1], c[7], temp_c[1]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||
} | |||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||
static void ker_neon_dirctconv_7x7s1_oc4_ow8(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, | |||
int8_t* dst_ptr, int ic, int ih, | |||
int iw, const Op& op) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
constexpr int pack_iw_len = 4; | |||
const int ic_stride = ih * iw * pack_iw_len; | |||
int32x4_t c[2 * 4]; | |||
int8x16_t weight[7]; | |||
int8x16_t src[8 + 2]; | |||
int16x8_t temp_c[2]; | |||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * pack_iw_len; | |||
src[0] = vld1q_s8(src_ic_0_3); | |||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||
// oc == 0 | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0] = vld1q_s8(read_weight_ptr); | |||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); | |||
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); | |||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[3], src[4], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[4], src[5], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[5], src[5], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[5], src[6], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[6], src[6], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[6], src[7], c[1], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[3], src[5], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[3], src[6], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[4], src[6], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[4], src[7], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[5], src[7], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[5], src[8], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[6], src[8], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[6], src[9], c[3], temp_c[1]); | |||
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||
c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[3], src[7], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[3], src[8], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[4], src[8], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[4], src[9], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[5], src[9], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[5], src[0], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[6], src[0], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[6], src[1], c[5], temp_c[1]); | |||
src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); | |||
c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[3], src[9], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[3], src[0], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[4], src[0], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[4], src[1], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[5], src[1], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[5], src[2], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[6], src[2], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[6], src[3], c[7], temp_c[1]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||
} | |||
} // namespace | |||
/** | |||
origin weight shape <oc/4, ic/4, fh, fw, 4, 4> | |||
packed weight shape <oc/4, ic/4, fh, fw, 16> | |||
example: (format like weight<oc, ic>) | |||
origin | |||
<0, 0> <1, 0> <2, 0> <3, 0> | |||
<0, 1> <1, 1> <2, 1> <3, 1> | |||
<0, 2> <1, 2> <2, 2> <3, 2> | |||
<0, 3> <1, 3> <2, 3> <3, 3> | |||
packed | |||
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> | |||
--------------------------------------------------------------------- | |||
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> | |||
**/ | |||
void conv_bias::nchw44_pack_filter(const int8_t* src, int8_t* dst, int length) { | |||
static const uint8_t weight_idx_buffer[16] = {0, 4, 9, 13, 2, 6, 11, 15, | |||
12, 8, 5, 1, 14, 10, 7, 3}; | |||
constexpr int simd_len = 16; | |||
uint8x16_t weight_idx = vld1q_u8(weight_idx_buffer); | |||
for (int i = 0; i < length; i++) { | |||
int8x16_t result = vldq_tbl_s8(src + i * simd_len, weight_idx); | |||
vst1q_s8(dst + i * simd_len, result); | |||
} | |||
} | |||
/** | |||
origin src shape <n, ic/4, h, w, 4> | |||
packed src shape <n, ic/4, h, w, 16> | |||
example: (format like <ic>) | |||
origin | |||
<0> <0> <0> <0> | |||
packed | |||
low 64 bit <0> <1> <2> <3> | <0> <1> <2> <3> | |||
--------------------------------------------------------------------- | |||
high 64 bit <3> <2> <1> <0> | <3> <2> <1> <0> | |||
**/ | |||
void conv_bias::nchw44_pack_src(const int8_t* src, int8_t* dst, int length) { | |||
static const uint8_t src_idx_buffer[16] = {0, 1, 2, 3, 0, 1, 2, 3, | |||
3, 2, 1, 0, 3, 2, 1, 0}; | |||
constexpr int pack_ic = 4; | |||
constexpr int simd_len = 16; | |||
uint8x16_t src_idx = vld1q_u8(src_idx_buffer); | |||
for (int i = 0; i < length; i++) { | |||
int8x16_t result = vld_dup_tbl_s32(src + i * pack_ic, src_idx); | |||
vst1q_s8(dst + i * simd_len, result); | |||
} | |||
} | |||
template <BiasMode bias_mode, typename Op, int remain_w> | |||
void conv_bias::conv_direct_stride1_2x2_int8_nchw44( | |||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||
const Op& op) { | |||
MEGDNN_MARK_USED_VAR(temp); | |||
constexpr size_t filter_size = 2; | |||
constexpr size_t fh = filter_size; | |||
constexpr size_t fw = filter_size; | |||
constexpr size_t ic_step = 4; | |||
constexpr size_t oc_step = 4; | |||
constexpr size_t big_oc_step = 8; | |||
constexpr size_t oh_step = 1; | |||
constexpr size_t ow_step = 8; | |||
constexpr int pack_iw_len = 4; | |||
const size_t img_stride = oh * ow; | |||
const size_t ow_end = ow / ow_step * ow_step; | |||
const size_t ow_remain = ow - ow_end; | |||
const size_t oc_end = oc / big_oc_step * big_oc_step; | |||
const size_t oc_remain = oc - oc_end; | |||
const int ld_oc = oh * ow * ic_step; | |||
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, 0, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ld_oc, op); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, remain_w, | |||
filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ld_oc, op); | |||
} | |||
} | |||
} | |||
if (oc_remain > 0) { | |||
const size_t oc_idx = oc_end; | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, remain_w, | |||
filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
} | |||
} | |||
} | |||
template <BiasMode bias_mode, typename Op, int remain_w> | |||
void conv_bias::conv_direct_stride1_3x3_int8_nchw44( | |||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||
const Op& op) { | |||
MEGDNN_MARK_USED_VAR(temp); | |||
constexpr size_t filter_size = 3; | |||
constexpr size_t fh = filter_size; | |||
constexpr size_t fw = filter_size; | |||
constexpr size_t ic_step = 4; | |||
constexpr size_t oc_step = 4; | |||
constexpr size_t oh_step = 1; | |||
constexpr size_t ow_step = 8; | |||
constexpr int pack_iw_len = 4; | |||
const size_t img_stride = oh * ow; | |||
const size_t ow_end = ow / ow_step * ow_step; | |||
const size_t ow_remain = ow - ow_end; | |||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_3x3s1_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
ker_neon_dirctconv_3x3s1_oc4_ow8<bias_mode, Op, remain_w, | |||
filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
} | |||
} | |||
} | |||
template <BiasMode bias_mode, typename Op, int remain_w> | |||
void conv_bias::conv_direct_stride1_5x5_int8_nchw44( | |||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||
const Op& op) { | |||
MEGDNN_MARK_USED_VAR(temp); | |||
constexpr size_t filter_size = 5; | |||
constexpr size_t fh = filter_size; | |||
constexpr size_t fw = filter_size; | |||
constexpr size_t ic_step = 4; | |||
constexpr size_t oc_step = 4; | |||
constexpr size_t oh_step = 1; | |||
constexpr size_t ow_step = 8; | |||
constexpr int pack_iw_len = 4; | |||
const size_t img_stride = oh * ow; | |||
const size_t ow_end = ow / ow_step * ow_step; | |||
const size_t ow_remain = ow - ow_end; | |||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_5x5s1_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
ker_neon_dirctconv_5x5s1_oc4_ow8<bias_mode, Op, remain_w, | |||
filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
} | |||
} | |||
} | |||
template <BiasMode bias_mode, typename Op, int remain_w> | |||
void conv_bias::conv_direct_stride1_7x7_int8_nchw44( | |||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||
const Op& op) { | |||
MEGDNN_MARK_USED_VAR(temp); | |||
constexpr size_t filter_size = 7; | |||
constexpr size_t fh = filter_size; | |||
constexpr size_t fw = filter_size; | |||
constexpr size_t ic_step = 4; | |||
constexpr size_t oc_step = 4; | |||
constexpr size_t oh_step = 1; | |||
constexpr size_t ow_step = 8; | |||
constexpr int pack_iw_len = 4; | |||
const size_t img_stride = oh * ow; | |||
const size_t ow_end = ow / ow_step * ow_step; | |||
const size_t ow_remain = ow - ow_end; | |||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_7x7s1_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
ker_neon_dirctconv_7x7s1_oc4_ow8<bias_mode, Op, remain_w, | |||
filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
} | |||
} | |||
} | |||
#define INSTANTIATION(stride, i, bias, remain_w, Op) \ | |||
template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw44< \ | |||
bias, Op, remain_w>(const int8_t*, const int8_t*, const int32_t*, \ | |||
int32_t*, int8_t*, const size_t, const size_t, \ | |||
const size_t, const size_t, const size_t, \ | |||
const size_t, const Op&); | |||
#define FOR_OP(stride, i, bias, remain_w) \ | |||
INSTANTIATION(stride, i, bias, remain_w, \ | |||
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
INSTANTIATION(stride, i, bias, remain_w, \ | |||
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
INSTANTIATION(stride, i, bias, remain_w, \ | |||
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||
#define FOR_REMAIN(stride, i, bias) \ | |||
FOR_OP(stride, i, bias, 0) \ | |||
FOR_OP(stride, i, bias, 1) \ | |||
FOR_OP(stride, i, bias, 2) \ | |||
FOR_OP(stride, i, bias, 3) \ | |||
FOR_OP(stride, i, bias, 4) \ | |||
FOR_OP(stride, i, bias, 5) \ | |||
FOR_OP(stride, i, bias, 6) \ | |||
FOR_OP(stride, i, bias, 7) | |||
#define FOR_BIAS(stride, i) \ | |||
FOR_REMAIN(stride, i, BiasMode::NO_BIAS) \ | |||
FOR_REMAIN(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
#define FOR_FILTER(stride) \ | |||
FOR_BIAS(stride, 2) \ | |||
FOR_BIAS(stride, 3) \ | |||
FOR_BIAS(stride, 5) \ | |||
FOR_BIAS(stride, 7) | |||
FOR_FILTER(stride1) | |||
#undef FOR_STRIDE | |||
#undef FOR_FILTER | |||
#undef FOR_IC | |||
#undef FOR_BIAS | |||
#undef FOR_NONLINEAR | |||
#undef FOR_REMAIN | |||
#undef INSTANTIATION |
@@ -1,793 +0,0 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/int8/direct.h" | |||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
#include "src/arm_common/elemwise_op.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
namespace { | |||
/** | |||
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc> | |||
example: (format like weight<oc, ic>) | |||
packed weight | |||
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> | |||
--------------------------------------------------------------------- | |||
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> | |||
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> | |||
**/ | |||
// TODO: can try oh = 2 impl, oc = 8 impl | |||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||
static void ker_neon_dirctconv_3x3s2_oc4_ow8(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, | |||
int8_t* dst_ptr, int ic, int ih, | |||
int iw, const Op& op) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
constexpr int pack_iw_len = 4; | |||
const int ic_stride = ih * iw * pack_iw_len; | |||
int32x4_t c[2 * 4]; | |||
int8x16_t weight[3]; | |||
int8x16_t src[8 + 2]; | |||
int16x8_t temp_c[2]; | |||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * pack_iw_len; | |||
src[0] = vld1q_s8(src_ic_0_3); | |||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||
// oc == 0 | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0] = vld1q_s8(read_weight_ptr); | |||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]); | |||
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||
src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); | |||
c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]); | |||
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); | |||
src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); | |||
src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); | |||
src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); | |||
c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||
} | |||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||
static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, | |||
int8_t* dst_ptr, int ic, int ih, | |||
int iw, int ld_dst_oc, | |||
const Op& op) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int oc_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
constexpr int pack_iw_len = 4; | |||
const int ic_stride = ih * iw * pack_iw_len; | |||
const int ld_weight_oc4 = oc_step * fh * fw * ic; | |||
int32x4_t c[2][8]; | |||
int8x16_t weight[2][2]; | |||
int8x16_t src[8 + 1]; | |||
int16x8_t temp_c[4]; | |||
init_oc8_ow8<bias_mode>(c, bias_ptr, oc_step); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * pack_iw_len; | |||
src[0] = vld1q_s8(src_ic_0_3); | |||
src[1] = vld1q_s8(src_ic_0_3 + 16); | |||
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||
// oc == 0 | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0][0] = vld1q_s8(read_weight_ptr); | |||
weight[0][1] = vld1q_s8(read_weight_ptr + 16); | |||
weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); | |||
weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); | |||
c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); | |||
c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); | |||
c[0][1] = vdotq_s32_h(weight[0][0], src[2], c[0][1], temp_c[2]); | |||
c[1][1] = vdotq_s32_h(weight[1][0], src[2], c[1][1], temp_c[3]); | |||
c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); | |||
c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); | |||
c[0][1] = vdotq_s32_h(weight[0][1], src[3], c[0][1], temp_c[2]); | |||
c[1][1] = vdotq_s32_h(weight[1][1], src[3], c[1][1], temp_c[3]); | |||
c[0][2] = vdotq_s32_h(weight[0][0], src[4], c[0][2], temp_c[0]); | |||
c[1][2] = vdotq_s32_h(weight[1][0], src[4], c[1][2], temp_c[1]); | |||
c[0][3] = vdotq_s32_h(weight[0][0], src[6], c[0][3], temp_c[2]); | |||
c[1][3] = vdotq_s32_h(weight[1][0], src[6], c[1][3], temp_c[3]); | |||
c[0][2] = vdotq_s32_h(weight[0][1], src[5], c[0][2], temp_c[0]); | |||
c[1][2] = vdotq_s32_h(weight[1][1], src[5], c[1][2], temp_c[1]); | |||
c[0][3] = vdotq_s32_h(weight[0][1], src[7], c[0][3], temp_c[2]); | |||
c[1][3] = vdotq_s32_h(weight[1][1], src[7], c[1][3], temp_c[3]); | |||
src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||
src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||
src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||
c[0][4] = vdotq_s32_h(weight[0][0], src[8], c[0][4], temp_c[0]); | |||
c[1][4] = vdotq_s32_h(weight[1][0], src[8], c[1][4], temp_c[1]); | |||
c[0][5] = vdotq_s32_h(weight[0][0], src[1], c[0][5], temp_c[2]); | |||
c[1][5] = vdotq_s32_h(weight[1][0], src[1], c[1][5], temp_c[3]); | |||
c[0][4] = vdotq_s32_h(weight[0][1], src[0], c[0][4], temp_c[0]); | |||
c[1][4] = vdotq_s32_h(weight[1][1], src[0], c[1][4], temp_c[1]); | |||
c[0][5] = vdotq_s32_h(weight[0][1], src[2], c[0][5], temp_c[2]); | |||
c[1][5] = vdotq_s32_h(weight[1][1], src[2], c[1][5], temp_c[3]); | |||
src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||
src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); | |||
src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); | |||
src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); | |||
c[0][6] = vdotq_s32_h(weight[0][0], src[3], c[0][6], temp_c[0]); | |||
c[1][6] = vdotq_s32_h(weight[1][0], src[3], c[1][6], temp_c[1]); | |||
c[0][7] = vdotq_s32_h(weight[0][0], src[5], c[0][7], temp_c[2]); | |||
c[1][7] = vdotq_s32_h(weight[1][0], src[5], c[1][7], temp_c[3]); | |||
c[0][6] = vdotq_s32_h(weight[0][1], src[4], c[0][6], temp_c[0]); | |||
c[1][6] = vdotq_s32_h(weight[1][1], src[4], c[1][6], temp_c[1]); | |||
c[0][7] = vdotq_s32_h(weight[0][1], src[6], c[0][7], temp_c[2]); | |||
c[1][7] = vdotq_s32_h(weight[1][1], src[6], c[1][7], temp_c[3]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
store_oc8_ow8_remain_static<remain_w>(c, op, dst_ptr, ld_dst_oc); | |||
} | |||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||
static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, | |||
int8_t* dst_ptr, int ic, int ih, | |||
int iw, const Op& op) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
constexpr int pack_iw_len = 4; | |||
const int ic_stride = ih * iw * pack_iw_len; | |||
int32x4_t c[2 * 4]; | |||
int8x16_t weight[2]; | |||
int8x16_t src[8 + 1]; | |||
int16x8_t temp_c[2]; | |||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * pack_iw_len; | |||
src[0] = vld1q_s8(src_ic_0_3); | |||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||
// oc == 0 | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0] = vld1q_s8(read_weight_ptr); | |||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); | |||
src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||
src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||
src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||
c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[0], src[1], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[1], src[0], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[1], src[2], c[5], temp_c[1]); | |||
src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||
src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); | |||
src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); | |||
src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); | |||
c[6] = vdotq_s32_h(weight[0], src[3], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[0], src[5], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[1], src[4], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[1], src[6], c[7], temp_c[1]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||
} | |||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||
static void ker_neon_dirctconv_5x5s2_oc4_ow8(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, | |||
int8_t* dst_ptr, int ic, int ih, | |||
int iw, const Op& op) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
constexpr int pack_iw_len = 4; | |||
const int ic_stride = ih * iw * pack_iw_len; | |||
int32x4_t c[2 * 4]; | |||
int8x16_t weight[5]; | |||
int8x16_t src[8 + 2]; | |||
int16x8_t temp_c[2]; | |||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * pack_iw_len; | |||
src[0] = vld1q_s8(src_ic_0_3); | |||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||
// oc == 0 | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0] = vld1q_s8(read_weight_ptr); | |||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[3], src[5], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[4], src[6], c[1], temp_c[1]); | |||
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||
c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[3], src[7], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[3], src[9], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[4], src[8], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[4], src[0], c[3], temp_c[1]); | |||
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||
src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); | |||
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); | |||
src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); | |||
c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[3], src[1], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[3], src[3], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[4], src[2], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[4], src[4], c[5], temp_c[1]); | |||
src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); | |||
src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); | |||
src[7] = vld1q_s8((src_ic_0_3 + 17 * 16)); | |||
src[8] = vld1q_s8((src_ic_0_3 + 18 * 16)); | |||
c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[3], src[5], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[3], src[7], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[4], src[6], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[4], src[8], c[7], temp_c[1]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||
} | |||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||
static void ker_neon_dirctconv_7x7s2_oc4_ow8(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, | |||
int8_t* dst_ptr, int ic, int ih, | |||
int iw, const Op& op) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
constexpr int pack_iw_len = 4; | |||
const int ic_stride = ih * iw * pack_iw_len; | |||
int32x4_t c[2 * 4]; | |||
int8x16_t weight[7]; | |||
int8x16_t src[8 + 2]; | |||
int16x8_t temp_c[2]; | |||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * pack_iw_len; | |||
src[0] = vld1q_s8(src_ic_0_3); | |||
src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); | |||
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||
src[9] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||
// oc == 0 | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0] = vld1q_s8(read_weight_ptr); | |||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); | |||
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); | |||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[3], src[5], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[4], src[6], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[5], src[5], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[5], src[7], c[1], temp_c[1]); | |||
c[0] = vdotq_s32_h(weight[6], src[6], c[0], temp_c[0]); | |||
c[1] = vdotq_s32_h(weight[6], src[8], c[1], temp_c[1]); | |||
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||
src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||
src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||
c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[3], src[7], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[3], src[9], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[4], src[8], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[4], src[0], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[5], src[9], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[5], src[1], c[3], temp_c[1]); | |||
c[2] = vdotq_s32_h(weight[6], src[0], c[2], temp_c[0]); | |||
c[3] = vdotq_s32_h(weight[6], src[2], c[3], temp_c[1]); | |||
src[3] = vld1q_s8(src_ic_0_3 + 13 * 16); | |||
src[4] = vld1q_s8(src_ic_0_3 + 14 * 16); | |||
src[5] = vld1q_s8(src_ic_0_3 + 15 * 16); | |||
src[6] = vld1q_s8(src_ic_0_3 + 16 * 16); | |||
c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[3], src[1], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[3], src[3], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[4], src[2], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[4], src[4], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[5], src[3], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[5], src[5], c[5], temp_c[1]); | |||
c[4] = vdotq_s32_h(weight[6], src[4], c[4], temp_c[0]); | |||
c[5] = vdotq_s32_h(weight[6], src[6], c[5], temp_c[1]); | |||
src[7] = vld1q_s8(src_ic_0_3 + 17 * 16); | |||
src[8] = vld1q_s8(src_ic_0_3 + 18 * 16); | |||
src[9] = vld1q_s8(src_ic_0_3 + 19 * 16); | |||
src[0] = vld1q_s8(src_ic_0_3 + 20 * 16); | |||
c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[3], src[5], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[3], src[7], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[4], src[6], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[4], src[8], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[5], src[7], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[5], src[9], c[7], temp_c[1]); | |||
c[6] = vdotq_s32_h(weight[6], src[8], c[6], temp_c[0]); | |||
c[7] = vdotq_s32_h(weight[6], src[0], c[7], temp_c[1]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||
} | |||
} // namespace | |||
template <BiasMode bias_mode, typename Op, int remain_w> | |||
void conv_bias::conv_direct_stride2_2x2_int8_nchw44( | |||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||
const Op& op) { | |||
MEGDNN_MARK_USED_VAR(temp); | |||
constexpr size_t filter_size = 2; | |||
constexpr size_t fh = filter_size; | |||
constexpr size_t fw = filter_size; | |||
constexpr size_t ic_step = 4; | |||
constexpr size_t oc_step = 4; | |||
constexpr size_t big_oc_step = 8; | |||
constexpr size_t oh_step = 1; | |||
constexpr size_t ow_step = 8; | |||
constexpr size_t stride_h = 2; | |||
constexpr size_t stride_w = 2; | |||
constexpr int pack_iw_len = 4; | |||
const size_t out_img_stride = oh * ow; | |||
const size_t ow_end = ow / ow_step * ow_step; | |||
const size_t ow_remain = ow - ow_end; | |||
const size_t oc_end = oc / big_oc_step * big_oc_step; | |||
const size_t oc_remain = oc - oc_end; | |||
const int ld_oc = oh * ow * ic_step; | |||
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = | |||
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||
pack_iw_len; | |||
const size_t dst_offset = oc_idx * out_img_stride + | |||
(oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, 0, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ld_oc, op); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = | |||
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||
pack_iw_len; | |||
const size_t dst_offset = oc_idx * out_img_stride + | |||
(oh_idx * ow + ow_end) * oc_step; | |||
ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, remain_w, | |||
filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ld_oc, op); | |||
} | |||
} | |||
} | |||
if (oc_remain > 0) { | |||
const size_t oc_idx = oc_end; | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = | |||
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||
pack_iw_len; | |||
const size_t dst_offset = oc_idx * out_img_stride + | |||
(oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = | |||
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||
pack_iw_len; | |||
const size_t dst_offset = oc_idx * out_img_stride + | |||
(oh_idx * ow + ow_end) * oc_step; | |||
ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, remain_w, | |||
filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
} | |||
} | |||
} | |||
template <BiasMode bias_mode, typename Op, int remain_w> | |||
void conv_bias::conv_direct_stride2_3x3_int8_nchw44( | |||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||
const Op& op) { | |||
MEGDNN_MARK_USED_VAR(temp); | |||
constexpr size_t filter_size = 3; | |||
constexpr size_t fh = filter_size; | |||
constexpr size_t fw = filter_size; | |||
constexpr size_t ic_step = 4; | |||
constexpr size_t oc_step = 4; | |||
constexpr size_t oh_step = 1; | |||
constexpr size_t ow_step = 8; | |||
constexpr size_t stride_h = 2; | |||
constexpr size_t stride_w = 2; | |||
constexpr int pack_iw_len = 4; | |||
const size_t img_stride = oh * ow; | |||
const size_t ow_end = ow / ow_step * ow_step; | |||
const size_t ow_remain = ow - ow_end; | |||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = | |||
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||
pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_3x3s2_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = | |||
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||
pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
ker_neon_dirctconv_3x3s2_oc4_ow8<bias_mode, Op, remain_w, | |||
filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
} | |||
} | |||
} | |||
template <BiasMode bias_mode, typename Op, int remain_w> | |||
void conv_bias::conv_direct_stride2_5x5_int8_nchw44( | |||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||
const Op& op) { | |||
MEGDNN_MARK_USED_VAR(temp); | |||
constexpr size_t filter_size = 5; | |||
constexpr size_t fh = filter_size; | |||
constexpr size_t fw = filter_size; | |||
constexpr size_t ic_step = 4; | |||
constexpr size_t oc_step = 4; | |||
constexpr size_t oh_step = 1; | |||
constexpr size_t ow_step = 8; | |||
constexpr size_t stride_h = 2; | |||
constexpr size_t stride_w = 2; | |||
constexpr int pack_iw_len = 4; | |||
const size_t img_stride = oh * ow; | |||
const size_t ow_end = ow / ow_step * ow_step; | |||
const size_t ow_remain = ow - ow_end; | |||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = | |||
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||
pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_5x5s2_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = | |||
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||
pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
ker_neon_dirctconv_5x5s2_oc4_ow8<bias_mode, Op, remain_w, | |||
filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
} | |||
} | |||
} | |||
template <BiasMode bias_mode, typename Op, int remain_w> | |||
void conv_bias::conv_direct_stride2_7x7_int8_nchw44( | |||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||
const Op& op) { | |||
MEGDNN_MARK_USED_VAR(temp); | |||
constexpr size_t filter_size = 7; | |||
constexpr size_t fh = filter_size; | |||
constexpr size_t fw = filter_size; | |||
constexpr size_t ic_step = 4; | |||
constexpr size_t oc_step = 4; | |||
constexpr size_t oh_step = 1; | |||
constexpr size_t ow_step = 8; | |||
constexpr size_t stride_h = 2; | |||
constexpr size_t stride_w = 2; | |||
constexpr int pack_iw_len = 4; | |||
const size_t img_stride = oh * ow; | |||
const size_t ow_end = ow / ow_step * ow_step; | |||
const size_t ow_remain = ow - ow_end; | |||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = | |||
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||
pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_7x7s2_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = | |||
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||
pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
ker_neon_dirctconv_7x7s2_oc4_ow8<bias_mode, Op, remain_w, | |||
filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, op); | |||
} | |||
} | |||
} | |||
} | |||
#define INSTANTIATION(stride, i, bias, remain_w, Op) \ | |||
template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw44< \ | |||
bias, Op, remain_w>(const int8_t*, const int8_t*, const int32_t*, \ | |||
int32_t*, int8_t*, const size_t, const size_t, \ | |||
const size_t, const size_t, const size_t, \ | |||
const size_t, const Op&); | |||
#define FOR_OP(stride, i, bias, remain_w) \ | |||
INSTANTIATION(stride, i, bias, remain_w, \ | |||
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
INSTANTIATION(stride, i, bias, remain_w, \ | |||
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
INSTANTIATION(stride, i, bias, remain_w, \ | |||
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||
#define FOR_REMAIN(stride, i, bias) \ | |||
FOR_OP(stride, i, bias, 0) \ | |||
FOR_OP(stride, i, bias, 1) \ | |||
FOR_OP(stride, i, bias, 2) \ | |||
FOR_OP(stride, i, bias, 3) \ | |||
FOR_OP(stride, i, bias, 4) \ | |||
FOR_OP(stride, i, bias, 5) \ | |||
FOR_OP(stride, i, bias, 6) \ | |||
FOR_OP(stride, i, bias, 7) | |||
#define FOR_BIAS(stride, i) \ | |||
FOR_REMAIN(stride, i, BiasMode::NO_BIAS) \ | |||
FOR_REMAIN(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
#define FOR_FILTER(stride) \ | |||
FOR_BIAS(stride, 2) \ | |||
FOR_BIAS(stride, 3) \ | |||
FOR_BIAS(stride, 5) \ | |||
FOR_BIAS(stride, 7) | |||
FOR_FILTER(stride2) | |||
#undef FOR_STRIDE | |||
#undef FOR_FILTER | |||
#undef FOR_IC | |||
#undef FOR_BIAS | |||
#undef FOR_NONLINEAR | |||
#undef FOR_REMAIN | |||
#undef INSTANTIATION |
@@ -46,11 +46,10 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
AlgoQU8DirectStride1 qu8_direct_stride1_small_group{false}; | |||
AlgoS8DirectStride2 s8_direct_stride2_large_group{true}; | |||
AlgoS8DirectStride2 s8_direct_stride2_small_group{false}; | |||
AlgoS8DirectStride2NCHW44 s8_direct_stride2_nchw44; | |||
AlgoS8DirectNCHW44 s8_direct_nchw44; | |||
AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; | |||
AlgoS8DirectStride1 s8_direct_stride1_large_group{true}; | |||
AlgoS8DirectStride1 s8_direct_stride1_small_group{false}; | |||
AlgoS8DirectStride1NCHW44 s8_direct_stride1_nchw44; | |||
AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; | |||
AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; | |||
@@ -114,11 +113,10 @@ public: | |||
direct_algos.emplace_back(&qu8_direct_stride1_small_group); | |||
direct_algos.emplace_back(&s8_direct_stride2_large_group); | |||
direct_algos.emplace_back(&s8_direct_stride2_small_group); | |||
direct_algos.emplace_back(&s8_direct_stride2_nchw44); | |||
direct_algos.emplace_back(&s8_direct_nchw44); | |||
direct_algos.emplace_back(&s8_direct_nchw_nchw44); | |||
direct_algos.emplace_back(&s8_direct_stride1_large_group); | |||
direct_algos.emplace_back(&s8_direct_stride1_small_group); | |||
direct_algos.emplace_back(&s8_direct_stride1_nchw44); | |||
direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); | |||
direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); | |||
@@ -37,9 +37,8 @@ protected: | |||
private: | |||
class AlgoS8DirectStride1; | |||
class AlgoS8DirectStride1NCHW44; | |||
class AlgoS8DirectStride2; | |||
class AlgoS8DirectStride2NCHW44; | |||
class AlgoS8DirectNCHW44; | |||
class AlgoS8DirectNCHWNCHW44; | |||
class AlgoQU8DirectStride1; | |||
class AlgoQU8DirectStride2; | |||
@@ -27,6 +27,8 @@ struct NoneOp; | |||
#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ | |||
template <> \ | |||
struct NoneOp<_ctype> : NoneOpBase<_ctype> { \ | |||
NoneOp(){}; \ | |||
NoneOp(float, float){}; \ | |||
using NoneOpBase::NoneOpBase; \ | |||
using NoneOpBase::operator(); \ | |||
constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
@@ -226,7 +226,15 @@ static void benchmark_convbias(Handle* handle, std::string int_name, | |||
run(1, 3, 32, 224, 224, 5, 1, true); | |||
run(1, 3, 64, 224, 224, 7, 1, true); | |||
for (size_t stride : {1, 2}) { | |||
run(1, 64, 128, 56, 56, 3, 2, false); | |||
run(1, 128, 256, 28, 28, 3, 2, false); | |||
run(1, 256, 512, 14, 14, 3, 2, false); | |||
run(1, 128, 128, 28, 28, 3, 1, false); | |||
run(1, 256, 256, 14, 14, 3, 1, false); | |||
run(1, 512, 512, 7, 7, 3, 1, false); | |||
for (size_t stride : {1}) { | |||
printf("stride %zu\n", stride); | |||
for (size_t filter_size : {2, 3, 5, 7}) { | |||
for (size_t img_size : {32}) { | |||
@@ -527,12 +527,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) { | |||
checker_conv_bias_qint8x8x8( | |||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), | |||
handle(), "S8_NCHW44_DIRECT_STRD1"); | |||
handle(), "S8_NCHW44_DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) { | |||
checker_conv_bias_qint8x8x32( | |||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true), | |||
handle(), "S8_NCHW44_DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) { | |||
checker_conv_bias_qint8x8x32( | |||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true), | |||
handle(), "S8_NCHW44_DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) { | |||
checker_conv_bias_qint8x8x8( | |||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | |||
handle(), "S8_NCHW44_DIRECT_STRD2"); | |||
handle(), "S8_NCHW44_DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) { | |||
checker_conv_bias_qint8x8x8( | |||
@@ -1085,7 +1095,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) { | |||
dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | |||
using namespace conv_bias; | |||
@@ -1096,17 +1105,17 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | |||
param::MatrixMul::Format format, float eps) { | |||
for (auto&& arg : args) { | |||
for (uint32_t m : out_size) { | |||
checker.set_extra_opr_impl(std::bind( | |||
winograd_algo_extra_impl, std::placeholders::_1, m, | |||
arg.param, handle, format)); | |||
checker.set_dtype(0, A_dtype) | |||
.set_dtype(1, B_dtype) | |||
.set_dtype(2, C_dtype) | |||
.set_dtype(4, D_dtype) | |||
.set_epsilon(eps) | |||
.set_param(arg.param) | |||
.execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||
} | |||
checker.set_extra_opr_impl(std::bind( | |||
winograd_algo_extra_impl, std::placeholders::_1, m, | |||
arg.param, handle, format)); | |||
checker.set_dtype(0, A_dtype) | |||
.set_dtype(1, B_dtype) | |||
.set_dtype(2, C_dtype) | |||
.set_dtype(4, D_dtype) | |||
.set_epsilon(eps) | |||
.set_param(arg.param) | |||
.execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||
} | |||
} | |||
}; | |||
@@ -1118,7 +1127,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | |||
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | |||
ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str())); | |||
std::vector<TestArg> quantized_args = get_int8_nchw44_args (3,4); | |||
std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4); | |||
UniformIntRNG int_rng{-50, 50}; | |||
checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); | |||
run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f), | |||
@@ -1126,8 +1135,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | |||
dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) { | |||
using namespace conv_bias; | |||
Checker<ConvBiasForward> checker(handle()); | |||
@@ -1137,17 +1146,17 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPM | |||
param::MatrixMul::Format format, float eps) { | |||
for (auto&& arg : args) { | |||
for (uint32_t m : out_size) { | |||
checker.set_extra_opr_impl(std::bind( | |||
winograd_algo_extra_impl, std::placeholders::_1, m, | |||
arg.param, handle, format)); | |||
checker.set_dtype(0, A_dtype) | |||
.set_dtype(1, B_dtype) | |||
.set_dtype(2, C_dtype) | |||
.set_dtype(4, D_dtype) | |||
.set_epsilon(eps) | |||
.set_param(arg.param) | |||
.execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||
} | |||
checker.set_extra_opr_impl(std::bind( | |||
winograd_algo_extra_impl, std::placeholders::_1, m, | |||
arg.param, handle, format)); | |||
checker.set_dtype(0, A_dtype) | |||
.set_dtype(1, B_dtype) | |||
.set_dtype(2, C_dtype) | |||
.set_dtype(4, D_dtype) | |||
.set_epsilon(eps) | |||
.set_param(arg.param) | |||
.execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||
} | |||
} | |||
}; | |||
@@ -1168,7 +1177,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPM | |||
dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) { | |||
using namespace conv_bias; | |||
Checker<ConvBiasForward> checker(handle()); | |||
@@ -1196,21 +1206,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F | |||
#if MEGDNN_AARCH64 | |||
const char* matmul_name = "AARCH64_F32_MK4_4x16"; | |||
#else | |||
const char* matmul_name = "ARMV7_F32_MK4_4x8"; | |||
const char* matmul_name = "ARMV7_F32_MK4_4x8"; | |||
#endif | |||
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | |||
ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str())); | |||
std::vector<TestArg> quantized_args = | |||
get_int8_nchw44_args(3, 4, true); | |||
std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true); | |||
UniformIntRNG int_rng{-50, 50}; | |||
checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); | |||
run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f), | |||
dtype::QuantizedS8(0.01887994f), | |||
dtype::QuantizedS32(0.41113496f * 0.01887994f), | |||
dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, epsilon); | |||
dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, | |||
epsilon); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) { | |||
using namespace conv_bias; | |||
Checker<ConvBiasForward> checker(handle()); | |||
@@ -1238,7 +1249,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F | |||
#if MEGDNN_AARCH64 | |||
const char* matmul_name = "AARCH64_F32_MK4_4x16"; | |||
#else | |||
const char* matmul_name = "ARMV7_F32_MK4_4x8"; | |||
const char* matmul_name = "ARMV7_F32_MK4_4x8"; | |||
#endif | |||
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | |||
ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str())); | |||
@@ -1249,10 +1260,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F | |||
run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f), | |||
dtype::QuantizedS8(0.01887994f), | |||
dtype::QuantizedS32(0.41113496f * 0.01887994f), | |||
dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, epsilon); | |||
dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, | |||
epsilon); | |||
} | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) { | |||
using namespace conv_bias; | |||