GitOrigin-RevId: d326035202
tags/v0.4.0
@@ -17,18 +17,14 @@ | |||||
#include "src/fallback/conv_bias/opr_impl.h" | #include "src/fallback/conv_bias/opr_impl.h" | ||||
#include "src/fallback/conv_bias/winograd/strategy.h" | #include "src/fallback/conv_bias/winograd/strategy.h" | ||||
#include "src/naive/convolution/helper.h" | #include "src/naive/convolution/helper.h" | ||||
#if MEGDNN_X86 | |||||
#include "src/x86/conv_bias/postprocess_helper.h" | |||||
#endif | |||||
#include "midout.h" | #include "midout.h" | ||||
MIDOUT_DECL(megdnn_fallback_im2col) | MIDOUT_DECL(megdnn_fallback_im2col) | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace fallback; | using namespace fallback; | ||||
using namespace im2col; | using namespace im2col; | ||||
#if MEGDNN_X86 | |||||
using namespace x86; | |||||
#endif | |||||
/*======================== AlgoIm2col=======================*/ | /*======================== AlgoIm2col=======================*/ | ||||
/*! | /*! | ||||
@@ -47,8 +43,8 @@ using Pack_Mode=fallback::MatrixMulImpl::AlgoBase::PackMode; | |||||
static void copy_padding_kern(WorkspaceBundle bundle, | static void copy_padding_kern(WorkspaceBundle bundle, | ||||
const ConvBiasImpl::NCBKernParam& param, | const ConvBiasImpl::NCBKernParam& param, | ||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | const ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
StrategyBase* im2colstrategy) { | |||||
im2colstrategy->copy_padding_kern(bundle, param, ncb_index); | |||||
StrategyBase* im2colstrategy, size_t pack_oc_size) { | |||||
im2colstrategy->copy_padding_kern(bundle, param, ncb_index, pack_oc_size); | |||||
} | } | ||||
//! packA_kern | //! packA_kern | ||||
@@ -57,9 +53,9 @@ static void packA_kern(WorkspaceBundle bundle, | |||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
StrategyBase* im2colstrategy) { | |||||
StrategyBase* im2colstrategy, size_t pack_oc_size) { | |||||
im2colstrategy->packA_kern(bundle, param, matmulparam, matmul_algo, | im2colstrategy->packA_kern(bundle, param, matmulparam, matmul_algo, | ||||
ncb_index); | |||||
ncb_index, pack_oc_size); | |||||
} | } | ||||
/*! | /*! | ||||
@@ -129,14 +125,17 @@ public: | |||||
size_t oc_tile_size) { | size_t oc_tile_size) { | ||||
size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0], | size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0], | ||||
FW = param.filter_meta.spatial[1]; | FW = param.filter_meta.spatial[1]; | ||||
size_t pack_oc_size = 1; | |||||
size_t im2col = 0, packb = 0, bias_temp = 0; | size_t im2col = 0, packb = 0, bias_temp = 0; | ||||
bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT; | bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT; | ||||
megdnn_assert(default_pack, "only support default packa"); | megdnn_assert(default_pack, "only support default packa"); | ||||
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||||
pack_oc_size = 4; | |||||
} | |||||
size_t im2col_dst_size = | size_t im2col_dst_size = | ||||
IC * FH * FW * ohw_tile_size * sizeof(param.src_type); | IC * FH * FW * ohw_tile_size * sizeof(param.src_type); | ||||
size_t matmul_dst_size = | |||||
oc_tile_size * ohw_tile_size * sizeof(param.bias_type); | |||||
size_t matmul_dst_size = pack_oc_size * oc_tile_size * ohw_tile_size * | |||||
sizeof(param.bias_type); | |||||
//! matmul_dst and im2col_dst use the same memory | //! matmul_dst and im2col_dst use the same memory | ||||
WorkspaceBundle wb = matmul_algo->get_bundle(im2col_kern_param); | WorkspaceBundle wb = matmul_algo->get_bundle(im2col_kern_param); | ||||
packb = wb.get_size(1); | packb = wb.get_size(1); | ||||
@@ -318,17 +317,18 @@ public: | |||||
} | } | ||||
}; | }; | ||||
#undef FILL_IM2COL_STRATEGY_PARAM | |||||
fallback::MatrixMulImpl::KernSizeParam | fallback::MatrixMulImpl::KernSizeParam | ||||
ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | ||||
size_t ohw_tile_size, | size_t ohw_tile_size, | ||||
size_t oc_tile_size) const { | size_t oc_tile_size) const { | ||||
bool is_nchw44 = | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW44; | |||||
size_t M = oc_tile_size; | size_t M = oc_tile_size; | ||||
size_t N = ohw_tile_size; | size_t N = ohw_tile_size; | ||||
size_t K = param.filter_meta.icpg * param.filter_meta.spatial[0] * | size_t K = param.filter_meta.icpg * param.filter_meta.spatial[0] * | ||||
param.filter_meta.spatial[1]; | param.filter_meta.spatial[1]; | ||||
size_t LDA = K, LDB = N, LDC = N; | |||||
size_t pack_oc_size = is_nchw44 ? 4 : 1; | |||||
size_t LDA = pack_oc_size * K, LDB = pack_oc_size * N, LDC = N; | |||||
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 && | bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | ||||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | (param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | ||||
@@ -345,7 +345,8 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | |||||
false, | false, | ||||
false, | false, | ||||
param::MatrixMul::ComputeMode::DEFAULT, | param::MatrixMul::ComputeMode::DEFAULT, | ||||
param::MatrixMul::Format::DEFAULT}; | |||||
is_nchw44 ? param::MatrixMul::Format::MK4 | |||||
: param::MatrixMul::Format::DEFAULT}; | |||||
} | } | ||||
void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( | void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( | ||||
@@ -405,6 +406,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||||
size_t GROUP = param.filter_meta.group; | size_t GROUP = param.filter_meta.group; | ||||
bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT; | bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT; | ||||
bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA; | bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA; | ||||
if (need_pack || only_packA) { | if (need_pack || only_packA) { | ||||
auto inner_block = m_matmul_algo->get_inner_block_size(); | auto inner_block = m_matmul_algo->get_inner_block_size(); | ||||
choice_ohw_oc_block(param, inner_block.m, inner_block.n, need_pack); | choice_ohw_oc_block(param, inner_block.m, inner_block.n, need_pack); | ||||
@@ -421,16 +423,19 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||||
need_pack); | need_pack); | ||||
packa_group_size = 0; | packa_group_size = 0; | ||||
} | } | ||||
if (no_need_pading) { | if (no_need_pading) { | ||||
padding = 0; //! not need padding | padding = 0; //! not need padding | ||||
} else { | } else { | ||||
padding = (GROUP * N * IC * IH2 * IW2) * | padding = (GROUP * N * IC * IH2 * IW2) * | ||||
sizeof(param.src_type); //! for padding | sizeof(param.src_type); //! for padding | ||||
} | } | ||||
packa_size = GROUP * packa_group_size; //! for packA size = GROUP * a_size | packa_size = GROUP * packa_group_size; //! for packA size = GROUP * a_size | ||||
WorkspaceBundle ws = {nullptr, {}}; | WorkspaceBundle ws = {nullptr, {}}; | ||||
auto im2col_kern_param = | auto im2col_kern_param = | ||||
get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size); | get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size); | ||||
if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) { | if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) { | ||||
Im2colKerns<Pack_Mode::DEFAULT> defaultkern; | Im2colKerns<Pack_Mode::DEFAULT> defaultkern; | ||||
ws = defaultkern.get_thread_bundle(param, im2col_kern_param, | ws = defaultkern.get_thread_bundle(param, im2col_kern_param, | ||||
@@ -447,6 +452,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||||
m_matmul_algo, m_ohw_tile_size, | m_matmul_algo, m_ohw_tile_size, | ||||
m_oc_tile_size); | m_oc_tile_size); | ||||
} | } | ||||
return {nullptr, | return {nullptr, | ||||
{padding, packa_size, ws.total_size_in_bytes() * nr_threads}}; | {padding, packa_size, ws.total_size_in_bytes() * nr_threads}}; | ||||
} | } | ||||
@@ -461,7 +467,7 @@ size_t ConvBiasImpl::AlgoIm2col::get_workspace( | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | ||||
ConvBiasImpl* opr, const NCBKernSizeParam& param) const { | |||||
ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 1) { | MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 1) { | ||||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | UNPACK_CONV_F32_NCB_KERN_SIZES(param); | ||||
MEGDNN_MARK_USED_VAR(SH); | MEGDNN_MARK_USED_VAR(SH); | ||||
@@ -473,7 +479,6 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||||
size_t ohw = OH * OW; | size_t ohw = OH * OW; | ||||
size_t ohw_parallel_times = div_ceil(ohw, m_ohw_tile_size); | size_t ohw_parallel_times = div_ceil(ohw, m_ohw_tile_size); | ||||
size_t GROUP = param.filter_meta.group; | size_t GROUP = param.filter_meta.group; | ||||
WorkspaceBundle bundle = get_bundle(param); | WorkspaceBundle bundle = get_bundle(param); | ||||
WorkspaceBundle bundle_thread = {nullptr, {}}; | WorkspaceBundle bundle_thread = {nullptr, {}}; | ||||
size_t oc_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size); | size_t oc_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size); | ||||
@@ -483,11 +488,14 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||||
bool no_pack = packmode == Pack_Mode::NO_PACK; | bool no_pack = packmode == Pack_Mode::NO_PACK; | ||||
bool only_packA = packmode == Pack_Mode::ONLY_PACKA; | bool only_packA = packmode == Pack_Mode::ONLY_PACKA; | ||||
size_t packa_parallel_times = 0; | size_t packa_parallel_times = 0; | ||||
size_t pack_oc_size = | |||||
(param.filter_meta.format == param::ConvBias::Format::NCHW ? 1 | |||||
: 4); | |||||
if (only_packA) { | if (only_packA) { | ||||
packa_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size); | packa_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size); | ||||
} else if (default_pack) { | } else if (default_pack) { | ||||
packa_parallel_times = div_ceil<size_t>( | packa_parallel_times = div_ceil<size_t>( | ||||
OC, m_matmul_algo->get_inner_block_size().m); | |||||
OC, m_matmul_algo->get_inner_block_size().m * pack_oc_size); | |||||
} | } | ||||
auto matmul_param = get_matmul_kern_param( | auto matmul_param = get_matmul_kern_param( | ||||
@@ -520,25 +528,29 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||||
strategyparam.skip_copy_dst = | strategyparam.skip_copy_dst = | ||||
strategyparam.is_ohw_size_bigger && !strategyparam.is_dst_8bit; | strategyparam.is_ohw_size_bigger && !strategyparam.is_dst_8bit; | ||||
strategyparam.oc_tile_size = m_oc_tile_size; | strategyparam.oc_tile_size = m_oc_tile_size; | ||||
strategyparam.pack_oc_size = pack_oc_size; | |||||
SmallVector<ConvBiasImpl::NCBKern> ret_kern; | SmallVector<ConvBiasImpl::NCBKern> ret_kern; | ||||
MIDOUT_BEGIN( | MIDOUT_BEGIN( | ||||
megdnn_fallback_im2col, | megdnn_fallback_im2col, | ||||
midout_iv("ConvBiasImpl::AlgoIm2col::dispatch_kerns"_hash)) { | midout_iv("ConvBiasImpl::AlgoIm2col::dispatch_kerns"_hash)) { | ||||
StrategyBase* im2colstrategy = Factory::get_im2col_strategy( | |||||
param, m_matmul_algo, opr->param().format); | |||||
auto kern_padding = [bundle, im2colstrategy]( | |||||
StrategyBase* im2colstrategy = | |||||
Factory::get_im2col_strategy(param, m_matmul_algo); | |||||
auto kern_padding = [bundle, im2colstrategy, | |||||
pack_oc_size = pack_oc_size]( | |||||
const NCBKernParam& param, | const NCBKernParam& param, | ||||
const NCBKernIndex& ncb_index) { | const NCBKernIndex& ncb_index) { | ||||
copy_padding_kern(bundle, param, ncb_index, im2colstrategy); | |||||
copy_padding_kern(bundle, param, ncb_index, im2colstrategy, | |||||
pack_oc_size); | |||||
}; | }; | ||||
auto kern_packA = [bundle, matmul_algo = m_matmul_algo, | auto kern_packA = [bundle, matmul_algo = m_matmul_algo, | ||||
matmul_param, | |||||
im2colstrategy](const NCBKernParam& param, | |||||
const NCBKernIndex& ncb_index) { | |||||
matmul_param, im2colstrategy, | |||||
pack_oc_size = pack_oc_size]( | |||||
const NCBKernParam& param, | |||||
const NCBKernIndex& ncb_index) { | |||||
packA_kern(bundle, param, matmul_param, matmul_algo, ncb_index, | packA_kern(bundle, param, matmul_param, matmul_algo, ncb_index, | ||||
im2colstrategy); | |||||
im2colstrategy, pack_oc_size); | |||||
}; | }; | ||||
if (default_pack) { | if (default_pack) { | ||||
auto kern_compute_default = | auto kern_compute_default = | ||||
@@ -556,7 +568,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||||
ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}}); | ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}}); | ||||
if (need_padding) { | if (need_padding) { | ||||
ret_kern.push_back({kern_padding, {param.n, GROUP, IC}}); | |||||
ret_kern.push_back({kern_padding, | |||||
{param.n, GROUP, IC / pack_oc_size}}); | |||||
} | } | ||||
ret_kern.push_back( | ret_kern.push_back( | ||||
{kern_compute_default, | {kern_compute_default, | ||||
@@ -629,19 +642,25 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | ||||
return false; | return false; | ||||
} | } | ||||
//! current now im2col only support int8 quantized s8 nchw44 | |||||
if (opr->param().format == param::ConvBias::Format::NCHW44 && | |||||
(param.src_type.enumv() == param.filter_type.enumv() && | |||||
(param.src_type.enumv() != DTypeEnum::Int8) && | |||||
(param.src_type.enumv() != DTypeEnum::QuantizedS8))) { | |||||
return false; | |||||
} | |||||
fallback::MatrixMulImpl::KernSizeParam matmul_param = | fallback::MatrixMulImpl::KernSizeParam matmul_param = | ||||
get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size); | get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size); | ||||
bool matmulusable = m_matmul_algo->usable(matmul_param); | bool matmulusable = m_matmul_algo->usable(matmul_param); | ||||
return matmulusable && | return matmulusable && | ||||
(opr->param().format == param::ConvBias::Format::NCHW) && | |||||
((param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||||
(param.filter_meta.spatial[0] <= 7) && | |||||
(param.filter_meta.spatial[0] >= 2)) || | |||||
(param.filter_meta.spatial[0] != param.filter_meta.spatial[1] && | |||||
(param.filter_meta.spatial[0] <= 7) && | |||||
(param.filter_meta.spatial[0] >= 1) && | |||||
(param.filter_meta.spatial[1] <= 7) && | |||||
(param.filter_meta.spatial[1] >= 1))) && | |||||
(opr->param().format == param::ConvBias::Format::NCHW || | |||||
opr->param().format == param::ConvBias::Format::NCHW44) && | |||||
(!(param.filter_meta.spatial[0] == | |||||
param.filter_meta.spatial[1] && | |||||
(param.filter_meta.spatial[0] == 1) && | |||||
param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||||
param.filter_meta.stride[0] == 1)) && | |||||
(param.filter_meta.dilation[0] == | (param.filter_meta.dilation[0] == | ||||
param.filter_meta.dilation[1] && | param.filter_meta.dilation[1] && | ||||
param.filter_meta.dilation[0] == 1) && | param.filter_meta.dilation[0] == 1) && | ||||
@@ -36,7 +36,6 @@ class ConvBiasImpl::AlgoIm2col final : public AlgoBase { | |||||
const NCBKernSizeParam& param, size_t ohw_tile_size, | const NCBKernSizeParam& param, size_t ohw_tile_size, | ||||
size_t oc_tile_size) const; | size_t oc_tile_size) const; | ||||
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; | WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; | ||||
WorkspaceBundle get_thread_bundle(const NCBKernSizeParam& param) const; | |||||
void choice_ohw_oc_block(const NCBKernSizeParam& param, size_t block_m, | void choice_ohw_oc_block(const NCBKernSizeParam& param, size_t block_m, | ||||
size_t block_n, bool pack_default) const; | size_t block_n, bool pack_default) const; | ||||
@@ -23,19 +23,11 @@ namespace im2col { | |||||
enum class StrategyType : uint32_t { | enum class StrategyType : uint32_t { | ||||
FLOAT = 0, | FLOAT = 0, | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
FLOAT_FP16 = 1, | |||||
#else | |||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
FLOAT16_FLOAT16 = 2, | FLOAT16_FLOAT16 = 2, | ||||
#endif | #endif | ||||
#endif | |||||
INT8x8x32 = 3, | INT8x8x32 = 3, | ||||
INT8x8x16 = 4, | INT8x8x16 = 4, | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
QUINT8x8x32 = 5, | |||||
QUINT8x8x32x8 = 6, | |||||
#endif | |||||
QINT8x8x32 = 7, | QINT8x8x32 = 7, | ||||
QINT8x8x32x8 = 8 | QINT8x8x32x8 = 8 | ||||
}; | }; | ||||
@@ -107,8 +99,7 @@ public: | |||||
~StrategyDelegationStorage() = default; | ~StrategyDelegationStorage() = default; | ||||
template <typename Strategy> | template <typename Strategy> | ||||
Strategy* get(param::ConvBias::Format format, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
Strategy* get(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernSizeParam& param, | const fallback::ConvBiasImpl::NCBKernSizeParam& param, | ||||
StrategyType stype); | StrategyType stype); | ||||
}; | }; | ||||
@@ -117,12 +108,10 @@ class Factory { | |||||
public: | public: | ||||
static StrategyBase* get_im2col_strategy( | static StrategyBase* get_im2col_strategy( | ||||
const fallback::ConvBiasImpl::NCBKernSizeParam& param, | const fallback::ConvBiasImpl::NCBKernSizeParam& param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
param::ConvBias::Format format) { | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||||
static StrategyDelegationStorage storage; | static StrategyDelegationStorage storage; | ||||
StrategyType strategytype = get_strategy_type(param); | StrategyType strategytype = get_strategy_type(param); | ||||
return storage.get<StrategyBase>(format, matmul_algo, param, | |||||
strategytype); | |||||
return storage.get<StrategyBase>(matmul_algo, param, strategytype); | |||||
} | } | ||||
static StrategyType get_strategy_type( | static StrategyType get_strategy_type( | ||||
@@ -141,13 +130,9 @@ public: | |||||
} | } | ||||
cb1(dt_float32, dt_float32, StrategyType::FLOAT); | cb1(dt_float32, dt_float32, StrategyType::FLOAT); | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
cb1(dt_float16, __fp16, StrategyType::FLOAT_FP16); | |||||
#else | |||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
cb1(dt_float16, dt_float16, StrategyType::FLOAT16_FLOAT16); | cb1(dt_float16, dt_float16, StrategyType::FLOAT16_FLOAT16); | ||||
#endif | #endif | ||||
#endif | |||||
cb2(dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32, | cb2(dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32, | ||||
StrategyType::INT8x8x32); | StrategyType::INT8x8x32); | ||||
@@ -155,13 +140,6 @@ public: | |||||
cb2(dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, dt_int16, | cb2(dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, dt_int16, | ||||
StrategyType::INT8x8x16); | StrategyType::INT8x8x16); | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::QuantizedS32, | |||||
dt_uint8, dt_int32, dt_int32, StrategyType::QUINT8x8x32); | |||||
cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::Quantized8Asymm, | |||||
dt_uint8, dt_int32, dt_uint8, StrategyType::QUINT8x8x32x8); | |||||
#endif | |||||
cb2(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32, | cb2(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32, | ||||
dt_int8, dt_int32, dt_int32, StrategyType::QINT8x8x32); | dt_int8, dt_int32, dt_int32, StrategyType::QINT8x8x32); | ||||
@@ -172,98 +150,106 @@ public: | |||||
megdnn_throw("not support datatype in im2col strategy\n"); | megdnn_throw("not support datatype in im2col strategy\n"); | ||||
} | } | ||||
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||||
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||||
midout_iv(_midout_tag)) { \ | |||||
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||||
return std::make_unique< \ | |||||
Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||||
_postprocess_mode, PackMode::_packmode>>(); \ | |||||
} \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
#define cb1(_format, _packmode, _dt, _post_ctype, _postprocess_mode, \ | |||||
_midout_tag) \ | |||||
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||||
midout_iv(_midout_tag)) { \ | |||||
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||||
return std::make_unique< \ | |||||
Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||||
_postprocess_mode, PackMode::_packmode, \ | |||||
FormatMode::_format>>(); \ | |||||
} \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
return {}; | return {}; | ||||
#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||||
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||||
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||||
midout_iv(_midout_tag)) { \ | |||||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||||
return std::make_unique< \ | |||||
Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||||
DTypeTrait<_i_bias_type>::ctype, \ | |||||
DTypeTrait<_i_dst_type>::ctype, \ | |||||
_postprocess_mode, PackMode::_packmode>>(); \ | |||||
} \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
#define cb2(_format, _packmode, _i_src_type, _i_bias_type, _i_dst_type, \ | |||||
_src_ctype, _bias_ctype, _dst_ctype, _postprocess_mode, \ | |||||
_midout_tag) \ | |||||
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||||
midout_iv(_midout_tag)) { \ | |||||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||||
return std::make_unique<Strategy< \ | |||||
_src_ctype, _bias_ctype, _dst_ctype, \ | |||||
DTypeTrait<_i_bias_type>::ctype, \ | |||||
DTypeTrait<_i_dst_type>::ctype, _postprocess_mode, \ | |||||
PackMode::_packmode, FormatMode::_format>>(); \ | |||||
} \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
return {}; | return {}; | ||||
static std::unique_ptr<StrategyBase> make_default_strategy( | static std::unique_ptr<StrategyBase> make_default_strategy( | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
const fallback::ConvBiasImpl::NCBKernSizeParam& param, | const fallback::ConvBiasImpl::NCBKernSizeParam& param, | ||||
param::ConvBias::Format format, StrategyType strategytype) { | |||||
StrategyType strategytype) { | |||||
MEGDNN_MARK_USED_VAR(matmul_algo); | MEGDNN_MARK_USED_VAR(matmul_algo); | ||||
MEGDNN_MARK_USED_VAR(format); | |||||
param::ConvBias::Format format = param.filter_meta.format; | |||||
switch (strategytype) { | switch (strategytype) { | ||||
case StrategyType::FLOAT: | case StrategyType::FLOAT: | ||||
cb1(DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT, | |||||
"DefaultStrategyType::FLOAT"_hash); | |||||
break; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
case StrategyType::FLOAT_FP16: | |||||
cb1(DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, | |||||
"DefaultStrategyType::FLOAT_FP16"_hash); | |||||
cb1(NCHW, DEFAULT, dt_float32, dt_float32, | |||||
PostprocessMode::FLOAT, "DefaultStrategyType::FLOAT"_hash); | |||||
break; | break; | ||||
#else | |||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
case StrategyType::FLOAT16_FLOAT16: | case StrategyType::FLOAT16_FLOAT16: | ||||
cb1(DEFAULT, dt_float16, dt_float16, | |||||
cb1(NCHW, DEFAULT, dt_float16, dt_float16, | |||||
PostprocessMode::NO_PROCESS, | PostprocessMode::NO_PROCESS, | ||||
"DefaultStrategyType::FLOAT16_FLOAT16"_hash); | "DefaultStrategyType::FLOAT16_FLOAT16"_hash); | ||||
break; | break; | ||||
#endif | #endif | ||||
#endif | |||||
case StrategyType::INT8x8x32: | case StrategyType::INT8x8x32: | ||||
cb2(DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, | |||||
dt_int32, PostprocessMode::NO_PROCESS, | |||||
"DefaultStrategyType::INT8x8x32"_hash); | |||||
if (format == param::ConvBias::Format::NCHW) { | |||||
cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
"DefaultStrategyType::INT8x8x32"_hash); | |||||
} else if (format == param::ConvBias::Format::NCHW44) { | |||||
cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
"DefaultStrategyType::INT8x8x32"_hash); | |||||
} else { | |||||
megdnn_throw("not support format except nchw44 and nchw\n"); | |||||
} | |||||
break; | break; | ||||
case StrategyType::INT8x8x16: | case StrategyType::INT8x8x16: | ||||
cb2(DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, | |||||
dt_int16, PostprocessMode::NO_PROCESS, | |||||
cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
"DefaultStrategyType::INT8x8x16"_hash); | "DefaultStrategyType::INT8x8x16"_hash); | ||||
break; | break; | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
case StrategyType::QUINT8x8x32: | |||||
cb2(DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||||
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||||
PostprocessMode::NO_PROCESS, | |||||
"DefaultStrategyType::QUINT8x8x32"_hash); | |||||
break; | |||||
case StrategyType::QUINT8x8x32x8: | |||||
cb2(DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||||
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | |||||
PostprocessMode::QUANTIZED, | |||||
"DefaultStrategyType::QUINT8x8x32x8"_hash); | |||||
break; | |||||
#endif | |||||
case StrategyType::QINT8x8x32: | case StrategyType::QINT8x8x32: | ||||
cb2(DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||||
PostprocessMode::NO_PROCESS, | |||||
"DefaultStrategyType::QINT8x8x32"_hash); | |||||
if (format == param::ConvBias::Format::NCHW) { | |||||
cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||||
PostprocessMode::NO_PROCESS, | |||||
"DefaultStrategyTypeNCHW::QINT8x8x32"_hash); | |||||
} else if (format == param::ConvBias::Format::NCHW44) { | |||||
cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | |||||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, | |||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
"DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); | |||||
} else { | |||||
megdnn_throw("not support format except nchw44 and nchw\n"); | |||||
} | |||||
break; | break; | ||||
case StrategyType::QINT8x8x32x8: | case StrategyType::QINT8x8x32x8: | ||||
cb2(DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||||
PostprocessMode::QUANTIZED, | |||||
"DefaultStrategyType::QINT8x8x32x8"_hash); | |||||
if (format == param::ConvBias::Format::NCHW) { | |||||
cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||||
PostprocessMode::QUANTIZED, | |||||
"DefaultStrategyType::QINT8x8x32x8"_hash); | |||||
} else if (format == param::ConvBias::Format::NCHW44) { | |||||
cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | |||||
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, | |||||
dt_int32, dt_int8, PostprocessMode::QUANTIZED, | |||||
"DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash); | |||||
} else { | |||||
megdnn_throw("not support format except nchw44 and nchw\n"); | |||||
} | |||||
break; | break; | ||||
} | } | ||||
megdnn_throw("error not support strategy type "); | megdnn_throw("error not support strategy type "); | ||||
@@ -272,63 +258,41 @@ public: | |||||
static std::unique_ptr<StrategyBase> make_nopack_strategy( | static std::unique_ptr<StrategyBase> make_nopack_strategy( | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
const fallback::ConvBiasImpl::NCBKernSizeParam& param, | const fallback::ConvBiasImpl::NCBKernSizeParam& param, | ||||
param::ConvBias::Format format, StrategyType strategytype) { | |||||
StrategyType strategytype) { | |||||
MEGDNN_MARK_USED_VAR(matmul_algo); | MEGDNN_MARK_USED_VAR(matmul_algo); | ||||
MEGDNN_MARK_USED_VAR(format); | |||||
switch (strategytype) { | switch (strategytype) { | ||||
case StrategyType::FLOAT: | case StrategyType::FLOAT: | ||||
cb1(NO_PACK, dt_float32, dt_float32, PostprocessMode::FLOAT, | |||||
"NoPackStrategyType::FLOAT"_hash); | |||||
break; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
case StrategyType::FLOAT_FP16: | |||||
cb1(NO_PACK, dt_float16, __fp16, PostprocessMode::FLOAT, | |||||
"NoPackStrategyType::FLOAT_FP16"_hash); | |||||
cb1(NCHW, NO_PACK, dt_float32, dt_float32, | |||||
PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash); | |||||
break; | break; | ||||
#else | |||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
case StrategyType::FLOAT16_FLOAT16: | case StrategyType::FLOAT16_FLOAT16: | ||||
cb1(NO_PACK, dt_float16, dt_float16, PostprocessMode::NO_PROCESS, | |||||
cb1(NCHW, NO_PACK, dt_float16, dt_float16, | |||||
PostprocessMode::NO_PROCESS, | |||||
"NoPackStrategyType::FLOAT16_FLOAT16"_hash); | "NoPackStrategyType::FLOAT16_FLOAT16"_hash); | ||||
break; | break; | ||||
#endif | #endif | ||||
#endif | |||||
case StrategyType::INT8x8x32: | case StrategyType::INT8x8x32: | ||||
cb2(NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, | |||||
dt_int32, PostprocessMode::NO_PROCESS, | |||||
cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
"NoPackStrategyType::INT8x8x32"_hash); | "NoPackStrategyType::INT8x8x32"_hash); | ||||
break; | break; | ||||
case StrategyType::INT8x8x16: | case StrategyType::INT8x8x16: | ||||
cb2(NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, | |||||
dt_int16, PostprocessMode::NO_PROCESS, | |||||
cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
"NoPackStrategyType::INT8x8x16"_hash); | "NoPackStrategyType::INT8x8x16"_hash); | ||||
break; | break; | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
case StrategyType::QUINT8x8x32: | |||||
cb2(NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||||
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||||
PostprocessMode::NO_PROCESS, | |||||
"NoPackStrategyType::QUINT8x8x32"_hash); | |||||
break; | |||||
case StrategyType::QUINT8x8x32x8: | |||||
cb2(NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||||
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | |||||
PostprocessMode::QUANTIZED, | |||||
"NoPackStrategyType::QUINT8x8x32x8"_hash); | |||||
break; | |||||
#endif | |||||
case StrategyType::QINT8x8x32: | case StrategyType::QINT8x8x32: | ||||
cb2(NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | ||||
PostprocessMode::NO_PROCESS, | PostprocessMode::NO_PROCESS, | ||||
"NoPackStrategyType::QINT8x8x32"_hash); | "NoPackStrategyType::QINT8x8x32"_hash); | ||||
break; | break; | ||||
case StrategyType::QINT8x8x32x8: | case StrategyType::QINT8x8x32x8: | ||||
cb2(NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | ||||
PostprocessMode::QUANTIZED, | PostprocessMode::QUANTIZED, | ||||
"NoPackStrategyType::QINT8x8x32x8"_hash); | "NoPackStrategyType::QINT8x8x32x8"_hash); | ||||
@@ -340,64 +304,42 @@ public: | |||||
static std::unique_ptr<StrategyBase> make_onlypacka_strategy( | static std::unique_ptr<StrategyBase> make_onlypacka_strategy( | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
const fallback::ConvBiasImpl::NCBKernSizeParam& param, | const fallback::ConvBiasImpl::NCBKernSizeParam& param, | ||||
param::ConvBias::Format format, StrategyType strategytype) { | |||||
StrategyType strategytype) { | |||||
MEGDNN_MARK_USED_VAR(matmul_algo); | MEGDNN_MARK_USED_VAR(matmul_algo); | ||||
MEGDNN_MARK_USED_VAR(format); | |||||
switch (strategytype) { | switch (strategytype) { | ||||
case StrategyType::FLOAT: | case StrategyType::FLOAT: | ||||
cb1(ONLY_PACKA, dt_float32, dt_float32, PostprocessMode::FLOAT, | |||||
cb1(NCHW, ONLY_PACKA, dt_float32, dt_float32, | |||||
PostprocessMode::FLOAT, | |||||
"OnlyPackaStrategyType::FLOAT"_hash); | "OnlyPackaStrategyType::FLOAT"_hash); | ||||
break; | break; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
case StrategyType::FLOAT_FP16: | |||||
cb1(ONLY_PACKA, dt_float16, __fp16, PostprocessMode::FLOAT, | |||||
"OnlyPackaStrategyType::FLOAT_FP16"_hash); | |||||
break; | |||||
#else | |||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
case StrategyType::FLOAT16_FLOAT16: | case StrategyType::FLOAT16_FLOAT16: | ||||
cb1(ONLY_PACKA, dt_float16, dt_float16, | |||||
cb1(NCHW, ONLY_PACKA, dt_float16, dt_float16, | |||||
PostprocessMode::NO_PROCESS, | PostprocessMode::NO_PROCESS, | ||||
"OnlyPackaStrategyType::FLOAT16_FLOAT16"_hash); | "OnlyPackaStrategyType::FLOAT16_FLOAT16"_hash); | ||||
break; | break; | ||||
#endif | #endif | ||||
#endif | |||||
case StrategyType::INT8x8x32: | case StrategyType::INT8x8x32: | ||||
cb2(ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, | |||||
dt_int32, PostprocessMode::NO_PROCESS, | |||||
cb2(NCHW, ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
"OnlyPackaStrategyType::INT8x8x32"_hash); | "OnlyPackaStrategyType::INT8x8x32"_hash); | ||||
break; | break; | ||||
case StrategyType::INT8x8x16: | case StrategyType::INT8x8x16: | ||||
cb2(ONLY_PACKA, dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, | |||||
dt_int16, PostprocessMode::NO_PROCESS, | |||||
cb2(NCHW, ONLY_PACKA, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
"OnlyPackaStrategyType::INT8x8x16"_hash); | "OnlyPackaStrategyType::INT8x8x16"_hash); | ||||
break; | break; | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
case StrategyType::QUINT8x8x32: | |||||
cb2(ONLY_PACKA, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||||
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||||
PostprocessMode::NO_PROCESS, | |||||
"OnlyPackaStrategyType::QUINT8x8x32"_hash); | |||||
break; | |||||
case StrategyType::QUINT8x8x32x8: | |||||
cb2(ONLY_PACKA, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||||
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | |||||
PostprocessMode::QUANTIZED, | |||||
"OnlyPackaStrategyType::QUINT8x8x32x8"_hash); | |||||
break; | |||||
#endif | |||||
case StrategyType::QINT8x8x32: | case StrategyType::QINT8x8x32: | ||||
cb2(ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | ||||
PostprocessMode::NO_PROCESS, | PostprocessMode::NO_PROCESS, | ||||
"OnlyPackaStrategyType::QINT8x8x32"_hash); | "OnlyPackaStrategyType::QINT8x8x32"_hash); | ||||
break; | break; | ||||
case StrategyType::QINT8x8x32x8: | case StrategyType::QINT8x8x32x8: | ||||
cb2(ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | ||||
PostprocessMode::QUANTIZED, | PostprocessMode::QUANTIZED, | ||||
"OnlyPackaStrategyType::QINT8x8x32x8"_hash); | "OnlyPackaStrategyType::QINT8x8x32x8"_hash); | ||||
@@ -410,21 +352,19 @@ public: | |||||
#undef cb2 | #undef cb2 | ||||
static std::unique_ptr<StrategyBase> make_strategy( | static std::unique_ptr<StrategyBase> make_strategy( | ||||
param::ConvBias::Format format, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
fallback::MatrixMulImpl::AlgoBase::PackMode packmode, | fallback::MatrixMulImpl::AlgoBase::PackMode packmode, | ||||
const fallback::ConvBiasImpl::NCBKernSizeParam& param, | const fallback::ConvBiasImpl::NCBKernSizeParam& param, | ||||
StrategyType stype) { | StrategyType stype) { | ||||
switch (packmode) { | switch (packmode) { | ||||
case MatrixMulImpl::AlgoBase::PackMode::DEFAULT: | case MatrixMulImpl::AlgoBase::PackMode::DEFAULT: | ||||
return make_default_strategy(matmul_algo, param, format, stype); | |||||
return make_default_strategy(matmul_algo, param, stype); | |||||
break; | break; | ||||
case MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA: | case MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA: | ||||
return make_onlypacka_strategy(matmul_algo, param, format, | |||||
stype); | |||||
return make_onlypacka_strategy(matmul_algo, param, stype); | |||||
break; | break; | ||||
case MatrixMulImpl::AlgoBase::PackMode::NO_PACK: | case MatrixMulImpl::AlgoBase::PackMode::NO_PACK: | ||||
return make_nopack_strategy(matmul_algo, param, format, stype); | |||||
return make_nopack_strategy(matmul_algo, param, stype); | |||||
break; | break; | ||||
default: | default: | ||||
megdnn_throw( | megdnn_throw( | ||||
@@ -432,14 +372,12 @@ public: | |||||
"nopack"); | "nopack"); | ||||
break; | break; | ||||
} | } | ||||
megdnn_throw( | |||||
"factory make Strategy error please check your code"); | |||||
megdnn_throw("factory make Strategy error please check your code"); | |||||
} | } | ||||
}; | }; | ||||
template <typename Strategy> | template <typename Strategy> | ||||
Strategy* StrategyDelegationStorage::get( | Strategy* StrategyDelegationStorage::get( | ||||
param::ConvBias::Format format, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
const fallback::ConvBiasImpl::NCBKernSizeParam& param, | const fallback::ConvBiasImpl::NCBKernSizeParam& param, | ||||
StrategyType stype) { | StrategyType stype) { | ||||
@@ -455,14 +393,14 @@ Strategy* StrategyDelegationStorage::get( | |||||
} | } | ||||
StrategyHashParam sparam; | StrategyHashParam sparam; | ||||
sparam.param = param; | sparam.param = param; | ||||
sparam.format = format; | |||||
sparam.format = param.filter_meta.format; | |||||
sparam.packmode = packmode; | sparam.packmode = packmode; | ||||
sparam.block_m = block_m; | sparam.block_m = block_m; | ||||
sparam.block_n = block_n; | sparam.block_n = block_n; | ||||
sparam.block_k = block_k; | sparam.block_k = block_k; | ||||
if (map_strategys.find(sparam) == map_strategys.end()) { | if (map_strategys.find(sparam) == map_strategys.end()) { | ||||
MEGDNN_LOCK_GUARD(m_mtx); | MEGDNN_LOCK_GUARD(m_mtx); | ||||
auto strategy = Factory::make_strategy(format, matmul_algo, packmode, | |||||
auto strategy = Factory::make_strategy(matmul_algo, packmode, | |||||
param, stype); | param, stype); | ||||
map_strategys[sparam] = std::move(strategy); | map_strategys[sparam] = std::move(strategy); | ||||
} | } | ||||
@@ -14,6 +14,7 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | ||||
using FormatMode = param::ConvBias::Format; | |||||
struct StrategyParam { | struct StrategyParam { | ||||
size_t batch_id; | size_t batch_id; | ||||
@@ -28,6 +29,7 @@ struct StrategyParam { | |||||
size_t block_m; | size_t block_m; | ||||
size_t block_n; | size_t block_n; | ||||
size_t block_k; | size_t block_k; | ||||
size_t pack_oc_size; | |||||
bool skip_copy_dst; | bool skip_copy_dst; | ||||
bool is_dst_8bit; | bool is_dst_8bit; | ||||
bool is_ohw_size_bigger; | bool is_ohw_size_bigger; | ||||
@@ -40,13 +42,15 @@ public: | |||||
virtual void copy_padding_kern( | virtual void copy_padding_kern( | ||||
WorkspaceBundle bundle, | WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) = 0; | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_size) = 0; | |||||
virtual void packA_kern( | virtual void packA_kern( | ||||
WorkspaceBundle bundle, | WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) = 0; | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_size) = 0; | |||||
virtual void exec_im2col( | virtual void exec_im2col( | ||||
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | ||||
@@ -70,14 +74,16 @@ public: | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode, PackMode packmode> | |||||
megdnn::PostprocessMode postprocess_mode, PackMode packmode, | |||||
FormatMode format> | |||||
class Strategy; | class Strategy; | ||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::DEFAULT> : public StrategyBase { | |||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW> | |||||
: public StrategyBase { | |||||
public: | public: | ||||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | constexpr static size_t BUNDLE_PADDING_INDEX = 0; | ||||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | constexpr static size_t BUNDLE_PACKA_INDEX = 1; | ||||
@@ -85,24 +91,26 @@ public: | |||||
constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1; | constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1; | ||||
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; | constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; | ||||
Strategy(); | |||||
Strategy() = default; | |||||
void copy_padding_kern( | void copy_padding_kern( | ||||
WorkspaceBundle bundle, | WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||||
void packA_kern( | |||||
WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_size) override; | |||||
void packA_kern(WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_size) override; | |||||
virtual void exec_im2col( | |||||
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||||
const StrategyParam& sparam, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||||
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||||
const StrategyParam& sparam, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernParam matmul_param, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
fallback::MatrixMulImpl::KernParam matmul_param, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
void exec_matmul( | void exec_matmul( | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
@@ -132,7 +140,32 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::NO_PACK> : public StrategyBase { | |||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW44> | |||||
: public Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::DEFAULT, | |||||
FormatMode::NCHW> { | |||||
public: | |||||
const size_t BUNDLE_PADDING_INDEX = 0; | |||||
const size_t BUNDLE_PACKA_INDEX = 1; | |||||
const size_t THREAD_BUNDLE_PACKB_INDEX = 0; | |||||
const size_t THREAD_BUNDLE_IM2COL_INDEX = 1; | |||||
const size_t THREAD_BUNDLE_BIAS_INDEX = 2; | |||||
Strategy() = default; | |||||
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||||
const StrategyParam& sparam, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernParam matmul_param, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
}; | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW> | |||||
: public StrategyBase { | |||||
public: | public: | ||||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | constexpr static size_t BUNDLE_PADDING_INDEX = 0; | ||||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | constexpr static size_t BUNDLE_PACKA_INDEX = 1; | ||||
@@ -141,19 +174,20 @@ public: | |||||
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; | constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; | ||||
constexpr static size_t THREAD_BUNDLE_MATCOMP_INDEX = 3; | constexpr static size_t THREAD_BUNDLE_MATCOMP_INDEX = 3; | ||||
Strategy(); | |||||
Strategy() = default; | |||||
void copy_padding_kern( | void copy_padding_kern( | ||||
WorkspaceBundle bundle, | WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_size) override; | |||||
void packA_kern( | |||||
WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||||
void packA_kern(WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_size) override; | |||||
void exec_matmul( | void exec_matmul( | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
@@ -197,7 +231,8 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::ONLY_PACKA> : public StrategyBase { | |||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW> | |||||
: public StrategyBase { | |||||
public: | public: | ||||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | constexpr static size_t BUNDLE_PADDING_INDEX = 0; | ||||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | constexpr static size_t BUNDLE_PACKA_INDEX = 1; | ||||
@@ -206,19 +241,20 @@ public: | |||||
constexpr static size_t THREAD_BUNDLE_MATMULDST_INDEX = 2; | constexpr static size_t THREAD_BUNDLE_MATMULDST_INDEX = 2; | ||||
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 3; | constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 3; | ||||
Strategy(); | |||||
Strategy() = default; | |||||
void copy_padding_kern( | void copy_padding_kern( | ||||
WorkspaceBundle bundle, | WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||||
void packA_kern( | |||||
WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_size) override; | |||||
void packA_kern(WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_size) override; | |||||
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
@@ -8,8 +8,6 @@ | |||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "megdnn/opr_param_defs.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | #include "src/fallback/conv_bias/im2col/strategy_base.h" | ||||
#include "src/fallback/convolution/img2col_helper.h" | #include "src/fallback/convolution/img2col_helper.h" | ||||
#if MEGDNN_X86 | #if MEGDNN_X86 | ||||
@@ -25,19 +23,12 @@ namespace megdnn { | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode,PackMode::DEFAULT>::Strategy() | |||||
: StrategyBase() {} | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::DEFAULT>:: | |||||
copy_padding_kern( | |||||
WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
copy_padding_kern(WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_oc_size) { | |||||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | UNPACK_CONV_F32_NCB_KERN_SIZES(param); | ||||
MEGDNN_MARK_USED_VAR(N); | MEGDNN_MARK_USED_VAR(N); | ||||
MEGDNN_MARK_USED_VAR(OC); | MEGDNN_MARK_USED_VAR(OC); | ||||
@@ -53,9 +44,13 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
size_t batch_id = ncb_index.ndrange_id[0]; | size_t batch_id = ncb_index.ndrange_id[0]; | ||||
size_t group_id = ncb_index.ndrange_id[1]; | size_t group_id = ncb_index.ndrange_id[1]; | ||||
size_t channel_id = ncb_index.ndrange_id[2]; | size_t channel_id = ncb_index.ndrange_id[2]; | ||||
size_t PH_SIZE = PH * IW2 * pack_oc_size; | |||||
PW = PW * pack_oc_size; | |||||
IW = IW * pack_oc_size; | |||||
size_t padding_group_size = IH2 * IW2 * IC; | size_t padding_group_size = IH2 * IW2 * IC; | ||||
size_t workspace_channel_offset = IH2 * IW2 * channel_id; | |||||
size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id; | |||||
size_t workspace_group_offset = group_id * padding_group_size; | size_t workspace_group_offset = group_id * padding_group_size; | ||||
size_t workspace_batch_offset = | size_t workspace_batch_offset = | ||||
param.filter_meta.group * batch_id * padding_group_size; | param.filter_meta.group * batch_id * padding_group_size; | ||||
@@ -65,8 +60,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | ||||
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | ||||
} | } | ||||
src_ctype* src = const_cast<src_ctype*>( | |||||
param.src<src_ctype>(batch_id, group_id, channel_id)); | |||||
src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>( | |||||
batch_id, group_id, channel_id, 1, pack_oc_size)); | |||||
src_ctype* src2; | src_ctype* src2; | ||||
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | ||||
workspace_group_offset + workspace_batch_offset + | workspace_group_offset + workspace_batch_offset + | ||||
@@ -74,8 +69,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
src_ctype* src2_ptr = src2; | src_ctype* src2_ptr = src2; | ||||
const src_ctype* src_ptr = src; | const src_ctype* src_ptr = src; | ||||
if (PH != 0) { | if (PH != 0) { | ||||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||||
src2_ptr += PH * IW2; | |||||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||||
src2_ptr += PH_SIZE; | |||||
} | } | ||||
rep(ih, IH) { | rep(ih, IH) { | ||||
if (PW != 0) | if (PW != 0) | ||||
@@ -87,8 +82,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
rep(pw, PW) * (src2_ptr++) = src_zp; | rep(pw, PW) * (src2_ptr++) = src_zp; | ||||
} | } | ||||
if (PH != 0) { | if (PH != 0) { | ||||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||||
src2_ptr += PH * IW2; | |||||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||||
src2_ptr += PH_SIZE; | |||||
} | } | ||||
} | } | ||||
@@ -96,12 +91,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::DEFAULT>:: | |||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
packA_kern(WorkspaceBundle bundle, | packA_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_oc_size) { | |||||
bundle.set(param.workspace_ptr); | bundle.set(param.workspace_ptr); | ||||
fallback::MatrixMulImpl::KernParam matmul_param; | fallback::MatrixMulImpl::KernParam matmul_param; | ||||
size_t group_id = ncb_index.ndrange_id[0]; | size_t group_id = ncb_index.ndrange_id[0]; | ||||
@@ -114,38 +110,38 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
matmul_algo->get_packA_type_size(); | matmul_algo->get_packA_type_size(); | ||||
size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size; | size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size; | ||||
int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) + | int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) + | ||||
group_id * packA_group_size + a_panel_offset; | |||||
group_id * packA_group_size + | |||||
(pack_oc_size == 4 ? 0 : a_panel_offset); | |||||
matmul_param.A_ptr = | matmul_param.A_ptr = | ||||
const_cast<src_ctype*>(param.filter<src_ctype>(group_id)); | const_cast<src_ctype*>(param.filter<src_ctype>(group_id)); | ||||
matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1], | matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1], | ||||
matmul_algo->get_inner_block_size().m); | |||||
matmul_algo->get_inner_block_size().m * pack_oc_size); | |||||
} | } | ||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::DEFAULT>:: | |||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo | |||||
) { | |||||
size_t m_sh = param.filter_meta.stride[0]; | |||||
size_t m_sw = param.filter_meta.stride[1]; | |||||
size_t m_oc = param.filter_meta.ocpg; | |||||
size_t m_oh = param.osz[0]; | |||||
size_t m_ow = param.osz[1]; | |||||
size_t m_ic = param.filter_meta.icpg; | |||||
size_t m_ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||||
size_t m_iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||||
size_t m_fh = param.filter_meta.spatial[0]; | |||||
size_t m_fw = param.filter_meta.spatial[1]; | |||||
size_t m_is_xcorr = !param.filter_meta.should_flip; | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||||
size_t sh = param.filter_meta.stride[0]; | |||||
size_t sw = param.filter_meta.stride[1]; | |||||
size_t oc = param.filter_meta.ocpg; | |||||
size_t oh = param.osz[0]; | |||||
size_t ow = param.osz[1]; | |||||
size_t ic = param.filter_meta.icpg; | |||||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||||
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||||
size_t fh = param.filter_meta.spatial[0]; | |||||
size_t fw = param.filter_meta.spatial[1]; | |||||
size_t is_xcorr = !param.filter_meta.should_flip; | |||||
size_t input_offset = | size_t input_offset = | ||||
m_ih * m_iw * m_ic * | |||||
ih * iw * ic * | |||||
(sparam.group_id + param.filter_meta.group * sparam.batch_id) * | (sparam.group_id + param.filter_meta.group * sparam.batch_id) * | ||||
sizeof(src_ctype); | sizeof(src_ctype); | ||||
@@ -160,26 +156,22 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
} | } | ||||
src_ctype* im2col_dst = static_cast<src_ctype*>( | src_ctype* im2col_dst = static_cast<src_ctype*>( | ||||
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | ||||
if (m_sh == 1 && m_sw == 1) { | |||||
if (m_is_xcorr) { | |||||
img2col<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||||
m_fh, m_fw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | |||||
if (sh == 1 && sw == 1) { | |||||
if (is_xcorr) { | |||||
img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||||
sparam.ohw_cur_index, sparam.output_block_size); | |||||
} else { | } else { | ||||
img2col<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||||
m_fh, m_fw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | |||||
img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||||
sparam.ohw_cur_index, sparam.output_block_size); | |||||
} | } | ||||
} else { | } else { | ||||
if (m_is_xcorr) { | |||||
img2col_stride<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, | |||||
m_iw, m_fh, m_fw, m_sh, m_sw, | |||||
sparam.ohw_cur_index, | |||||
if (is_xcorr) { | |||||
img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||||
fw, sh, sw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | sparam.output_block_size); | ||||
} else { | } else { | ||||
img2col_stride<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, | |||||
m_ih, m_iw, m_fh, m_fw, m_sh, m_sw, | |||||
sparam.ohw_cur_index, | |||||
img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||||
fw, sh, sw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | sparam.output_block_size); | ||||
} | } | ||||
} | } | ||||
@@ -199,7 +191,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::DEFAULT>:: | |||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const WorkspaceBundle& bundle_thread, | const WorkspaceBundle& bundle_thread, | ||||
const StrategyParam& sparam) { | const StrategyParam& sparam) { | ||||
@@ -218,7 +210,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::DEFAULT>:: | |||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, WorkspaceBundle bundle, | const StrategyParam& sparam, WorkspaceBundle bundle, | ||||
WorkspaceBundle bundle_thread, | WorkspaceBundle bundle_thread, | ||||
@@ -240,11 +232,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
src_ctype* b_panel = | src_ctype* b_panel = | ||||
reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>( | reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>( | ||||
bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX))); | bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX))); | ||||
size_t pack_oc_size = sparam.pack_oc_size; | |||||
matmul_param.M = sparam.output_block_oc_size; | matmul_param.M = sparam.output_block_oc_size; | ||||
matmul_param.N = sparam.output_block_size; | matmul_param.N = sparam.output_block_size; | ||||
matmul_param.LDB = sparam.output_block_size; | |||||
matmul_param.LDC = sparam.output_block_size; | |||||
matmul_param.LDB = pack_oc_size * sparam.output_block_size; | |||||
matmul_param.LDC = pack_oc_size * sparam.output_block_size; | |||||
matmul_param.C_ptr = matmul_dst; | matmul_param.C_ptr = matmul_dst; | ||||
auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); | auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); | ||||
@@ -255,7 +247,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::DEFAULT>:: | |||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
WorkspaceBundle bundle_thread) { | WorkspaceBundle bundle_thread) { | ||||
@@ -274,7 +266,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | ||||
matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode, | matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode, | ||||
param.nonlineMode, param.bias_type, param.dst_type, 1_z, | param.nonlineMode, param.bias_type, param.dst_type, 1_z, | ||||
sparam.output_block_oc_size, 1_z, sparam.output_block_size); | |||||
sparam.output_block_oc_size, 1_z, sparam.output_block_size, | |||||
sparam.pack_oc_size); | |||||
copy_dst(param, matmul_dst, sparam); | copy_dst(param, matmul_dst, sparam); | ||||
} | } | ||||
@@ -282,20 +275,24 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::DEFAULT>:: | |||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const void* matmul_dst, const StrategyParam& sparam) { | const void* matmul_dst, const StrategyParam& sparam) { | ||||
if (!sparam.skip_copy_dst) { | if (!sparam.skip_copy_dst) { | ||||
size_t pack_oc_size = sparam.pack_oc_size; | |||||
dst_ctype* dst_tmp_ptr = | dst_ctype* dst_tmp_ptr = | ||||
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | ||||
dst_ctype* dst = | dst_ctype* dst = | ||||
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | ||||
sparam.oc_cur_index * sparam.ohw + sparam.ohw_cur_index; | |||||
for (size_t oc = 0; oc < sparam.output_block_oc_size; oc++) { | |||||
sparam.oc_cur_index * sparam.ohw + | |||||
sparam.ohw_cur_index * pack_oc_size; | |||||
size_t oc_loop = sparam.output_block_oc_size / pack_oc_size; | |||||
for (size_t oc = 0; oc < oc_loop; oc++) { | |||||
std::memcpy(dst, dst_tmp_ptr, | std::memcpy(dst, dst_tmp_ptr, | ||||
sizeof(dst_ctype) * sparam.output_block_size); | |||||
dst_tmp_ptr += sparam.output_block_size; | |||||
dst += sparam.ohw; | |||||
sizeof(dst_ctype) * sparam.output_block_size * | |||||
pack_oc_size); | |||||
dst_tmp_ptr += sparam.output_block_size * pack_oc_size; | |||||
dst += sparam.ohw * pack_oc_size; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -304,7 +301,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::DEFAULT>:: | |||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const WorkspaceBundle& bundle_thread) { | const WorkspaceBundle& bundle_thread) { | ||||
bias_ctype* bias_tmp_ptr = | bias_ctype* bias_tmp_ptr = | ||||
@@ -319,7 +316,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::DEFAULT>:: | |||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | ||||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | ||||
@@ -340,31 +337,20 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
} | } | ||||
} | } | ||||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode) \ | |||||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode, PackMode::DEFAULT>; | |||||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode) \ | |||||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode, PackMode::DEFAULT, \ | |||||
FormatMode::NCHW>; | |||||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | ||||
megdnn::PostprocessMode::FLOAT) | megdnn::PostprocessMode::FLOAT) | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||||
megdnn::PostprocessMode::FLOAT) | |||||
#else | |||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | ||||
megdnn::PostprocessMode::NO_PROCESS) | megdnn::PostprocessMode::NO_PROCESS) | ||||
#endif | #endif | ||||
#endif | |||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||||
megdnn::PostprocessMode::QUANTIZED) | |||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
#endif | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | ||||
megdnn::PostprocessMode::QUANTIZED) | megdnn::PostprocessMode::QUANTIZED) | ||||
@@ -0,0 +1,118 @@ | |||||
/** | |||||
* \file dnn/src/fallback/conv_bias/im2col/strategy_default.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | |||||
#include "src/fallback/convolution/img2col_helper.h" | |||||
#if MEGDNN_X86 | |||||
#include "src/x86/conv_bias/postprocess_helper.h" | |||||
#endif | |||||
using namespace megdnn; | |||||
#if MEGDNN_X86 | |||||
using namespace x86; | |||||
#endif | |||||
namespace megdnn { | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW44>:: | |||||
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||||
const StrategyParam& sparam, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernParam matmul_param, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||||
size_t sh = param.filter_meta.stride[0]; | |||||
size_t sw = param.filter_meta.stride[1]; | |||||
size_t oc = param.filter_meta.ocpg; | |||||
size_t oh = param.osz[0]; | |||||
size_t ow = param.osz[1]; | |||||
size_t ic = param.filter_meta.icpg; | |||||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||||
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||||
size_t fh = param.filter_meta.spatial[0]; | |||||
size_t fw = param.filter_meta.spatial[1]; | |||||
size_t is_xcorr = !param.filter_meta.should_flip; | |||||
constexpr static size_t pack_size = 4; | |||||
size_t input_offset = | |||||
ih * iw * ic * | |||||
(sparam.group_id + param.filter_meta.group * sparam.batch_id) * | |||||
sizeof(src_ctype); | |||||
src_ctype* src2 = reinterpret_cast<src_ctype*>( | |||||
reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||||
input_offset); | |||||
bool is_phpwzero = param.filter_meta.padding[0] == 0 && | |||||
param.filter_meta.padding[1] == 0; | |||||
if (is_phpwzero) { | |||||
src2 = const_cast<src_ctype*>( | |||||
param.src<src_ctype>(sparam.batch_id, sparam.group_id)); | |||||
} | |||||
src_ctype* im2col_dst = static_cast<src_ctype*>( | |||||
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | |||||
if (is_xcorr) { | |||||
if (sh == sw && sh == 1) { | |||||
img2col_nchw4<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||||
fw, sh, sw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | |||||
} else { | |||||
img2col_stride_nchw4<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, | |||||
fh, fw, sh, sw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | |||||
} | |||||
} else { | |||||
if (sh == sw && sh == 1) { | |||||
img2col_nchw4<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||||
fw, sh, sw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | |||||
} else { | |||||
img2col_stride_nchw4<false>( | |||||
src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, sh, sw, | |||||
sparam.ohw_cur_index, sparam.output_block_size); | |||||
} | |||||
} | |||||
matmul_param.M = sparam.output_block_oc_size; | |||||
matmul_param.N = sparam.output_block_size; | |||||
matmul_param.LDB = pack_size * sparam.output_block_size; | |||||
matmul_param.LDC = pack_size * sparam.output_block_size; | |||||
matmul_param.B_ptr = im2col_dst; | |||||
src_ctype* b_panel = | |||||
reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>( | |||||
bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX))); | |||||
matmul_algo->pack_B(matmul_param, b_panel, 0, matmul_param.N); | |||||
} | |||||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode) \ | |||||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode, PackMode::DEFAULT, \ | |||||
FormatMode::NCHW44>; | |||||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | |||||
megdnn::PostprocessMode::FLOAT) | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
#endif | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||||
megdnn::PostprocessMode::QUANTIZED) | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
#undef INSTANTIAL_CLASS | |||||
} // namespace megdnn |
@@ -9,8 +9,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "megdnn/opr_param_defs.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | #include "src/fallback/conv_bias/im2col/strategy_base.h" | ||||
#include "src/fallback/convolution/img2col_helper.h" | #include "src/fallback/convolution/img2col_helper.h" | ||||
#if MEGDNN_X86 | #if MEGDNN_X86 | ||||
@@ -22,22 +20,16 @@ using namespace megdnn; | |||||
using namespace x86; | using namespace x86; | ||||
#endif | #endif | ||||
namespace megdnn { | namespace megdnn { | ||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode,PackMode::NO_PACK>::Strategy() | |||||
: StrategyBase() {} | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::NO_PACK>:: | |||||
copy_padding_kern( | |||||
WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
copy_padding_kern(WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t) { | |||||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | UNPACK_CONV_F32_NCB_KERN_SIZES(param); | ||||
MEGDNN_MARK_USED_VAR(N); | MEGDNN_MARK_USED_VAR(N); | ||||
MEGDNN_MARK_USED_VAR(OC); | MEGDNN_MARK_USED_VAR(OC); | ||||
@@ -96,12 +88,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::NO_PACK>:: | |||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
packA_kern(WorkspaceBundle bundle, | packA_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t) { | |||||
MEGDNN_MARK_USED_VAR(bundle); | MEGDNN_MARK_USED_VAR(bundle); | ||||
MEGDNN_MARK_USED_VAR(param); | MEGDNN_MARK_USED_VAR(param); | ||||
MEGDNN_MARK_USED_VAR(matmulparam); | MEGDNN_MARK_USED_VAR(matmulparam); | ||||
@@ -115,7 +108,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::NO_PACK>:: | |||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const WorkspaceBundle& bundle_thread, | const WorkspaceBundle& bundle_thread, | ||||
const StrategyParam& sparam) { | const StrategyParam& sparam) { | ||||
@@ -134,7 +127,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::NO_PACK>:: | |||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, WorkspaceBundle bundle, | const StrategyParam& sparam, WorkspaceBundle bundle, | ||||
WorkspaceBundle bundle_thread, | WorkspaceBundle bundle_thread, | ||||
@@ -167,29 +160,28 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::NO_PACK>:: | |||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo | |||||
) { | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||||
MEGDNN_MARK_USED_VAR(matmul_param); | MEGDNN_MARK_USED_VAR(matmul_param); | ||||
MEGDNN_MARK_USED_VAR(matmul_algo); | MEGDNN_MARK_USED_VAR(matmul_algo); | ||||
size_t m_sh = param.filter_meta.stride[0]; | |||||
size_t m_sw = param.filter_meta.stride[1]; | |||||
size_t m_oc = param.filter_meta.ocpg; | |||||
size_t m_oh = param.osz[0]; | |||||
size_t m_ow = param.osz[1]; | |||||
size_t m_ic = param.filter_meta.icpg; | |||||
size_t m_ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||||
size_t m_iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||||
size_t m_fh = param.filter_meta.spatial[0]; | |||||
size_t m_fw = param.filter_meta.spatial[1]; | |||||
size_t m_is_xcorr = !param.filter_meta.should_flip; | |||||
size_t sh = param.filter_meta.stride[0]; | |||||
size_t sw = param.filter_meta.stride[1]; | |||||
size_t oc = param.filter_meta.ocpg; | |||||
size_t oh = param.osz[0]; | |||||
size_t ow = param.osz[1]; | |||||
size_t ic = param.filter_meta.icpg; | |||||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||||
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||||
size_t fh = param.filter_meta.spatial[0]; | |||||
size_t fw = param.filter_meta.spatial[1]; | |||||
size_t is_xcorr = !param.filter_meta.should_flip; | |||||
size_t input_offset = | size_t input_offset = | ||||
m_ih * m_iw * m_ic * | |||||
ih * iw * ic * | |||||
(sparam.group_id + param.filter_meta.group * sparam.batch_id) * | (sparam.group_id + param.filter_meta.group * sparam.batch_id) * | ||||
sizeof(src_ctype); | sizeof(src_ctype); | ||||
@@ -205,26 +197,22 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
} | } | ||||
src_ctype* im2col_dst = static_cast<src_ctype*>( | src_ctype* im2col_dst = static_cast<src_ctype*>( | ||||
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | ||||
if (m_sh == 1 && m_sw == 1) { | |||||
if (m_is_xcorr) { | |||||
img2col<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||||
m_fh, m_fw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | |||||
if (sh == 1 && sw == 1) { | |||||
if (is_xcorr) { | |||||
img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||||
sparam.ohw_cur_index, sparam.output_block_size); | |||||
} else { | } else { | ||||
img2col<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||||
m_fh, m_fw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | |||||
img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||||
sparam.ohw_cur_index, sparam.output_block_size); | |||||
} | } | ||||
} else { | } else { | ||||
if (m_is_xcorr) { | |||||
img2col_stride<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, | |||||
m_iw, m_fh, m_fw, m_sh, m_sw, | |||||
sparam.ohw_cur_index, | |||||
if (is_xcorr) { | |||||
img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||||
fw, sh, sw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | sparam.output_block_size); | ||||
} else { | } else { | ||||
img2col_stride<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, | |||||
m_ih, m_iw, m_fh, m_fw, m_sh, m_sw, | |||||
sparam.ohw_cur_index, | |||||
img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||||
fw, sh, sw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | sparam.output_block_size); | ||||
} | } | ||||
} | } | ||||
@@ -234,7 +222,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::NO_PACK>:: | |||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
WorkspaceBundle bundle_thread) { | WorkspaceBundle bundle_thread) { | ||||
@@ -262,7 +250,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::NO_PACK>:: | |||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const void* matmul_dst, const StrategyParam& sparam) { | const void* matmul_dst, const StrategyParam& sparam) { | ||||
if (!sparam.skip_copy_dst) { | if (!sparam.skip_copy_dst) { | ||||
@@ -284,7 +272,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::NO_PACK>:: | |||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | ||||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | ||||
@@ -305,31 +293,20 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
} | } | ||||
} | } | ||||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode) \ | |||||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode, PackMode::NO_PACK>; | |||||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode) \ | |||||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode, PackMode::NO_PACK, \ | |||||
FormatMode::NCHW>; | |||||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | ||||
megdnn::PostprocessMode::FLOAT) | megdnn::PostprocessMode::FLOAT) | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||||
megdnn::PostprocessMode::FLOAT) | |||||
#else | |||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | ||||
megdnn::PostprocessMode::NO_PROCESS) | megdnn::PostprocessMode::NO_PROCESS) | ||||
#endif | #endif | ||||
#endif | |||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||||
megdnn::PostprocessMode::QUANTIZED) | |||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
#endif | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | ||||
megdnn::PostprocessMode::QUANTIZED) | megdnn::PostprocessMode::QUANTIZED) | ||||
@@ -9,7 +9,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "megdnn/opr_param_defs.h" | |||||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | #include "src/fallback/conv_bias/im2col/strategy_base.h" | ||||
#include "src/fallback/convolution/img2col_helper.h" | #include "src/fallback/convolution/img2col_helper.h" | ||||
#if MEGDNN_X86 | #if MEGDNN_X86 | ||||
@@ -21,22 +20,16 @@ using namespace megdnn; | |||||
using namespace x86; | using namespace x86; | ||||
#endif | #endif | ||||
namespace megdnn { | namespace megdnn { | ||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode,PackMode::ONLY_PACKA>::Strategy() | |||||
: StrategyBase() {} | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::ONLY_PACKA>:: | |||||
copy_padding_kern( | |||||
WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||||
copy_padding_kern(WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t) { | |||||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | UNPACK_CONV_F32_NCB_KERN_SIZES(param); | ||||
MEGDNN_MARK_USED_VAR(N); | MEGDNN_MARK_USED_VAR(N); | ||||
MEGDNN_MARK_USED_VAR(OC); | MEGDNN_MARK_USED_VAR(OC); | ||||
@@ -95,12 +88,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::ONLY_PACKA>:: | |||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||||
packA_kern(WorkspaceBundle bundle, | packA_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t) { | |||||
bundle.set(param.workspace_ptr); | bundle.set(param.workspace_ptr); | ||||
fallback::MatrixMulImpl::KernParam matmul_param; | fallback::MatrixMulImpl::KernParam matmul_param; | ||||
static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) = | static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) = | ||||
@@ -128,7 +122,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::ONLY_PACKA>:: | |||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||||
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const WorkspaceBundle& bundle_thread, | const WorkspaceBundle& bundle_thread, | ||||
const StrategyParam& sparam) { | const StrategyParam& sparam) { | ||||
@@ -147,7 +141,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::ONLY_PACKA>:: | |||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||||
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, WorkspaceBundle bundle, | const StrategyParam& sparam, WorkspaceBundle bundle, | ||||
WorkspaceBundle bundle_thread, | WorkspaceBundle bundle_thread, | ||||
@@ -185,29 +179,28 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::ONLY_PACKA>:: | |||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||||
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo | |||||
) { | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||||
MEGDNN_MARK_USED_VAR(matmul_param); | MEGDNN_MARK_USED_VAR(matmul_param); | ||||
MEGDNN_MARK_USED_VAR(matmul_algo); | MEGDNN_MARK_USED_VAR(matmul_algo); | ||||
size_t m_sh = param.filter_meta.stride[0]; | |||||
size_t m_sw = param.filter_meta.stride[1]; | |||||
size_t m_oc = param.filter_meta.ocpg; | |||||
size_t m_oh = param.osz[0]; | |||||
size_t m_ow = param.osz[1]; | |||||
size_t m_ic = param.filter_meta.icpg; | |||||
size_t m_ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||||
size_t m_iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||||
size_t m_fh = param.filter_meta.spatial[0]; | |||||
size_t m_fw = param.filter_meta.spatial[1]; | |||||
size_t m_is_xcorr = !param.filter_meta.should_flip; | |||||
size_t sh = param.filter_meta.stride[0]; | |||||
size_t sw = param.filter_meta.stride[1]; | |||||
size_t oc = param.filter_meta.ocpg; | |||||
size_t oh = param.osz[0]; | |||||
size_t ow = param.osz[1]; | |||||
size_t ic = param.filter_meta.icpg; | |||||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||||
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||||
size_t fh = param.filter_meta.spatial[0]; | |||||
size_t fw = param.filter_meta.spatial[1]; | |||||
size_t is_xcorr = !param.filter_meta.should_flip; | |||||
size_t input_offset = | size_t input_offset = | ||||
m_ih * m_iw * m_ic * | |||||
ih * iw * ic * | |||||
(sparam.group_id + param.filter_meta.group * sparam.batch_id) * | (sparam.group_id + param.filter_meta.group * sparam.batch_id) * | ||||
sizeof(src_ctype); | sizeof(src_ctype); | ||||
@@ -222,26 +215,22 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
} | } | ||||
src_ctype* im2col_dst = static_cast<src_ctype*>( | src_ctype* im2col_dst = static_cast<src_ctype*>( | ||||
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | ||||
if (m_sh == 1 && m_sw == 1) { | |||||
if (m_is_xcorr) { | |||||
img2col<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||||
m_fh, m_fw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | |||||
if (sh == 1 && sw == 1) { | |||||
if (is_xcorr) { | |||||
img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||||
sparam.ohw_cur_index, sparam.output_block_size); | |||||
} else { | } else { | ||||
img2col<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||||
m_fh, m_fw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | |||||
img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||||
sparam.ohw_cur_index, sparam.output_block_size); | |||||
} | } | ||||
} else { | } else { | ||||
if (m_is_xcorr) { | |||||
img2col_stride<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, | |||||
m_iw, m_fh, m_fw, m_sh, m_sw, | |||||
sparam.ohw_cur_index, | |||||
if (is_xcorr) { | |||||
img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||||
fw, sh, sw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | sparam.output_block_size); | ||||
} else { | } else { | ||||
img2col_stride<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, | |||||
m_ih, m_iw, m_fh, m_fw, m_sh, m_sw, | |||||
sparam.ohw_cur_index, | |||||
img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||||
fw, sh, sw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | sparam.output_block_size); | ||||
} | } | ||||
} | } | ||||
@@ -251,7 +240,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::ONLY_PACKA>:: | |||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||||
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
WorkspaceBundle bundle_thread) { | WorkspaceBundle bundle_thread) { | ||||
@@ -292,7 +281,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode,PackMode::ONLY_PACKA>:: | |||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||||
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const void* matmul_dst, const StrategyParam& sparam) { | const void* matmul_dst, const StrategyParam& sparam) { | ||||
if (!sparam.skip_copy_dst) { | if (!sparam.skip_copy_dst) { | ||||
@@ -310,31 +299,20 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
} | } | ||||
} | } | ||||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode) \ | |||||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||||
_op_ctype, _op_dtype, _postprocess_mode,PackMode::ONLY_PACKA>; | |||||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode) \ | |||||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode, \ | |||||
PackMode::ONLY_PACKA, FormatMode::NCHW>; | |||||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | ||||
megdnn::PostprocessMode::FLOAT) | megdnn::PostprocessMode::FLOAT) | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||||
megdnn::PostprocessMode::FLOAT) | |||||
#else | |||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | ||||
megdnn::PostprocessMode::NO_PROCESS) | megdnn::PostprocessMode::NO_PROCESS) | ||||
#endif | #endif | ||||
#endif | |||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||||
megdnn::PostprocessMode::QUANTIZED) | |||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
#endif | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | ||||
megdnn::PostprocessMode::QUANTIZED) | megdnn::PostprocessMode::QUANTIZED) | ||||
@@ -9,7 +9,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
namespace { | namespace { | ||||
template <bool is_xcorr, typename dtype> | template <bool is_xcorr, typename dtype> | ||||
@@ -41,7 +40,326 @@ void img2col_stride(const dtype* __restrict src, dtype* __restrict dst, | |||||
} | } | ||||
} | } | ||||
//!add for im2col matmul multithread | //!add for im2col matmul multithread | ||||
// | |||||
template <bool is_xcorr, typename dtype> | |||||
void img2col_stride_nchw4(const dtype* __restrict src, dtype* __restrict dst, | |||||
const int OC, const int OH, const int OW, const int IC, | |||||
const int IH, const int IW, const int FH, const int FW, | |||||
const int SH, const int SW, const int cur_index, | |||||
const int block_size) { | |||||
MEGDNN_MARK_USED_VAR(OC); | |||||
MEGDNN_MARK_USED_VAR(OH); | |||||
int start_h = cur_index / OW; | |||||
int cur_remain_w = cur_index % OW; | |||||
int end_h = (cur_index + block_size) / OW; | |||||
int end_remain_w = (cur_index + block_size) % OW; | |||||
bool same_line = false; | |||||
if (start_h == end_h) { | |||||
same_line = true; | |||||
} | |||||
size_t newIC = IC / 4; | |||||
size_t i = 0; | |||||
if (sizeof(dtype) != 1) { | |||||
if (same_line) { | |||||
rep(ic, newIC) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2, fw2; | |||||
if (is_xcorr) { | |||||
fh2 = fh; | |||||
fw2 = fw; | |||||
} else { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
for (int w = cur_remain_w; w < end_remain_w; w++) { | |||||
size_t index = 4 * (ic * IH * IW + | |||||
(start_h * SH + fh2) * IW + | |||||
(w * SW + fw2)); | |||||
dst[i++] = src[index]; | |||||
dst[i++] = src[index + 1]; | |||||
dst[i++] = src[index + 2]; | |||||
dst[i++] = src[index + 3]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
rep(ic, newIC) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2, fw2; | |||||
if (is_xcorr) { | |||||
fh2 = fh; | |||||
fw2 = fw; | |||||
} else { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
for (int w = cur_remain_w; w < OW; w++) { | |||||
size_t index =4 * (ic * IH * IW + | |||||
(start_h * SH + fh2) * IW + | |||||
(w * SW + fw2)); | |||||
dst[i++] = src[index + 0]; | |||||
dst[i++] = src[index + 1]; | |||||
dst[i++] = src[index + 2]; | |||||
dst[i++] = src[index + 3]; | |||||
} | |||||
for (int h = start_h + 1; h < end_h; h++) { | |||||
rep(ow, OW) { | |||||
size_t index = 4 * (ic * IH * IW + | |||||
(h * SH + fh2) * IW + | |||||
(ow * SW + fw2)); | |||||
dst[i++] = src[index + 0]; | |||||
dst[i++] = src[index + 1]; | |||||
dst[i++] = src[index + 2]; | |||||
dst[i++] = src[index + 3]; | |||||
} | |||||
} | |||||
for (int w = 0; w < end_remain_w; w++) { | |||||
size_t index = 4 * (ic * IH * IW + | |||||
(end_h * SH + fh2) * IW + | |||||
(w * SW + fw2)); | |||||
dst[i++] = src[index + 0]; | |||||
dst[i++] = src[index + 1]; | |||||
dst[i++] = src[index + 2]; | |||||
dst[i++] = src[index + 3]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
uint32_t* output = nullptr; | |||||
const uint32_t* uint32_src = | |||||
static_cast<const uint32_t*>(static_cast<const void*>(src)); | |||||
output = static_cast<uint32_t*>(static_cast<void*>(dst)); | |||||
if (same_line) { | |||||
rep(ic, newIC) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2, fw2; | |||||
if (is_xcorr) { | |||||
fh2 = fh; | |||||
fw2 = fw; | |||||
} else { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
size_t index = | |||||
(ic * IH * IW + (start_h * SH + fh2) * IW + | |||||
(cur_remain_w * SW + fw2)); | |||||
for (int w = cur_remain_w; w < end_remain_w; w++) { | |||||
output[i++] = uint32_src[index]; | |||||
index += SW; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
rep(ic, newIC) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2, fw2; | |||||
if (is_xcorr) { | |||||
fh2 = fh; | |||||
fw2 = fw; | |||||
} else { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
size_t index = ic * IH * IW + | |||||
(start_h * SH + fh2) * IW + | |||||
cur_remain_w * SW + fw2; | |||||
for (int w = cur_remain_w; w < OW; w++) { | |||||
output[i++] = uint32_src[index]; | |||||
index += SW; | |||||
} | |||||
for (int h = start_h + 1; h < end_h; h++) { | |||||
index = ic * IH * IW + (h * SH + fh2) * IW + fw2; | |||||
rep(ow, OW) { | |||||
output[i++] = uint32_src[index]; | |||||
index += SW; | |||||
} | |||||
} | |||||
index = ic * IH * IW + (end_h * SH + fh2) * IW + fw2; | |||||
for (int w = 0; w < end_remain_w; w++) { | |||||
output[i++] = uint32_src[index]; | |||||
index += SW; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <bool is_xcorr, typename dtype> | |||||
void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst, | |||||
const int OC, const int OH, const int OW, const int IC, | |||||
const int IH, const int IW, const int FH, const int FW, | |||||
const int SH, const int SW, const int cur_index, | |||||
const int block_size) { | |||||
MEGDNN_MARK_USED_VAR(OC); | |||||
MEGDNN_MARK_USED_VAR(OH); | |||||
MEGDNN_MARK_USED_VAR(SH); | |||||
MEGDNN_MARK_USED_VAR(SW); | |||||
int start_h = cur_index / OW; | |||||
int cur_remain_w = cur_index % OW; | |||||
int end_h = (cur_index + block_size) / OW; | |||||
int end_remain_w = (cur_index + block_size) % OW; | |||||
bool same_line = false; | |||||
if (start_h == end_h) { | |||||
same_line = true; | |||||
} | |||||
size_t newIC = IC / 4; | |||||
size_t i = 0; | |||||
if (sizeof(dtype) != 1) { | |||||
if (same_line) { | |||||
rep(ic, newIC) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2, fw2; | |||||
if (is_xcorr) { | |||||
fh2 = fh; | |||||
fw2 = fw; | |||||
} else { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
for (int w = cur_remain_w; w < end_remain_w; w++) { | |||||
size_t index = | |||||
4 * (ic * IH * IW + (start_h + fh2) * IW + | |||||
(w + fw2)); | |||||
dst[i++] = src[index]; | |||||
dst[i++] = src[index + 1]; | |||||
dst[i++] = src[index + 2]; | |||||
dst[i++] = src[index + 3]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
rep(ic, newIC) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2, fw2; | |||||
if (is_xcorr) { | |||||
fh2 = fh; | |||||
fw2 = fw; | |||||
} else { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
for (int w = cur_remain_w; w < OW; w++) { | |||||
size_t index = ic * IH * IW + (start_h + fh2) * IW + | |||||
(w + fw2); | |||||
dst[i++] = src[4 * index]; | |||||
dst[i++] = src[4 * index + 1]; | |||||
dst[i++] = src[4 * index + 2]; | |||||
dst[i++] = src[4 * index + 3]; | |||||
} | |||||
for (int h = start_h + 1; h < end_h; h++) { | |||||
rep(ow, OW) { | |||||
size_t index = | |||||
4 * (ic * IH * IW + (h + fh2) * IW + | |||||
(ow + fw2)); | |||||
dst[i++] = src[index + 0]; | |||||
dst[i++] = src[index + 1]; | |||||
dst[i++] = src[index + 2]; | |||||
dst[i++] = src[index + 3]; | |||||
} | |||||
} | |||||
for (int w = 0; w < end_remain_w; w++) { | |||||
size_t index = 4 * (ic * IH * IW + | |||||
(end_h + fh2) * IW + (w + fw2)); | |||||
dst[i++] = src[index + 0]; | |||||
dst[i++] = src[index + 1]; | |||||
dst[i++] = src[index + 2]; | |||||
dst[i++] = src[index + 3]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
uint32_t* output = nullptr; | |||||
const uint32_t* uint32_src = | |||||
static_cast<const uint32_t*>(static_cast<const void*>(src)); | |||||
output = static_cast<uint32_t*>(static_cast<void*>(dst)); | |||||
if (same_line) { | |||||
rep(ic, newIC) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2, fw2; | |||||
if (is_xcorr) { | |||||
fh2 = fh; | |||||
fw2 = fw; | |||||
} else { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
for (int w = cur_remain_w; w < end_remain_w; w++) { | |||||
size_t index = (ic * IH * IW + | |||||
(start_h + fh2) * IW + (w + fw2)); | |||||
output[i++] = uint32_src[index]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
rep(ic, newIC) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2, fw2; | |||||
if (is_xcorr) { | |||||
fh2 = fh; | |||||
fw2 = fw; | |||||
} else { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
for (int w = cur_remain_w; w < OW; w++) { | |||||
size_t index = ic * IH * IW + (start_h + fh2) * IW + | |||||
(w + fw2); | |||||
output[i++] = uint32_src[index]; | |||||
} | |||||
for (int h = start_h + 1; h < end_h; h++) { | |||||
rep(ow, OW) { | |||||
size_t index = (ic * IH * IW + (h + fh2) * IW + | |||||
(ow + fw2)); | |||||
output[i++] = uint32_src[index]; | |||||
} | |||||
} | |||||
for (int w = 0; w < end_remain_w; w++) { | |||||
size_t index = (ic * IH * IW + (end_h + fh2) * IW + | |||||
(w + fw2)); | |||||
output[i++] = uint32_src[index]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <bool is_xcorr, typename dtype> | template <bool is_xcorr, typename dtype> | ||||
void img2col_stride(const dtype* __restrict src, dtype* __restrict dst, | void img2col_stride(const dtype* __restrict src, dtype* __restrict dst, | ||||
@@ -124,7 +124,8 @@ struct PostProcess { | |||||
megdnn::ConvBiasForward::BiasMode bias_mode, | megdnn::ConvBiasForward::BiasMode bias_mode, | ||||
megdnn::param::ConvBias::NonlineMode nonlineMode, | megdnn::param::ConvBias::NonlineMode nonlineMode, | ||||
DType bias_type, DType dst_type, size_t N, size_t OC, | DType bias_type, DType dst_type, size_t N, size_t OC, | ||||
size_t OH, size_t OW) { | |||||
size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||||
megdnn::param::Elemwise::Mode elem_mode = | megdnn::param::Elemwise::Mode elem_mode = | ||||
megdnn::param::Elemwise::Mode::ADD; | megdnn::param::Elemwise::Mode::ADD; | ||||
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | ||||
@@ -154,7 +155,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::FLOAT> { | |||||
megdnn::ConvBiasForward::BiasMode bias_mode, | megdnn::ConvBiasForward::BiasMode bias_mode, | ||||
megdnn::param::ConvBias::NonlineMode nonlineMode, | megdnn::param::ConvBias::NonlineMode nonlineMode, | ||||
DType bias_type, DType dst_type, size_t N, size_t OC, | DType bias_type, DType dst_type, size_t N, size_t OC, | ||||
size_t OH, size_t OW) { | |||||
size_t OH, size_t OW, size_t pack_oc_size=1) { | |||||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||||
megdnn::param::Elemwise::Mode elem_mode = | megdnn::param::Elemwise::Mode elem_mode = | ||||
megdnn::param::Elemwise::Mode::ADD; | megdnn::param::Elemwise::Mode::ADD; | ||||
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | ||||
@@ -185,7 +187,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||||
megdnn::ConvBiasForward::BiasMode bias_mode, | megdnn::ConvBiasForward::BiasMode bias_mode, | ||||
megdnn::param::ConvBias::NonlineMode nonlineMode, | megdnn::param::ConvBias::NonlineMode nonlineMode, | ||||
DType bias_type, DType dst_type, size_t N, size_t OC, | DType bias_type, DType dst_type, size_t N, size_t OC, | ||||
size_t OH, size_t OW) { | |||||
size_t OH, size_t OW,size_t pack_oc_size = 1) { | |||||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||||
MEGDNN_MARK_USED_VAR(conv_dst_ptr); | MEGDNN_MARK_USED_VAR(conv_dst_ptr); | ||||
MEGDNN_MARK_USED_VAR(bias_ptr); | MEGDNN_MARK_USED_VAR(bias_ptr); | ||||
MEGDNN_MARK_USED_VAR(dst_ptr); | MEGDNN_MARK_USED_VAR(dst_ptr); | ||||
@@ -292,7 +295,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> { | |||||
megdnn::ConvBiasForward::BiasMode bias_mode, | megdnn::ConvBiasForward::BiasMode bias_mode, | ||||
megdnn::param::ConvBiasV0::NonlineMode nonlineMode, | megdnn::param::ConvBiasV0::NonlineMode nonlineMode, | ||||
DType bias_type, DType dst_type, size_t N, size_t OC, | DType bias_type, DType dst_type, size_t N, size_t OC, | ||||
size_t OH, size_t OW) { | |||||
size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||||
megdnn::param::Elemwise::Mode elem_mode = | megdnn::param::Elemwise::Mode elem_mode = | ||||
megdnn::param::Elemwise::Mode::ADD; | megdnn::param::Elemwise::Mode::ADD; | ||||
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | ||||