GitOrigin-RevId: 9d718eb7a4
tags/v0.5.0
@@ -37,7 +37,7 @@ static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, | |||||
static void get_rectified_size( | static void get_rectified_size( | ||||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, | const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, | ||||
int& iw2, int& oh2, int& ow2) { | int& iw2, int& oh2, int& ow2) { | ||||
constexpr int cacheline = 64 / sizeof(float); | |||||
constexpr int nr_elements_in_cacheline = 64 / sizeof(float); | |||||
int ic = param.filter_meta.icpg; | int ic = param.filter_meta.icpg; | ||||
int iw = param.isz[1]; | int iw = param.isz[1]; | ||||
int oh = param.osz[0]; | int oh = param.osz[0]; | ||||
@@ -52,7 +52,8 @@ static void get_rectified_size( | |||||
int block_oh = l2_block_helper(param.nr_threads, oh, | int block_oh = l2_block_helper(param.nr_threads, oh, | ||||
ic * iw * sizeof(float) * stride_h); | ic * iw * sizeof(float) * stride_h); | ||||
ih2 = block_oh * stride_h + filter_h - stride_h; | ih2 = block_oh * stride_h + filter_h - stride_h; | ||||
iw2 = round_up(iw + 2 * static_cast<int>(fm.padding[1]), cacheline); | |||||
iw2 = round_up(iw + 2 * static_cast<int>(fm.padding[1]), | |||||
nr_elements_in_cacheline); | |||||
} | } | ||||
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | ||||
@@ -90,9 +90,9 @@ public: | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
}; | }; | ||||
class ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44 final : public AlgoBase { | |||||
class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { | |||||
public: | public: | ||||
AlgoS8DirectStride2NCHWNCHW44() {} | |||||
AlgoS8DirectNCHWNCHW44() {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return "S8_CONV_NCHW_NCHW44"; } | const char* name() const override { return "S8_CONV_NCHW_NCHW44"; } | ||||
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | ||||
@@ -0,0 +1,373 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megdnn/oprs.h" | |||||
#include "src/arm_common/conv_bias/int8/algos.h" | |||||
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" | |||||
#include "src/arm_common/conv_bias/int8/strategy.h" | |||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | |||||
#include "midout.h" | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
using conv_fun = std::function<void( | |||||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | |||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw_nchw44) | |||||
static void get_rectified_size( | |||||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, | |||||
int& iw2, int& oh2, int& ow2) { | |||||
auto&& fm = param.filter_meta; | |||||
int ih = param.isz[0]; | |||||
int iw = param.isz[1]; | |||||
int oh = param.osz[0]; | |||||
int ow = param.osz[1]; | |||||
int ph = fm.padding[0]; | |||||
int pw = fm.padding[1]; | |||||
int stride_h = fm.stride[0]; | |||||
oh2 = oh; | |||||
ow2 = ow; | |||||
ih2 = stride_h == 2 ? round_up(ih + 2 * ph, 2) : ih + 2 * ph; | |||||
iw2 = iw + 2 * pw; | |||||
} | |||||
static inline size_t get_temp_bytes(const int iw, const int pw) { | |||||
//! border_size is used to avoid read illegal memory | |||||
constexpr int cacheline_size = 64; | |||||
constexpr int border_size = 1 * cacheline_size; | |||||
return round_up(iw + pw * 2, cacheline_size) + border_size; | |||||
} | |||||
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||||
auto&& fm = param.filter_meta; | |||||
int group = fm.group; | |||||
int batch = param.n; | |||||
int ic = fm.icpg; | |||||
int oc = fm.ocpg; | |||||
int fh = fm.spatial[0]; | |||||
int fw = fm.spatial[1]; | |||||
int stride_h = fm.stride[0]; | |||||
int iw = param.isz[1]; | |||||
int pw = fm.padding[1]; | |||||
int ih2, iw2, oh2, ow2; | |||||
const size_t src_expand = stride_h == 2 ? 4 : 16; | |||||
get_rectified_size(param, ih2, iw2, oh2, ow2); | |||||
megdnn_assert(group == 1, "only support group == 1 now"); | |||||
size_t src_size = | |||||
batch * group * ic * ih2 * iw2 * sizeof(int8_t) * src_expand; | |||||
size_t weight_size = group * oc * ic * fh * fw * sizeof(int8_t); | |||||
size_t tmp_size = 0; | |||||
if (stride_h == 1) { | |||||
weight_size = group * oc * ic * fh * round_up(fw, 4) * sizeof(int8_t); | |||||
tmp_size = get_temp_bytes(iw, pw); | |||||
} | |||||
return {nullptr, {src_size, weight_size, tmp_size * param.nr_threads}}; | |||||
}; | |||||
static void copy_padding_kern(WorkspaceBundle bundle, | |||||
const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids) { | |||||
int ih = kern_param.isz[0]; | |||||
int iw = kern_param.isz[1]; | |||||
int ic = kern_param.filter_meta.icpg; | |||||
int ph = kern_param.filter_meta.padding[0]; | |||||
int pw = kern_param.filter_meta.padding[1]; | |||||
int group = kern_param.filter_meta.group; | |||||
int stride_h = kern_param.filter_meta.stride[0]; | |||||
int ih2, iw2, oh2, ow2; | |||||
get_rectified_size(kern_param, ih2, iw2, oh2, ow2); | |||||
int padding_group_size = ih2 * iw2 * ic; | |||||
bundle.set(kern_param.workspace_ptr); | |||||
//! Used for get the workspace offset | |||||
const int src_expand = stride_h == 2 ? 4 : 16; | |||||
//! TODO: block dim is better to get from arg | |||||
int workspace_ic_block = 1; | |||||
int workspace_batch_id = workspace_ids[0]; | |||||
int workspace_group_id = workspace_ids[1]; | |||||
int workspace_ic_id = workspace_ids[2]; | |||||
int workspace_ic = workspace_ic_id * workspace_ic_block; | |||||
int batch_id = ncb_index.ndrange_id[0]; | |||||
int group_id = ncb_index.ndrange_id[1]; | |||||
const int8_t* sptr = static_cast<const int8_t*>( | |||||
kern_param.src<int8_t>(batch_id, group_id, workspace_ic_id, 1, 1)); | |||||
//! copy to sptr_base to eliminate padding effect | |||||
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) + | |||||
(workspace_batch_id * group * padding_group_size + | |||||
workspace_group_id * padding_group_size + | |||||
workspace_ic * ih2 * iw2) * | |||||
src_expand; | |||||
if (stride_h == 1) { | |||||
const size_t tmp_size = get_temp_bytes(iw, pw); | |||||
int8_t* tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) + | |||||
ncb_index.thread_id * tmp_size; | |||||
pack_nchw_src_for_nchw44_conv<1>(sptr, sptr_base, 1, ph, ph, pw, pw, ih, | |||||
iw, iw2, pw, tmp_ptr); | |||||
} else { | |||||
pack_nchw_src_for_nchw44_conv<2>(sptr, sptr_base, 1, ph, ph, pw, pw, ih, | |||||
iw, iw2, pw, nullptr); | |||||
} | |||||
} | |||||
static void pack_weight(WorkspaceBundle bundle, | |||||
const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index) { | |||||
bundle.set(kern_param.workspace_ptr); | |||||
const int group_id = ncb_index.ndrange_id[0]; | |||||
int fh = kern_param.filter_meta.spatial[0]; | |||||
int fw = kern_param.filter_meta.spatial[1]; | |||||
int oc = kern_param.filter_meta.ocpg; | |||||
int ic = kern_param.filter_meta.icpg; | |||||
int stride_h = kern_param.filter_meta.stride[0]; | |||||
int fw2 = stride_h == 2 ? fw : round_up(fw, 4); | |||||
int oc_block = oc; | |||||
int oc_idx = 0; | |||||
const int8_t* fptr = | |||||
kern_param.filter<dt_int8>(group_id) + oc_idx * fh * fw * ic; | |||||
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | |||||
group_id * oc * ic * fh * fw2 + oc_idx * ic * fh * fw2; | |||||
if (stride_h == 1) { | |||||
pack_nchw44_weight_for_nchw_conv<1>(fptr, packed_weight, ic, fh, fw, | |||||
oc_block); | |||||
} else { | |||||
pack_nchw44_weight_for_nchw_conv<2>(fptr, packed_weight, ic, fh, fw, | |||||
oc_block); | |||||
} | |||||
} | |||||
template <size_t filter, BiasMode bias_mode, typename Op, int stride> | |||||
static void do_conv_kern(WorkspaceBundle bundle, | |||||
const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids, | |||||
const CpuNDRange& ncb_range) { | |||||
int oh = kern_param.osz[0]; | |||||
int ow = kern_param.osz[1]; | |||||
int fh = kern_param.filter_meta.spatial[0]; | |||||
int fw = kern_param.filter_meta.spatial[1]; | |||||
int fw2 = stride == 2 ? fw : round_up(fw, 4); | |||||
int ic = kern_param.filter_meta.icpg; | |||||
int oc = kern_param.filter_meta.ocpg; | |||||
int group = kern_param.filter_meta.group; | |||||
int ih2, iw2, oh2, ow2; | |||||
get_rectified_size(kern_param, ih2, iw2, oh2, ow2); | |||||
bool need_post_process = | |||||
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||||
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) | |||||
Op op = Op(1.0f, 4.0f); | |||||
if (need_post_process) { | |||||
float scale_bias = | |||||
kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||||
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | |||||
op = Op(scale_bias, scale_dst); | |||||
} | |||||
int padding_group_size = ih2 * iw2 * ic; | |||||
bundle.set(kern_param.workspace_ptr); | |||||
constexpr int pack_c = 4; | |||||
constexpr int src_expand_size = stride == 2 ? 4 : 16; | |||||
const int workspace_batch_id = workspace_ids[0]; | |||||
const int workspace_group_id = workspace_ids[1]; | |||||
const int batch_id = ncb_index.ndrange_id[0]; | |||||
const int group_id = ncb_index.ndrange_id[1]; | |||||
const int oc_id = ncb_index.ndrange_id[2]; | |||||
const int oc_block_num = ncb_range[2]; | |||||
int nr_pack_per_step = div_ceil(div_ceil(oc, pack_c), oc_block_num); | |||||
int oc_block = nr_pack_per_step * pack_c; | |||||
const int oc_idx = oc_id * oc_block; | |||||
if (oc_id == (oc_block_num - 1)) { | |||||
oc_block = oc - oc_id * nr_pack_per_step * pack_c; | |||||
} | |||||
megdnn_assert(oc_block % pack_c == 0, | |||||
"oc must be devisible by 4, but oc = %d", oc_block); | |||||
const int8_t* sptr = | |||||
static_cast<int8_t*>(bundle.get(0)) + | |||||
workspace_batch_id * group * padding_group_size * src_expand_size + | |||||
workspace_group_id * padding_group_size * src_expand_size; | |||||
int8_t* dst = reinterpret_cast<int8_t*>( | |||||
reinterpret_cast<ptrdiff_t>( | |||||
kern_param.dst<void>(batch_id, group_id)) + | |||||
oc_idx * oh * ow); | |||||
const int32_t* bptr = | |||||
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx; | |||||
int8_t* packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | |||||
group_id * oc * ic * fh * fw2 + | |||||
oc_idx * ic * fh * fw2; | |||||
conv_direct_int8_nchw_nchw44<bias_mode, Op, filter, stride>( | |||||
sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih2, iw2, oh, | |||||
ow, op); | |||||
} | |||||
bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
MEGDNN_MARK_USED_VAR(algo_selection_strategy); | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
auto OC = fm.ocpg; | |||||
bool avaible = //! src and filter are qint8, dst is qint8 | |||||
fm.icpg < 4 && // must be nchw input | |||||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && | |||||
(fm.format == param::Convolution::Format::NCHW44) && | |||||
(OC % 4 == 0 && OC >= 4) && !fm.should_flip && fm.group == 1 && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && | |||||
(fm.stride[0] == 1 || fm.stride[0] == 2) && FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && fm.group == 1 && | |||||
param.bias_mode != BiasMode::BIAS; | |||||
return avaible; | |||||
} | |||||
bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred( | |||||
megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, | |||||
const NCBKernSizeParam& param) const { | |||||
// TODO: benchmark and fix | |||||
MEGDNN_MARK_USED_VAR(conv_bias_impl_ptr); | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoS8DirectNCHWNCHW44::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
auto fm = param.filter_meta; | |||||
size_t N = param.n; | |||||
size_t OC = fm.ocpg; | |||||
size_t group = fm.group; | |||||
WorkspaceBundle wbundle = get_bundle(param); | |||||
conv_fun do_conv_fun = nullptr; | |||||
// NOTE: remain_w is not used to gen hash of midout for compatible with changing | |||||
// shape runtime | |||||
#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44, \ | |||||
midout_iv(#stride #filter #bias_mode #op##_hash)) { \ | |||||
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \ | |||||
} \ | |||||
MIDOUT_END(); | |||||
#define GET_OP_PARAM(stride, filter, bias_mode) \ | |||||
switch (param.nonlineMode) { \ | |||||
case param::ConvBias::NonlineMode::IDENTITY: \ | |||||
DO_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||||
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
case param::ConvBias::NonlineMode::RELU: \ | |||||
DO_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||||
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
case param::ConvBias::NonlineMode::H_SWISH: \ | |||||
DO_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||||
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
break; \ | |||||
} | |||||
#define GET_BIAS_MODE_PARAM(stride, filter) \ | |||||
switch (param.bias_mode) { \ | |||||
case BiasMode::NO_BIAS: \ | |||||
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||||
break; \ | |||||
case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
break; \ | |||||
} | |||||
#define DISPATCH_CONV_KERN(stride) \ | |||||
switch (param.filter_meta.spatial[0]) { \ | |||||
case 2: \ | |||||
GET_BIAS_MODE_PARAM(stride, 2) \ | |||||
break; \ | |||||
case 3: \ | |||||
GET_BIAS_MODE_PARAM(stride, 3) \ | |||||
break; \ | |||||
case 5: \ | |||||
GET_BIAS_MODE_PARAM(stride, 5) \ | |||||
break; \ | |||||
case 7: \ | |||||
GET_BIAS_MODE_PARAM(stride, 7) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
break; \ | |||||
} | |||||
switch (param.filter_meta.stride[0]) { | |||||
case 1: | |||||
DISPATCH_CONV_KERN(1); | |||||
break; | |||||
case 2: | |||||
DISPATCH_CONV_KERN(2); | |||||
break; | |||||
default: | |||||
megdnn_throw(ssprintf("Unsupport stride size %u for the first conv", | |||||
param.filter_meta.stride[0]) | |||||
.c_str()); | |||||
break; | |||||
} | |||||
#undef DO_CONV_KERN_FUN | |||||
#undef GET_REMAIN_W_PARAM | |||||
#undef GET_OP_PARAM | |||||
#undef GET_BIAS_MODE_PARAM | |||||
#undef DISPATCH_CONV_KERN | |||||
megdnn_assert(do_conv_fun); | |||||
SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||||
WorkspaceBundle bundle = wbundle; | |||||
constexpr size_t pack_oc = 8; | |||||
size_t oc_step = pack_oc; | |||||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({copy_padding, {N, group, fm.icpg}}); | |||||
auto do_pack_weight = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
pack_weight(bundle, kern_param, ncb_index); | |||||
}; | |||||
ret_kerns.push_back({do_pack_weight, {static_cast<size_t>(group)}}); | |||||
CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; | |||||
auto do_conv = [bundle, do_conv_fun, ncb_range]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, | |||||
ncb_range); | |||||
}; | |||||
ret_kerns.push_back({do_conv, ncb_range}); | |||||
return ret_kerns; | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -1,305 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_algo.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megdnn/oprs.h" | |||||
#include "src/arm_common/conv_bias/int8/algos.h" | |||||
#include "src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h" | |||||
#include "src/arm_common/conv_bias/int8/strategy.h" | |||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | |||||
#include "midout.h" | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
using conv_fun = std::function<void( | |||||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | |||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw_nchw44_stride2) | |||||
static void get_rectified_size( | |||||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||||
size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { | |||||
auto&& fm = param.filter_meta; | |||||
size_t IH = param.isz[0]; | |||||
size_t IW = param.isz[1]; | |||||
size_t OH = param.osz[0]; | |||||
size_t OW = param.osz[1]; | |||||
OH2 = OH; | |||||
OW2 = OW; | |||||
IH2 = round_up(IH + 2 * fm.padding[0], static_cast<size_t>(2)); | |||||
IW2 = IW + 2 * fm.padding[1]; | |||||
} | |||||
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||||
constexpr size_t src_expand = 4; | |||||
auto&& fm = param.filter_meta; | |||||
size_t group = fm.group; | |||||
size_t batch = param.n; | |||||
size_t IC = fm.icpg; | |||||
size_t OC = fm.ocpg; | |||||
size_t FH = fm.spatial[0]; | |||||
size_t FW = fm.spatial[1]; | |||||
size_t IH2, IW2, OH2, OW2; | |||||
get_rectified_size(param, IH2, IW2, OH2, OW2); | |||||
megdnn_assert(group == 1, "only support group == 1 now"); | |||||
size_t src_size = | |||||
batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; | |||||
size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); | |||||
return {nullptr, {src_size, weight_size}}; | |||||
}; | |||||
static void copy_padding_kern(WorkspaceBundle bundle, | |||||
const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids) { | |||||
size_t IH = kern_param.isz[0]; | |||||
size_t IW = kern_param.isz[1]; | |||||
size_t IC = kern_param.filter_meta.icpg; | |||||
size_t PH = kern_param.filter_meta.padding[0]; | |||||
size_t PW = kern_param.filter_meta.padding[1]; | |||||
size_t GROUP = kern_param.filter_meta.group; | |||||
size_t IH2, IW2, OH2, OW2; | |||||
get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||||
size_t padding_group_size = IH2 * IW2 * IC; | |||||
bundle.set(kern_param.workspace_ptr); | |||||
//! Used for get the workspace offset | |||||
constexpr int expend_element = 4; | |||||
// TODO: block dim is better to get from arg | |||||
size_t workspace_ic_block = 1; | |||||
size_t workspace_batch_id = workspace_ids[0]; | |||||
size_t workspace_group_id = workspace_ids[1]; | |||||
size_t workspace_ic_id = workspace_ids[2]; | |||||
size_t workspace_ic = workspace_ic_id * workspace_ic_block; | |||||
size_t batch_id = ncb_index.ndrange_id[0]; | |||||
size_t group_id = ncb_index.ndrange_id[1]; | |||||
const int8_t* sptr = static_cast<const int8_t*>( | |||||
kern_param.src<int8_t>(batch_id, group_id, workspace_ic_id, 1, 1)); | |||||
//! copy to sptr_base to eliminate padding effect | |||||
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) + | |||||
(workspace_batch_id * GROUP * padding_group_size + | |||||
workspace_group_id * padding_group_size + | |||||
workspace_ic * IH2 * IW2) * | |||||
expend_element; | |||||
conv_bias::pack_nchw_src_for_nchw44_conv(sptr, sptr_base, 1, PH, PH, PW, PW, | |||||
IH, IW); | |||||
} | |||||
template <size_t filter, BiasMode bias_mode, typename Op> | |||||
static void do_conv_kern(WorkspaceBundle bundle, | |||||
const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids, | |||||
const CpuNDRange& ncb_range) { | |||||
size_t OH = kern_param.osz[0]; | |||||
size_t OW = kern_param.osz[1]; | |||||
size_t FH = kern_param.filter_meta.spatial[0]; | |||||
size_t FW = kern_param.filter_meta.spatial[1]; | |||||
size_t IC = kern_param.filter_meta.icpg; | |||||
size_t OC = kern_param.filter_meta.ocpg; | |||||
size_t GROUP = kern_param.filter_meta.group; | |||||
size_t IH2, IW2, OH2, OW2; | |||||
get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||||
bool need_post_process = | |||||
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||||
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) | |||||
Op op = Op(1.0f, 4.0f); | |||||
if (need_post_process) { | |||||
float scale_bias = | |||||
kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||||
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | |||||
op = Op(scale_bias, scale_dst); | |||||
} | |||||
size_t padding_group_size = IH2 * IW2 * IC; | |||||
bundle.set(kern_param.workspace_ptr); | |||||
constexpr size_t pack_c = 4; | |||||
constexpr size_t src_expand_size = 4; | |||||
const size_t workspace_batch_id = workspace_ids[0]; | |||||
const size_t workspace_group_id = workspace_ids[1]; | |||||
const size_t batch_id = ncb_index.ndrange_id[0]; | |||||
const size_t group_id = ncb_index.ndrange_id[1]; | |||||
const size_t oc_id = ncb_index.ndrange_id[2]; | |||||
const size_t oc_block_num = ncb_range[2]; | |||||
size_t nr_pack_per_step = div_ceil(div_ceil(OC, pack_c), oc_block_num); | |||||
size_t oc_block = nr_pack_per_step * pack_c; | |||||
const size_t oc_idx = oc_id * oc_block; | |||||
if (oc_id == (oc_block_num - 1)) { | |||||
oc_block = OC - oc_id * nr_pack_per_step * pack_c; | |||||
} | |||||
megdnn_assert(oc_block % pack_c == 0, | |||||
"oc must be devisible by 4, but oc = %zu", oc_block); | |||||
const int8_t* sptr = | |||||
static_cast<int8_t*>(bundle.get(0)) + | |||||
workspace_batch_id * GROUP * padding_group_size * src_expand_size + | |||||
workspace_group_id * padding_group_size * src_expand_size; | |||||
const int8_t* fptr = | |||||
kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC; | |||||
void* dst = reinterpret_cast<void*>( | |||||
reinterpret_cast<ptrdiff_t>( | |||||
kern_param.dst<void>(batch_id, group_id)) + | |||||
oc_idx * OH * OW); | |||||
const int32_t* bptr = | |||||
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx; | |||||
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | |||||
group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; | |||||
conv_bias::pack_nchw44_weight_for_nchw_conv(fptr, packed_weight, IC, FH, FW, | |||||
oc_block); | |||||
#define KERN1_NCHW44_CONV(filter) \ | |||||
conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw_nchw44< \ | |||||
bias_mode, Op>(sptr, packed_weight, bptr, nullptr, \ | |||||
static_cast<int8_t*>(dst), oc_block, IC, IH2, IW2, \ | |||||
OH, OW, op) | |||||
DISPATCH_FILTER(filter, KERN1_NCHW44_CONV); | |||||
#undef KERN1_NCHW44_CONV | |||||
} | |||||
/* ===================== stride2 algo ===================== */ | |||||
bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::usable( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
MEGDNN_MARK_USED_VAR(algo_selection_strategy); | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
auto OC = fm.ocpg; | |||||
bool avaible = //! src and filter are qint8, dst is qint8 | |||||
fm.icpg < 4 && // must be nchw input | |||||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && | |||||
(fm.format == param::Convolution::Format::NCHW44) && | |||||
(OC % 4 == 0 && OC >= 4) && !fm.should_flip && fm.group == 1 && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||||
FH == fm.spatial[1] && (FH == 3 || FH == 5 || FH == 7) && | |||||
fm.group == 1 && param.bias_mode != BiasMode::BIAS; | |||||
return avaible; | |||||
} | |||||
bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::is_preferred( | |||||
megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, | |||||
const NCBKernSizeParam& param) const { | |||||
// TODO: benchmark and fix | |||||
MEGDNN_MARK_USED_VAR(conv_bias_impl_ptr); | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
auto fm = param.filter_meta; | |||||
size_t N = param.n; | |||||
size_t OC = fm.ocpg; | |||||
size_t group = fm.group; | |||||
WorkspaceBundle wbundle = get_bundle(param); | |||||
conv_fun do_conv_fun = nullptr; | |||||
// NOTE: remain_w is not used to gen hash of midout for compatible with changing | |||||
// shape runtime | |||||
#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44_stride2, \ | |||||
midout_iv(#filter #bias_mode #op##_hash)) { \ | |||||
do_conv_fun = do_conv_kern<filter, bias_mode, op>; \ | |||||
} \ | |||||
MIDOUT_END(); | |||||
#define GET_OP_PARAM(filter, bias_mode) \ | |||||
switch (param.nonlineMode) { \ | |||||
case param::ConvBias::NonlineMode::IDENTITY: \ | |||||
DO_CONV_KERN_FUN(filter, bias_mode, \ | |||||
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
case param::ConvBias::NonlineMode::RELU: \ | |||||
DO_CONV_KERN_FUN(filter, bias_mode, \ | |||||
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
case param::ConvBias::NonlineMode::H_SWISH: \ | |||||
DO_CONV_KERN_FUN(filter, bias_mode, \ | |||||
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
break; \ | |||||
} | |||||
#define GET_BIAS_MODE_PARAM(filter) \ | |||||
switch (param.bias_mode) { \ | |||||
case BiasMode::NO_BIAS: \ | |||||
GET_OP_PARAM(filter, BiasMode::NO_BIAS) \ | |||||
break; \ | |||||
case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
break; \ | |||||
} | |||||
#define DISPATCH_CONV_KERN() \ | |||||
switch (param.filter_meta.spatial[0]) { \ | |||||
case 3: \ | |||||
GET_BIAS_MODE_PARAM(3) \ | |||||
break; \ | |||||
case 5: \ | |||||
GET_BIAS_MODE_PARAM(5) \ | |||||
break; \ | |||||
case 7: \ | |||||
GET_BIAS_MODE_PARAM(7) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
break; \ | |||||
} | |||||
DISPATCH_CONV_KERN(); | |||||
#undef DO_CONV_KERN_FUN | |||||
#undef GET_REMAIN_W_PARAM | |||||
#undef GET_OP_PARAM | |||||
#undef GET_BIAS_MODE_PARAM | |||||
#undef DISPATCH_CONV_KERN | |||||
megdnn_assert(do_conv_fun); | |||||
SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||||
WorkspaceBundle bundle = wbundle; | |||||
constexpr size_t pack_oc = 8; | |||||
size_t oc_step = pack_oc; | |||||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({copy_padding, {N, group, fm.icpg}}); | |||||
CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; | |||||
auto do_conv = [bundle, do_conv_fun, ncb_range]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, | |||||
ncb_range); | |||||
}; | |||||
ret_kerns.push_back({do_conv, ncb_range}); | |||||
return ret_kerns; | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -1,789 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern_nchw.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h" | |||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | |||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
template <int src_idx, int weight_idx, int c_dim, typename Func, typename T, | |||||
typename T2, typename T3, typename T4> | |||||
struct ShiftCalHelper { | |||||
static void impl(T& c, T2& src, T3& weight, T4& temp); | |||||
static void impl(T& c, T2& src, T3& weight); | |||||
}; | |||||
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
typename T3, typename T4> | |||||
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, T, T2, T3, T4> { | |||||
static void impl(T& c, T2& src, T3& weight, T4& temp) { | |||||
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], | |||||
temp[0]); | |||||
c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0], | |||||
temp[1]); | |||||
c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], | |||||
temp[2]); | |||||
c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1], | |||||
temp[3]); | |||||
c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], | |||||
temp[0]); | |||||
c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2], | |||||
temp[1]); | |||||
c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], | |||||
temp[2]); | |||||
c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3], | |||||
temp[3]); | |||||
} | |||||
static void impl(T& c, T2& src, T3& weight) { | |||||
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); | |||||
c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0]); | |||||
c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); | |||||
c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1]); | |||||
c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]); | |||||
c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2]); | |||||
c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]); | |||||
c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3]); | |||||
} | |||||
}; | |||||
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
typename T3, typename T4> | |||||
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, T, T2, T3, T4> { | |||||
static void impl(T& c, T2& src, T3& weight, T4& temp) { | |||||
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], | |||||
temp[0]); | |||||
c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], | |||||
temp[2]); | |||||
c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], | |||||
temp[0]); | |||||
c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], | |||||
temp[2]); | |||||
} | |||||
static void impl(T& c, T2& src, T3& weight) { | |||||
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); | |||||
c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); | |||||
c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]); | |||||
c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]); | |||||
} | |||||
}; | |||||
template <int src_idx, int weight_idx, int c_dim, typename FUNC, typename T, | |||||
typename T2, typename T3, typename T4> | |||||
inline void cal_helper(T& c, T2& src, T3& weight, T4& temp) { | |||||
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, T, T2, T3, T4>::impl( | |||||
c, src, weight, temp); | |||||
} | |||||
template <int src_idx, int weight_idx, int c_dim, typename FUNC, typename T, | |||||
typename T2, typename T3> | |||||
inline void cal_helper(T& c, T2& src, T3& weight) { | |||||
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, T, T2, T3, int>::impl( | |||||
c, src, weight); | |||||
}; | |||||
template <int oc> | |||||
struct OCHelper { | |||||
public: | |||||
static const int val = 0; | |||||
}; | |||||
template <> | |||||
struct OCHelper<4> { | |||||
public: | |||||
static const int val = 1; | |||||
}; | |||||
template <> | |||||
struct OCHelper<8> { | |||||
public: | |||||
static const int val = 2; | |||||
}; | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size, | |||||
int oc_block> | |||||
struct KerNeonXXs2NchwNchw44 { | |||||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
int iw, int ld_dst_oc, const Op& op); | |||||
}; | |||||
/** | |||||
* filter shape = (oc/4, ic, 7, 7, 4), first 4 oc is f0 = filter[0, 0, :, :, :] | |||||
* calculate sequence \ | |||||
* f0[0:1, 0:1, 4] dot4, \ | |||||
* f0[0:1, 2:3, 4] dot4, \ | |||||
* f0[0:1, 4:5, 4] dot4, \ | |||||
* f0[0:1, 6, 4] dot2, \ | |||||
* ... | |||||
* f0[6, 0:1, 4] dot2, \ | |||||
* f0[6, 2:3, 4] dot2, \ | |||||
* f0[6, 4:5, 4] dot2, \ | |||||
* f0[6, 6, 4] dot1, \ | |||||
* look like: | |||||
* |---|---|---|-| | |||||
* |x x|x x|x x|x| | |||||
* |x x|x x|x x|x| | |||||
* |---|---|---|-| | |||||
* |x x|x x|x x|x| | |||||
* |x x|x x|x x|x| | |||||
* |---|---|---|-| | |||||
* |x x|x x|x x|x| | |||||
* |x x|x x|x x|x| | |||||
* |---|---|---|-| | |||||
* |x x|x x|x x|x| | |||||
* |---|---|---|-| | |||||
**/ | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||||
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block> { | |||||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
int iw, int ld_dst_oc, const Op& op) { | |||||
static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, | |||||
0, 8, 0, 8, 0, 8, 0, 8}; | |||||
constexpr int filter_size = 7; | |||||
constexpr int ic_step = 1; | |||||
constexpr int oc_step = 4; | |||||
constexpr int pack_iw_len = 4; | |||||
constexpr int fh_step = 2; | |||||
constexpr int fh_end = filter_size / fh_step * fh_step; | |||||
constexpr int c_dim = OCHelper<oc_block>::val; | |||||
const int ic_stride = ih * iw * pack_iw_len; | |||||
const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; | |||||
int32x4_t c[c_dim][4]; | |||||
init_ocx_ow4<c_dim, bias_mode>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { | |||||
const int8_t* nchw_src_ptr = | |||||
src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * pack_iw_len; | |||||
int8x16_t src[6]; | |||||
int8x16_t dot4_weight[c_dim][3]; | |||||
int16x8_t temp_c[4]; | |||||
load_helper<3, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, | |||||
ld_dot4_weight_oc); | |||||
load_helper<6, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); | |||||
cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight, | |||||
temp_c); | |||||
cal_helper<1, 1, c_dim, Vdotq_s32_h>(c, src, dot4_weight, | |||||
temp_c); | |||||
cal_helper<2, 2, c_dim, Vdotq_s32_h>(c, src, dot4_weight, | |||||
temp_c); | |||||
int8x8_t src_dot2[4]; | |||||
int8x8_t dot2_weight[c_dim][1]; | |||||
load_helper<1, 3 * 16, 8, c_dim, Vld1_s8>( | |||||
dot2_weight, weight_ptr, ld_dot4_weight_oc); | |||||
load_helper<4, 3 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, | |||||
0); | |||||
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||||
temp_c); | |||||
weight_ptr += filter_size * pack_iw_len * fh_step; | |||||
} | |||||
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + | |||||
6 * iw * ic_step * pack_iw_len; | |||||
int8x8_t dot2_weight[c_dim][3]; | |||||
int16x8_t temp_c[4]; | |||||
int8x8_t src_dot2[6]; | |||||
uint8x16_t tbl = vld1q_u8(src_idx_buffer); | |||||
load_helper<3, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, | |||||
ld_dot4_weight_oc); | |||||
load_helper_x<6, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, | |||||
0, tbl); | |||||
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||||
temp_c); | |||||
cal_helper<1, 1, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||||
temp_c); | |||||
cal_helper<2, 2, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||||
temp_c); | |||||
int16x8_t dot1_weight[c_dim][1]; | |||||
int16x8_t src_dot1[4]; | |||||
load_helper<1, 3 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( | |||||
dot1_weight, weight_ptr, ld_dot4_weight_oc); | |||||
load_helper<4, 3 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, | |||||
nchw_src_ptr, 0); | |||||
cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight); | |||||
weight_ptr += filter_size * pack_iw_len; | |||||
} | |||||
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc); | |||||
} | |||||
}; | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||||
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block> { | |||||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
int iw, int ld_dst_oc, const Op& op) { | |||||
constexpr int filter_size = 5; | |||||
static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, | |||||
0, 8, 0, 8, 0, 8, 0, 8}; | |||||
constexpr int ih_step = 2; | |||||
constexpr int ic_step = 1; | |||||
constexpr int oc_step = 4; | |||||
constexpr int pack_iw_len = 4; | |||||
constexpr int fh_end = filter_size / ih_step * ih_step; | |||||
const int ic_stride = ih * iw * pack_iw_len; | |||||
const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; | |||||
constexpr int c_dim = OCHelper<oc_block>::val; | |||||
int32x4_t c[c_dim][4]; | |||||
init_ocx_ow4<c_dim, bias_mode>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh_end; fh_idx += ih_step) { | |||||
const int8_t* nchw_src_ptr = | |||||
src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * pack_iw_len; | |||||
int8x16_t src[5]; | |||||
int8x16_t dot4_weight[c_dim][2]; | |||||
int16x8_t temp_c[4]; | |||||
load_helper<2, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, | |||||
ld_dot4_weight_oc); | |||||
load_helper<5, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); | |||||
cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight, | |||||
temp_c); | |||||
cal_helper<1, 1, c_dim, Vdotq_s32_h>(c, src, dot4_weight, | |||||
temp_c); | |||||
int8x8_t src_dot2[4]; | |||||
int8x8_t dot2_weight[c_dim][1]; | |||||
load_helper<1, 2 * 16, 8, c_dim, Vld1_s8>( | |||||
dot2_weight, weight_ptr, ld_dot4_weight_oc); | |||||
load_helper<4, 2 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, | |||||
0); | |||||
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||||
temp_c); | |||||
weight_ptr += filter_size * pack_iw_len * ih_step; | |||||
} | |||||
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + | |||||
fh_end * iw * ic_step * pack_iw_len; | |||||
int8x8_t dot2_weight[c_dim][2]; | |||||
int16x8_t temp_c[4]; | |||||
int8x8_t src_dot2[5]; | |||||
uint8x16_t tbl = vld1q_u8(src_idx_buffer); | |||||
load_helper<2, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, | |||||
ld_dot4_weight_oc); | |||||
load_helper_x<5, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, | |||||
0, tbl); | |||||
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||||
temp_c); | |||||
cal_helper<1, 1, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||||
temp_c); | |||||
int16x8_t dot1_weight[c_dim][1]; | |||||
int16x8_t src_dot1[4]; | |||||
load_helper<1, 2 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( | |||||
dot1_weight, weight_ptr, ld_dot4_weight_oc); | |||||
load_helper<4, 2 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, | |||||
nchw_src_ptr, 0); | |||||
cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight); | |||||
weight_ptr += filter_size * pack_iw_len; | |||||
} | |||||
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc); | |||||
} | |||||
}; | |||||
/** | |||||
* filter shape = (oc/4, ic, 3, 3, 4), first 4 oc is f0 = filter[0, 0, :, :, :] | |||||
* calculate sequence \ | |||||
* f0[0:1, 0:1, 4] dot4, \ | |||||
* f0[0:1, 2, 4] dot2, \ | |||||
* f0[2, 0:1, 4] dot2, \ | |||||
* f0[2, 2, 4] dot1 \ | |||||
* look like: | |||||
* |---|-| | |||||
* |x x|x| | |||||
* |x x|x| | |||||
* |-----| | |||||
* |x x|x| | |||||
* |-----| | |||||
**/ | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||||
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block> { | |||||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
int iw, int ld_dst_oc, const Op& op) { | |||||
constexpr int filter_size = 3; | |||||
static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, | |||||
0, 8, 0, 8, 0, 8, 0, 8}; | |||||
constexpr int oc_step = 4; | |||||
constexpr int ic_step = 1; | |||||
constexpr int loop_ic_step = 1; | |||||
constexpr int pack_iw_len = 4; | |||||
const int ic_stride = ih * iw * pack_iw_len; | |||||
const int ld_weight_oc = oc_step * filter_size * filter_size * ic; | |||||
constexpr int c_dim = OCHelper<oc_block>::val; | |||||
int32x4_t c[c_dim][4]; | |||||
init_ocx_ow4<c_dim, bias_mode>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
// first 2 line | |||||
{ | |||||
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||||
int8x16_t src[4]; | |||||
int8x16_t dot4_weight[c_dim][1]; | |||||
int16x8_t temp_c[4]; | |||||
load_helper<1, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, | |||||
ld_weight_oc); | |||||
load_helper<4, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); | |||||
cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight, | |||||
temp_c); | |||||
int8x8_t src_dot2[4]; | |||||
int8x8_t dot2_weight[c_dim][1]; | |||||
load_helper<1, 1 * 16, 8, c_dim, Vld1_s8>( | |||||
dot2_weight, weight_ptr, ld_weight_oc); | |||||
load_helper<4, 1 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, | |||||
0); | |||||
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||||
temp_c); | |||||
} | |||||
// last line | |||||
{ | |||||
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + | |||||
2 * iw * ic_step * pack_iw_len; | |||||
int16x8_t temp_c[4]; | |||||
int8x8_t src_dot2[4]; | |||||
int8x8_t dot2_weight[c_dim][1]; | |||||
uint8x16_t tbl = vld1q_u8(src_idx_buffer); | |||||
load_helper<1, 24, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, | |||||
ld_weight_oc); | |||||
load_helper_x<4, 0, 16, 0, Vldq_tbl_low_s8>( | |||||
src_dot2, nchw_src_ptr, 0, tbl); | |||||
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||||
temp_c); | |||||
int16x8_t dot1_weight[c_dim][1]; | |||||
int16x8_t src_dot1[4]; | |||||
load_helper<1, 32, 8, c_dim, Vldq_dup_4s8_8s16>( | |||||
dot1_weight, weight_ptr, ld_weight_oc); | |||||
load_helper<4, 1 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, | |||||
nchw_src_ptr, 0); | |||||
cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight); | |||||
weight_ptr += filter_size * filter_size * pack_iw_len; | |||||
} | |||||
} | |||||
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc); | |||||
} | |||||
}; | |||||
} // namespace | |||||
enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 }; | |||||
template <PACK_MODE mode> | |||||
inline void pack_src_one_line(const int8_t* inptr, int8_t* outptr, int left_pad, | |||||
int right_pad, const int iw) { | |||||
const int8_t* src_row_0 = inptr; | |||||
const int8_t* src_row_1 = inptr + iw; | |||||
constexpr int combine_row = 2; | |||||
constexpr int iw_step = 16; | |||||
constexpr int src_expand = 4; | |||||
constexpr int out_gap = iw_step * src_expand; | |||||
const int iw_end = iw / iw_step * iw_step; | |||||
memset(outptr, 0, combine_row * left_pad * src_expand * sizeof(int8_t)); | |||||
outptr += combine_row * left_pad * src_expand; | |||||
for (int iw_idx = 0; iw_idx < iw_end; iw_idx += iw_step) { | |||||
int8x16_t row0 = vld1q_s8(src_row_0 + iw_idx); | |||||
int8x16_t row1 = vdupq_n_s8(0); | |||||
if (mode == PACK_MODE::NO_PAD) { | |||||
row1 = vld1q_s8(src_row_1 + iw_idx); | |||||
} else if (mode == PACK_MODE::FIRST_PAD) { | |||||
row1 = row0; | |||||
row0 = vdupq_n_s8(0); | |||||
} | |||||
int8x16x2_t pack_rows = vzipq_s8(row0, row1); | |||||
#define STORE_8S8(step) \ | |||||
vst1_s8(outptr + step * 8, \ | |||||
vreinterpret_s8_s16(vdup_laneq_s16( \ | |||||
vreinterpretq_s16_s8(pack_rows.val[0]), step))); | |||||
UNROLL_CALL_RAW(8, STORE_8S8); | |||||
#undef STORE_8S8 | |||||
#define STORE_8S8(step) \ | |||||
vst1_s8(outptr + out_gap + step * 8, \ | |||||
vreinterpret_s8_s16(vdup_laneq_s16( \ | |||||
vreinterpretq_s16_s8(pack_rows.val[1]), step))); | |||||
UNROLL_CALL_RAW(8, STORE_8S8); | |||||
#undef STORE_8S8 | |||||
outptr += out_gap * combine_row; | |||||
} | |||||
for (int iw_idx = iw_end; iw_idx < iw; iw_idx++) { | |||||
int8x8_t row0 = vld1_dup_s8(src_row_0 + iw_idx); | |||||
int8x8_t row1 = vdup_n_s8(0); | |||||
if (mode == PACK_MODE::NO_PAD) { | |||||
row1 = vld1_dup_s8(src_row_1 + iw_idx); | |||||
} else if (mode == PACK_MODE::FIRST_PAD) { | |||||
row1 = row0; | |||||
row0 = vdup_n_s8(0); | |||||
} | |||||
int8x8x2_t pack_rows = vzip_s8(row0, row1); | |||||
vst1_s8(outptr, pack_rows.val[0]); | |||||
outptr += src_expand * combine_row; | |||||
} | |||||
memset(outptr, 0, combine_row * right_pad * src_expand * sizeof(int8_t)); | |||||
outptr += combine_row * right_pad * src_expand; | |||||
} | |||||
/** | |||||
* pack (ic, h, w) to (ic, h / 2, 2 * w) | |||||
* pack interleave two adjacent row in src and repeat 4 times, store to one row | |||||
* */ | |||||
void conv_bias::pack_nchw_src_for_nchw44_conv( | |||||
const int8_t* inptr, int8_t* outptr, const int ic, const int top_pad, | |||||
const int bottom_pad, const int left_pad, const int right_pad, | |||||
const int ih, const int iw) { | |||||
constexpr int src_expand = 4; | |||||
constexpr int oh_step = 2; | |||||
const int oh = ih + top_pad + bottom_pad; | |||||
const int oh_end = div_floor(ih + top_pad, oh_step) * oh_step; | |||||
const int ow = (iw + left_pad + right_pad) * src_expand; | |||||
for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { | |||||
int oh_idx = 0; | |||||
for (; oh_idx < top_pad; oh_idx += oh_step) { | |||||
if (top_pad - oh_idx >= oh_step) { | |||||
memset(outptr, 0, oh_step * ow * sizeof(int8_t)); | |||||
} else { | |||||
pack_src_one_line<PACK_MODE::FIRST_PAD>(inptr, outptr, left_pad, | |||||
right_pad, iw); | |||||
inptr += iw; | |||||
} | |||||
outptr += oh_step * ow; | |||||
} | |||||
for (; oh_idx < oh_end; oh_idx += oh_step) { | |||||
pack_src_one_line<PACK_MODE::NO_PAD>(inptr, outptr, left_pad, | |||||
right_pad, iw); | |||||
inptr += oh_step * iw; | |||||
outptr += oh_step * ow; | |||||
} | |||||
for (; oh_idx < oh; oh_idx += oh_step) { | |||||
const int last_pad = oh_idx - ih - top_pad; | |||||
if (last_pad >= 0) { | |||||
memset(outptr, 0, oh_step * ow * sizeof(int8_t)); | |||||
} else { | |||||
pack_src_one_line<PACK_MODE::LAST_PAD>(inptr, outptr, left_pad, | |||||
right_pad, iw); | |||||
inptr += iw; | |||||
} | |||||
outptr += oh_step * ow; | |||||
} | |||||
} | |||||
} | |||||
/** | |||||
* pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh * fw, 4(oc)} | |||||
* pack interleave two adjacent row in filter to one row | |||||
* */ | |||||
void conv_bias::pack_nchw44_weight_for_nchw_conv(const int8_t* inptr, | |||||
int8_t* outptr, const int ic, | |||||
const int fh, const int fw, | |||||
const int oc) { | |||||
constexpr int oc_step = 4; | |||||
constexpr int ic_step = 2; | |||||
constexpr int fh_step = 2; | |||||
constexpr int fw_step = 2; | |||||
const int ic_end = ic / ic_step * ic_step; | |||||
const int ic_remain = ic - ic_end; | |||||
const int fh_end = fh / fh_step * fh_step; | |||||
const int fh_remain = fh - fh_end; | |||||
const int fw_end = fw / fw_step * fw_step; | |||||
const int fw_remain = fw - fw_end; | |||||
const int filter_stride = ic * oc_step; | |||||
static const uint8_t ic2_idx_h_buffer[16] = {0, 8, 1, 9, 2, 10, 3, 11, | |||||
4, 12, 5, 13, 6, 14, 7, 15}; | |||||
uint8x16_t ic2_idx_h = vld1q_u8(ic2_idx_h_buffer); | |||||
for (int oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
for (int ic_idx = 0; ic_idx < ic_end; ic_idx += ic_step) { | |||||
const int ic_offset = ic_idx * oc_step; | |||||
int8_t* output_ic0 = outptr + ic_idx * fh * fw * oc_step; | |||||
int8_t* output_ic1 = output_ic0 + fh * fw * oc_step; | |||||
for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { | |||||
const int fh_offset = fh_idx * fw * filter_stride; | |||||
for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { | |||||
const int8_t* filter_ptr = inptr + fh_offset + | |||||
fw_idx * filter_stride + | |||||
ic_offset; | |||||
int8x8_t row_0 = vld1_s8(filter_ptr); | |||||
int8x8_t row_1 = vld1_s8(filter_ptr + fw * filter_stride); | |||||
int8x16_t combine_row = vcombine_s8(row_0, row_1); | |||||
combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); | |||||
vst1_s8(output_ic0, vget_low_s8(combine_row)); | |||||
vst1_s8(output_ic1, vget_high_s8(combine_row)); | |||||
output_ic0 += 8; | |||||
output_ic1 += 8; | |||||
} | |||||
} | |||||
if (fh_remain > 0) { | |||||
const int fh_offset = fh_end * fw * filter_stride; | |||||
for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { | |||||
const int8_t* filter_ptr = inptr + fh_offset + | |||||
fw_idx * filter_stride + | |||||
ic_offset; | |||||
int8x8_t row_0 = vld1_s8(filter_ptr); | |||||
int8x8_t row_1 = vld1_s8(filter_ptr + filter_stride); | |||||
int8x16_t combine_row = vcombine_s8(row_0, row_1); | |||||
combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); | |||||
vst1_s8(output_ic0, vget_low_s8(combine_row)); | |||||
vst1_s8(output_ic1, vget_high_s8(combine_row)); | |||||
output_ic0 += 8; | |||||
output_ic1 += 8; | |||||
} | |||||
if (fw_remain > 0) { | |||||
const int8_t* filter_ptr = inptr + fh_offset + | |||||
fw_end * filter_stride + | |||||
ic_offset; | |||||
int8x8_t row_0 = vld1_s8(filter_ptr); | |||||
vst1_lane_s32((int32_t*)output_ic0, | |||||
vreinterpret_s32_s8(row_0), 0); | |||||
vst1_lane_s32((int32_t*)output_ic1, | |||||
vreinterpret_s32_s8(row_0), 1); | |||||
output_ic0 += 4; | |||||
output_ic1 += 4; | |||||
} | |||||
} | |||||
} | |||||
if (ic_remain > 0) { | |||||
const int ic_offset = ic_end * oc_step; | |||||
int8_t* output_ic0 = outptr + ic_end * fh * fw * oc_step; | |||||
for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { | |||||
const int fh_offset = fh_idx * fw * filter_stride; | |||||
for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { | |||||
const int8_t* filter_ptr = inptr + fh_offset + | |||||
fw_idx * filter_stride + | |||||
ic_offset; | |||||
int8x8_t row_0 = vreinterpret_s8_s32( | |||||
vld1_dup_s32((const int32_t*)(filter_ptr))); | |||||
int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( | |||||
(const int32_t*)(filter_ptr + fw * filter_stride))); | |||||
int8x16_t combine_row = vcombine_s8(row_0, row_1); | |||||
combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); | |||||
vst1_s8(output_ic0, vget_low_s8(combine_row)); | |||||
output_ic0 += 8; | |||||
} | |||||
} | |||||
if (fh_remain > 0) { | |||||
const int fh_offset = fh_end * fw * filter_stride; | |||||
for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { | |||||
const int8_t* filter_ptr = inptr + fh_offset + | |||||
fw_idx * filter_stride + | |||||
ic_offset; | |||||
int8x8_t row_0 = vreinterpret_s8_s32( | |||||
vld1_dup_s32((const int32_t*)(filter_ptr))); | |||||
int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( | |||||
(const int32_t*)(filter_ptr + filter_stride))); | |||||
int8x16_t combine_row = vcombine_s8(row_0, row_1); | |||||
combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); | |||||
vst1_s8(output_ic0, vget_low_s8(combine_row)); | |||||
output_ic0 += 8; | |||||
} | |||||
if (fw_remain > 0) { | |||||
const int8_t* filter_ptr = inptr + fh_offset + | |||||
fw_end * filter_stride + | |||||
ic_offset; | |||||
*(int32_t*)(output_ic0) = *(const int32_t*)(filter_ptr); | |||||
output_ic0 += 4; | |||||
} | |||||
} | |||||
} | |||||
inptr += oc_step * fh * fw * ic; | |||||
outptr += oc_step * fh * fw * ic; | |||||
} | |||||
} | |||||
template <BiasMode bias_mode, typename Op, size_t filter_size> | |||||
static void conv_direct_stride2_int8_nchw_nchw44( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(temp); | |||||
constexpr size_t fh = filter_size; | |||||
constexpr size_t fw = filter_size; | |||||
constexpr size_t ic_step = 1; | |||||
constexpr size_t big_oc_step = 8; | |||||
constexpr size_t oc_step = 4; | |||||
constexpr size_t ih_step = 2; | |||||
constexpr size_t oh_step = 1; | |||||
constexpr size_t ow_step = 4; | |||||
constexpr size_t stride_h = 2; | |||||
constexpr size_t stride_w = 2; | |||||
constexpr int pack_iw_len = 4; | |||||
const size_t img_stride = oh * ow; | |||||
const size_t ow_end = ow / ow_step * ow_step; | |||||
const size_t ow_remain = ow - ow_end; | |||||
const size_t oc_end = oc / big_oc_step * big_oc_step; | |||||
const size_t oc_remain = oc - oc_end; | |||||
const int ld_dst_oc = oc_step * img_stride; | |||||
using remain_fun = | |||||
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, | |||||
int ih, int iw, int ld_dst_oc, const Op& op)>; | |||||
remain_fun kern_big_oc_remain = nullptr; | |||||
remain_fun kern_small_oc_remain = nullptr; | |||||
switch (ow_remain) { | |||||
#define cb(step) \ | |||||
case step: \ | |||||
kern_big_oc_remain = \ | |||||
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \ | |||||
big_oc_step>::impl; \ | |||||
kern_small_oc_remain = \ | |||||
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \ | |||||
oc_step>::impl; \ | |||||
break; | |||||
UNROLL_CALL_RAW(4, cb); | |||||
default: | |||||
megdnn_assert(0, "no remain %zu for kern", ow_remain); | |||||
} | |||||
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = | |||||
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||||
ic_step * pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
KerNeonXXs2NchwNchw44<bias_mode, Op, 0, filter_size, | |||||
big_oc_step>::impl(src + src_offset, | |||||
filter + weight_offset, | |||||
bias + oc_idx, | |||||
dst + dst_offset, ic, | |||||
ih, iw, ld_dst_oc, op); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = | |||||
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||||
ic_step * pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
kern_big_oc_remain(src + src_offset, filter + weight_offset, | |||||
bias + oc_idx, dst + dst_offset, ic, ih, iw, | |||||
ld_dst_oc, op); | |||||
} | |||||
} | |||||
} | |||||
if (oc_remain > 0) { | |||||
size_t oc_idx = oc_end; | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = | |||||
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||||
ic_step * pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
KerNeonXXs2NchwNchw44<bias_mode, Op, 0, filter_size, | |||||
oc_step>::impl(src + src_offset, | |||||
filter + weight_offset, | |||||
bias + oc_idx, | |||||
dst + dst_offset, ic, ih, | |||||
iw, ld_dst_oc, op); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = | |||||
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||||
ic_step * pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||||
bias + oc_idx, dst + dst_offset, ic, ih, | |||||
iw, ld_dst_oc, op); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#define CONSTRUCT_FUNC(filter_size) \ | |||||
template <BiasMode bias_mode, typename Op> \ | |||||
void conv_bias:: \ | |||||
conv_direct_stride2_##filter_size##x##filter_size##_int8_nchw_nchw44( \ | |||||
const int8_t* src, const int8_t* filter, \ | |||||
const int32_t* bias, int32_t* temp, int8_t* dst, \ | |||||
const size_t oc, const size_t ic, const size_t ih, \ | |||||
const size_t iw, const size_t oh, const size_t ow, \ | |||||
const Op& op) { \ | |||||
conv_direct_stride2_int8_nchw_nchw44<bias_mode, Op, filter_size>( \ | |||||
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); \ | |||||
} | |||||
CONSTRUCT_FUNC(3); | |||||
CONSTRUCT_FUNC(5); | |||||
CONSTRUCT_FUNC(7); | |||||
#undef CONSTRUCT_FUNC | |||||
template <BiasMode bias_mode, typename Op> | |||||
void conv_bias::conv_direct_stride2_2x2_int8_nchw_nchw44( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(src); | |||||
MEGDNN_MARK_USED_VAR(filter); | |||||
MEGDNN_MARK_USED_VAR(bias); | |||||
MEGDNN_MARK_USED_VAR(temp); | |||||
MEGDNN_MARK_USED_VAR(dst); | |||||
MEGDNN_MARK_USED_VAR(oc); | |||||
MEGDNN_MARK_USED_VAR(ic); | |||||
MEGDNN_MARK_USED_VAR(ih); | |||||
MEGDNN_MARK_USED_VAR(iw); | |||||
MEGDNN_MARK_USED_VAR(oh); | |||||
MEGDNN_MARK_USED_VAR(ow); | |||||
MEGDNN_MARK_USED_VAR(op); | |||||
megdnn_assert(0, "not imple nchw_nchw44 2x2s2 conv"); | |||||
} | |||||
#define INSTANTIATION(stride, i, bias, Op) \ | |||||
template void conv_bias:: \ | |||||
conv_direct_##stride##_##i##x##i##_int8_nchw_nchw44<bias, Op>( \ | |||||
const int8_t*, const int8_t*, const int32_t*, int32_t*, \ | |||||
int8_t*, const size_t, const size_t, const size_t, \ | |||||
const size_t, const size_t, const size_t, const Op&); | |||||
#define FOR_OP(stride, i, bias) \ | |||||
INSTANTIATION(stride, i, bias, TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
INSTANTIATION(stride, i, bias, ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
INSTANTIATION(stride, i, bias, HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||||
#define FOR_BIAS(stride, i) \ | |||||
FOR_OP(stride, i, BiasMode::NO_BIAS) \ | |||||
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
#define FOR_FILTER(stride) \ | |||||
FOR_BIAS(stride, 2) \ | |||||
FOR_BIAS(stride, 3) \ | |||||
FOR_BIAS(stride, 5) \ | |||||
FOR_BIAS(stride, 7) | |||||
FOR_FILTER(stride2) | |||||
#undef FOR_STRIDE | |||||
#undef FOR_FILTER | |||||
#undef FOR_IC | |||||
#undef FOR_BIAS | |||||
#undef FOR_NONLINEAR | |||||
#undef FOR_REMAIN | |||||
#undef INSTANTIATION |
@@ -1,44 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/opr_impl.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace conv_bias { | |||||
#define KERN(stride, i, layout) \ | |||||
template <BiasMode bias_mode, typename Op> \ | |||||
void conv_direct_##stride##_##i##x##i##_int8_nchw_##layout( \ | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, \ | |||||
int32_t* temp, int8_t* dst, const size_t OC, const size_t IC, \ | |||||
const size_t IH, const size_t IW, const size_t OH, \ | |||||
const size_t OW, const Op& op); | |||||
KERN(stride2, 2, nchw44) | |||||
KERN(stride2, 3, nchw44) | |||||
KERN(stride2, 5, nchw44) | |||||
KERN(stride2, 7, nchw44) | |||||
#undef KERN | |||||
void pack_nchw44_weight_for_nchw_conv(const int8_t* inptr, int8_t* outptr, | |||||
const int ic, const int fh, const int fw, | |||||
const int oc); | |||||
void pack_nchw_src_for_nchw44_conv(const int8_t* inptr, int8_t* outptr, | |||||
const int ic, const int top_pad, | |||||
const int bottom_pad, const int left_pad, | |||||
const int right_pad, const int ih, | |||||
const int iw); | |||||
} // namespace conv_bias | |||||
} // namespace arm_common | |||||
} // namespace megdnn |
@@ -47,7 +47,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
AlgoS8DirectStride2 s8_direct_stride2_large_group{true}; | AlgoS8DirectStride2 s8_direct_stride2_large_group{true}; | ||||
AlgoS8DirectStride2 s8_direct_stride2_small_group{false}; | AlgoS8DirectStride2 s8_direct_stride2_small_group{false}; | ||||
AlgoS8DirectStride2NCHW44 s8_direct_stride2_nchw44; | AlgoS8DirectStride2NCHW44 s8_direct_stride2_nchw44; | ||||
AlgoS8DirectStride2NCHWNCHW44 s8_direct_stride2_nchw_nchw44; | |||||
AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; | |||||
AlgoS8DirectStride1 s8_direct_stride1_large_group{true}; | AlgoS8DirectStride1 s8_direct_stride1_large_group{true}; | ||||
AlgoS8DirectStride1 s8_direct_stride1_small_group{false}; | AlgoS8DirectStride1 s8_direct_stride1_small_group{false}; | ||||
AlgoS8DirectStride1NCHW44 s8_direct_stride1_nchw44; | AlgoS8DirectStride1NCHW44 s8_direct_stride1_nchw44; | ||||
@@ -115,7 +115,7 @@ public: | |||||
direct_algos.emplace_back(&s8_direct_stride2_large_group); | direct_algos.emplace_back(&s8_direct_stride2_large_group); | ||||
direct_algos.emplace_back(&s8_direct_stride2_small_group); | direct_algos.emplace_back(&s8_direct_stride2_small_group); | ||||
direct_algos.emplace_back(&s8_direct_stride2_nchw44); | direct_algos.emplace_back(&s8_direct_stride2_nchw44); | ||||
direct_algos.emplace_back(&s8_direct_stride2_nchw_nchw44); | |||||
direct_algos.emplace_back(&s8_direct_nchw_nchw44); | |||||
direct_algos.emplace_back(&s8_direct_stride1_large_group); | direct_algos.emplace_back(&s8_direct_stride1_large_group); | ||||
direct_algos.emplace_back(&s8_direct_stride1_small_group); | direct_algos.emplace_back(&s8_direct_stride1_small_group); | ||||
direct_algos.emplace_back(&s8_direct_stride1_nchw44); | direct_algos.emplace_back(&s8_direct_stride1_nchw44); | ||||
@@ -40,7 +40,7 @@ private: | |||||
class AlgoS8DirectStride1NCHW44; | class AlgoS8DirectStride1NCHW44; | ||||
class AlgoS8DirectStride2; | class AlgoS8DirectStride2; | ||||
class AlgoS8DirectStride2NCHW44; | class AlgoS8DirectStride2NCHW44; | ||||
class AlgoS8DirectStride2NCHWNCHW44; | |||||
class AlgoS8DirectNCHWNCHW44; | |||||
class AlgoQU8DirectStride1; | class AlgoQU8DirectStride1; | ||||
class AlgoQU8DirectStride2; | class AlgoQU8DirectStride2; | ||||
class AlgoFP32WinogradF23_4x4; | class AlgoFP32WinogradF23_4x4; | ||||
@@ -244,18 +244,26 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) { | |||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | ||||
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", true); | "IM2COLMATMUL:AARCH64_F32K8X12X1:192", true); | ||||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | |||||
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false); | |||||
#else | #else | ||||
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384", | benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384", | ||||
"IM2COLMATMUL:ARMV7_F32:192", true); | "IM2COLMATMUL:ARMV7_F32:192", true); | ||||
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384", | |||||
"IM2COLMATMUL:ARMV7_F32:192", false); | |||||
#endif | #endif | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | ||||
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", true); | "IM2COLMATMUL:AARCH64_F32K8X12X1:192", true); | ||||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | |||||
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false); | |||||
#else | #else | ||||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | ||||
"IM2COLMATMUL:ARMV7_F32:192", true); | "IM2COLMATMUL:ARMV7_F32:192", true); | ||||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | |||||
"IM2COLMATMUL:ARMV7_F32:192", false); | |||||
#endif | #endif | ||||
} | } | ||||
@@ -541,7 +541,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44) { | ||||
checker_conv_bias_qint8x8x8( | checker_conv_bias_qint8x8x8( | ||||
get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true), | |||||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false, | |||||
true), | |||||
handle(), "S8_CONV_NCHW_NCHW44"); | |||||
checker_conv_bias_qint8x8x8( | |||||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false, | |||||
true), | |||||
handle(), "S8_CONV_NCHW_NCHW44"); | handle(), "S8_CONV_NCHW_NCHW44"); | ||||
} | } | ||||