From e05c795b45aef7bb99dfc73232e687fce9bb16d8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 20 Jul 2020 14:20:30 +0800 Subject: [PATCH] refactor(dnn/arm): refactor direct algo in algo selection GitOrigin-RevId: d195f44decb45847fa46e0e90d6e64368c07539c --- dnn/src/aarch64/conv_bias/fp16/algos.cpp | 35 ++-- dnn/src/aarch64/conv_bias/fp16/algos.h | 8 +- dnn/src/aarch64/conv_bias/fp32/algos.cpp | 35 ++-- dnn/src/aarch64/conv_bias/fp32/algos.h | 8 +- dnn/src/aarch64/conv_bias/opr_impl.cpp | 12 +- dnn/src/arm_common/conv_bias/f16/algos.cpp | 70 +++---- dnn/src/arm_common/conv_bias/f16/algos.h | 13 +- dnn/src/arm_common/conv_bias/fp32/algos.cpp | 107 +++++------ dnn/src/arm_common/conv_bias/fp32/algos.h | 19 +- dnn/src/arm_common/conv_bias/int8/algos.cpp | 86 ++++----- dnn/src/arm_common/conv_bias/int8/algos.h | 27 +-- dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp | 76 ++++---- dnn/src/arm_common/conv_bias/int8x8x16/algos.h | 15 +- dnn/src/arm_common/conv_bias/opr_impl.cpp | 90 +++------ dnn/src/arm_common/conv_bias/quint8/algos.cpp | 81 +++------ dnn/src/arm_common/conv_bias/quint8/algos.h | 26 +-- dnn/src/x86/conv_bias/f32/algos.cpp | 62 +++---- dnn/src/x86/conv_bias/f32/algos.h | 12 +- dnn/src/x86/conv_bias/opr_impl.cpp | 12 +- dnn/test/aarch64/conv_bias.cpp | 19 +- dnn/test/arm_common/conv_bias.cpp | 9 +- dnn/test/arm_common/conv_bias_multi_thread.cpp | 201 +++++---------------- .../conv_bias_multi_thread_benchmark.cpp | 60 +++--- dnn/test/arm_common/convolution.cpp | 8 +- dnn/test/x86/conv_bias.cpp | 108 ++++++++--- dnn/test/x86/convolution.cpp | 4 +- 26 files changed, 451 insertions(+), 752 deletions(-) diff --git a/dnn/src/aarch64/conv_bias/fp16/algos.cpp b/dnn/src/aarch64/conv_bias/fp16/algos.cpp index cd05b606..60ccc693 100644 --- a/dnn/src/aarch64/conv_bias/fp16/algos.cpp +++ b/dnn/src/aarch64/conv_bias/fp16/algos.cpp @@ -22,26 +22,19 @@ using namespace aarch64; /* ===================== stride-2 algo ===================== */ MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) -bool ConvBiasImpl::AlgoF16DirectStride2::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { +bool ConvBiasImpl::AlgoF16DirectStride2::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; - bool aviliable = - param.filter_meta.format == param::Convolution::Format::NCHW && - param.src_type.enumv() == DTypeEnum::Float16 && - param.filter_type.enumv() == DTypeEnum::Float16 && - param.dst_type.enumv() == DTypeEnum::Float16 && - !fm.should_flip && fm.spatial_ndim == 2 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && - (FH == 2 || FH == 3 || FH == 5 || FH == 7); - if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - aviliable &= (large_group == m_large_group); - } - return aviliable; + return param.filter_meta.format == param::Convolution::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Float16 && + param.filter_type.enumv() == DTypeEnum::Float16 && + param.dst_type.enumv() == DTypeEnum::Float16 && + !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && + FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); } MIDOUT_END(); return false; @@ -50,8 +43,9 @@ bool ConvBiasImpl::AlgoF16DirectStride2::usable( size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 1) { + bool large_group = param.filter_meta.group >= param.nr_threads; auto wbundle = arm_common::MultithreadDirectConvCommon< - dt_float16, __fp16>::get_bundle_stride(param, m_large_group); + dt_float16, __fp16>::get_bundle_stride(param, large_group); return wbundle.total_size_in_bytes(); } MIDOUT_END(); @@ -77,6 +71,7 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( size_t IC = param.filter_meta.icpg; size_t OC = param.filter_meta.ocpg; size_t group = fm.group; + bool large_group = group >= param.nr_threads; using Func = std::function; Func conv = nullptr; @@ -91,11 +86,11 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( } WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon< - dt_float16, __fp16>::get_bundle_stride(param, m_large_group); + dt_float16, __fp16>::get_bundle_stride(param, large_group); SmallVector ret_kerns; //! Dense conv and small group - if (m_large_group) { + if (large_group) { //! Channel wise conv and big groups auto exec_one_group = [bundle, conv]( const NCBKernParam& kern_param, diff --git a/dnn/src/aarch64/conv_bias/fp16/algos.h b/dnn/src/aarch64/conv_bias/fp16/algos.h index 2e0321bd..77ab5d76 100644 --- a/dnn/src/aarch64/conv_bias/fp16/algos.h +++ b/dnn/src/aarch64/conv_bias/fp16/algos.h @@ -18,15 +18,9 @@ namespace aarch64 { /* ===================== stride-2 algo ===================== */ class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; - bool m_large_group; - public: - AlgoF16DirectStride2(bool large_group) : m_large_group(large_group) {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "ARMV8F16STRD2_LARGE_GROUP" - : "ARMV8F16STRD2_SMALL_GROUP"; - } + const char* name() const override { return "ARMV8F16STRD2"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; diff --git a/dnn/src/aarch64/conv_bias/fp32/algos.cpp b/dnn/src/aarch64/conv_bias/fp32/algos.cpp index 4436837d..3742032d 100644 --- a/dnn/src/aarch64/conv_bias/fp32/algos.cpp +++ b/dnn/src/aarch64/conv_bias/fp32/algos.cpp @@ -21,26 +21,19 @@ using namespace megdnn; using namespace aarch64; MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) -bool ConvBiasImpl::AlgoF32DirectStride2::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { +bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; - bool aviliable = - param.filter_meta.format == param::ConvBias::Format::NCHW && - param.src_type.enumv() == DTypeEnum::Float32 && - param.filter_type.enumv() == DTypeEnum::Float32 && - param.dst_type.enumv() == DTypeEnum::Float32 && - !fm.should_flip && fm.spatial_ndim == 2 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && - (FH == 2 || FH == 3 || FH == 5 || FH == 7); - if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - aviliable &= (large_group == m_large_group); - } - return aviliable; + return param.filter_meta.format == param::ConvBias::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_type.enumv() == DTypeEnum::Float32 && + param.dst_type.enumv() == DTypeEnum::Float32 && + !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && + FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); } MIDOUT_END(); return false; @@ -49,8 +42,9 @@ bool ConvBiasImpl::AlgoF32DirectStride2::usable( size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 1) { + bool large_group = param.filter_meta.group >= param.nr_threads; auto wbundle = arm_common::MultithreadDirectConvCommon< - float, float>::get_bundle_stride(param, m_large_group); + float, float>::get_bundle_stride(param, large_group); return wbundle.total_size_in_bytes(); } MIDOUT_END(); @@ -75,6 +69,7 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( size_t IC = param.filter_meta.icpg; size_t OC = param.filter_meta.ocpg; size_t group = fm.group; + bool large_group = group >= param.nr_threads; using Func = std::function; Func conv = nullptr; @@ -89,11 +84,11 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( } WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon< - float, float>::get_bundle_stride(param, m_large_group); + float, float>::get_bundle_stride(param, large_group); SmallVector ret_kerns; //! Dense conv and small group - if (m_large_group) { + if (large_group) { //! Channel wise conv and big groups auto exec_one_group = [bundle, conv]( const NCBKernParam& kern_param, diff --git a/dnn/src/aarch64/conv_bias/fp32/algos.h b/dnn/src/aarch64/conv_bias/fp32/algos.h index 1947fd19..6ae1bf00 100644 --- a/dnn/src/aarch64/conv_bias/fp32/algos.h +++ b/dnn/src/aarch64/conv_bias/fp32/algos.h @@ -22,15 +22,9 @@ using FallbackConvBiasImpl = fallback::ConvBiasImpl; class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; - bool m_large_group; - public: - AlgoF32DirectStride2(bool large_group) : m_large_group(large_group) {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "ARMV8F32STRD2_LARGE_GROUP" - : "ARMV8F32STRD2_SMALL_GROUP"; - } + const char* name() const override { return "ARMV8F32STRD2"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; diff --git a/dnn/src/aarch64/conv_bias/opr_impl.cpp b/dnn/src/aarch64/conv_bias/opr_impl.cpp index f9905dd0..e65a96a5 100644 --- a/dnn/src/aarch64/conv_bias/opr_impl.cpp +++ b/dnn/src/aarch64/conv_bias/opr_impl.cpp @@ -25,13 +25,11 @@ using namespace megdnn; using namespace aarch64; class ConvBiasImpl::AlgoPack : NonCopyableObj { - AlgoF32DirectStride2 f32_direct_stride2_large_group{true}; - AlgoF32DirectStride2 f32_direct_stride2_small_group{false}; + AlgoF32DirectStride2 f32_direct_stride2; AlgoS8MatrixMul s8_matrix_mul; AlgoQU8MatrixMul qu8_matrix_mul; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - AlgoF16DirectStride2 f16_direct_stride2_large_group{true}; - AlgoF16DirectStride2 f16_direct_stride2_small_group{false}; + AlgoF16DirectStride2 f16_direct_stride2; #endif public: @@ -39,11 +37,9 @@ public: matmul_algos.emplace_back(&qu8_matrix_mul); matmul_algos.emplace_back(&s8_matrix_mul); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - direct_algos.emplace_back(&f16_direct_stride2_large_group); - direct_algos.emplace_back(&f16_direct_stride2_small_group); + direct_algos.emplace_back(&f16_direct_stride2); #endif - direct_algos.emplace_back(&f32_direct_stride2_large_group); - direct_algos.emplace_back(&f32_direct_stride2_small_group); + direct_algos.emplace_back(&f32_direct_stride2); } SmallVector direct_algos; SmallVector matmul_algos; diff --git a/dnn/src/arm_common/conv_bias/f16/algos.cpp b/dnn/src/arm_common/conv_bias/f16/algos.cpp index d5d095e3..9ab11020 100644 --- a/dnn/src/arm_common/conv_bias/f16/algos.cpp +++ b/dnn/src/arm_common/conv_bias/f16/algos.cpp @@ -192,9 +192,8 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP16WinogradF23_8x8, MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_kimpl) -bool ConvBiasImpl::AlgoF16Direct::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { +bool ConvBiasImpl::AlgoF16Direct::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; @@ -203,20 +202,14 @@ bool ConvBiasImpl::AlgoF16Direct::usable( // ``param.osz[0]*param.osz[1] >= 8'' comes from the fact that the // kernel may have access to up to 8 fp16 after the end of the memory // chunk. - bool aviliable = fm.format == param::ConvBias::Format::NCHW && - param.src_type.enumv() == DTypeEnum::Float16 && - param.filter_type.enumv() == DTypeEnum::Float16 && - param.dst_type.enumv() == DTypeEnum::Float16 && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && - param.isz[0] * param.isz[1] >= 8 && - param.osz[0] * param.osz[1] >= 8 && FH <= 7 && - SH == 1 && SW == 1; - if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - aviliable &= (large_group == m_large_group); - } - return aviliable; + return fm.format == param::ConvBias::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Float16 && + param.filter_type.enumv() == DTypeEnum::Float16 && + param.dst_type.enumv() == DTypeEnum::Float16 && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && param.isz[0] * param.isz[1] >= 8 && + param.osz[0] * param.osz[1] >= 8 && FH <= 7 && SH == 1 && + SW == 1; } MIDOUT_END(); return false; @@ -225,9 +218,10 @@ bool ConvBiasImpl::AlgoF16Direct::usable( size_t ConvBiasImpl::AlgoF16Direct::get_workspace( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) { + bool large_group = param.filter_meta.group >= param.nr_threads; auto wbundle = MultithreadDirectConvCommon::get_bundle( - param, m_large_group); + param, large_group); return wbundle.total_size_in_bytes(); } MIDOUT_END(); @@ -241,13 +235,14 @@ SmallVector ConvBiasImpl::AlgoF16Direct::get_kimpls( size_t IC = param.filter_meta.icpg; size_t OC = param.filter_meta.ocpg; size_t group = fm.group; + bool large_group = group >= param.nr_threads; WorkspaceBundle bundle = MultithreadDirectConvCommon::get_bundle( - param, m_large_group); + param, large_group); SmallVector ret_kerns; //! When group >= nr_threads, treat it as large_group, each thread process //! one group for better performance - if (m_large_group) { + if (large_group) { //! Channel wise conv and big groups auto exec_one_group = [bundle](const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { @@ -316,27 +311,18 @@ SmallVector ConvBiasImpl::AlgoF16Direct::dispatch_kerns( /* ===================== stride-1 algo ===================== */ -bool ConvBiasImpl::AlgoF16DirectStride1::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { +bool ConvBiasImpl::AlgoF16DirectStride1::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; - bool aviliable = - param.filter_meta.format == param::ConvBias::Format::NCHW && - param.src_type.enumv() == DTypeEnum::Float16 && - param.filter_type.enumv() == DTypeEnum::Float16 && - param.dst_type.enumv() == DTypeEnum::Float16 && - !fm.should_flip && fm.spatial_ndim == 2 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && - (FH == 2 || FH == 3 || FH == 5); - if (algo_selection_strategy == - ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - aviliable &= (large_group == m_large_group); - } - return aviliable; + return param.filter_meta.format == param::ConvBias::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Float16 && + param.filter_type.enumv() == DTypeEnum::Float16 && + param.dst_type.enumv() == DTypeEnum::Float16 && + !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); } MIDOUT_END(); return false; @@ -351,6 +337,7 @@ ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( size_t IC = param.filter_meta.icpg; size_t OC = param.filter_meta.ocpg; size_t group = fm.group; + bool large_group = group >= param.nr_threads; using Func = std::function; Func conv_kern_function = nullptr; @@ -371,11 +358,11 @@ ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( WorkspaceBundle bundle = MultithreadDirectConvCommon::get_bundle_stride( - param, m_large_group); + param, large_group); SmallVector ret_kerns; //! When group >= nr_threads, treat it as large_group, each thread process //! one group for better performance - if (m_large_group) { + if (large_group) { //! Channel wise conv and big groups auto exec_one_group = [bundle, conv_kern_function]( const NCBKernParam& kern_param, @@ -423,8 +410,9 @@ ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( size_t ConvBiasImpl::AlgoF16DirectStride1::get_workspace( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 1) { + bool large_group = param.filter_meta.group >= param.nr_threads; auto bundle = MultithreadDirectConvCommon< - dt_float16, __fp16>::get_bundle_stride(param, m_large_group); + dt_float16, __fp16>::get_bundle_stride(param, large_group); return bundle.total_size_in_bytes(); } MIDOUT_END(); diff --git a/dnn/src/arm_common/conv_bias/f16/algos.h b/dnn/src/arm_common/conv_bias/f16/algos.h index a38c8607..0e9d5c32 100644 --- a/dnn/src/arm_common/conv_bias/f16/algos.h +++ b/dnn/src/arm_common/conv_bias/f16/algos.h @@ -79,15 +79,10 @@ public: class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; - bool m_large_group; public: - AlgoF16Direct(bool is_large_group) : m_large_group{is_large_group} {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "F16DIRECT_LARGE_GROUP" - : "F16DIRECT_SMALL_GROUP"; - } + const char* name() const override { return "F16DIRECT"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; @@ -99,14 +94,10 @@ public: class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; - bool m_large_group; public: - AlgoF16DirectStride1(bool is_large_group) : m_large_group{is_large_group} {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "F16STRD1_LARGE_GROUP" : "F16STRD1_SMALL_GROUP"; - } + const char* name() const override { return "F16STRD1"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.cpp b/dnn/src/arm_common/conv_bias/fp32/algos.cpp index e25d40ae..20f11b81 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/algos.cpp @@ -334,9 +334,8 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF63_4x4_NCHW44, /* ===================== direct algo ===================== */ MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl); -bool ConvBiasImpl::AlgoF32Direct::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { +bool ConvBiasImpl::AlgoF32Direct::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; @@ -345,20 +344,14 @@ bool ConvBiasImpl::AlgoF32Direct::usable( // ``param.osz[0]*param.osz[1] >= 4'' comes from the fact that the // kernel may have access to up to 4 floats after the end of the memory // chunk. - bool aviliable = fm.format == param::ConvBias::Format::NCHW && - param.src_type.enumv() == DTypeEnum::Float32 && - param.filter_type.enumv() == DTypeEnum::Float32 && - param.dst_type.enumv() == DTypeEnum::Float32 && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && - param.isz[0] * param.isz[1] >= 4 && - param.osz[0] * param.osz[1] >= 4 && FH <= 7 && - SH == 1 && SW == 1; - if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - aviliable &= (large_group == m_large_group); - } - return aviliable; + return fm.format == param::ConvBias::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_type.enumv() == DTypeEnum::Float32 && + param.dst_type.enumv() == DTypeEnum::Float32 && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && param.isz[0] * param.isz[1] >= 4 && + param.osz[0] * param.osz[1] >= 4 && FH <= 7 && SH == 1 && + SW == 1; } MIDOUT_END(); return false; @@ -366,8 +359,9 @@ bool ConvBiasImpl::AlgoF32Direct::usable( size_t ConvBiasImpl::AlgoF32Direct::get_workspace( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { + bool large_group = param.filter_meta.group >= param.nr_threads; auto wbundle = MultithreadDirectConvCommon::get_bundle( - param, m_large_group); + param, large_group); return wbundle.total_size_in_bytes(); } MIDOUT_END(); @@ -380,13 +374,14 @@ SmallVector ConvBiasImpl::AlgoF32Direct::get_kimpls( size_t IC = param.filter_meta.icpg; size_t OC = param.filter_meta.ocpg; size_t group = fm.group; + bool large_group = group >= param.nr_threads; WorkspaceBundle bundle = MultithreadDirectConvCommon::get_bundle( - param, m_large_group); + param, large_group); SmallVector ret_kerns; //! When group >= nr_threads, treat it as large_group, each thread process //! one group for better performance - if (m_large_group) { + if (large_group) { //! Channel wise conv and big groups auto exec_one_group = [bundle](const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { @@ -452,27 +447,19 @@ SmallVector ConvBiasImpl::AlgoF32Direct::dispatch_kerns( return {}; } /* ===================== stride-1 algo ===================== */ -bool ConvBiasImpl::AlgoF32DirectStride1::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { +bool ConvBiasImpl::AlgoF32DirectStride1::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; - bool aviliable = - param.filter_meta.format == param::ConvBias::Format::NCHW && - param.src_type.enumv() == DTypeEnum::Float32 && - param.filter_type.enumv() == DTypeEnum::Float32 && - param.dst_type.enumv() == DTypeEnum::Float32 && - !fm.should_flip && fm.spatial_ndim == 2 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && - (FH == 2 || FH == 3 || FH == 5 || FH == 7); - if (algo_selection_strategy == - ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - aviliable &= (large_group == m_large_group); - } - return aviliable; + return param.filter_meta.format == param::ConvBias::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_type.enumv() == DTypeEnum::Float32 && + param.dst_type.enumv() == DTypeEnum::Float32 && + !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && + FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); } MIDOUT_END(); return false; @@ -481,9 +468,10 @@ bool ConvBiasImpl::AlgoF32DirectStride1::usable( size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { + bool large_group = param.filter_meta.group >= param.nr_threads; auto bundle = MultithreadDirectConvCommon::get_bundle_stride( - param, m_large_group); + param, large_group); return bundle.total_size_in_bytes(); } MIDOUT_END(); @@ -499,6 +487,7 @@ ConvBiasImpl::AlgoF32DirectStride1::get_kimpls( size_t IC = param.filter_meta.icpg; size_t OC = param.filter_meta.ocpg; size_t group = fm.group; + bool large_group = group >= param.nr_threads; using Func = std::function; Func conv_kern_function = nullptr; @@ -522,11 +511,11 @@ ConvBiasImpl::AlgoF32DirectStride1::get_kimpls( WorkspaceBundle bundle = MultithreadDirectConvCommon::get_bundle_stride( - param, m_large_group); + param, large_group); SmallVector ret_kerns; //! When group >= nr_threads, treat it as large_group, each thread process //! one group for better performance - if (m_large_group) { + if (large_group) { //! Channel wise conv and big groups auto exec_one_group = [bundle, conv_kern_function]( const NCBKernParam& kern_param, @@ -580,27 +569,19 @@ ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns( /* ===================== stride-2 algo ===================== */ -bool ConvBiasImpl::AlgoF32DirectStride2::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { +bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; - bool aviliable = - param.filter_meta.format == param::ConvBias::Format::NCHW && - param.src_type.enumv() == DTypeEnum::Float32 && - param.filter_type.enumv() == DTypeEnum::Float32 && - param.dst_type.enumv() == DTypeEnum::Float32 && - !fm.should_flip && fm.spatial_ndim == 2 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && - (FH == 2 || FH == 3 || FH == 5 || FH == 7); - if (algo_selection_strategy == - ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - aviliable &= (large_group == m_large_group); - } - return aviliable; + return param.filter_meta.format == param::ConvBias::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_type.enumv() == DTypeEnum::Float32 && + param.dst_type.enumv() == DTypeEnum::Float32 && + !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && + FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); } MIDOUT_END(); return false; @@ -608,9 +589,10 @@ bool ConvBiasImpl::AlgoF32DirectStride2::usable( size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) { + bool large_group = param.filter_meta.group >= param.nr_threads; auto bundle = MultithreadDirectConvCommon::get_bundle_stride( - param, m_large_group); + param, large_group); return bundle.total_size_in_bytes(); } MIDOUT_END(); @@ -625,6 +607,7 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( size_t IC = param.filter_meta.icpg; size_t OC = param.filter_meta.ocpg; size_t group = fm.group; + bool large_group = group >= param.nr_threads; using Func = std::function; Func conv_kern_function = nullptr; @@ -648,11 +631,11 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( WorkspaceBundle bundle = MultithreadDirectConvCommon::get_bundle_stride( - param, m_large_group); + param, large_group); SmallVector ret_kerns; //! When group >= nr_threads, treat it as large_group, each thread process //! one group for better performance - if (m_large_group) { + if (large_group) { //! Channel wise conv and big groups auto exec_one_group = [bundle, conv_kern_function]( const NCBKernParam& kern_param, diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.h b/dnn/src/arm_common/conv_bias/fp32/algos.h index af290b4c..c04d009d 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.h +++ b/dnn/src/arm_common/conv_bias/fp32/algos.h @@ -128,15 +128,10 @@ public: class ConvBiasImpl::AlgoF32Direct final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; - bool m_large_group; public: - AlgoF32Direct(bool is_large_group) : m_large_group{is_large_group} {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "F32DIRECT_LARGE_GROUP" - : "F32DIRECT_SMALL_GROUP"; - } + const char* name() const override { return "F32DIRECT"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; @@ -147,14 +142,10 @@ public: class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; - bool m_large_group; public: - AlgoF32DirectStride1(bool is_large_group) : m_large_group{is_large_group} {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "F32STRD1_LARGE_GROUP" : "F32STRD1_SMALL_GROUP"; - } + const char* name() const override { return "F32STRD1"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; @@ -165,14 +156,10 @@ public: class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; - bool m_large_group; public: - AlgoF32DirectStride2(bool is_large_group) : m_large_group{is_large_group} {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "F32STRD2_LARGE_GROUP" : "F32STRD2_SMALL_GROUP"; - } + const char* name() const override { return "F32STRD2"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; diff --git a/dnn/src/arm_common/conv_bias/int8/algos.cpp b/dnn/src/arm_common/conv_bias/int8/algos.cpp index 23613615..a54dfc03 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.cpp +++ b/dnn/src/arm_common/conv_bias/int8/algos.cpp @@ -27,17 +27,10 @@ using namespace arm_common; MIDOUT_DECL(megdnn_arm_common_conv_bias_int8) /* ===================== stride1 algo ===================== */ -bool ConvBiasImpl::AlgoS8DirectStride1::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { - bool avaible = direct_int8_stride1::can_conv_direct_stride1_int8(param); - auto fm = param.filter_meta; - if (algo_selection_strategy == - ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { - bool large_group = fm.group >= param.nr_threads; - avaible &= (large_group == m_large_group); - } - return avaible; + +bool ConvBiasImpl::AlgoS8DirectStride1::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { + return direct_int8_stride1::can_conv_direct_stride1_int8(param); } bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred( const NCBKernSizeParam& param) const { @@ -53,8 +46,9 @@ bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred( } size_t ConvBiasImpl::AlgoS8DirectStride1::get_workspace( - const NCBKernSizeParam& param) const { - auto bundle = direct_int8_stride1::get_bundle(param, m_large_group); + const NCBKernSizeParam& param) const { + bool large_group = param.filter_meta.group >= param.nr_threads; + auto bundle = direct_int8_stride1::get_bundle(param, large_group); return bundle.total_size_in_bytes(); } @@ -62,7 +56,8 @@ SmallVector ConvBiasImpl::AlgoS8DirectStride1::dispatch_kerns( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 0) { - return direct_int8_stride1::get_kimpls(param, m_large_group); + bool large_group = param.filter_meta.group >= param.nr_threads; + return direct_int8_stride1::get_kimpls(param, large_group); } MIDOUT_END(); return {}; @@ -117,21 +112,15 @@ ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::dispatch_kerns( } /* ===================== stride2 algo ===================== */ -bool ConvBiasImpl::AlgoS8DirectStride2::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { - bool avaible = direct_int8_stride2::can_conv_direct_stride2_int8(param); - if (algo_selection_strategy == - ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - avaible &= (large_group == m_large_group); - } - return avaible; +bool ConvBiasImpl::AlgoS8DirectStride2::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { + return direct_int8_stride2::can_conv_direct_stride2_int8(param); } size_t ConvBiasImpl::AlgoS8DirectStride2::get_workspace( const NCBKernSizeParam& param) const { - auto bundle = direct_int8_stride2::get_bundle(param, m_large_group); + bool large_group = param.filter_meta.group >= param.nr_threads; + auto bundle = direct_int8_stride2::get_bundle(param, large_group); return bundle.total_size_in_bytes(); } @@ -139,7 +128,8 @@ SmallVector ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 1) { - return direct_int8_stride2::get_kimpls(param, m_large_group); + bool large_group = param.filter_meta.group >= param.nr_threads; + return direct_int8_stride2::get_kimpls(param, large_group); } MIDOUT_END(); return {}; @@ -147,24 +137,15 @@ ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( #if __ARM_FEATURE_DOTPROD /* ===================== dot stride1 algo ======================== */ -bool ConvBiasImpl::AlgoDotS8DirectStride1::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { - bool avaible = - direct_dotprod_int8_stride1::can_conv_direct_stride1_int8(param); - - if (algo_selection_strategy == - ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - avaible &= (large_group == m_large_group); - } - - return avaible; +bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { + return direct_dotprod_int8_stride1::can_conv_direct_stride1_int8(param); } size_t ConvBiasImpl::AlgoDotS8DirectStride1::get_workspace( const NCBKernSizeParam& param) const { - auto bundle = direct_dotprod_int8_stride1::get_bundle(param, m_large_group); + bool large_group = param.filter_meta.group >= param.nr_threads; + auto bundle = direct_dotprod_int8_stride1::get_bundle(param, large_group); return bundle.total_size_in_bytes(); } @@ -172,29 +153,23 @@ SmallVector ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 1) { - return direct_dotprod_int8_stride1::get_kimpls(param, m_large_group); + bool large_group = param.filter_meta.group >= param.nr_threads; + return direct_dotprod_int8_stride1::get_kimpls(param, large_group); } MIDOUT_END(); return {}; } /* ===================== dot stride2 algo ======================== */ -bool ConvBiasImpl::AlgoDotS8DirectStride2::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { - bool avaible = - direct_dotprod_int8_stride2::can_conv_direct_stride2_int8(param); - if (algo_selection_strategy == - ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - avaible &= (large_group == m_large_group); - } - return avaible; +bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { + return direct_dotprod_int8_stride2::can_conv_direct_stride2_int8(param); } size_t ConvBiasImpl::AlgoDotS8DirectStride2::get_workspace( - const NCBKernSizeParam& param) const { - auto bundle = direct_dotprod_int8_stride2::get_bundle(param, m_large_group); + const NCBKernSizeParam& param) const { + bool large_group = param.filter_meta.group >= param.nr_threads; + auto bundle = direct_dotprod_int8_stride2::get_bundle(param, large_group); return bundle.total_size_in_bytes(); } @@ -202,7 +177,8 @@ SmallVector ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 2) { - return direct_dotprod_int8_stride2::get_kimpls(param, m_large_group); + bool large_group = param.filter_meta.group >= param.nr_threads; + return direct_dotprod_int8_stride2::get_kimpls(param, large_group); } MIDOUT_END(); return {}; diff --git a/dnn/src/arm_common/conv_bias/int8/algos.h b/dnn/src/arm_common/conv_bias/int8/algos.h index 9bbdb194..196584f7 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.h +++ b/dnn/src/arm_common/conv_bias/int8/algos.h @@ -18,14 +18,10 @@ namespace megdnn { namespace arm_common { class ConvBiasImpl::AlgoS8DirectStride1 final : public AlgoBase { - bool m_large_group; public: - AlgoS8DirectStride1(bool large_group) : m_large_group(large_group) {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "S8STRD1_LARGE_GROUP" : "S8STRD1_SMALL_GROUP"; - } + const char* name() const override { return "S8STRD1"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; @@ -36,14 +32,10 @@ public: }; class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { - bool m_large_group; public: - AlgoS8DirectStride2(bool large_group) : m_large_group(large_group) {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "S8STRD2_LARGE_GROUP" : "S8STRD2_SMALL_GROUP"; - } + const char* name() const override { return "S8STRD2"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; @@ -115,16 +107,10 @@ public: }; class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { - bool m_large_group; public: - AlgoDotS8DirectStride1(bool large_group) : m_large_group(large_group) {} - bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "ARMDOTS8STRD1_LARGE_GROUP" - : "ARMDOTS8STRD1_SMALL_GROUP"; - } + const char* name() const override { return "ARMDOTS8STRD1"; } bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy) const override; @@ -134,15 +120,10 @@ public: }; class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { - bool m_large_group; public: - AlgoDotS8DirectStride2(bool large_group) : m_large_group(large_group) {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "ARMDOTS8STRD2_LARGE_GROUP" - : "ARMDOTS8STRD2_SMALL_GROUP"; - } + const char* name() const override { return "ARMDOTS8STRD2"; } bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy) const override; diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp index f9bf53b3..d82c18ff 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp @@ -82,28 +82,20 @@ void get_rectified_size_str2(size_t IH, size_t IW, size_t OH, size_t OW, } // namespace /* ===================== direct algo ===================== */ -bool ConvBiasImpl::AlgoI8x8x16Direct::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { +bool ConvBiasImpl::AlgoI8x8x16Direct::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; - bool aviliable = - param.bias_mode == BiasMode::NO_BIAS && - param.nonlineMode == NonlineMode::IDENTITY && - fm.format == param::ConvBias::Format::NCHW && !fm.should_flip && - param.src_type.enumv() == DTypeEnum::Int8 && - param.filter_type.enumv() == DTypeEnum::Int8 && - param.dst_type.enumv() == DTypeEnum::Int16 && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); - if (algo_selection_strategy == - ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - aviliable &= (large_group == m_large_group); - } - return aviliable; + return param.bias_mode == BiasMode::NO_BIAS && + param.nonlineMode == NonlineMode::IDENTITY && + fm.format == param::ConvBias::Format::NCHW && !fm.should_flip && + param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int16 && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); } MIDOUT_END(); return false; @@ -117,11 +109,12 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Direct::get_bundle( auto OH = param.osz[0], OW = param.osz[1]; auto PH = fm.padding[0], PW = fm.padding[1]; size_t OH2, OW2, IH2, IW2; + bool large_group = group >= param.nr_threads; get_rectified_size_str1(IH, IW, OH, OW, PH, PW, IH2, IW2, OH2, OW2); size_t part0 = 0u, part1 = 0u; if (need_src_copy_str1(param)) { - part0 = m_large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads - : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; + part0 = large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; } if (need_dst_copy_str1(param)) { part1 = OH2 * OW2 * sizeof(int16_t) * nr_threads + 16; @@ -255,9 +248,10 @@ SmallVector ConvBiasImpl::AlgoI8x8x16Direct::get_kimpls( size_t IC = param.filter_meta.icpg; size_t OC = param.filter_meta.ocpg; size_t group = fm.group; + bool large_group = group >= param.nr_threads; WorkspaceBundle bundle = get_bundle(param); SmallVector ret_kerns; - if (m_large_group) { + if (large_group) { auto exec_one_group = [bundle](const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { auto fm = kern_param.filter_meta; @@ -302,28 +296,20 @@ ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns( } /* ===================== stride-2 algo ===================== */ -bool ConvBiasImpl::AlgoI8x8x16Stride2::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { +bool ConvBiasImpl::AlgoI8x8x16Stride2::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; - bool aviliable = param.bias_mode == BiasMode::NO_BIAS && - param.nonlineMode == NonlineMode::IDENTITY && - fm.format == param::ConvBias::Format::NCHW && - !fm.should_flip && - param.src_type.enumv() == DTypeEnum::Int8 && - param.filter_type.enumv() == DTypeEnum::Int8 && - param.dst_type.enumv() == DTypeEnum::Int16 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == 2 && fm.stride[1] == 2 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); - if (algo_selection_strategy == - ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - aviliable &= (large_group == m_large_group); - } - return aviliable; + return param.bias_mode == BiasMode::NO_BIAS && + param.nonlineMode == NonlineMode::IDENTITY && + fm.format == param::ConvBias::Format::NCHW && !fm.should_flip && + param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int16 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5); } MIDOUT_END(); return false; @@ -340,9 +326,10 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Stride2::get_bundle( size_t OH2, OW2, IH2, IW2; get_rectified_size_str2(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); size_t part0 = 0u, part1 = 0u; + bool large_group = group >= param.nr_threads; if (need_src_copy_str2(param)) { - part0 = m_large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads - : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; + part0 = large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; } if (need_dst_copy_str2(param)) { part1 = OH2 * OW2 * sizeof(int16_t) * nr_threads + 16; @@ -475,9 +462,10 @@ SmallVector ConvBiasImpl::AlgoI8x8x16Stride2::get_kimpls( size_t IC = param.filter_meta.icpg; size_t OC = param.filter_meta.ocpg; size_t group = fm.group; + bool large_group = group >= param.nr_threads; WorkspaceBundle bundle = get_bundle(param); SmallVector ret_kerns; - if (m_large_group) { + if (large_group) { auto exec_one_group = [bundle](const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { auto fm = kern_param.filter_meta; diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/algos.h b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h index acabe888..198fa891 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/algos.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h @@ -26,15 +26,10 @@ class ConvBiasImpl::AlgoI8x8x16Direct final : public AlgoBase { const NCBKernParam& kern_param, const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); - bool m_large_group; public: - AlgoI8x8x16Direct(bool large_group) : m_large_group(large_group) {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "I8816DIRECT_LARGE_GROUP" - : "I8816DIRECT_SMALL_GROUP"; - } + const char* name() const override { return "I8816DIRECT"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; @@ -53,15 +48,9 @@ class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { const NCBKernParam& kern_param, const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); - bool m_large_group; - public: - AlgoI8x8x16Stride2(bool large_group) : m_large_group(large_group) {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "I8816STRD2_LARGE_GROUP" - : "I8816STRD2_SMALL_GROUP"; - } + const char* name() const override { return "I8816STRD2"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index cae7c4ae..33a00b66 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -40,28 +40,20 @@ uint8_t arm_common_algo_type_storage; } // anonymous namespace class ConvBiasImpl::AlgoPack : NonCopyableObj { - AlgoQU8DirectStride2 qu8_direct_stride2_large_group{true}; - AlgoQU8DirectStride2 qu8_direct_stride2_small_group{false}; - AlgoQU8DirectStride1 qu8_direct_stride1_large_group{true}; - AlgoQU8DirectStride1 qu8_direct_stride1_small_group{false}; - AlgoS8DirectStride2 s8_direct_stride2_large_group{true}; - AlgoS8DirectStride2 s8_direct_stride2_small_group{false}; + AlgoQU8DirectStride2 qu8_direct_stride2; + AlgoQU8DirectStride1 qu8_direct_stride1; + AlgoS8DirectStride2 s8_direct_stride2; AlgoS8DirectNCHW44 s8_direct_nchw44; AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; - AlgoS8DirectStride1 s8_direct_stride1_large_group{true}; - AlgoS8DirectStride1 s8_direct_stride1_small_group{false}; + AlgoS8DirectStride1 s8_direct_stride1; AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; #if __ARM_FEATURE_DOTPROD - AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true}; - AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false}; - AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true}; - AlgoDotS8DirectStride2 ds8_direct_stride2_small_group{false}; - AlgoDotU8DirectStride1 du8_direct_stride1_large_group{true}; - AlgoDotU8DirectStride1 du8_direct_stride1_small_group{false}; - AlgoDotU8DirectStride2 du8_direct_stride2_large_group{true}; - AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false}; + AlgoDotS8DirectStride1 ds8_direct_stride1; + AlgoDotS8DirectStride2 ds8_direct_stride2; + AlgoDotU8DirectStride1 du8_direct_stride1; + AlgoDotU8DirectStride2 du8_direct_stride2; AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; @@ -71,23 +63,16 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44; AlgoF32DirectNCHW44 f32_direct_nchw44; - AlgoF32Direct f32_direct_large_group{true}; - AlgoF32Direct f32_direct_small_group{false}; - AlgoF32DirectStride2 f32_direct_stride2_large_group{true}; - AlgoF32DirectStride2 f32_direct_stride2_small_group{false}; - AlgoF32DirectStride1 f32_direct_stride1_large_group{true}; - AlgoF32DirectStride1 f32_direct_stride1_small_group{false}; + AlgoF32Direct f32_direct; + AlgoF32DirectStride2 f32_direct_stride2; + AlgoF32DirectStride1 f32_direct_stride1; - AlgoI8x8x16Direct i8x8x16_direct_large_group{true}; - AlgoI8x8x16Direct i8x8x16_direct_small_group{false}; - AlgoI8x8x16Stride2 i8x8x16_stride2_large_group{true}; - AlgoI8x8x16Stride2 i8x8x16_stride2_small_group{false}; + AlgoI8x8x16Direct i8x8x16_direct; + AlgoI8x8x16Stride2 i8x8x16_stride2; AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - AlgoF16Direct f16_direct_large_group{true}; - AlgoF16Direct f16_direct_small_group{false}; - AlgoF16DirectStride1 f16_direct_stride1_large_group{true}; - AlgoF16DirectStride1 f16_direct_stride1_small_group{false}; + AlgoF16Direct f16_direct; + AlgoF16DirectStride1 f16_direct_stride1; #endif SmallVector> refhold; @@ -95,54 +80,39 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { public: AlgoPack() { #if __ARM_FEATURE_DOTPROD - direct_algos.emplace_back(&ds8_direct_stride1_large_group); - direct_algos.emplace_back(&ds8_direct_stride1_small_group); - direct_algos.emplace_back(&ds8_direct_stride2_large_group); - direct_algos.emplace_back(&ds8_direct_stride2_small_group); - direct_algos.emplace_back(&du8_direct_stride1_large_group); - direct_algos.emplace_back(&du8_direct_stride1_small_group); - direct_algos.emplace_back(&du8_direct_stride2_large_group); - direct_algos.emplace_back(&du8_direct_stride2_small_group); + direct_algos.emplace_back(&ds8_direct_stride1); + direct_algos.emplace_back(&ds8_direct_stride2); + direct_algos.emplace_back(&du8_direct_stride1); + direct_algos.emplace_back(&du8_direct_stride2); direct_algos.emplace_back(&ds8_direct_nchw44); direct_algos.emplace_back(&ds8_direct_nchw_nchw44); #endif - direct_algos.emplace_back(&qu8_direct_stride2_large_group); - direct_algos.emplace_back(&qu8_direct_stride2_small_group); - direct_algos.emplace_back(&qu8_direct_stride1_large_group); - direct_algos.emplace_back(&qu8_direct_stride1_small_group); - direct_algos.emplace_back(&s8_direct_stride2_large_group); - direct_algos.emplace_back(&s8_direct_stride2_small_group); + direct_algos.emplace_back(&qu8_direct_stride2); + direct_algos.emplace_back(&qu8_direct_stride1); + direct_algos.emplace_back(&s8_direct_stride2); direct_algos.emplace_back(&s8_direct_nchw44); direct_algos.emplace_back(&s8_direct_nchw_nchw44); - direct_algos.emplace_back(&s8_direct_stride1_large_group); - direct_algos.emplace_back(&s8_direct_stride1_small_group); + direct_algos.emplace_back(&s8_direct_stride1); direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - direct_algos.emplace_back(&f16_direct_stride1_large_group); - direct_algos.emplace_back(&f16_direct_stride1_small_group); - direct_algos.emplace_back(&f16_direct_large_group); - direct_algos.emplace_back(&f16_direct_small_group); + direct_algos.emplace_back(&f16_direct_stride1); + direct_algos.emplace_back(&f16_direct); #endif - direct_algos.emplace_back(&i8x8x16_direct_large_group); - direct_algos.emplace_back(&i8x8x16_direct_small_group); + direct_algos.emplace_back(&i8x8x16_direct); direct_algos.emplace_back(&i8x8x16_stride2_filter2); - direct_algos.emplace_back(&i8x8x16_stride2_large_group); - direct_algos.emplace_back(&i8x8x16_stride2_small_group); + direct_algos.emplace_back(&i8x8x16_stride2); direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); direct_algos.emplace_back(&f32_chanel_wise_nchw44); direct_algos.emplace_back(&f32_direct_nchw44); - direct_algos.emplace_back(&f32_direct_stride1_large_group); - direct_algos.emplace_back(&f32_direct_stride1_small_group); - direct_algos.emplace_back(&f32_direct_stride2_large_group); - direct_algos.emplace_back(&f32_direct_stride2_small_group); - direct_algos.emplace_back(&f32_direct_large_group); - direct_algos.emplace_back(&f32_direct_small_group); + direct_algos.emplace_back(&f32_direct_stride1); + direct_algos.emplace_back(&f32_direct_stride2); + direct_algos.emplace_back(&f32_direct); static CpuOprDelegationStorage<2> storage; auto matmul_opr = storage.get(); diff --git a/dnn/src/arm_common/conv_bias/quint8/algos.cpp b/dnn/src/arm_common/conv_bias/quint8/algos.cpp index ae48792f..f6e42f84 100644 --- a/dnn/src/arm_common/conv_bias/quint8/algos.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/algos.cpp @@ -25,21 +25,15 @@ using namespace megdnn; using namespace arm_common; /* ===================== stride1 algo ===================== */ -bool ConvBiasImpl::AlgoQU8DirectStride1::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { - bool avaible = direct_quint8_stride1::can_conv_direct_stride1_quint8(param); - if (algo_selection_strategy == - ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - avaible &= (large_group == m_large_group); - } - return avaible; +bool ConvBiasImpl::AlgoQU8DirectStride1::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { + return direct_quint8_stride1::can_conv_direct_stride1_quint8(param); } size_t ConvBiasImpl::AlgoQU8DirectStride1::get_workspace( const NCBKernSizeParam& param) const { - auto bundle = direct_quint8_stride1::get_bundle(param, m_large_group); + bool large_group = param.filter_meta.group >= param.nr_threads; + auto bundle = direct_quint8_stride1::get_bundle(param, large_group); return bundle.total_size_in_bytes(); } @@ -47,7 +41,8 @@ SmallVector ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 0) { - return direct_quint8_stride1::get_kimpls(param, m_large_group); + bool large_group = param.filter_meta.group >= param.nr_threads; + return direct_quint8_stride1::get_kimpls(param, large_group); } MIDOUT_END(); return {}; @@ -55,20 +50,15 @@ ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns( /* ===================== stride2 algo ===================== */ bool ConvBiasImpl::AlgoQU8DirectStride2::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { - bool avaible = direct_quint8_stride2::can_conv_direct_stride2_quint8(param); - if (algo_selection_strategy == - ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - avaible &= (large_group == m_large_group); - } - return avaible; + const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { + return direct_quint8_stride2::can_conv_direct_stride2_quint8(param); } size_t ConvBiasImpl::AlgoQU8DirectStride2::get_workspace( const NCBKernSizeParam& param) const { - auto bundle = direct_quint8_stride2::get_bundle(param, m_large_group); + bool large_group = param.filter_meta.group >= param.nr_threads; + auto bundle = direct_quint8_stride2::get_bundle(param, large_group); return bundle.total_size_in_bytes(); } @@ -76,31 +66,23 @@ SmallVector ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 1) { - return direct_quint8_stride2::get_kimpls(param, m_large_group); + bool large_group = param.filter_meta.group >= param.nr_threads; + return direct_quint8_stride2::get_kimpls(param, large_group); } MIDOUT_END(); return {}; } #if __ARM_FEATURE_DOTPROD /* ===================== stride1 algo ===================== */ -bool ConvBiasImpl::AlgoDotU8DirectStride1::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { - bool avaible = - direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8( - param); - if (algo_selection_strategy == - ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - avaible &= (large_group == m_large_group); - } - return avaible; +bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { + return direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8(param); } size_t ConvBiasImpl::AlgoDotU8DirectStride1::get_workspace( const NCBKernSizeParam& param) const { - auto bundle = - direct_dotprod_quint8_stride1::get_bundle(param, m_large_group); + bool large_group = param.filter_meta.group >= param.nr_threads; + auto bundle = direct_dotprod_quint8_stride1::get_bundle(param, large_group); return bundle.total_size_in_bytes(); } @@ -108,31 +90,23 @@ SmallVector ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 0) { - return direct_dotprod_quint8_stride1::get_kimpls(param, m_large_group); + bool large_group = param.filter_meta.group >= param.nr_threads; + return direct_dotprod_quint8_stride1::get_kimpls(param, large_group); } MIDOUT_END(); return {}; } /* ===================== stride2 algo ===================== */ -bool ConvBiasImpl::AlgoDotU8DirectStride2::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { - bool avaible = - direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8( - param); - if (algo_selection_strategy == - ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - avaible &= (large_group == m_large_group); - } - return avaible; +bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { + return direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8(param); } size_t ConvBiasImpl::AlgoDotU8DirectStride2::get_workspace( const NCBKernSizeParam& param) const { - auto bundle = - direct_dotprod_quint8_stride2::get_bundle(param, m_large_group); + bool large_group = param.filter_meta.group >= param.nr_threads; + auto bundle = direct_dotprod_quint8_stride2::get_bundle(param, large_group); return bundle.total_size_in_bytes(); } @@ -140,7 +114,8 @@ SmallVector ConvBiasImpl::AlgoDotU8DirectStride2::dispatch_kerns( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 1) { - return direct_dotprod_quint8_stride2::get_kimpls(param, m_large_group); + bool large_group = param.filter_meta.group >= param.nr_threads; + return direct_dotprod_quint8_stride2::get_kimpls(param, large_group); } MIDOUT_END(); return {}; diff --git a/dnn/src/arm_common/conv_bias/quint8/algos.h b/dnn/src/arm_common/conv_bias/quint8/algos.h index 2ff7dcdf..2f0de9a2 100644 --- a/dnn/src/arm_common/conv_bias/quint8/algos.h +++ b/dnn/src/arm_common/conv_bias/quint8/algos.h @@ -18,14 +18,10 @@ namespace megdnn { namespace arm_common { class ConvBiasImpl::AlgoQU8DirectStride1 final : public AlgoBase { - bool m_large_group; public: - AlgoQU8DirectStride1(bool large_group) : m_large_group(large_group) {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "QU8STRD1_LARGE_GROUP" : "QU8STRD1_SMALL_GROUP"; - } + const char* name() const override { return "QU8STRD1"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; @@ -36,14 +32,10 @@ public: }; class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { - bool m_large_group; public: - AlgoQU8DirectStride2(bool large_group) : m_large_group(large_group) {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "QU8STRD2_LARGE_GROUP" : "QU8STRD2_SMALL_GROUP"; - } + const char* name() const override { return "QU8STRD2"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; @@ -53,15 +45,10 @@ public: }; #if __ARM_FEATURE_DOTPROD class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { - bool m_large_group; public: - AlgoDotU8DirectStride1(bool large_group) : m_large_group(large_group) {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "ARMDOTU8STRD1_LARGE_GROUP" - : "ARMDOTU8STRD1_SMALL_GROUP"; - } + const char* name() const override { return "ARMDOTU8STRD1"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; @@ -72,15 +59,10 @@ public: }; class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { - bool m_large_group; public: - AlgoDotU8DirectStride2(bool large_group) : m_large_group(large_group) {} bool is_reproducible() const override { return true; } - const char* name() const override { - return m_large_group ? "ARMDOTU8STRD2_LARGE_GROUP" - : "ARMDOTU8STRD2_SMALL_GROUP"; - } + const char* name() const override { return "ARMDOTU8STRD2"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; diff --git a/dnn/src/x86/conv_bias/f32/algos.cpp b/dnn/src/x86/conv_bias/f32/algos.cpp index 73248523..04171df1 100644 --- a/dnn/src/x86/conv_bias/f32/algos.cpp +++ b/dnn/src/x86/conv_bias/f32/algos.cpp @@ -65,9 +65,10 @@ void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, size_t IC = param.filter_meta.icpg; \ size_t OC = param.filter_meta.ocpg; \ size_t group = fm.group; \ + bool large_group = group >= param.nr_threads; \ WorkspaceBundle bundle = get_bundle(param); \ SmallVector ret_kerns; \ - if (m_large_group) { \ + if (large_group) { \ auto exec_one_group = [bundle]( \ const NCBKernParam& kern_param, \ const NCBKernIndex& ncb_index) mutable { \ @@ -104,22 +105,15 @@ void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, /* ===================== direct algo ===================== */ -bool ConvBiasImpl::AlgoDirect::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { +bool ConvBiasImpl::AlgoDirect::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { auto&& fm = param.filter_meta; - bool aviliable = fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 && - param.src_type.enumv() == DTypeEnum::Float32 && - param.filter_type.enumv() == DTypeEnum::Float32 && - param.dst_type.enumv() == DTypeEnum::Float32 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.spatial[0] <= 7 && fm.stride[0] == 1 && - fm.stride[1] == 1; - if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - aviliable &= (large_group == m_large_group); - } - return aviliable; + return fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 && + param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_type.enumv() == DTypeEnum::Float32 && + param.dst_type.enumv() == DTypeEnum::Float32 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.spatial[0] <= 7 && + fm.stride[0] == 1 && fm.stride[1] == 1; } WorkspaceBundle ConvBiasImpl::AlgoDirect::get_bundle( const NCBKernSizeParam& param) const { @@ -133,9 +127,10 @@ WorkspaceBundle ConvBiasImpl::AlgoDirect::get_bundle( get_rectified_img_size(IH, IW, FH, FW, OH, OW, fm.padding[0], fm.padding[1], IH2, IW2, OH2, OW2); size_t part0 = 0u, part1 = 0u; + bool large_group = group >= param.nr_threads; if (IH != IH2 || IW != IW2) { - part0 = m_large_group ? IC * IH2 * IW2 * sizeof(float) * nr_threads - : IC * IH2 * IW2 * sizeof(float) * group * batch; + part0 = large_group ? IC * IH2 * IW2 * sizeof(float) * nr_threads + : IC * IH2 * IW2 * sizeof(float) * group * batch; } if (OH != OH2 || OW != OW2) { part1 = OH2 * OW2 * sizeof(float) * nr_threads; @@ -319,24 +314,17 @@ SmallVector ConvBiasImpl::AlgoDirect::get_kimpls( GET_KERN; } /* ===================== direct-stride2 algo ===================== */ -bool ConvBiasImpl::AlgoDirectStride2::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { +bool ConvBiasImpl::AlgoDirectStride2::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; - bool aviliable = - param.filter_meta.format == param::ConvBias::Format::NCHW && - param.src_type.enumv() == DTypeEnum::Float32 && - param.filter_type.enumv() == DTypeEnum::Float32 && - param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); - if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { - bool large_group = param.filter_meta.group >= param.nr_threads; - aviliable &= (large_group == m_large_group); - } - return aviliable; + return param.filter_meta.format == param::ConvBias::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_type.enumv() == DTypeEnum::Float32 && + param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); } WorkspaceBundle ConvBiasImpl::AlgoDirectStride2::get_bundle( @@ -352,10 +340,10 @@ WorkspaceBundle ConvBiasImpl::AlgoDirectStride2::get_bundle( size_t src_size = 0, dst_size = 0; size_t IH2, IW2, OH2, OW2; get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); + bool large_group = group >= param.nr_threads; \ if (need_src_copy(param)) { - src_size = m_large_group - ? IC * IH2 * IW2 * sizeof(float) * nr_threads - : IC * IH2 * IW2 * sizeof(float) * group * batch; + src_size = large_group ? IC * IH2 * IW2 * sizeof(float) * nr_threads + : IC * IH2 * IW2 * sizeof(float) * group * batch; } if (need_dst_copy(param)) { // we only need one dst plane diff --git a/dnn/src/x86/conv_bias/f32/algos.h b/dnn/src/x86/conv_bias/f32/algos.h index 0f9111e1..a231c565 100644 --- a/dnn/src/x86/conv_bias/f32/algos.h +++ b/dnn/src/x86/conv_bias/f32/algos.h @@ -29,14 +29,10 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase { const NCBKernParam& kern_param, const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); - bool m_large_group; - public: - AlgoDirect(bool large_group) : m_large_group(large_group) {} bool is_reproducible() const override { return true; } const char* name() const override { - return m_large_group ? "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP" - : "X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP"; + return "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; @@ -65,14 +61,10 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase { const NCBKernParam& kern_param, const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); - bool m_large_group; - public: - AlgoDirectStride2(bool large_group) : m_large_group(large_group) {} bool is_reproducible() const override { return true; } const char* name() const override { - return m_large_group ? "X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP" - : "X86_CONV_BIAS_DIRECT_STRIDE2_SMALL_GROUP"; + return "X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP"; } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; diff --git a/dnn/src/x86/conv_bias/opr_impl.cpp b/dnn/src/x86/conv_bias/opr_impl.cpp index 1f5adeeb..0de20e30 100644 --- a/dnn/src/x86/conv_bias/opr_impl.cpp +++ b/dnn/src/x86/conv_bias/opr_impl.cpp @@ -76,10 +76,8 @@ void* ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::type() const { } class ConvBiasImpl::AlgoPack : NonCopyableObj { - AlgoDirect stride1_direct_large_group{true}; - AlgoDirect stride1_direct_small_group{false}; - AlgoDirectStride2 stride2_direct_large_group{true}; - AlgoDirectStride2 stride2_direct_small_group{false}; + AlgoDirect stride1_direct; + AlgoDirectStride2 stride2_direct; AlgoDirectAvx2Stride1Int8 avx2_stride1_direct_int8; AlgoAVX2DirectConvStride2 avx2_stride2_direct; AlgoChanWiseAvx2Stride1Qint8 avx2_stride1_chanwsie_qint8; @@ -103,10 +101,8 @@ public: all_algos.emplace_back(&mkldnn_matmul_qint8); all_algos.emplace_back(&mkldnn_qint8); #endif - all_algos.emplace_back(&stride1_direct_large_group); - all_algos.emplace_back(&stride1_direct_small_group); - all_algos.emplace_back(&stride2_direct_large_group); - all_algos.emplace_back(&stride2_direct_small_group); + all_algos.emplace_back(&stride1_direct); + all_algos.emplace_back(&stride2_direct); all_algos.emplace_back(&avx2_stride1_direct_int8); all_algos.emplace_back(&avx2_stride2_direct); all_algos.emplace_back(&avx2_stride1_chanwsie_qint8); diff --git a/dnn/test/aarch64/conv_bias.cpp b/dnn/test/aarch64/conv_bias.cpp index 02eeed25..8d993d2e 100644 --- a/dnn/test/aarch64/conv_bias.cpp +++ b/dnn/test/aarch64/conv_bias.cpp @@ -81,15 +81,10 @@ void checker_conv_bias(std::vector args, Handle* handle, {arg.src, arg.filter, arg.bias, {}, {}}); } } -TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) { +TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) { check_conv_bias( conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), - handle(), "ARMV8F32STRD2_LARGE_GROUP"); -} -TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) { - check_conv_bias( - conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), - handle(), "ARMV8F32STRD2_SMALL_GROUP"); + handle(), "ARMV8F32STRD2"); } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -114,17 +109,11 @@ void checker_conv_bias_fp16(std::vector args, } } -TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR2_LARGE_GROUP) { - NormalRNG rng(1); - checker_conv_bias_f16( - conv_bias::get_conv_bias_args({2, 3, 5}, 2, false, false, false), - handle(), rng, "ARMV8F16STRD2_LARGE_GROUP", 0.04); -} -TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR2_SMALL_GROUP) { +TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR2) { NormalRNG rng(1); checker_conv_bias_f16( conv_bias::get_conv_bias_args({2, 3, 5}, 2, false, false, false), - handle(), rng, "ARMV8F16STRD2_SMALL_GROUP", 0.04); + handle(), rng, "ARMV8F16STRD2", 0.04); } #endif diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index 2e52a54f..1b3094e5 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -1310,8 +1310,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE1_NCHW44) { benchmark0.set_param(param); benchmark0.set_times(RUN); benchmark0.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker( - "F32STRD1_LARGE_GROUP")); + conv_bias::ConvBiasAlgoChecker("F32STRD1")); auto opr = handle()->create_operator(); opr->param() = param; @@ -1385,8 +1384,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE2_NCHW44) { benchmark0.set_param(param); benchmark0.set_times(RUN); benchmark0.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker( - "F32STRD2_LARGE_GROUP")); + conv_bias::ConvBiasAlgoChecker("F32STRD2")); auto opr = handle()->create_operator(); opr->param() = param; @@ -1464,8 +1462,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) { benchmark0.set_param(param); benchmark0.set_times(RUN); benchmark0.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker( - "S8STRD1_LARGE_GROUP")); + conv_bias::ConvBiasAlgoChecker("S8STRD1")); auto opr = handle()->create_operator(); opr->param() = param; diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 279f9e4c..5300f1b4 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -356,15 +356,10 @@ void checker_conv_bias_int8x8x32_multi(std::vector args, } /**********************************F32 direct************************/ -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_LARGE_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32) { check_conv_bias( get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), - handle(), "F32DIRECT_LARGE_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) { - check_conv_bias( - get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), - handle(), "F32DIRECT_SMALL_GROUP"); + handle(), "F32DIRECT"); } TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) { @@ -391,21 +386,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) { handle(), "F32_CONV_NCHW44_DIRECT"); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_LARGE_GROUP) { - check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), - handle(), "F32STRD1_LARGE_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_SMALL_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1) { check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), - handle(), "F32STRD1_SMALL_GROUP"); + handle(), "F32STRD1"); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) { check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), - handle(), "F32STRD2_LARGE_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) { - check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), - handle(), "F32STRD2_SMALL_GROUP"); + handle(), "F32STRD2"); } TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) { check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, @@ -437,72 +424,41 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) { /**********************************F16 direct************************/ #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) { NormalRNG rng(1); checker_conv_bias_f16( get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), - handle(), rng, "F16DIRECT_LARGE_GROUP", 0.03); + handle(), rng, "F16DIRECT", 0.03); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_SMALL_GROUP) { - NormalRNG rng(1); - checker_conv_bias_f16( - get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), - handle(), rng, "F16DIRECT_SMALL_GROUP", 0.03); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_LARGE_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) { NormalRNG rng(1); checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false), - handle(), rng, "F16STRD1_LARGE_GROUP", 0.03); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_SMALL_GROUP) { - NormalRNG rng(1); - checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false), - handle(), rng, "F16STRD1_SMALL_GROUP", 0.03); + handle(), rng, "F16STRD1", 0.03); } #endif /**********************************algo 8816 direct************************/ -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_LARGE_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT) { checker_conv_bias_int8x8x16( get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(), - "I8816DIRECT_LARGE_GROUP"); + "I8816DIRECT"); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_SMALL_GROUP) { - checker_conv_bias_int8x8x16( - get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(), - "I8816DIRECT_SMALL_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_LARGE_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2) { checker_conv_bias_int8x8x16( get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(), - "I8816STRD2_LARGE_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_SMALL_GROUP) { - checker_conv_bias_int8x8x16( - get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(), - "I8816STRD2_SMALL_GROUP"); + "I8816STRD2"); } /**********************************algo 8-8-32 direct************************/ -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_LARGE_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1) { checker_conv_bias_int8x8x32_multi( get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), - "S8STRD1_LARGE_GROUP"); + "S8STRD1"); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_SMALL_GROUP) { - checker_conv_bias_int8x8x32_multi( - get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), - "S8STRD1_SMALL_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_LARGE_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2) { checker_conv_bias_int8x8x32_multi( get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), - "S8STRD2_LARGE_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_SMALL_GROUP) { - checker_conv_bias_int8x8x32_multi( - get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), - "S8STRD2_SMALL_GROUP"); + "S8STRD2"); } TEST_F(ARM_COMMON_MULTI_THREADS, @@ -520,25 +476,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, } /********************************qint8 direct******************************/ -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_LARGE_GROUP) { - checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( - {2, 3, 5, 7}, 1, false, false, false), - handle(), "S8STRD1_LARGE_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_SMALL_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1) { checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( {2, 3, 5, 7}, 1, false, false, false), - handle(), "S8STRD1_SMALL_GROUP"); + handle(), "S8STRD1"); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_LARGE_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2) { checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( {2, 3, 5, 7}, 2, false, false, false), - handle(), "S8STRD2_LARGE_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) { - checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( - {2, 3, 5, 7}, 2, false, false, false), - handle(), "S8STRD2_SMALL_GROUP"); + handle(), "S8STRD2"); } TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) { checker_conv_bias_qint8x8x8( @@ -586,25 +532,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) { } /*****************************quint8 direct****************************/ -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_LARGE_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) { checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( {2, 3, 5, 7}, 1, false, false, false), - handle(), "QU8STRD1_LARGE_GROUP"); + handle(), "QU8STRD1"); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_SMALL_GROUP) { - checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( - {2, 3, 5, 7}, 1, false, false, false), - handle(), "QU8STRD1_SMALL_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_LARGE_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) { checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( {2, 3, 5, 7}, 2, false, false, false), - handle(), "QU8STRD2_LARGE_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) { - checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( - {2, 3, 5, 7}, 2, false, false, false), - handle(), "QU8STRD2_SMALL_GROUP"); + handle(), "QU8STRD2"); } /****************************dot qint8 direct*************************/ @@ -624,100 +560,53 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { } checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); } -TEST_F(ARM_COMMON_MULTI_THREADS, - CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( {2, 3, 5, 7}, 1, false, false, false), - handle(), "ARMDOTS8STRD1_LARGE_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, - CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) { - checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( - {2, 3, 5, 7}, 1, false, false, false), - handle(), "ARMDOTS8STRD1_SMALL_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, - CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) { - checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( - {2, 3, 5, 7}, 2, false, false, false), - handle(), "ARMDOTS8STRD2_LARGE_GROUP"); + handle(), "ARMDOTS8STRD1"); } -TEST_F(ARM_COMMON_MULTI_THREADS, - CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) { + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_WITHDOTPROD) { checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( {2, 3, 5, 7}, 2, false, false, false), - handle(), "ARMDOTS8STRD2_SMALL_GROUP"); + handle(), "ARMDOTS8STRD2"); } - /****************************dot 8-8-32 direct*************************/ -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_LARGE_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT) { checker_conv_bias_qint8x8x32( get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), - "ARMDOTS8STRD1_LARGE_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_SMALL_GROUP) { - checker_conv_bias_qint8x8x32( - get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), - "ARMDOTS8STRD1_SMALL_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_LARGE_GROUP) { - checker_conv_bias_qint8x8x32( - get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), - "ARMDOTS8STRD2_LARGE_GROUP"); + "ARMDOTS8STRD1"); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_SMALL_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT) { checker_conv_bias_qint8x8x32( get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), - "ARMDOTS8STRD2_SMALL_GROUP"); + "ARMDOTS8STRD2"); } /******************************dot quint8*****************************/ -TEST_F(ARM_COMMON_MULTI_THREADS, - CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD) { checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( {2, 3, 5, 7}, 1, false, false, false), - handle(), "ARMDOTU8STRD1_LARGE_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, - CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) { - checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( - {2, 3, 5, 7}, 1, false, false, false), - handle(), "ARMDOTU8STRD1_SMALL_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, - CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) { - checker_conv_bias_quint8x8x8( - get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false), - handle(), "ARMDOTU8STRD2_LARGE_GROUP"); + handle(), "ARMDOTU8STRD1"); } -TEST_F(ARM_COMMON_MULTI_THREADS, - CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) { +//! TODO: this test without test kernel size=3, add it will case buss error now +//! in armv7 +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) { checker_conv_bias_quint8x8x8( get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false), - handle(), "ARMDOTU8STRD2_SMALL_GROUP"); + handle(), "ARMDOTU8STRD2"); } /******************************dot quint8x8x32***********************/ -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_LARGE_GROUP) { - checker_conv_bias_quint8x8x32( - get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), - "ARMDOTU8STRD1_LARGE_GROUP"); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_SMALL_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1) { checker_conv_bias_quint8x8x32( get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), - "ARMDOTU8STRD1_SMALL_GROUP"); + "ARMDOTU8STRD1"); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_LARGE_GROUP) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2) { checker_conv_bias_quint8x8x32( get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), - "ARMDOTU8STRD2_LARGE_GROUP"); + "ARMDOTU8STRD2"); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_SMALL_GROUP) { - checker_conv_bias_quint8x8x32( - get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), - "ARMDOTU8STRD2_SMALL_GROUP"); -} - /******************************dot int8x8x8 nchw44 ***********************/ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) { using namespace conv_bias; diff --git a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp index 0431615b..58b62d8a 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp @@ -125,7 +125,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { bench_case(1, 32, 32, 80, 80, 3, 4); bench_case(1, 32, 32, 80, 80, 3, 32); - std::string algo_name = "F32DIRECT_LARGE_GROUP"; + std::string algo_name = "F32DIRECT"; printf("Benchmark F32DIRECT_LARGE_GROUP algo\n"); std::vector data_type = {dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; @@ -137,7 +137,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "F32DIRECT_SMALL_GROUP"; + algo_name = "F32DIRECT"; printf("Benchmark F32DIRECT_SMALL_GROUP algo\n"); bench_case(1, 32, 32, 200, 200, 3, 1); bench_case(1, 32, 32, 128, 128, 3, 1); @@ -186,7 +186,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR1) { bench_case(1, 32, 32, 80, 80, 3, 4); bench_case(1, 32, 32, 80, 80, 3, 32); - std::string algo_name = "F32STRD1_LARGE_GROUP"; + std::string algo_name = "F32STRD1"; printf("Benchmark F32STRD1_LARGE_GROUP algo\n"); std::vector data_type = {dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; @@ -198,7 +198,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR1) { {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "F32STRD1_SMALL_GROUP"; + algo_name = "F32STRD1"; printf("Benchmark F32STRD1_SMALL_GROUP algo\n"); bench_case(1, 32, 32, 200, 200, 3, 1); bench_case(1, 32, 32, 128, 128, 3, 1); @@ -249,7 +249,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR2) { bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); - std::string algo_name = "F32STRD2_LARGE_GROUP"; + std::string algo_name = "F32STRD2"; printf("Benchmark F32STRD2_LARGE_GROUP algo\n"); std::vector data_type = {dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; @@ -261,7 +261,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR2) { {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "F32STRD2_SMALL_GROUP"; + algo_name = "F32STRD2"; printf("Benchmark F32STRD2_SMALL_GROUP algo\n"); bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); @@ -313,7 +313,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16) { bench_case(1, 32, 32, 80, 80, 3, 4); bench_case(1, 32, 32, 80, 80, 3, 32); - std::string algo_name = "F16DIRECT_LARGE_GROUP"; + std::string algo_name = "F16DIRECT"; printf("Benchmark F16DIRECT_LARGE_GROUP algo\n"); std::vector data_type = {dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16()}; @@ -325,7 +325,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16) { {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "F16DIRECT_SMALL_GROUP"; + algo_name = "F16DIRECT"; printf("Benchmark F16DIRECT_SMALL_GROUP algo\n"); bench_case(1, 32, 32, 200, 200, 3, 1); bench_case(1, 32, 32, 128, 128, 3, 1); @@ -375,7 +375,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16_STR1) { bench_case(1, 32, 32, 80, 80, 3, 4); bench_case(1, 32, 32, 80, 80, 3, 32); - std::string algo_name = "F16STRD1_LARGE_GROUP"; + std::string algo_name = "F16STRD1"; printf("Benchmark F16STRD1_LARGE_GROUP algo\n"); std::vector data_type = {dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16()}; @@ -387,7 +387,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16_STR1) { {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "F16STRD1_SMALL_GROUP"; + algo_name = "F16STRD1"; printf("Benchmark F16STRD1_SMALL_GROUP algo\n"); bench_case(1, 32, 32, 200, 200, 3, 1); bench_case(1, 32, 32, 128, 128, 3, 1); @@ -439,7 +439,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, bench_case(1, 32, 32, 80, 80, 3, 4); bench_case(1, 32, 32, 80, 80, 3, 32); - std::string algo_name = "I8816DIRECT_LARGE_GROUP"; + std::string algo_name = "I8816DIRECT"; printf("Benchmark I8816DIRECT_LARGE_GROUP algo\n"); std::vector data_type = {dtype::Int8(), dtype::Int8(), dtype::Int16(), dtype::Int16()}; @@ -451,7 +451,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "I8816DIRECT_SMALL_GROUP"; + algo_name = "I8816DIRECT"; printf("Benchmark I8816DIRECT_SMALL_GROUP algo\n"); bench_case(1, 32, 32, 200, 200, 3, 1); bench_case(1, 32, 32, 128, 128, 3, 1); @@ -503,7 +503,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); - std::string algo_name = "I8816STRD2_LARGE_GROUP"; + std::string algo_name = "I8816STRD2"; printf("Benchmark I8816STRD2_LARGE_GROUP algo\n"); std::vector data_type = {dtype::Int8(), dtype::Int8(), dtype::Int16(), dtype::Int16()}; @@ -515,7 +515,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "I8816STRD2_SMALL_GROUP"; + algo_name = "I8816STRD2"; printf("Benchmark I8816STRD2_SMALL_GROUP algo\n"); bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); @@ -567,7 +567,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); - std::string algo_name = "S8STRD1_LARGE_GROUP"; + std::string algo_name = "S8STRD1"; printf("Benchmark S8STRD1_LARGE_GROUP algo\n"); std::vector data_type = { dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), @@ -580,7 +580,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "S8STRD1_SMALL_GROUP"; + algo_name = "S8STRD1"; printf("Benchmark S8STRD1_SMALL_GROUP algo\n"); bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); @@ -866,7 +866,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); - std::string algo_name = "S8STRD2_LARGE_GROUP"; + std::string algo_name = "S8STRD2"; printf("Benchmark S8STRD2_LARGE_GROUP algo\n"); std::vector data_type = { dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), @@ -879,7 +879,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "S8STRD2_SMALL_GROUP"; + algo_name = "S8STRD2"; printf("Benchmark S8STRD2_SMALL_GROUP algo\n"); bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); @@ -932,7 +932,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); - std::string algo_name = "ARMDOTS8STRD1_LARGE_GROUP"; + std::string algo_name = "ARMDOTS8STRD1"; printf("Benchmark ARMDOTS8STRD1_LARGE_GROUP algo\n"); std::vector data_type = { dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), @@ -945,7 +945,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "ARMDOTS8STRD1_SMALL_GROUP"; + algo_name = "ARMDOTS8STRD1"; printf("Benchmark ARMDOTS8STRD1_SMALL_GROUP algo\n"); bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); @@ -997,7 +997,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); - std::string algo_name = "ARMDOTS8STRD2_LARGE_GROUP"; + std::string algo_name = "ARMDOTS8STRD2"; printf("Benchmark ARMDOTS8STRD2_LARGE_GROUP algo\n"); std::vector data_type = { dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), @@ -1010,7 +1010,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "ARMDOTS8STRD2_SMALL_GROUP"; + algo_name = "ARMDOTS8STRD2"; printf("Benchmark ARMDOTS8STRD2_SMALL_GROUP algo\n"); bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); @@ -1064,7 +1064,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); - std::string algo_name = "QU8STRD1_LARGE_GROUP"; + std::string algo_name = "QU8STRD1"; printf("Benchmark QU8STRD1_LARGE_GROUP algo\n"); std::vector data_type = {dtype::Quantized8Asymm(0.2f, 100), dtype::Quantized8Asymm(0.2f, 120), @@ -1078,7 +1078,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "QU8STRD1_SMALL_GROUP"; + algo_name = "QU8STRD1"; printf("Benchmark QU8STRD1_SMALL_GROUP algo\n"); bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); @@ -1130,7 +1130,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); - std::string algo_name = "QU8STRD2_LARGE_GROUP"; + std::string algo_name = "QU8STRD2"; printf("Benchmark QU8STRD2_LARGE_GROUP algo\n"); std::vector data_type = {dtype::Quantized8Asymm(0.2f, 100), dtype::Quantized8Asymm(0.2f, 120), @@ -1144,7 +1144,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "QU8STRD2_SMALL_GROUP"; + algo_name = "QU8STRD2"; printf("Benchmark QU8STRD2_SMALL_GROUP algo\n"); bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); @@ -1198,7 +1198,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); - std::string algo_name = "ARMDOTU8STRD1_LARGE_GROUP"; + std::string algo_name = "ARMDOTU8STRD1"; printf("Benchmark ARMDOTU8STRD1_LARGE_GROUP algo\n"); std::vector data_type = {dtype::Quantized8Asymm(0.2f, 100), dtype::Quantized8Asymm(0.2f, 120), @@ -1212,7 +1212,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "ARMDOTU8STRD1_SMALL_GROUP"; + algo_name = "ARMDOTU8STRD1"; printf("Benchmark ARMDOTS8STRD1_SMALL_GROUP algo\n"); bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); @@ -1265,7 +1265,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, bench_case(1, 32, 32, 80, 80, 5, 4, 1, 2); bench_case(1, 32, 32, 80, 80, 5, 32, 1, 2); - std::string algo_name = "ARMDOTU8STRD2_LARGE_GROUP"; + std::string algo_name = "ARMDOTU8STRD2"; printf("Benchmark ARMDOTU8STRD2_LARGE_GROUP algo\n"); std::vector data_type = {dtype::Quantized8Asymm(0.2f, 100), dtype::Quantized8Asymm(0.2f, 120), @@ -1279,7 +1279,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "ARMDOTU8STRD2_SMALL_GROUP"; + algo_name = "ARMDOTU8STRD2"; printf("Benchmark ARMDOTU8STRD2_SMALL_GROUP algo\n"); bench_case(1, 32, 32, 200, 200, 5, 1, 1, 2); bench_case(1, 32, 32, 128, 128, 5, 1, 1, 2); diff --git a/dnn/test/arm_common/convolution.cpp b/dnn/test/arm_common/convolution.cpp index eda5ee9b..101b3b03 100644 --- a/dnn/test/arm_common/convolution.cpp +++ b/dnn/test/arm_common/convolution.cpp @@ -176,7 +176,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE1_I8x8x32_WITHDOTPROD) { constexpr size_t RUN = 50; Benchmarker benchmark(handle()); benchmark.set_before_exec_callback( - AlgoChecker("CONVOLUTION_DEFAULT_ARMDOTS8STRD1_SMALL_GROUP")); + AlgoChecker("CONVOLUTION_DEFAULT_ARMDOTS8STRD1")); benchmark.set_dtype(0, dtype::Int8()) .set_dtype(1, dtype::Int8()) .set_dtype(2, dtype::Int32()); @@ -243,7 +243,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE2_I8x8x32_WITHDOTPROD) { constexpr size_t RUN = 10; Benchmarker benchmark(handle()); benchmark.set_before_exec_callback( - AlgoChecker("CONVOLUTION_DEFAULT_ARMDOTS8STRD2_SMALL_GROUP")); + AlgoChecker("CONVOLUTION_DEFAULT_ARMDOTS8STRD2")); benchmark.set_dtype(0, dtype::Int8()) .set_dtype(1, dtype::Int8()) .set_dtype(2, dtype::Int32()); @@ -317,7 +317,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE1_QUINT8_WITHDOTPROD) { benchmark.set_display(false); benchmark.set_times(RUN); benchmark.set_before_exec_callback(AlgoChecker( - "CONVOLUTION_DEFAULT_ARMDOTU8STRD1_SMALL_GROUP")); + "CONVOLUTION_DEFAULT_ARMDOTU8STRD1")); Benchmarker benchmark_float(handle()); benchmark_float.set_display(false); @@ -387,7 +387,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE2_QUINT8_WITHDOTPROD) { benchmark.set_display(false); benchmark.set_times(RUN); benchmark.set_before_exec_callback(AlgoChecker( - "CONVOLUTION_DEFAULT_ARMDOTU8STRD2_SMALL_GROUP")); + "CONVOLUTION_DEFAULT_ARMDOTU8STRD2")); Benchmarker benchmark_float(handle()); benchmark_float.set_display(false); diff --git a/dnn/test/x86/conv_bias.cpp b/dnn/test/x86/conv_bias.cpp index a1180453..e3ec4001 100644 --- a/dnn/test/x86/conv_bias.cpp +++ b/dnn/test/x86/conv_bias.cpp @@ -583,7 +583,7 @@ TEST_F(X86_MULTI_THREADS, AVX2_CONV_BIAS_DIRECT_STRIDE2_S8S8S8) { } } -TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP) { +TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_DENSE) { using namespace conv_bias; std::vector args; @@ -633,19 +633,19 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP) { .set_rng(2, &rng); checker.set_before_exec_callback( conv_bias::ConvBiasAlgoChecker( - "X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP")); + "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP")); for (auto&& arg : args) { checker.set_param(arg.param).exec( {arg.src, arg.filter, arg.bias, {}, {}}); } } -TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP) { +TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_GROUP) { using namespace conv_bias; std::vector args; - auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, - size_t p, NonlineMode nonline_mode) { + auto run = [&](size_t group, size_t channel, size_t w, size_t h, + size_t kernel, size_t p, NonlineMode nonline_mode) { if (w + 2 * p < kernel || h + 2 * p < kernel) return; param::ConvBias param; @@ -654,30 +654,37 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP) { param.pad_h = p; param.pad_w = p; param.nonlineMode = nonline_mode; + param.sparse = param::ConvBias::Sparse::GROUP; //! no bias - args.emplace_back(param, TensorShape{1, ic, h, w}, - TensorShape{oc, ic, kernel, kernel}, TensorShape{}); + args.emplace_back( + param, TensorShape{1, channel, h, w}, + TensorShape{group, channel / group, channel / group, kernel, kernel}, + TensorShape{}); //! bias channel - args.emplace_back(param, TensorShape{2, ic, h, w}, - TensorShape{oc, ic, kernel, kernel}, - TensorShape{1, oc, 1, 1}); + args.emplace_back(param, TensorShape{2, channel, h, w}, + TensorShape{group, channel / group, channel / group, + kernel, kernel}, + TensorShape{1, channel, 1, 1}); //! bias - args.emplace_back(param, TensorShape{2, ic, h, w}, - TensorShape{oc, ic, kernel, kernel}, - TensorShape{2, oc, (h + param.pad_h * 2 - kernel) + 1, - (w + param.pad_w * 2 - kernel) + 1}); + args.emplace_back( + param, TensorShape{2, channel, h, w}, + TensorShape{group, channel / group, channel / group, kernel, + kernel}, + TensorShape{2, channel, (h + param.pad_h * 2 - kernel) + 1, + (w + param.pad_w * 2 - kernel) + 1}); }; for (size_t kernel : {1, 2, 3, 4, 5, 6, 7}) - for (size_t ic : {1, 4, 8, 16}) - for (size_t oc : {1, 4, 8}) + for (size_t channel : {4, 8, 16}) + for (size_t group : {1, 2, 4}) for (size_t p : {0, 2}) for (size_t size : {20, 21, 24}) for (NonlineMode nonline_mode : {NonlineMode::RELU, NonlineMode::SIGMOID, NonlineMode::H_SWISH, NonlineMode::IDENTITY}) { - run(oc, ic, size, size, kernel, p, nonline_mode); + run(group, channel, size, size, kernel, p, + nonline_mode); } Checker checker(handle()); @@ -697,7 +704,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP) { } } -TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2) { +TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2_DENSE) { using namespace conv_bias; std::vector args; @@ -738,11 +745,68 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2) { .set_rng(2, &rng); checker.set_before_exec_callback( conv_bias::ConvBiasAlgoChecker( - "X86_CONV_BIAS_DIRECT_STRIDE2_SMALL_GROUP")); + "X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP")); for (auto&& arg : args) { checker.set_param(arg.param).exec( {arg.src, arg.filter, arg.bias, {}, {}}); } +} + +TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2_GROUP) { + using namespace conv_bias; + std::vector args; + + auto run = [&](size_t group, size_t channel, size_t w, size_t h, + size_t kernel, size_t p, NonlineMode nonline_mode) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = 2; + param.stride_w = 2; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + param.sparse = param::ConvBias::Sparse::GROUP; + + //! no bias + args.emplace_back( + param, TensorShape{1, channel, h, w}, + TensorShape{group, channel / group, channel / group, kernel, kernel}, + TensorShape{}); + //! bias channel + args.emplace_back(param, TensorShape{2, channel, h, w}, + TensorShape{group, channel / group, channel / group, + kernel, kernel}, + TensorShape{1, channel, 1, 1}); + //! bias + args.emplace_back( + param, TensorShape{2, channel, h, w}, + TensorShape{group, channel / group, channel / group, kernel, + kernel}, + TensorShape{2, channel, (h + param.pad_h * 2 - kernel) / 2 + 1, + (w + param.pad_w * 2 - kernel) / 2 + 1}); + }; + + for (size_t kernel : {2, 3, 5, 7}) + for (size_t channel : {4, 8, 16}) + for (size_t group : {1, 2, 4}) + for (size_t p : {0, 2}) + for (size_t size : {20, 21, 24}) + for (NonlineMode nonline_mode : + {NonlineMode::RELU, NonlineMode::SIGMOID, + NonlineMode::H_SWISH, NonlineMode::IDENTITY}) { + run(group, channel, size, size, kernel, p, + nonline_mode); + } + + Checker checker(handle()); + UniformIntRNG rng{-50, 50}; + checker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng); checker.set_before_exec_callback( conv_bias::ConvBiasAlgoChecker( "X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP")); @@ -2502,7 +2566,7 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { bench_case(1, 32, 32, 80, 80, 3, 32); std::string algo_name = "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"; - printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP algo\n"); + printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_GROUP algo\n"); benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, data_type); benchmark_impl(param, shapes_and_computation, algo_name, RUNS, @@ -2511,8 +2575,8 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { {1, {4}}, data_type); shapes_and_computation.clear(); - algo_name = "X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP"; - printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP algo\n"); + algo_name = "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"; + printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_DENSE algo\n"); bench_case(1, 32, 32, 200, 200, 3, 1); bench_case(1, 32, 32, 128, 128, 3, 1); bench_case(1, 32, 32, 100, 100, 3, 1); diff --git a/dnn/test/x86/convolution.cpp b/dnn/test/x86/convolution.cpp index 5de87bcd..97a22b09 100644 --- a/dnn/test/x86/convolution.cpp +++ b/dnn/test/x86/convolution.cpp @@ -125,7 +125,7 @@ TEST_F(X86, DEFAULT_CONV_DIRECT_STRIDE1) { Checker checker(handle()); checker.set_before_exec_callback(AlgoChecker( - "CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP")); + "CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP")); checker.set_epsilon(1); UniformIntRNG rng{-50, 50}; checker.set_dtype(0, dtype::Float32()) @@ -167,7 +167,7 @@ TEST_F(X86, DEFAULT_CONV_DIRECT_STRIDE2) { Checker checker(handle()); checker.set_before_exec_callback(AlgoChecker( - "CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE2_SMALL_GROUP")); + "CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP")); checker.set_epsilon(1); UniformIntRNG rng{-50, 50}; checker.set_dtype(0, dtype::Float32())