2x2 3x3 5x5 7x7 directconv
GitOrigin-RevId: 3710182af1
tags/v1.0.0-rc1
@@ -38,6 +38,18 @@ public: | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
}; | }; | ||||
class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { | |||||
public: | |||||
AlgoS8x8x16DirectNCHW44() {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "S8x8x16_NCHW44_DIRECT"; } | |||||
bool usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
const NCBKernSizeParam& param) const override; | |||||
}; | |||||
class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { | class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { | ||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; | WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; | ||||
@@ -0,0 +1,481 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/int8x8x16/conv_direct_int8x8x16_nchw44_algo.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/int8x8x16/algos.h" | |||||
#include "src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h" | |||||
#include "midout.h" | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
using conv_fun = std::function<void( | |||||
const WorkspaceBundle& bundle, | |||||
const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | |||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8x8x16_nchw44_direct) | |||||
static void get_rectified_size( | |||||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, | |||||
int& iw2) { | |||||
auto&& fm = param.filter_meta; | |||||
int ih = param.isz[0]; | |||||
int iw = param.isz[1]; | |||||
int ph = fm.padding[0]; | |||||
int pw = fm.padding[1]; | |||||
ih2 = ih + ph * 2; | |||||
iw2 = iw + pw * 2; | |||||
} | |||||
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||||
auto&& fm = param.filter_meta; | |||||
size_t group = fm.group; | |||||
size_t batch = param.n; | |||||
size_t IC = fm.icpg; | |||||
int IH2, IW2; | |||||
get_rectified_size(param, IH2, IW2); | |||||
if (group == 1) { | |||||
size_t src_size = 0; | |||||
bool need_padding = param.filter_meta.padding[0] > 0 || | |||||
param.filter_meta.padding[1] > 0; | |||||
src_size = need_padding | |||||
? batch * group * IC * IH2 * IW2 * sizeof(int8_t) | |||||
: 0; | |||||
#if MEGDNN_ARMV7 | |||||
if (fm.stride[0] == 1) { | |||||
constexpr int src_expand_element = 4; | |||||
src_size = batch * group * IC * IH2 * IW2 * sizeof(int8_t) * | |||||
src_expand_element; | |||||
} | |||||
#endif | |||||
return {nullptr, {src_size}}; | |||||
} else { | |||||
size_t src_size = 0; | |||||
bool need_padding = param.filter_meta.padding[0] > 0 || | |||||
param.filter_meta.padding[1] > 0; | |||||
src_size = need_padding | |||||
? param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) | |||||
: 0; | |||||
#if MEGDNN_ARMV7 | |||||
if (fm.stride[0] == 1) { | |||||
constexpr int src_expand_element = 4; | |||||
src_size = param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) * | |||||
src_expand_element; | |||||
} | |||||
#endif | |||||
return {nullptr, {src_size}}; | |||||
} | |||||
}; | |||||
#if MEGDNN_ARMV7 | |||||
static void copy_padding_kern(const WorkspaceBundle& bundle, | |||||
const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids) { | |||||
int IH = kern_param.isz[0]; | |||||
int IW = kern_param.isz[1]; | |||||
int IC = kern_param.filter_meta.icpg; | |||||
int PH = kern_param.filter_meta.padding[0]; | |||||
int PW = kern_param.filter_meta.padding[1]; | |||||
int GROUP = kern_param.filter_meta.group; | |||||
int IH2, IW2; | |||||
get_rectified_size(kern_param, IH2, IW2); | |||||
int padding_group_size = IH2 * IW2 * IC; | |||||
//! Used for get the workspace offset | |||||
constexpr int pack_ic = 4; | |||||
constexpr int src_expand_element = 4;; | |||||
size_t workspace_ic_block = 4; | |||||
size_t workspace_batch_id = workspace_ids[0]; | |||||
size_t workspace_group_id = workspace_ids[1]; | |||||
size_t workspace_ic_id = workspace_ids[2]; | |||||
size_t workspace_ic = workspace_ic_id * workspace_ic_block; | |||||
size_t batch_id = ncb_index.ndrange_id[0]; | |||||
size_t group_id = ncb_index.ndrange_id[1]; | |||||
size_t group_pack_size = 1; | |||||
int nr_pad_w = PW * pack_ic * src_expand_element; | |||||
int nr_pad_h = PH * IW2 * pack_ic * src_expand_element; | |||||
int row_last_pad = (IW2 - IW - PW) * pack_ic * src_expand_element; | |||||
int col_last_pad = (IH2 - IH - PH) * IW2 * pack_ic * src_expand_element; | |||||
const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>( | |||||
batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); | |||||
//! copy to sptr_base to eliminate padding effect | |||||
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) + | |||||
(workspace_batch_id * GROUP * padding_group_size + | |||||
workspace_group_id * padding_group_size + | |||||
workspace_ic * IH2 * IW2) * | |||||
src_expand_element; | |||||
size_t nr_ic = workspace_ic_block; | |||||
if (GROUP > 1) { | |||||
nr_ic = IC; | |||||
} | |||||
rep_step(ic_idx, nr_ic, pack_ic) { | |||||
std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); | |||||
sptr_base += nr_pad_h; | |||||
rep(ih_idx, IH) { | |||||
std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); | |||||
sptr_base += nr_pad_w; | |||||
int8x8x16_direct_nchw44::nchw44_pack_src(sptr, sptr_base, IW); | |||||
sptr_base += IW * pack_ic * src_expand_element; | |||||
sptr += IW * pack_ic; | |||||
std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); | |||||
sptr_base += row_last_pad; | |||||
} | |||||
std::memset(sptr_base, 0, col_last_pad * sizeof(int8_t)); | |||||
sptr_base += col_last_pad; | |||||
} | |||||
} | |||||
#endif | |||||
static void copy_padding_kern_no_pack_src(const WorkspaceBundle& bundle, | |||||
const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids) { | |||||
int IH = kern_param.isz[0]; | |||||
int IW = kern_param.isz[1]; | |||||
int IC = kern_param.filter_meta.icpg; | |||||
int PH = kern_param.filter_meta.padding[0]; | |||||
int PW = kern_param.filter_meta.padding[1]; | |||||
int GROUP = kern_param.filter_meta.group; | |||||
int IH2, IW2; | |||||
get_rectified_size(kern_param, IH2, IW2); | |||||
int padding_group_size = IH2 * IW2 * IC; | |||||
//! Used for get the workspace offset | |||||
constexpr int pack_ic = 4; | |||||
constexpr int src_expand_element = 1; | |||||
size_t workspace_ic_block = 4; | |||||
size_t workspace_batch_id = workspace_ids[0]; | |||||
size_t workspace_group_id = workspace_ids[1]; | |||||
size_t workspace_ic_id = workspace_ids[2]; | |||||
size_t workspace_ic = workspace_ic_id * workspace_ic_block; | |||||
size_t batch_id = ncb_index.ndrange_id[0]; | |||||
size_t group_id = ncb_index.ndrange_id[1]; | |||||
size_t group_pack_size = 1; | |||||
int nr_pad_w = PW * pack_ic * src_expand_element; | |||||
int nr_pad_h = PH * IW2 * pack_ic * src_expand_element; | |||||
int row_last_pad = (IW2 - IW - PW) * pack_ic * src_expand_element; | |||||
int col_last_pad = (IH2 - IH - PH) * IW2 * pack_ic * src_expand_element; | |||||
const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>( | |||||
batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); | |||||
//! copy to sptr_base to eliminate padding effect | |||||
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) + | |||||
(workspace_batch_id * GROUP * padding_group_size + | |||||
workspace_group_id * padding_group_size + | |||||
workspace_ic * IH2 * IW2) * | |||||
src_expand_element; | |||||
size_t nr_ic = workspace_ic_block; | |||||
if (GROUP > 1) { | |||||
nr_ic = IC; | |||||
} | |||||
rep_step(ic_idx, nr_ic, pack_ic) { | |||||
std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); | |||||
sptr_base += nr_pad_h; | |||||
rep(ih_idx, IH) { | |||||
std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); | |||||
sptr_base += nr_pad_w; | |||||
std::memcpy(sptr_base, sptr, IW * pack_ic); | |||||
sptr_base += IW * pack_ic * src_expand_element; | |||||
sptr += IW * pack_ic; | |||||
std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); | |||||
sptr_base += row_last_pad; | |||||
} | |||||
std::memset(sptr_base, 0, col_last_pad * sizeof(int8_t)); | |||||
sptr_base += col_last_pad; | |||||
} | |||||
} | |||||
template <size_t filter, BiasMode bias_mode, int stride> | |||||
static void do_conv_kern(const WorkspaceBundle& bundle, | |||||
const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids, | |||||
const CpuNDRange& ncb_range) { | |||||
size_t OH = kern_param.osz[0]; | |||||
size_t OW = kern_param.osz[1]; | |||||
size_t FH = kern_param.filter_meta.spatial[0]; | |||||
size_t FW = kern_param.filter_meta.spatial[1]; | |||||
size_t IC = kern_param.filter_meta.icpg; | |||||
size_t OC = kern_param.filter_meta.ocpg; | |||||
size_t GROUP = kern_param.filter_meta.group; | |||||
int IH2, IW2; | |||||
get_rectified_size(kern_param, IH2, IW2); | |||||
size_t padding_group_size = IH2 * IW2 * IC; | |||||
constexpr size_t pack_c = 4; | |||||
const size_t workspace_batch_id = workspace_ids[0]; | |||||
const size_t workspace_group_id = workspace_ids[1]; | |||||
const size_t batch_id = ncb_index.ndrange_id[0]; | |||||
const size_t group_id = ncb_index.ndrange_id[1]; | |||||
const size_t oc_id = ncb_index.ndrange_id[2]; | |||||
const size_t oc_block_num = ncb_range[2]; | |||||
megdnn_assert((OC & (pack_c - 1)) == 0, "OC must times of 4"); | |||||
size_t nr_pack_per_step = div_ceil(OC / pack_c, oc_block_num); | |||||
size_t oc_block = nr_pack_per_step * pack_c; | |||||
const size_t oc_idx = oc_id * oc_block; | |||||
if (oc_id == (oc_block_num - 1)) { | |||||
oc_block = OC - oc_id * nr_pack_per_step * pack_c; | |||||
} | |||||
megdnn_assert(oc_block % pack_c == 0, | |||||
"oc must be devisible by 4, but oc = %zu", oc_block); | |||||
bool need_padding = kern_param.filter_meta.padding[0] > 0 || | |||||
kern_param.filter_meta.padding[1] > 0; | |||||
const int8_t* sptr = need_padding | |||||
? static_cast<int8_t*>(bundle.get(0)) + | |||||
workspace_batch_id * GROUP * padding_group_size + | |||||
workspace_group_id * padding_group_size | |||||
: kern_param.src<int8_t>(batch_id, group_id); | |||||
//!armv7 use packsrc mode | |||||
#if MEGDNN_ARMV7 | |||||
if (stride == 1) { | |||||
constexpr size_t src_expand_size = 4; | |||||
sptr = static_cast<int8_t*>(bundle.get(0)) + | |||||
workspace_batch_id * GROUP * padding_group_size * | |||||
src_expand_size + | |||||
workspace_group_id * padding_group_size * src_expand_size; | |||||
} | |||||
#endif | |||||
const int8_t* fptr = | |||||
kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC; | |||||
int16_t* dst = reinterpret_cast<int16_t*>( | |||||
kern_param.dst<void>(batch_id, group_id, oc_idx)); | |||||
const int16_t* bptr = | |||||
kern_param.bias<dt_int16>(batch_id, group_id) + oc_idx; | |||||
int8x8x16_direct_nchw44::ConvDirectInt8Nchw44Choose< | |||||
bias_mode, filter, stride>::impl(sptr, fptr, bptr, dst, oc_block, | |||||
IC, IH2, IW2, OH, OW); | |||||
} | |||||
bool ConvBiasImpl::AlgoS8x8x16DirectNCHW44::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
MEGDNN_MARK_USED_VAR(algo_selection_strategy); | |||||
auto&& fm = param.filter_meta; | |||||
const int fh = fm.spatial[0]; | |||||
const int fw = fm.spatial[1]; | |||||
const int oc = fm.ocpg; | |||||
const int ic = fm.icpg; | |||||
const bool avaible = //! src and filter are int8, dst is int16_t | |||||
(param.src_type.enumv() == DTypeEnum::Int8 && | |||||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||||
param.dst_type.enumv() == DTypeEnum::Int16) && | |||||
(fm.format == param::Convolution::Format::NCHW44) && | |||||
(oc % 4 == 0 && ic % 4 == 0 && oc >= 4) && !fm.should_flip && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && | |||||
(fm.stride[0] == 2 || fm.stride[0] == 1) && fh == fw && | |||||
(fh == 2 || fh == 3 || fh == 5 || fh == 7) && | |||||
param.nonlineMode == NonlineMode::IDENTITY && | |||||
param.bias_mode != BiasMode::BIAS; | |||||
return avaible; | |||||
} | |||||
size_t ConvBiasImpl::AlgoS8x8x16DirectNCHW44::get_workspace( | |||||
const NCBKernSizeParam& param) const { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoS8x8x16DirectNCHW44::dispatch_kerns( | |||||
const NCBKernSizeParam& param) const { | |||||
auto fm = param.filter_meta; | |||||
size_t N = param.n; | |||||
size_t IC = fm.icpg; | |||||
size_t OC = fm.ocpg; | |||||
size_t group = fm.group; | |||||
size_t fh = fm.spatial[0]; | |||||
size_t fw = fm.spatial[1]; | |||||
size_t ph = fm.padding[0]; | |||||
size_t pw = fm.padding[1]; | |||||
WorkspaceBundle wbundle = get_bundle(param); | |||||
conv_fun do_conv_fun = nullptr; | |||||
#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8x8x16_nchw44_direct, \ | |||||
midout_iv("int8x8x16_nchw44_direct_" \ | |||||
"conv" #stride #filter #bias_mode##_hash)) { \ | |||||
do_conv_fun = do_conv_kern<filter, bias_mode, stride>; \ | |||||
} \ | |||||
MIDOUT_END(); | |||||
#define GET_OP_PARAM(stride, filter, bias_mode) \ | |||||
switch (param.nonlineMode) { \ | |||||
case param::ConvBias::NonlineMode::IDENTITY: \ | |||||
DO_CONV_KERN_FUN(stride, dt_int16, filter, bias_mode) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_throw(ssprintf("only support IDENTITY mode when dst is " \ | |||||
"dt_int16 nonlineMode is %d", \ | |||||
uint32_t(param.nonlineMode)) \ | |||||
.c_str()); \ | |||||
break; \ | |||||
} | |||||
#define GET_BIAS_MODE_PARAM(stride, filter) \ | |||||
switch (param.bias_mode) { \ | |||||
case BiasMode::NO_BIAS: \ | |||||
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||||
break; \ | |||||
case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_throw(ssprintf("only support NO_BIAS/BROADCAST biasmode " \ | |||||
"when dst is " \ | |||||
"dt_int16 biasmode is %d", \ | |||||
uint32_t(param.bias_mode)) \ | |||||
.c_str()); \ | |||||
break; \ | |||||
} | |||||
#define DISPATCH_CONV_KERN(stride) \ | |||||
switch (param.filter_meta.spatial[0]) { \ | |||||
case 2: \ | |||||
GET_BIAS_MODE_PARAM(stride, 2) \ | |||||
break; \ | |||||
case 3: \ | |||||
GET_BIAS_MODE_PARAM(stride, 3) \ | |||||
break; \ | |||||
case 5: \ | |||||
GET_BIAS_MODE_PARAM(stride, 5) \ | |||||
break; \ | |||||
case 7: \ | |||||
GET_BIAS_MODE_PARAM(stride, 7) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_throw(ssprintf("only support 2x2 3x3 5x5 7x7 filters size " \ | |||||
"when dst is " \ | |||||
"dt_int16 filter size is %u", \ | |||||
uint32_t(param.filter_meta.spatial[0])) \ | |||||
.c_str()); \ | |||||
break; \ | |||||
} | |||||
switch (param.filter_meta.stride[0]) { | |||||
case 1: | |||||
DISPATCH_CONV_KERN(1); | |||||
break; | |||||
case 2: | |||||
DISPATCH_CONV_KERN(2); | |||||
break; | |||||
default: | |||||
megdnn_throw(ssprintf("Unsupport stride size %u for the 8x8x16 direct conv", | |||||
param.filter_meta.stride[0]) | |||||
.c_str()); | |||||
break; | |||||
} | |||||
#undef DO_CONV_KERN_FUN | |||||
#undef GET_REMAIN_W_PARAM | |||||
#undef GET_OP_PARAM | |||||
#undef GET_BIAS_MODE_PARAM | |||||
#undef DISPATCH_CONV_KERN | |||||
megdnn_assert(do_conv_fun); | |||||
SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||||
constexpr size_t pack_oc = 4; | |||||
size_t oc_step = pack_oc; | |||||
if (fh == fw && (fh == 2 || fw == 3) && OC >= 8) { | |||||
oc_step = 8; | |||||
} | |||||
#if MEGDNN_ARMV7 | |||||
if (param.filter_meta.stride[0] == 1) { | |||||
if (group == 1) { | |||||
CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; | |||||
auto copy_padding = [wbundle]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
wbundle.set(kern_param.workspace_ptr); | |||||
copy_padding_kern(wbundle, kern_param, ncb_index, | |||||
ncb_index.ndrange_id); | |||||
}; | |||||
constexpr size_t pack_ic = 4; | |||||
ret_kerns.push_back( | |||||
{copy_padding, {N, group, div_ceil(IC, pack_ic)}}); | |||||
auto do_conv = [wbundle, do_conv_fun, ncb_range]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
wbundle.set(kern_param.workspace_ptr); | |||||
do_conv_fun(wbundle, kern_param, ncb_index, | |||||
ncb_index.ndrange_id, ncb_range); | |||||
}; | |||||
ret_kerns.push_back({do_conv, ncb_range}); | |||||
} else { | |||||
CpuNDRange ncb_range = {N, group, 1}; | |||||
auto do_conv = [wbundle, do_conv_fun, ncb_range]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
wbundle.set(kern_param.workspace_ptr); | |||||
copy_padding_kern(wbundle, kern_param, ncb_index, | |||||
{0, ncb_index.thread_id, 0}); | |||||
do_conv_fun(wbundle, kern_param, ncb_index, | |||||
{0, ncb_index.thread_id, 0}, ncb_range); | |||||
}; | |||||
ret_kerns.push_back({do_conv, ncb_range}); | |||||
} | |||||
return ret_kerns; | |||||
} | |||||
#endif | |||||
bool need_padding = ph > 0 || pw >0; | |||||
if (group == 1) { | |||||
CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; | |||||
auto copy_padding = [wbundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
wbundle.set(kern_param.workspace_ptr); | |||||
copy_padding_kern_no_pack_src(wbundle, kern_param, ncb_index, | |||||
ncb_index.ndrange_id); | |||||
}; | |||||
constexpr size_t pack_ic = 4; | |||||
if (need_padding) { | |||||
ret_kerns.push_back( | |||||
{copy_padding, {N, group, div_ceil(IC, pack_ic)}}); | |||||
} | |||||
auto do_conv = [wbundle, do_conv_fun, ncb_range]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
wbundle.set(kern_param.workspace_ptr); | |||||
do_conv_fun(wbundle, kern_param, ncb_index, ncb_index.ndrange_id, | |||||
ncb_range); | |||||
}; | |||||
ret_kerns.push_back({do_conv, ncb_range}); | |||||
} else { | |||||
CpuNDRange ncb_range = {N, group, 1}; | |||||
auto do_conv = [wbundle, do_conv_fun, ncb_range, need_padding]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
wbundle.set(kern_param.workspace_ptr); | |||||
if (need_padding) { | |||||
copy_padding_kern_no_pack_src(wbundle, kern_param, ncb_index, | |||||
{0, ncb_index.thread_id, 0}); | |||||
}; | |||||
do_conv_fun(wbundle, kern_param, ncb_index, | |||||
{0, ncb_index.thread_id, 0}, ncb_range); | |||||
}; | |||||
ret_kerns.push_back({do_conv, ncb_range}); | |||||
} | |||||
return ret_kerns; | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,56 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw44_kern.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace int8x8x16_direct_nchw44 { | |||||
/** | |||||
origin src shape <n, ic/4, h, w, 4> | |||||
packed src shape <n, ic/4, h, w, 16> | |||||
example: (format like <ic>) | |||||
origin | |||||
<0> <1> <2> <3> | |||||
packed | |||||
low 64 bit <0> <0> <0> <0> | <1> <1> <1> <1> | |||||
--------------------------------------------------------------------- | |||||
high 64 bit <2> <2> <2> <2> | <3> <3> <3> <3> | |||||
**/ | |||||
static inline void nchw44_pack_src(const int8_t* src, int8_t* dst, int length) { | |||||
static const uint8_t src_idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, | |||||
2, 2, 2, 2, 3, 3, 3, 3}; | |||||
constexpr int pack_ic = 4; | |||||
constexpr int simd_len = 16; | |||||
uint8x16_t src_idx = vld1q_u8(src_idx_buffer); | |||||
for (int i = 0; i < length; i++) { | |||||
int8x16_t result = vld_dup_tbl_s32(src + i * pack_ic, src_idx); | |||||
vst1q_s8(dst + i * simd_len, result); | |||||
} | |||||
} | |||||
template <BiasMode bias_mode, int filter_size, int stride> | |||||
struct ConvDirectInt8Nchw44Choose { | |||||
static void impl(const int8_t* src, const int8_t* filter, | |||||
const int16_t* bias, int16_t* dst, const size_t oc, | |||||
const size_t ic, const size_t ih, const size_t iw, | |||||
const size_t oh, const size_t ow); | |||||
}; | |||||
} // namespace int8_direct_nchw44 | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,971 @@ | |||||
/** | |||||
* \file | |||||
* dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8_direct_nchw44_s1_aarch64.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#if MEGDNN_AARCH64 | |||||
#include "src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace { | |||||
#define INIT_SUM() \ | |||||
int16x4_t init_sum; \ | |||||
if (bias_mode == BiasMode::NO_BIAS) { \ | |||||
init_sum = vdup_n_s16(0); \ | |||||
} else { \ | |||||
init_sum = vld1_s16(bias_ptr); \ | |||||
} | |||||
#define STORE_1_LINE_RESULT() \ | |||||
switch (remain_w) { \ | |||||
case 8: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||||
vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \ | |||||
break; \ | |||||
case 1: \ | |||||
vst1_s16(dst_ptr, c[0][0]); \ | |||||
break; \ | |||||
case 2: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
break; \ | |||||
case 3: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1_s16(dst_ptr + 8, c[0][2]); \ | |||||
break; \ | |||||
case 4: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
break; \ | |||||
case 5: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
vst1_s16(dst_ptr + 16, c[0][4]); \ | |||||
break; \ | |||||
case 6: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||||
break; \ | |||||
case 7: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||||
vst1_s16(dst_ptr + 24, c[0][6]); \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0, "oc 1 error remainw"); \ | |||||
}; | |||||
#define STORE_2_LINE_RESULT_OW4() \ | |||||
switch (remain_w) { \ | |||||
case 4: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc + 8, \ | |||||
vcombine_s16(c[1][2], c[1][3])); \ | |||||
break; \ | |||||
case 1: \ | |||||
vst1_s16(dst_ptr, c[0][0]); \ | |||||
vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ | |||||
break; \ | |||||
case 2: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||||
break; \ | |||||
case 3: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1_s16(dst_ptr + 8, c[0][2]); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||||
vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0, "oc 2 error remainw"); \ | |||||
break; \ | |||||
} | |||||
#define STORE_1_LINE_RESULT_OW4_OH2() \ | |||||
switch (remain_w) { \ | |||||
case 4: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
vst1q_s16(dst_ptr + ow, vcombine_s16(c[0][4], c[0][5])); \ | |||||
vst1q_s16(dst_ptr + 8 + ow, vcombine_s16(c[0][6], c[0][7])); \ | |||||
break; \ | |||||
case 1: \ | |||||
vst1_s16(dst_ptr, c[0][0]); \ | |||||
vst1_s16(dst_ptr + ow, c[0][4]); \ | |||||
break; \ | |||||
case 2: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + ow, vcombine_s16(c[0][4], c[0][5])); \ | |||||
break; \ | |||||
case 3: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1_s16(dst_ptr + 8, c[0][2]); \ | |||||
vst1q_s16(dst_ptr + ow, vcombine_s16(c[0][4], c[0][5])); \ | |||||
vst1_s16(dst_ptr + ow + 8, c[0][6]); \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0, "oc 2 error remainw"); \ | |||||
break; \ | |||||
} | |||||
#define STORE_1_LINE_RESULT_OW4() \ | |||||
switch (remain_w) { \ | |||||
case 4: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
break; \ | |||||
case 1: \ | |||||
vst1_s16(dst_ptr, c[0][0]); \ | |||||
break; \ | |||||
case 2: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
break; \ | |||||
case 3: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1_s16(dst_ptr + 8, c[0][2]); \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0, "oc 1 error remainw"); \ | |||||
}; | |||||
template <BiasMode bias_mode,int filter_size> | |||||
static void ker_neon_dirctconv_2x2s1_oc8_ow4(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int16_t* bias_ptr, | |||||
int16_t* dst_ptr, int ic, int ih, | |||||
int iw, int remain_w,int ld_dst_oc) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int oc_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
const int ic_stride = ih * iw; | |||||
const int ld_weight_oc4 = oc_step * fh * fw * ic; | |||||
int16x4_t c[2][4]; | |||||
int8x16_t weight[2][2]; | |||||
int8x16_t src[5]; | |||||
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, | |||||
2, 2, 2, 2, 3, 3, 3, 3}; | |||||
static uint8x16_t idx = vld1q_u8(idx_buffer); | |||||
INIT_SUM(); | |||||
#define cb(_i) \ | |||||
c[0][_i] = init_sum; \ | |||||
c[1][_i] = init_sum; | |||||
UNROLL_CALL_RAW(4, cb); | |||||
#undef cb | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
const int8_t* src_row0 = | |||||
src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; | |||||
const int8_t* src_row1 = | |||||
src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; | |||||
src[0] = vld_dup_tbl_s32(src_row0 + 0, idx); | |||||
src[1] = vld_dup_tbl_s32(src_row0 + 4, idx); | |||||
src[2] = vld_dup_tbl_s32(src_row0 + 8, idx); | |||||
weight[0][0] = vld1q_s8(weight_ptr); | |||||
weight[0][1] = vld1q_s8(weight_ptr + 16); | |||||
weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4); | |||||
weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 16); | |||||
#define CALC_ONE_RESULT(_src0, _src1, _w0, _w1, _c) \ | |||||
do { \ | |||||
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w0)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w1)); \ | |||||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||||
} while (0); | |||||
int16x8_t tmp0; | |||||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][0]); | |||||
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][0]); | |||||
src[3] = vld_dup_tbl_s32(src_row0 + 12, idx); | |||||
src[4] = vld_dup_tbl_s32(src_row0 + 16, idx); | |||||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][1]); | |||||
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][1]); | |||||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][2]); | |||||
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][2]); | |||||
CALC_ONE_RESULT(src[3], src[4], weight[0][0], weight[0][1], c[0][3]); | |||||
CALC_ONE_RESULT(src[3], src[4], weight[1][0], weight[1][1], c[1][3]); | |||||
src[0] = vld_dup_tbl_s32(src_row1 + 0, idx); | |||||
src[1] = vld_dup_tbl_s32(src_row1 + 4, idx); | |||||
src[2] = vld_dup_tbl_s32(src_row1 + 8, idx); | |||||
weight[0][0] = vld1q_s8(weight_ptr + 32); | |||||
weight[0][1] = vld1q_s8(weight_ptr + 48); | |||||
weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4 + 32); | |||||
weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 48); | |||||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][0]); | |||||
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][0]); | |||||
src[3] = vld_dup_tbl_s32(src_row1 + 12, idx); | |||||
src[4] = vld_dup_tbl_s32(src_row1 + 16, idx); | |||||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][1]); | |||||
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][1]); | |||||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][2]); | |||||
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][2]); | |||||
CALC_ONE_RESULT(src[3], src[4], weight[0][0], weight[0][1], c[0][3]); | |||||
CALC_ONE_RESULT(src[3], src[4], weight[1][0], weight[1][1], c[1][3]); | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
STORE_2_LINE_RESULT_OW4(); | |||||
} | |||||
template <BiasMode bias_mode, int filter_size> | |||||
static void ker_neon_dirctconv_2x2s1_oc4_ow4(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int16_t* bias_ptr, | |||||
int16_t* dst_ptr, int ic, int ih, | |||||
int iw, int remain_w, | |||||
int /*ld_dst_oc*/) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, | |||||
2, 2, 2, 2, 3, 3, 3, 3}; | |||||
static uint8x16_t idx = vld1q_u8(idx_buffer); | |||||
const int ic_stride = ih * iw; | |||||
int16x4_t c[1][4]; | |||||
int8x16_t weight[1][2]; | |||||
int8x16_t src[5]; | |||||
INIT_SUM(); | |||||
#define cb(_i) c[0][_i] = init_sum; | |||||
UNROLL_CALL_RAW(4, cb); | |||||
#undef cb | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = | |||||
src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step; | |||||
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 0, idx); | |||||
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx); | |||||
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); | |||||
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); | |||||
src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx); | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0][0] = vld1q_s8(read_weight_ptr); | |||||
weight[0][1] = vld1q_s8(read_weight_ptr + 16); | |||||
int16x8_t tmp0; | |||||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], | |||||
c[0][0]); | |||||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], | |||||
c[0][1]); | |||||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], | |||||
c[0][2]); | |||||
CALC_ONE_RESULT(src[3], src[4], weight[0][0], weight[0][1], | |||||
c[0][3]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
STORE_1_LINE_RESULT_OW4(); | |||||
} | |||||
#undef CALC_ONE_RESULT | |||||
#define CALC_ONE_RESULT(_src0, _src1, _src2, _w, _c) \ | |||||
do { \ | |||||
int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w[0])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w[1])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w[2])); \ | |||||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||||
} while (0); | |||||
template <BiasMode bias_mode, int filter_size> | |||||
static void ker_neon_dirctconv_3x3s1_oc4_ow4(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int16_t* bias_ptr, | |||||
int16_t* dst_ptr, int ic, int ih, | |||||
int iw, int remain_w, | |||||
int /*ld_dst_oc*/) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
const int ic_stride = ih * iw; | |||||
int16x4_t c[1][4]; | |||||
int8x16_t weight[1][3]; | |||||
int8x16_t src[6]; | |||||
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, | |||||
2, 2, 2, 2, 3, 3, 3, 3}; | |||||
static uint8x16_t idx = vld1q_u8(idx_buffer); | |||||
INIT_SUM(); | |||||
#define cb(_i) c[0][_i] = init_sum; | |||||
UNROLL_CALL_RAW(4, cb); | |||||
#undef cb | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
const int8_t* src_row0 = | |||||
src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; | |||||
const int8_t* src_row1 = | |||||
src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; | |||||
const int8_t* src_row2 = | |||||
src_ptr + ic_idx * ic_stride + 2 * iw * ic_step; | |||||
src[0] = vld_dup_tbl_s32(src_row0 + 0, idx); | |||||
src[1] = vld_dup_tbl_s32(src_row0 + 4, idx); | |||||
src[2] = vld_dup_tbl_s32(src_row0 + 8, idx); | |||||
weight[0][0] = vld1q_s8(weight_ptr); | |||||
weight[0][1] = vld1q_s8(weight_ptr + 16); | |||||
weight[0][2] = vld1q_s8(weight_ptr + 32); | |||||
src[3] = vld_dup_tbl_s32(src_row0 + 12, idx); | |||||
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]); | |||||
src[4] = vld_dup_tbl_s32(src_row0 + 16, idx); | |||||
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][1]); | |||||
src[5] = vld_dup_tbl_s32(src_row0 + 20, idx); | |||||
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][2]); | |||||
CALC_ONE_RESULT(src[3], src[4], src[5], weight[0], c[0][3]); | |||||
src[0] = vld_dup_tbl_s32(src_row1 + 0, idx); | |||||
src[1] = vld_dup_tbl_s32(src_row1 + 4, idx); | |||||
src[2] = vld_dup_tbl_s32(src_row1 + 8, idx); | |||||
weight[0][0] = vld1q_s8(weight_ptr + 48); | |||||
weight[0][1] = vld1q_s8(weight_ptr + 64); | |||||
weight[0][2] = vld1q_s8(weight_ptr + 80); | |||||
src[3] = vld_dup_tbl_s32(src_row1 + 12, idx); | |||||
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]); | |||||
src[4] = vld_dup_tbl_s32(src_row1 + 16, idx); | |||||
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][1]); | |||||
src[5] = vld_dup_tbl_s32(src_row1 + 20, idx); | |||||
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][2]); | |||||
CALC_ONE_RESULT(src[3], src[4], src[5], weight[0], c[0][3]); | |||||
src[0] = vld_dup_tbl_s32(src_row2 + 0, idx); | |||||
src[1] = vld_dup_tbl_s32(src_row2 + 4, idx); | |||||
src[2] = vld_dup_tbl_s32(src_row2 + 8, idx); | |||||
weight[0][0] = vld1q_s8(weight_ptr + 96); | |||||
weight[0][1] = vld1q_s8(weight_ptr + 112); | |||||
weight[0][2] = vld1q_s8(weight_ptr + 128); | |||||
src[3] = vld_dup_tbl_s32(src_row2 + 12, idx); | |||||
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]); | |||||
src[4] = vld_dup_tbl_s32(src_row2 + 16, idx); | |||||
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][1]); | |||||
src[5] = vld_dup_tbl_s32(src_row2 + 20, idx); | |||||
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][2]); | |||||
CALC_ONE_RESULT(src[3], src[4], src[5], weight[0], c[0][3]); | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
STORE_1_LINE_RESULT_OW4(); | |||||
} | |||||
template <BiasMode bias_mode, int filter_size> | |||||
static void ker_neon_dirctconv_3x3s1_oc4_ow4_oh2(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int16_t* bias_ptr, | |||||
int16_t* dst_ptr, int ic, | |||||
int ih, int iw, int remain_w, | |||||
int /*ld_dst_oc*/, int ow) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
const int ic_stride = ih * iw; | |||||
int16x4_t c[1][8]; | |||||
int8x16_t weight[2][3]; | |||||
int8x16_t src[1][6]; | |||||
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, | |||||
2, 2, 2, 2, 3, 3, 3, 3}; | |||||
static uint8x16_t idx = vld1q_u8(idx_buffer); | |||||
INIT_SUM(); | |||||
#define cb(_i) c[0][_i] = init_sum; | |||||
UNROLL_CALL_RAW(8, cb); | |||||
#undef cb | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
const int8_t* src_row0 = | |||||
src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; | |||||
const int8_t* src_row1 = | |||||
src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; | |||||
const int8_t* src_row2 = | |||||
src_ptr + ic_idx * ic_stride + 2 * iw * ic_step; | |||||
const int8_t* src_row3 = | |||||
src_ptr + ic_idx * ic_stride + 3 * iw * ic_step; | |||||
#define LOAD_SRC(_src, _src_ptr) \ | |||||
_src[0] = vld_dup_tbl_s32(_src_ptr + 0, idx); \ | |||||
_src[1] = vld_dup_tbl_s32(_src_ptr + 4, idx); \ | |||||
_src[2] = vld_dup_tbl_s32(_src_ptr + 8, idx); \ | |||||
_src[3] = vld_dup_tbl_s32(_src_ptr + 12, idx); \ | |||||
_src[4] = vld_dup_tbl_s32(_src_ptr + 16, idx); \ | |||||
_src[5] = vld_dup_tbl_s32(_src_ptr + 20, idx); | |||||
LOAD_SRC(src[0], src_row0); | |||||
weight[0][0] = vld1q_s8(weight_ptr); | |||||
weight[0][1] = vld1q_s8(weight_ptr + 16); | |||||
weight[0][2] = vld1q_s8(weight_ptr + 32); | |||||
weight[1][0] = vld1q_s8(weight_ptr + 48); | |||||
weight[1][1] = vld1q_s8(weight_ptr + 64); | |||||
weight[1][2] = vld1q_s8(weight_ptr + 80); | |||||
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], | |||||
c[0][0]); // row0 src0 w0 | |||||
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][1]); | |||||
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][2]); | |||||
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][3]); | |||||
LOAD_SRC(src[0], src_row1); | |||||
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], | |||||
c[0][4]); // row1 src1 w0 | |||||
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][5]); | |||||
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][6]); | |||||
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][7]); | |||||
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[1], | |||||
c[0][0]); // row1 src1 w1 | |||||
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[1], c[0][1]); | |||||
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[1], c[0][2]); | |||||
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[1], c[0][3]); | |||||
LOAD_SRC(src[0], src_row2); | |||||
weight[0][0] = vld1q_s8(weight_ptr + 96); | |||||
weight[0][1] = vld1q_s8(weight_ptr + 112); | |||||
weight[0][2] = vld1q_s8(weight_ptr + 128); | |||||
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[1], | |||||
c[0][4]); // row2 src0 w1 | |||||
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[1], c[0][5]); | |||||
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[1], c[0][6]); | |||||
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[1], c[0][7]); | |||||
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], | |||||
c[0][0]); // row2 w0 src[0] | |||||
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][1]); | |||||
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][2]); | |||||
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][3]); | |||||
LOAD_SRC(src[0], src_row3); | |||||
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], | |||||
c[0][4]); // row3 w0 src1 | |||||
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][5]); | |||||
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][6]); | |||||
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][7]); | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
STORE_1_LINE_RESULT_OW4_OH2(); | |||||
} | |||||
#undef LOAD_SRC | |||||
#undef CALC_ONE_RESULT | |||||
template <BiasMode bias_mode, int filter_size> | |||||
struct KerNeonDirectStride1Int8 { | |||||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, | |||||
int iw, int remain_w, int ld_dst_oc); | |||||
}; | |||||
template <BiasMode bias_mode> | |||||
struct KerNeonDirectStride1Int8<bias_mode, 5> { | |||||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, | |||||
int iw, int remain_w, int /*ld_dst_oc*/) { | |||||
constexpr int filter_size = 5; | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
const int ic_stride = ih * iw; | |||||
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, | |||||
2, 2, 2, 2, 3, 3, 3, 3}; | |||||
static uint8x16_t idx = vld1q_u8(idx_buffer); | |||||
int16x4_t c[1][8]; | |||||
int8x16_t weight[5]; | |||||
int8x16_t src[8 + 2]; | |||||
INIT_SUM(); | |||||
#define cb(_i) c[0][_i] = init_sum; | |||||
UNROLL_CALL_RAW(8, cb); | |||||
#undef cb | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = | |||||
src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step; | |||||
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 0 * 4, idx); | |||||
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 1 * 4, idx); | |||||
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 2 * 4, idx); | |||||
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 3 * 4, idx); | |||||
src[4] = vld_dup_tbl_s32(src_ic_0_3 + 4 * 4, idx); | |||||
src[5] = vld_dup_tbl_s32(src_ic_0_3 + 5 * 4, idx); | |||||
src[6] = vld_dup_tbl_s32(src_ic_0_3 + 6 * 4, idx); | |||||
src[7] = vld_dup_tbl_s32(src_ic_0_3 + 7 * 4, idx); | |||||
src[8] = vld_dup_tbl_s32(src_ic_0_3 + 8 * 4, idx); | |||||
src[9] = vld_dup_tbl_s32(src_ic_0_3 + 9 * 4, idx); | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0] = vld1q_s8(read_weight_ptr); | |||||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _w0, _w1, _w2, _w3, \ | |||||
_w4, _c) \ | |||||
do { \ | |||||
int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ | |||||
int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w0)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ | |||||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w1)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \ | |||||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w2)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w3)); \ | |||||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w3)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w4)); \ | |||||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w4)); \ | |||||
tmp0 = vaddq_s16(tmp0, tmp1); \ | |||||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||||
} while (0); | |||||
CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][0]); | |||||
CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][1]); | |||||
CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][2]); | |||||
CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][3]); | |||||
CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][4]); | |||||
CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][5]); | |||||
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 10 * 4, idx); | |||||
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 11 * 4, idx); | |||||
CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][6]); | |||||
CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][7]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
STORE_1_LINE_RESULT(); | |||||
} | |||||
}; | |||||
#undef CALC_ONE_RESULT | |||||
template <BiasMode bias_mode> | |||||
struct KerNeonDirectStride1Int8<bias_mode, 7> { | |||||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, | |||||
int iw, int remain_w, int /*ld_dst_oc*/) { | |||||
constexpr int filter_size = 7; | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
const int ic_stride = ih * iw; | |||||
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, | |||||
2, 2, 2, 2, 3, 3, 3, 3}; | |||||
static uint8x16_t idx = vld1q_u8(idx_buffer); | |||||
int16x4_t c[1][8]; | |||||
int8x16_t weight[7]; | |||||
int8x16_t src[8 + 2]; | |||||
INIT_SUM(); | |||||
#define cb(_i) c[0][_i] = init_sum; | |||||
UNROLL_CALL_RAW(8, cb); | |||||
#undef cb | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = | |||||
src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step; | |||||
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 0 * 4, idx); | |||||
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 1 * 4, idx); | |||||
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 2 * 4, idx); | |||||
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 3 * 4, idx); | |||||
src[4] = vld_dup_tbl_s32(src_ic_0_3 + 4 * 4, idx); | |||||
src[5] = vld_dup_tbl_s32(src_ic_0_3 + 5 * 4, idx); | |||||
src[6] = vld_dup_tbl_s32(src_ic_0_3 + 6 * 4, idx); | |||||
src[7] = vld_dup_tbl_s32(src_ic_0_3 + 7 * 4, idx); | |||||
src[8] = vld_dup_tbl_s32(src_ic_0_3 + 8 * 4, idx); | |||||
src[9] = vld_dup_tbl_s32(src_ic_0_3 + 9 * 4, idx); | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0] = vld1q_s8(read_weight_ptr); | |||||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); | |||||
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); | |||||
#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, \ | |||||
_c) \ | |||||
do { \ | |||||
int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ | |||||
int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w[0])); \ | |||||
int16x8_t tmp2 = vmull_s8(vget_low_s8(_src1), vget_low_s8(_w[1])); \ | |||||
int16x8_t tmp3 = vmull_s8(vget_high_s8(_src1), vget_high_s8(_w[1])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ | |||||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w[2])); \ | |||||
tmp2 = vmlal_s8(tmp2, vget_low_s8(_src3), vget_low_s8(_w[3])); \ | |||||
tmp3 = vmlal_s8(tmp3, vget_high_s8(_src3), vget_high_s8(_w[3])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \ | |||||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w[4])); \ | |||||
tmp2 = vmlal_s8(tmp2, vget_low_s8(_src5), vget_low_s8(_w[5])); \ | |||||
tmp3 = vmlal_s8(tmp3, vget_high_s8(_src5), vget_high_s8(_w[5])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \ | |||||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src6), vget_high_s8(_w[6])); \ | |||||
tmp0 = vaddq_s16(tmp0, tmp1); \ | |||||
tmp2 = vaddq_s16(tmp2, tmp3); \ | |||||
tmp0 = vaddq_s16(tmp0, tmp2); \ | |||||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||||
} while (0); | |||||
CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], src[5], | |||||
src[6], weight, c[0][0]); | |||||
CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], src[6], | |||||
src[7], weight, c[0][1]); | |||||
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 10 * 4, idx); | |||||
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 11 * 4, idx); | |||||
CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], src[7], | |||||
src[8], weight, c[0][2]); | |||||
CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], src[8], | |||||
src[9], weight, c[0][3]); | |||||
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 12 * 4, idx); | |||||
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 13 * 4, idx); | |||||
CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], src[9], | |||||
src[0], weight, c[0][4]); | |||||
CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], src[0], | |||||
src[1], weight, c[0][5]); | |||||
CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], src[1], | |||||
src[2], weight, c[0][6]); | |||||
CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], src[2], | |||||
src[3], weight, c[0][7]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
STORE_1_LINE_RESULT(); | |||||
} | |||||
}; | |||||
#undef CALC_ONE_RESULT | |||||
template <BiasMode bias_mode> | |||||
void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, | |||||
const int8_t* filter, | |||||
const int16_t* bias, int16_t* dst, | |||||
const size_t oc, const size_t ic, | |||||
const size_t ih, const size_t iw, | |||||
const size_t oh, const size_t ow) { | |||||
constexpr size_t filter_size = 2; | |||||
constexpr size_t fh = filter_size; | |||||
constexpr size_t fw = filter_size; | |||||
constexpr size_t ic_step = 4; | |||||
constexpr size_t oc_step = 4; | |||||
constexpr size_t big_oc_step = 8; | |||||
constexpr size_t oh_step = 1; | |||||
constexpr size_t ow_step = 4; | |||||
const size_t img_stride = oh * ow; | |||||
const size_t ow_end = ow / ow_step * ow_step; | |||||
const size_t ow_remain = ow - ow_end; | |||||
const size_t oc_end = oc / big_oc_step * big_oc_step; | |||||
const size_t oc_remain = oc - oc_end; | |||||
const int ld_oc = oh * ow * oc_step; | |||||
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
size_t oh_idx = 0; | |||||
for (; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_2x2s1_oc8_ow4<bias_mode, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ow_step, ld_oc); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
ker_neon_dirctconv_2x2s1_oc8_ow4<bias_mode, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ow_remain, ld_oc); | |||||
} | |||||
} | |||||
} | |||||
if (oc_remain > 0) { | |||||
const size_t oc_idx = oc_end; | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_2x2s1_oc4_ow4<bias_mode, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ow_step, ld_oc); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
ker_neon_dirctconv_2x2s1_oc4_ow4<bias_mode, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ow_remain, ld_oc); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <BiasMode bias_mode> | |||||
void conv_direct_stride1_3x3_int8x8x16_oh2_nchw44( | |||||
const int8_t* src, const int8_t* filter, const int16_t* bias, | |||||
int16_t* dst, const size_t oc, const size_t ic, const size_t ih, | |||||
const size_t iw, const size_t oh, const size_t ow) { | |||||
constexpr size_t filter_size = 3; | |||||
constexpr size_t fh = filter_size; | |||||
constexpr size_t fw = filter_size; | |||||
constexpr size_t ic_step = 4; | |||||
constexpr size_t oc_step = 4; | |||||
constexpr size_t big_oc_step = 4; | |||||
constexpr size_t oh_step = 1; | |||||
constexpr size_t ow_step = 4; | |||||
const size_t img_stride = oh * ow; | |||||
const size_t ow_end = ow / ow_step * ow_step; | |||||
const size_t ow_remain = ow - ow_end; | |||||
const size_t oc_end = oc / big_oc_step * big_oc_step; | |||||
const int ld_oc = oh * ow * oc_step; | |||||
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
size_t oh_idx = 0; | |||||
for (; oh_idx + 1 < oh; oh_idx += 2) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_3x3s1_oc4_ow4_oh2<bias_mode, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ow_step, ld_oc, | |||||
ow * oc_step); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
ker_neon_dirctconv_3x3s1_oc4_ow4_oh2<bias_mode, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ow_remain, ld_oc, | |||||
ow * oc_step); | |||||
} | |||||
} | |||||
for (; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_3x3s1_oc4_ow4<bias_mode, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ow_step, ld_oc); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
ker_neon_dirctconv_3x3s1_oc4_ow4<bias_mode, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ow_remain, ld_oc); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <BiasMode bias_mode, int filter_size> | |||||
void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, | |||||
const int8_t* filter, | |||||
const int16_t* bias, int16_t* dst, | |||||
const size_t oc, const size_t ic, | |||||
const size_t ih, const size_t iw, | |||||
const size_t oh, const size_t ow) { | |||||
constexpr size_t fh = filter_size; | |||||
constexpr size_t fw = filter_size; | |||||
constexpr size_t ic_step = 4; | |||||
constexpr size_t oc_step = 4; | |||||
constexpr size_t oh_step = 1; | |||||
constexpr size_t ow_step = 8; | |||||
const size_t img_stride = oh * ow; | |||||
const int ld_dst_oc = oh * ow * oc_step; | |||||
const size_t ow_end = ow / ow_step * ow_step; | |||||
const size_t ow_remain = ow - ow_end; | |||||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
KerNeonDirectStride1Int8<bias_mode, filter_size>::impl( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ow_step, ld_dst_oc); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
KerNeonDirectStride1Int8<bias_mode, filter_size>::impl( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ow_remain, ld_dst_oc); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} // namespace | |||||
namespace int8x8x16_direct_nchw44 { | |||||
template <BiasMode bias_mode, int filter_size> | |||||
struct ConvDirectInt8Nchw44Choose<bias_mode, filter_size, 1> { | |||||
static void impl(const int8_t* src, const int8_t* filter, | |||||
const int16_t* bias, int16_t* dst, const size_t oc, | |||||
const size_t ic, const size_t ih, const size_t iw, | |||||
const size_t oh, const size_t ow) { | |||||
conv_direct_stride1_int8_nchw44_kern<bias_mode, filter_size>( | |||||
src, filter, bias, dst, oc, ic, ih, iw, oh, ow); | |||||
} | |||||
}; | |||||
template <BiasMode bias_mode> | |||||
struct ConvDirectInt8Nchw44Choose<bias_mode, 2, 1> { | |||||
static void impl(const int8_t* src, const int8_t* filter, | |||||
const int16_t* bias, int16_t* dst, const size_t oc, | |||||
const size_t ic, const size_t ih, const size_t iw, | |||||
const size_t oh, const size_t ow) { | |||||
conv_direct_stride1_2x2_int8_nchw44<bias_mode>(src, filter, bias, dst, | |||||
oc, ic, ih, iw, oh, ow); | |||||
} | |||||
}; | |||||
template <BiasMode bias_mode> | |||||
struct ConvDirectInt8Nchw44Choose<bias_mode, 3, 1> { | |||||
static void impl(const int8_t* src, const int8_t* filter, | |||||
const int16_t* bias, int16_t* dst, const size_t oc, | |||||
const size_t ic, const size_t ih, const size_t iw, | |||||
const size_t oh, const size_t ow) { | |||||
conv_direct_stride1_3x3_int8x8x16_oh2_nchw44<bias_mode>( | |||||
src, filter, bias, dst, oc, ic, ih, iw, oh, ow); | |||||
} | |||||
}; | |||||
#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode) \ | |||||
template struct ConvDirectInt8Nchw44Choose<bias_mode, filter_size, stride>; | |||||
#define GET_OP_PARAM(stride, filter, bias_mode) \ | |||||
DO_CONV_KERN_FUN(stride, filter, bias_mode) | |||||
#define GET_BIAS_MODE_PARAM(stride, filter) \ | |||||
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||||
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
#define DISPATCH_CONV_KERN(stride) \ | |||||
GET_BIAS_MODE_PARAM(stride, 2) \ | |||||
GET_BIAS_MODE_PARAM(stride, 3) \ | |||||
GET_BIAS_MODE_PARAM(stride, 5) \ | |||||
GET_BIAS_MODE_PARAM(stride, 7) | |||||
DISPATCH_CONV_KERN(1); | |||||
} // namespace int8x8x16_direct_nchw44 | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,854 @@ | |||||
/** | |||||
* \file | |||||
* dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8_direct_nchw44_s1_armv7.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/common/utils.h" | |||||
#if MEGDNN_ARMV7 | |||||
#include "src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace { | |||||
#define INIT_SUM() \ | |||||
int16x4_t init_sum; \ | |||||
if (bias_mode == BiasMode::NO_BIAS) { \ | |||||
init_sum = vdup_n_s16(0); \ | |||||
} else { \ | |||||
init_sum = vld1_s16(bias_ptr); \ | |||||
} | |||||
#define STORE_1_LINE_RESULT() \ | |||||
switch (remain_w) { \ | |||||
case 8: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||||
vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \ | |||||
break; \ | |||||
case 1: \ | |||||
vst1_s16(dst_ptr, c[0][0]); \ | |||||
break; \ | |||||
case 2: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
break; \ | |||||
case 3: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1_s16(dst_ptr + 8, c[0][2]); \ | |||||
break; \ | |||||
case 4: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
break; \ | |||||
case 5: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
vst1_s16(dst_ptr + 16, c[0][4]); \ | |||||
break; \ | |||||
case 6: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||||
break; \ | |||||
case 7: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||||
vst1_s16(dst_ptr + 24, c[0][6]); \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0, "oc 1 error remainw"); \ | |||||
}; | |||||
#define STORE_2_LINE_RESULT() \ | |||||
switch (remain_w) { \ | |||||
case 8: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||||
vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc + 8, \ | |||||
vcombine_s16(c[1][2], c[1][3])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc + 16, \ | |||||
vcombine_s16(c[1][4], c[1][5])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc + 24, \ | |||||
vcombine_s16(c[1][6], c[1][7])); \ | |||||
break; \ | |||||
case 1: \ | |||||
vst1_s16(dst_ptr, c[0][0]); \ | |||||
vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ | |||||
break; \ | |||||
case 2: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||||
break; \ | |||||
case 3: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1_s16(dst_ptr + 8, c[0][2]); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||||
vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ | |||||
break; \ | |||||
case 4: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc + 8, \ | |||||
vcombine_s16(c[1][2], c[1][3])); \ | |||||
break; \ | |||||
case 5: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
vst1_s16(dst_ptr + 16, c[0][4]); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc + 8, \ | |||||
vcombine_s16(c[1][2], c[1][3])); \ | |||||
vst1_s16(dst_ptr + ld_dst_oc + 16, c[1][4]); \ | |||||
break; \ | |||||
case 6: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc + 8, \ | |||||
vcombine_s16(c[1][2], c[1][3])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc + 16, \ | |||||
vcombine_s16(c[1][4], c[1][5])); \ | |||||
break; \ | |||||
case 7: \ | |||||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||||
vst1_s16(dst_ptr + 24, c[0][6]); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc + 8, \ | |||||
vcombine_s16(c[1][2], c[1][3])); \ | |||||
vst1q_s16(dst_ptr + ld_dst_oc + 16, \ | |||||
vcombine_s16(c[1][4], c[1][5])); \ | |||||
vst1_s16(dst_ptr + ld_dst_oc + 24, c[1][6]); \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0, "oc 2 error remainw"); \ | |||||
break; \ | |||||
} | |||||
template <BiasMode bias_mode, int remain_w, int filter_size> | |||||
static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int16_t* bias_ptr, | |||||
int16_t* dst_ptr, int ic, int ih, | |||||
int iw, int ld_dst_oc) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int oc_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
constexpr int src_expand_size = 4; | |||||
const int ic_stride = ih * iw * src_expand_size; | |||||
const int ld_weight_oc4 = oc_step * fh * fw * ic; | |||||
int16x4_t c[2][8]; | |||||
int8x16_t weight[2][2]; | |||||
int8x16_t src[4]; | |||||
INIT_SUM(); | |||||
#define cb(_i) \ | |||||
c[0][_i] = init_sum; \ | |||||
c[1][_i] = init_sum; | |||||
UNROLL_CALL_RAW(8, cb); | |||||
#undef cb | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
const int8_t* src_row0 = src_ptr + ic_idx * ic_stride + | |||||
0 * iw * ic_step * src_expand_size; | |||||
const int8_t* src_row1 = src_ptr + ic_idx * ic_stride + | |||||
1 * iw * ic_step * src_expand_size; | |||||
src[0] = vld1q_s8(src_row0); | |||||
src[1] = vld1q_s8(src_row0 + 16); | |||||
weight[0][0] = vld1q_s8(weight_ptr); | |||||
weight[0][1] = vld1q_s8(weight_ptr + 16); | |||||
weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4); | |||||
weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 16); | |||||
#define CALC_ONE_RESULT(_src0, _src1, _w0, _w1, _c) \ | |||||
do { \ | |||||
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w0)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w1)); \ | |||||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||||
} while (0); | |||||
int16x8_t tmp0; | |||||
src[2] = vld1q_s8(src_row0 + 2 * 16); | |||||
src[3] = vld1q_s8(src_row0 + 3 * 16); | |||||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][0]); | |||||
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][0]); | |||||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][1]); | |||||
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][1]); | |||||
src[0] = vld1q_s8(src_row0 + 4 * 16); | |||||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][2]); | |||||
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][2]); | |||||
src[1] = vld1q_s8(src_row0 + 5 * 16); | |||||
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][3]); | |||||
CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][3]); | |||||
src[2] = vld1q_s8(src_row0 + 6 * 16); | |||||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][4]); | |||||
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][4]); | |||||
src[3] = vld1q_s8(src_row0 + 7 * 16); | |||||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][5]); | |||||
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][5]); | |||||
src[0] = vld1q_s8(src_row0 + 8 * 16); | |||||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][6]); | |||||
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][6]); | |||||
src[1] = vld1q_s8(src_row1 + 0 * 16); | |||||
src[2] = vld1q_s8(src_row1 + 1 * 16); | |||||
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][7]); | |||||
CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][7]); | |||||
weight[0][0] = vld1q_s8(weight_ptr + 32); | |||||
weight[0][1] = vld1q_s8(weight_ptr + 48); | |||||
src[3] = vld1q_s8(src_row1 + 2 * 16); | |||||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][0]); | |||||
weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4 + 32); | |||||
weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 48); | |||||
src[0] = vld1q_s8(src_row1 + 3 * 16); | |||||
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][0]); | |||||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][1]); | |||||
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][1]); | |||||
src[1] = vld1q_s8(src_row1 + 4 * 16); | |||||
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][2]); | |||||
CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][2]); | |||||
src[2] = vld1q_s8(src_row1 + 5 * 16); | |||||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][3]); | |||||
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][3]); | |||||
src[3] = vld1q_s8(src_row1 + 6 * 16); | |||||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][4]); | |||||
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][4]); | |||||
src[0] = vld1q_s8(src_row1 + 7 * 16); | |||||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][5]); | |||||
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][5]); | |||||
src[1] = vld1q_s8(src_row1 + 8 * 16); | |||||
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][6]); | |||||
CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][6]); | |||||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][7]); | |||||
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][7]); | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
STORE_2_LINE_RESULT(); | |||||
} | |||||
template <BiasMode bias_mode, int remain_w, int filter_size> | |||||
static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int16_t* bias_ptr, | |||||
int16_t* dst_ptr, int ic, int ih, | |||||
int iw, int /*ld_dst_oc*/) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
constexpr int src_expand_size = 4; | |||||
const int ic_stride = ih * iw * src_expand_size; | |||||
int16x4_t c[1][8]; | |||||
int8x16_t weight[1][2]; | |||||
int8x16_t src[4]; | |||||
INIT_SUM(); | |||||
#define cb(_i) c[0][_i] = init_sum; | |||||
UNROLL_CALL_RAW(8, cb); | |||||
#undef cb | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * src_expand_size; | |||||
src[0] = vld1q_s8(src_ic_0_3); | |||||
src[1] = vld1q_s8(src_ic_0_3 + 16); | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0][0] = vld1q_s8(read_weight_ptr); | |||||
weight[0][1] = vld1q_s8(read_weight_ptr + 16); | |||||
int16x8_t tmp0; | |||||
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||||
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], | |||||
c[0][0]); | |||||
src[0] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], | |||||
c[0][1]); | |||||
src[1] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], | |||||
c[0][2]); | |||||
src[2] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||||
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], | |||||
c[0][3]); | |||||
src[3] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], | |||||
c[0][4]); | |||||
src[0] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], | |||||
c[0][5]); | |||||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], | |||||
c[0][6]); | |||||
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], | |||||
c[0][7]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
STORE_1_LINE_RESULT(); | |||||
} | |||||
#undef CALC_ONE_RESULT | |||||
template <BiasMode bias_mode, int remain_w, int filter_size> | |||||
struct KerNeonDirectStride1Int8 { | |||||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, | |||||
int iw, int ld_dst_oc); | |||||
}; | |||||
template <BiasMode bias_mode, int remain_w> | |||||
struct KerNeonDirectStride1Int8<bias_mode, remain_w, 3> { | |||||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, | |||||
int iw, int /*ld_dst_oc*/) { | |||||
constexpr int filter_size = 3; | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
constexpr int src_expand_size = 4; | |||||
const int ic_stride = ih * iw * src_expand_size; | |||||
int16x4_t c[1][8]; | |||||
int8x16_t weight[3]; | |||||
int8x16_t src[5]; | |||||
INIT_SUM(); | |||||
#define cb(_i) c[0][_i] = init_sum; | |||||
UNROLL_CALL_RAW(8, cb); | |||||
#undef cb | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = | |||||
src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * src_expand_size; | |||||
src[0] = vld1q_s8(src_ic_0_3 + 0 * 16); | |||||
src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); | |||||
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0] = vld1q_s8(read_weight_ptr); | |||||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||||
#define CALC_ONE_RESULT(_src0, _src1, _src2, _w0, _w1, _w2, _c) \ | |||||
do { \ | |||||
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w0)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w1)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w2)); \ | |||||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||||
} while (0); | |||||
int16x8_t tmp0; | |||||
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], weight[1], | |||||
weight[2], c[0][0]); | |||||
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||||
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], weight[1], | |||||
weight[2], c[0][1]); | |||||
src[0] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||||
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], weight[1], | |||||
weight[2], c[0][2]); | |||||
src[1] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||||
CALC_ONE_RESULT(src[3], src[4], src[0], weight[0], weight[1], | |||||
weight[2], c[0][3]); | |||||
src[2] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||||
CALC_ONE_RESULT(src[4], src[0], src[1], weight[0], weight[1], | |||||
weight[2], c[0][4]); | |||||
src[3] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||||
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], weight[1], | |||||
weight[2], c[0][5]); | |||||
src[4] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||||
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], weight[1], | |||||
weight[2], c[0][6]); | |||||
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], weight[1], | |||||
weight[2], c[0][7]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
STORE_1_LINE_RESULT(); | |||||
} | |||||
}; | |||||
#undef CALC_ONE_RESULT | |||||
template <BiasMode bias_mode, int remain_w> | |||||
struct KerNeonDirectStride1Int8<bias_mode, remain_w, 5> { | |||||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, | |||||
int iw, int /*ld_dst_oc*/) { | |||||
constexpr int filter_size = 5; | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
constexpr int src_expand_size = 4; | |||||
const int ic_stride = ih * iw * src_expand_size; | |||||
int16x4_t c[1][8]; | |||||
int8x16_t weight[5]; | |||||
int8x16_t src[8 + 2]; | |||||
INIT_SUM(); | |||||
#define cb(_i) c[0][_i] = init_sum; | |||||
UNROLL_CALL_RAW(8, cb); | |||||
#undef cb | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = | |||||
src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * src_expand_size; | |||||
src[0] = vld1q_s8(src_ic_0_3 + 0 * 16); | |||||
src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); | |||||
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||||
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||||
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||||
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||||
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||||
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||||
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||||
src[9] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0] = vld1q_s8(read_weight_ptr); | |||||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _w0, _w1, _w2, _w3, \ | |||||
_w4, _c) \ | |||||
do { \ | |||||
int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ | |||||
int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w0)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ | |||||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w1)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \ | |||||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w2)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w3)); \ | |||||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w3)); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w4)); \ | |||||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w4)); \ | |||||
tmp0 = vaddq_s16(tmp0, tmp1); \ | |||||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||||
} while (0); | |||||
CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][0]); | |||||
CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][1]); | |||||
CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][2]); | |||||
CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][3]); | |||||
CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][4]); | |||||
CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][5]); | |||||
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||||
CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][6]); | |||||
CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], | |||||
weight[0], weight[1], weight[2], weight[3], | |||||
weight[4], c[0][7]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
STORE_1_LINE_RESULT(); | |||||
} | |||||
}; | |||||
#undef CALC_ONE_RESULT | |||||
template <BiasMode bias_mode, int remain_w> | |||||
struct KerNeonDirectStride1Int8<bias_mode, remain_w, 7> { | |||||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, | |||||
int iw, int /*ld_dst_oc*/) { | |||||
constexpr int filter_size = 7; | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
constexpr int src_expand_size = 4; | |||||
const int ic_stride = ih * iw * src_expand_size; | |||||
int16x4_t c[1][8]; | |||||
int8x16_t weight[7]; | |||||
int8x16_t src[8 + 2]; | |||||
INIT_SUM(); | |||||
#define cb(_i) c[0][_i] = init_sum; | |||||
UNROLL_CALL_RAW(8, cb); | |||||
#undef cb | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = | |||||
src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * src_expand_size; | |||||
src[0] = vld1q_s8(src_ic_0_3 + 0 * 16); | |||||
src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); | |||||
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||||
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||||
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||||
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||||
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||||
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||||
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||||
src[9] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0] = vld1q_s8(read_weight_ptr); | |||||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); | |||||
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); | |||||
#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, \ | |||||
_c) \ | |||||
do { \ | |||||
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w[0])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w[1])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w[2])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w[3])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src3), vget_high_s8(_w[3])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src4), vget_high_s8(_w[4])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src5), vget_low_s8(_w[5])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src5), vget_high_s8(_w[5])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \ | |||||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src6), vget_high_s8(_w[6])); \ | |||||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||||
} while (0); | |||||
int16x8_t tmp0; | |||||
CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], src[5], | |||||
src[6], weight, c[0][0]); | |||||
CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], src[6], | |||||
src[7], weight, c[0][1]); | |||||
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||||
CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], src[7], | |||||
src[8], weight, c[0][2]); | |||||
CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], src[8], | |||||
src[9], weight, c[0][3]); | |||||
src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||||
src[3] = vld1q_s8(src_ic_0_3 + 13 * 16); | |||||
CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], src[9], | |||||
src[0], weight, c[0][4]); | |||||
CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], src[0], | |||||
src[1], weight, c[0][5]); | |||||
CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], src[1], | |||||
src[2], weight, c[0][6]); | |||||
CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], src[2], | |||||
src[3], weight, c[0][7]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
STORE_1_LINE_RESULT(); | |||||
} | |||||
}; | |||||
template <BiasMode bias_mode> | |||||
void conv_direct_stride1_2x2_int8_oc8_ow8_nchw44( | |||||
const int8_t* src, const int8_t* filter, const int16_t* bias, | |||||
int16_t* dst, const size_t oc, const size_t ic, const size_t ih, | |||||
const size_t iw, const size_t oh, const size_t ow) { | |||||
constexpr size_t filter_size = 2; | |||||
constexpr size_t fh = filter_size; | |||||
constexpr size_t fw = filter_size; | |||||
constexpr size_t ic_step = 4; | |||||
constexpr size_t oc_step = 4; | |||||
constexpr size_t big_oc_step = 8; | |||||
constexpr size_t oh_step = 1; | |||||
constexpr size_t ow_step = 8; | |||||
constexpr size_t src_expand_size = 4; | |||||
const size_t img_stride = oh * ow; | |||||
const size_t ow_end = ow / ow_step * ow_step; | |||||
const size_t ow_remain = ow - ow_end; | |||||
const size_t oc_end = oc / big_oc_step * big_oc_step; | |||||
const size_t oc_remain = oc - oc_end; | |||||
const int ld_oc = oh * ow * oc_step; | |||||
using remain_fun = | |||||
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int16_t* bias_ptr, int16_t* dst_ptr, | |||||
int ic, int ih, int iw, int ld_dst_oc)>; | |||||
remain_fun kern_big_oc_remain = nullptr; | |||||
remain_fun kern_small_oc_remain = nullptr; | |||||
switch (ow_remain) { | |||||
#define cb(step) \ | |||||
case step: \ | |||||
kern_big_oc_remain = ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, step, \ | |||||
filter_size>; \ | |||||
kern_small_oc_remain = \ | |||||
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, step, \ | |||||
filter_size>; \ | |||||
break; | |||||
UNROLL_CALL_RAW(8, cb); | |||||
default: | |||||
megdnn_assert(0, "no remain %zu for kern", ow_remain); | |||||
} | |||||
#undef cb | |||||
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
size_t oh_idx = 0; | |||||
for (; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_idx) * ic_step * src_expand_size; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, ow_step, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ld_oc); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_end) * ic_step * src_expand_size; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
kern_big_oc_remain(src + src_offset, filter + weight_offset, | |||||
bias + oc_idx, dst + dst_offset, ic, ih, iw, | |||||
ld_oc); | |||||
} | |||||
} | |||||
} | |||||
if (oc_remain > 0) { | |||||
const size_t oc_idx = oc_end; | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_idx) * ic_step * src_expand_size; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, ow_step, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ld_oc); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_end) * ic_step * src_expand_size; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||||
bias + oc_idx, dst + dst_offset, ic, ih, | |||||
iw, ld_oc); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#undef CALC_ONE_RESULT | |||||
template <BiasMode bias_mode, int filter_size> | |||||
void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, | |||||
const int8_t* filter, | |||||
const int16_t* bias, int16_t* dst, | |||||
const size_t oc, const size_t ic, | |||||
const size_t ih, const size_t iw, | |||||
const size_t oh, const size_t ow) { | |||||
constexpr size_t fh = filter_size; | |||||
constexpr size_t fw = filter_size; | |||||
constexpr size_t ic_step = 4; | |||||
constexpr size_t oc_step = 4; | |||||
constexpr size_t oh_step = 1; | |||||
constexpr size_t ow_step = 8; | |||||
constexpr size_t src_expand_size = 4; | |||||
const size_t img_stride = oh * ow; | |||||
const int ld_dst_oc = oh * ow * oc_step; | |||||
const size_t ow_end = ow / ow_step * ow_step; | |||||
const size_t ow_remain = ow - ow_end; | |||||
using remain_fun = | |||||
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int16_t* bias_ptr, int16_t* dst_ptr, | |||||
int ic, int ih, int iw, int ld_dst_oc)>; | |||||
remain_fun kern_small_oc_remain = nullptr; | |||||
switch (ow_remain) { | |||||
#define cb(step) \ | |||||
case step: \ | |||||
kern_small_oc_remain = KerNeonDirectStride1Int8<bias_mode, step, \ | |||||
filter_size>::impl; \ | |||||
break; | |||||
UNROLL_CALL_RAW(8, cb); | |||||
default: | |||||
megdnn_assert(0, "no remain %zu for kern", ow_remain); | |||||
} | |||||
#undef cb | |||||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_idx) * ic_step * src_expand_size; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
KerNeonDirectStride1Int8<bias_mode, ow_step, filter_size>::impl( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ld_dst_oc); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_end) * ic_step * src_expand_size; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||||
bias + oc_idx, dst + dst_offset, ic, ih, | |||||
iw, ld_dst_oc); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} // namespace | |||||
namespace int8x8x16_direct_nchw44 { | |||||
template <BiasMode bias_mode, int filter_size> | |||||
struct ConvDirectInt8Nchw44Choose<bias_mode, filter_size, 1> { | |||||
static void impl(const int8_t* src, const int8_t* filter, | |||||
const int16_t* bias, int16_t* dst, const size_t oc, | |||||
const size_t ic, const size_t ih, const size_t iw, | |||||
const size_t oh, const size_t ow) { | |||||
conv_direct_stride1_int8_nchw44_kern<bias_mode, filter_size>( | |||||
src, filter, bias, dst, oc, ic, ih, iw, oh, ow); | |||||
} | |||||
}; | |||||
template <BiasMode bias_mode> | |||||
struct ConvDirectInt8Nchw44Choose<bias_mode, 2, 1> { | |||||
static void impl(const int8_t* src, const int8_t* filter, | |||||
const int16_t* bias, int16_t* dst, const size_t oc, | |||||
const size_t ic, const size_t ih, const size_t iw, | |||||
const size_t oh, const size_t ow) { | |||||
conv_direct_stride1_2x2_int8_oc8_ow8_nchw44<bias_mode>( | |||||
src, filter, bias, dst, oc, ic, ih, iw, oh, ow); | |||||
} | |||||
}; | |||||
#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode) \ | |||||
template struct ConvDirectInt8Nchw44Choose<bias_mode, filter_size, stride>; | |||||
#define GET_OP_PARAM(stride, filter, bias_mode) \ | |||||
DO_CONV_KERN_FUN(stride, filter, bias_mode) | |||||
#define GET_BIAS_MODE_PARAM(stride, filter) \ | |||||
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||||
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
#define DISPATCH_CONV_KERN(stride) \ | |||||
GET_BIAS_MODE_PARAM(stride, 2) \ | |||||
GET_BIAS_MODE_PARAM(stride, 3) \ | |||||
GET_BIAS_MODE_PARAM(stride, 5) \ | |||||
GET_BIAS_MODE_PARAM(stride, 7) | |||||
DISPATCH_CONV_KERN(1); | |||||
} // namespace int8x8x16_direct_nchw44 | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -44,6 +44,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
AlgoQU8DirectStride1 qu8_direct_stride1; | AlgoQU8DirectStride1 qu8_direct_stride1; | ||||
AlgoS8DirectStride2 s8_direct_stride2; | AlgoS8DirectStride2 s8_direct_stride2; | ||||
AlgoS8DirectNCHW44 s8_direct_nchw44; | AlgoS8DirectNCHW44 s8_direct_nchw44; | ||||
AlgoS8x8x16DirectNCHW44 s8x8x16_direct_nchw44; | |||||
AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; | AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; | ||||
AlgoS8DirectStride1 s8_direct_stride1; | AlgoS8DirectStride1 s8_direct_stride1; | ||||
AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; | AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; | ||||
@@ -94,6 +95,7 @@ public: | |||||
direct_algos.emplace_back(&qu8_direct_stride1); | direct_algos.emplace_back(&qu8_direct_stride1); | ||||
direct_algos.emplace_back(&s8_direct_stride2); | direct_algos.emplace_back(&s8_direct_stride2); | ||||
direct_algos.emplace_back(&s8_direct_nchw44); | direct_algos.emplace_back(&s8_direct_nchw44); | ||||
direct_algos.emplace_back(&s8x8x16_direct_nchw44); | |||||
direct_algos.emplace_back(&s8_direct_nchw_nchw44); | direct_algos.emplace_back(&s8_direct_nchw_nchw44); | ||||
direct_algos.emplace_back(&s8_direct_stride1); | direct_algos.emplace_back(&s8_direct_stride1); | ||||
@@ -39,6 +39,7 @@ private: | |||||
class AlgoS8DirectStride1; | class AlgoS8DirectStride1; | ||||
class AlgoS8DirectStride2; | class AlgoS8DirectStride2; | ||||
class AlgoS8DirectNCHW44; | class AlgoS8DirectNCHW44; | ||||
class AlgoS8x8x16DirectNCHW44; | |||||
class AlgoS8DirectNCHWNCHW44; | class AlgoS8DirectNCHWNCHW44; | ||||
class AlgoQU8DirectStride1; | class AlgoQU8DirectStride1; | ||||
class AlgoQU8DirectStride2; | class AlgoQU8DirectStride2; | ||||
@@ -518,6 +518,116 @@ void benchmark_im2col_single_algo(const char* im2col_name, Handle* handle, | |||||
} | } | ||||
} | } | ||||
void benchmark_nchw44_8x8x16_vs_8x8x32(const char* im2col_name, Handle* handle, | |||||
size_t kernel, size_t stride, | |||||
size_t pack_size = 1) { | |||||
megdnn_assert(stride == 1 || stride == 2, "only support stride 1 or 2"); | |||||
std::vector<conv_bias::TestArg> args; | |||||
auto pack = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, | |||||
size_t p) { | |||||
if (ic % pack_size != 0 || oc % pack_size != 0) | |||||
return; | |||||
if (w + 2 * p < kernel || h + 2 * p < kernel) | |||||
return; | |||||
param::ConvBias param; | |||||
param.format = param::ConvBias::Format::NCHW44; | |||||
param.stride_h = stride; | |||||
param.stride_w = stride; | |||||
param.pad_h = p; | |||||
param.pad_w = p; | |||||
param.sparse = param::ConvBias::Sparse::DENSE; | |||||
args.push_back(conv_bias::TestArg{ | |||||
param, | |||||
TensorShape{1, ic / 4, h, w, 4}, | |||||
TensorShape{oc / 4, ic / 4, kernel, kernel, 4, 4}, | |||||
{1, oc / 4, 1, 1, 4}}); | |||||
}; | |||||
pack(1, 64, 56, 56, kernel, 0); | |||||
pack(8, 64, 56, 56, kernel, 0); | |||||
pack(16, 64, 56, 56, kernel, 1); | |||||
pack(32, 64, 56, 56, kernel, 1); | |||||
pack(1, 64, 100, 100, kernel, 1); | |||||
pack(8, 64, 100, 100, kernel, 1); | |||||
pack(1, 64, 100, 100, kernel, 0); | |||||
pack(8, 64, 100, 100, kernel, 0); | |||||
pack(16, 64, 100, 100, kernel, 1); | |||||
pack(32, 64, 100, 100, kernel, 1); | |||||
pack(64, 64, 100, 100, kernel, 1); | |||||
pack(128, 64, 100, 100, kernel, 1); | |||||
pack(256, 64, 100, 100, kernel, 1); | |||||
pack(512, 64, 100, 100, kernel, 1); | |||||
pack(1024, 64, 100, 100, kernel, 1); | |||||
pack(1, 32, 200, 200, kernel, 1); | |||||
pack(8, 64, 200, 200, kernel, 1); | |||||
pack(1, 32, 200, 200, kernel, 0); | |||||
pack(8, 64, 200, 200, kernel, 0); | |||||
pack(16, 96, 200, 200, kernel, 1); | |||||
pack(32, 32, 200, 200, kernel, 1); | |||||
pack(64, 64, 200, 200, kernel, 1); | |||||
pack(128, 96, 200, 200, kernel, 1); | |||||
pack(1, 64, 10, 10, kernel, 1); | |||||
pack(8, 64, 10, 10, kernel, 1); | |||||
pack(16, 64, 10, 10, kernel, 1); | |||||
pack(32, 64, 10, 10, kernel, 1); | |||||
pack(64, 64, 10, 10, kernel, 1); | |||||
pack(128, 64, 10, 10, kernel, 1); | |||||
pack(256, 64, 10, 10, kernel, 1); | |||||
pack(512, 64, 10, 10, kernel, 1); | |||||
pack(1024, 64, 10, 10, kernel, 1); | |||||
using namespace conv_bias; | |||||
constexpr size_t RUN = 20; | |||||
Benchmarker<ConvBias> benchmark_im2col(handle); | |||||
benchmark_im2col.set_display(false); | |||||
benchmark_im2col.set_times(RUN); | |||||
Benchmarker<ConvBias> benchmark_8832(handle); | |||||
benchmark_8832.set_display(false); | |||||
benchmark_8832.set_times(RUN); | |||||
for (auto&& arg : args) { | |||||
TensorLayout dst_layout; | |||||
auto opr = handle->create_operator<ConvBias>(); | |||||
opr->param() = arg.param; | |||||
opr->deduce_layout({arg.src, dtype::Float32()}, | |||||
{arg.filter, dtype::Float32()}, | |||||
{arg.bias, dtype::Float32()}, {}, dst_layout); | |||||
//! dst.nr_elems * IC * FH * FW * 2 | |||||
float computations = dst_layout.total_nr_elems() * arg.filter[1] * | |||||
arg.filter[2] * arg.filter[3] * 2.0 * 4 / | |||||
(1024 * 1024 * 1024) * 1e3; | |||||
benchmark_im2col.set_param(arg.param); | |||||
benchmark_im2col.set_dtype(0, dtype::Int8()); | |||||
benchmark_im2col.set_dtype(1, dtype::Int8()); | |||||
benchmark_im2col.set_dtype(2, dtype::Int16()); | |||||
benchmark_im2col.set_dtype(4, dtype::Int16()); | |||||
auto used_8816 = | |||||
algo_benchmark<ConvBias>(benchmark_im2col, | |||||
{arg.src, arg.filter, {}, {}, {}}, | |||||
im2col_name) / | |||||
RUN; | |||||
benchmark_8832.set_param(arg.param); | |||||
benchmark_8832.set_dtype(0, dtype::QuantizedS8(2.5)); | |||||
benchmark_8832.set_dtype(1, dtype::QuantizedS8(2.5)); | |||||
benchmark_8832.set_dtype(2, dtype::QuantizedS32(6.25)); | |||||
benchmark_8832.set_dtype(4, {}); | |||||
auto used_8832 = | |||||
algo_benchmark<ConvBias>(benchmark_8832, | |||||
{arg.src, arg.filter, {}, {}, {}}, | |||||
"S8_NCHW44_DIRECT") / | |||||
RUN; | |||||
printf("%s %s: 8816: %f ms %f GFlops ", arg.src.to_string().c_str(), | |||||
arg.filter.to_string().c_str(), used_8816, | |||||
computations / used_8816); | |||||
printf("%s %s: 8832: %f ms %f GFlops ", arg.src.to_string().c_str(), | |||||
arg.filter.to_string().c_str(), used_8832, | |||||
computations / used_8832); | |||||
printf("speedup %f \n", used_8832 / used_8816); | |||||
} | |||||
} | |||||
void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name, | void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name, | ||||
const char* im2col_name, Handle* handle, | const char* im2col_name, Handle* handle, | ||||
size_t kernel, DType src_type, | size_t kernel, DType src_type, | ||||
@@ -872,6 +982,28 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_MATMUL) { | |||||
#endif | #endif | ||||
#if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_8X8X16_DIRECT_STRIDE1) { | |||||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 2, 1, | |||||
4); | |||||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 3, 1, | |||||
4); | |||||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 5, 1, | |||||
4); | |||||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 7, 1, | |||||
4); | |||||
} | |||||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_8X8X16_DIRECT_STRIDE2) { | |||||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 2, 2, | |||||
4); | |||||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 3, 2, | |||||
4); | |||||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 5, 2, | |||||
4); | |||||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 7, 2, | |||||
4); | |||||
} | |||||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23) { | TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23) { | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
benchmark_winograd("WINOGRAD:AARCH64_F32:1:2", handle(), 3); | benchmark_winograd("WINOGRAD:AARCH64_F32:1:2", handle(), 3); | ||||
@@ -534,11 +534,25 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) { | |||||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), | get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), | ||||
handle(), "S8_NCHW44_DIRECT"); | handle(), "S8_NCHW44_DIRECT"); | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8816) { | |||||
checker_conv_bias_int8x8x16( | |||||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true), | |||||
handle(), "S8x8x16_NCHW44_DIRECT"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8816) { | |||||
checker_conv_bias_int8x8x16( | |||||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true), | |||||
handle(), "S8x8x16_NCHW44_DIRECT"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) { | ||||
checker_conv_bias_qint8x8x32( | checker_conv_bias_qint8x8x32( | ||||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true), | get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true), | ||||
handle(), "S8_NCHW44_DIRECT"); | handle(), "S8_NCHW44_DIRECT"); | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) { | ||||
checker_conv_bias_qint8x8x32( | checker_conv_bias_qint8x8x32( | ||||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true), | get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true), | ||||