@@ -749,7 +749,7 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
const int8_t* inptr1 = inptr0 + ldin; | const int8_t* inptr1 = inptr0 + ldin; | ||||
const int8_t* inptr2 = inptr1 + ldin; | const int8_t* inptr2 = inptr1 + ldin; | ||||
const int8_t* inptr3 = inptr2 + ldin; | const int8_t* inptr3 = inptr2 + ldin; | ||||
int8_t* output = outptr + start_y * out_offset; | |||||
int8_t* output = outptr + (y - y0) / 4 * out_offset; | |||||
prefetch_2x(inptr0); | prefetch_2x(inptr0); | ||||
prefetch_2x(inptr1); | prefetch_2x(inptr1); | ||||
prefetch_2x(inptr2); | prefetch_2x(inptr2); | ||||
@@ -776,7 +776,7 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
} | } | ||||
for (; y + 3 < ymax; y += 4, start_y++) { | for (; y + 3 < ymax; y += 4, start_y++) { | ||||
const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; | const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; | ||||
int8_t* output = outptr + start_y * out_offset; | |||||
int8_t* output = outptr + (y - y0) / 4 * out_offset; | |||||
prefetch_2x(inptr0); | prefetch_2x(inptr0); | ||||
int K = kmax - k0; | int K = kmax - k0; | ||||
for (; K > 15; K -= 16) { | for (; K > 15; K -= 16) { | ||||
@@ -227,7 +227,7 @@ static void gemm_mk4_s8_4x2_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
const int8_t* inptr1 = inptr0 + ldin; | const int8_t* inptr1 = inptr0 + ldin; | ||||
const int8_t* inptr2 = inptr1 + ldin; | const int8_t* inptr2 = inptr1 + ldin; | ||||
const int8_t* inptr3 = inptr2 + ldin; | const int8_t* inptr3 = inptr2 + ldin; | ||||
int8_t* output = outptr + start_y * out_offset; | |||||
int8_t* output = outptr + (y - y0) / 4 * out_offset; | |||||
prefetch_2x(inptr0); | prefetch_2x(inptr0); | ||||
prefetch_2x(inptr1); | prefetch_2x(inptr1); | ||||
prefetch_2x(inptr2); | prefetch_2x(inptr2); | ||||
@@ -254,7 +254,7 @@ static void gemm_mk4_s8_4x2_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
} | } | ||||
for (; y + 3 < ymax; y += 4, start_y++) { | for (; y + 3 < ymax; y += 4, start_y++) { | ||||
const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; | const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; | ||||
int8_t* output = outptr + start_y * out_offset; | |||||
int8_t* output = outptr + (y - y0) / 4 * out_offset; | |||||
prefetch_2x(inptr0); | prefetch_2x(inptr0); | ||||
int K = kmax - k0; | int K = kmax - k0; | ||||
for (; K > 15; K -= 16) { | for (; K > 15; K -= 16) { | ||||
@@ -22,20 +22,6 @@ namespace conv1x1 { | |||||
namespace { | namespace { | ||||
size_t get_format_pack_size(param::ConvBias::Format format) { | |||||
switch(format){ | |||||
case param::ConvBias::Format::NCHW44: | |||||
case param::ConvBias::Format::NCHW4: | |||||
return 4_z; | |||||
case param::ConvBias::Format::NCHW88: | |||||
return 8_z; | |||||
case param::ConvBias::Format::NCHW: | |||||
return 1_z; | |||||
default: | |||||
megdnn_throw("unknow pack size of the format"); | |||||
} | |||||
} | |||||
struct StrategyHashParam { | struct StrategyHashParam { | ||||
ConvBiasImpl::NCBKernSizeParam param; | ConvBiasImpl::NCBKernSizeParam param; | ||||
param::ConvBias::Format format; | param::ConvBias::Format format; | ||||
@@ -125,13 +125,10 @@ public: | |||||
size_t oc_tile_size) { | size_t oc_tile_size) { | ||||
size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0], | size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0], | ||||
FW = param.filter_meta.spatial[1]; | FW = param.filter_meta.spatial[1]; | ||||
size_t pack_oc_size = 1; | |||||
size_t pack_oc_size = get_format_pack_size(param.filter_meta.format); | |||||
size_t im2col = 0, packb = 0, bias_temp = 0; | size_t im2col = 0, packb = 0, bias_temp = 0; | ||||
bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT; | bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT; | ||||
megdnn_assert(default_pack, "only support default packa"); | megdnn_assert(default_pack, "only support default packa"); | ||||
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||||
pack_oc_size = 4; | |||||
} | |||||
size_t im2col_dst_size = | size_t im2col_dst_size = | ||||
IC * FH * FW * ohw_tile_size * sizeof(param.src_type); | IC * FH * FW * ohw_tile_size * sizeof(param.src_type); | ||||
size_t matmul_dst_size = pack_oc_size * oc_tile_size * ohw_tile_size * | size_t matmul_dst_size = pack_oc_size * oc_tile_size * ohw_tile_size * | ||||
@@ -321,14 +318,17 @@ fallback::MatrixMulImpl::KernSizeParam | |||||
ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | ||||
size_t ohw_tile_size, | size_t ohw_tile_size, | ||||
size_t oc_tile_size) const { | size_t oc_tile_size) const { | ||||
bool is_nchw44 = | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW44; | |||||
auto format = param::MatrixMul::Format::DEFAULT; | |||||
size_t pack_oc_size = get_format_pack_size(param.filter_meta.format); | |||||
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||||
format = param::MatrixMul::Format::MK4; | |||||
} | |||||
size_t M = oc_tile_size; | size_t M = oc_tile_size; | ||||
size_t N = ohw_tile_size; | size_t N = ohw_tile_size; | ||||
size_t K = param.filter_meta.icpg * param.filter_meta.spatial[0] * | size_t K = param.filter_meta.icpg * param.filter_meta.spatial[0] * | ||||
param.filter_meta.spatial[1]; | param.filter_meta.spatial[1]; | ||||
size_t pack_oc_size = is_nchw44 ? 4 : 1; | |||||
size_t LDA = pack_oc_size * K, LDB = pack_oc_size * N, LDC = N; | |||||
size_t LDA = pack_oc_size * K, LDB = pack_oc_size * N, | |||||
LDC = N * pack_oc_size; | |||||
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 && | bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | ||||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | (param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | ||||
@@ -345,8 +345,7 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | |||||
false, | false, | ||||
false, | false, | ||||
param::MatrixMul::ComputeMode::DEFAULT, | param::MatrixMul::ComputeMode::DEFAULT, | ||||
is_nchw44 ? param::MatrixMul::Format::MK4 | |||||
: param::MatrixMul::Format::DEFAULT}; | |||||
format}; | |||||
} | } | ||||
void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( | void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( | ||||
@@ -356,11 +355,7 @@ void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( | |||||
size_t nr_threads = param.nr_threads; | size_t nr_threads = param.nr_threads; | ||||
size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
size_t ohw = param.osz[0] * param.osz[1]; | size_t ohw = param.osz[0] * param.osz[1]; | ||||
//! pay attention please, should not change the 2 line code, | |||||
//! the opr use the same im2col algo, via choice_ohw_oc_block may change the | |||||
//! m_ohw_tile_size and m_oc_tile_size, if the two value changed, the | |||||
//! workspace size may change, will ocur workspace not match problem, so | |||||
//! should use the original data init them to avoid the problem | |||||
oc_tile_size = DEFAULT_OC_TILE_SIZE; | oc_tile_size = DEFAULT_OC_TILE_SIZE; | ||||
ohw_tile_size = m_ohw_tile_size; | ohw_tile_size = m_ohw_tile_size; | ||||
@@ -505,14 +500,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||||
size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size); | size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size); | ||||
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | ||||
size_t packa_parallel_times = 0; | size_t packa_parallel_times = 0; | ||||
size_t pack_oc_size = | |||||
(param.filter_meta.format == param::ConvBias::Format::NCHW ? 1 | |||||
: 4); | |||||
size_t pack_oc_size = get_format_pack_size(param.filter_meta.format); | |||||
if (only_packA) { | if (only_packA) { | ||||
packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | ||||
} else if (default_pack) { | } else if (default_pack) { | ||||
packa_parallel_times = div_ceil<size_t>( | packa_parallel_times = div_ceil<size_t>( | ||||
OC, m_matmul_algo->get_inner_block_size().m * pack_oc_size); | |||||
OC, m_matmul_algo->get_inner_block_size().m); | |||||
} | } | ||||
auto matmul_param = get_matmul_kern_param( | auto matmul_param = get_matmul_kern_param( | ||||
@@ -659,12 +653,16 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | ||||
return false; | return false; | ||||
} | } | ||||
//! current now im2col only support int8 quantized s8 nchw44 | |||||
if (opr->param().format == param::ConvBias::Format::NCHW44 && | |||||
(param.src_type.enumv() == param.filter_type.enumv() && | |||||
(param.src_type.enumv() != DTypeEnum::Int8) && | |||||
(param.src_type.enumv() != DTypeEnum::QuantizedS8))) { | |||||
return false; | |||||
if (opr->param().format == param::ConvBias::Format::NCHW44) { | |||||
//! current NCHW44 im2col only support DEFAULT mode matmul | |||||
if(m_matmul_algo->packmode() != Pack_Mode::DEFAULT) { | |||||
return false; | |||||
//! nchw44 hybird mode and channel wise is not support | |||||
} else if (param.filter_meta.icpg < 4_z || | |||||
param.filter_meta.icpg == 1 || | |||||
param.filter_meta.ocpg == 1) { | |||||
return false; | |||||
} | |||||
} | } | ||||
size_t oc_tile_size = 0, ohw_tile_size = 0; | size_t oc_tile_size = 0, ohw_tile_size = 0; | ||||
@@ -221,8 +221,17 @@ public: | |||||
param::ConvBias::Format format = param.filter_meta.format; | param::ConvBias::Format format = param.filter_meta.format; | ||||
switch (strategytype) { | switch (strategytype) { | ||||
case StrategyType::FLOAT: | case StrategyType::FLOAT: | ||||
cb1(NCHW, DEFAULT, dt_float32, dt_float32, | |||||
PostprocessMode::FLOAT, "DefaultStrategyType::FLOAT"_hash); | |||||
if (format == param::ConvBias::Format::NCHW) { | |||||
cb1(NCHW, DEFAULT, dt_float32, dt_float32, | |||||
PostprocessMode::FLOAT, | |||||
"DefaultStrategyType::FLOAT"_hash); | |||||
} else if (format == param::ConvBias::Format::NCHW44) { | |||||
cb1(NCHW44, DEFAULT, dt_float32, dt_float32, | |||||
PostprocessMode::FLOAT, | |||||
"DefaultStrategyTypeNCHW44::FLOAT"_hash); | |||||
} else { | |||||
megdnn_throw("not support format except nchw44 and nchw\n"); | |||||
} | |||||
break; | break; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
case StrategyType::FLOAT_FP16: | case StrategyType::FLOAT_FP16: | ||||
@@ -75,15 +75,14 @@ public: | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode, PackMode packmode, | megdnn::PostprocessMode postprocess_mode, PackMode packmode, | ||||
FormatMode format> | |||||
FormatMode format = FormatMode::NCHW> | |||||
class Strategy; | class Strategy; | ||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW> | |||||
: public StrategyBase { | |||||
postprocess_mode, PackMode::DEFAULT> : public StrategyBase { | |||||
public: | public: | ||||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | constexpr static size_t BUNDLE_PADDING_INDEX = 0; | ||||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | constexpr static size_t BUNDLE_PACKA_INDEX = 1; | ||||
@@ -142,8 +141,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW44> | postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW44> | ||||
: public Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | : public Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::DEFAULT, | |||||
FormatMode::NCHW> { | |||||
postprocess_mode, PackMode::DEFAULT> { | |||||
public: | public: | ||||
const size_t BUNDLE_PADDING_INDEX = 0; | const size_t BUNDLE_PADDING_INDEX = 0; | ||||
const size_t BUNDLE_PACKA_INDEX = 1; | const size_t BUNDLE_PACKA_INDEX = 1; | ||||
@@ -164,8 +162,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW> | |||||
: public StrategyBase { | |||||
postprocess_mode, PackMode::NO_PACK> : public StrategyBase { | |||||
public: | public: | ||||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | constexpr static size_t BUNDLE_PADDING_INDEX = 0; | ||||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | constexpr static size_t BUNDLE_PACKA_INDEX = 1; | ||||
@@ -231,8 +228,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW> | |||||
: public StrategyBase { | |||||
postprocess_mode, PackMode::ONLY_PACKA> : public StrategyBase { | |||||
public: | public: | ||||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | constexpr static size_t BUNDLE_PADDING_INDEX = 0; | ||||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | constexpr static size_t BUNDLE_PACKA_INDEX = 1; | ||||
@@ -26,7 +26,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::DEFAULT>:: | |||||
copy_padding_kern(WorkspaceBundle bundle, | copy_padding_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
@@ -93,13 +93,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::DEFAULT>:: | |||||
packA_kern(WorkspaceBundle bundle, | packA_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
size_t pack_oc_size) { | |||||
size_t) { | |||||
bundle.set(param.workspace_ptr); | bundle.set(param.workspace_ptr); | ||||
fallback::MatrixMulImpl::KernParam matmul_param; | fallback::MatrixMulImpl::KernParam matmul_param; | ||||
size_t group_id = ncb_index.ndrange_id[0]; | size_t group_id = ncb_index.ndrange_id[0]; | ||||
@@ -112,19 +112,18 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
matmul_algo->get_packA_type_size(); | matmul_algo->get_packA_type_size(); | ||||
size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size; | size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size; | ||||
int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) + | int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) + | ||||
group_id * packA_group_size + | |||||
(pack_oc_size == 4 ? 0 : a_panel_offset); | |||||
group_id * packA_group_size + a_panel_offset; | |||||
matmul_param.A_ptr = | matmul_param.A_ptr = | ||||
const_cast<src_ctype*>(param.filter<src_ctype>(group_id)); | const_cast<src_ctype*>(param.filter<src_ctype>(group_id)); | ||||
matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1], | matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1], | ||||
matmul_algo->get_inner_block_size().m * pack_oc_size); | |||||
matmul_algo->get_inner_block_size().m); | |||||
} | } | ||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::DEFAULT>:: | |||||
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
@@ -193,7 +192,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::DEFAULT>:: | |||||
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const WorkspaceBundle& bundle_thread, | const WorkspaceBundle& bundle_thread, | ||||
const StrategyParam& sparam) { | const StrategyParam& sparam) { | ||||
@@ -212,7 +211,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::DEFAULT>:: | |||||
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, WorkspaceBundle bundle, | const StrategyParam& sparam, WorkspaceBundle bundle, | ||||
WorkspaceBundle bundle_thread, | WorkspaceBundle bundle_thread, | ||||
@@ -249,7 +248,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::DEFAULT>:: | |||||
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
WorkspaceBundle bundle_thread) { | WorkspaceBundle bundle_thread) { | ||||
@@ -264,12 +263,12 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
? bias_temp_ptr | ? bias_temp_ptr | ||||
: static_cast<void*>(const_cast<bias_ctype*>( | : static_cast<void*>(const_cast<bias_ctype*>( | ||||
bias_ptr + sparam.oc_cur_index))); | bias_ptr + sparam.oc_cur_index))); | ||||
size_t pack_oc_size = sparam.pack_oc_size; | |||||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | ||||
matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode, | matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode, | ||||
param.nonlineMode, param.bias_type, param.dst_type, 1_z, | param.nonlineMode, param.bias_type, param.dst_type, 1_z, | ||||
sparam.output_block_oc_size, 1_z, sparam.output_block_size, | |||||
sparam.pack_oc_size); | |||||
sparam.output_block_oc_size / pack_oc_size, 1_z, | |||||
sparam.output_block_size, pack_oc_size); | |||||
copy_dst(param, matmul_dst, sparam); | copy_dst(param, matmul_dst, sparam); | ||||
} | } | ||||
@@ -277,7 +276,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::DEFAULT>:: | |||||
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const void* matmul_dst, const StrategyParam& sparam) { | const void* matmul_dst, const StrategyParam& sparam) { | ||||
if (!sparam.skip_copy_dst) { | if (!sparam.skip_copy_dst) { | ||||
@@ -303,7 +302,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::DEFAULT>:: | |||||
get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const WorkspaceBundle& bundle_thread) { | const WorkspaceBundle& bundle_thread) { | ||||
bias_ctype* bias_tmp_ptr = | bias_ctype* bias_tmp_ptr = | ||||
@@ -318,7 +317,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::DEFAULT>:: | |||||
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | ||||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | ||||
@@ -339,11 +338,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
} | } | ||||
} | } | ||||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode) \ | |||||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode, PackMode::DEFAULT, \ | |||||
FormatMode::NCHW>; | |||||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode) \ | |||||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode, PackMode::DEFAULT>; | |||||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | ||||
megdnn::PostprocessMode::FLOAT) | megdnn::PostprocessMode::FLOAT) | ||||
@@ -12,10 +12,9 @@ | |||||
#include "src/fallback/convolution/img2col_helper.h" | #include "src/fallback/convolution/img2col_helper.h" | ||||
#if MEGDNN_X86 | #if MEGDNN_X86 | ||||
#include "src/x86/conv_bias/postprocess_helper.h" | #include "src/x86/conv_bias/postprocess_helper.h" | ||||
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#endif | #endif | ||||
using namespace megdnn; | using namespace megdnn; | ||||
#if MEGDNN_X86 | #if MEGDNN_X86 | ||||
using namespace x86; | using namespace x86; | ||||
@@ -101,23 +100,12 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | ||||
megdnn::PostprocessMode::FLOAT) | megdnn::PostprocessMode::FLOAT) | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||||
megdnn::PostprocessMode::FLOAT) | |||||
#else | |||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | ||||
megdnn::PostprocessMode::NO_PROCESS) | megdnn::PostprocessMode::NO_PROCESS) | ||||
#endif | #endif | ||||
#endif | |||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||||
megdnn::PostprocessMode::QUANTIZED) | |||||
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
megdnn::PostprocessMode::NO_PROCESS) | |||||
#endif | |||||
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | ||||
megdnn::PostprocessMode::QUANTIZED) | megdnn::PostprocessMode::QUANTIZED) | ||||
@@ -27,7 +27,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::NO_PACK>:: | |||||
copy_padding_kern(WorkspaceBundle bundle, | copy_padding_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
@@ -90,7 +90,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::NO_PACK>:: | |||||
packA_kern(WorkspaceBundle bundle, | packA_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
@@ -110,7 +110,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::NO_PACK>:: | |||||
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const WorkspaceBundle& bundle_thread, | const WorkspaceBundle& bundle_thread, | ||||
const StrategyParam& sparam) { | const StrategyParam& sparam) { | ||||
@@ -129,7 +129,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::NO_PACK>:: | |||||
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, WorkspaceBundle bundle, | const StrategyParam& sparam, WorkspaceBundle bundle, | ||||
WorkspaceBundle bundle_thread, | WorkspaceBundle bundle_thread, | ||||
@@ -162,7 +162,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::NO_PACK>:: | |||||
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
@@ -224,7 +224,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::NO_PACK>:: | |||||
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
WorkspaceBundle bundle_thread) { | WorkspaceBundle bundle_thread) { | ||||
@@ -252,7 +252,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::NO_PACK>:: | |||||
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const void* matmul_dst, const StrategyParam& sparam) { | const void* matmul_dst, const StrategyParam& sparam) { | ||||
if (!sparam.skip_copy_dst) { | if (!sparam.skip_copy_dst) { | ||||
@@ -274,7 +274,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::NO_PACK>:: | |||||
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | ||||
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | ||||
@@ -295,11 +295,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
} | } | ||||
} | } | ||||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode) \ | |||||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode, PackMode::NO_PACK, \ | |||||
FormatMode::NCHW>; | |||||
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode) \ | |||||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||||
_op_dtype, _postprocess_mode, PackMode::NO_PACK>; | |||||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | ||||
megdnn::PostprocessMode::FLOAT) | megdnn::PostprocessMode::FLOAT) | ||||
@@ -27,7 +27,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
copy_padding_kern(WorkspaceBundle bundle, | copy_padding_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
@@ -90,7 +90,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
packA_kern(WorkspaceBundle bundle, | packA_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
@@ -124,7 +124,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const WorkspaceBundle& bundle_thread, | const WorkspaceBundle& bundle_thread, | ||||
const StrategyParam& sparam) { | const StrategyParam& sparam) { | ||||
@@ -143,7 +143,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, WorkspaceBundle bundle, | const StrategyParam& sparam, WorkspaceBundle bundle, | ||||
WorkspaceBundle bundle_thread, | WorkspaceBundle bundle_thread, | ||||
@@ -181,7 +181,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
@@ -242,7 +242,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
WorkspaceBundle bundle_thread) { | WorkspaceBundle bundle_thread) { | ||||
@@ -283,7 +283,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||||
postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const void* matmul_dst, const StrategyParam& sparam) { | const void* matmul_dst, const StrategyParam& sparam) { | ||||
if (!sparam.skip_copy_dst) { | if (!sparam.skip_copy_dst) { | ||||
@@ -305,7 +305,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
_op_dtype, _postprocess_mode) \ | _op_dtype, _postprocess_mode) \ | ||||
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | ||||
_op_dtype, _postprocess_mode, \ | _op_dtype, _postprocess_mode, \ | ||||
PackMode::ONLY_PACKA, FormatMode::NCHW>; | |||||
PackMode::ONLY_PACKA>; | |||||
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | ||||
megdnn::PostprocessMode::FLOAT) | megdnn::PostprocessMode::FLOAT) | ||||
@@ -26,6 +26,18 @@ | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace fallback; | using namespace fallback; | ||||
size_t megdnn::fallback::get_format_pack_size(param::ConvBias::Format format) { | |||||
switch(format){ | |||||
case param::ConvBias::Format::NCHW44: | |||||
case param::ConvBias::Format::NCHW4: | |||||
return 4_z; | |||||
case param::ConvBias::Format::NCHW88: | |||||
return 8_z; | |||||
default: | |||||
return 1_z; | |||||
} | |||||
} | |||||
namespace { | namespace { | ||||
template <typename T> | template <typename T> | ||||
void incr_ptr(T*& dst, ptrdiff_t delta) { | void incr_ptr(T*& dst, ptrdiff_t delta) { | ||||
@@ -22,6 +22,11 @@ namespace megdnn { | |||||
namespace fallback { | namespace fallback { | ||||
/*! | /*! | ||||
* \brief get the pack_size according to the format | |||||
* */ | |||||
size_t get_format_pack_size(param::ConvBias::Format format); | |||||
/*! | |||||
* \brief fallback conv bias forward impl | * \brief fallback conv bias forward impl | ||||
* | * | ||||
* Note: this operator class serves for multiple purposes: | * Note: this operator class serves for multiple purposes: | ||||
@@ -9,9 +9,8 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#if MEGDNN_ARMV7 || MEGDNN_AARCH64 | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#endif | |||||
namespace { | namespace { | ||||
template <bool is_xcorr, typename dtype> | template <bool is_xcorr, typename dtype> | ||||
@@ -268,12 +267,13 @@ void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst, | |||||
} | } | ||||
for (int w = cur_remain_w; w < OW; w++) { | for (int w = cur_remain_w; w < OW; w++) { | ||||
size_t index = ic * IH * IW + (start_h + fh2) * IW + | |||||
(w + fw2); | |||||
dst[i++] = src[4 * index]; | |||||
dst[i++] = src[4 * index + 1]; | |||||
dst[i++] = src[4 * index + 2]; | |||||
dst[i++] = src[4 * index + 3]; | |||||
size_t index = | |||||
4 * (ic * IH * IW + (start_h + fh2) * IW + | |||||
(w + fw2)); | |||||
dst[i++] = src[index]; | |||||
dst[i++] = src[index + 1]; | |||||
dst[i++] = src[index + 2]; | |||||
dst[i++] = src[index + 3]; | |||||
} | } | ||||
for (int h = start_h + 1; h < end_h; h++) { | for (int h = start_h + 1; h < end_h; h++) { | ||||
@@ -317,26 +317,11 @@ void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst, | |||||
fh2 = FH - fh - 1; | fh2 = FH - fh - 1; | ||||
fw2 = FW - fw - 1; | fw2 = FW - fw - 1; | ||||
} | } | ||||
#if MEGDNN_ARMV7 || MEGDNN_AARCH64 | |||||
int w = cur_remain_w; | |||||
size_t index = (ic * IH * IW + (start_h + fh2) * IW + | |||||
(w + fw2)); | |||||
for (; w + 3 < end_remain_w; w += 4) { | |||||
vst1q_u32(&output[i], | |||||
vld1q_u32(&uint32_src[index])); | |||||
i += 4; | |||||
index += 4; | |||||
} | |||||
for (; w < end_remain_w; w++) { | |||||
output[i++] = uint32_src[index]; | |||||
} | |||||
#else | |||||
for (int w = cur_remain_w; w < end_remain_w; w++) { | for (int w = cur_remain_w; w < end_remain_w; w++) { | ||||
size_t index = (ic * IH * IW + | size_t index = (ic * IH * IW + | ||||
(start_h + fh2) * IW + (w + fw2)); | (start_h + fh2) * IW + (w + fw2)); | ||||
output[i++] = uint32_src[index]; | output[i++] = uint32_src[index]; | ||||
} | } | ||||
#endif | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -360,27 +345,11 @@ void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst, | |||||
} | } | ||||
for (int h = start_h + 1; h < end_h; h++) { | for (int h = start_h + 1; h < end_h; h++) { | ||||
#if MEGDNN_ARMV7 || MEGDNN_AARCH64 | |||||
int ow = 0; | |||||
size_t index = (ic * IH * IW + (h + fh2) * IW + | |||||
(ow + fw2)); | |||||
for (; ow + 3 < OW; ow += 4) { | |||||
vst1q_u32(&output[i], | |||||
vld1q_u32(&uint32_src[index])); | |||||
i += 4; | |||||
index += 4; | |||||
} | |||||
for (; ow < OW; ow++) { | |||||
output[i++] = uint32_src[index++]; | |||||
} | |||||
#else | |||||
rep(ow, OW) { | rep(ow, OW) { | ||||
size_t index = (ic * IH * IW + (h + fh2) * IW + | size_t index = (ic * IH * IW + (h + fh2) * IW + | ||||
(ow + fw2)); | (ow + fw2)); | ||||
output[i++] = uint32_src[index]; | output[i++] = uint32_src[index]; | ||||
} | } | ||||
#endif | |||||
} | } | ||||
for (int w = 0; w < end_remain_w; w++) { | for (int w = 0; w < end_remain_w; w++) { | ||||
@@ -1173,10 +1173,10 @@ void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args, | |||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
#if !__ARM_FEATURE_DOTPROD | #if !__ARM_FEATURE_DOTPROD | ||||
TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = | std::vector<conv_bias::TestArg> args = | ||||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true); | |||||
get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true); | |||||
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
@@ -1187,10 +1187,10 @@ TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44) { | |||||
#undef cb | #undef cb | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_MULTI) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = | std::vector<conv_bias::TestArg> args = | ||||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true); | |||||
get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true); | |||||
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
@@ -1202,13 +1202,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_MULTI) { | |||||
#undef cb | #undef cb | ||||
} | } | ||||
TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_S2) { | |||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
#define cb(name) \ | |||||
checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1), \ | |||||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||||
#define cb(name) \ | |||||
checker_conv_bias(get_nchw44_conv_bias_args({3, 4, 6}, 2), handle(), &rng, \ | |||||
epsilon, dtype::QuantizedS8(2.5f), \ | |||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||||
dtype::QuantizedS8(60.25f), name); | dtype::QuantizedS8(60.25f), name); | ||||
float epsilon = 0.001; | float epsilon = 0.001; | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
@@ -1220,13 +1221,13 @@ TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44) { | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, | TEST_F(ARM_COMMON_MULTI_THREADS, | ||||
CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_MULTI) { | |||||
CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_S1) { | |||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
#define cb(name) \ | |||||
checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1), \ | |||||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||||
#define cb(name) \ | |||||
checker_conv_bias(get_nchw44_conv_bias_args({2, 5, 7}, 1), handle(), &rng, \ | |||||
epsilon, dtype::QuantizedS8(2.5f), \ | |||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||||
dtype::QuantizedS8(60.25f), name); | dtype::QuantizedS8(60.25f), name); | ||||
float epsilon = 0.001; | float epsilon = 0.001; | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
@@ -1286,6 +1287,24 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { | |||||
#undef cb | #undef cb | ||||
} | } | ||||
#if MEGDNN_AARCH64 | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({2, 4, 7}, 1); | |||||
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | |||||
} | |||||
#endif | |||||
#if MEGDNN_AARCH64 | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) { | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({3, 5, 6}, 2); | |||||
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | |||||
} | |||||
#endif | |||||
/***************************** Conv1x1 Algo Test ***********************/ | /***************************** Conv1x1 Algo Test ***********************/ | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||