im2co and conv1x1 mk4_dot support
GitOrigin-RevId: 096b16a3ab
tags/v0.5.0
@@ -913,10 +913,10 @@ static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
*outptr++ = *inptr++; | |||
} | |||
for (; i < 4; i++) { | |||
*outptr++ = *inptr++; | |||
*outptr++ = *inptr++; | |||
*outptr++ = *inptr++; | |||
*outptr++ = *inptr++; | |||
*outptr++ = 0; | |||
*outptr++ = 0; | |||
*outptr++ = 0; | |||
*outptr++ = 0; | |||
} | |||
} | |||
@@ -39,7 +39,7 @@ namespace { | |||
megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::arm_common::VEC>:: \ | |||
run(static_cast<ctype*>(conv_dst_ptr), \ | |||
reinterpret_cast<ctype*>(dst_ptr), bias_type, dst_type, \ | |||
N* OC* OH* OW); | |||
N* OC* OH* OW* pack_oc_size); | |||
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||
megdnn::arm_common:: \ | |||
@@ -63,7 +63,7 @@ namespace { | |||
static_cast<ctype*>(conv_dst_ptr), \ | |||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||
dst_type, N* OC* OH* OW); | |||
dst_type, N* OC* OH* OW* pack_oc_size); | |||
#define FOR_BIAS(_mode) \ | |||
switch (_mode) { \ | |||
@@ -113,7 +113,6 @@ struct PostProcess { | |||
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | |||
megdnn::DType bias_type, megdnn::DType dst_type, size_t N, | |||
size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
FOR_BIAS(bias_mode) | |||
} | |||
}; | |||
@@ -155,7 +154,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
_op<opctype, opdtype>, \ | |||
megdnn::arm_common::VEC>::run(static_cast<opctype*>(conv_dst_ptr), \ | |||
reinterpret_cast<opdtype*>(dst_ptr), \ | |||
bias_type, dst_type, N* OC* OH* OW); | |||
bias_type, dst_type, \ | |||
N* OC* OH* OW* pack_oc_size); | |||
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \ | |||
@@ -173,8 +173,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | |||
dst_type, N, OC, OH* OW, pack_oc_size); | |||
#define HANDLE_IDENTITY(_caller, _op) \ | |||
case megdnn::NonlineMode::IDENTITY: \ | |||
#define HANDLE_IDENTITY(_caller, _op) \ | |||
case megdnn::NonlineMode::IDENTITY: \ | |||
_caller(_op) break; | |||
#define FOR_NONLINEAR(_caller) \ | |||
@@ -729,10 +729,10 @@ static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
*outptr++ = *inptr++; | |||
} | |||
for (; i < 4; i++) { | |||
*outptr++ = *inptr++; | |||
*outptr++ = *inptr++; | |||
*outptr++ = *inptr++; | |||
*outptr++ = *inptr++; | |||
*outptr++ = 0; | |||
*outptr++ = 0; | |||
*outptr++ = 0; | |||
*outptr++ = 0; | |||
} | |||
} | |||
outptr_base += 24; | |||
@@ -187,7 +187,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||
AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 2) { | |||
if (opr->param().format != param::ConvBias::Format::NCHW && | |||
opr->param().format != param::ConvBias::Format::NCHW44) | |||
opr->param().format != param::ConvBias::Format::NCHW44 && | |||
opr->param().format != param::ConvBias::Format::NCHW44_DOT) | |||
return false; | |||
size_t FH = param.filter_meta.spatial[0], | |||
@@ -219,8 +220,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) | |||
return false; | |||
if (opr->param().format == param::ConvBias::Format::NCHW44) { | |||
//! nchw44 hybird mode and channel wise is not support | |||
if (opr->param().format == param::ConvBias::Format::NCHW44 || | |||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | |||
if (param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 || | |||
param.filter_meta.ocpg == 1) { | |||
return false; | |||
@@ -73,32 +73,34 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||
const ConvBiasImpl::NCBKernSizeParam& param, | |||
MatrixMulImpl::AlgoBase::PackMode pack_mode, | |||
param::ConvBias::Format format) { | |||
size_t pack_size = get_format_pack_size(format); | |||
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||
return std::make_unique< \ | |||
Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||
_postprocess_mode, _packmode>>(pack_size); \ | |||
} \ | |||
} \ | |||
size_t pack_c_size = pack_size(format); | |||
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||
return std::make_unique< \ | |||
Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||
_postprocess_mode, _packmode>>( \ | |||
pack_c_size); \ | |||
} \ | |||
} \ | |||
MIDOUT_END() | |||
#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
return std::make_unique< \ | |||
Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||
DTypeTrait<_i_bias_type>::ctype, \ | |||
DTypeTrait<_i_dst_type>::ctype, \ | |||
_postprocess_mode, _packmode>>(pack_size); \ | |||
} \ | |||
} \ | |||
#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
return std::make_unique< \ | |||
Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||
DTypeTrait<_i_bias_type>::ctype, \ | |||
DTypeTrait<_i_dst_type>::ctype, \ | |||
_postprocess_mode, _packmode>>( \ | |||
pack_c_size); \ | |||
} \ | |||
} \ | |||
MIDOUT_END() | |||
switch (pack_mode) { | |||
@@ -12,7 +12,6 @@ | |||
#pragma once | |||
#include "megdnn/opr_param_defs.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#if MEGDNN_X86 | |||
#include "src/x86/conv_bias/postprocess_helper.h" | |||
@@ -41,12 +40,15 @@ MatrixMulImpl::KernSizeParam get_matmul_kern_param( | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | |||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); | |||
size_t pack_c_size = 1_z; | |||
size_t pack_c_size = pack_size(param.filter_meta.format); | |||
auto format = param::MatrixMul::Format::DEFAULT; | |||
if(param.filter_meta.format == param::ConvBias::Format::NCHW44){ | |||
pack_c_size = 4_z; | |||
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||
format = param::MatrixMul::Format::MK4; | |||
} else if (param.filter_meta.format == | |||
param::ConvBias::Format::NCHW44_DOT) { | |||
format = param::MatrixMul::Format::MK4_DOT; | |||
} | |||
return {param.filter_type, | |||
param.src_type, | |||
is_dst_8bit ? param.bias_type : param.dst_type, | |||
@@ -15,7 +15,6 @@ | |||
#include "src/common/opr_delegate.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#include "src/fallback/conv_bias/winograd/strategy.h" | |||
#include "src/naive/convolution/helper.h" | |||
#include "midout.h" | |||
@@ -125,7 +124,7 @@ public: | |||
size_t oc_tile_size) { | |||
size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0], | |||
FW = param.filter_meta.spatial[1]; | |||
size_t pack_oc_size = get_format_pack_size(param.filter_meta.format); | |||
size_t pack_oc_size = pack_size(param.filter_meta.format); | |||
size_t im2col = 0, packb = 0, bias_temp = 0; | |||
bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT; | |||
megdnn_assert(default_pack, "only support default packa"); | |||
@@ -319,9 +318,11 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | |||
size_t ohw_tile_size, | |||
size_t oc_tile_size) const { | |||
auto format = param::MatrixMul::Format::DEFAULT; | |||
size_t pack_oc_size = get_format_pack_size(param.filter_meta.format); | |||
size_t pack_oc_size = pack_size(param.filter_meta.format); | |||
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||
format = param::MatrixMul::Format::MK4; | |||
} else if(param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT){ | |||
format = param::MatrixMul::Format::MK4_DOT; | |||
} | |||
size_t M = oc_tile_size; | |||
size_t N = ohw_tile_size; | |||
@@ -351,11 +352,10 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | |||
void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( | |||
const NCBKernSizeParam& param, size_t& oc_tile_size, | |||
size_t& ohw_tile_size, size_t block_m, size_t block_n, | |||
bool need_pack) const { | |||
fallback::MatrixMulImpl::AlgoBase::PackMode pack_mode) const { | |||
size_t nr_threads = param.nr_threads; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t ohw = param.osz[0] * param.osz[1]; | |||
oc_tile_size = DEFAULT_OC_TILE_SIZE; | |||
ohw_tile_size = m_ohw_tile_size; | |||
@@ -376,7 +376,8 @@ void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( | |||
} | |||
} | |||
} else { | |||
if (!need_pack) { //! no pack ,usually in x86 save memroy | |||
//! in no_pack mode don't do block operation when using single thread | |||
if (pack_mode == fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { | |||
ohw_tile_size = ohw; | |||
oc_tile_size = OC; | |||
} | |||
@@ -406,7 +407,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||
if (need_pack || only_packA) { | |||
auto inner_block = m_matmul_algo->get_inner_block_size(); | |||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, inner_block.m, | |||
inner_block.n, need_pack); | |||
inner_block.n, m_matmul_algo->packmode()); | |||
auto im2col_kern_param = get_matmul_kern_param( | |||
param, ohw_tile_size, only_packA ? oc_tile_size : OC); | |||
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | |||
@@ -418,7 +419,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||
size_t nopack_default_blockn = 16; | |||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||
nopack_default_blockm, nopack_default_blockn, | |||
need_pack); | |||
m_matmul_algo->packmode()); | |||
packa_group_size = 0; | |||
} | |||
@@ -488,19 +489,20 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
if (default_pack || only_packA) { | |||
auto inner_block = m_matmul_algo->get_inner_block_size(); | |||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||
inner_block.m, inner_block.n, default_pack); | |||
} else { //! not support pack,not need pack | |||
inner_block.m, inner_block.n, | |||
m_matmul_algo->packmode()); | |||
} else { //! nopack_mode | |||
size_t nopack_default_blockm = 8; | |||
size_t nopack_default_blockn = 16; | |||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||
nopack_default_blockm, nopack_default_blockn, | |||
no_pack); | |||
m_matmul_algo->packmode()); | |||
} | |||
size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size); | |||
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | |||
size_t packa_parallel_times = 0; | |||
size_t pack_oc_size = get_format_pack_size(param.filter_meta.format); | |||
size_t pack_oc_size = pack_size(param.filter_meta.format); | |||
if (only_packA) { | |||
packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | |||
@@ -639,9 +641,15 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||
ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 2) { | |||
if (opr->param().format != param::ConvBias::Format::NCHW && | |||
opr->param().format != param::ConvBias::Format::NCHW44_DOT && | |||
opr->param().format != param::ConvBias::Format::NCHW44) { | |||
return false; | |||
} | |||
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is | |||
//! identity otherwise return false mean that 8x8x32 and 8x8x16 not support | |||
//! PostProcess | |||
//! identity otherwise return false mean that 8x8x32 and 8x8x16 not | |||
//! support PostProcess | |||
if (param.src_type.enumv() == param.filter_type.enumv() && | |||
((param.src_type.enumv() == DTypeEnum::Int8 && | |||
(param.dst_type.enumv() == DTypeEnum::Int16 || | |||
@@ -653,9 +661,10 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
return false; | |||
} | |||
if (opr->param().format == param::ConvBias::Format::NCHW44) { | |||
if (opr->param().format == param::ConvBias::Format::NCHW44 || | |||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | |||
//! current NCHW44 im2col only support DEFAULT mode matmul | |||
if(m_matmul_algo->packmode() != Pack_Mode::DEFAULT) { | |||
if (m_matmul_algo->packmode() != Pack_Mode::DEFAULT) { | |||
return false; | |||
//! nchw44 hybird mode and channel wise is not support | |||
} else if (param.filter_meta.icpg < 4_z || | |||
@@ -668,29 +677,27 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||
size_t oc_tile_size = 0, ohw_tile_size = 0; | |||
Pack_Mode packmode = m_matmul_algo->packmode(); | |||
bool default_pack = packmode == Pack_Mode::DEFAULT; | |||
bool no_pack = packmode == Pack_Mode::NO_PACK; | |||
bool only_packA = packmode == Pack_Mode::ONLY_PACKA; | |||
if (default_pack || only_packA) { | |||
auto inner_block = m_matmul_algo->get_inner_block_size(); | |||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||
inner_block.m, inner_block.n, default_pack); | |||
inner_block.m, inner_block.n, | |||
m_matmul_algo->packmode()); | |||
} else { //! not support pack,not need pack | |||
size_t nopack_default_blockm = 8; | |||
size_t nopack_default_blockn = 16; | |||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||
nopack_default_blockm, nopack_default_blockn, | |||
no_pack); | |||
m_matmul_algo->packmode()); | |||
} | |||
fallback::MatrixMulImpl::KernSizeParam matmul_param = | |||
get_matmul_kern_param(param, ohw_tile_size, oc_tile_size); | |||
bool matmulusable = m_matmul_algo->usable(matmul_param); | |||
return matmulusable && | |||
(opr->param().format == param::ConvBias::Format::NCHW || | |||
opr->param().format == param::ConvBias::Format::NCHW44) && | |||
(!(param.filter_meta.spatial[0] == | |||
param.filter_meta.spatial[1] && | |||
(param.filter_meta.spatial[0] == 1) && | |||
param.filter_meta.spatial[0] == 1 && | |||
param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||
param.filter_meta.stride[0] == 1)) && | |||
(param.filter_meta.dilation[0] == | |||
@@ -36,10 +36,10 @@ class ConvBiasImpl::AlgoIm2col final : public AlgoBase { | |||
const NCBKernSizeParam& param, size_t ohw_tile_size, | |||
size_t oc_tile_size) const; | |||
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; | |||
void choice_ohw_oc_block(const NCBKernSizeParam& param, | |||
size_t& oc_tile_size, size_t& ohw_tile_size, | |||
size_t block_m, size_t block_n, | |||
bool pack_default) const; | |||
void choice_ohw_oc_block( | |||
const NCBKernSizeParam& param, size_t& oc_tile_size, | |||
size_t& ohw_tile_size, size_t block_m, size_t block_n, | |||
fallback::MatrixMulImpl::AlgoBase::PackMode pack_mode) const; | |||
public: | |||
AlgoIm2col(MatrixMulImpl::AlgoBase* matmul_algo, size_t ohw_tile_size) | |||
@@ -230,7 +230,11 @@ public: | |||
PostprocessMode::FLOAT, | |||
"DefaultStrategyTypeNCHW44::FLOAT"_hash); | |||
} else { | |||
megdnn_throw("not support format except nchw44 and nchw\n"); | |||
megdnn_throw( | |||
ssprintf("Current only support layout " | |||
"NCHW44/NCHW for im2col " | |||
"algo, but got %d\n", | |||
uint32_t(format))); | |||
} | |||
break; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
@@ -252,12 +256,17 @@ public: | |||
cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyType::INT8x8x32"_hash); | |||
} else if (format == param::ConvBias::Format::NCHW44) { | |||
} else if (format == param::ConvBias::Format::NCHW44 || | |||
format == param::ConvBias::Format::NCHW44_DOT) { | |||
cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyType::INT8x8x32"_hash); | |||
} else { | |||
megdnn_throw("not support format except nchw44 and nchw\n"); | |||
megdnn_throw( | |||
ssprintf("Current only support layout " | |||
"NCHW44/NCHW/NCHW_DOT for im2col " | |||
"algo, but got %d\n", | |||
uint32_t(format))); | |||
} | |||
break; | |||
@@ -288,13 +297,18 @@ public: | |||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyTypeNCHW::QINT8x8x32"_hash); | |||
} else if (format == param::ConvBias::Format::NCHW44) { | |||
} else if (format == param::ConvBias::Format::NCHW44 || | |||
format == param::ConvBias::Format::NCHW44_DOT) { | |||
cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); | |||
} else { | |||
megdnn_throw("not support format except nchw44 and nchw\n"); | |||
megdnn_throw( | |||
ssprintf("Current only support layout " | |||
"NCHW44/NCHW/NCHW_DOT for im2col " | |||
"algo, but got %d\n", | |||
uint32_t(format))); | |||
} | |||
break; | |||
@@ -304,17 +318,22 @@ public: | |||
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||
PostprocessMode::QUANTIZED, | |||
"DefaultStrategyType::QINT8x8x32x8"_hash); | |||
} else if (format == param::ConvBias::Format::NCHW44) { | |||
} else if (format == param::ConvBias::Format::NCHW44 || | |||
format == param::ConvBias::Format::NCHW44_DOT) { | |||
cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, | |||
dt_int32, dt_int8, PostprocessMode::QUANTIZED, | |||
"DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash); | |||
} else { | |||
megdnn_throw("not support format except nchw44 and nchw\n"); | |||
megdnn_throw(ssprintf("Current only support layout " | |||
"NCHW44/NCHW/NCHW_DOT for im2col " | |||
"algo, but got %d\n", | |||
uint32_t(format))); | |||
} | |||
break; | |||
} | |||
megdnn_throw("error not support strategy type "); | |||
megdnn_throw(ssprintf("Unsupported strategy type %u in default mode", | |||
uint32_t(strategytype))); | |||
} | |||
static std::unique_ptr<StrategyBase> make_nopack_strategy( | |||
@@ -328,10 +347,6 @@ public: | |||
PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash); | |||
break; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
case StrategyType::FLOAT_FP16: | |||
cb1(NCHW, NO_PACK, dt_float16, __fp16, PostprocessMode::FLOAT, | |||
"NoPackStrategyType::FLOAT_FP16"_hash); | |||
break; | |||
#else | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
case StrategyType::FLOAT16_FLOAT16: | |||
@@ -341,48 +356,24 @@ public: | |||
break; | |||
#endif | |||
#endif | |||
case StrategyType::INT8x8x32: | |||
cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
"NoPackStrategyType::INT8x8x32"_hash); | |||
break; | |||
case StrategyType::INT8x8x16: | |||
cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | |||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
"NoPackStrategyType::INT8x8x16"_hash); | |||
break; | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
case StrategyType::QUINT8x8x32: | |||
cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
"NoPackStrategyType::QUINT8x8x32"_hash); | |||
break; | |||
case StrategyType::QUINT8x8x32x8: | |||
cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | |||
PostprocessMode::QUANTIZED, | |||
"NoPackStrategyType::QUINT8x8x32x8"_hash); | |||
break; | |||
#endif | |||
case StrategyType::QINT8x8x32: | |||
cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
"NoPackStrategyType::QINT8x8x32"_hash); | |||
case StrategyType::INT8x8x32: | |||
cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
"NoPackStrategyType::INT8x8x32"_hash); | |||
break; | |||
case StrategyType::QINT8x8x32x8: | |||
cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||
PostprocessMode::QUANTIZED, | |||
"NoPackStrategyType::QINT8x8x32x8"_hash); | |||
default: | |||
megdnn_throw( | |||
ssprintf("Unsupported strategy type %u in no_pack mode", | |||
uint32_t(strategytype))); | |||
break; | |||
} | |||
megdnn_throw("error not support strategy type "); | |||
megdnn_throw(ssprintf("Unsupported strategy type %u in no_pack mode", | |||
uint32_t(strategytype))); | |||
} | |||
static std::unique_ptr<StrategyBase> make_onlypacka_strategy( | |||
@@ -396,63 +387,14 @@ public: | |||
PostprocessMode::FLOAT, | |||
"OnlyPackaStrategyType::FLOAT"_hash); | |||
break; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
case StrategyType::FLOAT_FP16: | |||
cb1(NCHW, ONLY_PACKA, dt_float16, __fp16, | |||
PostprocessMode::FLOAT, | |||
"OnlyPackaStrategyType::FLOAT_FP16"_hash); | |||
break; | |||
#else | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
case StrategyType::FLOAT16_FLOAT16: | |||
cb1(NCHW, ONLY_PACKA, dt_float16, dt_float16, | |||
PostprocessMode::NO_PROCESS, | |||
"OnlyPackaStrategyType::FLOAT16_FLOAT16"_hash); | |||
break; | |||
#endif | |||
#endif | |||
case StrategyType::INT8x8x32: | |||
cb2(NCHW, ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
"OnlyPackaStrategyType::INT8x8x32"_hash); | |||
break; | |||
case StrategyType::INT8x8x16: | |||
cb2(NCHW, ONLY_PACKA, dt_int8, dt_int16, dt_int16, dt_int8, | |||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
"OnlyPackaStrategyType::INT8x8x16"_hash); | |||
break; | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
case StrategyType::QUINT8x8x32: | |||
cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm, | |||
dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, | |||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
"OnlyPackaStrategyType::QUINT8x8x32"_hash); | |||
break; | |||
case StrategyType::QUINT8x8x32x8: | |||
cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm, | |||
dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, | |||
dt_int32, dt_uint8, PostprocessMode::QUANTIZED, | |||
"OnlyPackaStrategyType::QUINT8x8x32x8"_hash); | |||
break; | |||
#endif | |||
case StrategyType::QINT8x8x32: | |||
cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
"OnlyPackaStrategyType::QINT8x8x32"_hash); | |||
break; | |||
case StrategyType::QINT8x8x32x8: | |||
cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||
PostprocessMode::QUANTIZED, | |||
"OnlyPackaStrategyType::QINT8x8x32x8"_hash); | |||
default: | |||
megdnn_throw(ssprintf( | |||
"Unsupported strategy type %u in onlypacka mode", | |||
uint32_t(strategytype))); | |||
break; | |||
} | |||
megdnn_throw("error not support strategy type "); | |||
megdnn_throw(ssprintf("Unsupported strategy type %u in onlypacka mode", | |||
uint32_t(strategytype))); | |||
} | |||
#undef cb1 | |||
@@ -11,6 +11,16 @@ | |||
#pragma once | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#if MEGDNN_X86 | |||
#include "src/x86/conv_bias/postprocess_helper.h" | |||
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#endif | |||
using namespace megdnn; | |||
#if MEGDNN_X86 | |||
using namespace x86; | |||
#endif | |||
namespace megdnn { | |||
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | |||
@@ -75,6 +85,185 @@ public: | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode, PackMode packmode, | |||
FormatMode format> | |||
//! this class is a new base class for StrategyDefault StrategyNoPack and so on, | |||
//! in order to handle copy pad use the same code | |||
class StrategyBridge : public StrategyBase { | |||
public: | |||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||
StrategyBridge() = default; | |||
virtual void copy_padding_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_oc_size) override { | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
MEGDNN_MARK_USED_VAR(N); | |||
MEGDNN_MARK_USED_VAR(OC); | |||
MEGDNN_MARK_USED_VAR(OH); | |||
MEGDNN_MARK_USED_VAR(OW); | |||
MEGDNN_MARK_USED_VAR(FH); | |||
MEGDNN_MARK_USED_VAR(FW); | |||
MEGDNN_MARK_USED_VAR(SH); | |||
MEGDNN_MARK_USED_VAR(SW); | |||
size_t IW2 = IW + 2 * PW; | |||
size_t IH2 = IH + 2 * PH; | |||
size_t batch_id = ncb_index.ndrange_id[0]; | |||
size_t group_id = ncb_index.ndrange_id[1]; | |||
size_t channel_id = ncb_index.ndrange_id[2]; | |||
size_t PH_SIZE = PH * IW2 * pack_oc_size; | |||
PW = PW * pack_oc_size; | |||
IW = IW * pack_oc_size; | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id; | |||
size_t workspace_group_offset = group_id * padding_group_size; | |||
size_t workspace_batch_offset = | |||
param.filter_meta.group * batch_id * padding_group_size; | |||
bundle.set(param.workspace_ptr); | |||
src_ctype src_zp = static_cast<src_ctype>(0); | |||
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||
} | |||
src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>( | |||
batch_id, group_id, channel_id, 1, pack_oc_size)); | |||
src_ctype* src2; | |||
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||
workspace_group_offset + workspace_batch_offset + | |||
workspace_channel_offset; | |||
src_ctype* src2_ptr = src2; | |||
const src_ctype* src_ptr = src; | |||
if (PH != 0) { | |||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||
src2_ptr += PH_SIZE; | |||
} | |||
rep(ih, IH) { | |||
if (PW != 0) | |||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW); | |||
src2_ptr += IW; | |||
src_ptr += IW; | |||
if (PW != 0) | |||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||
} | |||
if (PH != 0) { | |||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||
src2_ptr += PH_SIZE; | |||
} | |||
} | |||
}; | |||
namespace{ | |||
template <typename bias_ctype> | |||
inline void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread, | |||
const StrategyParam& sparam, | |||
size_t matmul_bundle_index) { | |||
if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) { | |||
return static_cast<void*>(bundle_thread.get(matmul_bundle_index)); | |||
} else { | |||
bias_ctype* dst = | |||
param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) + | |||
sparam.oc_cur_index * sparam.ohw; | |||
return static_cast<void*>(dst); | |||
} | |||
} | |||
template <typename bias_ctype> | |||
inline void* get_bias_temp_ptr( | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread, size_t bias_bundle_index) { | |||
bias_ctype* bias_tmp_ptr = | |||
param.bias_mode == megdnn::BiasMode::BIAS | |||
? static_cast<bias_ctype*>( | |||
bundle_thread.get(bias_bundle_index)) | |||
: nullptr; | |||
return bias_tmp_ptr; | |||
} | |||
template <typename dst_ctype> | |||
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const void* matmul_dst, const StrategyParam& sparam) { | |||
if (!sparam.skip_copy_dst) { | |||
size_t pack_oc_size = sparam.pack_oc_size; | |||
dst_ctype* dst_tmp_ptr = | |||
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | |||
dst_ctype* dst = | |||
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | |||
sparam.oc_cur_index * sparam.ohw + | |||
sparam.ohw_cur_index * pack_oc_size; | |||
size_t oc_loop = sparam.output_block_oc_size / pack_oc_size; | |||
for (size_t oc = 0; oc < oc_loop; oc++) { | |||
std::memcpy(dst, dst_tmp_ptr, | |||
sizeof(dst_ctype) * sparam.output_block_size * | |||
pack_oc_size); | |||
dst_tmp_ptr += sparam.output_block_size * pack_oc_size; | |||
dst += sparam.ohw * pack_oc_size; | |||
} | |||
} | |||
} | |||
template <typename bias_ctype> | |||
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
WorkspaceBundle bundle_thread, const StrategyParam& sparam, | |||
size_t bias_index) { | |||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||
bias_ctype* bias_temp_ptr = static_cast<bias_ctype*>( | |||
get_bias_temp_ptr<bias_ctype>(param, bundle_thread, bias_index)); | |||
if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||
bias_ctype* copy_dst = bias_temp_ptr; | |||
size_t pack_oc_size = sparam.pack_oc_size; | |||
const bias_ctype* copy_src = bias_ptr + | |||
sparam.oc_cur_index * sparam.ohw + | |||
sparam.ohw_cur_index * pack_oc_size; | |||
for (size_t oc = sparam.oc_cur_index / pack_oc_size; | |||
oc < sparam.oc_end_index / pack_oc_size; oc++) { | |||
std::memcpy(copy_dst, copy_src, | |||
sizeof(bias_ctype) * sparam.output_block_size * | |||
pack_oc_size); | |||
copy_dst += sparam.output_block_size * pack_oc_size; | |||
copy_src += sparam.ohw * pack_oc_size; | |||
} | |||
} | |||
} | |||
template <typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void do_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, WorkspaceBundle bundle_thread, | |||
size_t matmul_bundle_index, size_t bias_bundle_index) { | |||
copy_bias<bias_ctype>(param, bundle_thread, sparam, bias_bundle_index); | |||
void* matmul_dst = get_matmul_dst_ptr<bias_ctype>( | |||
param, bundle_thread, sparam, matmul_bundle_index); | |||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||
void* bias_temp_ptr = get_bias_temp_ptr<bias_ctype>(param, bundle_thread, | |||
bias_bundle_index); | |||
void* bias_preprocess_ptr = const_cast<void*>( | |||
param.bias_mode == megdnn::BiasMode::BIAS | |||
? bias_temp_ptr | |||
: static_cast<void*>(const_cast<bias_ctype*>( | |||
bias_ptr + sparam.oc_cur_index))); | |||
size_t pack_oc_size = sparam.pack_oc_size; | |||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||
matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode, | |||
param.nonlineMode, param.bias_type, param.dst_type, 1_z, | |||
sparam.output_block_oc_size / pack_oc_size, 1_z, | |||
sparam.output_block_size, pack_oc_size); | |||
copy_dst<dst_ctype>(param, matmul_dst, sparam); | |||
} | |||
} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode, PackMode packmode, | |||
FormatMode format = FormatMode::NCHW> | |||
class Strategy; | |||
@@ -82,7 +271,10 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::DEFAULT> : public StrategyBase { | |||
postprocess_mode, PackMode::DEFAULT> | |||
: public StrategyBridge<src_ctype, bias_ctype, dst_ctype, op_ctype, | |||
op_dtype, postprocess_mode, PackMode::DEFAULT, | |||
FormatMode::NCHW> { | |||
public: | |||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | |||
@@ -92,13 +284,7 @@ public: | |||
Strategy() = default; | |||
void copy_padding_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_size) override; | |||
void packA_kern(WorkspaceBundle bundle, | |||
virtual void packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
@@ -120,16 +306,13 @@ public: | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, | |||
WorkspaceBundle bundle_thread) override; | |||
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const void* matmul_dst, const StrategyParam& sparam); | |||
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
WorkspaceBundle bundle_thread, const StrategyParam& sparam); | |||
WorkspaceBundle bundle_thread) override { | |||
do_postprocess<bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode>(param, sparam, bundle_thread, | |||
THREAD_BUNDLE_IM2COL_INDEX, | |||
THREAD_BUNDLE_BIAS_INDEX); | |||
} | |||
void* get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread); | |||
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread, | |||
const StrategyParam& sparam); | |||
@@ -162,7 +345,10 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::NO_PACK> : public StrategyBase { | |||
postprocess_mode, PackMode::NO_PACK> | |||
: public StrategyBridge<src_ctype, bias_ctype, dst_ctype, op_ctype, | |||
op_dtype, postprocess_mode, PackMode::NO_PACK, | |||
FormatMode::NCHW> { | |||
public: | |||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | |||
@@ -173,12 +359,6 @@ public: | |||
Strategy() = default; | |||
void copy_padding_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_size) override; | |||
void packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
@@ -198,17 +378,6 @@ public: | |||
const WorkspaceBundle& bundle_thread, | |||
const StrategyParam& sparam); | |||
inline void* get_bias_temp_ptr( | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread) { | |||
bias_ctype* bias_tmp_ptr = | |||
param.bias_mode == megdnn::BiasMode::BIAS | |||
? static_cast<bias_ctype*>( | |||
bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX)) | |||
: nullptr; | |||
return bias_tmp_ptr; | |||
} | |||
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
@@ -216,19 +385,22 @@ public: | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, | |||
WorkspaceBundle bundle_thread) override; | |||
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const void* matmul_dst, const StrategyParam& sparam); | |||
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
WorkspaceBundle bundle_thread, const StrategyParam& sparam); | |||
WorkspaceBundle bundle_thread) override { | |||
do_postprocess<bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode>(param, sparam, bundle_thread, | |||
THREAD_BUNDLE_MATMULDST_INDEX, | |||
THREAD_BUNDLE_BIAS_INDEX); | |||
} | |||
}; | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::ONLY_PACKA> : public StrategyBase { | |||
postprocess_mode, PackMode::ONLY_PACKA> | |||
: public StrategyBridge<src_ctype, bias_ctype, dst_ctype, op_ctype, | |||
op_dtype, postprocess_mode, | |||
PackMode::ONLY_PACKA,FormatMode::NCHW> { | |||
public: | |||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | |||
@@ -239,12 +411,6 @@ public: | |||
Strategy() = default; | |||
void copy_padding_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_size) override; | |||
void packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
@@ -269,24 +435,15 @@ public: | |||
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread, | |||
const StrategyParam& sparam); | |||
inline void* get_bias_temp_ptr( | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread) { | |||
bias_ctype* bias_tmp_ptr = | |||
param.bias_mode == megdnn::BiasMode::BIAS | |||
? static_cast<bias_ctype*>( | |||
bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX)) | |||
: nullptr; | |||
return bias_tmp_ptr; | |||
} | |||
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, | |||
WorkspaceBundle bundle_thread) override; | |||
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const void* matmul_dst, const StrategyParam& sparam); | |||
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
WorkspaceBundle bundle_thread, const StrategyParam& sparam); | |||
WorkspaceBundle bundle_thread) override { | |||
do_postprocess<bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode>(param, sparam, bundle_thread, | |||
THREAD_BUNDLE_MATMULDST_INDEX, | |||
THREAD_BUNDLE_BIAS_INDEX); | |||
} | |||
}; | |||
} // namespace megdnn | |||
@@ -10,16 +10,7 @@ | |||
*/ | |||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | |||
#include "src/fallback/convolution/img2col_helper.h" | |||
#if MEGDNN_X86 | |||
#include "src/x86/conv_bias/postprocess_helper.h" | |||
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#endif | |||
using namespace megdnn; | |||
#if MEGDNN_X86 | |||
using namespace x86; | |||
#endif | |||
namespace megdnn { | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
@@ -27,73 +18,6 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::DEFAULT>:: | |||
copy_padding_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_oc_size) { | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
MEGDNN_MARK_USED_VAR(N); | |||
MEGDNN_MARK_USED_VAR(OC); | |||
MEGDNN_MARK_USED_VAR(OH); | |||
MEGDNN_MARK_USED_VAR(OW); | |||
MEGDNN_MARK_USED_VAR(FH); | |||
MEGDNN_MARK_USED_VAR(FW); | |||
MEGDNN_MARK_USED_VAR(SH); | |||
MEGDNN_MARK_USED_VAR(SW); | |||
size_t IW2 = IW + 2 * PW; | |||
size_t IH2 = IH + 2 * PH; | |||
size_t batch_id = ncb_index.ndrange_id[0]; | |||
size_t group_id = ncb_index.ndrange_id[1]; | |||
size_t channel_id = ncb_index.ndrange_id[2]; | |||
size_t PH_SIZE = PH * IW2 * pack_oc_size; | |||
PW = PW * pack_oc_size; | |||
IW = IW * pack_oc_size; | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id; | |||
size_t workspace_group_offset = group_id * padding_group_size; | |||
size_t workspace_batch_offset = | |||
param.filter_meta.group * batch_id * padding_group_size; | |||
bundle.set(param.workspace_ptr); | |||
src_ctype src_zp = static_cast<src_ctype>(0); | |||
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||
} | |||
src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>( | |||
batch_id, group_id, channel_id, 1, pack_oc_size)); | |||
src_ctype* src2; | |||
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||
workspace_group_offset + workspace_batch_offset + | |||
workspace_channel_offset; | |||
src_ctype* src2_ptr = src2; | |||
const src_ctype* src_ptr = src; | |||
if (PH != 0) { | |||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||
src2_ptr += PH_SIZE; | |||
} | |||
rep(ih, IH) { | |||
if (PW != 0) | |||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW); | |||
src2_ptr += IW; | |||
src_ptr += IW; | |||
if (PW != 0) | |||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||
} | |||
if (PH != 0) { | |||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||
src2_ptr += PH_SIZE; | |||
} | |||
} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::DEFAULT>:: | |||
packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
@@ -244,100 +168,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
matmul_kern_naked(matmul_param, a_panel, b_panel); | |||
} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::DEFAULT>:: | |||
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, | |||
WorkspaceBundle bundle_thread) { | |||
copy_bias(param, bundle_thread, sparam); | |||
void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam); | |||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||
void* bias_temp_ptr = get_bias_temp_ptr(param, bundle_thread); | |||
void* bias_preprocess_ptr = const_cast<void*>( | |||
param.bias_mode == megdnn::BiasMode::BIAS | |||
? bias_temp_ptr | |||
: static_cast<void*>(const_cast<bias_ctype*>( | |||
bias_ptr + sparam.oc_cur_index))); | |||
size_t pack_oc_size = sparam.pack_oc_size; | |||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||
matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode, | |||
param.nonlineMode, param.bias_type, param.dst_type, 1_z, | |||
sparam.output_block_oc_size / pack_oc_size, 1_z, | |||
sparam.output_block_size, pack_oc_size); | |||
copy_dst(param, matmul_dst, sparam); | |||
} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::DEFAULT>:: | |||
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const void* matmul_dst, const StrategyParam& sparam) { | |||
if (!sparam.skip_copy_dst) { | |||
size_t pack_oc_size = sparam.pack_oc_size; | |||
dst_ctype* dst_tmp_ptr = | |||
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | |||
dst_ctype* dst = | |||
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | |||
sparam.oc_cur_index * sparam.ohw + | |||
sparam.ohw_cur_index * pack_oc_size; | |||
size_t oc_loop = sparam.output_block_oc_size / pack_oc_size; | |||
for (size_t oc = 0; oc < oc_loop; oc++) { | |||
std::memcpy(dst, dst_tmp_ptr, | |||
sizeof(dst_ctype) * sparam.output_block_size * | |||
pack_oc_size); | |||
dst_tmp_ptr += sparam.output_block_size * pack_oc_size; | |||
dst += sparam.ohw * pack_oc_size; | |||
} | |||
} | |||
} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::DEFAULT>:: | |||
get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread) { | |||
bias_ctype* bias_tmp_ptr = | |||
param.bias_mode == megdnn::BiasMode::BIAS | |||
? static_cast<bias_ctype*>( | |||
bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX)) | |||
: nullptr; | |||
return bias_tmp_ptr; | |||
} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::DEFAULT>:: | |||
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | |||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||
bias_ctype* bias_temp_ptr = | |||
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread)); | |||
if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||
bias_ctype* copy_dst = bias_temp_ptr; | |||
const bias_ctype* copy_src = bias_ptr + | |||
sparam.oc_cur_index * sparam.ohw + | |||
sparam.ohw_cur_index; | |||
for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) { | |||
std::memcpy(copy_dst, copy_src, | |||
sizeof(bias_ctype) * sparam.output_block_size); | |||
copy_dst += sparam.output_block_size; | |||
copy_src += sparam.ohw; | |||
} | |||
} | |||
} | |||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
_op_dtype, _postprocess_mode) \ | |||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
@@ -11,16 +11,7 @@ | |||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | |||
#include "src/fallback/convolution/img2col_helper.h" | |||
#if MEGDNN_X86 | |||
#include "src/x86/conv_bias/postprocess_helper.h" | |||
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#endif | |||
using namespace megdnn; | |||
#if MEGDNN_X86 | |||
using namespace x86; | |||
#endif | |||
namespace megdnn { | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
@@ -28,69 +19,6 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::NO_PACK>:: | |||
copy_padding_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t) { | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
MEGDNN_MARK_USED_VAR(N); | |||
MEGDNN_MARK_USED_VAR(OC); | |||
MEGDNN_MARK_USED_VAR(OH); | |||
MEGDNN_MARK_USED_VAR(OW); | |||
MEGDNN_MARK_USED_VAR(FH); | |||
MEGDNN_MARK_USED_VAR(FW); | |||
MEGDNN_MARK_USED_VAR(SH); | |||
MEGDNN_MARK_USED_VAR(SW); | |||
size_t IW2 = IW + 2 * PW; | |||
size_t IH2 = IH + 2 * PH; | |||
size_t batch_id = ncb_index.ndrange_id[0]; | |||
size_t group_id = ncb_index.ndrange_id[1]; | |||
size_t channel_id = ncb_index.ndrange_id[2]; | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
size_t workspace_channel_offset = IH2 * IW2 * channel_id; | |||
size_t workspace_group_offset = group_id * padding_group_size; | |||
size_t workspace_batch_offset = | |||
param.filter_meta.group * batch_id * padding_group_size; | |||
bundle.set(param.workspace_ptr); | |||
src_ctype src_zp = static_cast<src_ctype>(0); | |||
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||
} | |||
src_ctype* src = const_cast<src_ctype*>( | |||
param.src<src_ctype>(batch_id, group_id, channel_id)); | |||
src_ctype* src2; | |||
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||
workspace_group_offset + workspace_batch_offset + | |||
workspace_channel_offset; | |||
src_ctype* src2_ptr = src2; | |||
const src_ctype* src_ptr = src; | |||
if (PH != 0) { | |||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||
src2_ptr += PH * IW2; | |||
} | |||
rep(ih, IH) { | |||
if (PW != 0) | |||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW); | |||
src2_ptr += IW; | |||
src_ptr += IW; | |||
if (PW != 0) | |||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||
} | |||
if (PH != 0) { | |||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||
src2_ptr += PH * IW2; | |||
} | |||
} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::NO_PACK>:: | |||
packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
@@ -220,81 +148,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
} | |||
} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::NO_PACK>:: | |||
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, | |||
WorkspaceBundle bundle_thread) { | |||
copy_bias(param, bundle_thread, sparam); | |||
void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam); | |||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||
bias_ctype* bias_temp_ptr = | |||
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread)); | |||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||
matmul_dst, | |||
const_cast<void*>( | |||
param.bias_mode == megdnn::BiasMode::BIAS | |||
? bias_temp_ptr | |||
: static_cast<void*>(const_cast<bias_ctype*>( | |||
bias_ptr + sparam.oc_cur_index))), | |||
matmul_dst, param.bias_mode, param.nonlineMode, param.bias_type, | |||
param.dst_type, 1_z, sparam.output_block_oc_size, 1_z, | |||
sparam.output_block_size); | |||
copy_dst(param, matmul_dst, sparam); | |||
} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::NO_PACK>:: | |||
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const void* matmul_dst, const StrategyParam& sparam) { | |||
if (!sparam.skip_copy_dst) { | |||
dst_ctype* dst_tmp_ptr = | |||
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | |||
dst_ctype* dst = | |||
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | |||
sparam.oc_cur_index * sparam.ohw + sparam.ohw_cur_index; | |||
for (size_t oc = 0; oc < sparam.output_block_oc_size; oc++) { | |||
std::memcpy(dst, dst_tmp_ptr, | |||
sizeof(dst_ctype) * sparam.output_block_size); | |||
dst_tmp_ptr += sparam.output_block_size; | |||
dst += sparam.ohw; | |||
} | |||
} | |||
} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::NO_PACK>:: | |||
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | |||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||
bias_ctype* bias_temp_ptr = | |||
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread)); | |||
if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||
bias_ctype* copy_dst = bias_temp_ptr; | |||
const bias_ctype* copy_src = bias_ptr + | |||
sparam.oc_cur_index * sparam.ohw + | |||
sparam.ohw_cur_index; | |||
for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) { | |||
std::memcpy(copy_dst, copy_src, | |||
sizeof(bias_ctype) * sparam.output_block_size); | |||
copy_dst += sparam.output_block_size; | |||
copy_src += sparam.ohw; | |||
} | |||
} | |||
} | |||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
_op_dtype, _postprocess_mode) \ | |||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
@@ -302,34 +155,18 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | |||
megdnn::PostprocessMode::FLOAT) | |||
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||
megdnn::PostprocessMode::FLOAT) | |||
#else | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#endif | |||
#endif | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||
megdnn::PostprocessMode::QUANTIZED) | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#endif | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||
megdnn::PostprocessMode::QUANTIZED) | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#undef INSTANTIAL_CLASS | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -11,16 +11,7 @@ | |||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | |||
#include "src/fallback/convolution/img2col_helper.h" | |||
#if MEGDNN_X86 | |||
#include "src/x86/conv_bias/postprocess_helper.h" | |||
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#endif | |||
using namespace megdnn; | |||
#if MEGDNN_X86 | |||
using namespace x86; | |||
#endif | |||
namespace megdnn { | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
@@ -28,69 +19,6 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||
copy_padding_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t) { | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
MEGDNN_MARK_USED_VAR(N); | |||
MEGDNN_MARK_USED_VAR(OC); | |||
MEGDNN_MARK_USED_VAR(OH); | |||
MEGDNN_MARK_USED_VAR(OW); | |||
MEGDNN_MARK_USED_VAR(FH); | |||
MEGDNN_MARK_USED_VAR(FW); | |||
MEGDNN_MARK_USED_VAR(SH); | |||
MEGDNN_MARK_USED_VAR(SW); | |||
size_t IW2 = IW + 2 * PW; | |||
size_t IH2 = IH + 2 * PH; | |||
size_t batch_id = ncb_index.ndrange_id[0]; | |||
size_t group_id = ncb_index.ndrange_id[1]; | |||
size_t channel_id = ncb_index.ndrange_id[2]; | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
size_t workspace_channel_offset = IH2 * IW2 * channel_id; | |||
size_t workspace_group_offset = group_id * padding_group_size; | |||
size_t workspace_batch_offset = | |||
param.filter_meta.group * batch_id * padding_group_size; | |||
bundle.set(param.workspace_ptr); | |||
src_ctype src_zp = static_cast<src_ctype>(0); | |||
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||
} | |||
src_ctype* src = const_cast<src_ctype*>( | |||
param.src<src_ctype>(batch_id, group_id, channel_id)); | |||
src_ctype* src2; | |||
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||
workspace_group_offset + workspace_batch_offset + | |||
workspace_channel_offset; | |||
src_ctype* src2_ptr = src2; | |||
const src_ctype* src_ptr = src; | |||
if (PH != 0) { | |||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||
src2_ptr += PH * IW2; | |||
} | |||
rep(ih, IH) { | |||
if (PW != 0) | |||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW); | |||
src2_ptr += IW; | |||
src_ptr += IW; | |||
if (PW != 0) | |||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||
} | |||
if (PH != 0) { | |||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||
src2_ptr += PH * IW2; | |||
} | |||
} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||
packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
@@ -123,25 +51,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread, | |||
const StrategyParam& sparam) { | |||
if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) { | |||
return static_cast<void*>( | |||
bundle_thread.get(THREAD_BUNDLE_MATMULDST_INDEX)); | |||
} else { | |||
bias_ctype* dst = | |||
param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) + | |||
sparam.oc_cur_index * sparam.ohw; | |||
return static_cast<void*>(dst); | |||
} | |||
} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
@@ -241,63 +150,19 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, | |||
WorkspaceBundle bundle_thread) { | |||
void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam); | |||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||
bias_ctype* bias_temp_ptr = | |||
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread)); | |||
if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||
bias_ctype* copy_dst = bias_temp_ptr; | |||
const bias_ctype* copy_src = bias_ptr + | |||
sparam.oc_cur_index * sparam.ohw + | |||
sparam.ohw_cur_index; | |||
for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) { | |||
std::memcpy(copy_dst, copy_src, | |||
sizeof(bias_ctype) * sparam.output_block_size); | |||
copy_dst += sparam.output_block_size; | |||
copy_src += sparam.ohw; | |||
} | |||
} | |||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||
matmul_dst, | |||
const_cast<void*>( | |||
param.bias_mode == megdnn::BiasMode::BIAS | |||
? bias_temp_ptr | |||
: static_cast<void*>(const_cast<bias_ctype*>( | |||
bias_ptr + sparam.oc_cur_index))), | |||
matmul_dst, param.bias_mode, param.nonlineMode, param.bias_type, | |||
param.dst_type, 1_z, sparam.output_block_oc_size, 1_z, | |||
sparam.output_block_size); | |||
copy_dst(param, matmul_dst, sparam); | |||
} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode> | |||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const void* matmul_dst, const StrategyParam& sparam) { | |||
if (!sparam.skip_copy_dst) { | |||
dst_ctype* dst_tmp_ptr = | |||
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | |||
dst_ctype* dst = | |||
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | |||
sparam.oc_cur_index * sparam.ohw + sparam.ohw_cur_index; | |||
for (size_t oc = 0; oc < sparam.output_block_oc_size; oc++) { | |||
std::memcpy(dst, dst_tmp_ptr, | |||
sizeof(dst_ctype) * sparam.output_block_size); | |||
dst_tmp_ptr += sparam.output_block_size; | |||
dst += sparam.ohw; | |||
} | |||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread, | |||
const StrategyParam& sparam) { | |||
if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) { | |||
return static_cast<bias_ctype*>( | |||
bundle_thread.get(THREAD_BUNDLE_MATMULDST_INDEX)); | |||
} else { | |||
bias_ctype* dst = | |||
param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) + | |||
sparam.oc_cur_index * sparam.ohw; | |||
return static_cast<void*>(dst); | |||
} | |||
} | |||
@@ -310,33 +175,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | |||
megdnn::PostprocessMode::FLOAT) | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||
megdnn::PostprocessMode::FLOAT) | |||
#else | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#endif | |||
#endif | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||
megdnn::PostprocessMode::QUANTIZED) | |||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#endif | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||
megdnn::PostprocessMode::QUANTIZED) | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#undef INSTANTIAL_CLASS | |||
} // namespace megdnn | |||
@@ -26,7 +26,7 @@ | |||
using namespace megdnn; | |||
using namespace fallback; | |||
size_t megdnn::fallback::get_format_pack_size(param::ConvBias::Format format) { | |||
size_t megdnn::fallback::pack_size(param::ConvBias::Format format) { | |||
switch (format) { | |||
case param::ConvBias::Format::NCHW44: | |||
case param::ConvBias::Format::NCHW44_DOT: | |||
@@ -23,8 +23,10 @@ namespace fallback { | |||
/*! | |||
* \brief get the pack_size according to the format | |||
* Note TODO: when remove format from param, | |||
* may using like this "opr::param::format specify" | |||
* */ | |||
size_t get_format_pack_size(param::ConvBias::Format format); | |||
size_t pack_size(param::ConvBias::Format format); | |||
/*! | |||
* \brief fallback conv bias forward impl | |||
@@ -52,9 +52,21 @@ class GemmInterleaved<Strategy, true> { | |||
} | |||
size_t get_b_workspace_size() const { | |||
#if __ARM_FEATURE_DOTPROD | |||
size_t new_blockn = m_strategy.block_n; | |||
if (m_strategy.KERNEL_W == 6 && m_strategy.UNROLL_K == 4 && | |||
m_strategy.KERNEL_H == 8) { | |||
new_blockn = round_up<size_t>((m_strategy.block_n-1) % 6, 4) + | |||
m_strategy.block_n / 6 * 6; | |||
} | |||
size_t N = round_up(new_blockn, m_strategy.KERNEL_W); | |||
size_t K = round_up(m_strategy.block_k, m_strategy.UNROLL_K); | |||
return round_up(sizeof(stype) * N * K, CACHELINE_SIZE) + m_align_size; | |||
#else | |||
size_t N = round_up(m_strategy.block_n, m_strategy.KERNEL_W); | |||
size_t K = round_up(m_strategy.block_k, m_strategy.UNROLL_K); | |||
return round_up(sizeof(stype) * N * K, CACHELINE_SIZE) + m_align_size; | |||
#endif | |||
} | |||
//! temporary storage for output, post process such as add bias or relu will | |||
@@ -268,7 +268,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { | |||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | |||
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false); | |||
#else | |||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | |||
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384", | |||
"IM2COLMATMUL:ARMV7_F32:192", true); | |||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | |||
"IM2COLMATMUL:ARMV7_F32:192", false); | |||
@@ -72,10 +72,12 @@ std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args( | |||
std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||
std::vector<size_t> kernel_vec, size_t stride, bool no_pad = false, | |||
bool no_bias = false, bool no_nonlinemode = false, | |||
bool is_input_nchw = false, bool support_full_bias = false, | |||
bool support_sigmoid = false) { | |||
bool is_input_nchw = false, bool is_nchw44_dot = false, | |||
bool support_full_bias = false, bool support_sigmoid = false, | |||
bool only_no_bias = false) { | |||
using namespace conv_bias; | |||
using NLMode = param::ConvBias::NonlineMode; | |||
std::vector<TestArg> args; | |||
auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, | |||
@@ -102,7 +104,11 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||
size_t kernel_h = kernel; | |||
size_t kernel_w = kernel; | |||
param::ConvBias param; | |||
param.format = param::ConvBias::Format::NCHW44; | |||
if (!is_nchw44_dot) { | |||
param.format = param::ConvBias::Format::NCHW44; | |||
} else { | |||
param.format = param::ConvBias::Format::NCHW44_DOT; | |||
} | |||
param.stride_h = stride; | |||
param.stride_w = stride; | |||
param.pad_h = pad; | |||
@@ -155,18 +161,22 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||
if (support_sigmoid) { | |||
nonlinemode.emplace_back(NLMode::SIGMOID); | |||
} | |||
std::vector<megdnn::BiasMode> bias_mode = { | |||
megdnn::BiasMode::BROADCAST_CHANNEL_BIAS}; | |||
if (no_bias) { | |||
std::vector<megdnn::BiasMode> bias_mode; | |||
if (!only_no_bias) { | |||
bias_mode.emplace_back(megdnn::BiasMode::BROADCAST_CHANNEL_BIAS); | |||
if (no_bias) { | |||
bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS); | |||
} | |||
} else { | |||
bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS); | |||
} | |||
if (support_full_bias) { | |||
bias_mode.emplace_back(megdnn::BiasMode::BIAS); | |||
bias_mode.emplace_back(megdnn::BiasMode::BIAS); | |||
} | |||
for (auto bias : bias_mode) | |||
for (auto nlmode : nonlinemode) | |||
for (size_t n : {1, 2}) | |||
for (size_t n : {1,2}) | |||
for (size_t kernel : kernel_vec) | |||
for (size_t oc : {4, 12}) | |||
for (size_t ic : {1, 3, 4, 12}) | |||
@@ -361,19 +371,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) { | |||
check_conv_bias(get_nchw44_conv_bias_args({2, 3}, 1, false, false, false, | |||
false, true, true), | |||
false, false, true, true), | |||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) { | |||
check_conv_bias(get_nchw44_conv_bias_args({5}, 1, false, false, false, | |||
false, true, true), | |||
false, false, true, true), | |||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) { | |||
check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, | |||
false, false, true, true), | |||
false, false, false, true, true), | |||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||
} | |||
@@ -1420,6 +1430,111 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) { | |||
#endif | |||
#undef cb | |||
} | |||
#if __ARM_FEATURE_DOTPROD | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) { | |||
UniformIntRNG rng{-50, 50}; | |||
#define cb(name) \ | |||
checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, \ | |||
false, false, false, true), \ | |||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||
dtype::QuantizedS8(60.25f), name); \ | |||
checker_conv_bias( \ | |||
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true), \ | |||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||
dtype::QuantizedS8(60.25f), name); | |||
float epsilon = 0.001; | |||
#if MEGDNN_AARCH64 | |||
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); | |||
#elif MEGDNN_ARMV7 | |||
cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X6X4_DOTPROD:96"); | |||
#endif | |||
#undef cb | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) { | |||
UniformIntRNG rng{-50, 50}; | |||
#define cb(name) \ | |||
checker_conv_bias( \ | |||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||
true, false, true, false, false, true), \ | |||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \ | |||
checker_conv_bias( \ | |||
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | |||
false, false, true), \ | |||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); | |||
float epsilon = 0.001; | |||
#if MEGDNN_AARCH64 | |||
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); | |||
#elif MEGDNN_ARMV7 | |||
cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X6X4_DOTPROD:96"); | |||
#endif | |||
#undef cb | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32_MK4_DOT) { | |||
UniformIntRNG rng{-50, 50}; | |||
#define cb(name) \ | |||
checker_conv_bias( \ | |||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||
true, false, true, false, false, true), \ | |||
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | |||
dtype::Int32(), {}, name); \ | |||
checker_conv_bias( \ | |||
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | |||
false, false, true), \ | |||
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | |||
dtype::Int32(), {}, name); | |||
float epsilon = 0.001; | |||
#if MEGDNN_AARCH64 | |||
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); | |||
#elif MEGDNN_ARMV7 | |||
cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X6X4_DOTPROD:96"); | |||
#endif | |||
#undef cb | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) { | |||
UniformIntRNG rng{-50, 50}; | |||
#define cb(name) \ | |||
checker_conv_bias( \ | |||
get_nchw44_conv_bias_args({1}, 1, true, true, false, false, true), \ | |||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||
dtype::QuantizedS8(60.25f), name); \ | |||
checker_conv_bias( \ | |||
get_nchw44_conv_bias_args({1}, 1, true, true, true, false, true, \ | |||
false, false, true), \ | |||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \ | |||
checker_conv_bias( \ | |||
get_nchw44_conv_bias_args({1}, 1, true, true, true, false, true, \ | |||
false, false, true), \ | |||
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | |||
dtype::Int32(), {}, name); | |||
float epsilon = 0.001; | |||
#if MEGDNN_AARCH64 | |||
cb("CONV1x1:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"); | |||
#elif MEGDNN_ARMV7 | |||
cb("CONV1x1:AARCH32_INT8_MK4_8X6X4_DOTPROD"); | |||
#endif | |||
#undef cb | |||
} | |||
#endif | |||
// clang-format on | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { | |||
@@ -1685,8 +1800,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({2, 4, 7}, 1); | |||
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | |||
{2, 4, 7}, 1, false, false, false, false, false, true,true); | |||
#if MEGDNN_AARCH64 | |||
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | |||
#elif MEGDNN_ARMV7 | |||
@@ -1696,8 +1811,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({3, 5, 6}, 2); | |||
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | |||
{3, 5, 6}, 2, false, false, false, false, false, true, true); | |||
#if MEGDNN_AARCH64 | |||
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | |||
#elif MEGDNN_ARMV7 | |||
@@ -897,6 +897,62 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32) { | |||
#undef cb | |||
} | |||
#if MEGDNN_X86_WITH_MKL || MEGDNN_X86_WITH_OPENBLAS | |||
TEST_F(X86, CONV_BIAS_IM2COLMATMUL_FP32) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args; | |||
auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, | |||
size_t p, NonlineMode nonline_mode) { | |||
if (w + 2 * p < kernel || h + 2 * p < kernel) | |||
return; | |||
param::ConvBias param; | |||
param.stride_h = 1; | |||
param.stride_w = 1; | |||
param.pad_h = p; | |||
param.pad_w = p; | |||
param.nonlineMode = nonline_mode; | |||
//! no bias | |||
args.emplace_back(param, TensorShape{1, ic, h, w}, | |||
TensorShape{oc, ic, kernel, kernel}, TensorShape{}); | |||
args.emplace_back(param, TensorShape{1, ic, h, w}, | |||
TensorShape{oc, ic, kernel, kernel}, | |||
TensorShape{1, oc, 1, 1}); | |||
args.emplace_back( | |||
param, TensorShape{1, ic, h, w}, | |||
TensorShape{oc, ic, kernel, kernel}, | |||
TensorShape{1, oc, (h + 2 * p - kernel) / param.stride_h + 1, | |||
(w + 2 * p - kernel) / param.stride_w + 1}); | |||
}; | |||
for (size_t kernel : {2, 3, 4, 5, 6, 7}) | |||
for (size_t ic : {1, 4, 8, 16}) | |||
for (size_t oc : {1, 4, 8, 16, 300}) | |||
for (size_t p : {0, 2}) | |||
for (size_t size : {8, 24}) | |||
for (NonlineMode nonline_mode : | |||
{NonlineMode::IDENTITY, NonlineMode::RELU}) { | |||
run(oc, ic, size, size, kernel, p, nonline_mode); | |||
} | |||
run(2046, 8, 20, 20, 3, 1, NonlineMode::IDENTITY); | |||
Checker<ConvBias> checker(handle()); | |||
#define cb(algo_name) \ | |||
checker.set_before_exec_callback( \ | |||
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name)); \ | |||
for (auto&& arg : args) { \ | |||
checker.set_param(arg.param).execs( \ | |||
{arg.src, arg.filter, arg.bias, {}, {}}); \ | |||
} | |||
cb("IM2COLMATMUL:X86_F32_BLAS"); | |||
#undef cb | |||
} | |||
#endif | |||
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | |||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32_PACKA) { | |||
using namespace conv_bias; | |||