@@ -49,6 +49,14 @@ namespace { | |||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||
dst_type, N, OC, OH* OW); | |||
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \ | |||
megdnn::arm_common::OpCallerBinary<_op<ctype>, \ | |||
megdnn::arm_common::VEC_BCAST101x4>:: \ | |||
run(static_cast<ctype*>(conv_dst_ptr), \ | |||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||
dst_type, N, OC, OH* OW, pack_oc_size); | |||
#define FOR_NONLINEAR_BINARY(_op) \ | |||
megdnn::arm_common:: \ | |||
OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \ | |||
@@ -57,20 +65,26 @@ namespace { | |||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||
dst_type, N* OC* OH* OW); | |||
#define FOR_BIAS(_mode) \ | |||
switch (_mode) { \ | |||
case megdnn::BiasMode::NO_BIAS: \ | |||
FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY) \ | |||
break; \ | |||
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST) \ | |||
break; \ | |||
case megdnn::BiasMode::BIAS: \ | |||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY) \ | |||
break; \ | |||
default: \ | |||
megdnn_throw("no quantized unsupported biasmode"); \ | |||
break; \ | |||
#define FOR_BIAS(_mode) \ | |||
switch (_mode) { \ | |||
case megdnn::BiasMode::NO_BIAS: \ | |||
FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY) \ | |||
break; \ | |||
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
if (pack_oc_size == 1) { \ | |||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | |||
} else { \ | |||
megdnn_assert(pack_oc_size == 4, \ | |||
"Only support nchw44 in ARM"); \ | |||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | |||
} \ | |||
break; \ | |||
case megdnn::BiasMode::BIAS: \ | |||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY) \ | |||
break; \ | |||
default: \ | |||
megdnn_throw("no quantized unsupported biasmode"); \ | |||
break; \ | |||
} | |||
#define FOR_NONLINEAR(_caller) \ | |||
@@ -129,6 +143,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
#undef FOR_NONLINEAR_UNARY | |||
#undef FOR_NONLINEAR_BINARY_BROADCAST | |||
#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44 | |||
#undef FOR_NONLINEAR_BINARY | |||
#undef FOR_NONLINEAR_NOBIAS | |||
#undef FOR_NONLINEAR | |||
@@ -187,6 +202,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
if (pack_oc_size == 1) { \ | |||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | |||
} else { \ | |||
megdnn_assert(pack_oc_size == 4, \ | |||
"Only support nchw44 in ARM"); \ | |||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | |||
} \ | |||
break; \ | |||
@@ -216,14 +216,18 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) | |||
return false; | |||
if (opr->param().format == param::ConvBias::Format::NCHW44) { | |||
//! nchw44 hybird mode and channel wise is not support | |||
if (param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 || | |||
param.filter_meta.ocpg == 1) { | |||
return false; | |||
} | |||
} | |||
size_t OH = param.osz[0]; | |||
size_t OW = param.osz[1]; | |||
MatrixMulImpl::KernSizeParam matmul_param = | |||
get_matmul_kern_param(param, OH * OW, get_oc_tile_size_heuristic(param)); | |||
if(opr->param().format == param::ConvBias::Format::NCHW44) | |||
matmul_param.format = param::MatrixMul::Format::MK4; | |||
MatrixMulImpl::KernSizeParam matmul_param = get_matmul_kern_param( | |||
param, OH * OW, get_oc_tile_size_heuristic(param)); | |||
bool matmul_usable = m_matmul_algo->usable(matmul_param); | |||
return matmul_usable && | |||
@@ -22,6 +22,20 @@ namespace conv1x1 { | |||
namespace { | |||
size_t get_format_pack_size(param::ConvBias::Format format) { | |||
switch(format){ | |||
case param::ConvBias::Format::NCHW44: | |||
case param::ConvBias::Format::NCHW4: | |||
return 4_z; | |||
case param::ConvBias::Format::NCHW88: | |||
return 8_z; | |||
case param::ConvBias::Format::NCHW: | |||
return 1_z; | |||
default: | |||
megdnn_throw("unknow pack size of the format"); | |||
} | |||
} | |||
struct StrategyHashParam { | |||
ConvBiasImpl::NCBKernSizeParam param; | |||
param::ConvBias::Format format; | |||
@@ -71,7 +85,7 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||
const ConvBiasImpl::NCBKernSizeParam& param, | |||
MatrixMulImpl::AlgoBase::PackMode pack_mode, | |||
param::ConvBias::Format format) { | |||
size_t pack_size = format == param::ConvBias::Format::NCHW ? 1 : 4; | |||
size_t pack_size = get_format_pack_size(format); | |||
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
midout_iv(_midout_tag)) { \ | |||
@@ -41,19 +41,25 @@ MatrixMulImpl::KernSizeParam get_matmul_kern_param( | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | |||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); | |||
size_t pack_c_size = 1_z; | |||
auto format = param::MatrixMul::Format::DEFAULT; | |||
if(param.filter_meta.format == param::ConvBias::Format::NCHW44){ | |||
pack_c_size = 4_z; | |||
format = param::MatrixMul::Format::MK4; | |||
} | |||
return {param.filter_type, | |||
param.src_type, | |||
is_dst_8bit ? param.bias_type : param.dst_type, | |||
M, | |||
N, | |||
K, | |||
LDA, | |||
LDB, | |||
LDC, | |||
LDA * pack_c_size, | |||
LDB * pack_c_size, | |||
LDC * pack_c_size, | |||
false, | |||
false, | |||
param::MatrixMul::ComputeMode::DEFAULT, | |||
param::MatrixMul::Format::DEFAULT}; | |||
format}; | |||
} | |||
} // namespace | |||
@@ -137,9 +143,7 @@ public: | |||
src_ctype* a_panel = reinterpret_cast<src_ctype*>( | |||
reinterpret_cast<int8_t*>(whole_bundle.get(0)) + | |||
bytes_offset_of_a_panel); | |||
matmul_kern_param.LDA *= m_pack_size; | |||
matmul_kern_param.A_ptr = const_cast<src_ctype*>( | |||
ncb_param.filter<src_ctype>(group_id) + | |||
numbers_offset_of_filter); | |||
@@ -172,7 +176,6 @@ public: | |||
static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) = | |||
get_matmul_kern_param(param, OH * OW, OC); | |||
matmul_kern_param.LDB *= m_pack_size; | |||
rep(batch, BATCH) { | |||
rep(g, GROUP) { | |||
@@ -282,8 +285,6 @@ public: | |||
matmul_kern_param.C_ptr = matmul_dst; | |||
matmul_kern_param.LDC *= m_pack_size; | |||
if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { | |||
auto matmul_kern = matmul_algo->get_kern(matmul_kern_param); | |||
matmul_kern(matmul_kern_param); | |||
@@ -295,14 +296,15 @@ public: | |||
//! do postprocess | |||
void* bias_ptr = nullptr; | |||
if (param.bias_mode == megdnn::BiasMode::BIAS) | |||
if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | |||
ncb_param.bias<bias_ctype>(batch_id, group_id) + | |||
numbers_of_ncb_dst_offset)); | |||
else | |||
} else { | |||
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | |||
ncb_param.bias<bias_ctype>(batch_id, group_id) + oc_start)); | |||
} | |||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||
matmul_dst, bias_ptr, conv_bias_dst, param.bias_mode, | |||
param.nonlineMode, param.bias_type, param.dst_type, 1_z, | |||
@@ -137,8 +137,8 @@ class ConvBias { | |||
sizeof(output_compute_type) * | |||
std::max(Strategy::IC_BLOCK_SIZE, Strategy::OC_BLOCK_SIZE); | |||
size_t matmul_workspace_size = | |||
matmul_algo->get_workspace(get_matmul_kern_param(param)); | |||
size_t matmul_workspace_size = matmul_algo->get_workspace( | |||
get_matmul_kern_param(param, m_unit_oc_size)); | |||
//! compute workspace is independent and separated as far as possible | |||
//! in case of false cache line sharing | |||
@@ -384,7 +384,7 @@ public: | |||
get_wbundle_compute(param, matmul_algo); | |||
fallback::MatrixMulImpl::KernParam matmul_param; | |||
static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) = | |||
get_matmul_kern_param(param); | |||
get_matmul_kern_param(param, m_unit_oc_size); | |||
Strategy strategy = m_strategy; | |||
size_t unit_tile_size = m_unit_tile_size; | |||
@@ -450,21 +450,24 @@ public: | |||
} | |||
fallback::MatrixMulImpl::KernSizeParam get_matmul_kern_param( | |||
const NCBKernSizeParam& param) const { | |||
const NCBKernSizeParam& param, size_t nr_oc_in_unit = 0) const { | |||
size_t M = 0; | |||
size_t N = 0; | |||
size_t K = 0; | |||
size_t LDA = 0, LDB = 0, LDC = 0; | |||
if (nr_oc_in_unit == 0) { | |||
nr_oc_in_unit = param.filter_meta.ocpg; | |||
} | |||
if (format == param::MatrixMul::Format::DEFAULT) { | |||
M = m_unit_tile_size; | |||
N = param.filter_meta.ocpg; | |||
N = nr_oc_in_unit; | |||
K = param.filter_meta.icpg; | |||
LDA = K; | |||
LDB = N; | |||
LDC = N; | |||
} else { | |||
M = param.filter_meta.ocpg; | |||
M = nr_oc_in_unit; | |||
N = m_unit_tile_size; | |||
K = param.filter_meta.icpg; | |||
megdnn_assert(K % Strategy::IC_BLOCK_SIZE == 0, "invalid K: %zu", | |||
@@ -126,6 +126,8 @@ struct PostProcess { | |||
DType bias_type, DType dst_type, size_t N, size_t OC, | |||
size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
megdnn_assert(pack_oc_size == 1, | |||
"PostProcess only support nchw in x86"); | |||
megdnn::param::Elemwise::Mode elem_mode = | |||
megdnn::param::Elemwise::Mode::ADD; | |||
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||
@@ -150,38 +152,6 @@ struct PostProcess { | |||
}; | |||
template <typename ctype, typename dtype> | |||
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::FLOAT> { | |||
static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | |||
megdnn::ConvBiasForward::BiasMode bias_mode, | |||
megdnn::param::ConvBias::NonlineMode nonlineMode, | |||
DType bias_type, DType dst_type, size_t N, size_t OC, | |||
size_t OH, size_t OW, size_t pack_oc_size=1) { | |||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
megdnn::param::Elemwise::Mode elem_mode = | |||
megdnn::param::Elemwise::Mode::ADD; | |||
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||
switch (nonlineMode) { | |||
BIAS_CASE(RELU); | |||
BIAS_CASE(SIGMOID); | |||
BIAS_CASE(H_SWISH); | |||
IDENTITY_CASE(IDENTITY); | |||
DEFAULT_CASE; | |||
} | |||
} else { | |||
switch (nonlineMode) { | |||
NOBIAS_CASE(RELU); | |||
NOBIAS_CASE(SIGMOID); | |||
NOBIAS_CASE(H_SWISH); | |||
IDENTITY_CASE(IDENTITY); | |||
DEFAULT_CASE; | |||
} | |||
} | |||
FOR_BIAS(bias_mode); | |||
} | |||
}; | |||
template <typename ctype, typename dtype> | |||
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | |||
megdnn::ConvBiasForward::BiasMode bias_mode, | |||
@@ -297,6 +267,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> { | |||
DType bias_type, DType dst_type, size_t N, size_t OC, | |||
size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
megdnn_assert(pack_oc_size == 1, | |||
"PostProcess only support nchw nchw in x86"); | |||
megdnn::param::Elemwise::Mode elem_mode = | |||
megdnn::param::Elemwise::Mode::ADD; | |||
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||
@@ -1297,6 +1297,32 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { | |||
#endif | |||
} | |||
#if MEGDNN_AARCH64 | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({1}, 1, true, false, false); | |||
check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24"); | |||
} | |||
#endif | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({1}, 1, true, false, false); | |||
std::vector<conv_bias::TestArg> args_of_4; | |||
for (auto&& arg : args) { | |||
if (arg.src.shape[2] * arg.src.shape[3] % 4 == 0) { | |||
args_of_4.push_back(arg); | |||
} | |||
} | |||
#if MEGDNN_AARCH64 | |||
check_conv_bias(args_of_4, handle(), "CONV1x1:AARCH64_F32_MK4_4x16:24"); | |||
#elif MEGDNN_ARMV7 | |||
check_conv_bias(args_of_4, handle(), "CONV1x1:ARMV7_F32_MK4_4x8:48"); | |||
#endif | |||
} | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) { | |||
using namespace conv_bias; | |||