@@ -38,23 +38,6 @@ public: | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
}; | }; | ||||
class ConvBiasImpl::AlgoS8DirectStride1NCHW44 final : public AlgoBase { | |||||
public: | |||||
AlgoS8DirectStride1NCHW44() {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "S8_NCHW44_DIRECT_STRD1"; } | |||||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(fallback::ConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
bool is_preferred(megdnn::fallback::ConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
}; | |||||
class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | ||||
bool m_large_group; | bool m_large_group; | ||||
@@ -74,11 +57,11 @@ public: | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
}; | }; | ||||
class ConvBiasImpl::AlgoS8DirectStride2NCHW44 final : public AlgoBase { | |||||
class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { | |||||
public: | public: | ||||
AlgoS8DirectStride2NCHW44() {} | |||||
AlgoS8DirectNCHW44() {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return "S8_NCHW44_DIRECT_STRD2"; } | |||||
const char* name() const override { return "S8_NCHW44_DIRECT"; } | |||||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
size_t get_workspace(fallback::ConvBiasImpl*, | size_t get_workspace(fallback::ConvBiasImpl*, | ||||
@@ -245,8 +228,8 @@ private: | |||||
//=======================input int8 compute fp32 output int8============ | //=======================input int8 compute fp32 output int8============ | ||||
class ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoS8CF32WinogradF23_4x4_NCHW44(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
uint32_t tile_size) | |||||
AlgoS8CF32WinogradF23_4x4_NCHW44( | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) | |||||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | const char* name() const override { | ||||
@@ -277,7 +260,7 @@ private: | |||||
class ConvBiasImpl::AlgoS8WinogradF23_8x8_NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8WinogradF23_8x8_NCHW44 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoS8WinogradF23_8x8_NCHW44(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | AlgoS8WinogradF23_8x8_NCHW44(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
uint32_t tile_size) | |||||
uint32_t tile_size) | |||||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | const char* name() const override { | ||||
@@ -36,26 +36,6 @@ KERN(stride2, 7, nchw) | |||||
#undef KERN | #undef KERN | ||||
#define KERN(stride, i, layout) \ | |||||
template <BiasMode bias_mode, typename Op, int remain_w> \ | |||||
void conv_direct_##stride##_##i##x##i##_int8_##layout( \ | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, \ | |||||
int32_t* temp, int8_t* dst, const size_t OC, const size_t IC, \ | |||||
const size_t IH, const size_t IW, const size_t OH, \ | |||||
const size_t OW, const Op& op); | |||||
KERN(stride1, 2, nchw44) | |||||
KERN(stride1, 3, nchw44) | |||||
KERN(stride1, 5, nchw44) | |||||
KERN(stride1, 7, nchw44) | |||||
KERN(stride2, 2, nchw44) | |||||
KERN(stride2, 3, nchw44) | |||||
KERN(stride2, 5, nchw44) | |||||
KERN(stride2, 7, nchw44) | |||||
#undef KERN | |||||
void nchw44_pack_filter(const int8_t* src, int8_t* dst, int filter); | |||||
void nchw44_pack_src(const int8_t* src, int8_t* dst, int length); | |||||
} // namespace conv_bias | } // namespace conv_bias | ||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_algo.cpp | |||||
* \file dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -13,6 +13,7 @@ | |||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/arm_common/conv_bias/int8/algos.h" | #include "src/arm_common/conv_bias/int8/algos.h" | ||||
#include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | |||||
#include "src/arm_common/conv_bias/int8/strategy.h" | #include "src/arm_common/conv_bias/int8/strategy.h" | ||||
#include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
@@ -25,28 +26,19 @@ using conv_fun = std::function<void( | |||||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | ||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | const ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | ||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_stride2) | |||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44) | |||||
static void get_rectified_size( | static void get_rectified_size( | ||||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||||
size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { | |||||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, | |||||
int& iw2) { | |||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
size_t SW = fm.stride[1]; | |||||
size_t IH = param.isz[0]; | |||||
size_t IW = param.isz[1]; | |||||
size_t OH = param.osz[0]; | |||||
size_t OW = param.osz[1]; | |||||
size_t FH = fm.spatial[0]; | |||||
size_t FW = fm.spatial[1]; | |||||
int ih = param.isz[0]; | |||||
int iw = param.isz[1]; | |||||
int ph = fm.padding[0]; | |||||
int pw = fm.padding[1]; | |||||
OH2 = OH; | |||||
OW2 = (OW + 7) & ~7; | |||||
IH2 = SW * OH + FH - SW; | |||||
IW2 = SW * OW2 + FW - SW; | |||||
// Because stride is 2, sometimes IW == IW2+1. Do a max update to | |||||
// handle this case. | |||||
IH2 = std::max(IH2, IH); | |||||
IW2 = std::max(IW2, IW); | |||||
ih2 = ih + ph * 2; | |||||
iw2 = iw + pw * 2; | |||||
} | } | ||||
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | ||||
constexpr size_t src_expand = 4; | constexpr size_t src_expand = 4; | ||||
@@ -57,8 +49,8 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||||
size_t OC = fm.ocpg; | size_t OC = fm.ocpg; | ||||
size_t FH = fm.spatial[0]; | size_t FH = fm.spatial[0]; | ||||
size_t FW = fm.spatial[1]; | size_t FW = fm.spatial[1]; | ||||
size_t IH2, IW2, OH2, OW2; | |||||
get_rectified_size(param, IH2, IW2, OH2, OW2); | |||||
int IH2, IW2; | |||||
get_rectified_size(param, IH2, IW2); | |||||
if (group == 1) { | if (group == 1) { | ||||
size_t src_size = | size_t src_size = | ||||
batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; | batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; | ||||
@@ -76,16 +68,16 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||||
const ConvBiasImpl::NCBKernParam& kern_param, | const ConvBiasImpl::NCBKernParam& kern_param, | ||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | const ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
const CpuNDRange& workspace_ids) { | const CpuNDRange& workspace_ids) { | ||||
size_t IH = kern_param.isz[0]; | |||||
size_t IW = kern_param.isz[1]; | |||||
size_t IC = kern_param.filter_meta.icpg; | |||||
size_t PH = kern_param.filter_meta.padding[0]; | |||||
size_t PW = kern_param.filter_meta.padding[1]; | |||||
size_t GROUP = kern_param.filter_meta.group; | |||||
int IH = kern_param.isz[0]; | |||||
int IW = kern_param.isz[1]; | |||||
int IC = kern_param.filter_meta.icpg; | |||||
int PH = kern_param.filter_meta.padding[0]; | |||||
int PW = kern_param.filter_meta.padding[1]; | |||||
int GROUP = kern_param.filter_meta.group; | |||||
size_t IH2, IW2, OH2, OW2; | |||||
get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||||
size_t padding_group_size = IH2 * IW2 * IC; | |||||
int IH2, IW2; | |||||
get_rectified_size(kern_param, IH2, IW2); | |||||
int padding_group_size = IH2 * IW2 * IC; | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
//! Used for get the workspace offset | //! Used for get the workspace offset | ||||
constexpr int pack_ic = 4; | constexpr int pack_ic = 4; | ||||
@@ -100,16 +92,10 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||||
size_t group_id = ncb_index.ndrange_id[1]; | size_t group_id = ncb_index.ndrange_id[1]; | ||||
size_t group_pack_size = 1; | size_t group_pack_size = 1; | ||||
int nr_pad_h = PH * IW2 * pack_ic * expend_element; | |||||
int nr_pad_w = PW * pack_ic * expend_element; | int nr_pad_w = PW * pack_ic * expend_element; | ||||
int over_pad = std::max(0_z, IW2 - IW - 2 * PW) * pack_ic * expend_element; | |||||
int row_last_pad = ((int)IW2 - (int)IW - 2 * (int)PW) >= 0 | |||||
? nr_pad_w + over_pad | |||||
: (IW2 - IW - PW) * pack_ic * expend_element; | |||||
int col_last_pad = | |||||
((int)IH2 - (int)IH - 2 * (int)PH) >= 0 | |||||
? nr_pad_h | |||||
: (IH2 - IH - PH) * IW2 * pack_ic * expend_element; | |||||
int nr_pad_h = PH * IW2 * pack_ic * expend_element; | |||||
int row_last_pad = (IW2 - IW - PW) * pack_ic * expend_element; | |||||
int col_last_pad = (IH2 - IH - PH) * IW2 * pack_ic * expend_element; | |||||
const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>( | const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>( | ||||
batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); | batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); | ||||
@@ -129,7 +115,7 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||||
rep(ih_idx, IH) { | rep(ih_idx, IH) { | ||||
std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); | std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); | ||||
sptr_base += nr_pad_w; | sptr_base += nr_pad_w; | ||||
conv_bias::nchw44_pack_src(sptr, sptr_base, IW); | |||||
nchw44_pack_src(sptr, sptr_base, IW); | |||||
sptr_base += IW * pack_ic * expend_element; | sptr_base += IW * pack_ic * expend_element; | ||||
sptr += IW * pack_ic; | sptr += IW * pack_ic; | ||||
std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); | std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); | ||||
@@ -140,7 +126,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||||
} | } | ||||
} | } | ||||
template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain> | |||||
template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain, | |||||
typename DstType, int stride> | |||||
static void do_conv_kern(WorkspaceBundle bundle, | static void do_conv_kern(WorkspaceBundle bundle, | ||||
const ConvBiasImpl::NCBKernParam& kern_param, | const ConvBiasImpl::NCBKernParam& kern_param, | ||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | const ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
@@ -153,12 +140,12 @@ static void do_conv_kern(WorkspaceBundle bundle, | |||||
size_t IC = kern_param.filter_meta.icpg; | size_t IC = kern_param.filter_meta.icpg; | ||||
size_t OC = kern_param.filter_meta.ocpg; | size_t OC = kern_param.filter_meta.ocpg; | ||||
size_t GROUP = kern_param.filter_meta.group; | size_t GROUP = kern_param.filter_meta.group; | ||||
size_t IH2, IW2, OH2, OW2; | |||||
get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||||
int IH2, IW2; | |||||
get_rectified_size(kern_param, IH2, IW2); | |||||
bool need_post_process = | bool need_post_process = | ||||
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | ||||
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) | //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) | ||||
Op op = Op(1.0f, 4.0f); | |||||
Op op(1.f, 4.f); | |||||
if (need_post_process) { | if (need_post_process) { | ||||
float scale_bias = | float scale_bias = | ||||
kern_param.bias_type.param<dtype::QuantizedS32>().scale; | kern_param.bias_type.param<dtype::QuantizedS32>().scale; | ||||
@@ -191,49 +178,43 @@ static void do_conv_kern(WorkspaceBundle bundle, | |||||
const int8_t* fptr = | const int8_t* fptr = | ||||
kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC; | kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC; | ||||
void* dst = reinterpret_cast<void*>( | |||||
reinterpret_cast<ptrdiff_t>( | |||||
kern_param.dst<void>(batch_id, group_id)) + | |||||
oc_idx * OH * OW); | |||||
DstType* dst = reinterpret_cast<DstType*>( | |||||
kern_param.dst<void>(batch_id, group_id, oc_idx)); | |||||
const int32_t* bptr = | const int32_t* bptr = | ||||
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx; | kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx; | ||||
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | ||||
group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; | group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; | ||||
conv_bias::nchw44_pack_filter(fptr, packed_weight, | |||||
oc_block / 4 * IC / 4 * FH * FW); | |||||
#define KERN1_NCHW44_CONV(filter) \ | |||||
conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw44< \ | |||||
bias_mode, Op, ow_remain>(sptr, packed_weight, bptr, nullptr, \ | |||||
static_cast<int8_t*>(dst), oc_block, IC, \ | |||||
IH2, IW2, OH, OW, op) | |||||
DISPATCH_FILTER(filter, KERN1_NCHW44_CONV) | |||||
#undef KERN1_NCHW44_CONV | |||||
nchw44_pack_filter(fptr, packed_weight, oc_block / 4 * IC / 4 * FH * FW); | |||||
conv_direct_int8_nchw44<bias_mode, Op, ow_remain, filter, DstType, stride>( | |||||
sptr, packed_weight, bptr, nullptr, static_cast<DstType*>(dst), | |||||
oc_block, IC, IH2, IW2, OH, OW, op); | |||||
} | } | ||||
/* ===================== stride2 algo ===================== */ | |||||
bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::usable( | |||||
bool ConvBiasImpl::AlgoS8DirectNCHW44::usable( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const { | AlgoSelectionStrategy algo_selection_strategy) const { | ||||
MEGDNN_MARK_USED_VAR(algo_selection_strategy); | MEGDNN_MARK_USED_VAR(algo_selection_strategy); | ||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | |||||
auto OC = fm.ocpg; | |||||
auto IC = fm.icpg; | |||||
bool avaible = //! src and filter are qint8, dst is qint8 or qint32 | |||||
const int fh = fm.spatial[0]; | |||||
const int fw = fm.spatial[1]; | |||||
const int oc = fm.ocpg; | |||||
const int ic = fm.icpg; | |||||
const bool avaible = //! src and filter are qint8, dst is qint8 or qint32 | |||||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
(param.dst_type.enumv() == DTypeEnum::QuantizedS8 || | (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || | ||||
param.dst_type.enumv() == DTypeEnum::QuantizedS32))) && | param.dst_type.enumv() == DTypeEnum::QuantizedS32))) && | ||||
(fm.format == param::Convolution::Format::NCHW44) && | (fm.format == param::Convolution::Format::NCHW44) && | ||||
(OC % 4 == 0 && IC % 4 == 0 && OC >= 4) && !fm.should_flip && | |||||
(oc % 4 == 0 && ic % 4 == 0 && oc >= 4) && !fm.should_flip && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | ||||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||||
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | |||||
fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && | |||||
(fm.stride[0] == 2 || fm.stride[0] == 1) && fh == fw && | |||||
(fh == 2 || fh == 3 || fh == 5 || fh == 7) && | |||||
param.bias_mode != BiasMode::BIAS; | param.bias_mode != BiasMode::BIAS; | ||||
return avaible; | return avaible; | ||||
} | } | ||||
bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::is_preferred( | |||||
bool ConvBiasImpl::AlgoS8DirectNCHW44::is_preferred( | |||||
megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, | megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
// TODO: benchmark and fix | // TODO: benchmark and fix | ||||
@@ -242,13 +223,13 @@ bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::is_preferred( | |||||
return false; | return false; | ||||
} | } | ||||
size_t ConvBiasImpl::AlgoS8DirectStride2NCHW44::get_workspace( | |||||
size_t ConvBiasImpl::AlgoS8DirectNCHW44::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | ||||
return get_bundle(param).total_size_in_bytes(); | return get_bundle(param).total_size_in_bytes(); | ||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
ConvBiasImpl::AlgoS8DirectStride2NCHW44::dispatch_kerns( | |||||
ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | ||||
auto fm = param.filter_meta; | auto fm = param.filter_meta; | ||||
size_t N = param.n; | size_t N = param.n; | ||||
@@ -261,97 +242,129 @@ ConvBiasImpl::AlgoS8DirectStride2NCHW44::dispatch_kerns( | |||||
WorkspaceBundle wbundle = get_bundle(param); | WorkspaceBundle wbundle = get_bundle(param); | ||||
conv_fun do_conv_fun = nullptr; | conv_fun do_conv_fun = nullptr; | ||||
int ow_remain = OW % 8; | int ow_remain = OW % 8; | ||||
bool need_post_process = param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||||
// NOTE: remain_w is not used to gen hash of midout for compatible with changing | // NOTE: remain_w is not used to gen hash of midout for compatible with changing | ||||
// shape runtime | // shape runtime | ||||
#define DO_CONV_KERN_FUN(filter, bias_mode, remain_w, op) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride2, \ | |||||
midout_iv(#filter #bias_mode #op##_hash)) { \ | |||||
do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w>; \ | |||||
} \ | |||||
#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, remain_w, op) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, \ | |||||
midout_iv(#stride #dst_type #filter #bias_mode #op##_hash)) { \ | |||||
do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w, dst_type, \ | |||||
stride>; \ | |||||
} \ | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
#define GET_OP_PARAM(filter, bias_mode, remain_w) \ | |||||
switch (param.nonlineMode) { \ | |||||
case param::ConvBias::NonlineMode::IDENTITY: \ | |||||
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||||
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
case param::ConvBias::NonlineMode::RELU: \ | |||||
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||||
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
case param::ConvBias::NonlineMode::H_SWISH: \ | |||||
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||||
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
break; \ | |||||
#define GET_OP_PARAM(stride, filter, bias_mode, remain_w) \ | |||||
if (need_post_process) { \ | |||||
switch (param.nonlineMode) { \ | |||||
case param::ConvBias::NonlineMode::IDENTITY: \ | |||||
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | |||||
remain_w, \ | |||||
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
case param::ConvBias::NonlineMode::RELU: \ | |||||
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | |||||
remain_w, \ | |||||
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
case param::ConvBias::NonlineMode::H_SWISH: \ | |||||
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | |||||
remain_w, \ | |||||
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0, "no supported noline mode"); \ | |||||
break; \ | |||||
} \ | |||||
} else { \ | |||||
switch (param.nonlineMode) { \ | |||||
case param::ConvBias::NonlineMode::IDENTITY: \ | |||||
DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, \ | |||||
remain_w, NoneOp<dt_int32>) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert( \ | |||||
0, \ | |||||
"only support IDENTITY mode when dst is not qint8"); \ | |||||
break; \ | |||||
} \ | |||||
} | } | ||||
#define GET_REMAIN_W_PARAM(filter, bias_mode) \ | |||||
switch (ow_remain) { \ | |||||
case 0: \ | |||||
GET_OP_PARAM(filter, bias_mode, 0); \ | |||||
break; \ | |||||
case 1: \ | |||||
GET_OP_PARAM(filter, bias_mode, 1); \ | |||||
break; \ | |||||
case 2: \ | |||||
GET_OP_PARAM(filter, bias_mode, 2); \ | |||||
break; \ | |||||
case 3: \ | |||||
GET_OP_PARAM(filter, bias_mode, 3); \ | |||||
break; \ | |||||
case 4: \ | |||||
GET_OP_PARAM(filter, bias_mode, 4); \ | |||||
break; \ | |||||
case 5: \ | |||||
GET_OP_PARAM(filter, bias_mode, 5); \ | |||||
break; \ | |||||
case 6: \ | |||||
GET_OP_PARAM(filter, bias_mode, 6); \ | |||||
break; \ | |||||
case 7: \ | |||||
GET_OP_PARAM(filter, bias_mode, 7); \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
#define GET_REMAIN_W_PARAM(stride, filter, bias_mode) \ | |||||
switch (ow_remain) { \ | |||||
case 0: \ | |||||
GET_OP_PARAM(stride, filter, bias_mode, 0); \ | |||||
break; \ | |||||
case 1: \ | |||||
GET_OP_PARAM(stride, filter, bias_mode, 1); \ | |||||
break; \ | |||||
case 2: \ | |||||
GET_OP_PARAM(stride, filter, bias_mode, 2); \ | |||||
break; \ | |||||
case 3: \ | |||||
GET_OP_PARAM(stride, filter, bias_mode, 3); \ | |||||
break; \ | |||||
case 4: \ | |||||
GET_OP_PARAM(stride, filter, bias_mode, 4); \ | |||||
break; \ | |||||
case 5: \ | |||||
GET_OP_PARAM(stride, filter, bias_mode, 5); \ | |||||
break; \ | |||||
case 6: \ | |||||
GET_OP_PARAM(stride, filter, bias_mode, 6); \ | |||||
break; \ | |||||
case 7: \ | |||||
GET_OP_PARAM(stride, filter, bias_mode, 7); \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
} | } | ||||
#define GET_BIAS_MODE_PARAM(filter) \ | |||||
switch (param.bias_mode) { \ | |||||
case BiasMode::NO_BIAS: \ | |||||
GET_REMAIN_W_PARAM(filter, BiasMode::NO_BIAS) \ | |||||
break; \ | |||||
case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
GET_REMAIN_W_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
break; \ | |||||
#define GET_BIAS_MODE_PARAM(stride, filter) \ | |||||
switch (param.bias_mode) { \ | |||||
case BiasMode::NO_BIAS: \ | |||||
GET_REMAIN_W_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||||
break; \ | |||||
case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
GET_REMAIN_W_PARAM(stride, filter, \ | |||||
BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
break; \ | |||||
} | } | ||||
#define DISPATCH_CONV_KERN() \ | |||||
#define DISPATCH_CONV_KERN(stride) \ | |||||
switch (param.filter_meta.spatial[0]) { \ | switch (param.filter_meta.spatial[0]) { \ | ||||
case 2: \ | case 2: \ | ||||
GET_BIAS_MODE_PARAM(2) \ | |||||
GET_BIAS_MODE_PARAM(stride, 2) \ | |||||
break; \ | break; \ | ||||
case 3: \ | case 3: \ | ||||
GET_BIAS_MODE_PARAM(3) \ | |||||
GET_BIAS_MODE_PARAM(stride, 3) \ | |||||
break; \ | break; \ | ||||
case 5: \ | case 5: \ | ||||
GET_BIAS_MODE_PARAM(5) \ | |||||
GET_BIAS_MODE_PARAM(stride, 5) \ | |||||
break; \ | break; \ | ||||
case 7: \ | case 7: \ | ||||
GET_BIAS_MODE_PARAM(7) \ | |||||
GET_BIAS_MODE_PARAM(stride, 7) \ | |||||
break; \ | break; \ | ||||
default: \ | default: \ | ||||
megdnn_assert(0); \ | megdnn_assert(0); \ | ||||
break; \ | break; \ | ||||
} | } | ||||
DISPATCH_CONV_KERN(); | |||||
switch (param.filter_meta.stride[0]) { | |||||
case 1: | |||||
DISPATCH_CONV_KERN(1); | |||||
break; | |||||
case 2: | |||||
DISPATCH_CONV_KERN(2); | |||||
break; | |||||
default: | |||||
megdnn_throw(ssprintf("Unsupport stride size %u for the first conv", | |||||
param.filter_meta.stride[0]) | |||||
.c_str()); | |||||
break; | |||||
} | |||||
#undef DO_CONV_KERN_FUN | #undef DO_CONV_KERN_FUN | ||||
#undef GET_REMAIN_W_PARAM | #undef GET_REMAIN_W_PARAM |
@@ -1,393 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megdnn/oprs.h" | |||||
#include "src/arm_common/conv_bias/int8/algos.h" | |||||
#include "src/arm_common/conv_bias/int8/direct.h" | |||||
#include "src/arm_common/conv_bias/int8/strategy.h" | |||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | |||||
#include "midout.h" | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
using conv_fun = std::function<void( | |||||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | |||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_stride1) | |||||
static void get_rectified_size( | |||||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||||
size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { | |||||
auto&& fm = param.filter_meta; | |||||
auto SW = fm.stride[1]; | |||||
auto OH = param.osz[0]; | |||||
auto OW = param.osz[1]; | |||||
auto FH = fm.spatial[0]; | |||||
auto FW = fm.spatial[1]; | |||||
OH2 = OH; | |||||
OW2 = (OW + 7) & ~7; | |||||
IH2 = SW * OH + FH - SW; | |||||
IW2 = SW * OW2 + FW - SW; | |||||
} | |||||
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||||
constexpr size_t src_expand = 4; | |||||
auto&& fm = param.filter_meta; | |||||
size_t group = fm.group; | |||||
size_t batch = param.n; | |||||
size_t IC = fm.icpg; | |||||
size_t OC = fm.ocpg; | |||||
size_t FH = fm.spatial[0]; | |||||
size_t FW = fm.spatial[1]; | |||||
size_t IH2, IW2, OH2, OW2; | |||||
get_rectified_size(param, IH2, IW2, OH2, OW2); | |||||
if (group == 1) { | |||||
size_t src_size = | |||||
batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; | |||||
size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); | |||||
return {nullptr, {src_size, weight_size}}; | |||||
} else { | |||||
size_t src_size = | |||||
param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; | |||||
size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); | |||||
return {nullptr, {src_size, weight_size}}; | |||||
} | |||||
}; | |||||
static void copy_padding_kern(WorkspaceBundle bundle, | |||||
const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids) { | |||||
size_t IH = kern_param.isz[0]; | |||||
size_t IW = kern_param.isz[1]; | |||||
size_t IC = kern_param.filter_meta.icpg; | |||||
size_t PH = kern_param.filter_meta.padding[0]; | |||||
size_t PW = kern_param.filter_meta.padding[1]; | |||||
size_t GROUP = kern_param.filter_meta.group; | |||||
size_t IH2, IW2, OH2, OW2; | |||||
get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||||
size_t padding_group_size = IH2 * IW2 * IC; | |||||
bundle.set(kern_param.workspace_ptr); | |||||
//! Used for get the workspace offset | |||||
constexpr int pack_ic = 4; | |||||
constexpr int expend_element = 4; | |||||
// TODO: block dim is better to get from arg | |||||
size_t workspace_ic_block = 4; | |||||
size_t workspace_batch_id = workspace_ids[0]; | |||||
size_t workspace_group_id = workspace_ids[1]; | |||||
size_t workspace_ic_id = workspace_ids[2]; | |||||
size_t workspace_ic = workspace_ic_id * workspace_ic_block; | |||||
size_t batch_id = ncb_index.ndrange_id[0]; | |||||
size_t group_id = ncb_index.ndrange_id[1]; | |||||
size_t group_pack_size = 1; | |||||
int nr_pad_h = PH * IW2 * pack_ic * expend_element; | |||||
int nr_pad_w = PW * pack_ic * expend_element; | |||||
int over_pad = std::max(0_z, IW2 - IW - 2 * PW) * pack_ic * expend_element; | |||||
//! copy to sptr_base to eliminate padding effect | |||||
const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>( | |||||
batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); | |||||
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) + | |||||
(workspace_batch_id * GROUP * padding_group_size + | |||||
workspace_group_id * padding_group_size + | |||||
workspace_ic * IH2 * IW2) * | |||||
expend_element; | |||||
size_t nr_ic = workspace_ic_block; | |||||
if (GROUP > 1) { | |||||
nr_ic = IC; | |||||
} | |||||
rep_step(ic_idx, nr_ic, pack_ic) { | |||||
std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); | |||||
sptr_base += nr_pad_h; | |||||
rep(ih_idx, IH) { | |||||
std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); | |||||
sptr_base += nr_pad_w; | |||||
conv_bias::nchw44_pack_src(sptr, sptr_base, IW); | |||||
sptr_base += IW * pack_ic * expend_element; | |||||
sptr += IW * pack_ic; | |||||
std::memset(sptr_base, 0, (nr_pad_w + over_pad) * sizeof(int8_t)); | |||||
sptr_base += nr_pad_w + over_pad; | |||||
} | |||||
std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); | |||||
sptr_base += nr_pad_h; | |||||
} | |||||
} | |||||
template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain> | |||||
static void do_conv_kern(WorkspaceBundle bundle, | |||||
const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids, | |||||
const CpuNDRange& ncb_range) { | |||||
size_t OH = kern_param.osz[0]; | |||||
size_t OW = kern_param.osz[1]; | |||||
size_t FH = kern_param.filter_meta.spatial[0]; | |||||
size_t FW = kern_param.filter_meta.spatial[1]; | |||||
size_t IC = kern_param.filter_meta.icpg; | |||||
size_t OC = kern_param.filter_meta.ocpg; | |||||
size_t GROUP = kern_param.filter_meta.group; | |||||
size_t IH2, IW2, OH2, OW2; | |||||
get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||||
bool need_post_process = | |||||
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||||
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) | |||||
Op op = Op(1.0f, 4.0f); | |||||
if (need_post_process) { | |||||
float scale_bias = | |||||
kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||||
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | |||||
op = Op(scale_bias, scale_dst); | |||||
} | |||||
size_t padding_group_size = IH2 * IW2 * IC; | |||||
bundle.set(kern_param.workspace_ptr); | |||||
constexpr size_t pack_c = 4; | |||||
constexpr size_t src_expand_size = 4; | |||||
const size_t workspace_batch_id = workspace_ids[0]; | |||||
const size_t workspace_group_id = workspace_ids[1]; | |||||
const size_t batch_id = ncb_index.ndrange_id[0]; | |||||
const size_t group_id = ncb_index.ndrange_id[1]; | |||||
const size_t oc_id = ncb_index.ndrange_id[2]; | |||||
const size_t oc_block_num = ncb_range[2]; | |||||
size_t nr_pack_per_step = div_ceil(div_ceil(OC, pack_c), oc_block_num); | |||||
size_t oc_block = nr_pack_per_step * pack_c; | |||||
const size_t oc_idx = oc_id * oc_block; | |||||
if (oc_id == (oc_block_num - 1)) { | |||||
oc_block = OC - oc_id * nr_pack_per_step * pack_c; | |||||
} | |||||
megdnn_assert(oc_block % pack_c == 0, | |||||
"oc must be devisible by 4, but oc = %zu", oc_block); | |||||
const int8_t* sptr = | |||||
static_cast<int8_t*>(bundle.get(0)) + | |||||
workspace_batch_id * GROUP * padding_group_size * src_expand_size + | |||||
workspace_group_id * padding_group_size * src_expand_size; | |||||
const int8_t* fptr = | |||||
kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC; | |||||
void* dst = reinterpret_cast<void*>( | |||||
reinterpret_cast<ptrdiff_t>( | |||||
kern_param.dst<void>(batch_id, group_id)) + | |||||
oc_idx * OH * OW); | |||||
const int32_t* bptr = | |||||
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx; | |||||
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | |||||
group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; | |||||
conv_bias::nchw44_pack_filter(fptr, packed_weight, | |||||
oc_block / 4 * IC / 4 * FH * FW); | |||||
#define KERN1_NCHW44_CONV(filter) \ | |||||
conv_bias::conv_direct_stride1_##filter##x##filter##_int8_nchw44< \ | |||||
bias_mode, Op, ow_remain>(sptr, packed_weight, bptr, nullptr, \ | |||||
static_cast<int8_t*>(dst), oc_block, IC, \ | |||||
IH2, IW2, OH, OW, op) | |||||
DISPATCH_FILTER(filter, KERN1_NCHW44_CONV) | |||||
#undef KERN1_NCHW44_CONV | |||||
} | |||||
/* ===================== stride1 algo ===================== */ | |||||
bool ConvBiasImpl::AlgoS8DirectStride1NCHW44::usable( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
MEGDNN_MARK_USED_VAR(algo_selection_strategy); | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
auto OC = fm.ocpg; | |||||
auto IC = fm.icpg; | |||||
bool avaible = //! src and filter are qint8, dst is qint8 or qint32 | |||||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
(param.dst_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
param.dst_type.enumv() == DTypeEnum::QuantizedS32))) && | |||||
(fm.format == param::Convolution::Format::NCHW44) && | |||||
(OC % 4 == 0 && IC % 4 == 0 && OC >= 4) && !fm.should_flip && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && | |||||
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | |||||
param.bias_mode != BiasMode::BIAS; | |||||
return avaible; | |||||
} | |||||
bool ConvBiasImpl::AlgoS8DirectStride1NCHW44::is_preferred( | |||||
megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, | |||||
const NCBKernSizeParam& param) const { | |||||
// TODO: benchmark and fix | |||||
MEGDNN_MARK_USED_VAR(conv_bias_impl_ptr); | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoS8DirectStride1NCHW44::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoS8DirectStride1NCHW44::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
auto fm = param.filter_meta; | |||||
size_t N = param.n; | |||||
size_t IC = fm.icpg; | |||||
size_t OC = fm.ocpg; | |||||
size_t OW = param.osz[1]; | |||||
size_t group = fm.group; | |||||
size_t fh = fm.spatial[0]; | |||||
size_t fw = fm.spatial[1]; | |||||
WorkspaceBundle wbundle = get_bundle(param); | |||||
conv_fun do_conv_fun = nullptr; | |||||
int ow_remain = OW % 8; | |||||
// NOTE: remain_w is not used to gen hash of midout for compatible with changing | |||||
// shape runtime | |||||
#define DO_CONV_KERN_FUN(filter, bias_mode, remain_w, op) \ | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride1, \ | |||||
midout_iv(#filter #bias_mode #op##_hash)) { \ | |||||
do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w>; \ | |||||
} \ | |||||
MIDOUT_END(); | |||||
#define GET_OP_PARAM(filter, bias_mode, remain_w) \ | |||||
switch (param.nonlineMode) { \ | |||||
case param::ConvBias::NonlineMode::IDENTITY: \ | |||||
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||||
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
case param::ConvBias::NonlineMode::RELU: \ | |||||
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||||
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
case param::ConvBias::NonlineMode::H_SWISH: \ | |||||
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||||
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
break; \ | |||||
} | |||||
#define GET_REMAIN_W_PARAM(filter, bias_mode) \ | |||||
switch (ow_remain) { \ | |||||
case 0: \ | |||||
GET_OP_PARAM(filter, bias_mode, 0); \ | |||||
break; \ | |||||
case 1: \ | |||||
GET_OP_PARAM(filter, bias_mode, 1); \ | |||||
break; \ | |||||
case 2: \ | |||||
GET_OP_PARAM(filter, bias_mode, 2); \ | |||||
break; \ | |||||
case 3: \ | |||||
GET_OP_PARAM(filter, bias_mode, 3); \ | |||||
break; \ | |||||
case 4: \ | |||||
GET_OP_PARAM(filter, bias_mode, 4); \ | |||||
break; \ | |||||
case 5: \ | |||||
GET_OP_PARAM(filter, bias_mode, 5); \ | |||||
break; \ | |||||
case 6: \ | |||||
GET_OP_PARAM(filter, bias_mode, 6); \ | |||||
break; \ | |||||
case 7: \ | |||||
GET_OP_PARAM(filter, bias_mode, 7); \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
} | |||||
#define GET_BIAS_MODE_PARAM(filter) \ | |||||
switch (param.bias_mode) { \ | |||||
case BiasMode::NO_BIAS: \ | |||||
GET_REMAIN_W_PARAM(filter, BiasMode::NO_BIAS) \ | |||||
break; \ | |||||
case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
GET_REMAIN_W_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
break; \ | |||||
} | |||||
#define DISPATCH_CONV_KERN() \ | |||||
switch (param.filter_meta.spatial[0]) { \ | |||||
case 2: \ | |||||
GET_BIAS_MODE_PARAM(2) \ | |||||
break; \ | |||||
case 3: \ | |||||
GET_BIAS_MODE_PARAM(3) \ | |||||
break; \ | |||||
case 5: \ | |||||
GET_BIAS_MODE_PARAM(5) \ | |||||
break; \ | |||||
case 7: \ | |||||
GET_BIAS_MODE_PARAM(7) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
break; \ | |||||
} | |||||
DISPATCH_CONV_KERN(); | |||||
#undef DO_CONV_KERN_FUN | |||||
#undef GET_REMAIN_W_PARAM | |||||
#undef GET_OP_PARAM | |||||
#undef GET_BIAS_MODE_PARAM | |||||
#undef DISPATCH_CONV_KERN | |||||
megdnn_assert(do_conv_fun); | |||||
SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||||
WorkspaceBundle bundle = wbundle; | |||||
constexpr size_t pack_oc = 4; | |||||
size_t oc_step = pack_oc; | |||||
if (fh == 2 && fw == 2 && OC >= 8) { | |||||
oc_step = 8; | |||||
} | |||||
if (group == 1) { | |||||
CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; | |||||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
copy_padding_kern(bundle, kern_param, ncb_index, | |||||
ncb_index.ndrange_id); | |||||
}; | |||||
constexpr size_t pack_ic = 4; | |||||
ret_kerns.push_back({copy_padding, {N, group, div_ceil(IC, pack_ic)}}); | |||||
auto do_conv = [bundle, do_conv_fun, ncb_range]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, | |||||
ncb_range); | |||||
}; | |||||
ret_kerns.push_back({do_conv, ncb_range}); | |||||
} else { | |||||
CpuNDRange ncb_range = {N, group, 1}; | |||||
auto do_conv = [bundle, do_conv_fun, ncb_range]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
copy_padding_kern(bundle, kern_param, ncb_index, | |||||
{0, ncb_index.thread_id, 0}); | |||||
do_conv_fun(bundle, kern_param, ncb_index, | |||||
{0, ncb_index.thread_id, 0}, ncb_range); | |||||
}; | |||||
ret_kerns.push_back({do_conv, ncb_range}); | |||||
} | |||||
return ret_kerns; | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -1,791 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/int8/direct.h" | |||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | |||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
/** | |||||
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc> | |||||
example: (format like weight<oc, ic>) | |||||
packed weight | |||||
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> | |||||
--------------------------------------------------------------------- | |||||
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> | |||||
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> | |||||
**/ | |||||
// TODO: can try oh = 2 impl, oc = 8 impl | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
static void ker_neon_dirctconv_3x3s1_oc4_ow8(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, | |||||
int8_t* dst_ptr, int ic, int ih, | |||||
int iw, const Op& op) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
constexpr int pack_iw_len = 4; | |||||
const int ic_stride = ih * iw * pack_iw_len; | |||||
int32x4_t c[2 * 4]; | |||||
int8x16_t weight[3]; | |||||
int8x16_t src[8 + 2]; | |||||
int16x8_t temp_c[2]; | |||||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * pack_iw_len; | |||||
src[0] = vld1q_s8(src_ic_0_3); | |||||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||||
// oc == 0 | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0] = vld1q_s8(read_weight_ptr); | |||||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
} | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, | |||||
int8_t* dst_ptr, int ic, int ih, | |||||
int iw, int ld_dst_oc, | |||||
const Op& op) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int oc_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
constexpr int pack_iw_len = 4; | |||||
const int ic_stride = ih * iw * pack_iw_len; | |||||
const int ld_weight_oc4 = oc_step * fh * fw * ic; | |||||
int32x4_t c[2][8]; | |||||
int8x16_t weight[2][2]; | |||||
int8x16_t src[8 + 1]; | |||||
int16x8_t temp_c[4]; | |||||
init_oc8_ow8<bias_mode>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * pack_iw_len; | |||||
src[0] = vld1q_s8(src_ic_0_3); | |||||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
// oc == 0 | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0][0] = vld1q_s8(read_weight_ptr); | |||||
weight[0][1] = vld1q_s8(read_weight_ptr + 16); | |||||
weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); | |||||
weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); | |||||
c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); | |||||
c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); | |||||
c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[2]); | |||||
c[1][1] = vdotq_s32_h(weight[1][0], src[1], c[1][1], temp_c[3]); | |||||
c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); | |||||
c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); | |||||
c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[2]); | |||||
c[1][1] = vdotq_s32_h(weight[1][1], src[2], c[1][1], temp_c[3]); | |||||
c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]); | |||||
c[1][2] = vdotq_s32_h(weight[1][0], src[2], c[1][2], temp_c[1]); | |||||
c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[2]); | |||||
c[1][3] = vdotq_s32_h(weight[1][0], src[3], c[1][3], temp_c[3]); | |||||
c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]); | |||||
c[1][2] = vdotq_s32_h(weight[1][1], src[3], c[1][2], temp_c[1]); | |||||
c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[2]); | |||||
c[1][3] = vdotq_s32_h(weight[1][1], src[4], c[1][3], temp_c[3]); | |||||
c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]); | |||||
c[1][4] = vdotq_s32_h(weight[1][0], src[4], c[1][4], temp_c[1]); | |||||
c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[2]); | |||||
c[1][5] = vdotq_s32_h(weight[1][0], src[5], c[1][5], temp_c[3]); | |||||
c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]); | |||||
c[1][4] = vdotq_s32_h(weight[1][1], src[5], c[1][4], temp_c[1]); | |||||
c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[2]); | |||||
c[1][5] = vdotq_s32_h(weight[1][1], src[6], c[1][5], temp_c[3]); | |||||
c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]); | |||||
c[1][6] = vdotq_s32_h(weight[1][0], src[6], c[1][6], temp_c[1]); | |||||
c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[2]); | |||||
c[1][7] = vdotq_s32_h(weight[1][0], src[7], c[1][7], temp_c[3]); | |||||
c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]); | |||||
c[1][6] = vdotq_s32_h(weight[1][1], src[7], c[1][6], temp_c[1]); | |||||
c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[2]); | |||||
c[1][7] = vdotq_s32_h(weight[1][1], src[8], c[1][7], temp_c[3]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
store_oc8_ow8_remain_static<remain_w>(c, op, dst_ptr, ld_dst_oc); | |||||
} | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, | |||||
int8_t* dst_ptr, int ic, int ih, | |||||
int iw, const Op& op) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
constexpr int pack_iw_len = 4; | |||||
const int ic_stride = ih * iw * pack_iw_len; | |||||
int32x4_t c[2 * 4]; | |||||
int8x16_t weight[2]; | |||||
int8x16_t src[8 + 1]; | |||||
int16x8_t temp_c[2]; | |||||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * pack_iw_len; | |||||
src[0] = vld1q_s8(src_ic_0_3); | |||||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
// oc == 0 | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0] = vld1q_s8(read_weight_ptr); | |||||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
} | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
static void ker_neon_dirctconv_5x5s1_oc4_ow8(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, | |||||
int8_t* dst_ptr, int ic, int ih, | |||||
int iw, const Op& op) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
constexpr int pack_iw_len = 4; | |||||
const int ic_stride = ih * iw * pack_iw_len; | |||||
int32x4_t c[2 * 4]; | |||||
int8x16_t weight[5]; | |||||
int8x16_t src[8 + 2]; | |||||
int16x8_t temp_c[2]; | |||||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * pack_iw_len; | |||||
src[0] = vld1q_s8(src_ic_0_3); | |||||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||||
// oc == 0 | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0] = vld1q_s8(read_weight_ptr); | |||||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[3], src[4], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[4], src[5], c[1], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[3], src[5], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[3], src[6], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[4], src[6], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[4], src[7], c[3], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[3], src[7], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[3], src[8], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[4], src[8], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[4], src[9], c[5], temp_c[1]); | |||||
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||||
c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[3], src[9], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[3], src[0], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[4], src[0], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[4], src[1], c[7], temp_c[1]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
} | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
static void ker_neon_dirctconv_7x7s1_oc4_ow8(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, | |||||
int8_t* dst_ptr, int ic, int ih, | |||||
int iw, const Op& op) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
constexpr int pack_iw_len = 4; | |||||
const int ic_stride = ih * iw * pack_iw_len; | |||||
int32x4_t c[2 * 4]; | |||||
int8x16_t weight[7]; | |||||
int8x16_t src[8 + 2]; | |||||
int16x8_t temp_c[2]; | |||||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * pack_iw_len; | |||||
src[0] = vld1q_s8(src_ic_0_3); | |||||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||||
// oc == 0 | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0] = vld1q_s8(read_weight_ptr); | |||||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); | |||||
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); | |||||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[3], src[4], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[4], src[5], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[5], src[5], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[5], src[6], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[6], src[6], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[6], src[7], c[1], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[3], src[5], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[3], src[6], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[4], src[6], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[4], src[7], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[5], src[7], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[5], src[8], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[6], src[8], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[6], src[9], c[3], temp_c[1]); | |||||
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||||
c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[3], src[7], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[3], src[8], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[4], src[8], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[4], src[9], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[5], src[9], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[5], src[0], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[6], src[0], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[6], src[1], c[5], temp_c[1]); | |||||
src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||||
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); | |||||
c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[3], src[9], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[3], src[0], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[4], src[0], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[4], src[1], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[5], src[1], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[5], src[2], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[6], src[2], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[6], src[3], c[7], temp_c[1]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
} | |||||
} // namespace | |||||
/** | |||||
origin weight shape <oc/4, ic/4, fh, fw, 4, 4> | |||||
packed weight shape <oc/4, ic/4, fh, fw, 16> | |||||
example: (format like weight<oc, ic>) | |||||
origin | |||||
<0, 0> <1, 0> <2, 0> <3, 0> | |||||
<0, 1> <1, 1> <2, 1> <3, 1> | |||||
<0, 2> <1, 2> <2, 2> <3, 2> | |||||
<0, 3> <1, 3> <2, 3> <3, 3> | |||||
packed | |||||
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> | |||||
--------------------------------------------------------------------- | |||||
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> | |||||
**/ | |||||
void conv_bias::nchw44_pack_filter(const int8_t* src, int8_t* dst, int length) { | |||||
static const uint8_t weight_idx_buffer[16] = {0, 4, 9, 13, 2, 6, 11, 15, | |||||
12, 8, 5, 1, 14, 10, 7, 3}; | |||||
constexpr int simd_len = 16; | |||||
uint8x16_t weight_idx = vld1q_u8(weight_idx_buffer); | |||||
for (int i = 0; i < length; i++) { | |||||
int8x16_t result = vldq_tbl_s8(src + i * simd_len, weight_idx); | |||||
vst1q_s8(dst + i * simd_len, result); | |||||
} | |||||
} | |||||
/** | |||||
origin src shape <n, ic/4, h, w, 4> | |||||
packed src shape <n, ic/4, h, w, 16> | |||||
example: (format like <ic>) | |||||
origin | |||||
<0> <0> <0> <0> | |||||
packed | |||||
low 64 bit <0> <1> <2> <3> | <0> <1> <2> <3> | |||||
--------------------------------------------------------------------- | |||||
high 64 bit <3> <2> <1> <0> | <3> <2> <1> <0> | |||||
**/ | |||||
void conv_bias::nchw44_pack_src(const int8_t* src, int8_t* dst, int length) { | |||||
static const uint8_t src_idx_buffer[16] = {0, 1, 2, 3, 0, 1, 2, 3, | |||||
3, 2, 1, 0, 3, 2, 1, 0}; | |||||
constexpr int pack_ic = 4; | |||||
constexpr int simd_len = 16; | |||||
uint8x16_t src_idx = vld1q_u8(src_idx_buffer); | |||||
for (int i = 0; i < length; i++) { | |||||
int8x16_t result = vld_dup_tbl_s32(src + i * pack_ic, src_idx); | |||||
vst1q_s8(dst + i * simd_len, result); | |||||
} | |||||
} | |||||
template <BiasMode bias_mode, typename Op, int remain_w> | |||||
void conv_bias::conv_direct_stride1_2x2_int8_nchw44( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(temp); | |||||
constexpr size_t filter_size = 2; | |||||
constexpr size_t fh = filter_size; | |||||
constexpr size_t fw = filter_size; | |||||
constexpr size_t ic_step = 4; | |||||
constexpr size_t oc_step = 4; | |||||
constexpr size_t big_oc_step = 8; | |||||
constexpr size_t oh_step = 1; | |||||
constexpr size_t ow_step = 8; | |||||
constexpr int pack_iw_len = 4; | |||||
const size_t img_stride = oh * ow; | |||||
const size_t ow_end = ow / ow_step * ow_step; | |||||
const size_t ow_remain = ow - ow_end; | |||||
const size_t oc_end = oc / big_oc_step * big_oc_step; | |||||
const size_t oc_remain = oc - oc_end; | |||||
const int ld_oc = oh * ow * ic_step; | |||||
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, 0, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ld_oc, op); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, remain_w, | |||||
filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ld_oc, op); | |||||
} | |||||
} | |||||
} | |||||
if (oc_remain > 0) { | |||||
const size_t oc_idx = oc_end; | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, remain_w, | |||||
filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <BiasMode bias_mode, typename Op, int remain_w> | |||||
void conv_bias::conv_direct_stride1_3x3_int8_nchw44( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(temp); | |||||
constexpr size_t filter_size = 3; | |||||
constexpr size_t fh = filter_size; | |||||
constexpr size_t fw = filter_size; | |||||
constexpr size_t ic_step = 4; | |||||
constexpr size_t oc_step = 4; | |||||
constexpr size_t oh_step = 1; | |||||
constexpr size_t ow_step = 8; | |||||
constexpr int pack_iw_len = 4; | |||||
const size_t img_stride = oh * ow; | |||||
const size_t ow_end = ow / ow_step * ow_step; | |||||
const size_t ow_remain = ow - ow_end; | |||||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_3x3s1_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
ker_neon_dirctconv_3x3s1_oc4_ow8<bias_mode, Op, remain_w, | |||||
filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <BiasMode bias_mode, typename Op, int remain_w> | |||||
void conv_bias::conv_direct_stride1_5x5_int8_nchw44( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(temp); | |||||
constexpr size_t filter_size = 5; | |||||
constexpr size_t fh = filter_size; | |||||
constexpr size_t fw = filter_size; | |||||
constexpr size_t ic_step = 4; | |||||
constexpr size_t oc_step = 4; | |||||
constexpr size_t oh_step = 1; | |||||
constexpr size_t ow_step = 8; | |||||
constexpr int pack_iw_len = 4; | |||||
const size_t img_stride = oh * ow; | |||||
const size_t ow_end = ow / ow_step * ow_step; | |||||
const size_t ow_remain = ow - ow_end; | |||||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_5x5s1_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
ker_neon_dirctconv_5x5s1_oc4_ow8<bias_mode, Op, remain_w, | |||||
filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <BiasMode bias_mode, typename Op, int remain_w> | |||||
void conv_bias::conv_direct_stride1_7x7_int8_nchw44( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(temp); | |||||
constexpr size_t filter_size = 7; | |||||
constexpr size_t fh = filter_size; | |||||
constexpr size_t fw = filter_size; | |||||
constexpr size_t ic_step = 4; | |||||
constexpr size_t oc_step = 4; | |||||
constexpr size_t oh_step = 1; | |||||
constexpr size_t ow_step = 8; | |||||
constexpr int pack_iw_len = 4; | |||||
const size_t img_stride = oh * ow; | |||||
const size_t ow_end = ow / ow_step * ow_step; | |||||
const size_t ow_remain = ow - ow_end; | |||||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_7x7s1_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = | |||||
(oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
ker_neon_dirctconv_7x7s1_oc4_ow8<bias_mode, Op, remain_w, | |||||
filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#define INSTANTIATION(stride, i, bias, remain_w, Op) \ | |||||
template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw44< \ | |||||
bias, Op, remain_w>(const int8_t*, const int8_t*, const int32_t*, \ | |||||
int32_t*, int8_t*, const size_t, const size_t, \ | |||||
const size_t, const size_t, const size_t, \ | |||||
const size_t, const Op&); | |||||
#define FOR_OP(stride, i, bias, remain_w) \ | |||||
INSTANTIATION(stride, i, bias, remain_w, \ | |||||
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
INSTANTIATION(stride, i, bias, remain_w, \ | |||||
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
INSTANTIATION(stride, i, bias, remain_w, \ | |||||
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||||
#define FOR_REMAIN(stride, i, bias) \ | |||||
FOR_OP(stride, i, bias, 0) \ | |||||
FOR_OP(stride, i, bias, 1) \ | |||||
FOR_OP(stride, i, bias, 2) \ | |||||
FOR_OP(stride, i, bias, 3) \ | |||||
FOR_OP(stride, i, bias, 4) \ | |||||
FOR_OP(stride, i, bias, 5) \ | |||||
FOR_OP(stride, i, bias, 6) \ | |||||
FOR_OP(stride, i, bias, 7) | |||||
#define FOR_BIAS(stride, i) \ | |||||
FOR_REMAIN(stride, i, BiasMode::NO_BIAS) \ | |||||
FOR_REMAIN(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
#define FOR_FILTER(stride) \ | |||||
FOR_BIAS(stride, 2) \ | |||||
FOR_BIAS(stride, 3) \ | |||||
FOR_BIAS(stride, 5) \ | |||||
FOR_BIAS(stride, 7) | |||||
FOR_FILTER(stride1) | |||||
#undef FOR_STRIDE | |||||
#undef FOR_FILTER | |||||
#undef FOR_IC | |||||
#undef FOR_BIAS | |||||
#undef FOR_NONLINEAR | |||||
#undef FOR_REMAIN | |||||
#undef INSTANTIATION |
@@ -1,793 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/int8/direct.h" | |||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | |||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
/** | |||||
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc> | |||||
example: (format like weight<oc, ic>) | |||||
packed weight | |||||
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> | |||||
--------------------------------------------------------------------- | |||||
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> | |||||
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> | |||||
**/ | |||||
// TODO: can try oh = 2 impl, oc = 8 impl | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
static void ker_neon_dirctconv_3x3s2_oc4_ow8(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, | |||||
int8_t* dst_ptr, int ic, int ih, | |||||
int iw, const Op& op) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
constexpr int pack_iw_len = 4; | |||||
const int ic_stride = ih * iw * pack_iw_len; | |||||
int32x4_t c[2 * 4]; | |||||
int8x16_t weight[3]; | |||||
int8x16_t src[8 + 2]; | |||||
int16x8_t temp_c[2]; | |||||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * pack_iw_len; | |||||
src[0] = vld1q_s8(src_ic_0_3); | |||||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||||
// oc == 0 | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0] = vld1q_s8(read_weight_ptr); | |||||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]); | |||||
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||||
src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); | |||||
c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]); | |||||
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); | |||||
src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); | |||||
src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); | |||||
src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); | |||||
c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
} | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, | |||||
int8_t* dst_ptr, int ic, int ih, | |||||
int iw, int ld_dst_oc, | |||||
const Op& op) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int oc_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
constexpr int pack_iw_len = 4; | |||||
const int ic_stride = ih * iw * pack_iw_len; | |||||
const int ld_weight_oc4 = oc_step * fh * fw * ic; | |||||
int32x4_t c[2][8]; | |||||
int8x16_t weight[2][2]; | |||||
int8x16_t src[8 + 1]; | |||||
int16x8_t temp_c[4]; | |||||
init_oc8_ow8<bias_mode>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * pack_iw_len; | |||||
src[0] = vld1q_s8(src_ic_0_3); | |||||
src[1] = vld1q_s8(src_ic_0_3 + 16); | |||||
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||||
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||||
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||||
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||||
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||||
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||||
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||||
// oc == 0 | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0][0] = vld1q_s8(read_weight_ptr); | |||||
weight[0][1] = vld1q_s8(read_weight_ptr + 16); | |||||
weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); | |||||
weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); | |||||
c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); | |||||
c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); | |||||
c[0][1] = vdotq_s32_h(weight[0][0], src[2], c[0][1], temp_c[2]); | |||||
c[1][1] = vdotq_s32_h(weight[1][0], src[2], c[1][1], temp_c[3]); | |||||
c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); | |||||
c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); | |||||
c[0][1] = vdotq_s32_h(weight[0][1], src[3], c[0][1], temp_c[2]); | |||||
c[1][1] = vdotq_s32_h(weight[1][1], src[3], c[1][1], temp_c[3]); | |||||
c[0][2] = vdotq_s32_h(weight[0][0], src[4], c[0][2], temp_c[0]); | |||||
c[1][2] = vdotq_s32_h(weight[1][0], src[4], c[1][2], temp_c[1]); | |||||
c[0][3] = vdotq_s32_h(weight[0][0], src[6], c[0][3], temp_c[2]); | |||||
c[1][3] = vdotq_s32_h(weight[1][0], src[6], c[1][3], temp_c[3]); | |||||
c[0][2] = vdotq_s32_h(weight[0][1], src[5], c[0][2], temp_c[0]); | |||||
c[1][2] = vdotq_s32_h(weight[1][1], src[5], c[1][2], temp_c[1]); | |||||
c[0][3] = vdotq_s32_h(weight[0][1], src[7], c[0][3], temp_c[2]); | |||||
c[1][3] = vdotq_s32_h(weight[1][1], src[7], c[1][3], temp_c[3]); | |||||
src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||||
src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||||
c[0][4] = vdotq_s32_h(weight[0][0], src[8], c[0][4], temp_c[0]); | |||||
c[1][4] = vdotq_s32_h(weight[1][0], src[8], c[1][4], temp_c[1]); | |||||
c[0][5] = vdotq_s32_h(weight[0][0], src[1], c[0][5], temp_c[2]); | |||||
c[1][5] = vdotq_s32_h(weight[1][0], src[1], c[1][5], temp_c[3]); | |||||
c[0][4] = vdotq_s32_h(weight[0][1], src[0], c[0][4], temp_c[0]); | |||||
c[1][4] = vdotq_s32_h(weight[1][1], src[0], c[1][4], temp_c[1]); | |||||
c[0][5] = vdotq_s32_h(weight[0][1], src[2], c[0][5], temp_c[2]); | |||||
c[1][5] = vdotq_s32_h(weight[1][1], src[2], c[1][5], temp_c[3]); | |||||
src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||||
src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); | |||||
src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); | |||||
src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); | |||||
c[0][6] = vdotq_s32_h(weight[0][0], src[3], c[0][6], temp_c[0]); | |||||
c[1][6] = vdotq_s32_h(weight[1][0], src[3], c[1][6], temp_c[1]); | |||||
c[0][7] = vdotq_s32_h(weight[0][0], src[5], c[0][7], temp_c[2]); | |||||
c[1][7] = vdotq_s32_h(weight[1][0], src[5], c[1][7], temp_c[3]); | |||||
c[0][6] = vdotq_s32_h(weight[0][1], src[4], c[0][6], temp_c[0]); | |||||
c[1][6] = vdotq_s32_h(weight[1][1], src[4], c[1][6], temp_c[1]); | |||||
c[0][7] = vdotq_s32_h(weight[0][1], src[6], c[0][7], temp_c[2]); | |||||
c[1][7] = vdotq_s32_h(weight[1][1], src[6], c[1][7], temp_c[3]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
store_oc8_ow8_remain_static<remain_w>(c, op, dst_ptr, ld_dst_oc); | |||||
} | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, | |||||
int8_t* dst_ptr, int ic, int ih, | |||||
int iw, const Op& op) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
constexpr int pack_iw_len = 4; | |||||
const int ic_stride = ih * iw * pack_iw_len; | |||||
int32x4_t c[2 * 4]; | |||||
int8x16_t weight[2]; | |||||
int8x16_t src[8 + 1]; | |||||
int16x8_t temp_c[2]; | |||||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * pack_iw_len; | |||||
src[0] = vld1q_s8(src_ic_0_3); | |||||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
// oc == 0 | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0] = vld1q_s8(read_weight_ptr); | |||||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); | |||||
src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||||
src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||||
c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[0], src[1], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[1], src[0], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[1], src[2], c[5], temp_c[1]); | |||||
src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||||
src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); | |||||
src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); | |||||
src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); | |||||
c[6] = vdotq_s32_h(weight[0], src[3], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[0], src[5], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[1], src[4], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[1], src[6], c[7], temp_c[1]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
} | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
static void ker_neon_dirctconv_5x5s2_oc4_ow8(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, | |||||
int8_t* dst_ptr, int ic, int ih, | |||||
int iw, const Op& op) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
constexpr int pack_iw_len = 4; | |||||
const int ic_stride = ih * iw * pack_iw_len; | |||||
int32x4_t c[2 * 4]; | |||||
int8x16_t weight[5]; | |||||
int8x16_t src[8 + 2]; | |||||
int16x8_t temp_c[2]; | |||||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * pack_iw_len; | |||||
src[0] = vld1q_s8(src_ic_0_3); | |||||
src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||||
// oc == 0 | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0] = vld1q_s8(read_weight_ptr); | |||||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[3], src[5], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[4], src[6], c[1], temp_c[1]); | |||||
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[3], src[7], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[3], src[9], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[4], src[8], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[4], src[0], c[3], temp_c[1]); | |||||
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||||
src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); | |||||
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); | |||||
src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); | |||||
c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[3], src[1], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[3], src[3], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[4], src[2], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[4], src[4], c[5], temp_c[1]); | |||||
src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); | |||||
src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); | |||||
src[7] = vld1q_s8((src_ic_0_3 + 17 * 16)); | |||||
src[8] = vld1q_s8((src_ic_0_3 + 18 * 16)); | |||||
c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[3], src[5], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[3], src[7], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[4], src[6], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[4], src[8], c[7], temp_c[1]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
} | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
static void ker_neon_dirctconv_7x7s2_oc4_ow8(const int8_t* src_ptr, | |||||
const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, | |||||
int8_t* dst_ptr, int ic, int ih, | |||||
int iw, const Op& op) { | |||||
constexpr int fh = filter_size; | |||||
constexpr int fw = filter_size; | |||||
constexpr int ic_step = 4; | |||||
constexpr int loop_ic_step = 4; | |||||
constexpr int ld_weight_ic4 = 16; | |||||
constexpr int pack_iw_len = 4; | |||||
const int ic_stride = ih * iw * pack_iw_len; | |||||
int32x4_t c[2 * 4]; | |||||
int8x16_t weight[7]; | |||||
int8x16_t src[8 + 2]; | |||||
int16x8_t temp_c[2]; | |||||
init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
fh_idx * iw * ic_step * pack_iw_len; | |||||
src[0] = vld1q_s8(src_ic_0_3); | |||||
src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); | |||||
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||||
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||||
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||||
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||||
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||||
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||||
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||||
src[9] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||||
// oc == 0 | |||||
const int8_t* read_weight_ptr = | |||||
weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
weight[0] = vld1q_s8(read_weight_ptr); | |||||
weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); | |||||
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); | |||||
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[3], src[5], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[4], src[6], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[5], src[5], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[5], src[7], c[1], temp_c[1]); | |||||
c[0] = vdotq_s32_h(weight[6], src[6], c[0], temp_c[0]); | |||||
c[1] = vdotq_s32_h(weight[6], src[8], c[1], temp_c[1]); | |||||
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||||
src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||||
c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[3], src[7], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[3], src[9], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[4], src[8], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[4], src[0], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[5], src[9], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[5], src[1], c[3], temp_c[1]); | |||||
c[2] = vdotq_s32_h(weight[6], src[0], c[2], temp_c[0]); | |||||
c[3] = vdotq_s32_h(weight[6], src[2], c[3], temp_c[1]); | |||||
src[3] = vld1q_s8(src_ic_0_3 + 13 * 16); | |||||
src[4] = vld1q_s8(src_ic_0_3 + 14 * 16); | |||||
src[5] = vld1q_s8(src_ic_0_3 + 15 * 16); | |||||
src[6] = vld1q_s8(src_ic_0_3 + 16 * 16); | |||||
c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[3], src[1], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[3], src[3], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[4], src[2], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[4], src[4], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[5], src[3], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[5], src[5], c[5], temp_c[1]); | |||||
c[4] = vdotq_s32_h(weight[6], src[4], c[4], temp_c[0]); | |||||
c[5] = vdotq_s32_h(weight[6], src[6], c[5], temp_c[1]); | |||||
src[7] = vld1q_s8(src_ic_0_3 + 17 * 16); | |||||
src[8] = vld1q_s8(src_ic_0_3 + 18 * 16); | |||||
src[9] = vld1q_s8(src_ic_0_3 + 19 * 16); | |||||
src[0] = vld1q_s8(src_ic_0_3 + 20 * 16); | |||||
c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[3], src[5], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[3], src[7], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[4], src[6], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[4], src[8], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[5], src[7], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[5], src[9], c[7], temp_c[1]); | |||||
c[6] = vdotq_s32_h(weight[6], src[8], c[6], temp_c[0]); | |||||
c[7] = vdotq_s32_h(weight[6], src[0], c[7], temp_c[1]); | |||||
} | |||||
weight_ptr += fh * fw * ld_weight_ic4; | |||||
} | |||||
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
} | |||||
} // namespace | |||||
template <BiasMode bias_mode, typename Op, int remain_w> | |||||
void conv_bias::conv_direct_stride2_2x2_int8_nchw44( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(temp); | |||||
constexpr size_t filter_size = 2; | |||||
constexpr size_t fh = filter_size; | |||||
constexpr size_t fw = filter_size; | |||||
constexpr size_t ic_step = 4; | |||||
constexpr size_t oc_step = 4; | |||||
constexpr size_t big_oc_step = 8; | |||||
constexpr size_t oh_step = 1; | |||||
constexpr size_t ow_step = 8; | |||||
constexpr size_t stride_h = 2; | |||||
constexpr size_t stride_w = 2; | |||||
constexpr int pack_iw_len = 4; | |||||
const size_t out_img_stride = oh * ow; | |||||
const size_t ow_end = ow / ow_step * ow_step; | |||||
const size_t ow_remain = ow - ow_end; | |||||
const size_t oc_end = oc / big_oc_step * big_oc_step; | |||||
const size_t oc_remain = oc - oc_end; | |||||
const int ld_oc = oh * ow * ic_step; | |||||
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = | |||||
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||||
pack_iw_len; | |||||
const size_t dst_offset = oc_idx * out_img_stride + | |||||
(oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, 0, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ld_oc, op); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = | |||||
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||||
pack_iw_len; | |||||
const size_t dst_offset = oc_idx * out_img_stride + | |||||
(oh_idx * ow + ow_end) * oc_step; | |||||
ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, remain_w, | |||||
filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, ld_oc, op); | |||||
} | |||||
} | |||||
} | |||||
if (oc_remain > 0) { | |||||
const size_t oc_idx = oc_end; | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = | |||||
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||||
pack_iw_len; | |||||
const size_t dst_offset = oc_idx * out_img_stride + | |||||
(oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = | |||||
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||||
pack_iw_len; | |||||
const size_t dst_offset = oc_idx * out_img_stride + | |||||
(oh_idx * ow + ow_end) * oc_step; | |||||
ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, remain_w, | |||||
filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <BiasMode bias_mode, typename Op, int remain_w> | |||||
void conv_bias::conv_direct_stride2_3x3_int8_nchw44( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(temp); | |||||
constexpr size_t filter_size = 3; | |||||
constexpr size_t fh = filter_size; | |||||
constexpr size_t fw = filter_size; | |||||
constexpr size_t ic_step = 4; | |||||
constexpr size_t oc_step = 4; | |||||
constexpr size_t oh_step = 1; | |||||
constexpr size_t ow_step = 8; | |||||
constexpr size_t stride_h = 2; | |||||
constexpr size_t stride_w = 2; | |||||
constexpr int pack_iw_len = 4; | |||||
const size_t img_stride = oh * ow; | |||||
const size_t ow_end = ow / ow_step * ow_step; | |||||
const size_t ow_remain = ow - ow_end; | |||||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = | |||||
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||||
pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_3x3s2_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = | |||||
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||||
pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
ker_neon_dirctconv_3x3s2_oc4_ow8<bias_mode, Op, remain_w, | |||||
filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <BiasMode bias_mode, typename Op, int remain_w> | |||||
void conv_bias::conv_direct_stride2_5x5_int8_nchw44( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(temp); | |||||
constexpr size_t filter_size = 5; | |||||
constexpr size_t fh = filter_size; | |||||
constexpr size_t fw = filter_size; | |||||
constexpr size_t ic_step = 4; | |||||
constexpr size_t oc_step = 4; | |||||
constexpr size_t oh_step = 1; | |||||
constexpr size_t ow_step = 8; | |||||
constexpr size_t stride_h = 2; | |||||
constexpr size_t stride_w = 2; | |||||
constexpr int pack_iw_len = 4; | |||||
const size_t img_stride = oh * ow; | |||||
const size_t ow_end = ow / ow_step * ow_step; | |||||
const size_t ow_remain = ow - ow_end; | |||||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = | |||||
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||||
pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_5x5s2_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = | |||||
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||||
pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
ker_neon_dirctconv_5x5s2_oc4_ow8<bias_mode, Op, remain_w, | |||||
filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <BiasMode bias_mode, typename Op, int remain_w> | |||||
void conv_bias::conv_direct_stride2_7x7_int8_nchw44( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(temp); | |||||
constexpr size_t filter_size = 7; | |||||
constexpr size_t fh = filter_size; | |||||
constexpr size_t fw = filter_size; | |||||
constexpr size_t ic_step = 4; | |||||
constexpr size_t oc_step = 4; | |||||
constexpr size_t oh_step = 1; | |||||
constexpr size_t ow_step = 8; | |||||
constexpr size_t stride_h = 2; | |||||
constexpr size_t stride_w = 2; | |||||
constexpr int pack_iw_len = 4; | |||||
const size_t img_stride = oh * ow; | |||||
const size_t ow_end = ow / ow_step * ow_step; | |||||
const size_t ow_remain = ow - ow_end; | |||||
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
const size_t src_offset = | |||||
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||||
pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
ker_neon_dirctconv_7x7s2_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
if (ow_remain > 0) { | |||||
const size_t src_offset = | |||||
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||||
pack_iw_len; | |||||
const size_t dst_offset = | |||||
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
ker_neon_dirctconv_7x7s2_oc4_ow8<bias_mode, Op, remain_w, | |||||
filter_size>( | |||||
src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
dst + dst_offset, ic, ih, iw, op); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#define INSTANTIATION(stride, i, bias, remain_w, Op) \ | |||||
template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw44< \ | |||||
bias, Op, remain_w>(const int8_t*, const int8_t*, const int32_t*, \ | |||||
int32_t*, int8_t*, const size_t, const size_t, \ | |||||
const size_t, const size_t, const size_t, \ | |||||
const size_t, const Op&); | |||||
#define FOR_OP(stride, i, bias, remain_w) \ | |||||
INSTANTIATION(stride, i, bias, remain_w, \ | |||||
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
INSTANTIATION(stride, i, bias, remain_w, \ | |||||
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
INSTANTIATION(stride, i, bias, remain_w, \ | |||||
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||||
#define FOR_REMAIN(stride, i, bias) \ | |||||
FOR_OP(stride, i, bias, 0) \ | |||||
FOR_OP(stride, i, bias, 1) \ | |||||
FOR_OP(stride, i, bias, 2) \ | |||||
FOR_OP(stride, i, bias, 3) \ | |||||
FOR_OP(stride, i, bias, 4) \ | |||||
FOR_OP(stride, i, bias, 5) \ | |||||
FOR_OP(stride, i, bias, 6) \ | |||||
FOR_OP(stride, i, bias, 7) | |||||
#define FOR_BIAS(stride, i) \ | |||||
FOR_REMAIN(stride, i, BiasMode::NO_BIAS) \ | |||||
FOR_REMAIN(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
#define FOR_FILTER(stride) \ | |||||
FOR_BIAS(stride, 2) \ | |||||
FOR_BIAS(stride, 3) \ | |||||
FOR_BIAS(stride, 5) \ | |||||
FOR_BIAS(stride, 7) | |||||
FOR_FILTER(stride2) | |||||
#undef FOR_STRIDE | |||||
#undef FOR_FILTER | |||||
#undef FOR_IC | |||||
#undef FOR_BIAS | |||||
#undef FOR_NONLINEAR | |||||
#undef FOR_REMAIN | |||||
#undef INSTANTIATION |
@@ -46,11 +46,10 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
AlgoQU8DirectStride1 qu8_direct_stride1_small_group{false}; | AlgoQU8DirectStride1 qu8_direct_stride1_small_group{false}; | ||||
AlgoS8DirectStride2 s8_direct_stride2_large_group{true}; | AlgoS8DirectStride2 s8_direct_stride2_large_group{true}; | ||||
AlgoS8DirectStride2 s8_direct_stride2_small_group{false}; | AlgoS8DirectStride2 s8_direct_stride2_small_group{false}; | ||||
AlgoS8DirectStride2NCHW44 s8_direct_stride2_nchw44; | |||||
AlgoS8DirectNCHW44 s8_direct_nchw44; | |||||
AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; | AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; | ||||
AlgoS8DirectStride1 s8_direct_stride1_large_group{true}; | AlgoS8DirectStride1 s8_direct_stride1_large_group{true}; | ||||
AlgoS8DirectStride1 s8_direct_stride1_small_group{false}; | AlgoS8DirectStride1 s8_direct_stride1_small_group{false}; | ||||
AlgoS8DirectStride1NCHW44 s8_direct_stride1_nchw44; | |||||
AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; | AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; | ||||
AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; | AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; | ||||
@@ -114,11 +113,10 @@ public: | |||||
direct_algos.emplace_back(&qu8_direct_stride1_small_group); | direct_algos.emplace_back(&qu8_direct_stride1_small_group); | ||||
direct_algos.emplace_back(&s8_direct_stride2_large_group); | direct_algos.emplace_back(&s8_direct_stride2_large_group); | ||||
direct_algos.emplace_back(&s8_direct_stride2_small_group); | direct_algos.emplace_back(&s8_direct_stride2_small_group); | ||||
direct_algos.emplace_back(&s8_direct_stride2_nchw44); | |||||
direct_algos.emplace_back(&s8_direct_nchw44); | |||||
direct_algos.emplace_back(&s8_direct_nchw_nchw44); | direct_algos.emplace_back(&s8_direct_nchw_nchw44); | ||||
direct_algos.emplace_back(&s8_direct_stride1_large_group); | direct_algos.emplace_back(&s8_direct_stride1_large_group); | ||||
direct_algos.emplace_back(&s8_direct_stride1_small_group); | direct_algos.emplace_back(&s8_direct_stride1_small_group); | ||||
direct_algos.emplace_back(&s8_direct_stride1_nchw44); | |||||
direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); | direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); | ||||
direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); | direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); | ||||
@@ -37,9 +37,8 @@ protected: | |||||
private: | private: | ||||
class AlgoS8DirectStride1; | class AlgoS8DirectStride1; | ||||
class AlgoS8DirectStride1NCHW44; | |||||
class AlgoS8DirectStride2; | class AlgoS8DirectStride2; | ||||
class AlgoS8DirectStride2NCHW44; | |||||
class AlgoS8DirectNCHW44; | |||||
class AlgoS8DirectNCHWNCHW44; | class AlgoS8DirectNCHWNCHW44; | ||||
class AlgoQU8DirectStride1; | class AlgoQU8DirectStride1; | ||||
class AlgoQU8DirectStride2; | class AlgoQU8DirectStride2; | ||||
@@ -27,6 +27,8 @@ struct NoneOp; | |||||
#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ | #define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ | ||||
template <> \ | template <> \ | ||||
struct NoneOp<_ctype> : NoneOpBase<_ctype> { \ | struct NoneOp<_ctype> : NoneOpBase<_ctype> { \ | ||||
NoneOp(){}; \ | |||||
NoneOp(float, float){}; \ | |||||
using NoneOpBase::NoneOpBase; \ | using NoneOpBase::NoneOpBase; \ | ||||
using NoneOpBase::operator(); \ | using NoneOpBase::operator(); \ | ||||
constexpr static size_t SIMD_WIDTH = _simd_width; \ | constexpr static size_t SIMD_WIDTH = _simd_width; \ | ||||
@@ -226,7 +226,15 @@ static void benchmark_convbias(Handle* handle, std::string int_name, | |||||
run(1, 3, 32, 224, 224, 5, 1, true); | run(1, 3, 32, 224, 224, 5, 1, true); | ||||
run(1, 3, 64, 224, 224, 7, 1, true); | run(1, 3, 64, 224, 224, 7, 1, true); | ||||
for (size_t stride : {1, 2}) { | |||||
run(1, 64, 128, 56, 56, 3, 2, false); | |||||
run(1, 128, 256, 28, 28, 3, 2, false); | |||||
run(1, 256, 512, 14, 14, 3, 2, false); | |||||
run(1, 128, 128, 28, 28, 3, 1, false); | |||||
run(1, 256, 256, 14, 14, 3, 1, false); | |||||
run(1, 512, 512, 7, 7, 3, 1, false); | |||||
for (size_t stride : {1}) { | |||||
printf("stride %zu\n", stride); | printf("stride %zu\n", stride); | ||||
for (size_t filter_size : {2, 3, 5, 7}) { | for (size_t filter_size : {2, 3, 5, 7}) { | ||||
for (size_t img_size : {32}) { | for (size_t img_size : {32}) { | ||||
@@ -527,12 +527,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) { | ||||
checker_conv_bias_qint8x8x8( | checker_conv_bias_qint8x8x8( | ||||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), | get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), | ||||
handle(), "S8_NCHW44_DIRECT_STRD1"); | |||||
handle(), "S8_NCHW44_DIRECT"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) { | |||||
checker_conv_bias_qint8x8x32( | |||||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true), | |||||
handle(), "S8_NCHW44_DIRECT"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) { | |||||
checker_conv_bias_qint8x8x32( | |||||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true), | |||||
handle(), "S8_NCHW44_DIRECT"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) { | ||||
checker_conv_bias_qint8x8x8( | checker_conv_bias_qint8x8x8( | ||||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | ||||
handle(), "S8_NCHW44_DIRECT_STRD2"); | |||||
handle(), "S8_NCHW44_DIRECT"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) { | ||||
checker_conv_bias_qint8x8x8( | checker_conv_bias_qint8x8x8( | ||||
@@ -1085,7 +1095,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) { | |||||
dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); | dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
@@ -1096,17 +1105,17 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | |||||
param::MatrixMul::Format format, float eps) { | param::MatrixMul::Format format, float eps) { | ||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
for (uint32_t m : out_size) { | for (uint32_t m : out_size) { | ||||
checker.set_extra_opr_impl(std::bind( | |||||
winograd_algo_extra_impl, std::placeholders::_1, m, | |||||
arg.param, handle, format)); | |||||
checker.set_dtype(0, A_dtype) | |||||
.set_dtype(1, B_dtype) | |||||
.set_dtype(2, C_dtype) | |||||
.set_dtype(4, D_dtype) | |||||
.set_epsilon(eps) | |||||
.set_param(arg.param) | |||||
.execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
} | |||||
checker.set_extra_opr_impl(std::bind( | |||||
winograd_algo_extra_impl, std::placeholders::_1, m, | |||||
arg.param, handle, format)); | |||||
checker.set_dtype(0, A_dtype) | |||||
.set_dtype(1, B_dtype) | |||||
.set_dtype(2, C_dtype) | |||||
.set_dtype(4, D_dtype) | |||||
.set_epsilon(eps) | |||||
.set_param(arg.param) | |||||
.execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
} | |||||
} | } | ||||
}; | }; | ||||
@@ -1118,7 +1127,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | |||||
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | ||||
ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str())); | ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str())); | ||||
std::vector<TestArg> quantized_args = get_int8_nchw44_args (3,4); | |||||
std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4); | |||||
UniformIntRNG int_rng{-50, 50}; | UniformIntRNG int_rng{-50, 50}; | ||||
checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); | checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); | ||||
run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f), | run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f), | ||||
@@ -1126,8 +1135,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | |||||
dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); | dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
Checker<ConvBiasForward> checker(handle()); | Checker<ConvBiasForward> checker(handle()); | ||||
@@ -1137,17 +1146,17 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPM | |||||
param::MatrixMul::Format format, float eps) { | param::MatrixMul::Format format, float eps) { | ||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
for (uint32_t m : out_size) { | for (uint32_t m : out_size) { | ||||
checker.set_extra_opr_impl(std::bind( | |||||
winograd_algo_extra_impl, std::placeholders::_1, m, | |||||
arg.param, handle, format)); | |||||
checker.set_dtype(0, A_dtype) | |||||
.set_dtype(1, B_dtype) | |||||
.set_dtype(2, C_dtype) | |||||
.set_dtype(4, D_dtype) | |||||
.set_epsilon(eps) | |||||
.set_param(arg.param) | |||||
.execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
} | |||||
checker.set_extra_opr_impl(std::bind( | |||||
winograd_algo_extra_impl, std::placeholders::_1, m, | |||||
arg.param, handle, format)); | |||||
checker.set_dtype(0, A_dtype) | |||||
.set_dtype(1, B_dtype) | |||||
.set_dtype(2, C_dtype) | |||||
.set_dtype(4, D_dtype) | |||||
.set_epsilon(eps) | |||||
.set_param(arg.param) | |||||
.execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
} | |||||
} | } | ||||
}; | }; | ||||
@@ -1168,7 +1177,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPM | |||||
dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); | dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
Checker<ConvBiasForward> checker(handle()); | Checker<ConvBiasForward> checker(handle()); | ||||
@@ -1196,21 +1206,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F | |||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
const char* matmul_name = "AARCH64_F32_MK4_4x16"; | const char* matmul_name = "AARCH64_F32_MK4_4x16"; | ||||
#else | #else | ||||
const char* matmul_name = "ARMV7_F32_MK4_4x8"; | |||||
const char* matmul_name = "ARMV7_F32_MK4_4x8"; | |||||
#endif | #endif | ||||
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | ||||
ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str())); | ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str())); | ||||
std::vector<TestArg> quantized_args = | |||||
get_int8_nchw44_args(3, 4, true); | |||||
std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true); | |||||
UniformIntRNG int_rng{-50, 50}; | UniformIntRNG int_rng{-50, 50}; | ||||
checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); | checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); | ||||
run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f), | run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f), | ||||
dtype::QuantizedS8(0.01887994f), | dtype::QuantizedS8(0.01887994f), | ||||
dtype::QuantizedS32(0.41113496f * 0.01887994f), | dtype::QuantizedS32(0.41113496f * 0.01887994f), | ||||
dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, epsilon); | |||||
dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, | |||||
epsilon); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
Checker<ConvBiasForward> checker(handle()); | Checker<ConvBiasForward> checker(handle()); | ||||
@@ -1238,7 +1249,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F | |||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
const char* matmul_name = "AARCH64_F32_MK4_4x16"; | const char* matmul_name = "AARCH64_F32_MK4_4x16"; | ||||
#else | #else | ||||
const char* matmul_name = "ARMV7_F32_MK4_4x8"; | |||||
const char* matmul_name = "ARMV7_F32_MK4_4x8"; | |||||
#endif | #endif | ||||
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | ||||
ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str())); | ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str())); | ||||
@@ -1249,10 +1260,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F | |||||
run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f), | run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f), | ||||
dtype::QuantizedS8(0.01887994f), | dtype::QuantizedS8(0.01887994f), | ||||
dtype::QuantizedS32(0.41113496f * 0.01887994f), | dtype::QuantizedS32(0.41113496f * 0.01887994f), | ||||
dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, epsilon); | |||||
dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, | |||||
epsilon); | |||||
} | } | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||