Browse Source

feat(dnn/fallback): add im2col none dot int8 nchw44 support

GitOrigin-RevId: d326035202
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
d345c86277
10 changed files with 834 additions and 461 deletions
  1. +57
    -38
      dnn/src/fallback/conv_bias/im2col/algos.cpp
  2. +0
    -1
      dnn/src/fallback/conv_bias/im2col/algos.h
  3. +103
    -165
      dnn/src/fallback/conv_bias/im2col/factory.h
  4. +72
    -36
      dnn/src/fallback/conv_bias/im2col/strategy_base.h
  5. +72
    -86
      dnn/src/fallback/conv_bias/im2col/strategy_default.cpp
  6. +118
    -0
      dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp
  7. +43
    -66
      dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp
  8. +42
    -64
      dnn/src/fallback/conv_bias/im2col/strategy_onlypacka.cpp
  9. +319
    -1
      dnn/src/fallback/convolution/img2col_helper.h
  10. +8
    -4
      dnn/src/x86/conv_bias/postprocess_helper.h

+ 57
- 38
dnn/src/fallback/conv_bias/im2col/algos.cpp View File

@@ -17,18 +17,14 @@
#include "src/fallback/conv_bias/opr_impl.h" #include "src/fallback/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/winograd/strategy.h" #include "src/fallback/conv_bias/winograd/strategy.h"
#include "src/naive/convolution/helper.h" #include "src/naive/convolution/helper.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#endif

#include "midout.h" #include "midout.h"

MIDOUT_DECL(megdnn_fallback_im2col) MIDOUT_DECL(megdnn_fallback_im2col)


using namespace megdnn; using namespace megdnn;
using namespace fallback; using namespace fallback;
using namespace im2col; using namespace im2col;
#if MEGDNN_X86
using namespace x86;
#endif


