2x2 3x3 5x5 7x7 directconv
GitOrigin-RevId: 3710182af1
tags/v1.0.0-rc1
@@ -38,6 +38,18 @@ public: | |||
const NCBKernSizeParam& param) const override; | |||
}; | |||
class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { | |||
public: | |||
AlgoS8x8x16DirectNCHW44() {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "S8x8x16_NCHW44_DIRECT"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
}; | |||
class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; | |||
@@ -0,0 +1,481 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/int8x8x16/conv_direct_int8x8x16_nchw44_algo.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/int8x8x16/algos.h" | |||
#include "src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h" | |||
#include "midout.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using conv_fun = std::function<void( | |||
const WorkspaceBundle& bundle, | |||
const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8x8x16_nchw44_direct) | |||
static void get_rectified_size( | |||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, | |||
int& iw2) { | |||
auto&& fm = param.filter_meta; | |||
int ih = param.isz[0]; | |||
int iw = param.isz[1]; | |||
int ph = fm.padding[0]; | |||
int pw = fm.padding[1]; | |||
ih2 = ih + ph * 2; | |||
iw2 = iw + pw * 2; | |||
} | |||
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||
auto&& fm = param.filter_meta; | |||
size_t group = fm.group; | |||
size_t batch = param.n; | |||
size_t IC = fm.icpg; | |||
int IH2, IW2; | |||
get_rectified_size(param, IH2, IW2); | |||
if (group == 1) { | |||
size_t src_size = 0; | |||
bool need_padding = param.filter_meta.padding[0] > 0 || | |||
param.filter_meta.padding[1] > 0; | |||
src_size = need_padding | |||
? batch * group * IC * IH2 * IW2 * sizeof(int8_t) | |||
: 0; | |||
#if MEGDNN_ARMV7 | |||
if (fm.stride[0] == 1) { | |||
constexpr int src_expand_element = 4; | |||
src_size = batch * group * IC * IH2 * IW2 * sizeof(int8_t) * | |||
src_expand_element; | |||
} | |||
#endif | |||
return {nullptr, {src_size}}; | |||
} else { | |||
size_t src_size = 0; | |||
bool need_padding = param.filter_meta.padding[0] > 0 || | |||
param.filter_meta.padding[1] > 0; | |||
src_size = need_padding | |||
? param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) | |||
: 0; | |||
#if MEGDNN_ARMV7 | |||
if (fm.stride[0] == 1) { | |||
constexpr int src_expand_element = 4; | |||
src_size = param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) * | |||
src_expand_element; | |||
} | |||
#endif | |||
return {nullptr, {src_size}}; | |||
} | |||
}; | |||
#if MEGDNN_ARMV7 | |||
static void copy_padding_kern(const WorkspaceBundle& bundle, | |||
const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids) { | |||
int IH = kern_param.isz[0]; | |||
int IW = kern_param.isz[1]; | |||
int IC = kern_param.filter_meta.icpg; | |||
int PH = kern_param.filter_meta.padding[0]; | |||
int PW = kern_param.filter_meta.padding[1]; | |||
int GROUP = kern_param.filter_meta.group; | |||
int IH2, IW2; | |||
get_rectified_size(kern_param, IH2, IW2); | |||
int padding_group_size = IH2 * IW2 * IC; | |||
//! Used for get the workspace offset | |||
constexpr int pack_ic = 4; | |||
constexpr int src_expand_element = 4;; | |||
size_t workspace_ic_block = 4; | |||
size_t workspace_batch_id = workspace_ids[0]; | |||
size_t workspace_group_id = workspace_ids[1]; | |||
size_t workspace_ic_id = workspace_ids[2]; | |||
size_t workspace_ic = workspace_ic_id * workspace_ic_block; | |||
size_t batch_id = ncb_index.ndrange_id[0]; | |||
size_t group_id = ncb_index.ndrange_id[1]; | |||
size_t group_pack_size = 1; | |||
int nr_pad_w = PW * pack_ic * src_expand_element; | |||
int nr_pad_h = PH * IW2 * pack_ic * src_expand_element; | |||
int row_last_pad = (IW2 - IW - PW) * pack_ic * src_expand_element; | |||
int col_last_pad = (IH2 - IH - PH) * IW2 * pack_ic * src_expand_element; | |||
const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>( | |||
batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); | |||
//! copy to sptr_base to eliminate padding effect | |||
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) + | |||
(workspace_batch_id * GROUP * padding_group_size + | |||
workspace_group_id * padding_group_size + | |||
workspace_ic * IH2 * IW2) * | |||
src_expand_element; | |||
size_t nr_ic = workspace_ic_block; | |||
if (GROUP > 1) { | |||
nr_ic = IC; | |||
} | |||
rep_step(ic_idx, nr_ic, pack_ic) { | |||
std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); | |||
sptr_base += nr_pad_h; | |||
rep(ih_idx, IH) { | |||
std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); | |||
sptr_base += nr_pad_w; | |||
int8x8x16_direct_nchw44::nchw44_pack_src(sptr, sptr_base, IW); | |||
sptr_base += IW * pack_ic * src_expand_element; | |||
sptr += IW * pack_ic; | |||
std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); | |||
sptr_base += row_last_pad; | |||
} | |||
std::memset(sptr_base, 0, col_last_pad * sizeof(int8_t)); | |||
sptr_base += col_last_pad; | |||
} | |||
} | |||
#endif | |||
static void copy_padding_kern_no_pack_src(const WorkspaceBundle& bundle, | |||
const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids) { | |||
int IH = kern_param.isz[0]; | |||
int IW = kern_param.isz[1]; | |||
int IC = kern_param.filter_meta.icpg; | |||
int PH = kern_param.filter_meta.padding[0]; | |||
int PW = kern_param.filter_meta.padding[1]; | |||
int GROUP = kern_param.filter_meta.group; | |||
int IH2, IW2; | |||
get_rectified_size(kern_param, IH2, IW2); | |||
int padding_group_size = IH2 * IW2 * IC; | |||
//! Used for get the workspace offset | |||
constexpr int pack_ic = 4; | |||
constexpr int src_expand_element = 1; | |||
size_t workspace_ic_block = 4; | |||
size_t workspace_batch_id = workspace_ids[0]; | |||
size_t workspace_group_id = workspace_ids[1]; | |||
size_t workspace_ic_id = workspace_ids[2]; | |||
size_t workspace_ic = workspace_ic_id * workspace_ic_block; | |||
size_t batch_id = ncb_index.ndrange_id[0]; | |||
size_t group_id = ncb_index.ndrange_id[1]; | |||
size_t group_pack_size = 1; | |||
int nr_pad_w = PW * pack_ic * src_expand_element; | |||
int nr_pad_h = PH * IW2 * pack_ic * src_expand_element; | |||
int row_last_pad = (IW2 - IW - PW) * pack_ic * src_expand_element; | |||
int col_last_pad = (IH2 - IH - PH) * IW2 * pack_ic * src_expand_element; | |||
const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>( | |||
batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); | |||
//! copy to sptr_base to eliminate padding effect | |||
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) + | |||
(workspace_batch_id * GROUP * padding_group_size + | |||
workspace_group_id * padding_group_size + | |||
workspace_ic * IH2 * IW2) * | |||
src_expand_element; | |||
size_t nr_ic = workspace_ic_block; | |||
if (GROUP > 1) { | |||
nr_ic = IC; | |||
} | |||
rep_step(ic_idx, nr_ic, pack_ic) { | |||
std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); | |||
sptr_base += nr_pad_h; | |||
rep(ih_idx, IH) { | |||
std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); | |||
sptr_base += nr_pad_w; | |||
std::memcpy(sptr_base, sptr, IW * pack_ic); | |||
sptr_base += IW * pack_ic * src_expand_element; | |||
sptr += IW * pack_ic; | |||
std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); | |||
sptr_base += row_last_pad; | |||
} | |||
std::memset(sptr_base, 0, col_last_pad * sizeof(int8_t)); | |||
sptr_base += col_last_pad; | |||
} | |||
} | |||
template <size_t filter, BiasMode bias_mode, int stride> | |||
static void do_conv_kern(const WorkspaceBundle& bundle, | |||
const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids, | |||
const CpuNDRange& ncb_range) { | |||
size_t OH = kern_param.osz[0]; | |||
size_t OW = kern_param.osz[1]; | |||
size_t FH = kern_param.filter_meta.spatial[0]; | |||
size_t FW = kern_param.filter_meta.spatial[1]; | |||
size_t IC = kern_param.filter_meta.icpg; | |||
size_t OC = kern_param.filter_meta.ocpg; | |||
size_t GROUP = kern_param.filter_meta.group; | |||
int IH2, IW2; | |||
get_rectified_size(kern_param, IH2, IW2); | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
constexpr size_t pack_c = 4; | |||
const size_t workspace_batch_id = workspace_ids[0]; | |||
const size_t workspace_group_id = workspace_ids[1]; | |||
const size_t batch_id = ncb_index.ndrange_id[0]; | |||
const size_t group_id = ncb_index.ndrange_id[1]; | |||
const size_t oc_id = ncb_index.ndrange_id[2]; | |||
const size_t oc_block_num = ncb_range[2]; | |||
megdnn_assert((OC & (pack_c - 1)) == 0, "OC must times of 4"); | |||
size_t nr_pack_per_step = div_ceil(OC / pack_c, oc_block_num); | |||
size_t oc_block = nr_pack_per_step * pack_c; | |||
const size_t oc_idx = oc_id * oc_block; | |||
if (oc_id == (oc_block_num - 1)) { | |||
oc_block = OC - oc_id * nr_pack_per_step * pack_c; | |||
} | |||
megdnn_assert(oc_block % pack_c == 0, | |||
"oc must be devisible by 4, but oc = %zu", oc_block); | |||
bool need_padding = kern_param.filter_meta.padding[0] > 0 || | |||
kern_param.filter_meta.padding[1] > 0; | |||
const int8_t* sptr = need_padding | |||
? static_cast<int8_t*>(bundle.get(0)) + | |||
workspace_batch_id * GROUP * padding_group_size + | |||
workspace_group_id * padding_group_size | |||
: kern_param.src<int8_t>(batch_id, group_id); | |||
//!armv7 use packsrc mode | |||
#if MEGDNN_ARMV7 | |||
if (stride == 1) { | |||
constexpr size_t src_expand_size = 4; | |||
sptr = static_cast<int8_t*>(bundle.get(0)) + | |||
workspace_batch_id * GROUP * padding_group_size * | |||
src_expand_size + | |||
workspace_group_id * padding_group_size * src_expand_size; | |||
} | |||
#endif | |||
const int8_t* fptr = | |||
kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC; | |||
int16_t* dst = reinterpret_cast<int16_t*>( | |||
kern_param.dst<void>(batch_id, group_id, oc_idx)); | |||
const int16_t* bptr = | |||
kern_param.bias<dt_int16>(batch_id, group_id) + oc_idx; | |||
int8x8x16_direct_nchw44::ConvDirectInt8Nchw44Choose< | |||
bias_mode, filter, stride>::impl(sptr, fptr, bptr, dst, oc_block, | |||
IC, IH2, IW2, OH, OW); | |||
} | |||
bool ConvBiasImpl::AlgoS8x8x16DirectNCHW44::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
MEGDNN_MARK_USED_VAR(algo_selection_strategy); | |||
auto&& fm = param.filter_meta; | |||
const int fh = fm.spatial[0]; | |||
const int fw = fm.spatial[1]; | |||
const int oc = fm.ocpg; | |||
const int ic = fm.icpg; | |||
const bool avaible = //! src and filter are int8, dst is int16_t | |||
(param.src_type.enumv() == DTypeEnum::Int8 && | |||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||
param.dst_type.enumv() == DTypeEnum::Int16) && | |||
(fm.format == param::Convolution::Format::NCHW44) && | |||
(oc % 4 == 0 && ic % 4 == 0 && oc >= 4) && !fm.should_flip && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && | |||
(fm.stride[0] == 2 || fm.stride[0] == 1) && fh == fw && | |||
(fh == 2 || fh == 3 || fh == 5 || fh == 7) && | |||
param.nonlineMode == NonlineMode::IDENTITY && | |||
param.bias_mode != BiasMode::BIAS; | |||
return avaible; | |||
} | |||
size_t ConvBiasImpl::AlgoS8x8x16DirectNCHW44::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoS8x8x16DirectNCHW44::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
auto fm = param.filter_meta; | |||
size_t N = param.n; | |||
size_t IC = fm.icpg; | |||
size_t OC = fm.ocpg; | |||
size_t group = fm.group; | |||
size_t fh = fm.spatial[0]; | |||
size_t fw = fm.spatial[1]; | |||
size_t ph = fm.padding[0]; | |||
size_t pw = fm.padding[1]; | |||
WorkspaceBundle wbundle = get_bundle(param); | |||
conv_fun do_conv_fun = nullptr; | |||
#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode) \ | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8x8x16_nchw44_direct, \ | |||
midout_iv("int8x8x16_nchw44_direct_" \ | |||
"conv" #stride #filter #bias_mode##_hash)) { \ | |||
do_conv_fun = do_conv_kern<filter, bias_mode, stride>; \ | |||
} \ | |||
MIDOUT_END(); | |||
#define GET_OP_PARAM(stride, filter, bias_mode) \ | |||
switch (param.nonlineMode) { \ | |||
case param::ConvBias::NonlineMode::IDENTITY: \ | |||
DO_CONV_KERN_FUN(stride, dt_int16, filter, bias_mode) \ | |||
break; \ | |||
default: \ | |||
megdnn_throw(ssprintf("only support IDENTITY mode when dst is " \ | |||
"dt_int16 nonlineMode is %d", \ | |||
uint32_t(param.nonlineMode)) \ | |||
.c_str()); \ | |||
break; \ | |||
} | |||
#define GET_BIAS_MODE_PARAM(stride, filter) \ | |||
switch (param.bias_mode) { \ | |||
case BiasMode::NO_BIAS: \ | |||
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||
break; \ | |||
case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||
break; \ | |||
default: \ | |||
megdnn_throw(ssprintf("only support NO_BIAS/BROADCAST biasmode " \ | |||
"when dst is " \ | |||
"dt_int16 biasmode is %d", \ | |||
uint32_t(param.bias_mode)) \ | |||
.c_str()); \ | |||
break; \ | |||
} | |||
#define DISPATCH_CONV_KERN(stride) \ | |||
switch (param.filter_meta.spatial[0]) { \ | |||
case 2: \ | |||
GET_BIAS_MODE_PARAM(stride, 2) \ | |||
break; \ | |||
case 3: \ | |||
GET_BIAS_MODE_PARAM(stride, 3) \ | |||
break; \ | |||
case 5: \ | |||
GET_BIAS_MODE_PARAM(stride, 5) \ | |||
break; \ | |||
case 7: \ | |||
GET_BIAS_MODE_PARAM(stride, 7) \ | |||
break; \ | |||
default: \ | |||
megdnn_throw(ssprintf("only support 2x2 3x3 5x5 7x7 filters size " \ | |||
"when dst is " \ | |||
"dt_int16 filter size is %u", \ | |||
uint32_t(param.filter_meta.spatial[0])) \ | |||
.c_str()); \ | |||
break; \ | |||
} | |||
switch (param.filter_meta.stride[0]) { | |||
case 1: | |||
DISPATCH_CONV_KERN(1); | |||
break; | |||
case 2: | |||
DISPATCH_CONV_KERN(2); | |||
break; | |||
default: | |||
megdnn_throw(ssprintf("Unsupport stride size %u for the 8x8x16 direct conv", | |||
param.filter_meta.stride[0]) | |||
.c_str()); | |||
break; | |||
} | |||
#undef DO_CONV_KERN_FUN | |||
#undef GET_REMAIN_W_PARAM | |||
#undef GET_OP_PARAM | |||
#undef GET_BIAS_MODE_PARAM | |||
#undef DISPATCH_CONV_KERN | |||
megdnn_assert(do_conv_fun); | |||
SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||
constexpr size_t pack_oc = 4; | |||
size_t oc_step = pack_oc; | |||
if (fh == fw && (fh == 2 || fw == 3) && OC >= 8) { | |||
oc_step = 8; | |||
} | |||
#if MEGDNN_ARMV7 | |||
if (param.filter_meta.stride[0] == 1) { | |||
if (group == 1) { | |||
CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; | |||
auto copy_padding = [wbundle]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
wbundle.set(kern_param.workspace_ptr); | |||
copy_padding_kern(wbundle, kern_param, ncb_index, | |||
ncb_index.ndrange_id); | |||
}; | |||
constexpr size_t pack_ic = 4; | |||
ret_kerns.push_back( | |||
{copy_padding, {N, group, div_ceil(IC, pack_ic)}}); | |||
auto do_conv = [wbundle, do_conv_fun, ncb_range]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
wbundle.set(kern_param.workspace_ptr); | |||
do_conv_fun(wbundle, kern_param, ncb_index, | |||
ncb_index.ndrange_id, ncb_range); | |||
}; | |||
ret_kerns.push_back({do_conv, ncb_range}); | |||
} else { | |||
CpuNDRange ncb_range = {N, group, 1}; | |||
auto do_conv = [wbundle, do_conv_fun, ncb_range]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
wbundle.set(kern_param.workspace_ptr); | |||
copy_padding_kern(wbundle, kern_param, ncb_index, | |||
{0, ncb_index.thread_id, 0}); | |||
do_conv_fun(wbundle, kern_param, ncb_index, | |||
{0, ncb_index.thread_id, 0}, ncb_range); | |||
}; | |||
ret_kerns.push_back({do_conv, ncb_range}); | |||
} | |||
return ret_kerns; | |||
} | |||
#endif | |||
bool need_padding = ph > 0 || pw >0; | |||
if (group == 1) { | |||
CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; | |||
auto copy_padding = [wbundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
wbundle.set(kern_param.workspace_ptr); | |||
copy_padding_kern_no_pack_src(wbundle, kern_param, ncb_index, | |||
ncb_index.ndrange_id); | |||
}; | |||
constexpr size_t pack_ic = 4; | |||
if (need_padding) { | |||
ret_kerns.push_back( | |||
{copy_padding, {N, group, div_ceil(IC, pack_ic)}}); | |||
} | |||
auto do_conv = [wbundle, do_conv_fun, ncb_range]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
wbundle.set(kern_param.workspace_ptr); | |||
do_conv_fun(wbundle, kern_param, ncb_index, ncb_index.ndrange_id, | |||
ncb_range); | |||
}; | |||
ret_kerns.push_back({do_conv, ncb_range}); | |||
} else { | |||
CpuNDRange ncb_range = {N, group, 1}; | |||
auto do_conv = [wbundle, do_conv_fun, ncb_range, need_padding]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
wbundle.set(kern_param.workspace_ptr); | |||
if (need_padding) { | |||
copy_padding_kern_no_pack_src(wbundle, kern_param, ncb_index, | |||
{0, ncb_index.thread_id, 0}); | |||
}; | |||
do_conv_fun(wbundle, kern_param, ncb_index, | |||
{0, ncb_index.thread_id, 0}, ncb_range); | |||
}; | |||
ret_kerns.push_back({do_conv, ncb_range}); | |||
} | |||
return ret_kerns; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,56 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw44_kern.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace int8x8x16_direct_nchw44 { | |||
/** | |||
origin src shape <n, ic/4, h, w, 4> | |||
packed src shape <n, ic/4, h, w, 16> | |||
example: (format like <ic>) | |||
origin | |||
<0> <1> <2> <3> | |||
packed | |||
low 64 bit <0> <0> <0> <0> | <1> <1> <1> <1> | |||
--------------------------------------------------------------------- | |||
high 64 bit <2> <2> <2> <2> | <3> <3> <3> <3> | |||
**/ | |||
static inline void nchw44_pack_src(const int8_t* src, int8_t* dst, int length) { | |||
static const uint8_t src_idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, | |||
2, 2, 2, 2, 3, 3, 3, 3}; | |||
constexpr int pack_ic = 4; | |||
constexpr int simd_len = 16; | |||
uint8x16_t src_idx = vld1q_u8(src_idx_buffer); | |||
for (int i = 0; i < length; i++) { | |||
int8x16_t result = vld_dup_tbl_s32(src + i * pack_ic, src_idx); | |||
vst1q_s8(dst + i * simd_len, result); | |||
} | |||
} | |||
template <BiasMode bias_mode, int filter_size, int stride> | |||
struct ConvDirectInt8Nchw44Choose { | |||
static void impl(const int8_t* src, const int8_t* filter, | |||
const int16_t* bias, int16_t* dst, const size_t oc, | |||
const size_t ic, const size_t ih, const size_t iw, | |||
const size_t oh, const size_t ow); | |||
}; | |||
} // namespace int8_direct_nchw44 | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,971 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8_direct_nchw44_s1_aarch64.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#if MEGDNN_AARCH64 | |||
#include "src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace { | |||
#define INIT_SUM() \ | |||
int16x4_t init_sum; \ | |||
if (bias_mode == BiasMode::NO_BIAS) { \ | |||
init_sum = vdup_n_s16(0); \ | |||
} else { \ | |||
init_sum = vld1_s16(bias_ptr); \ | |||
} | |||
#define STORE_1_LINE_RESULT() \ | |||
switch (remain_w) { \ | |||
case 8: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||
vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \ | |||
break; \ | |||
case 1: \ | |||
vst1_s16(dst_ptr, c[0][0]); \ | |||
break; \ | |||
case 2: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
break; \ | |||
case 3: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1_s16(dst_ptr + 8, c[0][2]); \ | |||
break; \ | |||
case 4: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
break; \ | |||
case 5: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
vst1_s16(dst_ptr + 16, c[0][4]); \ | |||
break; \ | |||
case 6: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||
break; \ | |||
case 7: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||
vst1_s16(dst_ptr + 24, c[0][6]); \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0, "oc 1 error remainw"); \ | |||
}; | |||
#define STORE_2_LINE_RESULT_OW4() \ | |||
switch (remain_w) { \ | |||
case 4: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc + 8, \ | |||
vcombine_s16(c[1][2], c[1][3])); \ | |||
break; \ | |||
case 1: \ | |||
vst1_s16(dst_ptr, c[0][0]); \ | |||
vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ | |||
break; \ | |||
case 2: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||
break; \ | |||
case 3: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1_s16(dst_ptr + 8, c[0][2]); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||
vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0, "oc 2 error remainw"); \ | |||
break; \ | |||
} | |||
#define STORE_1_LINE_RESULT_OW4_OH2() \ | |||
switch (remain_w) { \ | |||
case 4: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
vst1q_s16(dst_ptr + ow, vcombine_s16(c[0][4], c[0][5])); \ | |||
vst1q_s16(dst_ptr + 8 + ow, vcombine_s16(c[0][6], c[0][7])); \ | |||
break; \ | |||
case 1: \ | |||
vst1_s16(dst_ptr, c[0][0]); \ | |||
vst1_s16(dst_ptr + ow, c[0][4]); \ | |||
break; \ | |||
case 2: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + ow, vcombine_s16(c[0][4], c[0][5])); \ | |||
break; \ | |||
case 3: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1_s16(dst_ptr + 8, c[0][2]); \ | |||
vst1q_s16(dst_ptr + ow, vcombine_s16(c[0][4], c[0][5])); \ | |||
vst1_s16(dst_ptr + ow + 8, c[0][6]); \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0, "oc 2 error remainw"); \ | |||
break; \ | |||
} | |||
#define STORE_1_LINE_RESULT_OW4() \ | |||
switch (remain_w) { \ | |||
case 4: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
break; \ | |||
case 1: \ | |||
vst1_s16(dst_ptr, c[0][0]); \ | |||
break; \ | |||
case 2: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
break; \ | |||
case 3: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1_s16(dst_ptr + 8, c[0][2]); \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0, "oc 1 error remainw"); \ | |||
}; | |||
template <BiasMode bias_mode,int filter_size> | |||
static void ker_neon_dirctconv_2x2s1_oc8_ow4(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int16_t* bias_ptr, | |||
int16_t* dst_ptr, int ic, int ih, | |||
int iw, int remain_w,int ld_dst_oc) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int oc_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
const int ic_stride = ih * iw; | |||
const int ld_weight_oc4 = oc_step * fh * fw * ic; | |||
int16x4_t c[2][4]; | |||
int8x16_t weight[2][2]; | |||
int8x16_t src[5]; | |||
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, | |||
2, 2, 2, 2, 3, 3, 3, 3}; | |||
static uint8x16_t idx = vld1q_u8(idx_buffer); | |||
INIT_SUM(); | |||
#define cb(_i) \ | |||
c[0][_i] = init_sum; \ | |||
c[1][_i] = init_sum; | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
const int8_t* src_row0 = | |||
src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; | |||
const int8_t* src_row1 = | |||
src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; | |||
src[0] = vld_dup_tbl_s32(src_row0 + 0, idx); | |||
src[1] = vld_dup_tbl_s32(src_row0 + 4, idx); | |||
src[2] = vld_dup_tbl_s32(src_row0 + 8, idx); | |||
weight[0][0] = vld1q_s8(weight_ptr); | |||
weight[0][1] = vld1q_s8(weight_ptr + 16); | |||
weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4); | |||
weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 16); | |||
#define CALC_ONE_RESULT(_src0, _src1, _w0, _w1, _c) \ | |||
do { \ | |||
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w0)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w1)); \ | |||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||
} while (0); | |||
int16x8_t tmp0; | |||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][0]); | |||
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][0]); | |||
src[3] = vld_dup_tbl_s32(src_row0 + 12, idx); | |||
src[4] = vld_dup_tbl_s32(src_row0 + 16, idx); | |||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][1]); | |||
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][1]); | |||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][2]); | |||
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][2]); | |||
CALC_ONE_RESULT(src[3], src[4], weight[0][0], weight[0][1], c[0][3]); | |||
CALC_ONE_RESULT(src[3], src[4], weight[1][0], weight[1][1], c[1][3]); | |||
src[0] = vld_dup_tbl_s32(src_row1 + 0, idx); | |||
src[1] = vld_dup_tbl_s32(src_row1 + 4, idx); | |||
src[2] = vld_dup_tbl_s32(src_row1 + 8, idx); | |||
weight[0][0] = vld1q_s8(weight_ptr + 32); | |||
weight[0][1] = vld1q_s8(weight_ptr + 48); | |||
weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4 + 32); | |||
weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 48); | |||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][0]); | |||
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][0]); | |||
src[3] = vld_dup_tbl_s32(src_row1 + 12, idx); | |||
src[4] = vld_dup_tbl_s32(src_row1 + 16, idx); | |||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][1]); | |||
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][1]); | |||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][2]); | |||
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][2]); | |||
CALC_ONE_RESULT(src[3], src[4], weight[0][0], weight[0][1], c[0][3]); | |||
CALC_ONE_RESULT(src[3], src[4], weight[1][0], weight[1][1], c[1][3]); | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
STORE_2_LINE_RESULT_OW4(); | |||
} | |||
template <BiasMode bias_mode, int filter_size> | |||
static void ker_neon_dirctconv_2x2s1_oc4_ow4(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int16_t* bias_ptr, | |||
int16_t* dst_ptr, int ic, int ih, | |||
int iw, int remain_w, | |||
int /*ld_dst_oc*/) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, | |||
2, 2, 2, 2, 3, 3, 3, 3}; | |||
static uint8x16_t idx = vld1q_u8(idx_buffer); | |||
const int ic_stride = ih * iw; | |||
int16x4_t c[1][4]; | |||
int8x16_t weight[1][2]; | |||
int8x16_t src[5]; | |||
INIT_SUM(); | |||
#define cb(_i) c[0][_i] = init_sum; | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = | |||
src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step; | |||
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 0, idx); | |||
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx); | |||
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); | |||
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); | |||
src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx); | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0][0] = vld1q_s8(read_weight_ptr); | |||
weight[0][1] = vld1q_s8(read_weight_ptr + 16); | |||
int16x8_t tmp0; | |||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], | |||
c[0][0]); | |||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], | |||
c[0][1]); | |||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], | |||
c[0][2]); | |||
CALC_ONE_RESULT(src[3], src[4], weight[0][0], weight[0][1], | |||
c[0][3]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
STORE_1_LINE_RESULT_OW4(); | |||
} | |||
#undef CALC_ONE_RESULT | |||
#define CALC_ONE_RESULT(_src0, _src1, _src2, _w, _c) \ | |||
do { \ | |||
int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w[0])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w[1])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w[2])); \ | |||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||
} while (0); | |||
template <BiasMode bias_mode, int filter_size> | |||
static void ker_neon_dirctconv_3x3s1_oc4_ow4(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int16_t* bias_ptr, | |||
int16_t* dst_ptr, int ic, int ih, | |||
int iw, int remain_w, | |||
int /*ld_dst_oc*/) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
const int ic_stride = ih * iw; | |||
int16x4_t c[1][4]; | |||
int8x16_t weight[1][3]; | |||
int8x16_t src[6]; | |||
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, | |||
2, 2, 2, 2, 3, 3, 3, 3}; | |||
static uint8x16_t idx = vld1q_u8(idx_buffer); | |||
INIT_SUM(); | |||
#define cb(_i) c[0][_i] = init_sum; | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
const int8_t* src_row0 = | |||
src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; | |||
const int8_t* src_row1 = | |||
src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; | |||
const int8_t* src_row2 = | |||
src_ptr + ic_idx * ic_stride + 2 * iw * ic_step; | |||
src[0] = vld_dup_tbl_s32(src_row0 + 0, idx); | |||
src[1] = vld_dup_tbl_s32(src_row0 + 4, idx); | |||
src[2] = vld_dup_tbl_s32(src_row0 + 8, idx); | |||
weight[0][0] = vld1q_s8(weight_ptr); | |||
weight[0][1] = vld1q_s8(weight_ptr + 16); | |||
weight[0][2] = vld1q_s8(weight_ptr + 32); | |||
src[3] = vld_dup_tbl_s32(src_row0 + 12, idx); | |||
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]); | |||
src[4] = vld_dup_tbl_s32(src_row0 + 16, idx); | |||
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][1]); | |||
src[5] = vld_dup_tbl_s32(src_row0 + 20, idx); | |||
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][2]); | |||
CALC_ONE_RESULT(src[3], src[4], src[5], weight[0], c[0][3]); | |||
src[0] = vld_dup_tbl_s32(src_row1 + 0, idx); | |||
src[1] = vld_dup_tbl_s32(src_row1 + 4, idx); | |||
src[2] = vld_dup_tbl_s32(src_row1 + 8, idx); | |||
weight[0][0] = vld1q_s8(weight_ptr + 48); | |||
weight[0][1] = vld1q_s8(weight_ptr + 64); | |||
weight[0][2] = vld1q_s8(weight_ptr + 80); | |||
src[3] = vld_dup_tbl_s32(src_row1 + 12, idx); | |||
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]); | |||
src[4] = vld_dup_tbl_s32(src_row1 + 16, idx); | |||
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][1]); | |||
src[5] = vld_dup_tbl_s32(src_row1 + 20, idx); | |||
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][2]); | |||
CALC_ONE_RESULT(src[3], src[4], src[5], weight[0], c[0][3]); | |||
src[0] = vld_dup_tbl_s32(src_row2 + 0, idx); | |||
src[1] = vld_dup_tbl_s32(src_row2 + 4, idx); | |||
src[2] = vld_dup_tbl_s32(src_row2 + 8, idx); | |||
weight[0][0] = vld1q_s8(weight_ptr + 96); | |||
weight[0][1] = vld1q_s8(weight_ptr + 112); | |||
weight[0][2] = vld1q_s8(weight_ptr + 128); | |||
src[3] = vld_dup_tbl_s32(src_row2 + 12, idx); | |||
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]); | |||
src[4] = vld_dup_tbl_s32(src_row2 + 16, idx); | |||
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][1]); | |||
src[5] = vld_dup_tbl_s32(src_row2 + 20, idx); | |||
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][2]); | |||
CALC_ONE_RESULT(src[3], src[4], src[5], weight[0], c[0][3]); | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
STORE_1_LINE_RESULT_OW4(); | |||
} | |||
template <BiasMode bias_mode, int filter_size> | |||
static void ker_neon_dirctconv_3x3s1_oc4_ow4_oh2(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int16_t* bias_ptr, | |||
int16_t* dst_ptr, int ic, | |||
int ih, int iw, int remain_w, | |||
int /*ld_dst_oc*/, int ow) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
const int ic_stride = ih * iw; | |||
int16x4_t c[1][8]; | |||
int8x16_t weight[2][3]; | |||
int8x16_t src[1][6]; | |||
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, | |||
2, 2, 2, 2, 3, 3, 3, 3}; | |||
static uint8x16_t idx = vld1q_u8(idx_buffer); | |||
INIT_SUM(); | |||
#define cb(_i) c[0][_i] = init_sum; | |||
UNROLL_CALL_RAW(8, cb); | |||
#undef cb | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
const int8_t* src_row0 = | |||
src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; | |||
const int8_t* src_row1 = | |||
src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; | |||
const int8_t* src_row2 = | |||
src_ptr + ic_idx * ic_stride + 2 * iw * ic_step; | |||
const int8_t* src_row3 = | |||
src_ptr + ic_idx * ic_stride + 3 * iw * ic_step; | |||
#define LOAD_SRC(_src, _src_ptr) \ | |||
_src[0] = vld_dup_tbl_s32(_src_ptr + 0, idx); \ | |||
_src[1] = vld_dup_tbl_s32(_src_ptr + 4, idx); \ | |||
_src[2] = vld_dup_tbl_s32(_src_ptr + 8, idx); \ | |||
_src[3] = vld_dup_tbl_s32(_src_ptr + 12, idx); \ | |||
_src[4] = vld_dup_tbl_s32(_src_ptr + 16, idx); \ | |||
_src[5] = vld_dup_tbl_s32(_src_ptr + 20, idx); | |||
LOAD_SRC(src[0], src_row0); | |||
weight[0][0] = vld1q_s8(weight_ptr); | |||
weight[0][1] = vld1q_s8(weight_ptr + 16); | |||
weight[0][2] = vld1q_s8(weight_ptr + 32); | |||
weight[1][0] = vld1q_s8(weight_ptr + 48); | |||
weight[1][1] = vld1q_s8(weight_ptr + 64); | |||
weight[1][2] = vld1q_s8(weight_ptr + 80); | |||
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], | |||
c[0][0]); // row0 src0 w0 | |||
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][1]); | |||
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][2]); | |||
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][3]); | |||
LOAD_SRC(src[0], src_row1); | |||
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], | |||
c[0][4]); // row1 src1 w0 | |||
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][5]); | |||
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][6]); | |||
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][7]); | |||
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[1], | |||
c[0][0]); // row1 src1 w1 | |||
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[1], c[0][1]); | |||
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[1], c[0][2]); | |||
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[1], c[0][3]); | |||
LOAD_SRC(src[0], src_row2); | |||
weight[0][0] = vld1q_s8(weight_ptr + 96); | |||
weight[0][1] = vld1q_s8(weight_ptr + 112); | |||
weight[0][2] = vld1q_s8(weight_ptr + 128); | |||
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[1], | |||
c[0][4]); // row2 src0 w1 | |||
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[1], c[0][5]); | |||
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[1], c[0][6]); | |||
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[1], c[0][7]); | |||
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], | |||
c[0][0]); // row2 w0 src[0] | |||
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][1]); | |||
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][2]); | |||
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][3]); | |||
LOAD_SRC(src[0], src_row3); | |||
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], | |||
c[0][4]); // row3 w0 src1 | |||
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][5]); | |||
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][6]); | |||
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][7]); | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
STORE_1_LINE_RESULT_OW4_OH2(); | |||
} | |||
#undef LOAD_SRC | |||
#undef CALC_ONE_RESULT | |||
template <BiasMode bias_mode, int filter_size> | |||
struct KerNeonDirectStride1Int8 { | |||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, | |||
int iw, int remain_w, int ld_dst_oc); | |||
}; | |||
template <BiasMode bias_mode> | |||
struct KerNeonDirectStride1Int8<bias_mode, 5> { | |||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, | |||
int iw, int remain_w, int /*ld_dst_oc*/) { | |||
constexpr int filter_size = 5; | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
const int ic_stride = ih * iw; | |||
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, | |||
2, 2, 2, 2, 3, 3, 3, 3}; | |||
static uint8x16_t idx = vld1q_u8(idx_buffer); | |||
int16x4_t c[1][8]; | |||
int8x16_t weight[5]; | |||
int8x16_t src[8 + 2]; | |||
INIT_SUM(); | |||
#define cb(_i) c[0][_i] = init_sum; | |||
UNROLL_CALL_RAW(8, cb); | |||
#undef cb | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = | |||
src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step; | |||
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 0 * 4, idx); | |||
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 1 * 4, idx); | |||
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 2 * 4, idx); | |||
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 3 * 4, idx); | |||
src[4] = vld_dup_tbl_s32(src_ic_0_3 + 4 * 4, idx); | |||
src[5] = vld_dup_tbl_s32(src_ic_0_3 + 5 * 4, idx); | |||
src[6] = vld_dup_tbl_s32(src_ic_0_3 + 6 * 4, idx); | |||
src[7] = vld_dup_tbl_s32(src_ic_0_3 + 7 * 4, idx); | |||
src[8] = vld_dup_tbl_s32(src_ic_0_3 + 8 * 4, idx); | |||
src[9] = vld_dup_tbl_s32(src_ic_0_3 + 9 * 4, idx); | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0] = vld1q_s8(read_weight_ptr); | |||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||
#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _w0, _w1, _w2, _w3, \ | |||
_w4, _c) \ | |||
do { \ | |||
int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ | |||
int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w0)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ | |||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w1)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \ | |||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w2)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w3)); \ | |||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w3)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w4)); \ | |||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w4)); \ | |||
tmp0 = vaddq_s16(tmp0, tmp1); \ | |||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||
} while (0); | |||
CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][0]); | |||
CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][1]); | |||
CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][2]); | |||
CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][3]); | |||
CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][4]); | |||
CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][5]); | |||
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 10 * 4, idx); | |||
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 11 * 4, idx); | |||
CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][6]); | |||
CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][7]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
STORE_1_LINE_RESULT(); | |||
} | |||
}; | |||
#undef CALC_ONE_RESULT | |||
template <BiasMode bias_mode> | |||
struct KerNeonDirectStride1Int8<bias_mode, 7> { | |||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, | |||
int iw, int remain_w, int /*ld_dst_oc*/) { | |||
constexpr int filter_size = 7; | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
const int ic_stride = ih * iw; | |||
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, | |||
2, 2, 2, 2, 3, 3, 3, 3}; | |||
static uint8x16_t idx = vld1q_u8(idx_buffer); | |||
int16x4_t c[1][8]; | |||
int8x16_t weight[7]; | |||
int8x16_t src[8 + 2]; | |||
INIT_SUM(); | |||
#define cb(_i) c[0][_i] = init_sum; | |||
UNROLL_CALL_RAW(8, cb); | |||
#undef cb | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = | |||
src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step; | |||
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 0 * 4, idx); | |||
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 1 * 4, idx); | |||
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 2 * 4, idx); | |||
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 3 * 4, idx); | |||
src[4] = vld_dup_tbl_s32(src_ic_0_3 + 4 * 4, idx); | |||
src[5] = vld_dup_tbl_s32(src_ic_0_3 + 5 * 4, idx); | |||
src[6] = vld_dup_tbl_s32(src_ic_0_3 + 6 * 4, idx); | |||
src[7] = vld_dup_tbl_s32(src_ic_0_3 + 7 * 4, idx); | |||
src[8] = vld_dup_tbl_s32(src_ic_0_3 + 8 * 4, idx); | |||
src[9] = vld_dup_tbl_s32(src_ic_0_3 + 9 * 4, idx); | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0] = vld1q_s8(read_weight_ptr); | |||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); | |||
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); | |||
#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, \ | |||
_c) \ | |||
do { \ | |||
int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ | |||
int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w[0])); \ | |||
int16x8_t tmp2 = vmull_s8(vget_low_s8(_src1), vget_low_s8(_w[1])); \ | |||
int16x8_t tmp3 = vmull_s8(vget_high_s8(_src1), vget_high_s8(_w[1])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ | |||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w[2])); \ | |||
tmp2 = vmlal_s8(tmp2, vget_low_s8(_src3), vget_low_s8(_w[3])); \ | |||
tmp3 = vmlal_s8(tmp3, vget_high_s8(_src3), vget_high_s8(_w[3])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \ | |||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w[4])); \ | |||
tmp2 = vmlal_s8(tmp2, vget_low_s8(_src5), vget_low_s8(_w[5])); \ | |||
tmp3 = vmlal_s8(tmp3, vget_high_s8(_src5), vget_high_s8(_w[5])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \ | |||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src6), vget_high_s8(_w[6])); \ | |||
tmp0 = vaddq_s16(tmp0, tmp1); \ | |||
tmp2 = vaddq_s16(tmp2, tmp3); \ | |||
tmp0 = vaddq_s16(tmp0, tmp2); \ | |||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||
} while (0); | |||
CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], src[5], | |||
src[6], weight, c[0][0]); | |||
CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], src[6], | |||
src[7], weight, c[0][1]); | |||
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 10 * 4, idx); | |||
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 11 * 4, idx); | |||
CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], src[7], | |||
src[8], weight, c[0][2]); | |||
CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], src[8], | |||
src[9], weight, c[0][3]); | |||
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 12 * 4, idx); | |||
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 13 * 4, idx); | |||
CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], src[9], | |||
src[0], weight, c[0][4]); | |||
CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], src[0], | |||
src[1], weight, c[0][5]); | |||
CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], src[1], | |||
src[2], weight, c[0][6]); | |||
CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], src[2], | |||
src[3], weight, c[0][7]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
STORE_1_LINE_RESULT(); | |||
} | |||
}; | |||
#undef CALC_ONE_RESULT | |||
template <BiasMode bias_mode> | |||
void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, | |||
const int8_t* filter, | |||
const int16_t* bias, int16_t* dst, | |||
const size_t oc, const size_t ic, | |||
const size_t ih, const size_t iw, | |||
const size_t oh, const size_t ow) { | |||
constexpr size_t filter_size = 2; | |||
constexpr size_t fh = filter_size; | |||
constexpr size_t fw = filter_size; | |||
constexpr size_t ic_step = 4; | |||
constexpr size_t oc_step = 4; | |||
constexpr size_t big_oc_step = 8; | |||
constexpr size_t oh_step = 1; | |||
constexpr size_t ow_step = 4; | |||
const size_t img_stride = oh * ow; | |||
const size_t ow_end = ow / ow_step * ow_step; | |||
const size_t ow_remain = ow - ow_end; | |||
const size_t oc_end = oc / big_oc_step * big_oc_step; | |||
const size_t oc_remain = oc - oc_end; | |||
const int ld_oc = oh * ow * oc_step; | |||
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
size_t oh_idx = 0; | |||
for (; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_2x2s1_oc8_ow4<bias_mode, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ow_step, ld_oc); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
ker_neon_dirctconv_2x2s1_oc8_ow4<bias_mode, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ow_remain, ld_oc); | |||
} | |||
} | |||
} | |||
if (oc_remain > 0) { | |||
const size_t oc_idx = oc_end; | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_2x2s1_oc4_ow4<bias_mode, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ow_step, ld_oc); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
ker_neon_dirctconv_2x2s1_oc4_ow4<bias_mode, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ow_remain, ld_oc); | |||
} | |||
} | |||
} | |||
} | |||
template <BiasMode bias_mode> | |||
void conv_direct_stride1_3x3_int8x8x16_oh2_nchw44( | |||
const int8_t* src, const int8_t* filter, const int16_t* bias, | |||
int16_t* dst, const size_t oc, const size_t ic, const size_t ih, | |||
const size_t iw, const size_t oh, const size_t ow) { | |||
constexpr size_t filter_size = 3; | |||
constexpr size_t fh = filter_size; | |||
constexpr size_t fw = filter_size; | |||
constexpr size_t ic_step = 4; | |||
constexpr size_t oc_step = 4; | |||
constexpr size_t big_oc_step = 4; | |||
constexpr size_t oh_step = 1; | |||
constexpr size_t ow_step = 4; | |||
const size_t img_stride = oh * ow; | |||
const size_t ow_end = ow / ow_step * ow_step; | |||
const size_t ow_remain = ow - ow_end; | |||
const size_t oc_end = oc / big_oc_step * big_oc_step; | |||
const int ld_oc = oh * ow * oc_step; | |||
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
size_t oh_idx = 0; | |||
for (; oh_idx + 1 < oh; oh_idx += 2) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_3x3s1_oc4_ow4_oh2<bias_mode, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ow_step, ld_oc, | |||
ow * oc_step); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
ker_neon_dirctconv_3x3s1_oc4_ow4_oh2<bias_mode, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ow_remain, ld_oc, | |||
ow * oc_step); | |||
} | |||
} | |||
for (; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_3x3s1_oc4_ow4<bias_mode, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ow_step, ld_oc); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
ker_neon_dirctconv_3x3s1_oc4_ow4<bias_mode, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ow_remain, ld_oc); | |||
} | |||
} | |||
} | |||
} | |||
template <BiasMode bias_mode, int filter_size> | |||
void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, | |||
const int8_t* filter, | |||
const int16_t* bias, int16_t* dst, | |||
const size_t oc, const size_t ic, | |||
const size_t ih, const size_t iw, | |||
const size_t oh, const size_t ow) { | |||
constexpr size_t fh = filter_size; | |||
constexpr size_t fw = filter_size; | |||
constexpr size_t ic_step = 4; | |||
constexpr size_t oc_step = 4; | |||
constexpr size_t oh_step = 1; | |||
constexpr size_t ow_step = 8; | |||
const size_t img_stride = oh * ow; | |||
const int ld_dst_oc = oh * ow * oc_step; | |||
const size_t ow_end = ow / ow_step * ow_step; | |||
const size_t ow_remain = ow - ow_end; | |||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
KerNeonDirectStride1Int8<bias_mode, filter_size>::impl( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ow_step, ld_dst_oc); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
KerNeonDirectStride1Int8<bias_mode, filter_size>::impl( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ow_remain, ld_dst_oc); | |||
} | |||
} | |||
} | |||
} | |||
} // namespace | |||
namespace int8x8x16_direct_nchw44 { | |||
template <BiasMode bias_mode, int filter_size> | |||
struct ConvDirectInt8Nchw44Choose<bias_mode, filter_size, 1> { | |||
static void impl(const int8_t* src, const int8_t* filter, | |||
const int16_t* bias, int16_t* dst, const size_t oc, | |||
const size_t ic, const size_t ih, const size_t iw, | |||
const size_t oh, const size_t ow) { | |||
conv_direct_stride1_int8_nchw44_kern<bias_mode, filter_size>( | |||
src, filter, bias, dst, oc, ic, ih, iw, oh, ow); | |||
} | |||
}; | |||
template <BiasMode bias_mode> | |||
struct ConvDirectInt8Nchw44Choose<bias_mode, 2, 1> { | |||
static void impl(const int8_t* src, const int8_t* filter, | |||
const int16_t* bias, int16_t* dst, const size_t oc, | |||
const size_t ic, const size_t ih, const size_t iw, | |||
const size_t oh, const size_t ow) { | |||
conv_direct_stride1_2x2_int8_nchw44<bias_mode>(src, filter, bias, dst, | |||
oc, ic, ih, iw, oh, ow); | |||
} | |||
}; | |||
template <BiasMode bias_mode> | |||
struct ConvDirectInt8Nchw44Choose<bias_mode, 3, 1> { | |||
static void impl(const int8_t* src, const int8_t* filter, | |||
const int16_t* bias, int16_t* dst, const size_t oc, | |||
const size_t ic, const size_t ih, const size_t iw, | |||
const size_t oh, const size_t ow) { | |||
conv_direct_stride1_3x3_int8x8x16_oh2_nchw44<bias_mode>( | |||
src, filter, bias, dst, oc, ic, ih, iw, oh, ow); | |||
} | |||
}; | |||
#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode) \ | |||
template struct ConvDirectInt8Nchw44Choose<bias_mode, filter_size, stride>; | |||
#define GET_OP_PARAM(stride, filter, bias_mode) \ | |||
DO_CONV_KERN_FUN(stride, filter, bias_mode) | |||
#define GET_BIAS_MODE_PARAM(stride, filter) \ | |||
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
#define DISPATCH_CONV_KERN(stride) \ | |||
GET_BIAS_MODE_PARAM(stride, 2) \ | |||
GET_BIAS_MODE_PARAM(stride, 3) \ | |||
GET_BIAS_MODE_PARAM(stride, 5) \ | |||
GET_BIAS_MODE_PARAM(stride, 7) | |||
DISPATCH_CONV_KERN(1); | |||
} // namespace int8x8x16_direct_nchw44 | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,854 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8_direct_nchw44_s1_armv7.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/common/utils.h" | |||
#if MEGDNN_ARMV7 | |||
#include "src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace { | |||
#define INIT_SUM() \ | |||
int16x4_t init_sum; \ | |||
if (bias_mode == BiasMode::NO_BIAS) { \ | |||
init_sum = vdup_n_s16(0); \ | |||
} else { \ | |||
init_sum = vld1_s16(bias_ptr); \ | |||
} | |||
#define STORE_1_LINE_RESULT() \ | |||
switch (remain_w) { \ | |||
case 8: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||
vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \ | |||
break; \ | |||
case 1: \ | |||
vst1_s16(dst_ptr, c[0][0]); \ | |||
break; \ | |||
case 2: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
break; \ | |||
case 3: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1_s16(dst_ptr + 8, c[0][2]); \ | |||
break; \ | |||
case 4: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
break; \ | |||
case 5: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
vst1_s16(dst_ptr + 16, c[0][4]); \ | |||
break; \ | |||
case 6: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||
break; \ | |||
case 7: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||
vst1_s16(dst_ptr + 24, c[0][6]); \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0, "oc 1 error remainw"); \ | |||
}; | |||
#define STORE_2_LINE_RESULT() \ | |||
switch (remain_w) { \ | |||
case 8: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||
vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc + 8, \ | |||
vcombine_s16(c[1][2], c[1][3])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc + 16, \ | |||
vcombine_s16(c[1][4], c[1][5])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc + 24, \ | |||
vcombine_s16(c[1][6], c[1][7])); \ | |||
break; \ | |||
case 1: \ | |||
vst1_s16(dst_ptr, c[0][0]); \ | |||
vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ | |||
break; \ | |||
case 2: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||
break; \ | |||
case 3: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1_s16(dst_ptr + 8, c[0][2]); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||
vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ | |||
break; \ | |||
case 4: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc + 8, \ | |||
vcombine_s16(c[1][2], c[1][3])); \ | |||
break; \ | |||
case 5: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
vst1_s16(dst_ptr + 16, c[0][4]); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc + 8, \ | |||
vcombine_s16(c[1][2], c[1][3])); \ | |||
vst1_s16(dst_ptr + ld_dst_oc + 16, c[1][4]); \ | |||
break; \ | |||
case 6: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc + 8, \ | |||
vcombine_s16(c[1][2], c[1][3])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc + 16, \ | |||
vcombine_s16(c[1][4], c[1][5])); \ | |||
break; \ | |||
case 7: \ | |||
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ | |||
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ | |||
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ | |||
vst1_s16(dst_ptr + 24, c[0][6]); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc + 8, \ | |||
vcombine_s16(c[1][2], c[1][3])); \ | |||
vst1q_s16(dst_ptr + ld_dst_oc + 16, \ | |||
vcombine_s16(c[1][4], c[1][5])); \ | |||
vst1_s16(dst_ptr + ld_dst_oc + 24, c[1][6]); \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0, "oc 2 error remainw"); \ | |||
break; \ | |||
} | |||
template <BiasMode bias_mode, int remain_w, int filter_size> | |||
static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int16_t* bias_ptr, | |||
int16_t* dst_ptr, int ic, int ih, | |||
int iw, int ld_dst_oc) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int oc_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
constexpr int src_expand_size = 4; | |||
const int ic_stride = ih * iw * src_expand_size; | |||
const int ld_weight_oc4 = oc_step * fh * fw * ic; | |||
int16x4_t c[2][8]; | |||
int8x16_t weight[2][2]; | |||
int8x16_t src[4]; | |||
INIT_SUM(); | |||
#define cb(_i) \ | |||
c[0][_i] = init_sum; \ | |||
c[1][_i] = init_sum; | |||
UNROLL_CALL_RAW(8, cb); | |||
#undef cb | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
const int8_t* src_row0 = src_ptr + ic_idx * ic_stride + | |||
0 * iw * ic_step * src_expand_size; | |||
const int8_t* src_row1 = src_ptr + ic_idx * ic_stride + | |||
1 * iw * ic_step * src_expand_size; | |||
src[0] = vld1q_s8(src_row0); | |||
src[1] = vld1q_s8(src_row0 + 16); | |||
weight[0][0] = vld1q_s8(weight_ptr); | |||
weight[0][1] = vld1q_s8(weight_ptr + 16); | |||
weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4); | |||
weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 16); | |||
#define CALC_ONE_RESULT(_src0, _src1, _w0, _w1, _c) \ | |||
do { \ | |||
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w0)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w1)); \ | |||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||
} while (0); | |||
int16x8_t tmp0; | |||
src[2] = vld1q_s8(src_row0 + 2 * 16); | |||
src[3] = vld1q_s8(src_row0 + 3 * 16); | |||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][0]); | |||
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][0]); | |||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][1]); | |||
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][1]); | |||
src[0] = vld1q_s8(src_row0 + 4 * 16); | |||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][2]); | |||
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][2]); | |||
src[1] = vld1q_s8(src_row0 + 5 * 16); | |||
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][3]); | |||
CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][3]); | |||
src[2] = vld1q_s8(src_row0 + 6 * 16); | |||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][4]); | |||
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][4]); | |||
src[3] = vld1q_s8(src_row0 + 7 * 16); | |||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][5]); | |||
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][5]); | |||
src[0] = vld1q_s8(src_row0 + 8 * 16); | |||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][6]); | |||
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][6]); | |||
src[1] = vld1q_s8(src_row1 + 0 * 16); | |||
src[2] = vld1q_s8(src_row1 + 1 * 16); | |||
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][7]); | |||
CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][7]); | |||
weight[0][0] = vld1q_s8(weight_ptr + 32); | |||
weight[0][1] = vld1q_s8(weight_ptr + 48); | |||
src[3] = vld1q_s8(src_row1 + 2 * 16); | |||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][0]); | |||
weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4 + 32); | |||
weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 48); | |||
src[0] = vld1q_s8(src_row1 + 3 * 16); | |||
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][0]); | |||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][1]); | |||
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][1]); | |||
src[1] = vld1q_s8(src_row1 + 4 * 16); | |||
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][2]); | |||
CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][2]); | |||
src[2] = vld1q_s8(src_row1 + 5 * 16); | |||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][3]); | |||
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][3]); | |||
src[3] = vld1q_s8(src_row1 + 6 * 16); | |||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][4]); | |||
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][4]); | |||
src[0] = vld1q_s8(src_row1 + 7 * 16); | |||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][5]); | |||
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][5]); | |||
src[1] = vld1q_s8(src_row1 + 8 * 16); | |||
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][6]); | |||
CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][6]); | |||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][7]); | |||
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][7]); | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
STORE_2_LINE_RESULT(); | |||
} | |||
template <BiasMode bias_mode, int remain_w, int filter_size> | |||
static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, | |||
const int8_t* weight_ptr, | |||
const int16_t* bias_ptr, | |||
int16_t* dst_ptr, int ic, int ih, | |||
int iw, int /*ld_dst_oc*/) { | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
constexpr int src_expand_size = 4; | |||
const int ic_stride = ih * iw * src_expand_size; | |||
int16x4_t c[1][8]; | |||
int8x16_t weight[1][2]; | |||
int8x16_t src[4]; | |||
INIT_SUM(); | |||
#define cb(_i) c[0][_i] = init_sum; | |||
UNROLL_CALL_RAW(8, cb); | |||
#undef cb | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * src_expand_size; | |||
src[0] = vld1q_s8(src_ic_0_3); | |||
src[1] = vld1q_s8(src_ic_0_3 + 16); | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0][0] = vld1q_s8(read_weight_ptr); | |||
weight[0][1] = vld1q_s8(read_weight_ptr + 16); | |||
int16x8_t tmp0; | |||
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], | |||
c[0][0]); | |||
src[0] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], | |||
c[0][1]); | |||
src[1] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], | |||
c[0][2]); | |||
src[2] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], | |||
c[0][3]); | |||
src[3] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], | |||
c[0][4]); | |||
src[0] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], | |||
c[0][5]); | |||
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], | |||
c[0][6]); | |||
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], | |||
c[0][7]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
STORE_1_LINE_RESULT(); | |||
} | |||
#undef CALC_ONE_RESULT | |||
template <BiasMode bias_mode, int remain_w, int filter_size> | |||
struct KerNeonDirectStride1Int8 { | |||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, | |||
int iw, int ld_dst_oc); | |||
}; | |||
template <BiasMode bias_mode, int remain_w> | |||
struct KerNeonDirectStride1Int8<bias_mode, remain_w, 3> { | |||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, | |||
int iw, int /*ld_dst_oc*/) { | |||
constexpr int filter_size = 3; | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
constexpr int src_expand_size = 4; | |||
const int ic_stride = ih * iw * src_expand_size; | |||
int16x4_t c[1][8]; | |||
int8x16_t weight[3]; | |||
int8x16_t src[5]; | |||
INIT_SUM(); | |||
#define cb(_i) c[0][_i] = init_sum; | |||
UNROLL_CALL_RAW(8, cb); | |||
#undef cb | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = | |||
src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * src_expand_size; | |||
src[0] = vld1q_s8(src_ic_0_3 + 0 * 16); | |||
src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); | |||
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0] = vld1q_s8(read_weight_ptr); | |||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||
#define CALC_ONE_RESULT(_src0, _src1, _src2, _w0, _w1, _w2, _c) \ | |||
do { \ | |||
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w0)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w1)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w2)); \ | |||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||
} while (0); | |||
int16x8_t tmp0; | |||
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], weight[1], | |||
weight[2], c[0][0]); | |||
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], weight[1], | |||
weight[2], c[0][1]); | |||
src[0] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], weight[1], | |||
weight[2], c[0][2]); | |||
src[1] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||
CALC_ONE_RESULT(src[3], src[4], src[0], weight[0], weight[1], | |||
weight[2], c[0][3]); | |||
src[2] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||
CALC_ONE_RESULT(src[4], src[0], src[1], weight[0], weight[1], | |||
weight[2], c[0][4]); | |||
src[3] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], weight[1], | |||
weight[2], c[0][5]); | |||
src[4] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], weight[1], | |||
weight[2], c[0][6]); | |||
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], weight[1], | |||
weight[2], c[0][7]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
STORE_1_LINE_RESULT(); | |||
} | |||
}; | |||
#undef CALC_ONE_RESULT | |||
template <BiasMode bias_mode, int remain_w> | |||
struct KerNeonDirectStride1Int8<bias_mode, remain_w, 5> { | |||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, | |||
int iw, int /*ld_dst_oc*/) { | |||
constexpr int filter_size = 5; | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
constexpr int src_expand_size = 4; | |||
const int ic_stride = ih * iw * src_expand_size; | |||
int16x4_t c[1][8]; | |||
int8x16_t weight[5]; | |||
int8x16_t src[8 + 2]; | |||
INIT_SUM(); | |||
#define cb(_i) c[0][_i] = init_sum; | |||
UNROLL_CALL_RAW(8, cb); | |||
#undef cb | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = | |||
src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * src_expand_size; | |||
src[0] = vld1q_s8(src_ic_0_3 + 0 * 16); | |||
src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); | |||
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||
src[9] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0] = vld1q_s8(read_weight_ptr); | |||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||
#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _w0, _w1, _w2, _w3, \ | |||
_w4, _c) \ | |||
do { \ | |||
int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ | |||
int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w0)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ | |||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w1)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \ | |||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w2)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w3)); \ | |||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w3)); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w4)); \ | |||
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w4)); \ | |||
tmp0 = vaddq_s16(tmp0, tmp1); \ | |||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||
} while (0); | |||
CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][0]); | |||
CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][1]); | |||
CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][2]); | |||
CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][3]); | |||
CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][4]); | |||
CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][5]); | |||
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||
src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||
CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][6]); | |||
CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], | |||
weight[0], weight[1], weight[2], weight[3], | |||
weight[4], c[0][7]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
STORE_1_LINE_RESULT(); | |||
} | |||
}; | |||
#undef CALC_ONE_RESULT | |||
template <BiasMode bias_mode, int remain_w> | |||
struct KerNeonDirectStride1Int8<bias_mode, remain_w, 7> { | |||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, | |||
int iw, int /*ld_dst_oc*/) { | |||
constexpr int filter_size = 7; | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
constexpr int loop_ic_step = 4; | |||
constexpr int ld_weight_ic4 = 16; | |||
constexpr int src_expand_size = 4; | |||
const int ic_stride = ih * iw * src_expand_size; | |||
int16x4_t c[1][8]; | |||
int8x16_t weight[7]; | |||
int8x16_t src[8 + 2]; | |||
INIT_SUM(); | |||
#define cb(_i) c[0][_i] = init_sum; | |||
UNROLL_CALL_RAW(8, cb); | |||
#undef cb | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||
const int8_t* src_ic_0_3 = | |||
src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * src_expand_size; | |||
src[0] = vld1q_s8(src_ic_0_3 + 0 * 16); | |||
src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); | |||
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||
src[9] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||
const int8_t* read_weight_ptr = | |||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||
weight[0] = vld1q_s8(read_weight_ptr); | |||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); | |||
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); | |||
#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, \ | |||
_c) \ | |||
do { \ | |||
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w[0])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w[1])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w[2])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w[3])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src3), vget_high_s8(_w[3])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src4), vget_high_s8(_w[4])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src5), vget_low_s8(_w[5])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src5), vget_high_s8(_w[5])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \ | |||
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src6), vget_high_s8(_w[6])); \ | |||
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ | |||
} while (0); | |||
int16x8_t tmp0; | |||
CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], src[5], | |||
src[6], weight, c[0][0]); | |||
CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], src[6], | |||
src[7], weight, c[0][1]); | |||
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||
src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||
CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], src[7], | |||
src[8], weight, c[0][2]); | |||
CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], src[8], | |||
src[9], weight, c[0][3]); | |||
src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||
src[3] = vld1q_s8(src_ic_0_3 + 13 * 16); | |||
CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], src[9], | |||
src[0], weight, c[0][4]); | |||
CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], src[0], | |||
src[1], weight, c[0][5]); | |||
CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], src[1], | |||
src[2], weight, c[0][6]); | |||
CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], src[2], | |||
src[3], weight, c[0][7]); | |||
} | |||
weight_ptr += fh * fw * ld_weight_ic4; | |||
} | |||
STORE_1_LINE_RESULT(); | |||
} | |||
}; | |||
template <BiasMode bias_mode> | |||
void conv_direct_stride1_2x2_int8_oc8_ow8_nchw44( | |||
const int8_t* src, const int8_t* filter, const int16_t* bias, | |||
int16_t* dst, const size_t oc, const size_t ic, const size_t ih, | |||
const size_t iw, const size_t oh, const size_t ow) { | |||
constexpr size_t filter_size = 2; | |||
constexpr size_t fh = filter_size; | |||
constexpr size_t fw = filter_size; | |||
constexpr size_t ic_step = 4; | |||
constexpr size_t oc_step = 4; | |||
constexpr size_t big_oc_step = 8; | |||
constexpr size_t oh_step = 1; | |||
constexpr size_t ow_step = 8; | |||
constexpr size_t src_expand_size = 4; | |||
const size_t img_stride = oh * ow; | |||
const size_t ow_end = ow / ow_step * ow_step; | |||
const size_t ow_remain = ow - ow_end; | |||
const size_t oc_end = oc / big_oc_step * big_oc_step; | |||
const size_t oc_remain = oc - oc_end; | |||
const int ld_oc = oh * ow * oc_step; | |||
using remain_fun = | |||
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int16_t* bias_ptr, int16_t* dst_ptr, | |||
int ic, int ih, int iw, int ld_dst_oc)>; | |||
remain_fun kern_big_oc_remain = nullptr; | |||
remain_fun kern_small_oc_remain = nullptr; | |||
switch (ow_remain) { | |||
#define cb(step) \ | |||
case step: \ | |||
kern_big_oc_remain = ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, step, \ | |||
filter_size>; \ | |||
kern_small_oc_remain = \ | |||
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, step, \ | |||
filter_size>; \ | |||
break; | |||
UNROLL_CALL_RAW(8, cb); | |||
default: | |||
megdnn_assert(0, "no remain %zu for kern", ow_remain); | |||
} | |||
#undef cb | |||
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
size_t oh_idx = 0; | |||
for (; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_idx) * ic_step * src_expand_size; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, ow_step, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ld_oc); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_end) * ic_step * src_expand_size; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
kern_big_oc_remain(src + src_offset, filter + weight_offset, | |||
bias + oc_idx, dst + dst_offset, ic, ih, iw, | |||
ld_oc); | |||
} | |||
} | |||
} | |||
if (oc_remain > 0) { | |||
const size_t oc_idx = oc_end; | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_idx) * ic_step * src_expand_size; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, ow_step, filter_size>( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ld_oc); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_end) * ic_step * src_expand_size; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||
bias + oc_idx, dst + dst_offset, ic, ih, | |||
iw, ld_oc); | |||
} | |||
} | |||
} | |||
} | |||
#undef CALC_ONE_RESULT | |||
template <BiasMode bias_mode, int filter_size> | |||
void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, | |||
const int8_t* filter, | |||
const int16_t* bias, int16_t* dst, | |||
const size_t oc, const size_t ic, | |||
const size_t ih, const size_t iw, | |||
const size_t oh, const size_t ow) { | |||
constexpr size_t fh = filter_size; | |||
constexpr size_t fw = filter_size; | |||
constexpr size_t ic_step = 4; | |||
constexpr size_t oc_step = 4; | |||
constexpr size_t oh_step = 1; | |||
constexpr size_t ow_step = 8; | |||
constexpr size_t src_expand_size = 4; | |||
const size_t img_stride = oh * ow; | |||
const int ld_dst_oc = oh * ow * oc_step; | |||
const size_t ow_end = ow / ow_step * ow_step; | |||
const size_t ow_remain = ow - ow_end; | |||
using remain_fun = | |||
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int16_t* bias_ptr, int16_t* dst_ptr, | |||
int ic, int ih, int iw, int ld_dst_oc)>; | |||
remain_fun kern_small_oc_remain = nullptr; | |||
switch (ow_remain) { | |||
#define cb(step) \ | |||
case step: \ | |||
kern_small_oc_remain = KerNeonDirectStride1Int8<bias_mode, step, \ | |||
filter_size>::impl; \ | |||
break; | |||
UNROLL_CALL_RAW(8, cb); | |||
default: | |||
megdnn_assert(0, "no remain %zu for kern", ow_remain); | |||
} | |||
#undef cb | |||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_idx) * ic_step * src_expand_size; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
KerNeonDirectStride1Int8<bias_mode, ow_step, filter_size>::impl( | |||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||
dst + dst_offset, ic, ih, iw, ld_dst_oc); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = | |||
(oh_idx * iw + ow_end) * ic_step * src_expand_size; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||
bias + oc_idx, dst + dst_offset, ic, ih, | |||
iw, ld_dst_oc); | |||
} | |||
} | |||
} | |||
} | |||
} // namespace | |||
namespace int8x8x16_direct_nchw44 { | |||
template <BiasMode bias_mode, int filter_size> | |||
struct ConvDirectInt8Nchw44Choose<bias_mode, filter_size, 1> { | |||
static void impl(const int8_t* src, const int8_t* filter, | |||
const int16_t* bias, int16_t* dst, const size_t oc, | |||
const size_t ic, const size_t ih, const size_t iw, | |||
const size_t oh, const size_t ow) { | |||
conv_direct_stride1_int8_nchw44_kern<bias_mode, filter_size>( | |||
src, filter, bias, dst, oc, ic, ih, iw, oh, ow); | |||
} | |||
}; | |||
template <BiasMode bias_mode> | |||
struct ConvDirectInt8Nchw44Choose<bias_mode, 2, 1> { | |||
static void impl(const int8_t* src, const int8_t* filter, | |||
const int16_t* bias, int16_t* dst, const size_t oc, | |||
const size_t ic, const size_t ih, const size_t iw, | |||
const size_t oh, const size_t ow) { | |||
conv_direct_stride1_2x2_int8_oc8_ow8_nchw44<bias_mode>( | |||
src, filter, bias, dst, oc, ic, ih, iw, oh, ow); | |||
} | |||
}; | |||
#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode) \ | |||
template struct ConvDirectInt8Nchw44Choose<bias_mode, filter_size, stride>; | |||
#define GET_OP_PARAM(stride, filter, bias_mode) \ | |||
DO_CONV_KERN_FUN(stride, filter, bias_mode) | |||
#define GET_BIAS_MODE_PARAM(stride, filter) \ | |||
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
#define DISPATCH_CONV_KERN(stride) \ | |||
GET_BIAS_MODE_PARAM(stride, 2) \ | |||
GET_BIAS_MODE_PARAM(stride, 3) \ | |||
GET_BIAS_MODE_PARAM(stride, 5) \ | |||
GET_BIAS_MODE_PARAM(stride, 7) | |||
DISPATCH_CONV_KERN(1); | |||
} // namespace int8x8x16_direct_nchw44 | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -44,6 +44,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
AlgoQU8DirectStride1 qu8_direct_stride1; | |||
AlgoS8DirectStride2 s8_direct_stride2; | |||
AlgoS8DirectNCHW44 s8_direct_nchw44; | |||
AlgoS8x8x16DirectNCHW44 s8x8x16_direct_nchw44; | |||
AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; | |||
AlgoS8DirectStride1 s8_direct_stride1; | |||
AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; | |||
@@ -94,6 +95,7 @@ public: | |||
direct_algos.emplace_back(&qu8_direct_stride1); | |||
direct_algos.emplace_back(&s8_direct_stride2); | |||
direct_algos.emplace_back(&s8_direct_nchw44); | |||
direct_algos.emplace_back(&s8x8x16_direct_nchw44); | |||
direct_algos.emplace_back(&s8_direct_nchw_nchw44); | |||
direct_algos.emplace_back(&s8_direct_stride1); | |||
@@ -39,6 +39,7 @@ private: | |||
class AlgoS8DirectStride1; | |||
class AlgoS8DirectStride2; | |||
class AlgoS8DirectNCHW44; | |||
class AlgoS8x8x16DirectNCHW44; | |||
class AlgoS8DirectNCHWNCHW44; | |||
class AlgoQU8DirectStride1; | |||
class AlgoQU8DirectStride2; | |||
@@ -518,6 +518,116 @@ void benchmark_im2col_single_algo(const char* im2col_name, Handle* handle, | |||
} | |||
} | |||
void benchmark_nchw44_8x8x16_vs_8x8x32(const char* im2col_name, Handle* handle, | |||
size_t kernel, size_t stride, | |||
size_t pack_size = 1) { | |||
megdnn_assert(stride == 1 || stride == 2, "only support stride 1 or 2"); | |||
std::vector<conv_bias::TestArg> args; | |||
auto pack = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, | |||
size_t p) { | |||
if (ic % pack_size != 0 || oc % pack_size != 0) | |||
return; | |||
if (w + 2 * p < kernel || h + 2 * p < kernel) | |||
return; | |||
param::ConvBias param; | |||
param.format = param::ConvBias::Format::NCHW44; | |||
param.stride_h = stride; | |||
param.stride_w = stride; | |||
param.pad_h = p; | |||
param.pad_w = p; | |||
param.sparse = param::ConvBias::Sparse::DENSE; | |||
args.push_back(conv_bias::TestArg{ | |||
param, | |||
TensorShape{1, ic / 4, h, w, 4}, | |||
TensorShape{oc / 4, ic / 4, kernel, kernel, 4, 4}, | |||
{1, oc / 4, 1, 1, 4}}); | |||
}; | |||
pack(1, 64, 56, 56, kernel, 0); | |||
pack(8, 64, 56, 56, kernel, 0); | |||
pack(16, 64, 56, 56, kernel, 1); | |||
pack(32, 64, 56, 56, kernel, 1); | |||
pack(1, 64, 100, 100, kernel, 1); | |||
pack(8, 64, 100, 100, kernel, 1); | |||
pack(1, 64, 100, 100, kernel, 0); | |||
pack(8, 64, 100, 100, kernel, 0); | |||
pack(16, 64, 100, 100, kernel, 1); | |||
pack(32, 64, 100, 100, kernel, 1); | |||
pack(64, 64, 100, 100, kernel, 1); | |||
pack(128, 64, 100, 100, kernel, 1); | |||
pack(256, 64, 100, 100, kernel, 1); | |||
pack(512, 64, 100, 100, kernel, 1); | |||
pack(1024, 64, 100, 100, kernel, 1); | |||
pack(1, 32, 200, 200, kernel, 1); | |||
pack(8, 64, 200, 200, kernel, 1); | |||
pack(1, 32, 200, 200, kernel, 0); | |||
pack(8, 64, 200, 200, kernel, 0); | |||
pack(16, 96, 200, 200, kernel, 1); | |||
pack(32, 32, 200, 200, kernel, 1); | |||
pack(64, 64, 200, 200, kernel, 1); | |||
pack(128, 96, 200, 200, kernel, 1); | |||
pack(1, 64, 10, 10, kernel, 1); | |||
pack(8, 64, 10, 10, kernel, 1); | |||
pack(16, 64, 10, 10, kernel, 1); | |||
pack(32, 64, 10, 10, kernel, 1); | |||
pack(64, 64, 10, 10, kernel, 1); | |||
pack(128, 64, 10, 10, kernel, 1); | |||
pack(256, 64, 10, 10, kernel, 1); | |||
pack(512, 64, 10, 10, kernel, 1); | |||
pack(1024, 64, 10, 10, kernel, 1); | |||
using namespace conv_bias; | |||
constexpr size_t RUN = 20; | |||
Benchmarker<ConvBias> benchmark_im2col(handle); | |||
benchmark_im2col.set_display(false); | |||
benchmark_im2col.set_times(RUN); | |||
Benchmarker<ConvBias> benchmark_8832(handle); | |||
benchmark_8832.set_display(false); | |||
benchmark_8832.set_times(RUN); | |||
for (auto&& arg : args) { | |||
TensorLayout dst_layout; | |||
auto opr = handle->create_operator<ConvBias>(); | |||
opr->param() = arg.param; | |||
opr->deduce_layout({arg.src, dtype::Float32()}, | |||
{arg.filter, dtype::Float32()}, | |||
{arg.bias, dtype::Float32()}, {}, dst_layout); | |||
//! dst.nr_elems * IC * FH * FW * 2 | |||
float computations = dst_layout.total_nr_elems() * arg.filter[1] * | |||
arg.filter[2] * arg.filter[3] * 2.0 * 4 / | |||
(1024 * 1024 * 1024) * 1e3; | |||
benchmark_im2col.set_param(arg.param); | |||
benchmark_im2col.set_dtype(0, dtype::Int8()); | |||
benchmark_im2col.set_dtype(1, dtype::Int8()); | |||
benchmark_im2col.set_dtype(2, dtype::Int16()); | |||
benchmark_im2col.set_dtype(4, dtype::Int16()); | |||
auto used_8816 = | |||
algo_benchmark<ConvBias>(benchmark_im2col, | |||
{arg.src, arg.filter, {}, {}, {}}, | |||
im2col_name) / | |||
RUN; | |||
benchmark_8832.set_param(arg.param); | |||
benchmark_8832.set_dtype(0, dtype::QuantizedS8(2.5)); | |||
benchmark_8832.set_dtype(1, dtype::QuantizedS8(2.5)); | |||
benchmark_8832.set_dtype(2, dtype::QuantizedS32(6.25)); | |||
benchmark_8832.set_dtype(4, {}); | |||
auto used_8832 = | |||
algo_benchmark<ConvBias>(benchmark_8832, | |||
{arg.src, arg.filter, {}, {}, {}}, | |||
"S8_NCHW44_DIRECT") / | |||
RUN; | |||
printf("%s %s: 8816: %f ms %f GFlops ", arg.src.to_string().c_str(), | |||
arg.filter.to_string().c_str(), used_8816, | |||
computations / used_8816); | |||
printf("%s %s: 8832: %f ms %f GFlops ", arg.src.to_string().c_str(), | |||
arg.filter.to_string().c_str(), used_8832, | |||
computations / used_8832); | |||
printf("speedup %f \n", used_8832 / used_8816); | |||
} | |||
} | |||
void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name, | |||
const char* im2col_name, Handle* handle, | |||
size_t kernel, DType src_type, | |||
@@ -872,6 +982,28 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_MATMUL) { | |||
#endif | |||
#if MEGDNN_WITH_BENCHMARK | |||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_8X8X16_DIRECT_STRIDE1) { | |||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 2, 1, | |||
4); | |||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 3, 1, | |||
4); | |||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 5, 1, | |||
4); | |||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 7, 1, | |||
4); | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_8X8X16_DIRECT_STRIDE2) { | |||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 2, 2, | |||
4); | |||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 3, 2, | |||
4); | |||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 5, 2, | |||
4); | |||
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 7, 2, | |||
4); | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23) { | |||
#if MEGDNN_AARCH64 | |||
benchmark_winograd("WINOGRAD:AARCH64_F32:1:2", handle(), 3); | |||
@@ -534,11 +534,25 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) { | |||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), | |||
handle(), "S8_NCHW44_DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8816) { | |||
checker_conv_bias_int8x8x16( | |||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true), | |||
handle(), "S8x8x16_NCHW44_DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8816) { | |||
checker_conv_bias_int8x8x16( | |||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true), | |||
handle(), "S8x8x16_NCHW44_DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) { | |||
checker_conv_bias_qint8x8x32( | |||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true), | |||
handle(), "S8_NCHW44_DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) { | |||
checker_conv_bias_qint8x8x32( | |||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true), | |||