GitOrigin-RevId: 9d718eb7a4
tags/v0.5.0
@@ -37,7 +37,7 @@ static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, | |||
static void get_rectified_size( | |||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, | |||
int& iw2, int& oh2, int& ow2) { | |||
constexpr int cacheline = 64 / sizeof(float); | |||
constexpr int nr_elements_in_cacheline = 64 / sizeof(float); | |||
int ic = param.filter_meta.icpg; | |||
int iw = param.isz[1]; | |||
int oh = param.osz[0]; | |||
@@ -52,7 +52,8 @@ static void get_rectified_size( | |||
int block_oh = l2_block_helper(param.nr_threads, oh, | |||
ic * iw * sizeof(float) * stride_h); | |||
ih2 = block_oh * stride_h + filter_h - stride_h; | |||
iw2 = round_up(iw + 2 * static_cast<int>(fm.padding[1]), cacheline); | |||
iw2 = round_up(iw + 2 * static_cast<int>(fm.padding[1]), | |||
nr_elements_in_cacheline); | |||
} | |||
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||
@@ -90,9 +90,9 @@ public: | |||
const NCBKernSizeParam& param) const override; | |||
}; | |||
class ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44 final : public AlgoBase { | |||
class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { | |||
public: | |||
AlgoS8DirectStride2NCHWNCHW44() {} | |||
AlgoS8DirectNCHWNCHW44() {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "S8_CONV_NCHW_NCHW44"; } | |||
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||
@@ -0,0 +1,373 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/arm_common/conv_bias/int8/algos.h" | |||
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" | |||
#include "src/arm_common/conv_bias/int8/strategy.h" | |||
#include "src/arm_common/elemwise_op.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "midout.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using conv_fun = std::function<void( | |||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw_nchw44) | |||
static void get_rectified_size( | |||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, | |||
int& iw2, int& oh2, int& ow2) { | |||
auto&& fm = param.filter_meta; | |||
int ih = param.isz[0]; | |||
int iw = param.isz[1]; | |||
int oh = param.osz[0]; | |||
int ow = param.osz[1]; | |||
int ph = fm.padding[0]; | |||
int pw = fm.padding[1]; | |||
int stride_h = fm.stride[0]; | |||
oh2 = oh; | |||
ow2 = ow; | |||
ih2 = stride_h == 2 ? round_up(ih + 2 * ph, 2) : ih + 2 * ph; | |||
iw2 = iw + 2 * pw; | |||
} | |||
static inline size_t get_temp_bytes(const int iw, const int pw) { | |||
//! border_size is used to avoid read illegal memory | |||
constexpr int cacheline_size = 64; | |||
constexpr int border_size = 1 * cacheline_size; | |||
return round_up(iw + pw * 2, cacheline_size) + border_size; | |||
} | |||
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||
auto&& fm = param.filter_meta; | |||
int group = fm.group; | |||
int batch = param.n; | |||
int ic = fm.icpg; | |||
int oc = fm.ocpg; | |||
int fh = fm.spatial[0]; | |||
int fw = fm.spatial[1]; | |||
int stride_h = fm.stride[0]; | |||
int iw = param.isz[1]; | |||
int pw = fm.padding[1]; | |||
int ih2, iw2, oh2, ow2; | |||
const size_t src_expand = stride_h == 2 ? 4 : 16; | |||
get_rectified_size(param, ih2, iw2, oh2, ow2); | |||
megdnn_assert(group == 1, "only support group == 1 now"); | |||
size_t src_size = | |||
batch * group * ic * ih2 * iw2 * sizeof(int8_t) * src_expand; | |||
size_t weight_size = group * oc * ic * fh * fw * sizeof(int8_t); | |||
size_t tmp_size = 0; | |||
if (stride_h == 1) { | |||
weight_size = group * oc * ic * fh * round_up(fw, 4) * sizeof(int8_t); | |||
tmp_size = get_temp_bytes(iw, pw); | |||
} | |||
return {nullptr, {src_size, weight_size, tmp_size * param.nr_threads}}; | |||
}; | |||
static void copy_padding_kern(WorkspaceBundle bundle, | |||
const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids) { | |||
int ih = kern_param.isz[0]; | |||
int iw = kern_param.isz[1]; | |||
int ic = kern_param.filter_meta.icpg; | |||
int ph = kern_param.filter_meta.padding[0]; | |||
int pw = kern_param.filter_meta.padding[1]; | |||
int group = kern_param.filter_meta.group; | |||
int stride_h = kern_param.filter_meta.stride[0]; | |||
int ih2, iw2, oh2, ow2; | |||
get_rectified_size(kern_param, ih2, iw2, oh2, ow2); | |||
int padding_group_size = ih2 * iw2 * ic; | |||
bundle.set(kern_param.workspace_ptr); | |||
//! Used for get the workspace offset | |||
const int src_expand = stride_h == 2 ? 4 : 16; | |||
//! TODO: block dim is better to get from arg | |||
int workspace_ic_block = 1; | |||
int workspace_batch_id = workspace_ids[0]; | |||
int workspace_group_id = workspace_ids[1]; | |||
int workspace_ic_id = workspace_ids[2]; | |||
int workspace_ic = workspace_ic_id * workspace_ic_block; | |||
int batch_id = ncb_index.ndrange_id[0]; | |||
int group_id = ncb_index.ndrange_id[1]; | |||
const int8_t* sptr = static_cast<const int8_t*>( | |||
kern_param.src<int8_t>(batch_id, group_id, workspace_ic_id, 1, 1)); | |||
//! copy to sptr_base to eliminate padding effect | |||
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) + | |||
(workspace_batch_id * group * padding_group_size + | |||
workspace_group_id * padding_group_size + | |||
workspace_ic * ih2 * iw2) * | |||
src_expand; | |||
if (stride_h == 1) { | |||
const size_t tmp_size = get_temp_bytes(iw, pw); | |||
int8_t* tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) + | |||
ncb_index.thread_id * tmp_size; | |||
pack_nchw_src_for_nchw44_conv<1>(sptr, sptr_base, 1, ph, ph, pw, pw, ih, | |||
iw, iw2, pw, tmp_ptr); | |||
} else { | |||
pack_nchw_src_for_nchw44_conv<2>(sptr, sptr_base, 1, ph, ph, pw, pw, ih, | |||
iw, iw2, pw, nullptr); | |||
} | |||
} | |||
static void pack_weight(WorkspaceBundle bundle, | |||
const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
bundle.set(kern_param.workspace_ptr); | |||
const int group_id = ncb_index.ndrange_id[0]; | |||
int fh = kern_param.filter_meta.spatial[0]; | |||
int fw = kern_param.filter_meta.spatial[1]; | |||
int oc = kern_param.filter_meta.ocpg; | |||
int ic = kern_param.filter_meta.icpg; | |||
int stride_h = kern_param.filter_meta.stride[0]; | |||
int fw2 = stride_h == 2 ? fw : round_up(fw, 4); | |||
int oc_block = oc; | |||
int oc_idx = 0; | |||
const int8_t* fptr = | |||
kern_param.filter<dt_int8>(group_id) + oc_idx * fh * fw * ic; | |||
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | |||
group_id * oc * ic * fh * fw2 + oc_idx * ic * fh * fw2; | |||
if (stride_h == 1) { | |||
pack_nchw44_weight_for_nchw_conv<1>(fptr, packed_weight, ic, fh, fw, | |||
oc_block); | |||
} else { | |||
pack_nchw44_weight_for_nchw_conv<2>(fptr, packed_weight, ic, fh, fw, | |||
oc_block); | |||
} | |||
} | |||
template <size_t filter, BiasMode bias_mode, typename Op, int stride> | |||
static void do_conv_kern(WorkspaceBundle bundle, | |||
const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids, | |||
const CpuNDRange& ncb_range) { | |||
int oh = kern_param.osz[0]; | |||
int ow = kern_param.osz[1]; | |||
int fh = kern_param.filter_meta.spatial[0]; | |||
int fw = kern_param.filter_meta.spatial[1]; | |||
int fw2 = stride == 2 ? fw : round_up(fw, 4); | |||
int ic = kern_param.filter_meta.icpg; | |||
int oc = kern_param.filter_meta.ocpg; | |||
int group = kern_param.filter_meta.group; | |||
int ih2, iw2, oh2, ow2; | |||
get_rectified_size(kern_param, ih2, iw2, oh2, ow2); | |||
bool need_post_process = | |||
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) | |||
Op op = Op(1.0f, 4.0f); | |||
if (need_post_process) { | |||
float scale_bias = | |||
kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | |||
op = Op(scale_bias, scale_dst); | |||
} | |||
int padding_group_size = ih2 * iw2 * ic; | |||
bundle.set(kern_param.workspace_ptr); | |||
constexpr int pack_c = 4; | |||
constexpr int src_expand_size = stride == 2 ? 4 : 16; | |||
const int workspace_batch_id = workspace_ids[0]; | |||
const int workspace_group_id = workspace_ids[1]; | |||
const int batch_id = ncb_index.ndrange_id[0]; | |||
const int group_id = ncb_index.ndrange_id[1]; | |||
const int oc_id = ncb_index.ndrange_id[2]; | |||
const int oc_block_num = ncb_range[2]; | |||
int nr_pack_per_step = div_ceil(div_ceil(oc, pack_c), oc_block_num); | |||
int oc_block = nr_pack_per_step * pack_c; | |||
const int oc_idx = oc_id * oc_block; | |||
if (oc_id == (oc_block_num - 1)) { | |||
oc_block = oc - oc_id * nr_pack_per_step * pack_c; | |||
} | |||
megdnn_assert(oc_block % pack_c == 0, | |||
"oc must be devisible by 4, but oc = %d", oc_block); | |||
const int8_t* sptr = | |||
static_cast<int8_t*>(bundle.get(0)) + | |||
workspace_batch_id * group * padding_group_size * src_expand_size + | |||
workspace_group_id * padding_group_size * src_expand_size; | |||
int8_t* dst = reinterpret_cast<int8_t*>( | |||
reinterpret_cast<ptrdiff_t>( | |||
kern_param.dst<void>(batch_id, group_id)) + | |||
oc_idx * oh * ow); | |||
const int32_t* bptr = | |||
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx; | |||
int8_t* packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | |||
group_id * oc * ic * fh * fw2 + | |||
oc_idx * ic * fh * fw2; | |||
conv_direct_int8_nchw_nchw44<bias_mode, Op, filter, stride>( | |||
sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih2, iw2, oh, | |||
ow, op); | |||
} | |||
bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
MEGDNN_MARK_USED_VAR(algo_selection_strategy); | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
auto OC = fm.ocpg; | |||
bool avaible = //! src and filter are qint8, dst is qint8 | |||
fm.icpg < 4 && // must be nchw input | |||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && | |||
(fm.format == param::Convolution::Format::NCHW44) && | |||
(OC % 4 == 0 && OC >= 4) && !fm.should_flip && fm.group == 1 && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && | |||
(fm.stride[0] == 1 || fm.stride[0] == 2) && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && fm.group == 1 && | |||
param.bias_mode != BiasMode::BIAS; | |||
return avaible; | |||
} | |||
bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred( | |||
megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, | |||
const NCBKernSizeParam& param) const { | |||
// TODO: benchmark and fix | |||
MEGDNN_MARK_USED_VAR(conv_bias_impl_ptr); | |||
MEGDNN_MARK_USED_VAR(param); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoS8DirectNCHWNCHW44::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
auto fm = param.filter_meta; | |||
size_t N = param.n; | |||
size_t OC = fm.ocpg; | |||
size_t group = fm.group; | |||
WorkspaceBundle wbundle = get_bundle(param); | |||
conv_fun do_conv_fun = nullptr; | |||
// NOTE: remain_w is not used to gen hash of midout for compatible with changing | |||
// shape runtime | |||
#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44, \ | |||
midout_iv(#stride #filter #bias_mode #op##_hash)) { \ | |||
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \ | |||
} \ | |||
MIDOUT_END(); | |||
#define GET_OP_PARAM(stride, filter, bias_mode) \ | |||
switch (param.nonlineMode) { \ | |||
case param::ConvBias::NonlineMode::IDENTITY: \ | |||
DO_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
break; \ | |||
case param::ConvBias::NonlineMode::RELU: \ | |||
DO_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
break; \ | |||
case param::ConvBias::NonlineMode::H_SWISH: \ | |||
DO_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
break; \ | |||
} | |||
#define GET_BIAS_MODE_PARAM(stride, filter) \ | |||
switch (param.bias_mode) { \ | |||
case BiasMode::NO_BIAS: \ | |||
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||
break; \ | |||
case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
break; \ | |||
} | |||
#define DISPATCH_CONV_KERN(stride) \ | |||
switch (param.filter_meta.spatial[0]) { \ | |||
case 2: \ | |||
GET_BIAS_MODE_PARAM(stride, 2) \ | |||
break; \ | |||
case 3: \ | |||
GET_BIAS_MODE_PARAM(stride, 3) \ | |||
break; \ | |||
case 5: \ | |||
GET_BIAS_MODE_PARAM(stride, 5) \ | |||
break; \ | |||
case 7: \ | |||
GET_BIAS_MODE_PARAM(stride, 7) \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
break; \ | |||
} | |||
switch (param.filter_meta.stride[0]) { | |||
case 1: | |||
DISPATCH_CONV_KERN(1); | |||
break; | |||
case 2: | |||
DISPATCH_CONV_KERN(2); | |||
break; | |||
default: | |||
megdnn_throw(ssprintf("Unsupport stride size %u for the first conv", | |||
param.filter_meta.stride[0]) | |||
.c_str()); | |||
break; | |||
} | |||
#undef DO_CONV_KERN_FUN | |||
#undef GET_REMAIN_W_PARAM | |||
#undef GET_OP_PARAM | |||
#undef GET_BIAS_MODE_PARAM | |||
#undef DISPATCH_CONV_KERN | |||
megdnn_assert(do_conv_fun); | |||
SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||
WorkspaceBundle bundle = wbundle; | |||
constexpr size_t pack_oc = 8; | |||
size_t oc_step = pack_oc; | |||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({copy_padding, {N, group, fm.icpg}}); | |||
auto do_pack_weight = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
pack_weight(bundle, kern_param, ncb_index); | |||
}; | |||
ret_kerns.push_back({do_pack_weight, {static_cast<size_t>(group)}}); | |||
CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; | |||
auto do_conv = [bundle, do_conv_fun, ncb_range]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, | |||
ncb_range); | |||
}; | |||
ret_kerns.push_back({do_conv, ncb_range}); | |||
return ret_kerns; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -1,305 +0,0 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_algo.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/arm_common/conv_bias/int8/algos.h" | |||
#include "src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h" | |||
#include "src/arm_common/conv_bias/int8/strategy.h" | |||
#include "src/arm_common/elemwise_op.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "midout.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using conv_fun = std::function<void( | |||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw_nchw44_stride2) | |||
static void get_rectified_size( | |||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { | |||
auto&& fm = param.filter_meta; | |||
size_t IH = param.isz[0]; | |||
size_t IW = param.isz[1]; | |||
size_t OH = param.osz[0]; | |||
size_t OW = param.osz[1]; | |||
OH2 = OH; | |||
OW2 = OW; | |||
IH2 = round_up(IH + 2 * fm.padding[0], static_cast<size_t>(2)); | |||
IW2 = IW + 2 * fm.padding[1]; | |||
} | |||
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||
constexpr size_t src_expand = 4; | |||
auto&& fm = param.filter_meta; | |||
size_t group = fm.group; | |||
size_t batch = param.n; | |||
size_t IC = fm.icpg; | |||
size_t OC = fm.ocpg; | |||
size_t FH = fm.spatial[0]; | |||
size_t FW = fm.spatial[1]; | |||
size_t IH2, IW2, OH2, OW2; | |||
get_rectified_size(param, IH2, IW2, OH2, OW2); | |||
megdnn_assert(group == 1, "only support group == 1 now"); | |||
size_t src_size = | |||
batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; | |||
size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); | |||
return {nullptr, {src_size, weight_size}}; | |||
}; | |||
static void copy_padding_kern(WorkspaceBundle bundle, | |||
const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids) { | |||
size_t IH = kern_param.isz[0]; | |||
size_t IW = kern_param.isz[1]; | |||
size_t IC = kern_param.filter_meta.icpg; | |||
size_t PH = kern_param.filter_meta.padding[0]; | |||
size_t PW = kern_param.filter_meta.padding[1]; | |||
size_t GROUP = kern_param.filter_meta.group; | |||
size_t IH2, IW2, OH2, OW2; | |||
get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
bundle.set(kern_param.workspace_ptr); | |||
//! Used for get the workspace offset | |||
constexpr int expend_element = 4; | |||
// TODO: block dim is better to get from arg | |||
size_t workspace_ic_block = 1; | |||
size_t workspace_batch_id = workspace_ids[0]; | |||
size_t workspace_group_id = workspace_ids[1]; | |||
size_t workspace_ic_id = workspace_ids[2]; | |||
size_t workspace_ic = workspace_ic_id * workspace_ic_block; | |||
size_t batch_id = ncb_index.ndrange_id[0]; | |||
size_t group_id = ncb_index.ndrange_id[1]; | |||
const int8_t* sptr = static_cast<const int8_t*>( | |||
kern_param.src<int8_t>(batch_id, group_id, workspace_ic_id, 1, 1)); | |||
//! copy to sptr_base to eliminate padding effect | |||
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) + | |||
(workspace_batch_id * GROUP * padding_group_size + | |||
workspace_group_id * padding_group_size + | |||
workspace_ic * IH2 * IW2) * | |||
expend_element; | |||
conv_bias::pack_nchw_src_for_nchw44_conv(sptr, sptr_base, 1, PH, PH, PW, PW, | |||
IH, IW); | |||
} | |||
template <size_t filter, BiasMode bias_mode, typename Op> | |||
static void do_conv_kern(WorkspaceBundle bundle, | |||
const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids, | |||
const CpuNDRange& ncb_range) { | |||
size_t OH = kern_param.osz[0]; | |||
size_t OW = kern_param.osz[1]; | |||
size_t FH = kern_param.filter_meta.spatial[0]; | |||
size_t FW = kern_param.filter_meta.spatial[1]; | |||
size_t IC = kern_param.filter_meta.icpg; | |||
size_t OC = kern_param.filter_meta.ocpg; | |||
size_t GROUP = kern_param.filter_meta.group; | |||
size_t IH2, IW2, OH2, OW2; | |||
get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||
bool need_post_process = | |||
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) | |||
Op op = Op(1.0f, 4.0f); | |||
if (need_post_process) { | |||
float scale_bias = | |||
kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | |||
op = Op(scale_bias, scale_dst); | |||
} | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
bundle.set(kern_param.workspace_ptr); | |||
constexpr size_t pack_c = 4; | |||
constexpr size_t src_expand_size = 4; | |||
const size_t workspace_batch_id = workspace_ids[0]; | |||
const size_t workspace_group_id = workspace_ids[1]; | |||
const size_t batch_id = ncb_index.ndrange_id[0]; | |||
const size_t group_id = ncb_index.ndrange_id[1]; | |||
const size_t oc_id = ncb_index.ndrange_id[2]; | |||
const size_t oc_block_num = ncb_range[2]; | |||
size_t nr_pack_per_step = div_ceil(div_ceil(OC, pack_c), oc_block_num); | |||
size_t oc_block = nr_pack_per_step * pack_c; | |||
const size_t oc_idx = oc_id * oc_block; | |||
if (oc_id == (oc_block_num - 1)) { | |||
oc_block = OC - oc_id * nr_pack_per_step * pack_c; | |||
} | |||
megdnn_assert(oc_block % pack_c == 0, | |||
"oc must be devisible by 4, but oc = %zu", oc_block); | |||
const int8_t* sptr = | |||
static_cast<int8_t*>(bundle.get(0)) + | |||
workspace_batch_id * GROUP * padding_group_size * src_expand_size + | |||
workspace_group_id * padding_group_size * src_expand_size; | |||
const int8_t* fptr = | |||
kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC; | |||
void* dst = reinterpret_cast<void*>( | |||
reinterpret_cast<ptrdiff_t>( | |||
kern_param.dst<void>(batch_id, group_id)) + | |||
oc_idx * OH * OW); | |||
const int32_t* bptr = | |||
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx; | |||
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | |||
group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; | |||
conv_bias::pack_nchw44_weight_for_nchw_conv(fptr, packed_weight, IC, FH, FW, | |||
oc_block); | |||
#define KERN1_NCHW44_CONV(filter) \ | |||
conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw_nchw44< \ | |||
bias_mode, Op>(sptr, packed_weight, bptr, nullptr, \ | |||
static_cast<int8_t*>(dst), oc_block, IC, IH2, IW2, \ | |||
OH, OW, op) | |||
DISPATCH_FILTER(filter, KERN1_NCHW44_CONV); | |||
#undef KERN1_NCHW44_CONV | |||
} | |||
/* ===================== stride2 algo ===================== */ | |||
bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::usable( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
MEGDNN_MARK_USED_VAR(algo_selection_strategy); | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
auto OC = fm.ocpg; | |||
bool avaible = //! src and filter are qint8, dst is qint8 | |||
fm.icpg < 4 && // must be nchw input | |||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && | |||
(fm.format == param::Convolution::Format::NCHW44) && | |||
(OC % 4 == 0 && OC >= 4) && !fm.should_flip && fm.group == 1 && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||
FH == fm.spatial[1] && (FH == 3 || FH == 5 || FH == 7) && | |||
fm.group == 1 && param.bias_mode != BiasMode::BIAS; | |||
return avaible; | |||
} | |||
bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::is_preferred( | |||
megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, | |||
const NCBKernSizeParam& param) const { | |||
// TODO: benchmark and fix | |||
MEGDNN_MARK_USED_VAR(conv_bias_impl_ptr); | |||
MEGDNN_MARK_USED_VAR(param); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
auto fm = param.filter_meta; | |||
size_t N = param.n; | |||
size_t OC = fm.ocpg; | |||
size_t group = fm.group; | |||
WorkspaceBundle wbundle = get_bundle(param); | |||
conv_fun do_conv_fun = nullptr; | |||
// NOTE: remain_w is not used to gen hash of midout for compatible with changing | |||
// shape runtime | |||
#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44_stride2, \ | |||
midout_iv(#filter #bias_mode #op##_hash)) { \ | |||
do_conv_fun = do_conv_kern<filter, bias_mode, op>; \ | |||
} \ | |||
MIDOUT_END(); | |||
#define GET_OP_PARAM(filter, bias_mode) \ | |||
switch (param.nonlineMode) { \ | |||
case param::ConvBias::NonlineMode::IDENTITY: \ | |||
DO_CONV_KERN_FUN(filter, bias_mode, \ | |||
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
break; \ | |||
case param::ConvBias::NonlineMode::RELU: \ | |||
DO_CONV_KERN_FUN(filter, bias_mode, \ | |||
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
break; \ | |||
case param::ConvBias::NonlineMode::H_SWISH: \ | |||
DO_CONV_KERN_FUN(filter, bias_mode, \ | |||
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
break; \ | |||
} | |||
#define GET_BIAS_MODE_PARAM(filter) \ | |||
switch (param.bias_mode) { \ | |||
case BiasMode::NO_BIAS: \ | |||
GET_OP_PARAM(filter, BiasMode::NO_BIAS) \ | |||
break; \ | |||
case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
break; \ | |||
} | |||
#define DISPATCH_CONV_KERN() \ | |||
switch (param.filter_meta.spatial[0]) { \ | |||
case 3: \ | |||
GET_BIAS_MODE_PARAM(3) \ | |||
break; \ | |||
case 5: \ | |||
GET_BIAS_MODE_PARAM(5) \ | |||
break; \ | |||
case 7: \ | |||
GET_BIAS_MODE_PARAM(7) \ | |||
break; \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
break; \ | |||
} | |||
DISPATCH_CONV_KERN(); | |||
#undef DO_CONV_KERN_FUN | |||
#undef GET_REMAIN_W_PARAM | |||
#undef GET_OP_PARAM | |||
#undef GET_BIAS_MODE_PARAM | |||
#undef DISPATCH_CONV_KERN | |||
megdnn_assert(do_conv_fun); | |||
SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||
WorkspaceBundle bundle = wbundle; | |||
constexpr size_t pack_oc = 8; | |||
size_t oc_step = pack_oc; | |||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({copy_padding, {N, group, fm.icpg}}); | |||
CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; | |||
auto do_conv = [bundle, do_conv_fun, ncb_range]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, | |||
ncb_range); | |||
}; | |||
ret_kerns.push_back({do_conv, ncb_range}); | |||
return ret_kerns; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -1,789 +0,0 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern_nchw.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h" | |||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
#include "src/arm_common/elemwise_op.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
namespace { | |||
template <int src_idx, int weight_idx, int c_dim, typename Func, typename T, | |||
typename T2, typename T3, typename T4> | |||
struct ShiftCalHelper { | |||
static void impl(T& c, T2& src, T3& weight, T4& temp); | |||
static void impl(T& c, T2& src, T3& weight); | |||
}; | |||
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||
typename T3, typename T4> | |||
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, T, T2, T3, T4> { | |||
static void impl(T& c, T2& src, T3& weight, T4& temp) { | |||
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], | |||
temp[0]); | |||
c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0], | |||
temp[1]); | |||
c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], | |||
temp[2]); | |||
c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1], | |||
temp[3]); | |||
c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], | |||
temp[0]); | |||
c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2], | |||
temp[1]); | |||
c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], | |||
temp[2]); | |||
c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3], | |||
temp[3]); | |||
} | |||
static void impl(T& c, T2& src, T3& weight) { | |||
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); | |||
c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0]); | |||
c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); | |||
c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1]); | |||
c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]); | |||
c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2]); | |||
c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]); | |||
c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3]); | |||
} | |||
}; | |||
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||
typename T3, typename T4> | |||
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, T, T2, T3, T4> { | |||
static void impl(T& c, T2& src, T3& weight, T4& temp) { | |||
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], | |||
temp[0]); | |||
c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], | |||
temp[2]); | |||
c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], | |||
temp[0]); | |||
c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], | |||
temp[2]); | |||
} | |||
static void impl(T& c, T2& src, T3& weight) { | |||
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); | |||
c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); | |||
c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]); | |||
c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]); | |||
} | |||
}; | |||
template <int src_idx, int weight_idx, int c_dim, typename FUNC, typename T, | |||
typename T2, typename T3, typename T4> | |||
inline void cal_helper(T& c, T2& src, T3& weight, T4& temp) { | |||
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, T, T2, T3, T4>::impl( | |||
c, src, weight, temp); | |||
} | |||
template <int src_idx, int weight_idx, int c_dim, typename FUNC, typename T, | |||
typename T2, typename T3> | |||
inline void cal_helper(T& c, T2& src, T3& weight) { | |||
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, T, T2, T3, int>::impl( | |||
c, src, weight); | |||
}; | |||
template <int oc> | |||
struct OCHelper { | |||
public: | |||
static const int val = 0; | |||
}; | |||
template <> | |||
struct OCHelper<4> { | |||
public: | |||
static const int val = 1; | |||
}; | |||
template <> | |||
struct OCHelper<8> { | |||
public: | |||
static const int val = 2; | |||
}; | |||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size, | |||
int oc_block> | |||
struct KerNeonXXs2NchwNchw44 { | |||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
int iw, int ld_dst_oc, const Op& op); | |||
}; | |||
/** | |||
* filter shape = (oc/4, ic, 7, 7, 4), first 4 oc is f0 = filter[0, 0, :, :, :] | |||
* calculate sequence \ | |||
* f0[0:1, 0:1, 4] dot4, \ | |||
* f0[0:1, 2:3, 4] dot4, \ | |||
* f0[0:1, 4:5, 4] dot4, \ | |||
* f0[0:1, 6, 4] dot2, \ | |||
* ... | |||
* f0[6, 0:1, 4] dot2, \ | |||
* f0[6, 2:3, 4] dot2, \ | |||
* f0[6, 4:5, 4] dot2, \ | |||
* f0[6, 6, 4] dot1, \ | |||
* look like: | |||
* |---|---|---|-| | |||
* |x x|x x|x x|x| | |||
* |x x|x x|x x|x| | |||
* |---|---|---|-| | |||
* |x x|x x|x x|x| | |||
* |x x|x x|x x|x| | |||
* |---|---|---|-| | |||
* |x x|x x|x x|x| | |||
* |x x|x x|x x|x| | |||
* |---|---|---|-| | |||
* |x x|x x|x x|x| | |||
* |---|---|---|-| | |||
**/ | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block> { | |||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
int iw, int ld_dst_oc, const Op& op) { | |||
static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, | |||
0, 8, 0, 8, 0, 8, 0, 8}; | |||
constexpr int filter_size = 7; | |||
constexpr int ic_step = 1; | |||
constexpr int oc_step = 4; | |||
constexpr int pack_iw_len = 4; | |||
constexpr int fh_step = 2; | |||
constexpr int fh_end = filter_size / fh_step * fh_step; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
const int ic_stride = ih * iw * pack_iw_len; | |||
const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; | |||
int32x4_t c[c_dim][4]; | |||
init_ocx_ow4<c_dim, bias_mode>(c, bias_ptr, oc_step); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { | |||
const int8_t* nchw_src_ptr = | |||
src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * pack_iw_len; | |||
int8x16_t src[6]; | |||
int8x16_t dot4_weight[c_dim][3]; | |||
int16x8_t temp_c[4]; | |||
load_helper<3, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, | |||
ld_dot4_weight_oc); | |||
load_helper<6, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); | |||
cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight, | |||
temp_c); | |||
cal_helper<1, 1, c_dim, Vdotq_s32_h>(c, src, dot4_weight, | |||
temp_c); | |||
cal_helper<2, 2, c_dim, Vdotq_s32_h>(c, src, dot4_weight, | |||
temp_c); | |||
int8x8_t src_dot2[4]; | |||
int8x8_t dot2_weight[c_dim][1]; | |||
load_helper<1, 3 * 16, 8, c_dim, Vld1_s8>( | |||
dot2_weight, weight_ptr, ld_dot4_weight_oc); | |||
load_helper<4, 3 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, | |||
0); | |||
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||
temp_c); | |||
weight_ptr += filter_size * pack_iw_len * fh_step; | |||
} | |||
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + | |||
6 * iw * ic_step * pack_iw_len; | |||
int8x8_t dot2_weight[c_dim][3]; | |||
int16x8_t temp_c[4]; | |||
int8x8_t src_dot2[6]; | |||
uint8x16_t tbl = vld1q_u8(src_idx_buffer); | |||
load_helper<3, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, | |||
ld_dot4_weight_oc); | |||
load_helper_x<6, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, | |||
0, tbl); | |||
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||
temp_c); | |||
cal_helper<1, 1, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||
temp_c); | |||
cal_helper<2, 2, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||
temp_c); | |||
int16x8_t dot1_weight[c_dim][1]; | |||
int16x8_t src_dot1[4]; | |||
load_helper<1, 3 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( | |||
dot1_weight, weight_ptr, ld_dot4_weight_oc); | |||
load_helper<4, 3 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, | |||
nchw_src_ptr, 0); | |||
cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight); | |||
weight_ptr += filter_size * pack_iw_len; | |||
} | |||
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc); | |||
} | |||
}; | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block> { | |||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
int iw, int ld_dst_oc, const Op& op) { | |||
constexpr int filter_size = 5; | |||
static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, | |||
0, 8, 0, 8, 0, 8, 0, 8}; | |||
constexpr int ih_step = 2; | |||
constexpr int ic_step = 1; | |||
constexpr int oc_step = 4; | |||
constexpr int pack_iw_len = 4; | |||
constexpr int fh_end = filter_size / ih_step * ih_step; | |||
const int ic_stride = ih * iw * pack_iw_len; | |||
const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
int32x4_t c[c_dim][4]; | |||
init_ocx_ow4<c_dim, bias_mode>(c, bias_ptr, oc_step); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||
for (int fh_idx = 0; fh_idx < fh_end; fh_idx += ih_step) { | |||
const int8_t* nchw_src_ptr = | |||
src_ptr + ic_idx * ic_stride + | |||
fh_idx * iw * ic_step * pack_iw_len; | |||
int8x16_t src[5]; | |||
int8x16_t dot4_weight[c_dim][2]; | |||
int16x8_t temp_c[4]; | |||
load_helper<2, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, | |||
ld_dot4_weight_oc); | |||
load_helper<5, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); | |||
cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight, | |||
temp_c); | |||
cal_helper<1, 1, c_dim, Vdotq_s32_h>(c, src, dot4_weight, | |||
temp_c); | |||
int8x8_t src_dot2[4]; | |||
int8x8_t dot2_weight[c_dim][1]; | |||
load_helper<1, 2 * 16, 8, c_dim, Vld1_s8>( | |||
dot2_weight, weight_ptr, ld_dot4_weight_oc); | |||
load_helper<4, 2 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, | |||
0); | |||
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||
temp_c); | |||
weight_ptr += filter_size * pack_iw_len * ih_step; | |||
} | |||
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + | |||
fh_end * iw * ic_step * pack_iw_len; | |||
int8x8_t dot2_weight[c_dim][2]; | |||
int16x8_t temp_c[4]; | |||
int8x8_t src_dot2[5]; | |||
uint8x16_t tbl = vld1q_u8(src_idx_buffer); | |||
load_helper<2, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, | |||
ld_dot4_weight_oc); | |||
load_helper_x<5, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, | |||
0, tbl); | |||
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||
temp_c); | |||
cal_helper<1, 1, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||
temp_c); | |||
int16x8_t dot1_weight[c_dim][1]; | |||
int16x8_t src_dot1[4]; | |||
load_helper<1, 2 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( | |||
dot1_weight, weight_ptr, ld_dot4_weight_oc); | |||
load_helper<4, 2 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, | |||
nchw_src_ptr, 0); | |||
cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight); | |||
weight_ptr += filter_size * pack_iw_len; | |||
} | |||
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc); | |||
} | |||
}; | |||
/** | |||
* filter shape = (oc/4, ic, 3, 3, 4), first 4 oc is f0 = filter[0, 0, :, :, :] | |||
* calculate sequence \ | |||
* f0[0:1, 0:1, 4] dot4, \ | |||
* f0[0:1, 2, 4] dot2, \ | |||
* f0[2, 0:1, 4] dot2, \ | |||
* f0[2, 2, 4] dot1 \ | |||
* look like: | |||
* |---|-| | |||
* |x x|x| | |||
* |x x|x| | |||
* |-----| | |||
* |x x|x| | |||
* |-----| | |||
**/ | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block> { | |||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
int iw, int ld_dst_oc, const Op& op) { | |||
constexpr int filter_size = 3; | |||
static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, | |||
0, 8, 0, 8, 0, 8, 0, 8}; | |||
constexpr int oc_step = 4; | |||
constexpr int ic_step = 1; | |||
constexpr int loop_ic_step = 1; | |||
constexpr int pack_iw_len = 4; | |||
const int ic_stride = ih * iw * pack_iw_len; | |||
const int ld_weight_oc = oc_step * filter_size * filter_size * ic; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
int32x4_t c[c_dim][4]; | |||
init_ocx_ow4<c_dim, bias_mode>(c, bias_ptr, oc_step); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
// first 2 line | |||
{ | |||
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||
int8x16_t src[4]; | |||
int8x16_t dot4_weight[c_dim][1]; | |||
int16x8_t temp_c[4]; | |||
load_helper<1, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, | |||
ld_weight_oc); | |||
load_helper<4, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); | |||
cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight, | |||
temp_c); | |||
int8x8_t src_dot2[4]; | |||
int8x8_t dot2_weight[c_dim][1]; | |||
load_helper<1, 1 * 16, 8, c_dim, Vld1_s8>( | |||
dot2_weight, weight_ptr, ld_weight_oc); | |||
load_helper<4, 1 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, | |||
0); | |||
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||
temp_c); | |||
} | |||
// last line | |||
{ | |||
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + | |||
2 * iw * ic_step * pack_iw_len; | |||
int16x8_t temp_c[4]; | |||
int8x8_t src_dot2[4]; | |||
int8x8_t dot2_weight[c_dim][1]; | |||
uint8x16_t tbl = vld1q_u8(src_idx_buffer); | |||
load_helper<1, 24, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, | |||
ld_weight_oc); | |||
load_helper_x<4, 0, 16, 0, Vldq_tbl_low_s8>( | |||
src_dot2, nchw_src_ptr, 0, tbl); | |||
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, | |||
temp_c); | |||
int16x8_t dot1_weight[c_dim][1]; | |||
int16x8_t src_dot1[4]; | |||
load_helper<1, 32, 8, c_dim, Vldq_dup_4s8_8s16>( | |||
dot1_weight, weight_ptr, ld_weight_oc); | |||
load_helper<4, 1 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, | |||
nchw_src_ptr, 0); | |||
cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight); | |||
weight_ptr += filter_size * filter_size * pack_iw_len; | |||
} | |||
} | |||
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc); | |||
} | |||
}; | |||
} // namespace | |||
enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 }; | |||
template <PACK_MODE mode> | |||
inline void pack_src_one_line(const int8_t* inptr, int8_t* outptr, int left_pad, | |||
int right_pad, const int iw) { | |||
const int8_t* src_row_0 = inptr; | |||
const int8_t* src_row_1 = inptr + iw; | |||
constexpr int combine_row = 2; | |||
constexpr int iw_step = 16; | |||
constexpr int src_expand = 4; | |||
constexpr int out_gap = iw_step * src_expand; | |||
const int iw_end = iw / iw_step * iw_step; | |||
memset(outptr, 0, combine_row * left_pad * src_expand * sizeof(int8_t)); | |||
outptr += combine_row * left_pad * src_expand; | |||
for (int iw_idx = 0; iw_idx < iw_end; iw_idx += iw_step) { | |||
int8x16_t row0 = vld1q_s8(src_row_0 + iw_idx); | |||
int8x16_t row1 = vdupq_n_s8(0); | |||
if (mode == PACK_MODE::NO_PAD) { | |||
row1 = vld1q_s8(src_row_1 + iw_idx); | |||
} else if (mode == PACK_MODE::FIRST_PAD) { | |||
row1 = row0; | |||
row0 = vdupq_n_s8(0); | |||
} | |||
int8x16x2_t pack_rows = vzipq_s8(row0, row1); | |||
#define STORE_8S8(step) \ | |||
vst1_s8(outptr + step * 8, \ | |||
vreinterpret_s8_s16(vdup_laneq_s16( \ | |||
vreinterpretq_s16_s8(pack_rows.val[0]), step))); | |||
UNROLL_CALL_RAW(8, STORE_8S8); | |||
#undef STORE_8S8 | |||
#define STORE_8S8(step) \ | |||
vst1_s8(outptr + out_gap + step * 8, \ | |||
vreinterpret_s8_s16(vdup_laneq_s16( \ | |||
vreinterpretq_s16_s8(pack_rows.val[1]), step))); | |||
UNROLL_CALL_RAW(8, STORE_8S8); | |||
#undef STORE_8S8 | |||
outptr += out_gap * combine_row; | |||
} | |||
for (int iw_idx = iw_end; iw_idx < iw; iw_idx++) { | |||
int8x8_t row0 = vld1_dup_s8(src_row_0 + iw_idx); | |||
int8x8_t row1 = vdup_n_s8(0); | |||
if (mode == PACK_MODE::NO_PAD) { | |||
row1 = vld1_dup_s8(src_row_1 + iw_idx); | |||
} else if (mode == PACK_MODE::FIRST_PAD) { | |||
row1 = row0; | |||
row0 = vdup_n_s8(0); | |||
} | |||
int8x8x2_t pack_rows = vzip_s8(row0, row1); | |||
vst1_s8(outptr, pack_rows.val[0]); | |||
outptr += src_expand * combine_row; | |||
} | |||
memset(outptr, 0, combine_row * right_pad * src_expand * sizeof(int8_t)); | |||
outptr += combine_row * right_pad * src_expand; | |||
} | |||
/** | |||
* pack (ic, h, w) to (ic, h / 2, 2 * w) | |||
* pack interleave two adjacent row in src and repeat 4 times, store to one row | |||
* */ | |||
void conv_bias::pack_nchw_src_for_nchw44_conv( | |||
const int8_t* inptr, int8_t* outptr, const int ic, const int top_pad, | |||
const int bottom_pad, const int left_pad, const int right_pad, | |||
const int ih, const int iw) { | |||
constexpr int src_expand = 4; | |||
constexpr int oh_step = 2; | |||
const int oh = ih + top_pad + bottom_pad; | |||
const int oh_end = div_floor(ih + top_pad, oh_step) * oh_step; | |||
const int ow = (iw + left_pad + right_pad) * src_expand; | |||
for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { | |||
int oh_idx = 0; | |||
for (; oh_idx < top_pad; oh_idx += oh_step) { | |||
if (top_pad - oh_idx >= oh_step) { | |||
memset(outptr, 0, oh_step * ow * sizeof(int8_t)); | |||
} else { | |||
pack_src_one_line<PACK_MODE::FIRST_PAD>(inptr, outptr, left_pad, | |||
right_pad, iw); | |||
inptr += iw; | |||
} | |||
outptr += oh_step * ow; | |||
} | |||
for (; oh_idx < oh_end; oh_idx += oh_step) { | |||
pack_src_one_line<PACK_MODE::NO_PAD>(inptr, outptr, left_pad, | |||
right_pad, iw); | |||
inptr += oh_step * iw; | |||
outptr += oh_step * ow; | |||
} | |||
for (; oh_idx < oh; oh_idx += oh_step) { | |||
const int last_pad = oh_idx - ih - top_pad; | |||
if (last_pad >= 0) { | |||
memset(outptr, 0, oh_step * ow * sizeof(int8_t)); | |||
} else { | |||
pack_src_one_line<PACK_MODE::LAST_PAD>(inptr, outptr, left_pad, | |||
right_pad, iw); | |||
inptr += iw; | |||
} | |||
outptr += oh_step * ow; | |||
} | |||
} | |||
} | |||
/** | |||
* pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh * fw, 4(oc)} | |||
* pack interleave two adjacent row in filter to one row | |||
* */ | |||
void conv_bias::pack_nchw44_weight_for_nchw_conv(const int8_t* inptr, | |||
int8_t* outptr, const int ic, | |||
const int fh, const int fw, | |||
const int oc) { | |||
constexpr int oc_step = 4; | |||
constexpr int ic_step = 2; | |||
constexpr int fh_step = 2; | |||
constexpr int fw_step = 2; | |||
const int ic_end = ic / ic_step * ic_step; | |||
const int ic_remain = ic - ic_end; | |||
const int fh_end = fh / fh_step * fh_step; | |||
const int fh_remain = fh - fh_end; | |||
const int fw_end = fw / fw_step * fw_step; | |||
const int fw_remain = fw - fw_end; | |||
const int filter_stride = ic * oc_step; | |||
static const uint8_t ic2_idx_h_buffer[16] = {0, 8, 1, 9, 2, 10, 3, 11, | |||
4, 12, 5, 13, 6, 14, 7, 15}; | |||
uint8x16_t ic2_idx_h = vld1q_u8(ic2_idx_h_buffer); | |||
for (int oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||
for (int ic_idx = 0; ic_idx < ic_end; ic_idx += ic_step) { | |||
const int ic_offset = ic_idx * oc_step; | |||
int8_t* output_ic0 = outptr + ic_idx * fh * fw * oc_step; | |||
int8_t* output_ic1 = output_ic0 + fh * fw * oc_step; | |||
for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { | |||
const int fh_offset = fh_idx * fw * filter_stride; | |||
for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { | |||
const int8_t* filter_ptr = inptr + fh_offset + | |||
fw_idx * filter_stride + | |||
ic_offset; | |||
int8x8_t row_0 = vld1_s8(filter_ptr); | |||
int8x8_t row_1 = vld1_s8(filter_ptr + fw * filter_stride); | |||
int8x16_t combine_row = vcombine_s8(row_0, row_1); | |||
combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); | |||
vst1_s8(output_ic0, vget_low_s8(combine_row)); | |||
vst1_s8(output_ic1, vget_high_s8(combine_row)); | |||
output_ic0 += 8; | |||
output_ic1 += 8; | |||
} | |||
} | |||
if (fh_remain > 0) { | |||
const int fh_offset = fh_end * fw * filter_stride; | |||
for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { | |||
const int8_t* filter_ptr = inptr + fh_offset + | |||
fw_idx * filter_stride + | |||
ic_offset; | |||
int8x8_t row_0 = vld1_s8(filter_ptr); | |||
int8x8_t row_1 = vld1_s8(filter_ptr + filter_stride); | |||
int8x16_t combine_row = vcombine_s8(row_0, row_1); | |||
combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); | |||
vst1_s8(output_ic0, vget_low_s8(combine_row)); | |||
vst1_s8(output_ic1, vget_high_s8(combine_row)); | |||
output_ic0 += 8; | |||
output_ic1 += 8; | |||
} | |||
if (fw_remain > 0) { | |||
const int8_t* filter_ptr = inptr + fh_offset + | |||
fw_end * filter_stride + | |||
ic_offset; | |||
int8x8_t row_0 = vld1_s8(filter_ptr); | |||
vst1_lane_s32((int32_t*)output_ic0, | |||
vreinterpret_s32_s8(row_0), 0); | |||
vst1_lane_s32((int32_t*)output_ic1, | |||
vreinterpret_s32_s8(row_0), 1); | |||
output_ic0 += 4; | |||
output_ic1 += 4; | |||
} | |||
} | |||
} | |||
if (ic_remain > 0) { | |||
const int ic_offset = ic_end * oc_step; | |||
int8_t* output_ic0 = outptr + ic_end * fh * fw * oc_step; | |||
for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { | |||
const int fh_offset = fh_idx * fw * filter_stride; | |||
for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { | |||
const int8_t* filter_ptr = inptr + fh_offset + | |||
fw_idx * filter_stride + | |||
ic_offset; | |||
int8x8_t row_0 = vreinterpret_s8_s32( | |||
vld1_dup_s32((const int32_t*)(filter_ptr))); | |||
int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( | |||
(const int32_t*)(filter_ptr + fw * filter_stride))); | |||
int8x16_t combine_row = vcombine_s8(row_0, row_1); | |||
combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); | |||
vst1_s8(output_ic0, vget_low_s8(combine_row)); | |||
output_ic0 += 8; | |||
} | |||
} | |||
if (fh_remain > 0) { | |||
const int fh_offset = fh_end * fw * filter_stride; | |||
for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { | |||
const int8_t* filter_ptr = inptr + fh_offset + | |||
fw_idx * filter_stride + | |||
ic_offset; | |||
int8x8_t row_0 = vreinterpret_s8_s32( | |||
vld1_dup_s32((const int32_t*)(filter_ptr))); | |||
int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( | |||
(const int32_t*)(filter_ptr + filter_stride))); | |||
int8x16_t combine_row = vcombine_s8(row_0, row_1); | |||
combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); | |||
vst1_s8(output_ic0, vget_low_s8(combine_row)); | |||
output_ic0 += 8; | |||
} | |||
if (fw_remain > 0) { | |||
const int8_t* filter_ptr = inptr + fh_offset + | |||
fw_end * filter_stride + | |||
ic_offset; | |||
*(int32_t*)(output_ic0) = *(const int32_t*)(filter_ptr); | |||
output_ic0 += 4; | |||
} | |||
} | |||
} | |||
inptr += oc_step * fh * fw * ic; | |||
outptr += oc_step * fh * fw * ic; | |||
} | |||
} | |||
template <BiasMode bias_mode, typename Op, size_t filter_size> | |||
static void conv_direct_stride2_int8_nchw_nchw44( | |||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||
const Op& op) { | |||
MEGDNN_MARK_USED_VAR(temp); | |||
constexpr size_t fh = filter_size; | |||
constexpr size_t fw = filter_size; | |||
constexpr size_t ic_step = 1; | |||
constexpr size_t big_oc_step = 8; | |||
constexpr size_t oc_step = 4; | |||
constexpr size_t ih_step = 2; | |||
constexpr size_t oh_step = 1; | |||
constexpr size_t ow_step = 4; | |||
constexpr size_t stride_h = 2; | |||
constexpr size_t stride_w = 2; | |||
constexpr int pack_iw_len = 4; | |||
const size_t img_stride = oh * ow; | |||
const size_t ow_end = ow / ow_step * ow_step; | |||
const size_t ow_remain = ow - ow_end; | |||
const size_t oc_end = oc / big_oc_step * big_oc_step; | |||
const size_t oc_remain = oc - oc_end; | |||
const int ld_dst_oc = oc_step * img_stride; | |||
using remain_fun = | |||
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, | |||
int ih, int iw, int ld_dst_oc, const Op& op)>; | |||
remain_fun kern_big_oc_remain = nullptr; | |||
remain_fun kern_small_oc_remain = nullptr; | |||
switch (ow_remain) { | |||
#define cb(step) \ | |||
case step: \ | |||
kern_big_oc_remain = \ | |||
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \ | |||
big_oc_step>::impl; \ | |||
kern_small_oc_remain = \ | |||
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \ | |||
oc_step>::impl; \ | |||
break; | |||
UNROLL_CALL_RAW(4, cb); | |||
default: | |||
megdnn_assert(0, "no remain %zu for kern", ow_remain); | |||
} | |||
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = | |||
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||
ic_step * pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
KerNeonXXs2NchwNchw44<bias_mode, Op, 0, filter_size, | |||
big_oc_step>::impl(src + src_offset, | |||
filter + weight_offset, | |||
bias + oc_idx, | |||
dst + dst_offset, ic, | |||
ih, iw, ld_dst_oc, op); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = | |||
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||
ic_step * pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
kern_big_oc_remain(src + src_offset, filter + weight_offset, | |||
bias + oc_idx, dst + dst_offset, ic, ih, iw, | |||
ld_dst_oc, op); | |||
} | |||
} | |||
} | |||
if (oc_remain > 0) { | |||
size_t oc_idx = oc_end; | |||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
const size_t src_offset = | |||
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||
ic_step * pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
KerNeonXXs2NchwNchw44<bias_mode, Op, 0, filter_size, | |||
oc_step>::impl(src + src_offset, | |||
filter + weight_offset, | |||
bias + oc_idx, | |||
dst + dst_offset, ic, ih, | |||
iw, ld_dst_oc, op); | |||
} | |||
if (ow_remain > 0) { | |||
const size_t src_offset = | |||
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||
ic_step * pack_iw_len; | |||
const size_t dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||
bias + oc_idx, dst + dst_offset, ic, ih, | |||
iw, ld_dst_oc, op); | |||
} | |||
} | |||
} | |||
} | |||
#define CONSTRUCT_FUNC(filter_size) \ | |||
template <BiasMode bias_mode, typename Op> \ | |||
void conv_bias:: \ | |||
conv_direct_stride2_##filter_size##x##filter_size##_int8_nchw_nchw44( \ | |||
const int8_t* src, const int8_t* filter, \ | |||
const int32_t* bias, int32_t* temp, int8_t* dst, \ | |||
const size_t oc, const size_t ic, const size_t ih, \ | |||
const size_t iw, const size_t oh, const size_t ow, \ | |||
const Op& op) { \ | |||
conv_direct_stride2_int8_nchw_nchw44<bias_mode, Op, filter_size>( \ | |||
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); \ | |||
} | |||
CONSTRUCT_FUNC(3); | |||
CONSTRUCT_FUNC(5); | |||
CONSTRUCT_FUNC(7); | |||
#undef CONSTRUCT_FUNC | |||
template <BiasMode bias_mode, typename Op> | |||
void conv_bias::conv_direct_stride2_2x2_int8_nchw_nchw44( | |||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||
const Op& op) { | |||
MEGDNN_MARK_USED_VAR(src); | |||
MEGDNN_MARK_USED_VAR(filter); | |||
MEGDNN_MARK_USED_VAR(bias); | |||
MEGDNN_MARK_USED_VAR(temp); | |||
MEGDNN_MARK_USED_VAR(dst); | |||
MEGDNN_MARK_USED_VAR(oc); | |||
MEGDNN_MARK_USED_VAR(ic); | |||
MEGDNN_MARK_USED_VAR(ih); | |||
MEGDNN_MARK_USED_VAR(iw); | |||
MEGDNN_MARK_USED_VAR(oh); | |||
MEGDNN_MARK_USED_VAR(ow); | |||
MEGDNN_MARK_USED_VAR(op); | |||
megdnn_assert(0, "not imple nchw_nchw44 2x2s2 conv"); | |||
} | |||
#define INSTANTIATION(stride, i, bias, Op) \ | |||
template void conv_bias:: \ | |||
conv_direct_##stride##_##i##x##i##_int8_nchw_nchw44<bias, Op>( \ | |||
const int8_t*, const int8_t*, const int32_t*, int32_t*, \ | |||
int8_t*, const size_t, const size_t, const size_t, \ | |||
const size_t, const size_t, const size_t, const Op&); | |||
#define FOR_OP(stride, i, bias) \ | |||
INSTANTIATION(stride, i, bias, TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
INSTANTIATION(stride, i, bias, ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
INSTANTIATION(stride, i, bias, HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||
#define FOR_BIAS(stride, i) \ | |||
FOR_OP(stride, i, BiasMode::NO_BIAS) \ | |||
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
#define FOR_FILTER(stride) \ | |||
FOR_BIAS(stride, 2) \ | |||
FOR_BIAS(stride, 3) \ | |||
FOR_BIAS(stride, 5) \ | |||
FOR_BIAS(stride, 7) | |||
FOR_FILTER(stride2) | |||
#undef FOR_STRIDE | |||
#undef FOR_FILTER | |||
#undef FOR_IC | |||
#undef FOR_BIAS | |||
#undef FOR_NONLINEAR | |||
#undef FOR_REMAIN | |||
#undef INSTANTIATION |
@@ -1,44 +0,0 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/opr_impl.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace conv_bias { | |||
#define KERN(stride, i, layout) \ | |||
template <BiasMode bias_mode, typename Op> \ | |||
void conv_direct_##stride##_##i##x##i##_int8_nchw_##layout( \ | |||
const int8_t* src, const int8_t* filter, const int32_t* bias, \ | |||
int32_t* temp, int8_t* dst, const size_t OC, const size_t IC, \ | |||
const size_t IH, const size_t IW, const size_t OH, \ | |||
const size_t OW, const Op& op); | |||
KERN(stride2, 2, nchw44) | |||
KERN(stride2, 3, nchw44) | |||
KERN(stride2, 5, nchw44) | |||
KERN(stride2, 7, nchw44) | |||
#undef KERN | |||
void pack_nchw44_weight_for_nchw_conv(const int8_t* inptr, int8_t* outptr, | |||
const int ic, const int fh, const int fw, | |||
const int oc); | |||
void pack_nchw_src_for_nchw44_conv(const int8_t* inptr, int8_t* outptr, | |||
const int ic, const int top_pad, | |||
const int bottom_pad, const int left_pad, | |||
const int right_pad, const int ih, | |||
const int iw); | |||
} // namespace conv_bias | |||
} // namespace arm_common | |||
} // namespace megdnn |
@@ -47,7 +47,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
AlgoS8DirectStride2 s8_direct_stride2_large_group{true}; | |||
AlgoS8DirectStride2 s8_direct_stride2_small_group{false}; | |||
AlgoS8DirectStride2NCHW44 s8_direct_stride2_nchw44; | |||
AlgoS8DirectStride2NCHWNCHW44 s8_direct_stride2_nchw_nchw44; | |||
AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; | |||
AlgoS8DirectStride1 s8_direct_stride1_large_group{true}; | |||
AlgoS8DirectStride1 s8_direct_stride1_small_group{false}; | |||
AlgoS8DirectStride1NCHW44 s8_direct_stride1_nchw44; | |||
@@ -115,7 +115,7 @@ public: | |||
direct_algos.emplace_back(&s8_direct_stride2_large_group); | |||
direct_algos.emplace_back(&s8_direct_stride2_small_group); | |||
direct_algos.emplace_back(&s8_direct_stride2_nchw44); | |||
direct_algos.emplace_back(&s8_direct_stride2_nchw_nchw44); | |||
direct_algos.emplace_back(&s8_direct_nchw_nchw44); | |||
direct_algos.emplace_back(&s8_direct_stride1_large_group); | |||
direct_algos.emplace_back(&s8_direct_stride1_small_group); | |||
direct_algos.emplace_back(&s8_direct_stride1_nchw44); | |||
@@ -40,7 +40,7 @@ private: | |||
class AlgoS8DirectStride1NCHW44; | |||
class AlgoS8DirectStride2; | |||
class AlgoS8DirectStride2NCHW44; | |||
class AlgoS8DirectStride2NCHWNCHW44; | |||
class AlgoS8DirectNCHWNCHW44; | |||
class AlgoQU8DirectStride1; | |||
class AlgoQU8DirectStride2; | |||
class AlgoFP32WinogradF23_4x4; | |||
@@ -244,18 +244,26 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) { | |||
#if MEGDNN_AARCH64 | |||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | |||
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", true); | |||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | |||
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false); | |||
#else | |||
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384", | |||
"IM2COLMATMUL:ARMV7_F32:192", true); | |||
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384", | |||
"IM2COLMATMUL:ARMV7_F32:192", false); | |||
#endif | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { | |||
#if MEGDNN_AARCH64 | |||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | |||
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", true); | |||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | |||
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false); | |||
#else | |||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | |||
"IM2COLMATMUL:ARMV7_F32:192", true); | |||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | |||
"IM2COLMATMUL:ARMV7_F32:192", false); | |||
#endif | |||
} | |||
@@ -541,7 +541,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44) { | |||
checker_conv_bias_qint8x8x8( | |||
get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true), | |||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false, | |||
true), | |||
handle(), "S8_CONV_NCHW_NCHW44"); | |||
checker_conv_bias_qint8x8x8( | |||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false, | |||
true), | |||
handle(), "S8_CONV_NCHW_NCHW44"); | |||
} | |||