GitOrigin-RevId: d326035202
tags/v0.4.0
@@ -17,18 +17,14 @@ | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#include "src/fallback/conv_bias/winograd/strategy.h" | |||
#include "src/naive/convolution/helper.h" | |||
#if MEGDNN_X86 | |||
#include "src/x86/conv_bias/postprocess_helper.h" | |||
#endif | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_fallback_im2col) | |||
using namespace megdnn; | |||
using namespace fallback; | |||
using namespace im2col; | |||
#if MEGDNN_X86 | |||
using namespace x86; | |||
#endif | |||
/*======================== AlgoIm2col=======================*/ | |||
/*! | |||
@@ -47,8 +43,8 @@ using Pack_Mode=fallback::MatrixMulImpl::AlgoBase::PackMode; | |||
static void copy_padding_kern(WorkspaceBundle bundle, | |||
const ConvBiasImpl::NCBKernParam& param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
StrategyBase* im2colstrategy) { | |||
im2colstrategy->copy_padding_kern(bundle, param, ncb_index); | |||
StrategyBase* im2colstrategy, size_t pack_oc_size) { | |||
im2colstrategy->copy_padding_kern(bundle, param, ncb_index, pack_oc_size); | |||
} | |||
//! packA_kern | |||
@@ -57,9 +53,9 @@ static void packA_kern(WorkspaceBundle bundle, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
StrategyBase* im2colstrategy) { | |||
StrategyBase* im2colstrategy, size_t pack_oc_size) { | |||
im2colstrategy->packA_kern(bundle, param, matmulparam, matmul_algo, | |||
ncb_index); | |||
ncb_index, pack_oc_size); | |||
} | |||
/*! | |||
@@ -129,14 +125,17 @@ public: | |||
size_t oc_tile_size) { | |||
size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0], | |||
FW = param.filter_meta.spatial[1]; | |||
size_t pack_oc_size = 1; | |||
size_t im2col = 0, packb = 0, bias_temp = 0; | |||
bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT; | |||
megdnn_assert(default_pack, "only support default packa"); | |||
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||
pack_oc_size = 4; | |||
} | |||
size_t im2col_dst_size = | |||
IC * FH * FW * ohw_tile_size * sizeof(param.src_type); | |||
size_t matmul_dst_size = | |||
oc_tile_size * ohw_tile_size * sizeof(param.bias_type); | |||
size_t matmul_dst_size = pack_oc_size * oc_tile_size * ohw_tile_size * | |||
sizeof(param.bias_type); | |||
//! matmul_dst and im2col_dst use the same memory | |||
WorkspaceBundle wb = matmul_algo->get_bundle(im2col_kern_param); | |||
packb = wb.get_size(1); | |||
@@ -318,17 +317,18 @@ public: | |||
} | |||
}; | |||
#undef FILL_IM2COL_STRATEGY_PARAM | |||
fallback::MatrixMulImpl::KernSizeParam | |||
ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | |||
size_t ohw_tile_size, | |||
size_t oc_tile_size) const { | |||
bool is_nchw44 = | |||
param.filter_meta.format == param::ConvBias::Format::NCHW44; | |||
size_t M = oc_tile_size; | |||
size_t N = ohw_tile_size; | |||
size_t K = param.filter_meta.icpg * param.filter_meta.spatial[0] * | |||
param.filter_meta.spatial[1]; | |||
size_t LDA = K, LDB = N, LDC = N; | |||
size_t pack_oc_size = is_nchw44 ? 4 : 1; | |||
size_t LDA = pack_oc_size * K, LDB = pack_oc_size * N, LDC = N; | |||
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | |||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
@@ -345,7 +345,8 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | |||
false, | |||
false, | |||
param::MatrixMul::ComputeMode::DEFAULT, | |||
param::MatrixMul::Format::DEFAULT}; | |||
is_nchw44 ? param::MatrixMul::Format::MK4 | |||
: param::MatrixMul::Format::DEFAULT}; | |||
} | |||
void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( | |||
@@ -405,6 +406,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||
size_t GROUP = param.filter_meta.group; | |||
bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT; | |||
bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA; | |||
if (need_pack || only_packA) { | |||
auto inner_block = m_matmul_algo->get_inner_block_size(); | |||
choice_ohw_oc_block(param, inner_block.m, inner_block.n, need_pack); | |||
@@ -421,16 +423,19 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||
need_pack); | |||
packa_group_size = 0; | |||
} | |||
if (no_need_pading) { | |||
padding = 0; //! not need padding | |||
} else { | |||
padding = (GROUP * N * IC * IH2 * IW2) * | |||
sizeof(param.src_type); //! for padding | |||
} | |||
packa_size = GROUP * packa_group_size; //! for packA size = GROUP * a_size | |||
WorkspaceBundle ws = {nullptr, {}}; | |||
auto im2col_kern_param = | |||
get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size); | |||
if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) { | |||
Im2colKerns<Pack_Mode::DEFAULT> defaultkern; | |||
ws = defaultkern.get_thread_bundle(param, im2col_kern_param, | |||
@@ -447,6 +452,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||
m_matmul_algo, m_ohw_tile_size, | |||
m_oc_tile_size); | |||
} | |||
return {nullptr, | |||
{padding, packa_size, ws.total_size_in_bytes() * nr_threads}}; | |||
} | |||
@@ -461,7 +467,7 @@ size_t ConvBiasImpl::AlgoIm2col::get_workspace( | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
ConvBiasImpl* opr, const NCBKernSizeParam& param) const { | |||
ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 1) { | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
MEGDNN_MARK_USED_VAR(SH); | |||
@@ -473,7 +479,6 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
size_t ohw = OH * OW; | |||
size_t ohw_parallel_times = div_ceil(ohw, m_ohw_tile_size); | |||
size_t GROUP = param.filter_meta.group; | |||
WorkspaceBundle bundle = get_bundle(param); | |||
WorkspaceBundle bundle_thread = {nullptr, {}}; | |||
size_t oc_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size); | |||
@@ -483,11 +488,14 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
bool no_pack = packmode == Pack_Mode::NO_PACK; | |||
bool only_packA = packmode == Pack_Mode::ONLY_PACKA; | |||
size_t packa_parallel_times = 0; | |||
size_t pack_oc_size = | |||
(param.filter_meta.format == param::ConvBias::Format::NCHW ? 1 | |||
: 4); | |||
if (only_packA) { | |||
packa_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size); | |||
} else if (default_pack) { | |||
packa_parallel_times = div_ceil<size_t>( | |||
OC, m_matmul_algo->get_inner_block_size().m); | |||
OC, m_matmul_algo->get_inner_block_size().m * pack_oc_size); | |||
} | |||
auto matmul_param = get_matmul_kern_param( | |||
@@ -520,25 +528,29 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
strategyparam.skip_copy_dst = | |||
strategyparam.is_ohw_size_bigger && !strategyparam.is_dst_8bit; | |||
strategyparam.oc_tile_size = m_oc_tile_size; | |||
strategyparam.pack_oc_size = pack_oc_size; | |||
SmallVector<ConvBiasImpl::NCBKern> ret_kern; | |||
MIDOUT_BEGIN( | |||
megdnn_fallback_im2col, | |||
midout_iv("ConvBiasImpl::AlgoIm2col::dispatch_kerns"_hash)) { | |||
StrategyBase* im2colstrategy = Factory::get_im2col_strategy( | |||
param, m_matmul_algo, opr->param().format); | |||
auto kern_padding = [bundle, im2colstrategy]( | |||
StrategyBase* im2colstrategy = | |||
Factory::get_im2col_strategy(param, m_matmul_algo); | |||
auto kern_padding = [bundle, im2colstrategy, | |||
pack_oc_size = pack_oc_size]( | |||
const NCBKernParam& param, | |||
const NCBKernIndex& ncb_index) { | |||
copy_padding_kern(bundle, param, ncb_index, im2colstrategy); | |||
copy_padding_kern(bundle, param, ncb_index, im2colstrategy, | |||
pack_oc_size); | |||
}; | |||
auto kern_packA = [bundle, matmul_algo = m_matmul_algo, | |||
matmul_param, | |||
im2colstrategy](const NCBKernParam& param, | |||
const NCBKernIndex& ncb_index) { | |||
matmul_param, im2colstrategy, | |||
pack_oc_size = pack_oc_size]( | |||
const NCBKernParam& param, | |||
const NCBKernIndex& ncb_index) { | |||
packA_kern(bundle, param, matmul_param, matmul_algo, ncb_index, | |||
im2colstrategy); | |||
im2colstrategy, pack_oc_size); | |||
}; | |||
if (default_pack) { | |||
auto kern_compute_default = | |||
@@ -556,7 +568,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}}); | |||
if (need_padding) { | |||
ret_kern.push_back({kern_padding, {param.n, GROUP, IC}}); | |||
ret_kern.push_back({kern_padding, | |||
{param.n, GROUP, IC / pack_oc_size}}); | |||
} | |||
ret_kern.push_back( | |||
{kern_compute_default, | |||
@@ -629,19 +642,25 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
return false; | |||
} | |||
//! current now im2col only support int8 quantized s8 nchw44 | |||
if (opr->param().format == param::ConvBias::Format::NCHW44 && | |||
(param.src_type.enumv() == param.filter_type.enumv() && | |||
(param.src_type.enumv() != DTypeEnum::Int8) && | |||
(param.src_type.enumv() != DTypeEnum::QuantizedS8))) { | |||
return false; | |||
} | |||
fallback::MatrixMulImpl::KernSizeParam matmul_param = | |||
get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size); | |||
bool matmulusable = m_matmul_algo->usable(matmul_param); | |||
return matmulusable && | |||
(opr->param().format == param::ConvBias::Format::NCHW) && | |||
((param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||
(param.filter_meta.spatial[0] <= 7) && | |||
(param.filter_meta.spatial[0] >= 2)) || | |||
(param.filter_meta.spatial[0] != param.filter_meta.spatial[1] && | |||
(param.filter_meta.spatial[0] <= 7) && | |||
(param.filter_meta.spatial[0] >= 1) && | |||
(param.filter_meta.spatial[1] <= 7) && | |||
(param.filter_meta.spatial[1] >= 1))) && | |||
(opr->param().format == param::ConvBias::Format::NCHW || | |||
opr->param().format == param::ConvBias::Format::NCHW44) && | |||
(!(param.filter_meta.spatial[0] == | |||
param.filter_meta.spatial[1] && | |||
(param.filter_meta.spatial[0] == 1) && | |||
param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||
param.filter_meta.stride[0] == 1)) && | |||
(param.filter_meta.dilation[0] == | |||
param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
@@ -36,7 +36,6 @@ class ConvBiasImpl::AlgoIm2col final : public AlgoBase { | |||
const NCBKernSizeParam& param, size_t ohw_tile_size, | |||
size_t oc_tile_size) const; | |||
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; | |||
WorkspaceBundle get_thread_bundle(const NCBKernSizeParam& param) const; | |||
void choice_ohw_oc_block(const NCBKernSizeParam& param, size_t block_m, | |||
size_t block_n, bool pack_default) const; | |||
@@ -23,19 +23,11 @@ namespace im2col { | |||
enum class StrategyType : uint32_t { | |||
FLOAT = 0, | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
FLOAT_FP16 = 1, | |||
#else | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
FLOAT16_FLOAT16 = 2, | |||
#endif | |||
#endif | |||
INT8x8x32 = 3, | |||
INT8x8x16 = 4, | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
QUINT8x8x32 = 5, | |||
QUINT8x8x32x8 = 6, | |||
#endif | |||
QINT8x8x32 = 7, | |||
QINT8x8x32x8 = 8 | |||
}; | |||
@@ -107,8 +99,7 @@ public: | |||
~StrategyDelegationStorage() = default; | |||
template <typename Strategy> | |||
Strategy* get(param::ConvBias::Format format, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
Strategy* get(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
StrategyType stype); | |||
}; | |||
@@ -117,12 +108,10 @@ class Factory { | |||
public: | |||
static StrategyBase* get_im2col_strategy( | |||
const fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
param::ConvBias::Format format) { | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
static StrategyDelegationStorage storage; | |||
StrategyType strategytype = get_strategy_type(param); | |||
return storage.get<StrategyBase>(format, matmul_algo, param, | |||
strategytype); | |||
return storage.get<StrategyBase>(matmul_algo, param, strategytype); | |||
} | |||
static StrategyType get_strategy_type( | |||
@@ -141,13 +130,9 @@ public: | |||
} | |||
cb1(dt_float32, dt_float32, StrategyType::FLOAT); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
cb1(dt_float16, __fp16, StrategyType::FLOAT_FP16); | |||
#else | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
cb1(dt_float16, dt_float16, StrategyType::FLOAT16_FLOAT16); | |||
#endif | |||
#endif | |||
cb2(dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32, | |||
StrategyType::INT8x8x32); | |||
@@ -155,13 +140,6 @@ public: | |||
cb2(dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, dt_int16, | |||
StrategyType::INT8x8x16); | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::QuantizedS32, | |||
dt_uint8, dt_int32, dt_int32, StrategyType::QUINT8x8x32); | |||
cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::Quantized8Asymm, | |||
dt_uint8, dt_int32, dt_uint8, StrategyType::QUINT8x8x32x8); | |||
#endif | |||
cb2(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32, | |||
dt_int8, dt_int32, dt_int32, StrategyType::QINT8x8x32); | |||
@@ -172,98 +150,106 @@ public: | |||
megdnn_throw("not support datatype in im2col strategy\n"); | |||
} | |||
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||
midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||
return std::make_unique< \ | |||
Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||
_postprocess_mode, PackMode::_packmode>>(); \ | |||
} \ | |||
} \ | |||
MIDOUT_END(); \ | |||
#define cb1(_format, _packmode, _dt, _post_ctype, _postprocess_mode, \ | |||
_midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||
midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||
return std::make_unique< \ | |||
Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||
_postprocess_mode, PackMode::_packmode, \ | |||
FormatMode::_format>>(); \ | |||
} \ | |||
} \ | |||
MIDOUT_END(); \ | |||
return {}; | |||
#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||
midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
return std::make_unique< \ | |||
Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||
DTypeTrait<_i_bias_type>::ctype, \ | |||
DTypeTrait<_i_dst_type>::ctype, \ | |||
_postprocess_mode, PackMode::_packmode>>(); \ | |||
} \ | |||
} \ | |||
MIDOUT_END(); \ | |||
#define cb2(_format, _packmode, _i_src_type, _i_bias_type, _i_dst_type, \ | |||
_src_ctype, _bias_ctype, _dst_ctype, _postprocess_mode, \ | |||
_midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||
midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
return std::make_unique<Strategy< \ | |||
_src_ctype, _bias_ctype, _dst_ctype, \ | |||
DTypeTrait<_i_bias_type>::ctype, \ | |||
DTypeTrait<_i_dst_type>::ctype, _postprocess_mode, \ | |||
PackMode::_packmode, FormatMode::_format>>(); \ | |||
} \ | |||
} \ | |||
MIDOUT_END(); \ | |||
return {}; | |||
static std::unique_ptr<StrategyBase> make_default_strategy( | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
param::ConvBias::Format format, StrategyType strategytype) { | |||
StrategyType strategytype) { | |||
MEGDNN_MARK_USED_VAR(matmul_algo); | |||
MEGDNN_MARK_USED_VAR(format); | |||
param::ConvBias::Format format = param.filter_meta.format; | |||
switch (strategytype) { | |||
case StrategyType::FLOAT: | |||
cb1(DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT, | |||
"DefaultStrategyType::FLOAT"_hash); | |||
break; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
case StrategyType::FLOAT_FP16: | |||
cb1(DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, | |||
"DefaultStrategyType::FLOAT_FP16"_hash); | |||
cb1(NCHW, DEFAULT, dt_float32, dt_float32, | |||
PostprocessMode::FLOAT, "DefaultStrategyType::FLOAT"_hash); | |||
break; | |||
#else | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
case StrategyType::FLOAT16_FLOAT16: | |||
cb1(DEFAULT, dt_float16, dt_float16, | |||
cb1(NCHW, DEFAULT, dt_float16, dt_float16, | |||
PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyType::FLOAT16_FLOAT16"_hash); | |||
break; | |||
#endif | |||
#endif | |||
case StrategyType::INT8x8x32: | |||
cb2(DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, | |||
dt_int32, PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyType::INT8x8x32"_hash); | |||
if (format == param::ConvBias::Format::NCHW) { | |||
cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyType::INT8x8x32"_hash); | |||
} else if (format == param::ConvBias::Format::NCHW44) { | |||
cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyType::INT8x8x32"_hash); | |||
} else { | |||
megdnn_throw("not support format except nchw44 and nchw\n"); | |||
} | |||
break; | |||
case StrategyType::INT8x8x16: | |||
cb2(DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, | |||
dt_int16, PostprocessMode::NO_PROCESS, | |||
cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyType::INT8x8x16"_hash); | |||
break; | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
case StrategyType::QUINT8x8x32: | |||
cb2(DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyType::QUINT8x8x32"_hash); | |||
break; | |||
case StrategyType::QUINT8x8x32x8: | |||
cb2(DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | |||
PostprocessMode::QUANTIZED, | |||
"DefaultStrategyType::QUINT8x8x32x8"_hash); | |||
break; | |||
#endif | |||
case StrategyType::QINT8x8x32: | |||
cb2(DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyType::QINT8x8x32"_hash); | |||
if (format == param::ConvBias::Format::NCHW) { | |||
cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyTypeNCHW::QINT8x8x32"_hash); | |||
} else if (format == param::ConvBias::Format::NCHW44) { | |||
cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); | |||
} else { | |||
megdnn_throw("not support format except nchw44 and nchw\n"); | |||
} | |||
break; | |||
case StrategyType::QINT8x8x32x8: | |||
cb2(DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||
PostprocessMode::QUANTIZED, | |||
"DefaultStrategyType::QINT8x8x32x8"_hash); | |||
if (format == param::ConvBias::Format::NCHW) { | |||
cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||
PostprocessMode::QUANTIZED, | |||
"DefaultStrategyType::QINT8x8x32x8"_hash); | |||
} else if (format == param::ConvBias::Format::NCHW44) { | |||
cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, | |||
dt_int32, dt_int8, PostprocessMode::QUANTIZED, | |||
"DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash); | |||
} else { | |||
megdnn_throw("not support format except nchw44 and nchw\n"); | |||
} | |||
break; | |||
} | |||
megdnn_throw("error not support strategy type "); | |||
@@ -272,63 +258,41 @@ public: | |||
static std::unique_ptr<StrategyBase> make_nopack_strategy( | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
param::ConvBias::Format format, StrategyType strategytype) { | |||
StrategyType strategytype) { | |||
MEGDNN_MARK_USED_VAR(matmul_algo); | |||
MEGDNN_MARK_USED_VAR(format); | |||
switch (strategytype) { | |||
case StrategyType::FLOAT: | |||
cb1(NO_PACK, dt_float32, dt_float32, PostprocessMode::FLOAT, | |||
"NoPackStrategyType::FLOAT"_hash); | |||
break; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
case StrategyType::FLOAT_FP16: | |||
cb1(NO_PACK, dt_float16, __fp16, PostprocessMode::FLOAT, | |||
"NoPackStrategyType::FLOAT_FP16"_hash); | |||
cb1(NCHW, NO_PACK, dt_float32, dt_float32, | |||
PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash); | |||
break; | |||
#else | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
case StrategyType::FLOAT16_FLOAT16: | |||
cb1(NO_PACK, dt_float16, dt_float16, PostprocessMode::NO_PROCESS, | |||
cb1(NCHW, NO_PACK, dt_float16, dt_float16, | |||
PostprocessMode::NO_PROCESS, | |||
"NoPackStrategyType::FLOAT16_FLOAT16"_hash); | |||
break; | |||
#endif | |||
#endif | |||
case StrategyType::INT8x8x32: | |||
cb2(NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, | |||
dt_int32, PostprocessMode::NO_PROCESS, | |||
cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
"NoPackStrategyType::INT8x8x32"_hash); | |||
break; | |||
case StrategyType::INT8x8x16: | |||
cb2(NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, | |||
dt_int16, PostprocessMode::NO_PROCESS, | |||
cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | |||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
"NoPackStrategyType::INT8x8x16"_hash); | |||
break; | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
case StrategyType::QUINT8x8x32: | |||
cb2(NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
"NoPackStrategyType::QUINT8x8x32"_hash); | |||
break; | |||
case StrategyType::QUINT8x8x32x8: | |||
cb2(NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | |||
PostprocessMode::QUANTIZED, | |||
"NoPackStrategyType::QUINT8x8x32x8"_hash); | |||
break; | |||
#endif | |||
case StrategyType::QINT8x8x32: | |||
cb2(NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||
cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
"NoPackStrategyType::QINT8x8x32"_hash); | |||
break; | |||
case StrategyType::QINT8x8x32x8: | |||
cb2(NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||
cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||
PostprocessMode::QUANTIZED, | |||
"NoPackStrategyType::QINT8x8x32x8"_hash); | |||
@@ -340,64 +304,42 @@ public: | |||
static std::unique_ptr<StrategyBase> make_onlypacka_strategy( | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
param::ConvBias::Format format, StrategyType strategytype) { | |||
StrategyType strategytype) { | |||
MEGDNN_MARK_USED_VAR(matmul_algo); | |||
MEGDNN_MARK_USED_VAR(format); | |||
switch (strategytype) { | |||
case StrategyType::FLOAT: | |||
cb1(ONLY_PACKA, dt_float32, dt_float32, PostprocessMode::FLOAT, | |||
cb1(NCHW, ONLY_PACKA, dt_float32, dt_float32, | |||
PostprocessMode::FLOAT, | |||
"OnlyPackaStrategyType::FLOAT"_hash); | |||
break; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
case StrategyType::FLOAT_FP16: | |||
cb1(ONLY_PACKA, dt_float16, __fp16, PostprocessMode::FLOAT, | |||
"OnlyPackaStrategyType::FLOAT_FP16"_hash); | |||
break; | |||
#else | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
case StrategyType::FLOAT16_FLOAT16: | |||
cb1(ONLY_PACKA, dt_float16, dt_float16, | |||
cb1(NCHW, ONLY_PACKA, dt_float16, dt_float16, | |||
PostprocessMode::NO_PROCESS, | |||
"OnlyPackaStrategyType::FLOAT16_FLOAT16"_hash); | |||
break; | |||
#endif | |||
#endif | |||
case StrategyType::INT8x8x32: | |||
cb2(ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, | |||
dt_int32, PostprocessMode::NO_PROCESS, | |||
cb2(NCHW, ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
"OnlyPackaStrategyType::INT8x8x32"_hash); | |||
break; | |||
case StrategyType::INT8x8x16: | |||
cb2(ONLY_PACKA, dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, | |||
dt_int16, PostprocessMode::NO_PROCESS, | |||
cb2(NCHW, ONLY_PACKA, dt_int8, dt_int16, dt_int16, dt_int8, | |||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
"OnlyPackaStrategyType::INT8x8x16"_hash); | |||
break; | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
case StrategyType::QUINT8x8x32: | |||
cb2(ONLY_PACKA, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
"OnlyPackaStrategyType::QUINT8x8x32"_hash); | |||
break; | |||
case StrategyType::QUINT8x8x32x8: | |||
cb2(ONLY_PACKA, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | |||
PostprocessMode::QUANTIZED, | |||
"OnlyPackaStrategyType::QUINT8x8x32x8"_hash); | |||
break; | |||
#endif | |||
case StrategyType::QINT8x8x32: | |||
cb2(ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||
cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
"OnlyPackaStrategyType::QINT8x8x32"_hash); | |||
break; | |||
case StrategyType::QINT8x8x32x8: | |||
cb2(ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||
cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||
PostprocessMode::QUANTIZED, | |||
"OnlyPackaStrategyType::QINT8x8x32x8"_hash); | |||
@@ -410,21 +352,19 @@ public: | |||
#undef cb2 | |||
static std::unique_ptr<StrategyBase> make_strategy( | |||
param::ConvBias::Format format, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
fallback::MatrixMulImpl::AlgoBase::PackMode packmode, | |||
const fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
StrategyType stype) { | |||
switch (packmode) { | |||
case MatrixMulImpl::AlgoBase::PackMode::DEFAULT: | |||
return make_default_strategy(matmul_algo, param, format, stype); | |||
return make_default_strategy(matmul_algo, param, stype); | |||
break; | |||
case MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA: | |||
return make_onlypacka_strategy(matmul_algo, param, format, | |||
stype); | |||
return make_onlypacka_strategy(matmul_algo, param, stype); | |||
break; | |||
case MatrixMulImpl::AlgoBase::PackMode::NO_PACK: | |||
return make_nopack_strategy(matmul_algo, param, format, stype); | |||
return make_nopack_strategy(matmul_algo, param, stype); | |||
break; | |||
default: | |||
megdnn_throw( | |||
@@ -432,14 +372,12 @@ public: | |||
"nopack"); | |||
break; | |||
} | |||
megdnn_throw( | |||
"factory make Strategy error please check your code"); | |||
megdnn_throw("factory make Strategy error please check your code"); | |||
} | |||
}; | |||
template <typename Strategy> | |||
Strategy* StrategyDelegationStorage::get( | |||
param::ConvBias::Format format, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
StrategyType stype) { | |||
@@ -455,14 +393,14 @@ Strategy* StrategyDelegationStorage::get( | |||
} | |||
StrategyHashParam sparam; | |||
sparam.param = param; | |||
sparam.format = format; | |||
sparam.format = param.filter_meta.format; | |||
sparam.packmode = packmode; | |||
sparam.block_m = block_m; | |||
sparam.block_n = block_n; | |||
sparam.block_k = block_k; | |||
if (map_strategys.find(sparam) == map_strategys.end()) { | |||
MEGDNN_LOCK_GUARD(m_mtx); | |||
auto strategy = Factory::make_strategy(format, matmul_algo, packmode, | |||
auto strategy = Factory::make_strategy(matmul_algo, packmode, | |||
param, stype); | |||
map_strategys[sparam] = std::move(strategy); | |||
} | |||
@@ -14,6 +14,7 @@ | |||
namespace megdnn { | |||
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | |||
using FormatMode = param::ConvBias::Format; | |||
struct StrategyParam { | |||
size_t batch_id; | |||
@@ -28,6 +29,7 @@ struct StrategyParam { | |||
size_t block_m; | |||
size_t block_n; | |||
size_t block_k; | |||
size_t pack_oc_size; | |||
bool skip_copy_dst; | |||
bool is_dst_8bit; | |||
bool is_ohw_size_bigger; | |||
@@ -40,13 +42,15 @@ public: | |||
virtual void copy_padding_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) = 0; | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_size) = 0; | |||
virtual void packA_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) = 0; | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_size) = 0; | |||
virtual void exec_im2col( | |||
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
@@ -70,14 +74,16 @@ public: | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode, PackMode packmode> | |||
megdnn::PostprocessMode postprocess_mode, PackMode packmode, | |||
FormatMode format> | |||
class Strategy; | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::DEFAULT> : public StrategyBase { | |||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW> | |||
: public StrategyBase { | |||
public: | |||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | |||
@@ -85,24 +91,26 @@ public: | |||
constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1; | |||
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; | |||
Strategy(); | |||
Strategy() = default; | |||
void copy_padding_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
void packA_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_size) override; | |||
void packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_size) override; | |||
virtual void exec_im2col( | |||
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
void exec_matmul( | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
@@ -132,7 +140,32 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::NO_PACK> : public StrategyBase { | |||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW44> | |||
: public Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::DEFAULT, | |||
FormatMode::NCHW> { | |||
public: | |||
const size_t BUNDLE_PADDING_INDEX = 0; | |||
const size_t BUNDLE_PACKA_INDEX = 1; | |||
const size_t THREAD_BUNDLE_PACKB_INDEX = 0; | |||
const size_t THREAD_BUNDLE_IM2COL_INDEX = 1; | |||
const size_t THREAD_BUNDLE_BIAS_INDEX = 2; | |||
Strategy() = default; | |||
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
}; | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW> | |||
: public StrategyBase { | |||
public: | |||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | |||
@@ -141,19 +174,20 @@ public: | |||
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; | |||
constexpr static size_t THREAD_BUNDLE_MATCOMP_INDEX = 3; | |||
Strategy(); | |||
Strategy() = default; | |||
void copy_padding_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_size) override; | |||
void packA_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
void packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_size) override; | |||
void exec_matmul( | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
@@ -197,7 +231,8 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::ONLY_PACKA> : public StrategyBase { | |||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW> | |||
: public StrategyBase { | |||
public: | |||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | |||
@@ -206,19 +241,20 @@ public: | |||
constexpr static size_t THREAD_BUNDLE_MATMULDST_INDEX = 2; | |||
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 3; | |||
Strategy(); | |||
Strategy() = default; | |||
void copy_padding_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
void packA_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_size) override; | |||
void packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_size) override; | |||
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
const StrategyParam& sparam, | |||
@@ -8,8 +8,6 @@ | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/opr_param_defs.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | |||
#include "src/fallback/convolution/img2col_helper.h" | |||
#if MEGDNN_X86 | |||
@@ -25,19 +23,12 @@ namespace megdnn { | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::DEFAULT>::Strategy() | |||
: StrategyBase() {} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::DEFAULT>:: | |||
copy_padding_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
copy_padding_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_oc_size) { | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
MEGDNN_MARK_USED_VAR(N); | |||
MEGDNN_MARK_USED_VAR(OC); | |||
@@ -53,9 +44,13 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
size_t batch_id = ncb_index.ndrange_id[0]; | |||
size_t group_id = ncb_index.ndrange_id[1]; | |||
size_t channel_id = ncb_index.ndrange_id[2]; | |||
size_t PH_SIZE = PH * IW2 * pack_oc_size; | |||
PW = PW * pack_oc_size; | |||
IW = IW * pack_oc_size; | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
size_t workspace_channel_offset = IH2 * IW2 * channel_id; | |||
size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id; | |||
size_t workspace_group_offset = group_id * padding_group_size; | |||
size_t workspace_batch_offset = | |||
param.filter_meta.group * batch_id * padding_group_size; | |||
@@ -65,8 +60,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||
} | |||
src_ctype* src = const_cast<src_ctype*>( | |||
param.src<src_ctype>(batch_id, group_id, channel_id)); | |||
src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>( | |||
batch_id, group_id, channel_id, 1, pack_oc_size)); | |||
src_ctype* src2; | |||
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||
workspace_group_offset + workspace_batch_offset + | |||
@@ -74,8 +69,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
src_ctype* src2_ptr = src2; | |||
const src_ctype* src_ptr = src; | |||
if (PH != 0) { | |||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||
src2_ptr += PH * IW2; | |||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||
src2_ptr += PH_SIZE; | |||
} | |||
rep(ih, IH) { | |||
if (PW != 0) | |||
@@ -87,8 +82,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||
} | |||
if (PH != 0) { | |||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||
src2_ptr += PH * IW2; | |||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||
src2_ptr += PH_SIZE; | |||
} | |||
} | |||
@@ -96,12 +91,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::DEFAULT>:: | |||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_oc_size) { | |||
bundle.set(param.workspace_ptr); | |||
fallback::MatrixMulImpl::KernParam matmul_param; | |||
size_t group_id = ncb_index.ndrange_id[0]; | |||
@@ -114,38 +110,38 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
matmul_algo->get_packA_type_size(); | |||
size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size; | |||
int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) + | |||
group_id * packA_group_size + a_panel_offset; | |||
group_id * packA_group_size + | |||
(pack_oc_size == 4 ? 0 : a_panel_offset); | |||
matmul_param.A_ptr = | |||
const_cast<src_ctype*>(param.filter<src_ctype>(group_id)); | |||
matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1], | |||
matmul_algo->get_inner_block_size().m); | |||
matmul_algo->get_inner_block_size().m * pack_oc_size); | |||
} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::DEFAULT>:: | |||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo | |||
) { | |||
size_t m_sh = param.filter_meta.stride[0]; | |||
size_t m_sw = param.filter_meta.stride[1]; | |||
size_t m_oc = param.filter_meta.ocpg; | |||
size_t m_oh = param.osz[0]; | |||
size_t m_ow = param.osz[1]; | |||
size_t m_ic = param.filter_meta.icpg; | |||
size_t m_ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
size_t m_iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||
size_t m_fh = param.filter_meta.spatial[0]; | |||
size_t m_fw = param.filter_meta.spatial[1]; | |||
size_t m_is_xcorr = !param.filter_meta.should_flip; | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
size_t sh = param.filter_meta.stride[0]; | |||
size_t sw = param.filter_meta.stride[1]; | |||
size_t oc = param.filter_meta.ocpg; | |||
size_t oh = param.osz[0]; | |||
size_t ow = param.osz[1]; | |||
size_t ic = param.filter_meta.icpg; | |||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||
size_t fh = param.filter_meta.spatial[0]; | |||
size_t fw = param.filter_meta.spatial[1]; | |||
size_t is_xcorr = !param.filter_meta.should_flip; | |||
size_t input_offset = | |||
m_ih * m_iw * m_ic * | |||
ih * iw * ic * | |||
(sparam.group_id + param.filter_meta.group * sparam.batch_id) * | |||
sizeof(src_ctype); | |||
@@ -160,26 +156,22 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
} | |||
src_ctype* im2col_dst = static_cast<src_ctype*>( | |||
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | |||
if (m_sh == 1 && m_sw == 1) { | |||
if (m_is_xcorr) { | |||
img2col<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||
m_fh, m_fw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
if (sh == 1 && sw == 1) { | |||
if (is_xcorr) { | |||
img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||
sparam.ohw_cur_index, sparam.output_block_size); | |||
} else { | |||
img2col<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||
m_fh, m_fw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||
sparam.ohw_cur_index, sparam.output_block_size); | |||
} | |||
} else { | |||
if (m_is_xcorr) { | |||
img2col_stride<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, | |||
m_iw, m_fh, m_fw, m_sh, m_sw, | |||
sparam.ohw_cur_index, | |||
if (is_xcorr) { | |||
img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
fw, sh, sw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
} else { | |||
img2col_stride<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, | |||
m_ih, m_iw, m_fh, m_fw, m_sh, m_sw, | |||
sparam.ohw_cur_index, | |||
img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
fw, sh, sw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
} | |||
} | |||
@@ -199,7 +191,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::DEFAULT>:: | |||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread, | |||
const StrategyParam& sparam) { | |||
@@ -218,7 +210,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::DEFAULT>:: | |||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, WorkspaceBundle bundle, | |||
WorkspaceBundle bundle_thread, | |||
@@ -240,11 +232,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
src_ctype* b_panel = | |||
reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>( | |||
bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX))); | |||
size_t pack_oc_size = sparam.pack_oc_size; | |||
matmul_param.M = sparam.output_block_oc_size; | |||
matmul_param.N = sparam.output_block_size; | |||
matmul_param.LDB = sparam.output_block_size; | |||
matmul_param.LDC = sparam.output_block_size; | |||
matmul_param.LDB = pack_oc_size * sparam.output_block_size; | |||
matmul_param.LDC = pack_oc_size * sparam.output_block_size; | |||
matmul_param.C_ptr = matmul_dst; | |||
auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); | |||
@@ -255,7 +247,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::DEFAULT>:: | |||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, | |||
WorkspaceBundle bundle_thread) { | |||
@@ -274,7 +266,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||
matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode, | |||
param.nonlineMode, param.bias_type, param.dst_type, 1_z, | |||
sparam.output_block_oc_size, 1_z, sparam.output_block_size); | |||
sparam.output_block_oc_size, 1_z, sparam.output_block_size, | |||
sparam.pack_oc_size); | |||
copy_dst(param, matmul_dst, sparam); | |||
} | |||
@@ -282,20 +275,24 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::DEFAULT>:: | |||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const void* matmul_dst, const StrategyParam& sparam) { | |||
if (!sparam.skip_copy_dst) { | |||
size_t pack_oc_size = sparam.pack_oc_size; | |||
dst_ctype* dst_tmp_ptr = | |||
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | |||
dst_ctype* dst = | |||
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | |||
sparam.oc_cur_index * sparam.ohw + sparam.ohw_cur_index; | |||
for (size_t oc = 0; oc < sparam.output_block_oc_size; oc++) { | |||
sparam.oc_cur_index * sparam.ohw + | |||
sparam.ohw_cur_index * pack_oc_size; | |||
size_t oc_loop = sparam.output_block_oc_size / pack_oc_size; | |||
for (size_t oc = 0; oc < oc_loop; oc++) { | |||
std::memcpy(dst, dst_tmp_ptr, | |||
sizeof(dst_ctype) * sparam.output_block_size); | |||
dst_tmp_ptr += sparam.output_block_size; | |||
dst += sparam.ohw; | |||
sizeof(dst_ctype) * sparam.output_block_size * | |||
pack_oc_size); | |||
dst_tmp_ptr += sparam.output_block_size * pack_oc_size; | |||
dst += sparam.ohw * pack_oc_size; | |||
} | |||
} | |||
} | |||
@@ -304,7 +301,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::DEFAULT>:: | |||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread) { | |||
bias_ctype* bias_tmp_ptr = | |||
@@ -319,7 +316,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::DEFAULT>:: | |||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | |||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||
@@ -340,31 +337,20 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
} | |||
} | |||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
_op_dtype, _postprocess_mode) \ | |||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
_op_dtype, _postprocess_mode, PackMode::DEFAULT>; | |||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
_op_dtype, _postprocess_mode) \ | |||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
_op_dtype, _postprocess_mode, PackMode::DEFAULT, \ | |||
FormatMode::NCHW>; | |||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | |||
megdnn::PostprocessMode::FLOAT) | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||
megdnn::PostprocessMode::FLOAT) | |||
#else | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#endif | |||
#endif | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||
megdnn::PostprocessMode::QUANTIZED) | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#endif | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||
megdnn::PostprocessMode::QUANTIZED) | |||
@@ -0,0 +1,118 @@ | |||
/** | |||
* \file dnn/src/fallback/conv_bias/im2col/strategy_default.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | |||
#include "src/fallback/convolution/img2col_helper.h" | |||
#if MEGDNN_X86 | |||
#include "src/x86/conv_bias/postprocess_helper.h" | |||
#endif | |||
using namespace megdnn; | |||
#if MEGDNN_X86 | |||
using namespace x86; | |||
#endif | |||
namespace megdnn { | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW44>:: | |||
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
size_t sh = param.filter_meta.stride[0]; | |||
size_t sw = param.filter_meta.stride[1]; | |||
size_t oc = param.filter_meta.ocpg; | |||
size_t oh = param.osz[0]; | |||
size_t ow = param.osz[1]; | |||
size_t ic = param.filter_meta.icpg; | |||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||
size_t fh = param.filter_meta.spatial[0]; | |||
size_t fw = param.filter_meta.spatial[1]; | |||
size_t is_xcorr = !param.filter_meta.should_flip; | |||
constexpr static size_t pack_size = 4; | |||
size_t input_offset = | |||
ih * iw * ic * | |||
(sparam.group_id + param.filter_meta.group * sparam.batch_id) * | |||
sizeof(src_ctype); | |||
src_ctype* src2 = reinterpret_cast<src_ctype*>( | |||
reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||
input_offset); | |||
bool is_phpwzero = param.filter_meta.padding[0] == 0 && | |||
param.filter_meta.padding[1] == 0; | |||
if (is_phpwzero) { | |||
src2 = const_cast<src_ctype*>( | |||
param.src<src_ctype>(sparam.batch_id, sparam.group_id)); | |||
} | |||
src_ctype* im2col_dst = static_cast<src_ctype*>( | |||
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | |||
if (is_xcorr) { | |||
if (sh == sw && sh == 1) { | |||
img2col_nchw4<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
fw, sh, sw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
} else { | |||
img2col_stride_nchw4<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, | |||
fh, fw, sh, sw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
} | |||
} else { | |||
if (sh == sw && sh == 1) { | |||
img2col_nchw4<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
fw, sh, sw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
} else { | |||
img2col_stride_nchw4<false>( | |||
src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, sh, sw, | |||
sparam.ohw_cur_index, sparam.output_block_size); | |||
} | |||
} | |||
matmul_param.M = sparam.output_block_oc_size; | |||
matmul_param.N = sparam.output_block_size; | |||
matmul_param.LDB = pack_size * sparam.output_block_size; | |||
matmul_param.LDC = pack_size * sparam.output_block_size; | |||
matmul_param.B_ptr = im2col_dst; | |||
src_ctype* b_panel = | |||
reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>( | |||
bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX))); | |||
matmul_algo->pack_B(matmul_param, b_panel, 0, matmul_param.N); | |||
} | |||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
_op_dtype, _postprocess_mode) \ | |||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
_op_dtype, _postprocess_mode, PackMode::DEFAULT, \ | |||
FormatMode::NCHW44>; | |||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | |||
megdnn::PostprocessMode::FLOAT) | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#endif | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||
megdnn::PostprocessMode::QUANTIZED) | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#undef INSTANTIAL_CLASS | |||
} // namespace megdnn |
@@ -9,8 +9,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/opr_param_defs.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | |||
#include "src/fallback/convolution/img2col_helper.h" | |||
#if MEGDNN_X86 | |||
@@ -22,22 +20,16 @@ using namespace megdnn; | |||
using namespace x86; | |||
#endif | |||
namespace megdnn { | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::NO_PACK>::Strategy() | |||
: StrategyBase() {} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::NO_PACK>:: | |||
copy_padding_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
copy_padding_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t) { | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
MEGDNN_MARK_USED_VAR(N); | |||
MEGDNN_MARK_USED_VAR(OC); | |||
@@ -96,12 +88,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::NO_PACK>:: | |||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t) { | |||
MEGDNN_MARK_USED_VAR(bundle); | |||
MEGDNN_MARK_USED_VAR(param); | |||
MEGDNN_MARK_USED_VAR(matmulparam); | |||
@@ -115,7 +108,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::NO_PACK>:: | |||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread, | |||
const StrategyParam& sparam) { | |||
@@ -134,7 +127,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::NO_PACK>:: | |||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, WorkspaceBundle bundle, | |||
WorkspaceBundle bundle_thread, | |||
@@ -167,29 +160,28 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::NO_PACK>:: | |||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo | |||
) { | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
MEGDNN_MARK_USED_VAR(matmul_param); | |||
MEGDNN_MARK_USED_VAR(matmul_algo); | |||
size_t m_sh = param.filter_meta.stride[0]; | |||
size_t m_sw = param.filter_meta.stride[1]; | |||
size_t m_oc = param.filter_meta.ocpg; | |||
size_t m_oh = param.osz[0]; | |||
size_t m_ow = param.osz[1]; | |||
size_t m_ic = param.filter_meta.icpg; | |||
size_t m_ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
size_t m_iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||
size_t m_fh = param.filter_meta.spatial[0]; | |||
size_t m_fw = param.filter_meta.spatial[1]; | |||
size_t m_is_xcorr = !param.filter_meta.should_flip; | |||
size_t sh = param.filter_meta.stride[0]; | |||
size_t sw = param.filter_meta.stride[1]; | |||
size_t oc = param.filter_meta.ocpg; | |||
size_t oh = param.osz[0]; | |||
size_t ow = param.osz[1]; | |||
size_t ic = param.filter_meta.icpg; | |||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||
size_t fh = param.filter_meta.spatial[0]; | |||
size_t fw = param.filter_meta.spatial[1]; | |||
size_t is_xcorr = !param.filter_meta.should_flip; | |||
size_t input_offset = | |||
m_ih * m_iw * m_ic * | |||
ih * iw * ic * | |||
(sparam.group_id + param.filter_meta.group * sparam.batch_id) * | |||
sizeof(src_ctype); | |||
@@ -205,26 +197,22 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
} | |||
src_ctype* im2col_dst = static_cast<src_ctype*>( | |||
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | |||
if (m_sh == 1 && m_sw == 1) { | |||
if (m_is_xcorr) { | |||
img2col<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||
m_fh, m_fw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
if (sh == 1 && sw == 1) { | |||
if (is_xcorr) { | |||
img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||
sparam.ohw_cur_index, sparam.output_block_size); | |||
} else { | |||
img2col<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||
m_fh, m_fw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||
sparam.ohw_cur_index, sparam.output_block_size); | |||
} | |||
} else { | |||
if (m_is_xcorr) { | |||
img2col_stride<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, | |||
m_iw, m_fh, m_fw, m_sh, m_sw, | |||
sparam.ohw_cur_index, | |||
if (is_xcorr) { | |||
img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
fw, sh, sw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
} else { | |||
img2col_stride<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, | |||
m_ih, m_iw, m_fh, m_fw, m_sh, m_sw, | |||
sparam.ohw_cur_index, | |||
img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
fw, sh, sw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
} | |||
} | |||
@@ -234,7 +222,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::NO_PACK>:: | |||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, | |||
WorkspaceBundle bundle_thread) { | |||
@@ -262,7 +250,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::NO_PACK>:: | |||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const void* matmul_dst, const StrategyParam& sparam) { | |||
if (!sparam.skip_copy_dst) { | |||
@@ -284,7 +272,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::NO_PACK>:: | |||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | |||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||
@@ -305,31 +293,20 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
} | |||
} | |||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
_op_dtype, _postprocess_mode) \ | |||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
_op_dtype, _postprocess_mode, PackMode::NO_PACK>; | |||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
_op_dtype, _postprocess_mode) \ | |||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
_op_dtype, _postprocess_mode, PackMode::NO_PACK, \ | |||
FormatMode::NCHW>; | |||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | |||
megdnn::PostprocessMode::FLOAT) | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||
megdnn::PostprocessMode::FLOAT) | |||
#else | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#endif | |||
#endif | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||
megdnn::PostprocessMode::QUANTIZED) | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#endif | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||
megdnn::PostprocessMode::QUANTIZED) | |||
@@ -9,7 +9,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/opr_param_defs.h" | |||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | |||
#include "src/fallback/convolution/img2col_helper.h" | |||
#if MEGDNN_X86 | |||
@@ -21,22 +20,16 @@ using namespace megdnn; | |||
using namespace x86; | |||
#endif | |||
namespace megdnn { | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::ONLY_PACKA>::Strategy() | |||
: StrategyBase() {} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::ONLY_PACKA>:: | |||
copy_padding_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||
copy_padding_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t) { | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
MEGDNN_MARK_USED_VAR(N); | |||
MEGDNN_MARK_USED_VAR(OC); | |||
@@ -95,12 +88,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::ONLY_PACKA>:: | |||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||
packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t) { | |||
bundle.set(param.workspace_ptr); | |||
fallback::MatrixMulImpl::KernParam matmul_param; | |||
static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) = | |||
@@ -128,7 +122,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::ONLY_PACKA>:: | |||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread, | |||
const StrategyParam& sparam) { | |||
@@ -147,7 +141,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::ONLY_PACKA>:: | |||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, WorkspaceBundle bundle, | |||
WorkspaceBundle bundle_thread, | |||
@@ -185,29 +179,28 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::ONLY_PACKA>:: | |||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo | |||
) { | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
MEGDNN_MARK_USED_VAR(matmul_param); | |||
MEGDNN_MARK_USED_VAR(matmul_algo); | |||
size_t m_sh = param.filter_meta.stride[0]; | |||
size_t m_sw = param.filter_meta.stride[1]; | |||
size_t m_oc = param.filter_meta.ocpg; | |||
size_t m_oh = param.osz[0]; | |||
size_t m_ow = param.osz[1]; | |||
size_t m_ic = param.filter_meta.icpg; | |||
size_t m_ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
size_t m_iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||
size_t m_fh = param.filter_meta.spatial[0]; | |||
size_t m_fw = param.filter_meta.spatial[1]; | |||
size_t m_is_xcorr = !param.filter_meta.should_flip; | |||
size_t sh = param.filter_meta.stride[0]; | |||
size_t sw = param.filter_meta.stride[1]; | |||
size_t oc = param.filter_meta.ocpg; | |||
size_t oh = param.osz[0]; | |||
size_t ow = param.osz[1]; | |||
size_t ic = param.filter_meta.icpg; | |||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||
size_t fh = param.filter_meta.spatial[0]; | |||
size_t fw = param.filter_meta.spatial[1]; | |||
size_t is_xcorr = !param.filter_meta.should_flip; | |||
size_t input_offset = | |||
m_ih * m_iw * m_ic * | |||
ih * iw * ic * | |||
(sparam.group_id + param.filter_meta.group * sparam.batch_id) * | |||
sizeof(src_ctype); | |||
@@ -222,26 +215,22 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
} | |||
src_ctype* im2col_dst = static_cast<src_ctype*>( | |||
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | |||
if (m_sh == 1 && m_sw == 1) { | |||
if (m_is_xcorr) { | |||
img2col<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||
m_fh, m_fw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
if (sh == 1 && sw == 1) { | |||
if (is_xcorr) { | |||
img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||
sparam.ohw_cur_index, sparam.output_block_size); | |||
} else { | |||
img2col<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||
m_fh, m_fw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||
sparam.ohw_cur_index, sparam.output_block_size); | |||
} | |||
} else { | |||
if (m_is_xcorr) { | |||
img2col_stride<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, | |||
m_iw, m_fh, m_fw, m_sh, m_sw, | |||
sparam.ohw_cur_index, | |||
if (is_xcorr) { | |||
img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
fw, sh, sw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
} else { | |||
img2col_stride<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, | |||
m_ih, m_iw, m_fh, m_fw, m_sh, m_sw, | |||
sparam.ohw_cur_index, | |||
img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
fw, sh, sw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
} | |||
} | |||
@@ -251,7 +240,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::ONLY_PACKA>:: | |||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, | |||
WorkspaceBundle bundle_thread) { | |||
@@ -292,7 +281,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode,PackMode::ONLY_PACKA>:: | |||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const void* matmul_dst, const StrategyParam& sparam) { | |||
if (!sparam.skip_copy_dst) { | |||
@@ -310,31 +299,20 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
} | |||
} | |||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
_op_dtype, _postprocess_mode) \ | |||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||
_op_ctype, _op_dtype, _postprocess_mode,PackMode::ONLY_PACKA>; | |||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
_op_dtype, _postprocess_mode) \ | |||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
_op_dtype, _postprocess_mode, \ | |||
PackMode::ONLY_PACKA, FormatMode::NCHW>; | |||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | |||
megdnn::PostprocessMode::FLOAT) | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||
megdnn::PostprocessMode::FLOAT) | |||
#else | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#endif | |||
#endif | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||
megdnn::PostprocessMode::QUANTIZED) | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#endif | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||
megdnn::PostprocessMode::QUANTIZED) | |||
@@ -9,7 +9,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/common/utils.h" | |||
namespace { | |||
template <bool is_xcorr, typename dtype> | |||
@@ -41,7 +40,326 @@ void img2col_stride(const dtype* __restrict src, dtype* __restrict dst, | |||
} | |||
} | |||
//!add for im2col matmul multithread | |||
// | |||
template <bool is_xcorr, typename dtype> | |||
void img2col_stride_nchw4(const dtype* __restrict src, dtype* __restrict dst, | |||
const int OC, const int OH, const int OW, const int IC, | |||
const int IH, const int IW, const int FH, const int FW, | |||
const int SH, const int SW, const int cur_index, | |||
const int block_size) { | |||
MEGDNN_MARK_USED_VAR(OC); | |||
MEGDNN_MARK_USED_VAR(OH); | |||
int start_h = cur_index / OW; | |||
int cur_remain_w = cur_index % OW; | |||
int end_h = (cur_index + block_size) / OW; | |||
int end_remain_w = (cur_index + block_size) % OW; | |||
bool same_line = false; | |||
if (start_h == end_h) { | |||
same_line = true; | |||
} | |||
size_t newIC = IC / 4; | |||
size_t i = 0; | |||
if (sizeof(dtype) != 1) { | |||
if (same_line) { | |||
rep(ic, newIC) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2, fw2; | |||
if (is_xcorr) { | |||
fh2 = fh; | |||
fw2 = fw; | |||
} else { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
for (int w = cur_remain_w; w < end_remain_w; w++) { | |||
size_t index = 4 * (ic * IH * IW + | |||
(start_h * SH + fh2) * IW + | |||
(w * SW + fw2)); | |||
dst[i++] = src[index]; | |||
dst[i++] = src[index + 1]; | |||
dst[i++] = src[index + 2]; | |||
dst[i++] = src[index + 3]; | |||
} | |||
} | |||
} | |||
} | |||
} else { | |||
rep(ic, newIC) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2, fw2; | |||
if (is_xcorr) { | |||
fh2 = fh; | |||
fw2 = fw; | |||
} else { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
for (int w = cur_remain_w; w < OW; w++) { | |||
size_t index =4 * (ic * IH * IW + | |||
(start_h * SH + fh2) * IW + | |||
(w * SW + fw2)); | |||
dst[i++] = src[index + 0]; | |||
dst[i++] = src[index + 1]; | |||
dst[i++] = src[index + 2]; | |||
dst[i++] = src[index + 3]; | |||
} | |||
for (int h = start_h + 1; h < end_h; h++) { | |||
rep(ow, OW) { | |||
size_t index = 4 * (ic * IH * IW + | |||
(h * SH + fh2) * IW + | |||
(ow * SW + fw2)); | |||
dst[i++] = src[index + 0]; | |||
dst[i++] = src[index + 1]; | |||
dst[i++] = src[index + 2]; | |||
dst[i++] = src[index + 3]; | |||
} | |||
} | |||
for (int w = 0; w < end_remain_w; w++) { | |||
size_t index = 4 * (ic * IH * IW + | |||
(end_h * SH + fh2) * IW + | |||
(w * SW + fw2)); | |||
dst[i++] = src[index + 0]; | |||
dst[i++] = src[index + 1]; | |||
dst[i++] = src[index + 2]; | |||
dst[i++] = src[index + 3]; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} else { | |||
uint32_t* output = nullptr; | |||
const uint32_t* uint32_src = | |||
static_cast<const uint32_t*>(static_cast<const void*>(src)); | |||
output = static_cast<uint32_t*>(static_cast<void*>(dst)); | |||
if (same_line) { | |||
rep(ic, newIC) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2, fw2; | |||
if (is_xcorr) { | |||
fh2 = fh; | |||
fw2 = fw; | |||
} else { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
size_t index = | |||
(ic * IH * IW + (start_h * SH + fh2) * IW + | |||
(cur_remain_w * SW + fw2)); | |||
for (int w = cur_remain_w; w < end_remain_w; w++) { | |||
output[i++] = uint32_src[index]; | |||
index += SW; | |||
} | |||
} | |||
} | |||
} | |||
} else { | |||
rep(ic, newIC) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2, fw2; | |||
if (is_xcorr) { | |||
fh2 = fh; | |||
fw2 = fw; | |||
} else { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
size_t index = ic * IH * IW + | |||
(start_h * SH + fh2) * IW + | |||
cur_remain_w * SW + fw2; | |||
for (int w = cur_remain_w; w < OW; w++) { | |||
output[i++] = uint32_src[index]; | |||
index += SW; | |||
} | |||
for (int h = start_h + 1; h < end_h; h++) { | |||
index = ic * IH * IW + (h * SH + fh2) * IW + fw2; | |||
rep(ow, OW) { | |||
output[i++] = uint32_src[index]; | |||
index += SW; | |||
} | |||
} | |||
index = ic * IH * IW + (end_h * SH + fh2) * IW + fw2; | |||
for (int w = 0; w < end_remain_w; w++) { | |||
output[i++] = uint32_src[index]; | |||
index += SW; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
template <bool is_xcorr, typename dtype> | |||
void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst, | |||
const int OC, const int OH, const int OW, const int IC, | |||
const int IH, const int IW, const int FH, const int FW, | |||
const int SH, const int SW, const int cur_index, | |||
const int block_size) { | |||
MEGDNN_MARK_USED_VAR(OC); | |||
MEGDNN_MARK_USED_VAR(OH); | |||
MEGDNN_MARK_USED_VAR(SH); | |||
MEGDNN_MARK_USED_VAR(SW); | |||
int start_h = cur_index / OW; | |||
int cur_remain_w = cur_index % OW; | |||
int end_h = (cur_index + block_size) / OW; | |||
int end_remain_w = (cur_index + block_size) % OW; | |||
bool same_line = false; | |||
if (start_h == end_h) { | |||
same_line = true; | |||
} | |||
size_t newIC = IC / 4; | |||
size_t i = 0; | |||
if (sizeof(dtype) != 1) { | |||
if (same_line) { | |||
rep(ic, newIC) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2, fw2; | |||
if (is_xcorr) { | |||
fh2 = fh; | |||
fw2 = fw; | |||
} else { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
for (int w = cur_remain_w; w < end_remain_w; w++) { | |||
size_t index = | |||
4 * (ic * IH * IW + (start_h + fh2) * IW + | |||
(w + fw2)); | |||
dst[i++] = src[index]; | |||
dst[i++] = src[index + 1]; | |||
dst[i++] = src[index + 2]; | |||
dst[i++] = src[index + 3]; | |||
} | |||
} | |||
} | |||
} | |||
} else { | |||
rep(ic, newIC) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2, fw2; | |||
if (is_xcorr) { | |||
fh2 = fh; | |||
fw2 = fw; | |||
} else { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
for (int w = cur_remain_w; w < OW; w++) { | |||
size_t index = ic * IH * IW + (start_h + fh2) * IW + | |||
(w + fw2); | |||
dst[i++] = src[4 * index]; | |||
dst[i++] = src[4 * index + 1]; | |||
dst[i++] = src[4 * index + 2]; | |||
dst[i++] = src[4 * index + 3]; | |||
} | |||
for (int h = start_h + 1; h < end_h; h++) { | |||
rep(ow, OW) { | |||
size_t index = | |||
4 * (ic * IH * IW + (h + fh2) * IW + | |||
(ow + fw2)); | |||
dst[i++] = src[index + 0]; | |||
dst[i++] = src[index + 1]; | |||
dst[i++] = src[index + 2]; | |||
dst[i++] = src[index + 3]; | |||
} | |||
} | |||
for (int w = 0; w < end_remain_w; w++) { | |||
size_t index = 4 * (ic * IH * IW + | |||
(end_h + fh2) * IW + (w + fw2)); | |||
dst[i++] = src[index + 0]; | |||
dst[i++] = src[index + 1]; | |||
dst[i++] = src[index + 2]; | |||
dst[i++] = src[index + 3]; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} else { | |||
uint32_t* output = nullptr; | |||
const uint32_t* uint32_src = | |||
static_cast<const uint32_t*>(static_cast<const void*>(src)); | |||
output = static_cast<uint32_t*>(static_cast<void*>(dst)); | |||
if (same_line) { | |||
rep(ic, newIC) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2, fw2; | |||
if (is_xcorr) { | |||
fh2 = fh; | |||
fw2 = fw; | |||
} else { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
for (int w = cur_remain_w; w < end_remain_w; w++) { | |||
size_t index = (ic * IH * IW + | |||
(start_h + fh2) * IW + (w + fw2)); | |||
output[i++] = uint32_src[index]; | |||
} | |||
} | |||
} | |||
} | |||
} else { | |||
rep(ic, newIC) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2, fw2; | |||
if (is_xcorr) { | |||
fh2 = fh; | |||
fw2 = fw; | |||
} else { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
for (int w = cur_remain_w; w < OW; w++) { | |||
size_t index = ic * IH * IW + (start_h + fh2) * IW + | |||
(w + fw2); | |||
output[i++] = uint32_src[index]; | |||
} | |||
for (int h = start_h + 1; h < end_h; h++) { | |||
rep(ow, OW) { | |||
size_t index = (ic * IH * IW + (h + fh2) * IW + | |||
(ow + fw2)); | |||
output[i++] = uint32_src[index]; | |||
} | |||
} | |||
for (int w = 0; w < end_remain_w; w++) { | |||
size_t index = (ic * IH * IW + (end_h + fh2) * IW + | |||
(w + fw2)); | |||
output[i++] = uint32_src[index]; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
template <bool is_xcorr, typename dtype> | |||
void img2col_stride(const dtype* __restrict src, dtype* __restrict dst, | |||
@@ -124,7 +124,8 @@ struct PostProcess { | |||
megdnn::ConvBiasForward::BiasMode bias_mode, | |||
megdnn::param::ConvBias::NonlineMode nonlineMode, | |||
DType bias_type, DType dst_type, size_t N, size_t OC, | |||
size_t OH, size_t OW) { | |||
size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
megdnn::param::Elemwise::Mode elem_mode = | |||
megdnn::param::Elemwise::Mode::ADD; | |||
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||
@@ -154,7 +155,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::FLOAT> { | |||
megdnn::ConvBiasForward::BiasMode bias_mode, | |||
megdnn::param::ConvBias::NonlineMode nonlineMode, | |||
DType bias_type, DType dst_type, size_t N, size_t OC, | |||
size_t OH, size_t OW) { | |||
size_t OH, size_t OW, size_t pack_oc_size=1) { | |||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
megdnn::param::Elemwise::Mode elem_mode = | |||
megdnn::param::Elemwise::Mode::ADD; | |||
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||
@@ -185,7 +187,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
megdnn::ConvBiasForward::BiasMode bias_mode, | |||
megdnn::param::ConvBias::NonlineMode nonlineMode, | |||
DType bias_type, DType dst_type, size_t N, size_t OC, | |||
size_t OH, size_t OW) { | |||
size_t OH, size_t OW,size_t pack_oc_size = 1) { | |||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
MEGDNN_MARK_USED_VAR(conv_dst_ptr); | |||
MEGDNN_MARK_USED_VAR(bias_ptr); | |||
MEGDNN_MARK_USED_VAR(dst_ptr); | |||
@@ -292,7 +295,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> { | |||
megdnn::ConvBiasForward::BiasMode bias_mode, | |||
megdnn::param::ConvBiasV0::NonlineMode nonlineMode, | |||
DType bias_type, DType dst_type, size_t N, size_t OC, | |||
size_t OH, size_t OW) { | |||
size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
megdnn::param::Elemwise::Mode elem_mode = | |||
megdnn::param::Elemwise::Mode::ADD; | |||
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||