/*======================== AlgoIm2col=======================*/ /*======================== AlgoIm2col=======================*/
/*! /*!
@@ -47,8 +43,8 @@ using Pack_Mode=fallback::MatrixMulImpl::AlgoBase::PackMode;
static void copy_padding_kern(WorkspaceBundle bundle, static void copy_padding_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& param, const ConvBiasImpl::NCBKernParam& param,
const ConvBiasImpl::NCBKernIndex& ncb_index, const ConvBiasImpl::NCBKernIndex& ncb_index,
StrategyBase* im2colstrategy) {
im2colstrategy->copy_padding_kern(bundle, param, ncb_index);
StrategyBase* im2colstrategy, size_t pack_oc_size) {
im2colstrategy->copy_padding_kern(bundle, param, ncb_index, pack_oc_size);
} }


//! packA_kern //! packA_kern
@@ -57,9 +53,9 @@ static void packA_kern(WorkspaceBundle bundle,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
StrategyBase* im2colstrategy) {
StrategyBase* im2colstrategy, size_t pack_oc_size) {
im2colstrategy->packA_kern(bundle, param, matmulparam, matmul_algo, im2colstrategy->packA_kern(bundle, param, matmulparam, matmul_algo,
ncb_index);
ncb_index, pack_oc_size);
} }


/*! /*!
@@ -129,14 +125,17 @@ public:
size_t oc_tile_size) { size_t oc_tile_size) {
size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0], size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0],
FW = param.filter_meta.spatial[1]; FW = param.filter_meta.spatial[1];
size_t pack_oc_size = 1;
size_t im2col = 0, packb = 0, bias_temp = 0; size_t im2col = 0, packb = 0, bias_temp = 0;
bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT; bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT;
megdnn_assert(default_pack, "only support default packa"); megdnn_assert(default_pack, "only support default packa");
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) {
pack_oc_size = 4;
}
size_t im2col_dst_size = size_t im2col_dst_size =
IC * FH * FW * ohw_tile_size * sizeof(param.src_type); IC * FH * FW * ohw_tile_size * sizeof(param.src_type);
size_t matmul_dst_size =
oc_tile_size * ohw_tile_size * sizeof(param.bias_type);
size_t matmul_dst_size = pack_oc_size * oc_tile_size * ohw_tile_size *
sizeof(param.bias_type);
//! matmul_dst and im2col_dst use the same memory //! matmul_dst and im2col_dst use the same memory
WorkspaceBundle wb = matmul_algo->get_bundle(im2col_kern_param); WorkspaceBundle wb = matmul_algo->get_bundle(im2col_kern_param);
packb = wb.get_size(1); packb = wb.get_size(1);
@@ -318,17 +317,18 @@ public:
} }
}; };


#undef FILL_IM2COL_STRATEGY_PARAM

fallback::MatrixMulImpl::KernSizeParam fallback::MatrixMulImpl::KernSizeParam
ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param,
size_t ohw_tile_size, size_t ohw_tile_size,
size_t oc_tile_size) const { size_t oc_tile_size) const {
bool is_nchw44 =
param.filter_meta.format == param::ConvBias::Format::NCHW44;
size_t M = oc_tile_size; size_t M = oc_tile_size;
size_t N = ohw_tile_size; size_t N = ohw_tile_size;
size_t K = param.filter_meta.icpg * param.filter_meta.spatial[0] * size_t K = param.filter_meta.icpg * param.filter_meta.spatial[0] *
param.filter_meta.spatial[1]; param.filter_meta.spatial[1];
size_t LDA = K, LDB = N, LDC = N;
size_t pack_oc_size = is_nchw44 ? 4 : 1;
size_t LDA = pack_oc_size * K, LDB = pack_oc_size * N, LDC = N;
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 && bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && (param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
@@ -345,7 +345,8 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param,
false, false,
false, false,
param::MatrixMul::ComputeMode::DEFAULT, param::MatrixMul::ComputeMode::DEFAULT,
param::MatrixMul::Format::DEFAULT};
is_nchw44 ? param::MatrixMul::Format::MK4
: param::MatrixMul::Format::DEFAULT};
} }


void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block(
@@ -405,6 +406,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
size_t GROUP = param.filter_meta.group; size_t GROUP = param.filter_meta.group;
bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT; bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT;
bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA; bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA;

if (need_pack || only_packA) { if (need_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size(); auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, inner_block.m, inner_block.n, need_pack); choice_ohw_oc_block(param, inner_block.m, inner_block.n, need_pack);
@@ -421,16 +423,19 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
need_pack); need_pack);
packa_group_size = 0; packa_group_size = 0;
} }

if (no_need_pading) { if (no_need_pading) {
padding = 0; //! not need padding padding = 0; //! not need padding
} else { } else {
padding = (GROUP * N * IC * IH2 * IW2) * padding = (GROUP * N * IC * IH2 * IW2) *
sizeof(param.src_type); //! for padding sizeof(param.src_type); //! for padding
} }

packa_size = GROUP * packa_group_size; //! for packA size = GROUP * a_size packa_size = GROUP * packa_group_size; //! for packA size = GROUP * a_size
WorkspaceBundle ws = {nullptr, {}}; WorkspaceBundle ws = {nullptr, {}};
auto im2col_kern_param = auto im2col_kern_param =
get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size); get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size);

if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) { if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) {
Im2colKerns<Pack_Mode::DEFAULT> defaultkern; Im2colKerns<Pack_Mode::DEFAULT> defaultkern;
ws = defaultkern.get_thread_bundle(param, im2col_kern_param, ws = defaultkern.get_thread_bundle(param, im2col_kern_param,
@@ -447,6 +452,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
m_matmul_algo, m_ohw_tile_size, m_matmul_algo, m_ohw_tile_size,
m_oc_tile_size); m_oc_tile_size);
} }

return {nullptr, return {nullptr,
{padding, packa_size, ws.total_size_in_bytes() * nr_threads}}; {padding, packa_size, ws.total_size_in_bytes() * nr_threads}};
} }
@@ -461,7 +467,7 @@ size_t ConvBiasImpl::AlgoIm2col::get_workspace(
} }


SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
ConvBiasImpl* opr, const NCBKernSizeParam& param) const {
ConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 1) { MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 1) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param); UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(SH); MEGDNN_MARK_USED_VAR(SH);
@@ -473,7 +479,6 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
size_t ohw = OH * OW; size_t ohw = OH * OW;
size_t ohw_parallel_times = div_ceil(ohw, m_ohw_tile_size); size_t ohw_parallel_times = div_ceil(ohw, m_ohw_tile_size);
size_t GROUP = param.filter_meta.group; size_t GROUP = param.filter_meta.group;

WorkspaceBundle bundle = get_bundle(param); WorkspaceBundle bundle = get_bundle(param);
WorkspaceBundle bundle_thread = {nullptr, {}}; WorkspaceBundle bundle_thread = {nullptr, {}};
size_t oc_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size); size_t oc_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size);
@@ -483,11 +488,14 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
bool no_pack = packmode == Pack_Mode::NO_PACK; bool no_pack = packmode == Pack_Mode::NO_PACK;
bool only_packA = packmode == Pack_Mode::ONLY_PACKA; bool only_packA = packmode == Pack_Mode::ONLY_PACKA;
size_t packa_parallel_times = 0; size_t packa_parallel_times = 0;
size_t pack_oc_size =
(param.filter_meta.format == param::ConvBias::Format::NCHW ? 1
: 4);
if (only_packA) { if (only_packA) {
packa_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size); packa_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size);
} else if (default_pack) { } else if (default_pack) {
packa_parallel_times = div_ceil<size_t>( packa_parallel_times = div_ceil<size_t>(
OC, m_matmul_algo->get_inner_block_size().m);
OC, m_matmul_algo->get_inner_block_size().m * pack_oc_size);
} }


auto matmul_param = get_matmul_kern_param( auto matmul_param = get_matmul_kern_param(
@@ -520,25 +528,29 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
strategyparam.skip_copy_dst = strategyparam.skip_copy_dst =
strategyparam.is_ohw_size_bigger && !strategyparam.is_dst_8bit; strategyparam.is_ohw_size_bigger && !strategyparam.is_dst_8bit;
strategyparam.oc_tile_size = m_oc_tile_size; strategyparam.oc_tile_size = m_oc_tile_size;
strategyparam.pack_oc_size = pack_oc_size;


SmallVector<ConvBiasImpl::NCBKern> ret_kern; SmallVector<ConvBiasImpl::NCBKern> ret_kern;
MIDOUT_BEGIN( MIDOUT_BEGIN(
megdnn_fallback_im2col, megdnn_fallback_im2col,
midout_iv("ConvBiasImpl::AlgoIm2col::dispatch_kerns"_hash)) { midout_iv("ConvBiasImpl::AlgoIm2col::dispatch_kerns"_hash)) {
StrategyBase* im2colstrategy = Factory::get_im2col_strategy(
param, m_matmul_algo, opr->param().format);
auto kern_padding = [bundle, im2colstrategy](
StrategyBase* im2colstrategy =
Factory::get_im2col_strategy(param, m_matmul_algo);
auto kern_padding = [bundle, im2colstrategy,
pack_oc_size = pack_oc_size](
const NCBKernParam& param, const NCBKernParam& param,
const NCBKernIndex& ncb_index) { const NCBKernIndex& ncb_index) {
copy_padding_kern(bundle, param, ncb_index, im2colstrategy);
copy_padding_kern(bundle, param, ncb_index, im2colstrategy,
pack_oc_size);
}; };


auto kern_packA = [bundle, matmul_algo = m_matmul_algo, auto kern_packA = [bundle, matmul_algo = m_matmul_algo,
matmul_param,
im2colstrategy](const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
matmul_param, im2colstrategy,
pack_oc_size = pack_oc_size](
const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
packA_kern(bundle, param, matmul_param, matmul_algo, ncb_index, packA_kern(bundle, param, matmul_param, matmul_algo, ncb_index,
im2colstrategy);
im2colstrategy, pack_oc_size);
}; };
if (default_pack) { if (default_pack) {
auto kern_compute_default = auto kern_compute_default =
@@ -556,7 +568,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}}); ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}});


if (need_padding) { if (need_padding) {
ret_kern.push_back({kern_padding, {param.n, GROUP, IC}});
ret_kern.push_back({kern_padding,
{param.n, GROUP, IC / pack_oc_size}});
} }
ret_kern.push_back( ret_kern.push_back(
{kern_compute_default, {kern_compute_default,
@@ -629,19 +642,25 @@ bool ConvBiasImpl::AlgoIm2col::usable(
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
return false; return false;
} }
//! current now im2col only support int8 quantized s8 nchw44
if (opr->param().format == param::ConvBias::Format::NCHW44 &&
(param.src_type.enumv() == param.filter_type.enumv() &&
(param.src_type.enumv() != DTypeEnum::Int8) &&
(param.src_type.enumv() != DTypeEnum::QuantizedS8))) {
return false;
}

fallback::MatrixMulImpl::KernSizeParam matmul_param = fallback::MatrixMulImpl::KernSizeParam matmul_param =
get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size); get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size);
bool matmulusable = m_matmul_algo->usable(matmul_param); bool matmulusable = m_matmul_algo->usable(matmul_param);
return matmulusable && return matmulusable &&
(opr->param().format == param::ConvBias::Format::NCHW) &&
((param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
(param.filter_meta.spatial[0] <= 7) &&
(param.filter_meta.spatial[0] >= 2)) ||
(param.filter_meta.spatial[0] != param.filter_meta.spatial[1] &&
(param.filter_meta.spatial[0] <= 7) &&
(param.filter_meta.spatial[0] >= 1) &&
(param.filter_meta.spatial[1] <= 7) &&
(param.filter_meta.spatial[1] >= 1))) &&
(opr->param().format == param::ConvBias::Format::NCHW ||
opr->param().format == param::ConvBias::Format::NCHW44) &&
(!(param.filter_meta.spatial[0] ==
param.filter_meta.spatial[1] &&
(param.filter_meta.spatial[0] == 1) &&
param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1)) &&
(param.filter_meta.dilation[0] == (param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] && param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) && param.filter_meta.dilation[0] == 1) &&


+ 0
- 1
dnn/src/fallback/conv_bias/im2col/algos.h View File

@@ -36,7 +36,6 @@ class ConvBiasImpl::AlgoIm2col final : public AlgoBase {
const NCBKernSizeParam& param, size_t ohw_tile_size, const NCBKernSizeParam& param, size_t ohw_tile_size,
size_t oc_tile_size) const; size_t oc_tile_size) const;
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const;
WorkspaceBundle get_thread_bundle(const NCBKernSizeParam& param) const;
void choice_ohw_oc_block(const NCBKernSizeParam& param, size_t block_m, void choice_ohw_oc_block(const NCBKernSizeParam& param, size_t block_m,
size_t block_n, bool pack_default) const; size_t block_n, bool pack_default) const;




+ 103
- 165
dnn/src/fallback/conv_bias/im2col/factory.h View File

@@ -23,19 +23,11 @@ namespace im2col {


enum class StrategyType : uint32_t { enum class StrategyType : uint32_t {
FLOAT = 0, FLOAT = 0,
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
FLOAT_FP16 = 1,
#else
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
FLOAT16_FLOAT16 = 2, FLOAT16_FLOAT16 = 2,
#endif #endif
#endif
INT8x8x32 = 3, INT8x8x32 = 3,
INT8x8x16 = 4, INT8x8x16 = 4,
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
QUINT8x8x32 = 5,
QUINT8x8x32x8 = 6,
#endif
QINT8x8x32 = 7, QINT8x8x32 = 7,
QINT8x8x32x8 = 8 QINT8x8x32x8 = 8
}; };
@@ -107,8 +99,7 @@ public:
~StrategyDelegationStorage() = default; ~StrategyDelegationStorage() = default;


template <typename Strategy> template <typename Strategy>
Strategy* get(param::ConvBias::Format format,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
Strategy* get(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernSizeParam& param, const fallback::ConvBiasImpl::NCBKernSizeParam& param,
StrategyType stype); StrategyType stype);
}; };
@@ -117,12 +108,10 @@ class Factory {
public: public:
static StrategyBase* get_im2col_strategy( static StrategyBase* get_im2col_strategy(
const fallback::ConvBiasImpl::NCBKernSizeParam& param, const fallback::ConvBiasImpl::NCBKernSizeParam& param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
param::ConvBias::Format format) {
fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
static StrategyDelegationStorage storage; static StrategyDelegationStorage storage;
StrategyType strategytype = get_strategy_type(param); StrategyType strategytype = get_strategy_type(param);
return storage.get<StrategyBase>(format, matmul_algo, param,
strategytype);
return storage.get<StrategyBase>(matmul_algo, param, strategytype);
} }


static StrategyType get_strategy_type( static StrategyType get_strategy_type(
@@ -141,13 +130,9 @@ public:
} }


cb1(dt_float32, dt_float32, StrategyType::FLOAT); cb1(dt_float32, dt_float32, StrategyType::FLOAT);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
cb1(dt_float16, __fp16, StrategyType::FLOAT_FP16);
#else
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
cb1(dt_float16, dt_float16, StrategyType::FLOAT16_FLOAT16); cb1(dt_float16, dt_float16, StrategyType::FLOAT16_FLOAT16);
#endif #endif
#endif


cb2(dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32, cb2(dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32,
StrategyType::INT8x8x32); StrategyType::INT8x8x32);
@@ -155,13 +140,6 @@ public:
cb2(dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, dt_int16, cb2(dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, dt_int16,
StrategyType::INT8x8x16); StrategyType::INT8x8x16);


#if MEGDNN_AARCH64 || MEGDNN_ARMV7
cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::QuantizedS32,
dt_uint8, dt_int32, dt_int32, StrategyType::QUINT8x8x32);

cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::Quantized8Asymm,
dt_uint8, dt_int32, dt_uint8, StrategyType::QUINT8x8x32x8);
#endif
cb2(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32, cb2(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32,
dt_int8, dt_int32, dt_int32, StrategyType::QINT8x8x32); dt_int8, dt_int32, dt_int32, StrategyType::QINT8x8x32);


@@ -172,98 +150,106 @@ public:
megdnn_throw("not support datatype in im2col strategy\n"); megdnn_throw("not support datatype in im2col strategy\n");
} }


#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \
midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
return std::make_unique< \
Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \
_postprocess_mode, PackMode::_packmode>>(); \
} \
} \
MIDOUT_END(); \
#define cb1(_format, _packmode, _dt, _post_ctype, _postprocess_mode, \
_midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \
midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
return std::make_unique< \
Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \
_postprocess_mode, PackMode::_packmode, \
FormatMode::_format>>(); \
} \
} \
MIDOUT_END(); \
return {}; return {};


#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \
midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == param.src_type.enumv() && \
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
return std::make_unique< \
Strategy<_src_ctype, _bias_ctype, _dst_ctype, \
DTypeTrait<_i_bias_type>::ctype, \
DTypeTrait<_i_dst_type>::ctype, \
_postprocess_mode, PackMode::_packmode>>(); \
} \
} \
MIDOUT_END(); \
#define cb2(_format, _packmode, _i_src_type, _i_bias_type, _i_dst_type, \
_src_ctype, _bias_ctype, _dst_ctype, _postprocess_mode, \
_midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \
midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == param.src_type.enumv() && \
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
return std::make_unique<Strategy< \
_src_ctype, _bias_ctype, _dst_ctype, \
DTypeTrait<_i_bias_type>::ctype, \
DTypeTrait<_i_dst_type>::ctype, _postprocess_mode, \
PackMode::_packmode, FormatMode::_format>>(); \
} \
} \
MIDOUT_END(); \
return {}; return {};


static std::unique_ptr<StrategyBase> make_default_strategy( static std::unique_ptr<StrategyBase> make_default_strategy(
fallback::MatrixMulImpl::AlgoBase* matmul_algo, fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernSizeParam& param, const fallback::ConvBiasImpl::NCBKernSizeParam& param,
param::ConvBias::Format format, StrategyType strategytype) {
StrategyType strategytype) {
MEGDNN_MARK_USED_VAR(matmul_algo); MEGDNN_MARK_USED_VAR(matmul_algo);
MEGDNN_MARK_USED_VAR(format);
param::ConvBias::Format format = param.filter_meta.format;
switch (strategytype) { switch (strategytype) {
case StrategyType::FLOAT: case StrategyType::FLOAT:
cb1(DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT,
"DefaultStrategyType::FLOAT"_hash);
break;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case StrategyType::FLOAT_FP16:
cb1(DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT,
"DefaultStrategyType::FLOAT_FP16"_hash);
cb1(NCHW, DEFAULT, dt_float32, dt_float32,
PostprocessMode::FLOAT, "DefaultStrategyType::FLOAT"_hash);
break; break;
#else
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
case StrategyType::FLOAT16_FLOAT16: case StrategyType::FLOAT16_FLOAT16:
cb1(DEFAULT, dt_float16, dt_float16,
cb1(NCHW, DEFAULT, dt_float16, dt_float16,
PostprocessMode::NO_PROCESS, PostprocessMode::NO_PROCESS,
"DefaultStrategyType::FLOAT16_FLOAT16"_hash); "DefaultStrategyType::FLOAT16_FLOAT16"_hash);
break; break;
#endif #endif
#endif
case StrategyType::INT8x8x32: case StrategyType::INT8x8x32:
cb2(DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32,
dt_int32, PostprocessMode::NO_PROCESS,
"DefaultStrategyType::INT8x8x32"_hash);
if (format == param::ConvBias::Format::NCHW) {
cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"DefaultStrategyType::INT8x8x32"_hash);
} else if (format == param::ConvBias::Format::NCHW44) {
cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"DefaultStrategyType::INT8x8x32"_hash);
} else {
megdnn_throw("not support format except nchw44 and nchw\n");
}

break; break;


case StrategyType::INT8x8x16: case StrategyType::INT8x8x16:
cb2(DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, dt_int16,
dt_int16, PostprocessMode::NO_PROCESS,
cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
"DefaultStrategyType::INT8x8x16"_hash); "DefaultStrategyType::INT8x8x16"_hash);
break; break;
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
case StrategyType::QUINT8x8x32:
cb2(DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32,
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
"DefaultStrategyType::QUINT8x8x32"_hash);
break;

case StrategyType::QUINT8x8x32x8:
cb2(DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32,
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8,
PostprocessMode::QUANTIZED,
"DefaultStrategyType::QUINT8x8x32x8"_hash);
break;
#endif
case StrategyType::QINT8x8x32: case StrategyType::QINT8x8x32:
cb2(DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
"DefaultStrategyType::QINT8x8x32"_hash);
if (format == param::ConvBias::Format::NCHW) {
cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
"DefaultStrategyTypeNCHW::QINT8x8x32"_hash);
} else if (format == param::ConvBias::Format::NCHW44) {
cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"DefaultStrategyTypeHCHW44::QINT8x8x32"_hash);
} else {
megdnn_throw("not support format except nchw44 and nchw\n");
}
break; break;


case StrategyType::QINT8x8x32x8: case StrategyType::QINT8x8x32x8:
cb2(DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
PostprocessMode::QUANTIZED,
"DefaultStrategyType::QINT8x8x32x8"_hash);
if (format == param::ConvBias::Format::NCHW) {
cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
PostprocessMode::QUANTIZED,
"DefaultStrategyType::QINT8x8x32x8"_hash);
} else if (format == param::ConvBias::Format::NCHW44) {
cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8,
dt_int32, dt_int8, PostprocessMode::QUANTIZED,
"DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash);
} else {
megdnn_throw("not support format except nchw44 and nchw\n");
}
break; break;
} }
megdnn_throw("error not support strategy type "); megdnn_throw("error not support strategy type ");
@@ -272,63 +258,41 @@ public:
static std::unique_ptr<StrategyBase> make_nopack_strategy( static std::unique_ptr<StrategyBase> make_nopack_strategy(
fallback::MatrixMulImpl::AlgoBase* matmul_algo, fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernSizeParam& param, const fallback::ConvBiasImpl::NCBKernSizeParam& param,
param::ConvBias::Format format, StrategyType strategytype) {
StrategyType strategytype) {
MEGDNN_MARK_USED_VAR(matmul_algo); MEGDNN_MARK_USED_VAR(matmul_algo);
MEGDNN_MARK_USED_VAR(format);
switch (strategytype) { switch (strategytype) {
case StrategyType::FLOAT: case StrategyType::FLOAT:
cb1(NO_PACK, dt_float32, dt_float32, PostprocessMode::FLOAT,
"NoPackStrategyType::FLOAT"_hash);
break;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case StrategyType::FLOAT_FP16:
cb1(NO_PACK, dt_float16, __fp16, PostprocessMode::FLOAT,
"NoPackStrategyType::FLOAT_FP16"_hash);
cb1(NCHW, NO_PACK, dt_float32, dt_float32,
PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash);
break; break;
#else
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
case StrategyType::FLOAT16_FLOAT16: case StrategyType::FLOAT16_FLOAT16:
cb1(NO_PACK, dt_float16, dt_float16, PostprocessMode::NO_PROCESS,
cb1(NCHW, NO_PACK, dt_float16, dt_float16,
PostprocessMode::NO_PROCESS,
"NoPackStrategyType::FLOAT16_FLOAT16"_hash); "NoPackStrategyType::FLOAT16_FLOAT16"_hash);
break; break;
#endif #endif
#endif
case StrategyType::INT8x8x32: case StrategyType::INT8x8x32:
cb2(NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32,
dt_int32, PostprocessMode::NO_PROCESS,
cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"NoPackStrategyType::INT8x8x32"_hash); "NoPackStrategyType::INT8x8x32"_hash);
break; break;


case StrategyType::INT8x8x16: case StrategyType::INT8x8x16:
cb2(NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, dt_int16,
dt_int16, PostprocessMode::NO_PROCESS,
cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
"NoPackStrategyType::INT8x8x16"_hash); "NoPackStrategyType::INT8x8x16"_hash);
break; break;


#if MEGDNN_AARCH64 || MEGDNN_ARMV7
case StrategyType::QUINT8x8x32:
cb2(NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32,
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
"NoPackStrategyType::QUINT8x8x32"_hash);
break;

case StrategyType::QUINT8x8x32x8:
cb2(NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32,
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8,
PostprocessMode::QUANTIZED,
"NoPackStrategyType::QUINT8x8x32x8"_hash);
break;
#endif
case StrategyType::QINT8x8x32: case StrategyType::QINT8x8x32:
cb2(NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32,
cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS, PostprocessMode::NO_PROCESS,
"NoPackStrategyType::QINT8x8x32"_hash); "NoPackStrategyType::QINT8x8x32"_hash);
break; break;


case StrategyType::QINT8x8x32x8: case StrategyType::QINT8x8x32x8:
cb2(NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32,
cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
PostprocessMode::QUANTIZED, PostprocessMode::QUANTIZED,
"NoPackStrategyType::QINT8x8x32x8"_hash); "NoPackStrategyType::QINT8x8x32x8"_hash);
@@ -340,64 +304,42 @@ public:
static std::unique_ptr<StrategyBase> make_onlypacka_strategy( static std::unique_ptr<StrategyBase> make_onlypacka_strategy(
fallback::MatrixMulImpl::AlgoBase* matmul_algo, fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernSizeParam& param, const fallback::ConvBiasImpl::NCBKernSizeParam& param,
param::ConvBias::Format format, StrategyType strategytype) {
StrategyType strategytype) {
MEGDNN_MARK_USED_VAR(matmul_algo); MEGDNN_MARK_USED_VAR(matmul_algo);
MEGDNN_MARK_USED_VAR(format);
switch (strategytype) { switch (strategytype) {
case StrategyType::FLOAT: case StrategyType::FLOAT:
cb1(ONLY_PACKA, dt_float32, dt_float32, PostprocessMode::FLOAT,
cb1(NCHW, ONLY_PACKA, dt_float32, dt_float32,
PostprocessMode::FLOAT,
"OnlyPackaStrategyType::FLOAT"_hash); "OnlyPackaStrategyType::FLOAT"_hash);
break; break;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case StrategyType::FLOAT_FP16:
cb1(ONLY_PACKA, dt_float16, __fp16, PostprocessMode::FLOAT,
"OnlyPackaStrategyType::FLOAT_FP16"_hash);
break;
#else
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
case StrategyType::FLOAT16_FLOAT16: case StrategyType::FLOAT16_FLOAT16:
cb1(ONLY_PACKA, dt_float16, dt_float16,
cb1(NCHW, ONLY_PACKA, dt_float16, dt_float16,
PostprocessMode::NO_PROCESS, PostprocessMode::NO_PROCESS,
"OnlyPackaStrategyType::FLOAT16_FLOAT16"_hash); "OnlyPackaStrategyType::FLOAT16_FLOAT16"_hash);
break; break;
#endif #endif
#endif
case StrategyType::INT8x8x32: case StrategyType::INT8x8x32:
cb2(ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32,
dt_int32, PostprocessMode::NO_PROCESS,
cb2(NCHW, ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"OnlyPackaStrategyType::INT8x8x32"_hash); "OnlyPackaStrategyType::INT8x8x32"_hash);
break; break;


case StrategyType::INT8x8x16: case StrategyType::INT8x8x16:
cb2(ONLY_PACKA, dt_int8, dt_int16, dt_int16, dt_int8, dt_int16,
dt_int16, PostprocessMode::NO_PROCESS,
cb2(NCHW, ONLY_PACKA, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
"OnlyPackaStrategyType::INT8x8x16"_hash); "OnlyPackaStrategyType::INT8x8x16"_hash);
break; break;


#if MEGDNN_AARCH64 || MEGDNN_ARMV7
case StrategyType::QUINT8x8x32:
cb2(ONLY_PACKA, dtype::Quantized8Asymm, dtype::QuantizedS32,
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
"OnlyPackaStrategyType::QUINT8x8x32"_hash);
break;

case StrategyType::QUINT8x8x32x8:
cb2(ONLY_PACKA, dtype::Quantized8Asymm, dtype::QuantizedS32,
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8,
PostprocessMode::QUANTIZED,
"OnlyPackaStrategyType::QUINT8x8x32x8"_hash);
break;
#endif
case StrategyType::QINT8x8x32: case StrategyType::QINT8x8x32:
cb2(ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32,
cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS, PostprocessMode::NO_PROCESS,
"OnlyPackaStrategyType::QINT8x8x32"_hash); "OnlyPackaStrategyType::QINT8x8x32"_hash);
break; break;


case StrategyType::QINT8x8x32x8: case StrategyType::QINT8x8x32x8:
cb2(ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32,
cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
PostprocessMode::QUANTIZED, PostprocessMode::QUANTIZED,
"OnlyPackaStrategyType::QINT8x8x32x8"_hash); "OnlyPackaStrategyType::QINT8x8x32x8"_hash);
@@ -410,21 +352,19 @@ public:
#undef cb2 #undef cb2


static std::unique_ptr<StrategyBase> make_strategy( static std::unique_ptr<StrategyBase> make_strategy(
param::ConvBias::Format format,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, fallback::MatrixMulImpl::AlgoBase* matmul_algo,
fallback::MatrixMulImpl::AlgoBase::PackMode packmode, fallback::MatrixMulImpl::AlgoBase::PackMode packmode,
const fallback::ConvBiasImpl::NCBKernSizeParam& param, const fallback::ConvBiasImpl::NCBKernSizeParam& param,
StrategyType stype) { StrategyType stype) {
switch (packmode) { switch (packmode) {
case MatrixMulImpl::AlgoBase::PackMode::DEFAULT: case MatrixMulImpl::AlgoBase::PackMode::DEFAULT:
return make_default_strategy(matmul_algo, param, format, stype);
return make_default_strategy(matmul_algo, param, stype);
break; break;
case MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA: case MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA:
return make_onlypacka_strategy(matmul_algo, param, format,
stype);
return make_onlypacka_strategy(matmul_algo, param, stype);
break; break;
case MatrixMulImpl::AlgoBase::PackMode::NO_PACK: case MatrixMulImpl::AlgoBase::PackMode::NO_PACK:
return make_nopack_strategy(matmul_algo, param, format, stype);
return make_nopack_strategy(matmul_algo, param, stype);
break; break;
default: default:
megdnn_throw( megdnn_throw(
@@ -432,14 +372,12 @@ public:
"nopack"); "nopack");
break; break;
} }
megdnn_throw(
"factory make Strategy error please check your code");
megdnn_throw("factory make Strategy error please check your code");
} }
}; };


template <typename Strategy> template <typename Strategy>
Strategy* StrategyDelegationStorage::get( Strategy* StrategyDelegationStorage::get(
param::ConvBias::Format format,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernSizeParam& param, const fallback::ConvBiasImpl::NCBKernSizeParam& param,
StrategyType stype) { StrategyType stype) {
@@ -455,14 +393,14 @@ Strategy* StrategyDelegationStorage::get(
} }
StrategyHashParam sparam; StrategyHashParam sparam;
sparam.param = param; sparam.param = param;
sparam.format = format;
sparam.format = param.filter_meta.format;
sparam.packmode = packmode; sparam.packmode = packmode;
sparam.block_m = block_m; sparam.block_m = block_m;
sparam.block_n = block_n; sparam.block_n = block_n;
sparam.block_k = block_k; sparam.block_k = block_k;
if (map_strategys.find(sparam) == map_strategys.end()) { if (map_strategys.find(sparam) == map_strategys.end()) {
MEGDNN_LOCK_GUARD(m_mtx); MEGDNN_LOCK_GUARD(m_mtx);
auto strategy = Factory::make_strategy(format, matmul_algo, packmode,
auto strategy = Factory::make_strategy(matmul_algo, packmode,
param, stype); param, stype);
map_strategys[sparam] = std::move(strategy); map_strategys[sparam] = std::move(strategy);
} }


+ 72
- 36
dnn/src/fallback/conv_bias/im2col/strategy_base.h View File

@@ -14,6 +14,7 @@
namespace megdnn { namespace megdnn {


using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
using FormatMode = param::ConvBias::Format;


struct StrategyParam { struct StrategyParam {
size_t batch_id; size_t batch_id;
@@ -28,6 +29,7 @@ struct StrategyParam {
size_t block_m; size_t block_m;
size_t block_n; size_t block_n;
size_t block_k; size_t block_k;
size_t pack_oc_size;
bool skip_copy_dst; bool skip_copy_dst;
bool is_dst_8bit; bool is_dst_8bit;
bool is_ohw_size_bigger; bool is_ohw_size_bigger;
@@ -40,13 +42,15 @@ public:
virtual void copy_padding_kern( virtual void copy_padding_kern(
WorkspaceBundle bundle, WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) = 0;
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) = 0;
virtual void packA_kern( virtual void packA_kern(
WorkspaceBundle bundle, WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) = 0;
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) = 0;


virtual void exec_im2col( virtual void exec_im2col(
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
@@ -70,14 +74,16 @@ public:


template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode, PackMode packmode>
megdnn::PostprocessMode postprocess_mode, PackMode packmode,
FormatMode format>
class Strategy; class Strategy;


template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT> : public StrategyBase {
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>
: public StrategyBase {
public: public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0; constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1; constexpr static size_t BUNDLE_PACKA_INDEX = 1;
@@ -85,24 +91,26 @@ public:
constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1; constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1;
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2;


Strategy();
Strategy() = default;


void copy_padding_kern( void copy_padding_kern(
WorkspaceBundle bundle, WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;

void packA_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;

void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
virtual void exec_im2col(
WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;


void exec_matmul( void exec_matmul(
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
@@ -132,7 +140,32 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK> : public StrategyBase {
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW44>
: public Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT,
FormatMode::NCHW> {
public:
const size_t BUNDLE_PADDING_INDEX = 0;
const size_t BUNDLE_PACKA_INDEX = 1;
const size_t THREAD_BUNDLE_PACKB_INDEX = 0;
const size_t THREAD_BUNDLE_IM2COL_INDEX = 1;
const size_t THREAD_BUNDLE_BIAS_INDEX = 2;

Strategy() = default;

void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
};

template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>
: public StrategyBase {
public: public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0; constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1; constexpr static size_t BUNDLE_PACKA_INDEX = 1;
@@ -141,19 +174,20 @@ public:
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2;
constexpr static size_t THREAD_BUNDLE_MATCOMP_INDEX = 3; constexpr static size_t THREAD_BUNDLE_MATCOMP_INDEX = 3;


Strategy();
Strategy() = default;


void copy_padding_kern( void copy_padding_kern(
WorkspaceBundle bundle, WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;


void packA_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;
void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;


void exec_matmul( void exec_matmul(
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
@@ -197,7 +231,8 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA> : public StrategyBase {
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>
: public StrategyBase {
public: public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0; constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1; constexpr static size_t BUNDLE_PACKA_INDEX = 1;
@@ -206,19 +241,20 @@ public:
constexpr static size_t THREAD_BUNDLE_MATMULDST_INDEX = 2; constexpr static size_t THREAD_BUNDLE_MATMULDST_INDEX = 2;
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 3; constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 3;


Strategy();
Strategy() = default;


void copy_padding_kern( void copy_padding_kern(
WorkspaceBundle bundle, WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;

void packA_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;

void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;


void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam, const StrategyParam& sparam,


+ 72
- 86
dnn/src/fallback/conv_bias/im2col/strategy_default.cpp View File

@@ -8,8 +8,6 @@
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "megdnn/opr_param_defs.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/im2col/strategy_base.h" #include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h" #include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86 #if MEGDNN_X86
@@ -25,19 +23,12 @@ namespace megdnn {
template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::Strategy()
: StrategyBase() {}

template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
copy_padding_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_oc_size) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param); UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N); MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC); MEGDNN_MARK_USED_VAR(OC);
@@ -53,9 +44,13 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
size_t batch_id = ncb_index.ndrange_id[0]; size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1]; size_t group_id = ncb_index.ndrange_id[1];
size_t channel_id = ncb_index.ndrange_id[2]; size_t channel_id = ncb_index.ndrange_id[2];
size_t PH_SIZE = PH * IW2 * pack_oc_size;

PW = PW * pack_oc_size;
IW = IW * pack_oc_size;


size_t padding_group_size = IH2 * IW2 * IC; size_t padding_group_size = IH2 * IW2 * IC;
size_t workspace_channel_offset = IH2 * IW2 * channel_id;
size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id;
size_t workspace_group_offset = group_id * padding_group_size; size_t workspace_group_offset = group_id * padding_group_size;
size_t workspace_batch_offset = size_t workspace_batch_offset =
param.filter_meta.group * batch_id * padding_group_size; param.filter_meta.group * batch_id * padding_group_size;
@@ -65,8 +60,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
} }
src_ctype* src = const_cast<src_ctype*>(
param.src<src_ctype>(batch_id, group_id, channel_id));
src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>(
batch_id, group_id, channel_id, 1, pack_oc_size));
src_ctype* src2; src_ctype* src2;
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) +
workspace_group_offset + workspace_batch_offset + workspace_group_offset + workspace_batch_offset +
@@ -74,8 +69,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
src_ctype* src2_ptr = src2; src_ctype* src2_ptr = src2;
const src_ctype* src_ptr = src; const src_ctype* src_ptr = src;
if (PH != 0) { if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2);
src2_ptr += PH * IW2;
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE);
src2_ptr += PH_SIZE;
} }
rep(ih, IH) { rep(ih, IH) {
if (PW != 0) if (PW != 0)
@@ -87,8 +82,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
rep(pw, PW) * (src2_ptr++) = src_zp; rep(pw, PW) * (src2_ptr++) = src_zp;
} }
if (PH != 0) { if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2);
src2_ptr += PH * IW2;
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE);
src2_ptr += PH_SIZE;
} }
} }


@@ -96,12 +91,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
packA_kern(WorkspaceBundle bundle, packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_oc_size) {
bundle.set(param.workspace_ptr); bundle.set(param.workspace_ptr);
fallback::MatrixMulImpl::KernParam matmul_param; fallback::MatrixMulImpl::KernParam matmul_param;
size_t group_id = ncb_index.ndrange_id[0]; size_t group_id = ncb_index.ndrange_id[0];
@@ -114,38 +110,38 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
matmul_algo->get_packA_type_size(); matmul_algo->get_packA_type_size();
size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size; size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size;
int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) + int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) +
group_id * packA_group_size + a_panel_offset;
group_id * packA_group_size +
(pack_oc_size == 4 ? 0 : a_panel_offset);
matmul_param.A_ptr = matmul_param.A_ptr =
const_cast<src_ctype*>(param.filter<src_ctype>(group_id)); const_cast<src_ctype*>(param.filter<src_ctype>(group_id));
matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1], matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1],
matmul_algo->get_inner_block_size().m);
matmul_algo->get_inner_block_size().m * pack_oc_size);
} }


template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo
) {
size_t m_sh = param.filter_meta.stride[0];
size_t m_sw = param.filter_meta.stride[1];
size_t m_oc = param.filter_meta.ocpg;
size_t m_oh = param.osz[0];
size_t m_ow = param.osz[1];
size_t m_ic = param.filter_meta.icpg;
size_t m_ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t m_iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t m_fh = param.filter_meta.spatial[0];
size_t m_fw = param.filter_meta.spatial[1];
size_t m_is_xcorr = !param.filter_meta.should_flip;
fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
size_t sh = param.filter_meta.stride[0];
size_t sw = param.filter_meta.stride[1];
size_t oc = param.filter_meta.ocpg;
size_t oh = param.osz[0];
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t fh = param.filter_meta.spatial[0];
size_t fw = param.filter_meta.spatial[1];
size_t is_xcorr = !param.filter_meta.should_flip;


size_t input_offset = size_t input_offset =
m_ih * m_iw * m_ic *
ih * iw * ic *
(sparam.group_id + param.filter_meta.group * sparam.batch_id) * (sparam.group_id + param.filter_meta.group * sparam.batch_id) *
sizeof(src_ctype); sizeof(src_ctype);


@@ -160,26 +156,22 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
} }
src_ctype* im2col_dst = static_cast<src_ctype*>( src_ctype* im2col_dst = static_cast<src_ctype*>(
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
if (m_sh == 1 && m_sw == 1) {
if (m_is_xcorr) {
img2col<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw,
m_fh, m_fw, sparam.ohw_cur_index,
sparam.output_block_size);
if (sh == 1 && sw == 1) {
if (is_xcorr) {
img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
sparam.ohw_cur_index, sparam.output_block_size);
} else { } else {
img2col<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw,
m_fh, m_fw, sparam.ohw_cur_index,
sparam.output_block_size);
img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
sparam.ohw_cur_index, sparam.output_block_size);
} }
} else { } else {
if (m_is_xcorr) {
img2col_stride<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih,
m_iw, m_fh, m_fw, m_sh, m_sw,
sparam.ohw_cur_index,
if (is_xcorr) {
img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size); sparam.output_block_size);
} else { } else {
img2col_stride<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic,
m_ih, m_iw, m_fh, m_fw, m_sh, m_sw,
sparam.ohw_cur_index,
img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size); sparam.output_block_size);
} }
} }
@@ -199,7 +191,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread, const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam) { const StrategyParam& sparam) {
@@ -218,7 +210,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle, const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread, WorkspaceBundle bundle_thread,
@@ -240,11 +232,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
src_ctype* b_panel = src_ctype* b_panel =
reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>( reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>(
bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX))); bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)));
size_t pack_oc_size = sparam.pack_oc_size;
matmul_param.M = sparam.output_block_oc_size; matmul_param.M = sparam.output_block_oc_size;
matmul_param.N = sparam.output_block_size; matmul_param.N = sparam.output_block_size;
matmul_param.LDB = sparam.output_block_size;
matmul_param.LDC = sparam.output_block_size;
matmul_param.LDB = pack_oc_size * sparam.output_block_size;
matmul_param.LDC = pack_oc_size * sparam.output_block_size;
matmul_param.C_ptr = matmul_dst; matmul_param.C_ptr = matmul_dst;


auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param);
@@ -255,7 +247,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, const StrategyParam& sparam,
WorkspaceBundle bundle_thread) { WorkspaceBundle bundle_thread) {
@@ -274,7 +266,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode, matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode,
param.nonlineMode, param.bias_type, param.dst_type, 1_z, param.nonlineMode, param.bias_type, param.dst_type, 1_z,
sparam.output_block_oc_size, 1_z, sparam.output_block_size);
sparam.output_block_oc_size, 1_z, sparam.output_block_size,
sparam.pack_oc_size);
copy_dst(param, matmul_dst, sparam); copy_dst(param, matmul_dst, sparam);
} }


@@ -282,20 +275,24 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) { const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) { if (!sparam.skip_copy_dst) {
size_t pack_oc_size = sparam.pack_oc_size;
dst_ctype* dst_tmp_ptr = dst_ctype* dst_tmp_ptr =
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst));
dst_ctype* dst = dst_ctype* dst =
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw + sparam.ohw_cur_index;
for (size_t oc = 0; oc < sparam.output_block_oc_size; oc++) {
sparam.oc_cur_index * sparam.ohw +
sparam.ohw_cur_index * pack_oc_size;
size_t oc_loop = sparam.output_block_oc_size / pack_oc_size;
for (size_t oc = 0; oc < oc_loop; oc++) {
std::memcpy(dst, dst_tmp_ptr, std::memcpy(dst, dst_tmp_ptr,
sizeof(dst_ctype) * sparam.output_block_size);
dst_tmp_ptr += sparam.output_block_size;
dst += sparam.ohw;
sizeof(dst_ctype) * sparam.output_block_size *
pack_oc_size);
dst_tmp_ptr += sparam.output_block_size * pack_oc_size;
dst += sparam.ohw * pack_oc_size;
} }
} }
} }
@@ -304,7 +301,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread) { const WorkspaceBundle& bundle_thread) {
bias_ctype* bias_tmp_ptr = bias_ctype* bias_tmp_ptr =
@@ -319,7 +316,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
WorkspaceBundle bundle_thread, const StrategyParam& sparam) { WorkspaceBundle bundle_thread, const StrategyParam& sparam) {
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
@@ -340,31 +337,20 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
} }
} }


#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode, PackMode::DEFAULT>;
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode, PackMode::DEFAULT, \
FormatMode::NCHW>;


INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT) megdnn::PostprocessMode::FLOAT)


#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16,
megdnn::PostprocessMode::FLOAT)
#else
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
megdnn::PostprocessMode::NO_PROCESS) megdnn::PostprocessMode::NO_PROCESS)
#endif #endif
#endif


#if MEGDNN_AARCH64 || MEGDNN_ARMV7
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#endif


INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED) megdnn::PostprocessMode::QUANTIZED)


+ 118
- 0
dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp View File

@@ -0,0 +1,118 @@
/**
* \file dnn/src/fallback/conv_bias/im2col/strategy_default.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#endif

using namespace megdnn;
#if MEGDNN_X86
using namespace x86;
#endif
namespace megdnn {

template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW44>::
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
size_t sh = param.filter_meta.stride[0];
size_t sw = param.filter_meta.stride[1];
size_t oc = param.filter_meta.ocpg;
size_t oh = param.osz[0];
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t fh = param.filter_meta.spatial[0];
size_t fw = param.filter_meta.spatial[1];
size_t is_xcorr = !param.filter_meta.should_flip;
constexpr static size_t pack_size = 4;
size_t input_offset =
ih * iw * ic *
(sparam.group_id + param.filter_meta.group * sparam.batch_id) *
sizeof(src_ctype);

src_ctype* src2 = reinterpret_cast<src_ctype*>(
reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) +
input_offset);
bool is_phpwzero = param.filter_meta.padding[0] == 0 &&
param.filter_meta.padding[1] == 0;
if (is_phpwzero) {
src2 = const_cast<src_ctype*>(
param.src<src_ctype>(sparam.batch_id, sparam.group_id));
}
src_ctype* im2col_dst = static_cast<src_ctype*>(
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
if (is_xcorr) {
if (sh == sw && sh == 1) {
img2col_nchw4<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size);
} else {
img2col_stride_nchw4<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw,
fh, fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size);
}
} else {
if (sh == sw && sh == 1) {
img2col_nchw4<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size);
} else {
img2col_stride_nchw4<false>(
src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, sh, sw,
sparam.ohw_cur_index, sparam.output_block_size);
}
}
matmul_param.M = sparam.output_block_oc_size;
matmul_param.N = sparam.output_block_size;
matmul_param.LDB = pack_size * sparam.output_block_size;
matmul_param.LDC = pack_size * sparam.output_block_size;
matmul_param.B_ptr = im2col_dst;

src_ctype* b_panel =
reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>(
bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)));
matmul_algo->pack_B(matmul_param, b_panel, 0, matmul_param.N);
}

#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode, PackMode::DEFAULT, \
FormatMode::NCHW44>;

INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT)
#if !MEGDNN_DISABLE_FLOAT16
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
megdnn::PostprocessMode::NO_PROCESS)
#endif


INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)

#undef INSTANTIAL_CLASS
} // namespace megdnn

+ 43
- 66
dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp View File

@@ -9,8 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#include "megdnn/opr_param_defs.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/im2col/strategy_base.h" #include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h" #include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86 #if MEGDNN_X86
@@ -22,22 +20,16 @@ using namespace megdnn;
using namespace x86; using namespace x86;
#endif #endif
namespace megdnn { namespace megdnn {
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::Strategy()
: StrategyBase() {}


template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
copy_padding_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param); UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N); MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC); MEGDNN_MARK_USED_VAR(OC);
@@ -96,12 +88,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
packA_kern(WorkspaceBundle bundle, packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t) {
MEGDNN_MARK_USED_VAR(bundle); MEGDNN_MARK_USED_VAR(bundle);
MEGDNN_MARK_USED_VAR(param); MEGDNN_MARK_USED_VAR(param);
MEGDNN_MARK_USED_VAR(matmulparam); MEGDNN_MARK_USED_VAR(matmulparam);
@@ -115,7 +108,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread, const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam) { const StrategyParam& sparam) {
@@ -134,7 +127,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle, const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread, WorkspaceBundle bundle_thread,
@@ -167,29 +160,28 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo
) {
fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
MEGDNN_MARK_USED_VAR(matmul_param); MEGDNN_MARK_USED_VAR(matmul_param);
MEGDNN_MARK_USED_VAR(matmul_algo); MEGDNN_MARK_USED_VAR(matmul_algo);
size_t m_sh = param.filter_meta.stride[0];
size_t m_sw = param.filter_meta.stride[1];
size_t m_oc = param.filter_meta.ocpg;
size_t m_oh = param.osz[0];
size_t m_ow = param.osz[1];
size_t m_ic = param.filter_meta.icpg;
size_t m_ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t m_iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t m_fh = param.filter_meta.spatial[0];
size_t m_fw = param.filter_meta.spatial[1];
size_t m_is_xcorr = !param.filter_meta.should_flip;
size_t sh = param.filter_meta.stride[0];
size_t sw = param.filter_meta.stride[1];
size_t oc = param.filter_meta.ocpg;
size_t oh = param.osz[0];
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t fh = param.filter_meta.spatial[0];
size_t fw = param.filter_meta.spatial[1];
size_t is_xcorr = !param.filter_meta.should_flip;


size_t input_offset = size_t input_offset =
m_ih * m_iw * m_ic *
ih * iw * ic *
(sparam.group_id + param.filter_meta.group * sparam.batch_id) * (sparam.group_id + param.filter_meta.group * sparam.batch_id) *
sizeof(src_ctype); sizeof(src_ctype);


@@ -205,26 +197,22 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
} }
src_ctype* im2col_dst = static_cast<src_ctype*>( src_ctype* im2col_dst = static_cast<src_ctype*>(
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
if (m_sh == 1 && m_sw == 1) {
if (m_is_xcorr) {
img2col<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw,
m_fh, m_fw, sparam.ohw_cur_index,
sparam.output_block_size);
if (sh == 1 && sw == 1) {
if (is_xcorr) {
img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
sparam.ohw_cur_index, sparam.output_block_size);
} else { } else {
img2col<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw,
m_fh, m_fw, sparam.ohw_cur_index,
sparam.output_block_size);
img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
sparam.ohw_cur_index, sparam.output_block_size);
} }
} else { } else {
if (m_is_xcorr) {
img2col_stride<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih,
m_iw, m_fh, m_fw, m_sh, m_sw,
sparam.ohw_cur_index,
if (is_xcorr) {
img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size); sparam.output_block_size);
} else { } else {
img2col_stride<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic,
m_ih, m_iw, m_fh, m_fw, m_sh, m_sw,
sparam.ohw_cur_index,
img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size); sparam.output_block_size);
} }
} }
@@ -234,7 +222,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, const StrategyParam& sparam,
WorkspaceBundle bundle_thread) { WorkspaceBundle bundle_thread) {
@@ -262,7 +250,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) { const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) { if (!sparam.skip_copy_dst) {
@@ -284,7 +272,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
WorkspaceBundle bundle_thread, const StrategyParam& sparam) { WorkspaceBundle bundle_thread, const StrategyParam& sparam) {
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
@@ -305,31 +293,20 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
} }
} }


#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode, PackMode::NO_PACK>;
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode, PackMode::NO_PACK, \
FormatMode::NCHW>;


INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT) megdnn::PostprocessMode::FLOAT)


#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16,
megdnn::PostprocessMode::FLOAT)
#else
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
megdnn::PostprocessMode::NO_PROCESS) megdnn::PostprocessMode::NO_PROCESS)
#endif #endif
#endif


#if MEGDNN_AARCH64 || MEGDNN_ARMV7
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#endif


INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED) megdnn::PostprocessMode::QUANTIZED)


+ 42
- 64
dnn/src/fallback/conv_bias/im2col/strategy_onlypacka.cpp View File

@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#include "megdnn/opr_param_defs.h"
#include "src/fallback/conv_bias/im2col/strategy_base.h" #include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h" #include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86 #if MEGDNN_X86
@@ -21,22 +20,16 @@ using namespace megdnn;
using namespace x86; using namespace x86;
#endif #endif
namespace megdnn { namespace megdnn {
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::Strategy()
: StrategyBase() {}


template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::
copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>::
copy_padding_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param); UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N); MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC); MEGDNN_MARK_USED_VAR(OC);
@@ -95,12 +88,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>::
packA_kern(WorkspaceBundle bundle, packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t) {
bundle.set(param.workspace_ptr); bundle.set(param.workspace_ptr);
fallback::MatrixMulImpl::KernParam matmul_param; fallback::MatrixMulImpl::KernParam matmul_param;
static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) = static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) =
@@ -128,7 +122,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>::
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread, const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam) { const StrategyParam& sparam) {
@@ -147,7 +141,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>::
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle, const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread, WorkspaceBundle bundle_thread,
@@ -185,29 +179,28 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>::
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo
) {
fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
MEGDNN_MARK_USED_VAR(matmul_param); MEGDNN_MARK_USED_VAR(matmul_param);
MEGDNN_MARK_USED_VAR(matmul_algo); MEGDNN_MARK_USED_VAR(matmul_algo);
size_t m_sh = param.filter_meta.stride[0];
size_t m_sw = param.filter_meta.stride[1];
size_t m_oc = param.filter_meta.ocpg;
size_t m_oh = param.osz[0];
size_t m_ow = param.osz[1];
size_t m_ic = param.filter_meta.icpg;
size_t m_ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t m_iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t m_fh = param.filter_meta.spatial[0];
size_t m_fw = param.filter_meta.spatial[1];
size_t m_is_xcorr = !param.filter_meta.should_flip;
size_t sh = param.filter_meta.stride[0];
size_t sw = param.filter_meta.stride[1];
size_t oc = param.filter_meta.ocpg;
size_t oh = param.osz[0];
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t fh = param.filter_meta.spatial[0];
size_t fw = param.filter_meta.spatial[1];
size_t is_xcorr = !param.filter_meta.should_flip;


size_t input_offset = size_t input_offset =
m_ih * m_iw * m_ic *
ih * iw * ic *
(sparam.group_id + param.filter_meta.group * sparam.batch_id) * (sparam.group_id + param.filter_meta.group * sparam.batch_id) *
sizeof(src_ctype); sizeof(src_ctype);


@@ -222,26 +215,22 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
} }
src_ctype* im2col_dst = static_cast<src_ctype*>( src_ctype* im2col_dst = static_cast<src_ctype*>(
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
if (m_sh == 1 && m_sw == 1) {
if (m_is_xcorr) {
img2col<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw,
m_fh, m_fw, sparam.ohw_cur_index,
sparam.output_block_size);
if (sh == 1 && sw == 1) {
if (is_xcorr) {
img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
sparam.ohw_cur_index, sparam.output_block_size);
} else { } else {
img2col<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw,
m_fh, m_fw, sparam.ohw_cur_index,
sparam.output_block_size);
img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
sparam.ohw_cur_index, sparam.output_block_size);
} }
} else { } else {
if (m_is_xcorr) {
img2col_stride<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih,
m_iw, m_fh, m_fw, m_sh, m_sw,
sparam.ohw_cur_index,
if (is_xcorr) {
img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size); sparam.output_block_size);
} else { } else {
img2col_stride<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic,
m_ih, m_iw, m_fh, m_fw, m_sh, m_sw,
sparam.ohw_cur_index,
img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size); sparam.output_block_size);
} }
} }
@@ -251,7 +240,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>::
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, const StrategyParam& sparam,
WorkspaceBundle bundle_thread) { WorkspaceBundle bundle_thread) {
@@ -292,7 +281,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>::
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) { const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) { if (!sparam.skip_copy_dst) {
@@ -310,31 +299,20 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
} }
} }


#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, \
_op_ctype, _op_dtype, _postprocess_mode,PackMode::ONLY_PACKA>;
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode, \
PackMode::ONLY_PACKA, FormatMode::NCHW>;


INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT) megdnn::PostprocessMode::FLOAT)


#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16,
megdnn::PostprocessMode::FLOAT)
#else
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
megdnn::PostprocessMode::NO_PROCESS) megdnn::PostprocessMode::NO_PROCESS)
#endif #endif
#endif


#if MEGDNN_AARCH64 || MEGDNN_ARMV7
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#endif


INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED) megdnn::PostprocessMode::QUANTIZED)


+ 319
- 1
dnn/src/fallback/convolution/img2col_helper.h View File

@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "src/common/utils.h" #include "src/common/utils.h"

namespace { namespace {


template <bool is_xcorr, typename dtype> template <bool is_xcorr, typename dtype>
@@ -41,7 +40,326 @@ void img2col_stride(const dtype* __restrict src, dtype* __restrict dst,
} }
} }



//!add for im2col matmul multithread //!add for im2col matmul multithread
//
template <bool is_xcorr, typename dtype>
void img2col_stride_nchw4(const dtype* __restrict src, dtype* __restrict dst,
const int OC, const int OH, const int OW, const int IC,
const int IH, const int IW, const int FH, const int FW,
const int SH, const int SW, const int cur_index,
const int block_size) {
MEGDNN_MARK_USED_VAR(OC);
MEGDNN_MARK_USED_VAR(OH);
int start_h = cur_index / OW;
int cur_remain_w = cur_index % OW;
int end_h = (cur_index + block_size) / OW;
int end_remain_w = (cur_index + block_size) % OW;
bool same_line = false;
if (start_h == end_h) {
same_line = true;
}

size_t newIC = IC / 4;
size_t i = 0;
if (sizeof(dtype) != 1) {
if (same_line) {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}

for (int w = cur_remain_w; w < end_remain_w; w++) {
size_t index = 4 * (ic * IH * IW +
(start_h * SH + fh2) * IW +
(w * SW + fw2));
dst[i++] = src[index];
dst[i++] = src[index + 1];
dst[i++] = src[index + 2];
dst[i++] = src[index + 3];
}
}
}
}
} else {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}

for (int w = cur_remain_w; w < OW; w++) {
size_t index =4 * (ic * IH * IW +
(start_h * SH + fh2) * IW +
(w * SW + fw2));
dst[i++] = src[index + 0];
dst[i++] = src[index + 1];
dst[i++] = src[index + 2];
dst[i++] = src[index + 3];
}

for (int h = start_h + 1; h < end_h; h++) {
rep(ow, OW) {
size_t index = 4 * (ic * IH * IW +
(h * SH + fh2) * IW +
(ow * SW + fw2));
dst[i++] = src[index + 0];
dst[i++] = src[index + 1];
dst[i++] = src[index + 2];
dst[i++] = src[index + 3];
}
}

for (int w = 0; w < end_remain_w; w++) {
size_t index = 4 * (ic * IH * IW +
(end_h * SH + fh2) * IW +
(w * SW + fw2));
dst[i++] = src[index + 0];
dst[i++] = src[index + 1];
dst[i++] = src[index + 2];
dst[i++] = src[index + 3];
}
}
}
}
}
} else {
uint32_t* output = nullptr;
const uint32_t* uint32_src =
static_cast<const uint32_t*>(static_cast<const void*>(src));
output = static_cast<uint32_t*>(static_cast<void*>(dst));
if (same_line) {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}

size_t index =
(ic * IH * IW + (start_h * SH + fh2) * IW +
(cur_remain_w * SW + fw2));
for (int w = cur_remain_w; w < end_remain_w; w++) {
output[i++] = uint32_src[index];
index += SW;
}
}
}
}
} else {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}

size_t index = ic * IH * IW +
(start_h * SH + fh2) * IW +
cur_remain_w * SW + fw2;
for (int w = cur_remain_w; w < OW; w++) {
output[i++] = uint32_src[index];
index += SW;
}

for (int h = start_h + 1; h < end_h; h++) {
index = ic * IH * IW + (h * SH + fh2) * IW + fw2;
rep(ow, OW) {
output[i++] = uint32_src[index];
index += SW;
}
}

index = ic * IH * IW + (end_h * SH + fh2) * IW + fw2;
for (int w = 0; w < end_remain_w; w++) {
output[i++] = uint32_src[index];
index += SW;
}
}
}
}
}
}
}

template <bool is_xcorr, typename dtype>
void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst,
const int OC, const int OH, const int OW, const int IC,
const int IH, const int IW, const int FH, const int FW,
const int SH, const int SW, const int cur_index,
const int block_size) {
MEGDNN_MARK_USED_VAR(OC);
MEGDNN_MARK_USED_VAR(OH);
MEGDNN_MARK_USED_VAR(SH);
MEGDNN_MARK_USED_VAR(SW);
int start_h = cur_index / OW;
int cur_remain_w = cur_index % OW;
int end_h = (cur_index + block_size) / OW;
int end_remain_w = (cur_index + block_size) % OW;
bool same_line = false;
if (start_h == end_h) {
same_line = true;
}
size_t newIC = IC / 4;
size_t i = 0;
if (sizeof(dtype) != 1) {
if (same_line) {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}

for (int w = cur_remain_w; w < end_remain_w; w++) {
size_t index =
4 * (ic * IH * IW + (start_h + fh2) * IW +
(w + fw2));
dst[i++] = src[index];
dst[i++] = src[index + 1];
dst[i++] = src[index + 2];
dst[i++] = src[index + 3];
}
}
}
}
} else {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}

for (int w = cur_remain_w; w < OW; w++) {
size_t index = ic * IH * IW + (start_h + fh2) * IW +
(w + fw2);
dst[i++] = src[4 * index];
dst[i++] = src[4 * index + 1];
dst[i++] = src[4 * index + 2];
dst[i++] = src[4 * index + 3];
}

for (int h = start_h + 1; h < end_h; h++) {
rep(ow, OW) {
size_t index =
4 * (ic * IH * IW + (h + fh2) * IW +
(ow + fw2));
dst[i++] = src[index + 0];
dst[i++] = src[index + 1];
dst[i++] = src[index + 2];
dst[i++] = src[index + 3];
}
}

for (int w = 0; w < end_remain_w; w++) {
size_t index = 4 * (ic * IH * IW +
(end_h + fh2) * IW + (w + fw2));
dst[i++] = src[index + 0];
dst[i++] = src[index + 1];
dst[i++] = src[index + 2];
dst[i++] = src[index + 3];
}
}
}
}
}
} else {
uint32_t* output = nullptr;
const uint32_t* uint32_src =
static_cast<const uint32_t*>(static_cast<const void*>(src));
output = static_cast<uint32_t*>(static_cast<void*>(dst));
if (same_line) {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
for (int w = cur_remain_w; w < end_remain_w; w++) {
size_t index = (ic * IH * IW +
(start_h + fh2) * IW + (w + fw2));
output[i++] = uint32_src[index];
}
}
}
}
} else {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}

for (int w = cur_remain_w; w < OW; w++) {
size_t index = ic * IH * IW + (start_h + fh2) * IW +
(w + fw2);
output[i++] = uint32_src[index];
}

for (int h = start_h + 1; h < end_h; h++) {
rep(ow, OW) {
size_t index = (ic * IH * IW + (h + fh2) * IW +
(ow + fw2));
output[i++] = uint32_src[index];
}
}

for (int w = 0; w < end_remain_w; w++) {
size_t index = (ic * IH * IW + (end_h + fh2) * IW +
(w + fw2));
output[i++] = uint32_src[index];
}
}
}
}
}
}
}


template <bool is_xcorr, typename dtype> template <bool is_xcorr, typename dtype>
void img2col_stride(const dtype* __restrict src, dtype* __restrict dst, void img2col_stride(const dtype* __restrict src, dtype* __restrict dst,


+ 8
- 4
dnn/src/x86/conv_bias/postprocess_helper.h View File

@@ -124,7 +124,8 @@ struct PostProcess {
megdnn::ConvBiasForward::BiasMode bias_mode, megdnn::ConvBiasForward::BiasMode bias_mode,
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::param::ConvBias::NonlineMode nonlineMode,
DType bias_type, DType dst_type, size_t N, size_t OC, DType bias_type, DType dst_type, size_t N, size_t OC,
size_t OH, size_t OW) {
size_t OH, size_t OW, size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
megdnn::param::Elemwise::Mode elem_mode = megdnn::param::Elemwise::Mode elem_mode =
megdnn::param::Elemwise::Mode::ADD; megdnn::param::Elemwise::Mode::ADD;
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) {
@@ -154,7 +155,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::FLOAT> {
megdnn::ConvBiasForward::BiasMode bias_mode, megdnn::ConvBiasForward::BiasMode bias_mode,
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::param::ConvBias::NonlineMode nonlineMode,
DType bias_type, DType dst_type, size_t N, size_t OC, DType bias_type, DType dst_type, size_t N, size_t OC,
size_t OH, size_t OW) {
size_t OH, size_t OW, size_t pack_oc_size=1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
megdnn::param::Elemwise::Mode elem_mode = megdnn::param::Elemwise::Mode elem_mode =
megdnn::param::Elemwise::Mode::ADD; megdnn::param::Elemwise::Mode::ADD;
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) {
@@ -185,7 +187,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
megdnn::ConvBiasForward::BiasMode bias_mode, megdnn::ConvBiasForward::BiasMode bias_mode,
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::param::ConvBias::NonlineMode nonlineMode,
DType bias_type, DType dst_type, size_t N, size_t OC, DType bias_type, DType dst_type, size_t N, size_t OC,
size_t OH, size_t OW) {
size_t OH, size_t OW,size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
MEGDNN_MARK_USED_VAR(conv_dst_ptr); MEGDNN_MARK_USED_VAR(conv_dst_ptr);
MEGDNN_MARK_USED_VAR(bias_ptr); MEGDNN_MARK_USED_VAR(bias_ptr);
MEGDNN_MARK_USED_VAR(dst_ptr); MEGDNN_MARK_USED_VAR(dst_ptr);
@@ -292,7 +295,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> {
megdnn::ConvBiasForward::BiasMode bias_mode, megdnn::ConvBiasForward::BiasMode bias_mode,
megdnn::param::ConvBiasV0::NonlineMode nonlineMode, megdnn::param::ConvBiasV0::NonlineMode nonlineMode,
DType bias_type, DType dst_type, size_t N, size_t OC, DType bias_type, DType dst_type, size_t N, size_t OC,
size_t OH, size_t OW) {
size_t OH, size_t OW, size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
megdnn::param::Elemwise::Mode elem_mode = megdnn::param::Elemwise::Mode elem_mode =
megdnn::param::Elemwise::Mode::ADD; megdnn::param::Elemwise::Mode::ADD;
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) {


Loading…
Cancel
Save