GitOrigin-RevId: f870ad964c
tags/v0.5.0
@@ -182,7 +182,7 @@ bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( | |||||
bool ok_type = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | bool ok_type = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && | (param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && | ||||
(fm.format == param::Convolution::Format::NCHW44); | |||||
(fm.format == param::Convolution::Format::NCHW44_DOT); | |||||
bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic < 4); | bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic < 4); | ||||
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && | bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && | ||||
(fh == 2 || fh == 3 || fh == 5 || fh == 7); | (fh == 2 || fh == 3 || fh == 5 || fh == 7); | ||||
@@ -55,7 +55,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; | AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
AlgoDotS8DirectNCHWNCHW44 ds8_direct_stride2_nchw_nchw44; | |||||
AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true}; | AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true}; | ||||
AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false}; | AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false}; | ||||
AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true}; | AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true}; | ||||
@@ -66,6 +65,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false}; | AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false}; | ||||
AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; | AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; | ||||
AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; | |||||
#endif | #endif | ||||
AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44; | AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44; | ||||
@@ -96,7 +96,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
public: | public: | ||||
AlgoPack() { | AlgoPack() { | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
direct_algos.emplace_back(&ds8_direct_stride2_nchw_nchw44); | |||||
direct_algos.emplace_back(&ds8_direct_stride1_large_group); | direct_algos.emplace_back(&ds8_direct_stride1_large_group); | ||||
direct_algos.emplace_back(&ds8_direct_stride1_small_group); | direct_algos.emplace_back(&ds8_direct_stride1_small_group); | ||||
direct_algos.emplace_back(&ds8_direct_stride2_large_group); | direct_algos.emplace_back(&ds8_direct_stride2_large_group); | ||||
@@ -107,6 +106,7 @@ public: | |||||
direct_algos.emplace_back(&du8_direct_stride2_small_group); | direct_algos.emplace_back(&du8_direct_stride2_small_group); | ||||
direct_algos.emplace_back(&ds8_direct_nchw44); | direct_algos.emplace_back(&ds8_direct_nchw44); | ||||
direct_algos.emplace_back(&ds8_direct_nchw_nchw44); | |||||
#endif | #endif | ||||
direct_algos.emplace_back(&qu8_direct_stride2_large_group); | direct_algos.emplace_back(&qu8_direct_stride2_large_group); | ||||
direct_algos.emplace_back(&qu8_direct_stride2_small_group); | direct_algos.emplace_back(&qu8_direct_stride2_small_group); | ||||
@@ -582,14 +582,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) { | |||||
/****************************dot qint8 direct*************************/ | /****************************dot qint8 direct*************************/ | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { | ||||
checker_conv_bias_qint8x8x8( | |||||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false, | |||||
true), | |||||
handle(), "ARMDOTS8_NCHW_NCHW44"); | |||||
checker_conv_bias_qint8x8x8( | |||||
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false, | |||||
true), | |||||
handle(), "ARMDOTS8_NCHW_NCHW44"); | |||||
auto args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false, | |||||
true); | |||||
for (auto&& arg : args) { | |||||
arg.param.format = param::ConvBias::Format::NCHW44_DOT; | |||||
} | |||||
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); | |||||
args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false, | |||||
true); | |||||
for (auto&& arg : args) { | |||||
arg.param.format = param::ConvBias::Format::NCHW44_DOT; | |||||
} | |||||
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, | TEST_F(ARM_COMMON_MULTI_THREADS, | ||||
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) { | CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) { | ||||
@@ -987,6 +987,12 @@ Args Args::from_argv(int argc, char **argv) { | |||||
cb(nchw32); | cb(nchw32); | ||||
cb(nhwcd4); | cb(nhwcd4); | ||||
#undef cb | #undef cb | ||||
if (!strcmp(argv[i], "--enable-nchw44-dot")) { | |||||
mgb_log_warn("enable-nchw44-dot optimization"); | |||||
graph_opt.graph_opt.enable_nchw44_dot(); | |||||
continue; | |||||
} | |||||
if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) { | if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) { | ||||
mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization"); | mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization"); | ||||
graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); | graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); | ||||
@@ -94,7 +94,6 @@ namespace mgb { | |||||
m_param{param}, m_param_size{param_size} | m_param{param}, m_param_size{param_size} | ||||
{ | { | ||||
} | } | ||||
//! build a blob representation to be used as cache key | //! build a blob representation to be used as cache key | ||||
PersistentCache::Blob build_blob() const; | PersistentCache::Blob build_blob() const; | ||||
}; | }; | ||||
@@ -611,7 +611,6 @@ AlgoChooserProfileCache::Result AlgoChooser<Opr>::get_profile_result( | |||||
AlgoChooserProfileCache::Key cache_key{origin_layouts.data(), | AlgoChooserProfileCache::Key cache_key{origin_layouts.data(), | ||||
origin_layouts.size(), &origin_param, | origin_layouts.size(), &origin_param, | ||||
sizeof(origin_param)}; | sizeof(origin_param)}; | ||||
{ | { | ||||
auto&& rst = cache.get(cache_key); | auto&& rst = cache.get(cache_key); | ||||
if (rst.valid()) | if (rst.valid()) | ||||
@@ -107,7 +107,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, | |||||
src_shape[1] / group * 2; | src_shape[1] / group * 2; | ||||
return hybird_nchwx ? computation : computation * 8; | return hybird_nchwx ? computation : computation * 8; | ||||
} | } | ||||
if (param.format == Param::Format::NCHW44) { | |||||
if (param.format == Param::Format::NCHW44 || | |||||
param.format == Param::Format::NCHW44_DOT) { | |||||
//! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4} | //! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4} | ||||
if (filter_shape[1] == 1 && filter_shape[2] == 1) { | if (filter_shape[1] == 1 && filter_shape[2] == 1) { | ||||
group *= 4; | group *= 4; | ||||
@@ -145,6 +146,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, | |||||
if (param.format == Param::Format::NCHW4 || | if (param.format == Param::Format::NCHW4 || | ||||
param.format == Param::Format::NCHW88 || | param.format == Param::Format::NCHW88 || | ||||
param.format == Param::Format::NCHW44 || | param.format == Param::Format::NCHW44 || | ||||
param.format == Param::Format::NCHW44_DOT || | |||||
param.format == Param::Format::NCHW32) { | param.format == Param::Format::NCHW32) { | ||||
return eval_conv_computation_nchwx(); | return eval_conv_computation_nchwx(); | ||||
} | } | ||||