@@ -49,6 +49,14 @@ namespace { | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | ||||
dst_type, N, OC, OH* OW); | dst_type, N, OC, OH* OW); | ||||
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \ | |||||
megdnn::arm_common::OpCallerBinary<_op<ctype>, \ | |||||
megdnn::arm_common::VEC_BCAST101x4>:: \ | |||||
run(static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||||
dst_type, N, OC, OH* OW, pack_oc_size); | |||||
#define FOR_NONLINEAR_BINARY(_op) \ | #define FOR_NONLINEAR_BINARY(_op) \ | ||||
megdnn::arm_common:: \ | megdnn::arm_common:: \ | ||||
OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \ | OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \ | ||||
@@ -57,20 +65,26 @@ namespace { | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | ||||
dst_type, N* OC* OH* OW); | dst_type, N* OC* OH* OW); | ||||
#define FOR_BIAS(_mode) \ | |||||
switch (_mode) { \ | |||||
case megdnn::BiasMode::NO_BIAS: \ | |||||
FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY) \ | |||||
break; \ | |||||
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST) \ | |||||
break; \ | |||||
case megdnn::BiasMode::BIAS: \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_throw("no quantized unsupported biasmode"); \ | |||||
break; \ | |||||
#define FOR_BIAS(_mode) \ | |||||
switch (_mode) { \ | |||||
case megdnn::BiasMode::NO_BIAS: \ | |||||
FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY) \ | |||||
break; \ | |||||
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
if (pack_oc_size == 1) { \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | |||||
} else { \ | |||||
megdnn_assert(pack_oc_size == 4, \ | |||||
"Only support nchw44 in ARM"); \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | |||||
} \ | |||||
break; \ | |||||
case megdnn::BiasMode::BIAS: \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_throw("no quantized unsupported biasmode"); \ | |||||
break; \ | |||||
} | } | ||||
#define FOR_NONLINEAR(_caller) \ | #define FOR_NONLINEAR(_caller) \ | ||||
@@ -129,6 +143,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||||
#undef FOR_NONLINEAR_UNARY | #undef FOR_NONLINEAR_UNARY | ||||
#undef FOR_NONLINEAR_BINARY_BROADCAST | #undef FOR_NONLINEAR_BINARY_BROADCAST | ||||
#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44 | |||||
#undef FOR_NONLINEAR_BINARY | #undef FOR_NONLINEAR_BINARY | ||||
#undef FOR_NONLINEAR_NOBIAS | #undef FOR_NONLINEAR_NOBIAS | ||||
#undef FOR_NONLINEAR | #undef FOR_NONLINEAR | ||||
@@ -187,6 +202,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||||
if (pack_oc_size == 1) { \ | if (pack_oc_size == 1) { \ | ||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | ||||
} else { \ | } else { \ | ||||
megdnn_assert(pack_oc_size == 4, \ | |||||
"Only support nchw44 in ARM"); \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | ||||
} \ | } \ | ||||
break; \ | break; \ | ||||
@@ -216,14 +216,18 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) | param.nonlineMode != megdnn::NonlineMode::IDENTITY) | ||||
return false; | return false; | ||||
if (opr->param().format == param::ConvBias::Format::NCHW44) { | |||||
//! nchw44 hybird mode and channel wise is not support | |||||
if (param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 || | |||||
param.filter_meta.ocpg == 1) { | |||||
return false; | |||||
} | |||||
} | |||||
size_t OH = param.osz[0]; | size_t OH = param.osz[0]; | ||||
size_t OW = param.osz[1]; | size_t OW = param.osz[1]; | ||||
MatrixMulImpl::KernSizeParam matmul_param = | |||||
get_matmul_kern_param(param, OH * OW, get_oc_tile_size_heuristic(param)); | |||||
if(opr->param().format == param::ConvBias::Format::NCHW44) | |||||
matmul_param.format = param::MatrixMul::Format::MK4; | |||||
MatrixMulImpl::KernSizeParam matmul_param = get_matmul_kern_param( | |||||
param, OH * OW, get_oc_tile_size_heuristic(param)); | |||||
bool matmul_usable = m_matmul_algo->usable(matmul_param); | bool matmul_usable = m_matmul_algo->usable(matmul_param); | ||||
return matmul_usable && | return matmul_usable && | ||||
@@ -22,6 +22,20 @@ namespace conv1x1 { | |||||
namespace { | namespace { | ||||
size_t get_format_pack_size(param::ConvBias::Format format) { | |||||
switch(format){ | |||||
case param::ConvBias::Format::NCHW44: | |||||
case param::ConvBias::Format::NCHW4: | |||||
return 4_z; | |||||
case param::ConvBias::Format::NCHW88: | |||||
return 8_z; | |||||
case param::ConvBias::Format::NCHW: | |||||
return 1_z; | |||||
default: | |||||
megdnn_throw("unknow pack size of the format"); | |||||
} | |||||
} | |||||
struct StrategyHashParam { | struct StrategyHashParam { | ||||
ConvBiasImpl::NCBKernSizeParam param; | ConvBiasImpl::NCBKernSizeParam param; | ||||
param::ConvBias::Format format; | param::ConvBias::Format format; | ||||
@@ -71,7 +85,7 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||||
const ConvBiasImpl::NCBKernSizeParam& param, | const ConvBiasImpl::NCBKernSizeParam& param, | ||||
MatrixMulImpl::AlgoBase::PackMode pack_mode, | MatrixMulImpl::AlgoBase::PackMode pack_mode, | ||||
param::ConvBias::Format format) { | param::ConvBias::Format format) { | ||||
size_t pack_size = format == param::ConvBias::Format::NCHW ? 1 : 4; | |||||
size_t pack_size = get_format_pack_size(format); | |||||
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | #define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | ||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | ||||
midout_iv(_midout_tag)) { \ | midout_iv(_midout_tag)) { \ | ||||
@@ -41,19 +41,25 @@ MatrixMulImpl::KernSizeParam get_matmul_kern_param( | |||||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | ||||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | (param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | ||||
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); | param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); | ||||
size_t pack_c_size = 1_z; | |||||
auto format = param::MatrixMul::Format::DEFAULT; | |||||
if(param.filter_meta.format == param::ConvBias::Format::NCHW44){ | |||||
pack_c_size = 4_z; | |||||
format = param::MatrixMul::Format::MK4; | |||||
} | |||||
return {param.filter_type, | return {param.filter_type, | ||||
param.src_type, | param.src_type, | ||||
is_dst_8bit ? param.bias_type : param.dst_type, | is_dst_8bit ? param.bias_type : param.dst_type, | ||||
M, | M, | ||||
N, | N, | ||||
K, | K, | ||||
LDA, | |||||
LDB, | |||||
LDC, | |||||
LDA * pack_c_size, | |||||
LDB * pack_c_size, | |||||
LDC * pack_c_size, | |||||
false, | false, | ||||
false, | false, | ||||
param::MatrixMul::ComputeMode::DEFAULT, | param::MatrixMul::ComputeMode::DEFAULT, | ||||
param::MatrixMul::Format::DEFAULT}; | |||||
format}; | |||||
} | } | ||||
} // namespace | } // namespace | ||||
@@ -137,9 +143,7 @@ public: | |||||
src_ctype* a_panel = reinterpret_cast<src_ctype*>( | src_ctype* a_panel = reinterpret_cast<src_ctype*>( | ||||
reinterpret_cast<int8_t*>(whole_bundle.get(0)) + | reinterpret_cast<int8_t*>(whole_bundle.get(0)) + | ||||
bytes_offset_of_a_panel); | bytes_offset_of_a_panel); | ||||
matmul_kern_param.LDA *= m_pack_size; | |||||
matmul_kern_param.A_ptr = const_cast<src_ctype*>( | matmul_kern_param.A_ptr = const_cast<src_ctype*>( | ||||
ncb_param.filter<src_ctype>(group_id) + | ncb_param.filter<src_ctype>(group_id) + | ||||
numbers_offset_of_filter); | numbers_offset_of_filter); | ||||
@@ -172,7 +176,6 @@ public: | |||||
static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) = | static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) = | ||||
get_matmul_kern_param(param, OH * OW, OC); | get_matmul_kern_param(param, OH * OW, OC); | ||||
matmul_kern_param.LDB *= m_pack_size; | |||||
rep(batch, BATCH) { | rep(batch, BATCH) { | ||||
rep(g, GROUP) { | rep(g, GROUP) { | ||||
@@ -282,8 +285,6 @@ public: | |||||
matmul_kern_param.C_ptr = matmul_dst; | matmul_kern_param.C_ptr = matmul_dst; | ||||
matmul_kern_param.LDC *= m_pack_size; | |||||
if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { | if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { | ||||
auto matmul_kern = matmul_algo->get_kern(matmul_kern_param); | auto matmul_kern = matmul_algo->get_kern(matmul_kern_param); | ||||
matmul_kern(matmul_kern_param); | matmul_kern(matmul_kern_param); | ||||
@@ -295,14 +296,15 @@ public: | |||||
//! do postprocess | //! do postprocess | ||||
void* bias_ptr = nullptr; | void* bias_ptr = nullptr; | ||||
if (param.bias_mode == megdnn::BiasMode::BIAS) | |||||
if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||||
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | ||||
ncb_param.bias<bias_ctype>(batch_id, group_id) + | ncb_param.bias<bias_ctype>(batch_id, group_id) + | ||||
numbers_of_ncb_dst_offset)); | numbers_of_ncb_dst_offset)); | ||||
else | |||||
} else { | |||||
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | ||||
ncb_param.bias<bias_ctype>(batch_id, group_id) + oc_start)); | ncb_param.bias<bias_ctype>(batch_id, group_id) + oc_start)); | ||||
} | |||||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | ||||
matmul_dst, bias_ptr, conv_bias_dst, param.bias_mode, | matmul_dst, bias_ptr, conv_bias_dst, param.bias_mode, | ||||
param.nonlineMode, param.bias_type, param.dst_type, 1_z, | param.nonlineMode, param.bias_type, param.dst_type, 1_z, | ||||
@@ -137,8 +137,8 @@ class ConvBias { | |||||
sizeof(output_compute_type) * | sizeof(output_compute_type) * | ||||
std::max(Strategy::IC_BLOCK_SIZE, Strategy::OC_BLOCK_SIZE); | std::max(Strategy::IC_BLOCK_SIZE, Strategy::OC_BLOCK_SIZE); | ||||
size_t matmul_workspace_size = | |||||
matmul_algo->get_workspace(get_matmul_kern_param(param)); | |||||
size_t matmul_workspace_size = matmul_algo->get_workspace( | |||||
get_matmul_kern_param(param, m_unit_oc_size)); | |||||
//! compute workspace is independent and separated as far as possible | //! compute workspace is independent and separated as far as possible | ||||
//! in case of false cache line sharing | //! in case of false cache line sharing | ||||
@@ -384,7 +384,7 @@ public: | |||||
get_wbundle_compute(param, matmul_algo); | get_wbundle_compute(param, matmul_algo); | ||||
fallback::MatrixMulImpl::KernParam matmul_param; | fallback::MatrixMulImpl::KernParam matmul_param; | ||||
static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) = | static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) = | ||||
get_matmul_kern_param(param); | |||||
get_matmul_kern_param(param, m_unit_oc_size); | |||||
Strategy strategy = m_strategy; | Strategy strategy = m_strategy; | ||||
size_t unit_tile_size = m_unit_tile_size; | size_t unit_tile_size = m_unit_tile_size; | ||||
@@ -450,21 +450,24 @@ public: | |||||
} | } | ||||
fallback::MatrixMulImpl::KernSizeParam get_matmul_kern_param( | fallback::MatrixMulImpl::KernSizeParam get_matmul_kern_param( | ||||
const NCBKernSizeParam& param) const { | |||||
const NCBKernSizeParam& param, size_t nr_oc_in_unit = 0) const { | |||||
size_t M = 0; | size_t M = 0; | ||||
size_t N = 0; | size_t N = 0; | ||||
size_t K = 0; | size_t K = 0; | ||||
size_t LDA = 0, LDB = 0, LDC = 0; | size_t LDA = 0, LDB = 0, LDC = 0; | ||||
if (nr_oc_in_unit == 0) { | |||||
nr_oc_in_unit = param.filter_meta.ocpg; | |||||
} | |||||
if (format == param::MatrixMul::Format::DEFAULT) { | if (format == param::MatrixMul::Format::DEFAULT) { | ||||
M = m_unit_tile_size; | M = m_unit_tile_size; | ||||
N = param.filter_meta.ocpg; | |||||
N = nr_oc_in_unit; | |||||
K = param.filter_meta.icpg; | K = param.filter_meta.icpg; | ||||
LDA = K; | LDA = K; | ||||
LDB = N; | LDB = N; | ||||
LDC = N; | LDC = N; | ||||
} else { | } else { | ||||
M = param.filter_meta.ocpg; | |||||
M = nr_oc_in_unit; | |||||
N = m_unit_tile_size; | N = m_unit_tile_size; | ||||
K = param.filter_meta.icpg; | K = param.filter_meta.icpg; | ||||
megdnn_assert(K % Strategy::IC_BLOCK_SIZE == 0, "invalid K: %zu", | megdnn_assert(K % Strategy::IC_BLOCK_SIZE == 0, "invalid K: %zu", | ||||
@@ -126,6 +126,8 @@ struct PostProcess { | |||||
DType bias_type, DType dst_type, size_t N, size_t OC, | DType bias_type, DType dst_type, size_t N, size_t OC, | ||||
size_t OH, size_t OW, size_t pack_oc_size = 1) { | size_t OH, size_t OW, size_t pack_oc_size = 1) { | ||||
MEGDNN_MARK_USED_VAR(pack_oc_size); | MEGDNN_MARK_USED_VAR(pack_oc_size); | ||||
megdnn_assert(pack_oc_size == 1, | |||||
"PostProcess only support nchw in x86"); | |||||
megdnn::param::Elemwise::Mode elem_mode = | megdnn::param::Elemwise::Mode elem_mode = | ||||
megdnn::param::Elemwise::Mode::ADD; | megdnn::param::Elemwise::Mode::ADD; | ||||
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | ||||
@@ -150,38 +152,6 @@ struct PostProcess { | |||||
}; | }; | ||||
template <typename ctype, typename dtype> | template <typename ctype, typename dtype> | ||||
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::FLOAT> { | |||||
static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | |||||
megdnn::ConvBiasForward::BiasMode bias_mode, | |||||
megdnn::param::ConvBias::NonlineMode nonlineMode, | |||||
DType bias_type, DType dst_type, size_t N, size_t OC, | |||||
size_t OH, size_t OW, size_t pack_oc_size=1) { | |||||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||||
megdnn::param::Elemwise::Mode elem_mode = | |||||
megdnn::param::Elemwise::Mode::ADD; | |||||
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||||
switch (nonlineMode) { | |||||
BIAS_CASE(RELU); | |||||
BIAS_CASE(SIGMOID); | |||||
BIAS_CASE(H_SWISH); | |||||
IDENTITY_CASE(IDENTITY); | |||||
DEFAULT_CASE; | |||||
} | |||||
} else { | |||||
switch (nonlineMode) { | |||||
NOBIAS_CASE(RELU); | |||||
NOBIAS_CASE(SIGMOID); | |||||
NOBIAS_CASE(H_SWISH); | |||||
IDENTITY_CASE(IDENTITY); | |||||
DEFAULT_CASE; | |||||
} | |||||
} | |||||
FOR_BIAS(bias_mode); | |||||
} | |||||
}; | |||||
template <typename ctype, typename dtype> | |||||
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | ||||
static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | ||||
megdnn::ConvBiasForward::BiasMode bias_mode, | megdnn::ConvBiasForward::BiasMode bias_mode, | ||||
@@ -297,6 +267,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> { | |||||
DType bias_type, DType dst_type, size_t N, size_t OC, | DType bias_type, DType dst_type, size_t N, size_t OC, | ||||
size_t OH, size_t OW, size_t pack_oc_size = 1) { | size_t OH, size_t OW, size_t pack_oc_size = 1) { | ||||
MEGDNN_MARK_USED_VAR(pack_oc_size); | MEGDNN_MARK_USED_VAR(pack_oc_size); | ||||
megdnn_assert(pack_oc_size == 1, | |||||
"PostProcess only support nchw nchw in x86"); | |||||
megdnn::param::Elemwise::Mode elem_mode = | megdnn::param::Elemwise::Mode elem_mode = | ||||
megdnn::param::Elemwise::Mode::ADD; | megdnn::param::Elemwise::Mode::ADD; | ||||
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | ||||
@@ -1297,6 +1297,32 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { | |||||
#endif | #endif | ||||
} | } | ||||
#if MEGDNN_AARCH64 | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({1}, 1, true, false, false); | |||||
check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24"); | |||||
} | |||||
#endif | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) { | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({1}, 1, true, false, false); | |||||
std::vector<conv_bias::TestArg> args_of_4; | |||||
for (auto&& arg : args) { | |||||
if (arg.src.shape[2] * arg.src.shape[3] % 4 == 0) { | |||||
args_of_4.push_back(arg); | |||||
} | |||||
} | |||||
#if MEGDNN_AARCH64 | |||||
check_conv_bias(args_of_4, handle(), "CONV1x1:AARCH64_F32_MK4_4x16:24"); | |||||
#elif MEGDNN_ARMV7 | |||||
check_conv_bias(args_of_4, handle(), "CONV1x1:ARMV7_F32_MK4_4x8:48"); | |||||
#endif | |||||
} | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||