im2co and conv1x1 mk4_dot support
GitOrigin-RevId: 096b16a3ab
tags/v0.5.0
@@ -913,10 +913,10 @@ static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
*outptr++ = *inptr++; | *outptr++ = *inptr++; | ||||
} | } | ||||
for (; i < 4; i++) { | for (; i < 4; i++) { | ||||
*outptr++ = *inptr++; | |||||
*outptr++ = *inptr++; | |||||
*outptr++ = *inptr++; | |||||
*outptr++ = *inptr++; | |||||
*outptr++ = 0; | |||||
*outptr++ = 0; | |||||
*outptr++ = 0; | |||||
*outptr++ = 0; | |||||
} | } | ||||
} | } | ||||
@@ -39,7 +39,7 @@ namespace { | |||||
megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::arm_common::VEC>:: \ | megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::arm_common::VEC>:: \ | ||||
run(static_cast<ctype*>(conv_dst_ptr), \ | run(static_cast<ctype*>(conv_dst_ptr), \ | ||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, dst_type, \ | reinterpret_cast<ctype*>(dst_ptr), bias_type, dst_type, \ | ||||
N* OC* OH* OW); | |||||
N* OC* OH* OW* pack_oc_size); | |||||
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | ||||
megdnn::arm_common:: \ | megdnn::arm_common:: \ | ||||
@@ -63,7 +63,7 @@ namespace { | |||||
static_cast<ctype*>(conv_dst_ptr), \ | static_cast<ctype*>(conv_dst_ptr), \ | ||||
reinterpret_cast<const ctype*>(bias_ptr), \ | reinterpret_cast<const ctype*>(bias_ptr), \ | ||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | ||||
dst_type, N* OC* OH* OW); | |||||
dst_type, N* OC* OH* OW* pack_oc_size); | |||||
#define FOR_BIAS(_mode) \ | #define FOR_BIAS(_mode) \ | ||||
switch (_mode) { \ | switch (_mode) { \ | ||||
@@ -113,7 +113,6 @@ struct PostProcess { | |||||
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | ||||
megdnn::DType bias_type, megdnn::DType dst_type, size_t N, | megdnn::DType bias_type, megdnn::DType dst_type, size_t N, | ||||
size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { | size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { | ||||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||||
FOR_BIAS(bias_mode) | FOR_BIAS(bias_mode) | ||||
} | } | ||||
}; | }; | ||||
@@ -155,7 +154,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||||
_op<opctype, opdtype>, \ | _op<opctype, opdtype>, \ | ||||
megdnn::arm_common::VEC>::run(static_cast<opctype*>(conv_dst_ptr), \ | megdnn::arm_common::VEC>::run(static_cast<opctype*>(conv_dst_ptr), \ | ||||
reinterpret_cast<opdtype*>(dst_ptr), \ | reinterpret_cast<opdtype*>(dst_ptr), \ | ||||
bias_type, dst_type, N* OC* OH* OW); | |||||
bias_type, dst_type, \ | |||||
N* OC* OH* OW* pack_oc_size); | |||||
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | ||||
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \ | megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \ | ||||
@@ -173,8 +173,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||||
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | ||||
dst_type, N, OC, OH* OW, pack_oc_size); | dst_type, N, OC, OH* OW, pack_oc_size); | ||||
#define HANDLE_IDENTITY(_caller, _op) \ | |||||
case megdnn::NonlineMode::IDENTITY: \ | |||||
#define HANDLE_IDENTITY(_caller, _op) \ | |||||
case megdnn::NonlineMode::IDENTITY: \ | |||||
_caller(_op) break; | _caller(_op) break; | ||||
#define FOR_NONLINEAR(_caller) \ | #define FOR_NONLINEAR(_caller) \ | ||||
@@ -729,10 +729,10 @@ static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
*outptr++ = *inptr++; | *outptr++ = *inptr++; | ||||
} | } | ||||
for (; i < 4; i++) { | for (; i < 4; i++) { | ||||
*outptr++ = *inptr++; | |||||
*outptr++ = *inptr++; | |||||
*outptr++ = *inptr++; | |||||
*outptr++ = *inptr++; | |||||
*outptr++ = 0; | |||||
*outptr++ = 0; | |||||
*outptr++ = 0; | |||||
*outptr++ = 0; | |||||
} | } | ||||
} | } | ||||
outptr_base += 24; | outptr_base += 24; | ||||
@@ -187,7 +187,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||||
AlgoSelectionStrategy) const { | AlgoSelectionStrategy) const { | ||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 2) { | MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 2) { | ||||
if (opr->param().format != param::ConvBias::Format::NCHW && | if (opr->param().format != param::ConvBias::Format::NCHW && | ||||
opr->param().format != param::ConvBias::Format::NCHW44) | |||||
opr->param().format != param::ConvBias::Format::NCHW44 && | |||||
opr->param().format != param::ConvBias::Format::NCHW44_DOT) | |||||
return false; | return false; | ||||
size_t FH = param.filter_meta.spatial[0], | size_t FH = param.filter_meta.spatial[0], | ||||
@@ -219,8 +220,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) | param.nonlineMode != megdnn::NonlineMode::IDENTITY) | ||||
return false; | return false; | ||||
if (opr->param().format == param::ConvBias::Format::NCHW44) { | |||||
//! nchw44 hybird mode and channel wise is not support | |||||
if (opr->param().format == param::ConvBias::Format::NCHW44 || | |||||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | |||||
if (param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 || | if (param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 || | ||||
param.filter_meta.ocpg == 1) { | param.filter_meta.ocpg == 1) { | ||||
return false; | return false; | ||||
@@ -73,32 +73,34 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||||
const ConvBiasImpl::NCBKernSizeParam& param, | const ConvBiasImpl::NCBKernSizeParam& param, | ||||
MatrixMulImpl::AlgoBase::PackMode pack_mode, | MatrixMulImpl::AlgoBase::PackMode pack_mode, | ||||
param::ConvBias::Format format) { | param::ConvBias::Format format) { | ||||
size_t pack_size = get_format_pack_size(format); | |||||
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||||
midout_iv(_midout_tag)) { \ | |||||
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||||
return std::make_unique< \ | |||||
Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||||
_postprocess_mode, _packmode>>(pack_size); \ | |||||
} \ | |||||
} \ | |||||
size_t pack_c_size = pack_size(format); | |||||
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||||
midout_iv(_midout_tag)) { \ | |||||
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||||
return std::make_unique< \ | |||||
Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||||
_postprocess_mode, _packmode>>( \ | |||||
pack_c_size); \ | |||||
} \ | |||||
} \ | |||||
MIDOUT_END() | MIDOUT_END() | ||||
#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||||
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||||
midout_iv(_midout_tag)) { \ | |||||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||||
return std::make_unique< \ | |||||
Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||||
DTypeTrait<_i_bias_type>::ctype, \ | |||||
DTypeTrait<_i_dst_type>::ctype, \ | |||||
_postprocess_mode, _packmode>>(pack_size); \ | |||||
} \ | |||||
} \ | |||||
#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||||
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||||
midout_iv(_midout_tag)) { \ | |||||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||||
return std::make_unique< \ | |||||
Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||||
DTypeTrait<_i_bias_type>::ctype, \ | |||||
DTypeTrait<_i_dst_type>::ctype, \ | |||||
_postprocess_mode, _packmode>>( \ | |||||
pack_c_size); \ | |||||
} \ | |||||
} \ | |||||
MIDOUT_END() | MIDOUT_END() | ||||
switch (pack_mode) { | switch (pack_mode) { | ||||
@@ -12,7 +12,6 @@ | |||||
#pragma once | #pragma once | ||||
#include "megdnn/opr_param_defs.h" | |||||
#include "src/fallback/conv_bias/opr_impl.h" | #include "src/fallback/conv_bias/opr_impl.h" | ||||
#if MEGDNN_X86 | #if MEGDNN_X86 | ||||
#include "src/x86/conv_bias/postprocess_helper.h" | #include "src/x86/conv_bias/postprocess_helper.h" | ||||
@@ -41,12 +40,15 @@ MatrixMulImpl::KernSizeParam get_matmul_kern_param( | |||||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | ||||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | (param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | ||||
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); | param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); | ||||
size_t pack_c_size = 1_z; | |||||
size_t pack_c_size = pack_size(param.filter_meta.format); | |||||
auto format = param::MatrixMul::Format::DEFAULT; | auto format = param::MatrixMul::Format::DEFAULT; | ||||
if(param.filter_meta.format == param::ConvBias::Format::NCHW44){ | |||||
pack_c_size = 4_z; | |||||
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||||
format = param::MatrixMul::Format::MK4; | format = param::MatrixMul::Format::MK4; | ||||
} else if (param.filter_meta.format == | |||||
param::ConvBias::Format::NCHW44_DOT) { | |||||
format = param::MatrixMul::Format::MK4_DOT; | |||||
} | } | ||||
return {param.filter_type, | return {param.filter_type, | ||||
param.src_type, | param.src_type, | ||||
is_dst_8bit ? param.bias_type : param.dst_type, | is_dst_8bit ? param.bias_type : param.dst_type, | ||||
@@ -15,7 +15,6 @@ | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/opr_impl.h" | #include "src/fallback/conv_bias/opr_impl.h" | ||||
#include "src/fallback/conv_bias/winograd/strategy.h" | |||||
#include "src/naive/convolution/helper.h" | #include "src/naive/convolution/helper.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
@@ -125,7 +124,7 @@ public: | |||||
size_t oc_tile_size) { | size_t oc_tile_size) { | ||||
size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0], | size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0], | ||||
FW = param.filter_meta.spatial[1]; | FW = param.filter_meta.spatial[1]; | ||||
size_t pack_oc_size = get_format_pack_size(param.filter_meta.format); | |||||
size_t pack_oc_size = pack_size(param.filter_meta.format); | |||||
size_t im2col = 0, packb = 0, bias_temp = 0; | size_t im2col = 0, packb = 0, bias_temp = 0; | ||||
bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT; | bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT; | ||||
megdnn_assert(default_pack, "only support default packa"); | megdnn_assert(default_pack, "only support default packa"); | ||||
@@ -319,9 +318,11 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | |||||
size_t ohw_tile_size, | size_t ohw_tile_size, | ||||
size_t oc_tile_size) const { | size_t oc_tile_size) const { | ||||
auto format = param::MatrixMul::Format::DEFAULT; | auto format = param::MatrixMul::Format::DEFAULT; | ||||
size_t pack_oc_size = get_format_pack_size(param.filter_meta.format); | |||||
size_t pack_oc_size = pack_size(param.filter_meta.format); | |||||
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | ||||
format = param::MatrixMul::Format::MK4; | format = param::MatrixMul::Format::MK4; | ||||
} else if(param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT){ | |||||
format = param::MatrixMul::Format::MK4_DOT; | |||||
} | } | ||||
size_t M = oc_tile_size; | size_t M = oc_tile_size; | ||||
size_t N = ohw_tile_size; | size_t N = ohw_tile_size; | ||||
@@ -351,11 +352,10 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | |||||
void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( | void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( | ||||
const NCBKernSizeParam& param, size_t& oc_tile_size, | const NCBKernSizeParam& param, size_t& oc_tile_size, | ||||
size_t& ohw_tile_size, size_t block_m, size_t block_n, | size_t& ohw_tile_size, size_t block_m, size_t block_n, | ||||
bool need_pack) const { | |||||
fallback::MatrixMulImpl::AlgoBase::PackMode pack_mode) const { | |||||
size_t nr_threads = param.nr_threads; | size_t nr_threads = param.nr_threads; | ||||
size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
size_t ohw = param.osz[0] * param.osz[1]; | size_t ohw = param.osz[0] * param.osz[1]; | ||||
oc_tile_size = DEFAULT_OC_TILE_SIZE; | oc_tile_size = DEFAULT_OC_TILE_SIZE; | ||||
ohw_tile_size = m_ohw_tile_size; | ohw_tile_size = m_ohw_tile_size; | ||||
@@ -376,7 +376,8 @@ void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( | |||||
} | } | ||||
} | } | ||||
} else { | } else { | ||||
if (!need_pack) { //! no pack ,usually in x86 save memroy | |||||
//! in no_pack mode don't do block operation when using single thread | |||||
if (pack_mode == fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { | |||||
ohw_tile_size = ohw; | ohw_tile_size = ohw; | ||||
oc_tile_size = OC; | oc_tile_size = OC; | ||||
} | } | ||||
@@ -406,7 +407,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||||
if (need_pack || only_packA) { | if (need_pack || only_packA) { | ||||
auto inner_block = m_matmul_algo->get_inner_block_size(); | auto inner_block = m_matmul_algo->get_inner_block_size(); | ||||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, inner_block.m, | choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, inner_block.m, | ||||
inner_block.n, need_pack); | |||||
inner_block.n, m_matmul_algo->packmode()); | |||||
auto im2col_kern_param = get_matmul_kern_param( | auto im2col_kern_param = get_matmul_kern_param( | ||||
param, ohw_tile_size, only_packA ? oc_tile_size : OC); | param, ohw_tile_size, only_packA ? oc_tile_size : OC); | ||||
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | ||||
@@ -418,7 +419,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||||
size_t nopack_default_blockn = 16; | size_t nopack_default_blockn = 16; | ||||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | ||||
nopack_default_blockm, nopack_default_blockn, | nopack_default_blockm, nopack_default_blockn, | ||||
need_pack); | |||||
m_matmul_algo->packmode()); | |||||
packa_group_size = 0; | packa_group_size = 0; | ||||
} | } | ||||
@@ -488,19 +489,20 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||||
if (default_pack || only_packA) { | if (default_pack || only_packA) { | ||||
auto inner_block = m_matmul_algo->get_inner_block_size(); | auto inner_block = m_matmul_algo->get_inner_block_size(); | ||||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | ||||
inner_block.m, inner_block.n, default_pack); | |||||
} else { //! not support pack,not need pack | |||||
inner_block.m, inner_block.n, | |||||
m_matmul_algo->packmode()); | |||||
} else { //! nopack_mode | |||||
size_t nopack_default_blockm = 8; | size_t nopack_default_blockm = 8; | ||||
size_t nopack_default_blockn = 16; | size_t nopack_default_blockn = 16; | ||||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | ||||
nopack_default_blockm, nopack_default_blockn, | nopack_default_blockm, nopack_default_blockn, | ||||
no_pack); | |||||
m_matmul_algo->packmode()); | |||||
} | } | ||||
size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size); | size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size); | ||||
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | ||||
size_t packa_parallel_times = 0; | size_t packa_parallel_times = 0; | ||||
size_t pack_oc_size = get_format_pack_size(param.filter_meta.format); | |||||
size_t pack_oc_size = pack_size(param.filter_meta.format); | |||||
if (only_packA) { | if (only_packA) { | ||||
packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | ||||
@@ -639,9 +641,15 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
ConvBiasImpl* opr, const NCBKernSizeParam& param, | ConvBiasImpl* opr, const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | AlgoSelectionStrategy /*algo_selection_strategy*/) const { | ||||
MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 2) { | MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 2) { | ||||
if (opr->param().format != param::ConvBias::Format::NCHW && | |||||
opr->param().format != param::ConvBias::Format::NCHW44_DOT && | |||||
opr->param().format != param::ConvBias::Format::NCHW44) { | |||||
return false; | |||||
} | |||||
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is | //! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is | ||||
//! identity otherwise return false mean that 8x8x32 and 8x8x16 not support | |||||
//! PostProcess | |||||
//! identity otherwise return false mean that 8x8x32 and 8x8x16 not | |||||
//! support PostProcess | |||||
if (param.src_type.enumv() == param.filter_type.enumv() && | if (param.src_type.enumv() == param.filter_type.enumv() && | ||||
((param.src_type.enumv() == DTypeEnum::Int8 && | ((param.src_type.enumv() == DTypeEnum::Int8 && | ||||
(param.dst_type.enumv() == DTypeEnum::Int16 || | (param.dst_type.enumv() == DTypeEnum::Int16 || | ||||
@@ -653,9 +661,10 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | ||||
return false; | return false; | ||||
} | } | ||||
if (opr->param().format == param::ConvBias::Format::NCHW44) { | |||||
if (opr->param().format == param::ConvBias::Format::NCHW44 || | |||||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | |||||
//! current NCHW44 im2col only support DEFAULT mode matmul | //! current NCHW44 im2col only support DEFAULT mode matmul | ||||
if(m_matmul_algo->packmode() != Pack_Mode::DEFAULT) { | |||||
if (m_matmul_algo->packmode() != Pack_Mode::DEFAULT) { | |||||
return false; | return false; | ||||
//! nchw44 hybird mode and channel wise is not support | //! nchw44 hybird mode and channel wise is not support | ||||
} else if (param.filter_meta.icpg < 4_z || | } else if (param.filter_meta.icpg < 4_z || | ||||
@@ -668,29 +677,27 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
size_t oc_tile_size = 0, ohw_tile_size = 0; | size_t oc_tile_size = 0, ohw_tile_size = 0; | ||||
Pack_Mode packmode = m_matmul_algo->packmode(); | Pack_Mode packmode = m_matmul_algo->packmode(); | ||||
bool default_pack = packmode == Pack_Mode::DEFAULT; | bool default_pack = packmode == Pack_Mode::DEFAULT; | ||||
bool no_pack = packmode == Pack_Mode::NO_PACK; | |||||
bool only_packA = packmode == Pack_Mode::ONLY_PACKA; | bool only_packA = packmode == Pack_Mode::ONLY_PACKA; | ||||
if (default_pack || only_packA) { | if (default_pack || only_packA) { | ||||
auto inner_block = m_matmul_algo->get_inner_block_size(); | auto inner_block = m_matmul_algo->get_inner_block_size(); | ||||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | ||||
inner_block.m, inner_block.n, default_pack); | |||||
inner_block.m, inner_block.n, | |||||
m_matmul_algo->packmode()); | |||||
} else { //! not support pack,not need pack | } else { //! not support pack,not need pack | ||||
size_t nopack_default_blockm = 8; | size_t nopack_default_blockm = 8; | ||||
size_t nopack_default_blockn = 16; | size_t nopack_default_blockn = 16; | ||||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | ||||
nopack_default_blockm, nopack_default_blockn, | nopack_default_blockm, nopack_default_blockn, | ||||
no_pack); | |||||
m_matmul_algo->packmode()); | |||||
} | } | ||||
fallback::MatrixMulImpl::KernSizeParam matmul_param = | fallback::MatrixMulImpl::KernSizeParam matmul_param = | ||||
get_matmul_kern_param(param, ohw_tile_size, oc_tile_size); | get_matmul_kern_param(param, ohw_tile_size, oc_tile_size); | ||||
bool matmulusable = m_matmul_algo->usable(matmul_param); | bool matmulusable = m_matmul_algo->usable(matmul_param); | ||||
return matmulusable && | return matmulusable && | ||||
(opr->param().format == param::ConvBias::Format::NCHW || | |||||
opr->param().format == param::ConvBias::Format::NCHW44) && | |||||
(!(param.filter_meta.spatial[0] == | (!(param.filter_meta.spatial[0] == | ||||
param.filter_meta.spatial[1] && | param.filter_meta.spatial[1] && | ||||
(param.filter_meta.spatial[0] == 1) && | |||||
param.filter_meta.spatial[0] == 1 && | |||||
param.filter_meta.stride[0] == param.filter_meta.stride[1] && | param.filter_meta.stride[0] == param.filter_meta.stride[1] && | ||||
param.filter_meta.stride[0] == 1)) && | param.filter_meta.stride[0] == 1)) && | ||||
(param.filter_meta.dilation[0] == | (param.filter_meta.dilation[0] == | ||||
@@ -36,10 +36,10 @@ class ConvBiasImpl::AlgoIm2col final : public AlgoBase { | |||||
const NCBKernSizeParam& param, size_t ohw_tile_size, | const NCBKernSizeParam& param, size_t ohw_tile_size, | ||||
size_t oc_tile_size) const; | size_t oc_tile_size) const; | ||||
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; | WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; | ||||
void choice_ohw_oc_block(const NCBKernSizeParam& param, | |||||
size_t& oc_tile_size, size_t& ohw_tile_size, | |||||
size_t block_m, size_t block_n, | |||||
bool pack_default) const; | |||||
void choice_ohw_oc_block( | |||||
const NCBKernSizeParam& param, size_t& oc_tile_size, | |||||
size_t& ohw_tile_size, size_t block_m, size_t block_n, | |||||
fallback::MatrixMulImpl::AlgoBase::PackMode pack_mode) const; | |||||
public: | public: | ||||
AlgoIm2col(MatrixMulImpl::AlgoBase* matmul_algo, size_t ohw_tile_size) | AlgoIm2col(MatrixMulImpl::AlgoBase* matmul_algo, size_t ohw_tile_size) | ||||
@@ -230,7 +230,11 @@ public: | |||||
PostprocessMode::FLOAT, | PostprocessMode::FLOAT, | ||||
"DefaultStrategyTypeNCHW44::FLOAT"_hash); | "DefaultStrategyTypeNCHW44::FLOAT"_hash); | ||||
} else { | } else { | ||||
megdnn_throw("not support format except nchw44 and nchw\n"); | |||||
megdnn_throw( | |||||
ssprintf("Current only support layout " | |||||
"NCHW44/NCHW for im2col " | |||||
"algo, but got %d\n", | |||||
uint32_t(format))); | |||||
} | } | ||||
break; | break; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
@@ -252,12 +256,17 @@ public: | |||||
cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | ||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | ||||
"DefaultStrategyType::INT8x8x32"_hash); | "DefaultStrategyType::INT8x8x32"_hash); | ||||
} else if (format == param::ConvBias::Format::NCHW44) { | |||||
} else if (format == param::ConvBias::Format::NCHW44 || | |||||
format == param::ConvBias::Format::NCHW44_DOT) { | |||||
cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | ||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | ||||
"DefaultStrategyType::INT8x8x32"_hash); | "DefaultStrategyType::INT8x8x32"_hash); | ||||
} else { | } else { | ||||
megdnn_throw("not support format except nchw44 and nchw\n"); | |||||
megdnn_throw( | |||||
ssprintf("Current only support layout " | |||||
"NCHW44/NCHW/NCHW_DOT for im2col " | |||||
"algo, but got %d\n", | |||||
uint32_t(format))); | |||||
} | } | ||||
break; | break; | ||||
@@ -288,13 +297,18 @@ public: | |||||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | ||||
PostprocessMode::NO_PROCESS, | PostprocessMode::NO_PROCESS, | ||||
"DefaultStrategyTypeNCHW::QINT8x8x32"_hash); | "DefaultStrategyTypeNCHW::QINT8x8x32"_hash); | ||||
} else if (format == param::ConvBias::Format::NCHW44) { | |||||
} else if (format == param::ConvBias::Format::NCHW44 || | |||||
format == param::ConvBias::Format::NCHW44_DOT) { | |||||
cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | ||||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, | dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, | ||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | ||||
"DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); | "DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); | ||||
} else { | } else { | ||||
megdnn_throw("not support format except nchw44 and nchw\n"); | |||||
megdnn_throw( | |||||
ssprintf("Current only support layout " | |||||
"NCHW44/NCHW/NCHW_DOT for im2col " | |||||
"algo, but got %d\n", | |||||
uint32_t(format))); | |||||
} | } | ||||
break; | break; | ||||
@@ -304,17 +318,22 @@ public: | |||||
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | ||||
PostprocessMode::QUANTIZED, | PostprocessMode::QUANTIZED, | ||||
"DefaultStrategyType::QINT8x8x32x8"_hash); | "DefaultStrategyType::QINT8x8x32x8"_hash); | ||||
} else if (format == param::ConvBias::Format::NCHW44) { | |||||
} else if (format == param::ConvBias::Format::NCHW44 || | |||||
format == param::ConvBias::Format::NCHW44_DOT) { | |||||
cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | ||||
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, | dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, | ||||
dt_int32, dt_int8, PostprocessMode::QUANTIZED, | dt_int32, dt_int8, PostprocessMode::QUANTIZED, | ||||
"DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash); | "DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash); | ||||
} else { | } else { | ||||
megdnn_throw("not support format except nchw44 and nchw\n"); | |||||
megdnn_throw(ssprintf("Current only support layout " | |||||
"NCHW44/NCHW/NCHW_DOT for im2col " | |||||
"algo, but got %d\n", | |||||
uint32_t(format))); | |||||
} | } | ||||
break; | break; | ||||
} | } | ||||
megdnn_throw("error not support strategy type "); | |||||
megdnn_throw(ssprintf("Unsupported strategy type %u in default mode", | |||||
uint32_t(strategytype))); | |||||
} | } | ||||
static std::unique_ptr<StrategyBase> make_nopack_strategy( | static std::unique_ptr<StrategyBase> make_nopack_strategy( | ||||
@@ -328,10 +347,6 @@ public: | |||||
PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash); | PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash); | ||||
break; | break; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
case StrategyType::FLOAT_FP16: | |||||
cb1(NCHW, NO_PACK, dt_float16, __fp16, PostprocessMode::FLOAT, | |||||
"NoPackStrategyType::FLOAT_FP16"_hash); | |||||
break; | |||||
#else | #else | ||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
case StrategyType::FLOAT16_FLOAT16: | case StrategyType::FLOAT16_FLOAT16: | ||||
@@ -341,48 +356,24 @@ public: | |||||
break; | break; | ||||
#endif | #endif | ||||
#endif | #endif | ||||
case StrategyType::INT8x8x32: | |||||
cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
"NoPackStrategyType::INT8x8x32"_hash); | |||||
break; | |||||
case StrategyType::INT8x8x16: | case StrategyType::INT8x8x16: | ||||
cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | ||||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | ||||
"NoPackStrategyType::INT8x8x16"_hash); | "NoPackStrategyType::INT8x8x16"_hash); | ||||
break; | break; | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
case StrategyType::QUINT8x8x32: | |||||
cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||||
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||||
PostprocessMode::NO_PROCESS, | |||||
"NoPackStrategyType::QUINT8x8x32"_hash); | |||||
break; | |||||
case StrategyType::QUINT8x8x32x8: | |||||
cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||||
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | |||||
PostprocessMode::QUANTIZED, | |||||
"NoPackStrategyType::QUINT8x8x32x8"_hash); | |||||
break; | |||||
#endif | |||||
case StrategyType::QINT8x8x32: | |||||
cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||||
PostprocessMode::NO_PROCESS, | |||||
"NoPackStrategyType::QINT8x8x32"_hash); | |||||
case StrategyType::INT8x8x32: | |||||
cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
"NoPackStrategyType::INT8x8x32"_hash); | |||||
break; | break; | ||||
case StrategyType::QINT8x8x32x8: | |||||
cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||||
PostprocessMode::QUANTIZED, | |||||
"NoPackStrategyType::QINT8x8x32x8"_hash); | |||||
default: | |||||
megdnn_throw( | |||||
ssprintf("Unsupported strategy type %u in no_pack mode", | |||||
uint32_t(strategytype))); | |||||
break; | break; | ||||
} | } | ||||
megdnn_throw("error not support strategy type "); | |||||
megdnn_throw(ssprintf("Unsupported strategy type %u in no_pack mode", | |||||
uint32_t(strategytype))); | |||||
} | } | ||||
static std::unique_ptr<StrategyBase> make_onlypacka_strategy( | static std::unique_ptr<StrategyBase> make_onlypacka_strategy( | ||||
@@ -396,63 +387,14 @@ public: | |||||
PostprocessMode::FLOAT, | PostprocessMode::FLOAT, | ||||
"OnlyPackaStrategyType::FLOAT"_hash); | "OnlyPackaStrategyType::FLOAT"_hash); | ||||
break; | break; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
case StrategyType::FLOAT_FP16: | |||||
cb1(NCHW, ONLY_PACKA, dt_float16, __fp16, | |||||
PostprocessMode::FLOAT, | |||||
"OnlyPackaStrategyType::FLOAT_FP16"_hash); | |||||
break; | |||||
#else | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
case StrategyType::FLOAT16_FLOAT16: | |||||
cb1(NCHW, ONLY_PACKA, dt_float16, dt_float16, | |||||
PostprocessMode::NO_PROCESS, | |||||
"OnlyPackaStrategyType::FLOAT16_FLOAT16"_hash); | |||||
break; | |||||
#endif | |||||
#endif | |||||
case StrategyType::INT8x8x32: | |||||
cb2(NCHW, ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
"OnlyPackaStrategyType::INT8x8x32"_hash); | |||||
break; | |||||
case StrategyType::INT8x8x16: | |||||
cb2(NCHW, ONLY_PACKA, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
"OnlyPackaStrategyType::INT8x8x16"_hash); | |||||
break; | |||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
case StrategyType::QUINT8x8x32: | |||||
cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm, | |||||
dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, | |||||
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
"OnlyPackaStrategyType::QUINT8x8x32"_hash); | |||||
break; | |||||
case StrategyType::QUINT8x8x32x8: | |||||
cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm, | |||||
dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, | |||||
dt_int32, dt_uint8, PostprocessMode::QUANTIZED, | |||||
"OnlyPackaStrategyType::QUINT8x8x32x8"_hash); | |||||
break; | |||||
#endif | |||||
case StrategyType::QINT8x8x32: | |||||
cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||||
PostprocessMode::NO_PROCESS, | |||||
"OnlyPackaStrategyType::QINT8x8x32"_hash); | |||||
break; | |||||
case StrategyType::QINT8x8x32x8: | |||||
cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||||
PostprocessMode::QUANTIZED, | |||||
"OnlyPackaStrategyType::QINT8x8x32x8"_hash); | |||||
default: | |||||
megdnn_throw(ssprintf( | |||||
"Unsupported strategy type %u in onlypacka mode", | |||||
uint32_t(strategytype))); | |||||
break; | break; | ||||
} | } | ||||
megdnn_throw("error not support strategy type "); | |||||
megdnn_throw(ssprintf("Unsupported strategy type %u in onlypacka mode", | |||||
uint32_t(strategytype))); | |||||
} | } | ||||
#undef cb1 | #undef cb1 | ||||
@@ -11,6 +11,16 @@ | |||||
#pragma once | #pragma once | ||||
#include "src/fallback/conv_bias/opr_impl.h" | #include "src/fallback/conv_bias/opr_impl.h" | ||||
#if MEGDNN_X86 | |||||
#include "src/x86/conv_bias/postprocess_helper.h" | |||||
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#endif | |||||
using namespace megdnn; | |||||
#if MEGDNN_X86 | |||||
using namespace x86; | |||||
#endif | |||||
namespace megdnn { | namespace megdnn { | ||||
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | ||||
@@ -75,6 +85,185 @@ public: | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode, PackMode packmode, | megdnn::PostprocessMode postprocess_mode, PackMode packmode, | ||||
FormatMode format> | |||||
//! this class is a new base class for StrategyDefault StrategyNoPack and so on, | |||||
//! in order to handle copy pad use the same code | |||||
class StrategyBridge : public StrategyBase { | |||||
public: | |||||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||||
StrategyBridge() = default; | |||||
virtual void copy_padding_kern( | |||||
WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_oc_size) override { | |||||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||||
MEGDNN_MARK_USED_VAR(N); | |||||
MEGDNN_MARK_USED_VAR(OC); | |||||
MEGDNN_MARK_USED_VAR(OH); | |||||
MEGDNN_MARK_USED_VAR(OW); | |||||
MEGDNN_MARK_USED_VAR(FH); | |||||
MEGDNN_MARK_USED_VAR(FW); | |||||
MEGDNN_MARK_USED_VAR(SH); | |||||
MEGDNN_MARK_USED_VAR(SW); | |||||
size_t IW2 = IW + 2 * PW; | |||||
size_t IH2 = IH + 2 * PH; | |||||
size_t batch_id = ncb_index.ndrange_id[0]; | |||||
size_t group_id = ncb_index.ndrange_id[1]; | |||||
size_t channel_id = ncb_index.ndrange_id[2]; | |||||
size_t PH_SIZE = PH * IW2 * pack_oc_size; | |||||
PW = PW * pack_oc_size; | |||||
IW = IW * pack_oc_size; | |||||
size_t padding_group_size = IH2 * IW2 * IC; | |||||
size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id; | |||||
size_t workspace_group_offset = group_id * padding_group_size; | |||||
size_t workspace_batch_offset = | |||||
param.filter_meta.group * batch_id * padding_group_size; | |||||
bundle.set(param.workspace_ptr); | |||||
src_ctype src_zp = static_cast<src_ctype>(0); | |||||
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||||
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||||
} | |||||
src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>( | |||||
batch_id, group_id, channel_id, 1, pack_oc_size)); | |||||
src_ctype* src2; | |||||
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||||
workspace_group_offset + workspace_batch_offset + | |||||
workspace_channel_offset; | |||||
src_ctype* src2_ptr = src2; | |||||
const src_ctype* src_ptr = src; | |||||
if (PH != 0) { | |||||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||||
src2_ptr += PH_SIZE; | |||||
} | |||||
rep(ih, IH) { | |||||
if (PW != 0) | |||||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW); | |||||
src2_ptr += IW; | |||||
src_ptr += IW; | |||||
if (PW != 0) | |||||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
} | |||||
if (PH != 0) { | |||||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||||
src2_ptr += PH_SIZE; | |||||
} | |||||
} | |||||
}; | |||||
namespace{ | |||||
template <typename bias_ctype> | |||||
inline void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const WorkspaceBundle& bundle_thread, | |||||
const StrategyParam& sparam, | |||||
size_t matmul_bundle_index) { | |||||
if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) { | |||||
return static_cast<void*>(bundle_thread.get(matmul_bundle_index)); | |||||
} else { | |||||
bias_ctype* dst = | |||||
param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) + | |||||
sparam.oc_cur_index * sparam.ohw; | |||||
return static_cast<void*>(dst); | |||||
} | |||||
} | |||||
template <typename bias_ctype> | |||||
inline void* get_bias_temp_ptr( | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const WorkspaceBundle& bundle_thread, size_t bias_bundle_index) { | |||||
bias_ctype* bias_tmp_ptr = | |||||
param.bias_mode == megdnn::BiasMode::BIAS | |||||
? static_cast<bias_ctype*>( | |||||
bundle_thread.get(bias_bundle_index)) | |||||
: nullptr; | |||||
return bias_tmp_ptr; | |||||
} | |||||
template <typename dst_ctype> | |||||
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const void* matmul_dst, const StrategyParam& sparam) { | |||||
if (!sparam.skip_copy_dst) { | |||||
size_t pack_oc_size = sparam.pack_oc_size; | |||||
dst_ctype* dst_tmp_ptr = | |||||
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | |||||
dst_ctype* dst = | |||||
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | |||||
sparam.oc_cur_index * sparam.ohw + | |||||
sparam.ohw_cur_index * pack_oc_size; | |||||
size_t oc_loop = sparam.output_block_oc_size / pack_oc_size; | |||||
for (size_t oc = 0; oc < oc_loop; oc++) { | |||||
std::memcpy(dst, dst_tmp_ptr, | |||||
sizeof(dst_ctype) * sparam.output_block_size * | |||||
pack_oc_size); | |||||
dst_tmp_ptr += sparam.output_block_size * pack_oc_size; | |||||
dst += sparam.ohw * pack_oc_size; | |||||
} | |||||
} | |||||
} | |||||
template <typename bias_ctype> | |||||
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
WorkspaceBundle bundle_thread, const StrategyParam& sparam, | |||||
size_t bias_index) { | |||||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||||
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||||
bias_ctype* bias_temp_ptr = static_cast<bias_ctype*>( | |||||
get_bias_temp_ptr<bias_ctype>(param, bundle_thread, bias_index)); | |||||
if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||||
bias_ctype* copy_dst = bias_temp_ptr; | |||||
size_t pack_oc_size = sparam.pack_oc_size; | |||||
const bias_ctype* copy_src = bias_ptr + | |||||
sparam.oc_cur_index * sparam.ohw + | |||||
sparam.ohw_cur_index * pack_oc_size; | |||||
for (size_t oc = sparam.oc_cur_index / pack_oc_size; | |||||
oc < sparam.oc_end_index / pack_oc_size; oc++) { | |||||
std::memcpy(copy_dst, copy_src, | |||||
sizeof(bias_ctype) * sparam.output_block_size * | |||||
pack_oc_size); | |||||
copy_dst += sparam.output_block_size * pack_oc_size; | |||||
copy_src += sparam.ohw * pack_oc_size; | |||||
} | |||||
} | |||||
} | |||||
template <typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
void do_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const StrategyParam& sparam, WorkspaceBundle bundle_thread, | |||||
size_t matmul_bundle_index, size_t bias_bundle_index) { | |||||
copy_bias<bias_ctype>(param, bundle_thread, sparam, bias_bundle_index); | |||||
void* matmul_dst = get_matmul_dst_ptr<bias_ctype>( | |||||
param, bundle_thread, sparam, matmul_bundle_index); | |||||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||||
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||||
void* bias_temp_ptr = get_bias_temp_ptr<bias_ctype>(param, bundle_thread, | |||||
bias_bundle_index); | |||||
void* bias_preprocess_ptr = const_cast<void*>( | |||||
param.bias_mode == megdnn::BiasMode::BIAS | |||||
? bias_temp_ptr | |||||
: static_cast<void*>(const_cast<bias_ctype*>( | |||||
bias_ptr + sparam.oc_cur_index))); | |||||
size_t pack_oc_size = sparam.pack_oc_size; | |||||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||||
matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode, | |||||
param.nonlineMode, param.bias_type, param.dst_type, 1_z, | |||||
sparam.output_block_oc_size / pack_oc_size, 1_z, | |||||
sparam.output_block_size, pack_oc_size); | |||||
copy_dst<dst_ctype>(param, matmul_dst, sparam); | |||||
} | |||||
} | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode, PackMode packmode, | |||||
FormatMode format = FormatMode::NCHW> | FormatMode format = FormatMode::NCHW> | ||||
class Strategy; | class Strategy; | ||||
@@ -82,7 +271,10 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::DEFAULT> : public StrategyBase { | |||||
postprocess_mode, PackMode::DEFAULT> | |||||
: public StrategyBridge<src_ctype, bias_ctype, dst_ctype, op_ctype, | |||||
op_dtype, postprocess_mode, PackMode::DEFAULT, | |||||
FormatMode::NCHW> { | |||||
public: | public: | ||||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | constexpr static size_t BUNDLE_PADDING_INDEX = 0; | ||||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | constexpr static size_t BUNDLE_PACKA_INDEX = 1; | ||||
@@ -92,13 +284,7 @@ public: | |||||
Strategy() = default; | Strategy() = default; | ||||
void copy_padding_kern( | |||||
WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_size) override; | |||||
void packA_kern(WorkspaceBundle bundle, | |||||
virtual void packA_kern(WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
@@ -120,16 +306,13 @@ public: | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | ||||
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
WorkspaceBundle bundle_thread) override; | |||||
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const void* matmul_dst, const StrategyParam& sparam); | |||||
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
WorkspaceBundle bundle_thread, const StrategyParam& sparam); | |||||
WorkspaceBundle bundle_thread) override { | |||||
do_postprocess<bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode>(param, sparam, bundle_thread, | |||||
THREAD_BUNDLE_IM2COL_INDEX, | |||||
THREAD_BUNDLE_BIAS_INDEX); | |||||
} | |||||
void* get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const WorkspaceBundle& bundle_thread); | |||||
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const WorkspaceBundle& bundle_thread, | const WorkspaceBundle& bundle_thread, | ||||
const StrategyParam& sparam); | const StrategyParam& sparam); | ||||
@@ -162,7 +345,10 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::NO_PACK> : public StrategyBase { | |||||
postprocess_mode, PackMode::NO_PACK> | |||||
: public StrategyBridge<src_ctype, bias_ctype, dst_ctype, op_ctype, | |||||
op_dtype, postprocess_mode, PackMode::NO_PACK, | |||||
FormatMode::NCHW> { | |||||
public: | public: | ||||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | constexpr static size_t BUNDLE_PADDING_INDEX = 0; | ||||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | constexpr static size_t BUNDLE_PACKA_INDEX = 1; | ||||
@@ -173,12 +359,6 @@ public: | |||||
Strategy() = default; | Strategy() = default; | ||||
void copy_padding_kern( | |||||
WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_size) override; | |||||
void packA_kern(WorkspaceBundle bundle, | void packA_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
@@ -198,17 +378,6 @@ public: | |||||
const WorkspaceBundle& bundle_thread, | const WorkspaceBundle& bundle_thread, | ||||
const StrategyParam& sparam); | const StrategyParam& sparam); | ||||
inline void* get_bias_temp_ptr( | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const WorkspaceBundle& bundle_thread) { | |||||
bias_ctype* bias_tmp_ptr = | |||||
param.bias_mode == megdnn::BiasMode::BIAS | |||||
? static_cast<bias_ctype*>( | |||||
bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX)) | |||||
: nullptr; | |||||
return bias_tmp_ptr; | |||||
} | |||||
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
@@ -216,19 +385,22 @@ public: | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | ||||
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
WorkspaceBundle bundle_thread) override; | |||||
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const void* matmul_dst, const StrategyParam& sparam); | |||||
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
WorkspaceBundle bundle_thread, const StrategyParam& sparam); | |||||
WorkspaceBundle bundle_thread) override { | |||||
do_postprocess<bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode>(param, sparam, bundle_thread, | |||||
THREAD_BUNDLE_MATMULDST_INDEX, | |||||
THREAD_BUNDLE_BIAS_INDEX); | |||||
} | |||||
}; | }; | ||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::ONLY_PACKA> : public StrategyBase { | |||||
postprocess_mode, PackMode::ONLY_PACKA> | |||||
: public StrategyBridge<src_ctype, bias_ctype, dst_ctype, op_ctype, | |||||
op_dtype, postprocess_mode, | |||||
PackMode::ONLY_PACKA,FormatMode::NCHW> { | |||||
public: | public: | ||||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | constexpr static size_t BUNDLE_PADDING_INDEX = 0; | ||||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | constexpr static size_t BUNDLE_PACKA_INDEX = 1; | ||||
@@ -239,12 +411,6 @@ public: | |||||
Strategy() = default; | Strategy() = default; | ||||
void copy_padding_kern( | |||||
WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_size) override; | |||||
void packA_kern(WorkspaceBundle bundle, | void packA_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
@@ -269,24 +435,15 @@ public: | |||||
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const WorkspaceBundle& bundle_thread, | const WorkspaceBundle& bundle_thread, | ||||
const StrategyParam& sparam); | const StrategyParam& sparam); | ||||
inline void* get_bias_temp_ptr( | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const WorkspaceBundle& bundle_thread) { | |||||
bias_ctype* bias_tmp_ptr = | |||||
param.bias_mode == megdnn::BiasMode::BIAS | |||||
? static_cast<bias_ctype*>( | |||||
bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX)) | |||||
: nullptr; | |||||
return bias_tmp_ptr; | |||||
} | |||||
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
WorkspaceBundle bundle_thread) override; | |||||
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const void* matmul_dst, const StrategyParam& sparam); | |||||
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
WorkspaceBundle bundle_thread, const StrategyParam& sparam); | |||||
WorkspaceBundle bundle_thread) override { | |||||
do_postprocess<bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode>(param, sparam, bundle_thread, | |||||
THREAD_BUNDLE_MATMULDST_INDEX, | |||||
THREAD_BUNDLE_BIAS_INDEX); | |||||
} | |||||
}; | }; | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -10,16 +10,7 @@ | |||||
*/ | */ | ||||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | #include "src/fallback/conv_bias/im2col/strategy_base.h" | ||||
#include "src/fallback/convolution/img2col_helper.h" | #include "src/fallback/convolution/img2col_helper.h" | ||||
#if MEGDNN_X86 | |||||
#include "src/x86/conv_bias/postprocess_helper.h" | |||||
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#endif | |||||
using namespace megdnn; | |||||
#if MEGDNN_X86 | |||||
using namespace x86; | |||||
#endif | |||||
namespace megdnn { | namespace megdnn { | ||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
@@ -27,73 +18,6 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::DEFAULT>:: | postprocess_mode, PackMode::DEFAULT>:: | ||||
copy_padding_kern(WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_oc_size) { | |||||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||||
MEGDNN_MARK_USED_VAR(N); | |||||
MEGDNN_MARK_USED_VAR(OC); | |||||
MEGDNN_MARK_USED_VAR(OH); | |||||
MEGDNN_MARK_USED_VAR(OW); | |||||
MEGDNN_MARK_USED_VAR(FH); | |||||
MEGDNN_MARK_USED_VAR(FW); | |||||
MEGDNN_MARK_USED_VAR(SH); | |||||
MEGDNN_MARK_USED_VAR(SW); | |||||
size_t IW2 = IW + 2 * PW; | |||||
size_t IH2 = IH + 2 * PH; | |||||
size_t batch_id = ncb_index.ndrange_id[0]; | |||||
size_t group_id = ncb_index.ndrange_id[1]; | |||||
size_t channel_id = ncb_index.ndrange_id[2]; | |||||
size_t PH_SIZE = PH * IW2 * pack_oc_size; | |||||
PW = PW * pack_oc_size; | |||||
IW = IW * pack_oc_size; | |||||
size_t padding_group_size = IH2 * IW2 * IC; | |||||
size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id; | |||||
size_t workspace_group_offset = group_id * padding_group_size; | |||||
size_t workspace_batch_offset = | |||||
param.filter_meta.group * batch_id * padding_group_size; | |||||
bundle.set(param.workspace_ptr); | |||||
src_ctype src_zp = static_cast<src_ctype>(0); | |||||
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||||
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||||
} | |||||
src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>( | |||||
batch_id, group_id, channel_id, 1, pack_oc_size)); | |||||
src_ctype* src2; | |||||
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||||
workspace_group_offset + workspace_batch_offset + | |||||
workspace_channel_offset; | |||||
src_ctype* src2_ptr = src2; | |||||
const src_ctype* src_ptr = src; | |||||
if (PH != 0) { | |||||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||||
src2_ptr += PH_SIZE; | |||||
} | |||||
rep(ih, IH) { | |||||
if (PW != 0) | |||||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW); | |||||
src2_ptr += IW; | |||||
src_ptr += IW; | |||||
if (PW != 0) | |||||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
} | |||||
if (PH != 0) { | |||||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||||
src2_ptr += PH_SIZE; | |||||
} | |||||
} | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::DEFAULT>:: | |||||
packA_kern(WorkspaceBundle bundle, | packA_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
@@ -244,100 +168,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
matmul_kern_naked(matmul_param, a_panel, b_panel); | matmul_kern_naked(matmul_param, a_panel, b_panel); | ||||
} | } | ||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::DEFAULT>:: | |||||
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const StrategyParam& sparam, | |||||
WorkspaceBundle bundle_thread) { | |||||
copy_bias(param, bundle_thread, sparam); | |||||
void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam); | |||||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||||
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||||
void* bias_temp_ptr = get_bias_temp_ptr(param, bundle_thread); | |||||
void* bias_preprocess_ptr = const_cast<void*>( | |||||
param.bias_mode == megdnn::BiasMode::BIAS | |||||
? bias_temp_ptr | |||||
: static_cast<void*>(const_cast<bias_ctype*>( | |||||
bias_ptr + sparam.oc_cur_index))); | |||||
size_t pack_oc_size = sparam.pack_oc_size; | |||||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||||
matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode, | |||||
param.nonlineMode, param.bias_type, param.dst_type, 1_z, | |||||
sparam.output_block_oc_size / pack_oc_size, 1_z, | |||||
sparam.output_block_size, pack_oc_size); | |||||
copy_dst(param, matmul_dst, sparam); | |||||
} | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::DEFAULT>:: | |||||
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const void* matmul_dst, const StrategyParam& sparam) { | |||||
if (!sparam.skip_copy_dst) { | |||||
size_t pack_oc_size = sparam.pack_oc_size; | |||||
dst_ctype* dst_tmp_ptr = | |||||
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | |||||
dst_ctype* dst = | |||||
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | |||||
sparam.oc_cur_index * sparam.ohw + | |||||
sparam.ohw_cur_index * pack_oc_size; | |||||
size_t oc_loop = sparam.output_block_oc_size / pack_oc_size; | |||||
for (size_t oc = 0; oc < oc_loop; oc++) { | |||||
std::memcpy(dst, dst_tmp_ptr, | |||||
sizeof(dst_ctype) * sparam.output_block_size * | |||||
pack_oc_size); | |||||
dst_tmp_ptr += sparam.output_block_size * pack_oc_size; | |||||
dst += sparam.ohw * pack_oc_size; | |||||
} | |||||
} | |||||
} | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::DEFAULT>:: | |||||
get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const WorkspaceBundle& bundle_thread) { | |||||
bias_ctype* bias_tmp_ptr = | |||||
param.bias_mode == megdnn::BiasMode::BIAS | |||||
? static_cast<bias_ctype*>( | |||||
bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX)) | |||||
: nullptr; | |||||
return bias_tmp_ptr; | |||||
} | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::DEFAULT>:: | |||||
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | |||||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||||
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||||
bias_ctype* bias_temp_ptr = | |||||
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread)); | |||||
if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||||
bias_ctype* copy_dst = bias_temp_ptr; | |||||
const bias_ctype* copy_src = bias_ptr + | |||||
sparam.oc_cur_index * sparam.ohw + | |||||
sparam.ohw_cur_index; | |||||
for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) { | |||||
std::memcpy(copy_dst, copy_src, | |||||
sizeof(bias_ctype) * sparam.output_block_size); | |||||
copy_dst += sparam.output_block_size; | |||||
copy_src += sparam.ohw; | |||||
} | |||||
} | |||||
} | |||||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | ||||
_op_dtype, _postprocess_mode) \ | _op_dtype, _postprocess_mode) \ | ||||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | ||||
@@ -11,16 +11,7 @@ | |||||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | #include "src/fallback/conv_bias/im2col/strategy_base.h" | ||||
#include "src/fallback/convolution/img2col_helper.h" | #include "src/fallback/convolution/img2col_helper.h" | ||||
#if MEGDNN_X86 | |||||
#include "src/x86/conv_bias/postprocess_helper.h" | |||||
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#endif | |||||
using namespace megdnn; | |||||
#if MEGDNN_X86 | |||||
using namespace x86; | |||||
#endif | |||||
namespace megdnn { | namespace megdnn { | ||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
@@ -28,69 +19,6 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::NO_PACK>:: | postprocess_mode, PackMode::NO_PACK>:: | ||||
copy_padding_kern(WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t) { | |||||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||||
MEGDNN_MARK_USED_VAR(N); | |||||
MEGDNN_MARK_USED_VAR(OC); | |||||
MEGDNN_MARK_USED_VAR(OH); | |||||
MEGDNN_MARK_USED_VAR(OW); | |||||
MEGDNN_MARK_USED_VAR(FH); | |||||
MEGDNN_MARK_USED_VAR(FW); | |||||
MEGDNN_MARK_USED_VAR(SH); | |||||
MEGDNN_MARK_USED_VAR(SW); | |||||
size_t IW2 = IW + 2 * PW; | |||||
size_t IH2 = IH + 2 * PH; | |||||
size_t batch_id = ncb_index.ndrange_id[0]; | |||||
size_t group_id = ncb_index.ndrange_id[1]; | |||||
size_t channel_id = ncb_index.ndrange_id[2]; | |||||
size_t padding_group_size = IH2 * IW2 * IC; | |||||
size_t workspace_channel_offset = IH2 * IW2 * channel_id; | |||||
size_t workspace_group_offset = group_id * padding_group_size; | |||||
size_t workspace_batch_offset = | |||||
param.filter_meta.group * batch_id * padding_group_size; | |||||
bundle.set(param.workspace_ptr); | |||||
src_ctype src_zp = static_cast<src_ctype>(0); | |||||
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||||
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||||
} | |||||
src_ctype* src = const_cast<src_ctype*>( | |||||
param.src<src_ctype>(batch_id, group_id, channel_id)); | |||||
src_ctype* src2; | |||||
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||||
workspace_group_offset + workspace_batch_offset + | |||||
workspace_channel_offset; | |||||
src_ctype* src2_ptr = src2; | |||||
const src_ctype* src_ptr = src; | |||||
if (PH != 0) { | |||||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||||
src2_ptr += PH * IW2; | |||||
} | |||||
rep(ih, IH) { | |||||
if (PW != 0) | |||||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW); | |||||
src2_ptr += IW; | |||||
src_ptr += IW; | |||||
if (PW != 0) | |||||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
} | |||||
if (PH != 0) { | |||||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||||
src2_ptr += PH * IW2; | |||||
} | |||||
} | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::NO_PACK>:: | |||||
packA_kern(WorkspaceBundle bundle, | packA_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
@@ -220,81 +148,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
} | } | ||||
} | } | ||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::NO_PACK>:: | |||||
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const StrategyParam& sparam, | |||||
WorkspaceBundle bundle_thread) { | |||||
copy_bias(param, bundle_thread, sparam); | |||||
void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam); | |||||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||||
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||||
bias_ctype* bias_temp_ptr = | |||||
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread)); | |||||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||||
matmul_dst, | |||||
const_cast<void*>( | |||||
param.bias_mode == megdnn::BiasMode::BIAS | |||||
? bias_temp_ptr | |||||
: static_cast<void*>(const_cast<bias_ctype*>( | |||||
bias_ptr + sparam.oc_cur_index))), | |||||
matmul_dst, param.bias_mode, param.nonlineMode, param.bias_type, | |||||
param.dst_type, 1_z, sparam.output_block_oc_size, 1_z, | |||||
sparam.output_block_size); | |||||
copy_dst(param, matmul_dst, sparam); | |||||
} | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::NO_PACK>:: | |||||
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const void* matmul_dst, const StrategyParam& sparam) { | |||||
if (!sparam.skip_copy_dst) { | |||||
dst_ctype* dst_tmp_ptr = | |||||
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | |||||
dst_ctype* dst = | |||||
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | |||||
sparam.oc_cur_index * sparam.ohw + sparam.ohw_cur_index; | |||||
for (size_t oc = 0; oc < sparam.output_block_oc_size; oc++) { | |||||
std::memcpy(dst, dst_tmp_ptr, | |||||
sizeof(dst_ctype) * sparam.output_block_size); | |||||
dst_tmp_ptr += sparam.output_block_size; | |||||
dst += sparam.ohw; | |||||
} | |||||
} | |||||
} | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::NO_PACK>:: | |||||
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | |||||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||||
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||||
bias_ctype* bias_temp_ptr = | |||||
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread)); | |||||
if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||||
bias_ctype* copy_dst = bias_temp_ptr; | |||||
const bias_ctype* copy_src = bias_ptr + | |||||
sparam.oc_cur_index * sparam.ohw + | |||||
sparam.ohw_cur_index; | |||||
for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) { | |||||
std::memcpy(copy_dst, copy_src, | |||||
sizeof(bias_ctype) * sparam.output_block_size); | |||||
copy_dst += sparam.output_block_size; | |||||
copy_src += sparam.ohw; | |||||
} | |||||
} | |||||
} | |||||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | ||||
_op_dtype, _postprocess_mode) \ | _op_dtype, _postprocess_mode) \ | ||||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | ||||
@@ -302,34 +155,18 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | ||||
megdnn::PostprocessMode::FLOAT) | megdnn::PostprocessMode::FLOAT) | ||||
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||||
megdnn::PostprocessMode::FLOAT) | |||||
#else | #else | ||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | ||||
megdnn::PostprocessMode::NO_PROCESS) | megdnn::PostprocessMode::NO_PROCESS) | ||||
#endif | #endif | ||||
#endif | #endif | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||||
megdnn::PostprocessMode::QUANTIZED) | |||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
#endif | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||||
megdnn::PostprocessMode::QUANTIZED) | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
#undef INSTANTIAL_CLASS | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -11,16 +11,7 @@ | |||||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | #include "src/fallback/conv_bias/im2col/strategy_base.h" | ||||
#include "src/fallback/convolution/img2col_helper.h" | #include "src/fallback/convolution/img2col_helper.h" | ||||
#if MEGDNN_X86 | |||||
#include "src/x86/conv_bias/postprocess_helper.h" | |||||
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#endif | |||||
using namespace megdnn; | |||||
#if MEGDNN_X86 | |||||
using namespace x86; | |||||
#endif | |||||
namespace megdnn { | namespace megdnn { | ||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
@@ -28,69 +19,6 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::ONLY_PACKA>:: | postprocess_mode, PackMode::ONLY_PACKA>:: | ||||
copy_padding_kern(WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t) { | |||||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||||
MEGDNN_MARK_USED_VAR(N); | |||||
MEGDNN_MARK_USED_VAR(OC); | |||||
MEGDNN_MARK_USED_VAR(OH); | |||||
MEGDNN_MARK_USED_VAR(OW); | |||||
MEGDNN_MARK_USED_VAR(FH); | |||||
MEGDNN_MARK_USED_VAR(FW); | |||||
MEGDNN_MARK_USED_VAR(SH); | |||||
MEGDNN_MARK_USED_VAR(SW); | |||||
size_t IW2 = IW + 2 * PW; | |||||
size_t IH2 = IH + 2 * PH; | |||||
size_t batch_id = ncb_index.ndrange_id[0]; | |||||
size_t group_id = ncb_index.ndrange_id[1]; | |||||
size_t channel_id = ncb_index.ndrange_id[2]; | |||||
size_t padding_group_size = IH2 * IW2 * IC; | |||||
size_t workspace_channel_offset = IH2 * IW2 * channel_id; | |||||
size_t workspace_group_offset = group_id * padding_group_size; | |||||
size_t workspace_batch_offset = | |||||
param.filter_meta.group * batch_id * padding_group_size; | |||||
bundle.set(param.workspace_ptr); | |||||
src_ctype src_zp = static_cast<src_ctype>(0); | |||||
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||||
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||||
} | |||||
src_ctype* src = const_cast<src_ctype*>( | |||||
param.src<src_ctype>(batch_id, group_id, channel_id)); | |||||
src_ctype* src2; | |||||
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||||
workspace_group_offset + workspace_batch_offset + | |||||
workspace_channel_offset; | |||||
src_ctype* src2_ptr = src2; | |||||
const src_ctype* src_ptr = src; | |||||
if (PH != 0) { | |||||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||||
src2_ptr += PH * IW2; | |||||
} | |||||
rep(ih, IH) { | |||||
if (PW != 0) | |||||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW); | |||||
src2_ptr += IW; | |||||
src_ptr += IW; | |||||
if (PW != 0) | |||||
rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
} | |||||
if (PH != 0) { | |||||
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||||
src2_ptr += PH * IW2; | |||||
} | |||||
} | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
packA_kern(WorkspaceBundle bundle, | packA_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
@@ -123,25 +51,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const WorkspaceBundle& bundle_thread, | |||||
const StrategyParam& sparam) { | |||||
if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) { | |||||
return static_cast<void*>( | |||||
bundle_thread.get(THREAD_BUNDLE_MATMULDST_INDEX)); | |||||
} else { | |||||
bias_ctype* dst = | |||||
param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) + | |||||
sparam.oc_cur_index * sparam.ohw; | |||||
return static_cast<void*>(dst); | |||||
} | |||||
} | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::ONLY_PACKA>:: | postprocess_mode, PackMode::ONLY_PACKA>:: | ||||
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
@@ -241,63 +150,19 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const StrategyParam& sparam, | |||||
WorkspaceBundle bundle_thread) { | |||||
void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam); | |||||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||||
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||||
bias_ctype* bias_temp_ptr = | |||||
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread)); | |||||
if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||||
bias_ctype* copy_dst = bias_temp_ptr; | |||||
const bias_ctype* copy_src = bias_ptr + | |||||
sparam.oc_cur_index * sparam.ohw + | |||||
sparam.ohw_cur_index; | |||||
for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) { | |||||
std::memcpy(copy_dst, copy_src, | |||||
sizeof(bias_ctype) * sparam.output_block_size); | |||||
copy_dst += sparam.output_block_size; | |||||
copy_src += sparam.ohw; | |||||
} | |||||
} | |||||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||||
matmul_dst, | |||||
const_cast<void*>( | |||||
param.bias_mode == megdnn::BiasMode::BIAS | |||||
? bias_temp_ptr | |||||
: static_cast<void*>(const_cast<bias_ctype*>( | |||||
bias_ptr + sparam.oc_cur_index))), | |||||
matmul_dst, param.bias_mode, param.nonlineMode, param.bias_type, | |||||
param.dst_type, 1_z, sparam.output_block_oc_size, 1_z, | |||||
sparam.output_block_size); | |||||
copy_dst(param, matmul_dst, sparam); | |||||
} | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | |||||
megdnn::PostprocessMode postprocess_mode> | |||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const void* matmul_dst, const StrategyParam& sparam) { | |||||
if (!sparam.skip_copy_dst) { | |||||
dst_ctype* dst_tmp_ptr = | |||||
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | |||||
dst_ctype* dst = | |||||
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | |||||
sparam.oc_cur_index * sparam.ohw + sparam.ohw_cur_index; | |||||
for (size_t oc = 0; oc < sparam.output_block_oc_size; oc++) { | |||||
std::memcpy(dst, dst_tmp_ptr, | |||||
sizeof(dst_ctype) * sparam.output_block_size); | |||||
dst_tmp_ptr += sparam.output_block_size; | |||||
dst += sparam.ohw; | |||||
} | |||||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
const WorkspaceBundle& bundle_thread, | |||||
const StrategyParam& sparam) { | |||||
if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) { | |||||
return static_cast<bias_ctype*>( | |||||
bundle_thread.get(THREAD_BUNDLE_MATMULDST_INDEX)); | |||||
} else { | |||||
bias_ctype* dst = | |||||
param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) + | |||||
sparam.oc_cur_index * sparam.ohw; | |||||
return static_cast<void*>(dst); | |||||
} | } | ||||
} | } | ||||
@@ -310,33 +175,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | ||||
megdnn::PostprocessMode::FLOAT) | megdnn::PostprocessMode::FLOAT) | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||||
megdnn::PostprocessMode::FLOAT) | |||||
#else | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
#endif | |||||
#endif | |||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||||
megdnn::PostprocessMode::QUANTIZED) | |||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
#endif | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||||
megdnn::PostprocessMode::QUANTIZED) | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
#undef INSTANTIAL_CLASS | #undef INSTANTIAL_CLASS | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -26,7 +26,7 @@ | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace fallback; | using namespace fallback; | ||||
size_t megdnn::fallback::get_format_pack_size(param::ConvBias::Format format) { | |||||
size_t megdnn::fallback::pack_size(param::ConvBias::Format format) { | |||||
switch (format) { | switch (format) { | ||||
case param::ConvBias::Format::NCHW44: | case param::ConvBias::Format::NCHW44: | ||||
case param::ConvBias::Format::NCHW44_DOT: | case param::ConvBias::Format::NCHW44_DOT: | ||||
@@ -23,8 +23,10 @@ namespace fallback { | |||||
/*! | /*! | ||||
* \brief get the pack_size according to the format | * \brief get the pack_size according to the format | ||||
* Note TODO: when remove format from param, | |||||
* may using like this "opr::param::format specify" | |||||
* */ | * */ | ||||
size_t get_format_pack_size(param::ConvBias::Format format); | |||||
size_t pack_size(param::ConvBias::Format format); | |||||
/*! | /*! | ||||
* \brief fallback conv bias forward impl | * \brief fallback conv bias forward impl | ||||
@@ -52,9 +52,21 @@ class GemmInterleaved<Strategy, true> { | |||||
} | } | ||||
size_t get_b_workspace_size() const { | size_t get_b_workspace_size() const { | ||||
#if __ARM_FEATURE_DOTPROD | |||||
size_t new_blockn = m_strategy.block_n; | |||||
if (m_strategy.KERNEL_W == 6 && m_strategy.UNROLL_K == 4 && | |||||
m_strategy.KERNEL_H == 8) { | |||||
new_blockn = round_up<size_t>((m_strategy.block_n-1) % 6, 4) + | |||||
m_strategy.block_n / 6 * 6; | |||||
} | |||||
size_t N = round_up(new_blockn, m_strategy.KERNEL_W); | |||||
size_t K = round_up(m_strategy.block_k, m_strategy.UNROLL_K); | |||||
return round_up(sizeof(stype) * N * K, CACHELINE_SIZE) + m_align_size; | |||||
#else | |||||
size_t N = round_up(m_strategy.block_n, m_strategy.KERNEL_W); | size_t N = round_up(m_strategy.block_n, m_strategy.KERNEL_W); | ||||
size_t K = round_up(m_strategy.block_k, m_strategy.UNROLL_K); | size_t K = round_up(m_strategy.block_k, m_strategy.UNROLL_K); | ||||
return round_up(sizeof(stype) * N * K, CACHELINE_SIZE) + m_align_size; | return round_up(sizeof(stype) * N * K, CACHELINE_SIZE) + m_align_size; | ||||
#endif | |||||
} | } | ||||
//! temporary storage for output, post process such as add bias or relu will | //! temporary storage for output, post process such as add bias or relu will | ||||
@@ -268,7 +268,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { | |||||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | ||||
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false); | "IM2COLMATMUL:AARCH64_F32K8X12X1:192", false); | ||||
#else | #else | ||||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | |||||
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384", | |||||
"IM2COLMATMUL:ARMV7_F32:192", true); | "IM2COLMATMUL:ARMV7_F32:192", true); | ||||
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | ||||
"IM2COLMATMUL:ARMV7_F32:192", false); | "IM2COLMATMUL:ARMV7_F32:192", false); | ||||
@@ -72,10 +72,12 @@ std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args( | |||||
std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | ||||
std::vector<size_t> kernel_vec, size_t stride, bool no_pad = false, | std::vector<size_t> kernel_vec, size_t stride, bool no_pad = false, | ||||
bool no_bias = false, bool no_nonlinemode = false, | bool no_bias = false, bool no_nonlinemode = false, | ||||
bool is_input_nchw = false, bool support_full_bias = false, | |||||
bool support_sigmoid = false) { | |||||
bool is_input_nchw = false, bool is_nchw44_dot = false, | |||||
bool support_full_bias = false, bool support_sigmoid = false, | |||||
bool only_no_bias = false) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
using NLMode = param::ConvBias::NonlineMode; | using NLMode = param::ConvBias::NonlineMode; | ||||
std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, | auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, | ||||
@@ -102,7 +104,11 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||||
size_t kernel_h = kernel; | size_t kernel_h = kernel; | ||||
size_t kernel_w = kernel; | size_t kernel_w = kernel; | ||||
param::ConvBias param; | param::ConvBias param; | ||||
param.format = param::ConvBias::Format::NCHW44; | |||||
if (!is_nchw44_dot) { | |||||
param.format = param::ConvBias::Format::NCHW44; | |||||
} else { | |||||
param.format = param::ConvBias::Format::NCHW44_DOT; | |||||
} | |||||
param.stride_h = stride; | param.stride_h = stride; | ||||
param.stride_w = stride; | param.stride_w = stride; | ||||
param.pad_h = pad; | param.pad_h = pad; | ||||
@@ -155,18 +161,22 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||||
if (support_sigmoid) { | if (support_sigmoid) { | ||||
nonlinemode.emplace_back(NLMode::SIGMOID); | nonlinemode.emplace_back(NLMode::SIGMOID); | ||||
} | } | ||||
std::vector<megdnn::BiasMode> bias_mode = { | |||||
megdnn::BiasMode::BROADCAST_CHANNEL_BIAS}; | |||||
if (no_bias) { | |||||
std::vector<megdnn::BiasMode> bias_mode; | |||||
if (!only_no_bias) { | |||||
bias_mode.emplace_back(megdnn::BiasMode::BROADCAST_CHANNEL_BIAS); | |||||
if (no_bias) { | |||||
bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS); | |||||
} | |||||
} else { | |||||
bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS); | bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS); | ||||
} | } | ||||
if (support_full_bias) { | if (support_full_bias) { | ||||
bias_mode.emplace_back(megdnn::BiasMode::BIAS); | |||||
bias_mode.emplace_back(megdnn::BiasMode::BIAS); | |||||
} | } | ||||
for (auto bias : bias_mode) | for (auto bias : bias_mode) | ||||
for (auto nlmode : nonlinemode) | for (auto nlmode : nonlinemode) | ||||
for (size_t n : {1, 2}) | |||||
for (size_t n : {1,2}) | |||||
for (size_t kernel : kernel_vec) | for (size_t kernel : kernel_vec) | ||||
for (size_t oc : {4, 12}) | for (size_t oc : {4, 12}) | ||||
for (size_t ic : {1, 3, 4, 12}) | for (size_t ic : {1, 3, 4, 12}) | ||||
@@ -361,19 +371,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) { | ||||
check_conv_bias(get_nchw44_conv_bias_args({2, 3}, 1, false, false, false, | check_conv_bias(get_nchw44_conv_bias_args({2, 3}, 1, false, false, false, | ||||
false, true, true), | |||||
false, false, true, true), | |||||
handle(), "F32_CONV_NCHW44_DIRECT"); | handle(), "F32_CONV_NCHW44_DIRECT"); | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) { | ||||
check_conv_bias(get_nchw44_conv_bias_args({5}, 1, false, false, false, | check_conv_bias(get_nchw44_conv_bias_args({5}, 1, false, false, false, | ||||
false, true, true), | |||||
false, false, true, true), | |||||
handle(), "F32_CONV_NCHW44_DIRECT"); | handle(), "F32_CONV_NCHW44_DIRECT"); | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) { | ||||
check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, | check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, | ||||
false, false, true, true), | |||||
false, false, false, true, true), | |||||
handle(), "F32_CONV_NCHW44_DIRECT"); | handle(), "F32_CONV_NCHW44_DIRECT"); | ||||
} | } | ||||
@@ -1420,6 +1430,111 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) { | |||||
#endif | #endif | ||||
#undef cb | #undef cb | ||||
} | } | ||||
#if __ARM_FEATURE_DOTPROD | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) { | |||||
UniformIntRNG rng{-50, 50}; | |||||
#define cb(name) \ | |||||
checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, \ | |||||
false, false, false, true), \ | |||||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||||
dtype::QuantizedS8(60.25f), name); \ | |||||
checker_conv_bias( \ | |||||
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true), \ | |||||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||||
dtype::QuantizedS8(60.25f), name); | |||||
float epsilon = 0.001; | |||||
#if MEGDNN_AARCH64 | |||||
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); | |||||
#elif MEGDNN_ARMV7 | |||||
cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X6X4_DOTPROD:96"); | |||||
#endif | |||||
#undef cb | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) { | |||||
UniformIntRNG rng{-50, 50}; | |||||
#define cb(name) \ | |||||
checker_conv_bias( \ | |||||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||||
true, false, true, false, false, true), \ | |||||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \ | |||||
checker_conv_bias( \ | |||||
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | |||||
false, false, true), \ | |||||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); | |||||
float epsilon = 0.001; | |||||
#if MEGDNN_AARCH64 | |||||
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); | |||||
#elif MEGDNN_ARMV7 | |||||
cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X6X4_DOTPROD:96"); | |||||
#endif | |||||
#undef cb | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32_MK4_DOT) { | |||||
UniformIntRNG rng{-50, 50}; | |||||
#define cb(name) \ | |||||
checker_conv_bias( \ | |||||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||||
true, false, true, false, false, true), \ | |||||
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | |||||
dtype::Int32(), {}, name); \ | |||||
checker_conv_bias( \ | |||||
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | |||||
false, false, true), \ | |||||
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | |||||
dtype::Int32(), {}, name); | |||||
float epsilon = 0.001; | |||||
#if MEGDNN_AARCH64 | |||||
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); | |||||
#elif MEGDNN_ARMV7 | |||||
cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X6X4_DOTPROD:96"); | |||||
#endif | |||||
#undef cb | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) { | |||||
UniformIntRNG rng{-50, 50}; | |||||
#define cb(name) \ | |||||
checker_conv_bias( \ | |||||
get_nchw44_conv_bias_args({1}, 1, true, true, false, false, true), \ | |||||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||||
dtype::QuantizedS8(60.25f), name); \ | |||||
checker_conv_bias( \ | |||||
get_nchw44_conv_bias_args({1}, 1, true, true, true, false, true, \ | |||||
false, false, true), \ | |||||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \ | |||||
checker_conv_bias( \ | |||||
get_nchw44_conv_bias_args({1}, 1, true, true, true, false, true, \ | |||||
false, false, true), \ | |||||
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | |||||
dtype::Int32(), {}, name); | |||||
float epsilon = 0.001; | |||||
#if MEGDNN_AARCH64 | |||||
cb("CONV1x1:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"); | |||||
#elif MEGDNN_ARMV7 | |||||
cb("CONV1x1:AARCH32_INT8_MK4_8X6X4_DOTPROD"); | |||||
#endif | |||||
#undef cb | |||||
} | |||||
#endif | |||||
// clang-format on | // clang-format on | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { | ||||
@@ -1685,8 +1800,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({2, 4, 7}, 1); | |||||
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | |||||
{2, 4, 7}, 1, false, false, false, false, false, true,true); | |||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | ||||
#elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
@@ -1696,8 +1811,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({3, 5, 6}, 2); | |||||
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | |||||
{3, 5, 6}, 2, false, false, false, false, false, true, true); | |||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | ||||
#elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
@@ -897,6 +897,62 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32) { | |||||
#undef cb | #undef cb | ||||
} | } | ||||
#if MEGDNN_X86_WITH_MKL || MEGDNN_X86_WITH_OPENBLAS | |||||
TEST_F(X86, CONV_BIAS_IM2COLMATMUL_FP32) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args; | |||||
auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, | |||||
size_t p, NonlineMode nonline_mode) { | |||||
if (w + 2 * p < kernel || h + 2 * p < kernel) | |||||
return; | |||||
param::ConvBias param; | |||||
param.stride_h = 1; | |||||
param.stride_w = 1; | |||||
param.pad_h = p; | |||||
param.pad_w = p; | |||||
param.nonlineMode = nonline_mode; | |||||
//! no bias | |||||
args.emplace_back(param, TensorShape{1, ic, h, w}, | |||||
TensorShape{oc, ic, kernel, kernel}, TensorShape{}); | |||||
args.emplace_back(param, TensorShape{1, ic, h, w}, | |||||
TensorShape{oc, ic, kernel, kernel}, | |||||
TensorShape{1, oc, 1, 1}); | |||||
args.emplace_back( | |||||
param, TensorShape{1, ic, h, w}, | |||||
TensorShape{oc, ic, kernel, kernel}, | |||||
TensorShape{1, oc, (h + 2 * p - kernel) / param.stride_h + 1, | |||||
(w + 2 * p - kernel) / param.stride_w + 1}); | |||||
}; | |||||
for (size_t kernel : {2, 3, 4, 5, 6, 7}) | |||||
for (size_t ic : {1, 4, 8, 16}) | |||||
for (size_t oc : {1, 4, 8, 16, 300}) | |||||
for (size_t p : {0, 2}) | |||||
for (size_t size : {8, 24}) | |||||
for (NonlineMode nonline_mode : | |||||
{NonlineMode::IDENTITY, NonlineMode::RELU}) { | |||||
run(oc, ic, size, size, kernel, p, nonline_mode); | |||||
} | |||||
run(2046, 8, 20, 20, 3, 1, NonlineMode::IDENTITY); | |||||
Checker<ConvBias> checker(handle()); | |||||
#define cb(algo_name) \ | |||||
checker.set_before_exec_callback( \ | |||||
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name)); \ | |||||
for (auto&& arg : args) { \ | |||||
checker.set_param(arg.param).execs( \ | |||||
{arg.src, arg.filter, arg.bias, {}, {}}); \ | |||||
} | |||||
cb("IM2COLMATMUL:X86_F32_BLAS"); | |||||
#undef cb | |||||
} | |||||
#endif | |||||
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | ||||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32_PACKA) { | TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32_PACKA) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||