From f7b2bdae1ac94d971b07fc4ebc762d4e1e82ca60 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 7 Nov 2020 21:34:09 +0800 Subject: [PATCH] refactor(dnn): refactor algorithm type interface GitOrigin-RevId: 843d885f82a42456c8b0f1018290a3b5c04a3f00 --- dnn/include/megdnn/oprs/base.h | 5 +- dnn/src/aarch64/conv_bias/opr_impl.cpp | 2 +- dnn/src/aarch64/conv_bias/opr_impl.h | 9 +++- dnn/src/aarch64/matrix_mul/algos.h | 21 -------- dnn/src/aarch64/matrix_mul/opr_impl.cpp | 4 +- dnn/src/aarch64/matrix_mul/opr_impl.h | 10 +++- dnn/src/arm_common/conv_bias/opr_impl.cpp | 27 ++++++---- dnn/src/arm_common/conv_bias/opr_impl.h | 18 ++++--- dnn/src/arm_common/convolution/int8x8x32/algos.cpp | 16 +++--- dnn/src/arm_common/convolution/int8x8x32/algos.h | 33 ++++++------ dnn/src/arm_common/convolution/opr_impl.cpp | 45 ++++++++-------- dnn/src/arm_common/convolution/opr_impl.h | 21 ++++---- dnn/src/arm_common/convolution/quint8/algos.cpp | 16 +++--- dnn/src/arm_common/convolution/quint8/algos.h | 31 ++++++----- dnn/src/arm_common/matrix_mul/algos.h | 8 --- dnn/src/arm_common/matrix_mul/opr_impl.cpp | 11 +--- dnn/src/arm_common/matrix_mul/opr_impl.h | 11 ++-- dnn/src/armv7/conv_bias/opr_impl.cpp | 2 +- dnn/src/armv7/conv_bias/opr_impl.h | 9 +++- dnn/src/armv7/matrix_mul/algos.h | 14 ----- dnn/src/armv7/matrix_mul/opr_impl.cpp | 4 +- dnn/src/armv7/matrix_mul/opr_impl.h | 9 +++- dnn/src/cuda/batch_conv_bias/algo.h | 1 + dnn/src/cuda/batched_matrix_mul/algo.h | 1 + dnn/src/cuda/conv_bias/algo.h | 1 + dnn/src/cuda/convolution/backward_data/algo.h | 1 + dnn/src/cuda/convolution/backward_filter/algo.h | 1 + dnn/src/cuda/convolution3d/backward_data/algo.h | 1 + dnn/src/cuda/convolution3d/backward_filter/algo.h | 5 +- dnn/src/cuda/convolution3d/forward/algo.h | 5 +- dnn/src/cuda/deformable_conv/bwd_data/algo.h | 1 + dnn/src/cuda/deformable_conv/bwd_flt/algo.h | 1 + dnn/src/cuda/deformable_conv/fwd/algo.h | 1 + dnn/src/cuda/local_share/backward_data/algo.h | 1 + dnn/src/cuda/local_share/backward_filter/algo.h | 1 + dnn/src/cuda/local_share/forward/algo.h | 1 + dnn/src/cuda/matrix_mul/algos.h | 5 +- dnn/src/fallback/conv_bias/opr_impl.h | 3 ++ dnn/src/fallback/convolution/algos.h | 4 -- dnn/src/fallback/convolution/opr_impl.cpp | 12 +---- dnn/src/fallback/convolution/opr_impl.h | 10 ++-- dnn/src/fallback/matrix_mul/opr_impl.h | 1 + dnn/src/rocm/batched_matrix_mul/opr_impl.cpp | 9 ++-- dnn/src/rocm/convolution/backward_data/algo.h | 1 + dnn/src/rocm/convolution/backward_filter/algo.h | 1 + dnn/src/rocm/convolution/forward/algo.h | 1 + dnn/src/x86/conv_bias/f32/algos.h | 7 --- dnn/src/x86/conv_bias/int8/algos.h | 7 --- dnn/src/x86/conv_bias/opr_impl.cpp | 60 ++++------------------ dnn/src/x86/conv_bias/opr_impl.h | 9 +++- dnn/src/x86/matrix_mul/algos.h | 10 ---- dnn/src/x86/matrix_mul/opr_impl.cpp | 10 +--- dnn/src/x86/matrix_mul/opr_impl.h | 9 +++- 53 files changed, 230 insertions(+), 277 deletions(-) diff --git a/dnn/include/megdnn/oprs/base.h b/dnn/include/megdnn/oprs/base.h index 8f2197af..97233238 100644 --- a/dnn/include/megdnn/oprs/base.h +++ b/dnn/include/megdnn/oprs/base.h @@ -11,6 +11,7 @@ #pragma once #include "megdnn/basic_types.h" +#include "megdnn/handle.h" #include "megdnn/internal/visibility_prologue.h" namespace megdnn { @@ -105,11 +106,11 @@ public: virtual bool is_reproducible() const = 0; virtual const char* name() const = 0; - //! a pointer to represent class type - virtual void* type() const { return nullptr; } + Handle::HandleType handle_type() const { return m_handle_type; } protected: ~Algorithm() = default; + Handle::HandleType m_handle_type = Handle::HandleType::NAIVE; }; /*! diff --git a/dnn/src/aarch64/conv_bias/opr_impl.cpp b/dnn/src/aarch64/conv_bias/opr_impl.cpp index 65ec4809..0d997f73 100644 --- a/dnn/src/aarch64/conv_bias/opr_impl.cpp +++ b/dnn/src/aarch64/conv_bias/opr_impl.cpp @@ -45,7 +45,7 @@ public: SmallVector matmul_algos; }; -SmallVector ConvBiasImpl::algo_pack() { +SmallVector ConvBiasImpl::algo_pack() { static AlgoPack sl_algo_pack; auto&& algos = arm_common::ConvBiasImpl::algo_pack(); algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), diff --git a/dnn/src/aarch64/conv_bias/opr_impl.h b/dnn/src/aarch64/conv_bias/opr_impl.h index 16c11672..7666ab15 100644 --- a/dnn/src/aarch64/conv_bias/opr_impl.h +++ b/dnn/src/aarch64/conv_bias/opr_impl.h @@ -18,11 +18,16 @@ namespace aarch64 { class ConvBiasImpl : public arm_common::ConvBiasImpl { public: using arm_common::ConvBiasImpl::ConvBiasImpl; + class AlgoBase : public arm_common::ConvBiasImpl::AlgoBase { + public: + AlgoBase() : arm_common::ConvBiasImpl::AlgoBase() { + m_handle_type = Handle::HandleType::AARCH64; + } + }; - SmallVector algo_pack() override; + SmallVector algo_pack() override; protected: - const char* get_algorithm_set_name() const override; private: diff --git a/dnn/src/aarch64/matrix_mul/algos.h b/dnn/src/aarch64/matrix_mul/algos.h index b08eeb7d..a20247c4 100644 --- a/dnn/src/aarch64/matrix_mul/algos.h +++ b/dnn/src/aarch64/matrix_mul/algos.h @@ -26,7 +26,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -37,7 +36,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -48,7 +46,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -59,7 +56,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4) }; @@ -75,7 +71,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -86,7 +81,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8) }; @@ -103,7 +97,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -116,7 +109,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; #else @@ -129,7 +121,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::DEFAULT; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); @@ -143,7 +134,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -156,7 +146,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; #endif @@ -169,7 +158,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -182,7 +170,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -196,7 +183,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::DEFAULT; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); @@ -212,7 +198,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::DEFAULT; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); @@ -226,7 +211,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::DEFAULT; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); @@ -240,7 +224,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -251,7 +234,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) }; @@ -266,7 +248,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -278,7 +259,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override { return 0; } kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) @@ -292,7 +272,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; #endif diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.cpp b/dnn/src/aarch64/matrix_mul/opr_impl.cpp index 2b7614eb..2910e582 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.cpp +++ b/dnn/src/aarch64/matrix_mul/opr_impl.cpp @@ -52,7 +52,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { #endif public: - SmallVector all_algos; + SmallVector all_algos; AlgoPack() { all_algos.emplace_back(&f32_gemv); @@ -89,7 +89,7 @@ public: } }; -SmallVector MatrixMulImpl::algo_pack() { +SmallVector MatrixMulImpl::algo_pack() { static AlgoPack s_algo_pack; auto&& algos = arm_common::MatrixMulImpl::algo_pack(); algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.h b/dnn/src/aarch64/matrix_mul/opr_impl.h index 906cbc85..31c8ef3b 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.h +++ b/dnn/src/aarch64/matrix_mul/opr_impl.h @@ -18,8 +18,14 @@ namespace aarch64 { class MatrixMulImpl : public arm_common::MatrixMulImpl { public: using arm_common::MatrixMulImpl::MatrixMulImpl; + class AlgoBase : public arm_common::MatrixMulImpl::AlgoBase { + public: + AlgoBase() : arm_common::MatrixMulImpl::AlgoBase() { + m_handle_type = Handle::HandleType::AARCH64; + } + }; - SmallVector algo_pack() override; + SmallVector algo_pack() override; private: class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 @@ -57,7 +63,7 @@ private: #else class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 #endif - class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 + class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 class AlgoPack; }; diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index 052d1430..2a50104d 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -11,6 +11,7 @@ */ #include "megdnn/opr_param_defs.h" +#include "megdnn/oprs/base.h" #include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8x8x16/algos.h" #include "src/arm_common/conv_bias/quint8/algos.h" @@ -18,6 +19,7 @@ #include "src/arm_common/conv_bias/opr_impl.h" #include "src/common/metahelper.h" #include "src/common/utils.h" +#include "src/fallback/conv_bias/opr_impl.h" #include "src/naive/handle.h" #include "src/arm_common/convolution/opr_impl.h" @@ -37,7 +39,12 @@ using namespace megdnn; using namespace arm_common; namespace { -uint8_t arm_common_algo_type_storage; + +bool is_fallback_or_naive(const detail::Algorithm* algo) { + return algo->handle_type() == Handle::HandleType::NAIVE || + algo->handle_type() == Handle::HandleType::FALLBACK; +} + } // anonymous namespace class ConvBiasImpl::AlgoPack : NonCopyableObj { @@ -50,7 +57,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoS8DirectStride1 s8_direct_stride1; AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; - AlgoS8x8x16ChanWiseStride1Stride2NCHW44 s8x8x16_channel_wise_stride1_stride2_nchw44; + AlgoS8x8x16ChanWiseStride1Stride2NCHW44 + s8x8x16_channel_wise_stride1_stride2_nchw44; #if __ARM_FEATURE_DOTPROD AlgoDotS8DirectStride1 ds8_direct_stride1; @@ -129,7 +137,7 @@ public: ->select_algo_type( {AlgoDataType::FLOAT32, MatmulFormat::MK4}); for (auto&& algo : matmul_algos) { - if (algo->type() == nullptr) + if (is_fallback_or_naive(algo)) continue; for (uint32_t tile_size : {16, 8, 24, 32}) { refhold.emplace_back(new AlgoFP32WinogradF23_4x4( @@ -166,7 +174,7 @@ public: ->select_algo_type({AlgoDataType::FLOAT32, MatmulFormat::DEFAULT}); for (auto&& algo : matmul_algos) { - if (algo->type() == nullptr) + if (is_fallback_or_naive(algo)) continue; for (uint32_t tile_size : {16, 8, 24, 32}) { refhold.emplace_back(new AlgoFP32WinogradF63( @@ -189,7 +197,7 @@ public: ->select_algo_type({AlgoDataType::FLOAT16, MatmulFormat::DEFAULT}); for (auto&& algo : matmul_algos) { - if (algo->type() == nullptr) + if (is_fallback_or_naive(algo)) continue; for (uint32_t tile_size : {16, 8, 24, 32}) { refhold.emplace_back(new AlgoFP16WinogradF23( @@ -210,7 +218,7 @@ public: ->select_algo_type({AlgoDataType::FLOAT16, MatmulFormat::MK8}); for (auto&& algo : matmul_algos) { - if (algo->type() == nullptr) + if (is_fallback_or_naive(algo)) continue; for (uint32_t tile_size : {16, 8, 24, 32}) { refhold.emplace_back(new AlgoFP16WinogradF23_8x8( @@ -224,7 +232,7 @@ public: ->select_algo_type({AlgoDataType::INT16X16X32, MatmulFormat::MK8}); for (auto&& algo : matmul_algos) { - if (algo->type() == nullptr) + if (is_fallback_or_naive(algo)) continue; for (uint32_t tile_size : {16, 8, 24, 32}) { refhold.emplace_back(new AlgoS8WinogradF23_8x8( @@ -242,7 +250,7 @@ public: SmallVector winograd_algos; }; -SmallVector ConvBiasImpl::algo_pack() { +SmallVector ConvBiasImpl::algo_pack() { static AlgoPack sl_algo_pack; auto&& algos = fallback::ConvBiasImpl::algo_pack(); algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), @@ -252,9 +260,6 @@ SmallVector ConvBiasImpl::algo_pack() { return std::move(algos); } -void* const ConvBiasImpl::sm_arm_common_algo_type = - &arm_common_algo_type_storage; - bool ConvBiasImpl::is_matmul_quantized_prefer( const ConvBiasImpl::NCBKernSizeParam& param) const { fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param( diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 80fe1c3e..61f622e9 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -19,23 +19,25 @@ namespace arm_common { class ConvBiasImpl : public fallback::ConvBiasImpl { public: using fallback::ConvBiasImpl::ConvBiasImpl; - using FallbackConvBiasImpl = fallback::ConvBiasImpl; - using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; - bool is_thread_safe() const override { return true; } + class AlgoBase : public fallback::ConvBiasImpl::AlgoBase { + public: + AlgoBase() : fallback::ConvBiasImpl::AlgoBase() { + m_handle_type = Handle::HandleType::ARM_COMMON; + } + }; - SmallVector algo_pack() override; + SmallVector algo_pack() override; bool is_matmul_quantized_prefer( - const ConvBiasImpl::NCBKernSizeParam& ncb_param) const override; + const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param) + const override; SmallVector suggest_algo_category_order( const NCBKernSizeParam& param) const override; class AlgoPack; protected: - static void* const sm_arm_common_algo_type; - const char* get_algorithm_set_name() const override; private: @@ -93,7 +95,7 @@ private: class AlgoF16Direct; class AlgoF16DirectStride1; #endif - }; +}; } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/convolution/int8x8x32/algos.cpp b/dnn/src/arm_common/convolution/int8x8x32/algos.cpp index 65e98640..68eeef83 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/algos.cpp +++ b/dnn/src/arm_common/convolution/int8x8x32/algos.cpp @@ -26,12 +26,14 @@ using namespace arm_common; /* ===================== ConvolutionBackwardData ===================== */ /* ===================== direct stride 1 algo ===================== */ bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable( - ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const { return deconv::can_stride1_int8x8x32_dot(param); } size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace( - ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, midout_iv("AlgoSdot8DirectStride1::get_workspace"_hash)) { return deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(param); @@ -42,7 +44,7 @@ size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace( ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern( - ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, midout_iv("AlgoSdot8DirectStride1::dispatch_kern"_hash)) { return deconv::stride1_int8x8x32_dot; @@ -53,12 +55,14 @@ ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern( /* ===================== direct stride 2 algo ===================== */ bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable( - ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const { return deconv::can_stride2_int8x8x32_dot(param); } size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace( - ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, midout_iv("AlgoSdot8DirectStride2::get_workspace"_hash)) { return deconv::get_workspace_in_bytes_stride2_int8x8x32_dot(param); @@ -69,7 +73,7 @@ size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace( ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::dispatch_kern( - ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, midout_iv("AlgoSdot8DirectStride2::dispatch_kern"_hash)) { return deconv::stride2_int8x8x32_dot; diff --git a/dnn/src/arm_common/convolution/int8x8x32/algos.h b/dnn/src/arm_common/convolution/int8x8x32/algos.h index 154cc8a5..e69794dc 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/algos.h +++ b/dnn/src/arm_common/convolution/int8x8x32/algos.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -19,38 +20,40 @@ namespace arm_common { #if __ARM_FEATURE_DOTPROD /* ===================== ConvolutionBackwardData ===================== */ -class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final : public AlgoBase { +class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final + : public AlgoBase { public: bool is_reproducible() const override { return true; } - const char* name() const override { return "AARCH32_I8x8x32_DECONV_STRIDE1"; } + const char* name() const override { + return "AARCH32_I8x8x32_DECONV_STRIDE1"; + } - bool usable(ConvolutionBackwardDataImpl*, + bool usable(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const override; - size_t get_workspace(ConvolutionBackwardDataImpl*, + size_t get_workspace(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const override; - ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, + ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override; - - void* type() const override { return sm_arm_common_algo_type; } }; -class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final : public AlgoBase { +class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final + : public AlgoBase { public: bool is_reproducible() const override { return true; } - const char* name() const override { return "AARCH32_I8x8x32_DECONV_STRIDE2"; } + const char* name() const override { + return "AARCH32_I8x8x32_DECONV_STRIDE2"; + } - bool usable(ConvolutionBackwardDataImpl*, + bool usable(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const override; - size_t get_workspace(ConvolutionBackwardDataImpl*, + size_t get_workspace(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const override; - ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, + ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override; - - void* type() const override { return sm_arm_common_algo_type; } }; #endif diff --git a/dnn/src/arm_common/convolution/opr_impl.cpp b/dnn/src/arm_common/convolution/opr_impl.cpp index d9a36e94..9d7af4de 100644 --- a/dnn/src/arm_common/convolution/opr_impl.cpp +++ b/dnn/src/arm_common/convolution/opr_impl.cpp @@ -21,9 +21,6 @@ using namespace megdnn; using namespace arm_common; -namespace { -uint8_t arm_common_algo_type_storage; -} // anonymous namespace /* ===================== ConvolutionBackwardData ===================== */ struct ConvolutionBackwardDataImpl::AlgoPack { @@ -36,46 +33,44 @@ struct ConvolutionBackwardDataImpl::AlgoPack { }; ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack; -void* const ConvolutionBackwardDataImpl::sm_arm_common_algo_type = - &arm_common_algo_type_storage; - -ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern( +ConvolutionBackwardDataImpl::ncb_kern_t +ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern( Algorithm* algo, const NCBKernSizeParam& param) { - if (algo->type() == sm_arm_common_algo_type) { + if (algo->handle_type() == Handle::HandleType::ARM_COMMON) { return static_cast(algo)->dispatch_kern(this, param); } - return fallback::ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(algo, param); + return fallback::ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(algo, + param); } -size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(Algorithm* algo, - const NCBKernSizeParam& param) { - if (algo->type() == sm_arm_common_algo_type) { +size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( + Algorithm* algo, const NCBKernSizeParam& param) { + if (algo->handle_type() == Handle::HandleType::ARM_COMMON) { return static_cast(algo)->get_workspace(this, param); } - return fallback::ConvolutionBackwardDataImpl::ncb_1g_get_workspace(algo, param); + return fallback::ConvolutionBackwardDataImpl::ncb_1g_get_workspace(algo, + param); } std::vector -ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(const NCBKernSizeParam& param) { - - auto ret = fallback::ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(param); +ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( + const NCBKernSizeParam& param) { + auto ret = fallback::ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( + param); #if __ARM_FEATURE_DOTPROD - if((param.filter_type.enumv() == DTypeEnum::QuantizedS8 || - param.filter_type.enumv() == DTypeEnum::Int8) && - (param.grad_type.enumv() == DTypeEnum::QuantizedS32 || - param.grad_type.enumv() == DTypeEnum::Int32)) { - + if ((param.filter_type.enumv() == DTypeEnum::QuantizedS8 || + param.filter_type.enumv() == DTypeEnum::Int8) && + (param.grad_type.enumv() == DTypeEnum::QuantizedS32 || + param.grad_type.enumv() == DTypeEnum::Int32)) { if (sm_algo_pack.i8x8x32_direct_stride1_sdot.usable(this, param)) { ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride1_sdot); } if (sm_algo_pack.i8x8x32_direct_stride2_sdot.usable(this, param)) { ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride2_sdot); } - } - else if(param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && - param.grad_type.enumv() == DTypeEnum::QuantizedS32) { - + } else if (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && + param.grad_type.enumv() == DTypeEnum::QuantizedS32) { if (sm_algo_pack.quint8_direct_stride1_udot.usable(this, param)) { ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride1_udot); } diff --git a/dnn/src/arm_common/convolution/opr_impl.h b/dnn/src/arm_common/convolution/opr_impl.h index 17e8ec0b..d8d124ad 100644 --- a/dnn/src/arm_common/convolution/opr_impl.h +++ b/dnn/src/arm_common/convolution/opr_impl.h @@ -18,24 +18,27 @@ namespace arm_common { class ConvBiasImpl; -class ConvolutionBackwardDataImpl : public fallback::ConvolutionBackwardDataImpl { +class ConvolutionBackwardDataImpl + : public fallback::ConvolutionBackwardDataImpl { public: using fallback::ConvolutionBackwardDataImpl::ConvolutionBackwardDataImpl; protected: - static void* const sm_arm_common_algo_type; - - class AlgoBase : public Algorithm { + class AlgoBase : public fallback::ConvolutionBackwardDataImpl::AlgoBase { protected: ~AlgoBase() = default; public: - virtual bool usable(ConvolutionBackwardDataImpl* opr, + AlgoBase() : fallback::ConvolutionBackwardDataImpl::AlgoBase() { + m_handle_type = Handle::HandleType::ARM_COMMON; + } + virtual bool usable(fallback::ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param) const = 0; - virtual size_t get_workspace(ConvolutionBackwardDataImpl* opr, + virtual size_t get_workspace(fallback::ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param) const = 0; virtual ncb_kern_t dispatch_kern( - ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param) const = 0; + fallback::ConvolutionBackwardDataImpl* opr, + const NCBKernSizeParam& param) const = 0; }; ncb_kern_t ncb_1g_dispatch_kern(Algorithm* algo, @@ -49,7 +52,7 @@ protected: const char* get_algorithm_set_name() const override; - private: +private: #if __ARM_FEATURE_DOTPROD class AlgoSdot8DirectStride1; class AlgoSdot8DirectStride2; @@ -62,4 +65,4 @@ protected: } // namespace arm_common } // namespace megdnn -// vim: syntax=cpp.doxygen + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/quint8/algos.cpp b/dnn/src/arm_common/convolution/quint8/algos.cpp index cb816e8b..bc531610 100644 --- a/dnn/src/arm_common/convolution/quint8/algos.cpp +++ b/dnn/src/arm_common/convolution/quint8/algos.cpp @@ -27,12 +27,14 @@ using namespace arm_common; /* ===================== direct stride 1 algo ===================== */ bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable( - ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const { return deconv::can_stride1_quint8_dot(param); } size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace( - ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, midout_iv("AlgoUdot8DirectStride1::get_workspace"_hash)) { return deconv::get_workspace_in_bytes_stride1_quint8_dot(param); @@ -43,7 +45,7 @@ size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace( ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern( - ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, midout_iv("AlgoUdot8DirectStride1::dispatch_kern"_hash)) { return deconv::stride1_quint8_dot; @@ -54,12 +56,14 @@ ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern( /* ===================== direct stride 2 algo ===================== */ bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable( - ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const { return deconv::can_stride2_quint8_dot(param); } size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace( - ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, midout_iv("AlgoUdot8DirectStride2::get_workspace"_hash)) { return deconv::get_workspace_in_bytes_stride2_quint8_dot(param); @@ -70,7 +74,7 @@ size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace( ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::dispatch_kern( - ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, midout_iv("AlgoUdot8DirectStride2::dispatch_kern"_hash)) { return deconv::stride2_quint8_dot; diff --git a/dnn/src/arm_common/convolution/quint8/algos.h b/dnn/src/arm_common/convolution/quint8/algos.h index 5cba3485..a4815380 100644 --- a/dnn/src/arm_common/convolution/quint8/algos.h +++ b/dnn/src/arm_common/convolution/quint8/algos.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -18,38 +19,42 @@ namespace arm_common { #if __ARM_FEATURE_DOTPROD /* ===================== ConvolutionBackwardData ===================== */ -class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final : public AlgoBase { +class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final + : public AlgoBase { public: bool is_reproducible() const override { return true; } - const char* name() const override { return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE1"; } + const char* name() const override { + return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE1"; + } - bool usable(ConvolutionBackwardDataImpl*, + bool usable(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const override; - size_t get_workspace(ConvolutionBackwardDataImpl*, + size_t get_workspace(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const override; - ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, + ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } }; -class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final : public AlgoBase { +class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final + : public AlgoBase { public: bool is_reproducible() const override { return true; } - const char* name() const override { return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE2"; } + const char* name() const override { + return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE2"; + } - bool usable(ConvolutionBackwardDataImpl*, + bool usable(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const override; - size_t get_workspace(ConvolutionBackwardDataImpl*, + size_t get_workspace(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const override; - ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, + ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } }; #endif } // namespace arm_common diff --git a/dnn/src/arm_common/matrix_mul/algos.h b/dnn/src/arm_common/matrix_mul/algos.h index ef85bc73..e728a9df 100644 --- a/dnn/src/arm_common/matrix_mul/algos.h +++ b/dnn/src/arm_common/matrix_mul/algos.h @@ -24,7 +24,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT) }; @@ -37,7 +36,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override { return 0; } kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) @@ -51,7 +49,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override { return 0; } kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) @@ -66,7 +63,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override { return 0; } kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT) @@ -84,7 +80,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override { return 0; } kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) @@ -98,7 +93,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override { return 0; } kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) @@ -113,7 +107,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override { return 0; } kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT) @@ -128,7 +121,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override { return 0; } kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC( diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.cpp b/dnn/src/arm_common/matrix_mul/opr_impl.cpp index 66d93084..f1527374 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.cpp +++ b/dnn/src/arm_common/matrix_mul/opr_impl.cpp @@ -15,13 +15,6 @@ using namespace megdnn; using namespace arm_common; -namespace { -uint8_t arm_common_algo_type_storage; -} // anonymous namespace - -void* const MatrixMulImpl::sm_arm_common_algo_type = - &arm_common_algo_type_storage; - class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoInt8x8x16 int8x8x16; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -49,10 +42,10 @@ public: all_algos.emplace_back(&f32_gemv_mk4); all_algos.emplace_back(&gevm); } - SmallVector all_algos; + SmallVector all_algos; }; -SmallVector MatrixMulImpl::algo_pack() { +SmallVector MatrixMulImpl::algo_pack() { static AlgoPack s_algo_pack; auto&& algos = fallback::MatrixMulImpl::algo_pack(); algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.h b/dnn/src/arm_common/matrix_mul/opr_impl.h index 9ed9f3c0..0014b0cf 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.h +++ b/dnn/src/arm_common/matrix_mul/opr_impl.h @@ -18,13 +18,18 @@ namespace arm_common { class MatrixMulImpl : public fallback::MatrixMulImpl { public: using fallback::MatrixMulImpl::MatrixMulImpl; - bool is_thread_safe() const override { return true; } - SmallVector algo_pack() override; + class AlgoBase : public fallback::MatrixMulImpl::AlgoBase { + public: + AlgoBase() : fallback::MatrixMulImpl::AlgoBase() { + m_handle_type = Handle::HandleType::ARM_COMMON; + } + }; + + SmallVector algo_pack() override; protected: - static void* const sm_arm_common_algo_type; class AlgoF32Gemv; // Arm_common F32 Gemv class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44 class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv diff --git a/dnn/src/armv7/conv_bias/opr_impl.cpp b/dnn/src/armv7/conv_bias/opr_impl.cpp index 76602933..db6f09c3 100644 --- a/dnn/src/armv7/conv_bias/opr_impl.cpp +++ b/dnn/src/armv7/conv_bias/opr_impl.cpp @@ -32,7 +32,7 @@ public: SmallVector all_algos; }; -SmallVector ConvBiasImpl::algo_pack() { +SmallVector ConvBiasImpl::algo_pack() { static AlgoPack sl_algo_pack; auto&& algos = arm_common::ConvBiasImpl::algo_pack(); //! TODO fused matmul bias is slower than matmul + elemwise in armv7 now, diff --git a/dnn/src/armv7/conv_bias/opr_impl.h b/dnn/src/armv7/conv_bias/opr_impl.h index 4cbf4b06..32d97439 100644 --- a/dnn/src/armv7/conv_bias/opr_impl.h +++ b/dnn/src/armv7/conv_bias/opr_impl.h @@ -18,11 +18,16 @@ namespace armv7 { class ConvBiasImpl : public arm_common::ConvBiasImpl { public: using arm_common::ConvBiasImpl::ConvBiasImpl; + class AlgoBase : public arm_common::ConvBiasImpl::AlgoBase { + public: + AlgoBase() : arm_common::ConvBiasImpl::AlgoBase() { + m_handle_type = Handle::HandleType::ARMV7; + } + }; - SmallVector algo_pack() override; + SmallVector algo_pack() override; protected: - const char* get_algorithm_set_name() const override; private: diff --git a/dnn/src/armv7/matrix_mul/algos.h b/dnn/src/armv7/matrix_mul/algos.h index 481662fd..10c8f0ca 100644 --- a/dnn/src/armv7/matrix_mul/algos.h +++ b/dnn/src/armv7/matrix_mul/algos.h @@ -26,7 +26,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -37,7 +36,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -48,7 +46,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) }; @@ -61,7 +58,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase { @@ -71,7 +67,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8) }; @@ -121,7 +116,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -133,7 +127,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -144,7 +137,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -156,7 +148,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -168,7 +159,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -180,7 +170,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -192,7 +181,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -203,7 +191,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) }; @@ -216,7 +203,6 @@ public: bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; diff --git a/dnn/src/armv7/matrix_mul/opr_impl.cpp b/dnn/src/armv7/matrix_mul/opr_impl.cpp index 904bba8b..6887cfe3 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.cpp +++ b/dnn/src/armv7/matrix_mul/opr_impl.cpp @@ -44,7 +44,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; public: - SmallVector all_algos; + SmallVector all_algos; AlgoPack() { all_algos.emplace_back(&f32_gemv); @@ -73,7 +73,7 @@ public: } }; -SmallVector MatrixMulImpl::algo_pack() { +SmallVector MatrixMulImpl::algo_pack() { static AlgoPack s_algo_pack; auto algos = arm_common::MatrixMulImpl::algo_pack(); algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), diff --git a/dnn/src/armv7/matrix_mul/opr_impl.h b/dnn/src/armv7/matrix_mul/opr_impl.h index 701cef4e..744099a8 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.h +++ b/dnn/src/armv7/matrix_mul/opr_impl.h @@ -18,7 +18,14 @@ namespace armv7 { class MatrixMulImpl : public arm_common::MatrixMulImpl { public: using arm_common::MatrixMulImpl::MatrixMulImpl; - SmallVector algo_pack() override; + class AlgoBase : public arm_common::MatrixMulImpl::AlgoBase { + public: + AlgoBase() : arm_common::MatrixMulImpl::AlgoBase() { + m_handle_type = Handle::HandleType::ARMV7; + } + }; + + SmallVector algo_pack() override; private: class AlgoF32; // Armv7 F32 diff --git a/dnn/src/cuda/batch_conv_bias/algo.h b/dnn/src/cuda/batch_conv_bias/algo.h index 6b2668ef..e7748cab 100644 --- a/dnn/src/cuda/batch_conv_bias/algo.h +++ b/dnn/src/cuda/batch_conv_bias/algo.h @@ -26,6 +26,7 @@ protected: ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { BatchConvBiasForwardImpl* opr; TensorLayout src_layout, filter_layout, bias_layout, z_layout, diff --git a/dnn/src/cuda/batched_matrix_mul/algo.h b/dnn/src/cuda/batched_matrix_mul/algo.h index 83597f5d..b0b3bd8a 100644 --- a/dnn/src/cuda/batched_matrix_mul/algo.h +++ b/dnn/src/cuda/batched_matrix_mul/algo.h @@ -28,6 +28,7 @@ protected: ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { BatchedMatrixMulForwardImpl* opr; TensorLayout layout_a, layout_b, layout_c; diff --git a/dnn/src/cuda/conv_bias/algo.h b/dnn/src/cuda/conv_bias/algo.h index 2c673ff0..32a214f1 100644 --- a/dnn/src/cuda/conv_bias/algo.h +++ b/dnn/src/cuda/conv_bias/algo.h @@ -38,6 +38,7 @@ protected: ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs : public conv_bias::BiasForwardSizeArgs { ConvBiasForwardImpl* opr; diff --git a/dnn/src/cuda/convolution/backward_data/algo.h b/dnn/src/cuda/convolution/backward_data/algo.h index 9e505e18..eaa6038d 100644 --- a/dnn/src/cuda/convolution/backward_data/algo.h +++ b/dnn/src/cuda/convolution/backward_data/algo.h @@ -28,6 +28,7 @@ class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm { ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { HandleImpl *handle; CanonizedFilterMeta filter_meta; diff --git a/dnn/src/cuda/convolution/backward_filter/algo.h b/dnn/src/cuda/convolution/backward_filter/algo.h index ef70258b..d54f3121 100644 --- a/dnn/src/cuda/convolution/backward_filter/algo.h +++ b/dnn/src/cuda/convolution/backward_filter/algo.h @@ -28,6 +28,7 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm { ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { HandleImpl *handle; const TensorLayout *src_layout, *diff_layout, *grad_layout; diff --git a/dnn/src/cuda/convolution3d/backward_data/algo.h b/dnn/src/cuda/convolution3d/backward_data/algo.h index 56a495d9..2d0baed9 100644 --- a/dnn/src/cuda/convolution3d/backward_data/algo.h +++ b/dnn/src/cuda/convolution3d/backward_data/algo.h @@ -28,6 +28,7 @@ class Convolution3DBackwardDataImpl::AlgoBase: public Algorithm { ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { HandleImpl *handle; CanonizedFilterMeta filter_meta; diff --git a/dnn/src/cuda/convolution3d/backward_filter/algo.h b/dnn/src/cuda/convolution3d/backward_filter/algo.h index 3750e25b..9ce504ec 100644 --- a/dnn/src/cuda/convolution3d/backward_filter/algo.h +++ b/dnn/src/cuda/convolution3d/backward_filter/algo.h @@ -22,6 +22,7 @@ class Convolution3DBackwardFilterImpl::AlgoBase: public Algorithm { ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { HandleImpl *handle; const TensorLayout *src_layout, *diff_layout; @@ -128,8 +129,8 @@ class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final: public AlgoBase const char* name() const override { return "INPLACE_MATMUL"; } - bool is_reproducible() const override { - return false; + bool is_reproducible() const override { + return false; } }; diff --git a/dnn/src/cuda/convolution3d/forward/algo.h b/dnn/src/cuda/convolution3d/forward/algo.h index baf6ad16..726dcbaf 100644 --- a/dnn/src/cuda/convolution3d/forward/algo.h +++ b/dnn/src/cuda/convolution3d/forward/algo.h @@ -34,6 +34,7 @@ class Convolution3DForwardImpl::AlgoBase: public Algorithm { ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs: public convolution3d::ForwardSizeArgs { Convolution3DForwardImpl *opr; @@ -42,11 +43,11 @@ class Convolution3DForwardImpl::AlgoBase: public Algorithm { desc.set(*src_layout, filter_meta, *dst_layout, opr->param()); } SizeArgs(Convolution3DForwardImpl *opr, - const TensorLayout &src, + const TensorLayout &src, const TensorLayout &filter, const TensorLayout &dst); SizeArgs(Convolution3DForwardImpl *opr, - const TensorLayout &src, + const TensorLayout &src, const CanonizedFilterMeta &filter, const TensorLayout &dst); }; diff --git a/dnn/src/cuda/deformable_conv/bwd_data/algo.h b/dnn/src/cuda/deformable_conv/bwd_data/algo.h index d16a66eb..5f83bc2f 100644 --- a/dnn/src/cuda/deformable_conv/bwd_data/algo.h +++ b/dnn/src/cuda/deformable_conv/bwd_data/algo.h @@ -26,6 +26,7 @@ protected: ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { DeformableConvBackwardDataImpl* opr; HandleImpl* handle; diff --git a/dnn/src/cuda/deformable_conv/bwd_flt/algo.h b/dnn/src/cuda/deformable_conv/bwd_flt/algo.h index ad9a9329..a2bb713e 100644 --- a/dnn/src/cuda/deformable_conv/bwd_flt/algo.h +++ b/dnn/src/cuda/deformable_conv/bwd_flt/algo.h @@ -26,6 +26,7 @@ protected: ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { DeformableConvBackwardFilterImpl* opr; HandleImpl* handle; diff --git a/dnn/src/cuda/deformable_conv/fwd/algo.h b/dnn/src/cuda/deformable_conv/fwd/algo.h index 768b49e5..f2d28ecb 100644 --- a/dnn/src/cuda/deformable_conv/fwd/algo.h +++ b/dnn/src/cuda/deformable_conv/fwd/algo.h @@ -24,6 +24,7 @@ protected: ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { DeformableConvForwardImpl* opr; HandleImpl* handle; diff --git a/dnn/src/cuda/local_share/backward_data/algo.h b/dnn/src/cuda/local_share/backward_data/algo.h index 7c5f2e8a..66a954ce 100644 --- a/dnn/src/cuda/local_share/backward_data/algo.h +++ b/dnn/src/cuda/local_share/backward_data/algo.h @@ -25,6 +25,7 @@ protected: ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { LocalShareBackwardDataImpl* opr; TensorLayout filter_layout, diff_layout, grad_layout; diff --git a/dnn/src/cuda/local_share/backward_filter/algo.h b/dnn/src/cuda/local_share/backward_filter/algo.h index 634f1203..cf916e78 100644 --- a/dnn/src/cuda/local_share/backward_filter/algo.h +++ b/dnn/src/cuda/local_share/backward_filter/algo.h @@ -25,6 +25,7 @@ protected: ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { LocalShareBackwardFilterImpl* opr; TensorLayout src_layout, diff_layout, grad_layout; diff --git a/dnn/src/cuda/local_share/forward/algo.h b/dnn/src/cuda/local_share/forward/algo.h index b41ec58d..a82be4f4 100644 --- a/dnn/src/cuda/local_share/forward/algo.h +++ b/dnn/src/cuda/local_share/forward/algo.h @@ -25,6 +25,7 @@ protected: ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { LocalShareForwardImpl* opr; TensorLayout src_layout, filter_layout, dst_layout; diff --git a/dnn/src/cuda/matrix_mul/algos.h b/dnn/src/cuda/matrix_mul/algos.h index 56e968c0..e52f628c 100644 --- a/dnn/src/cuda/matrix_mul/algos.h +++ b/dnn/src/cuda/matrix_mul/algos.h @@ -32,13 +32,14 @@ protected: ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { MatrixMulForwardImpl* opr; TensorLayout layout_a, layout_b, layout_c; std::string to_string() const; - SizeArgs(MatrixMulForwardImpl* opr, const TensorLayout& A, const TensorLayout& B, - const TensorLayout& C); + SizeArgs(MatrixMulForwardImpl* opr, const TensorLayout& A, + const TensorLayout& B, const TensorLayout& C); bool can_be_treated_as_int8x8x32() const { return layout_a.dtype.enumv() == layout_b.dtype.enumv() && diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index 86717452..8a44770b 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -213,6 +213,9 @@ public: class AlgoBase : public Algorithm { public: + AlgoBase() : Algorithm() { + m_handle_type = Handle::HandleType::FALLBACK; + } virtual ~AlgoBase() = default; virtual bool usable( const NCBKernSizeParam& param, diff --git a/dnn/src/fallback/convolution/algos.h b/dnn/src/fallback/convolution/algos.h index f8f73fdd..57d959f9 100644 --- a/dnn/src/fallback/convolution/algos.h +++ b/dnn/src/fallback/convolution/algos.h @@ -141,8 +141,6 @@ public: return get_kimpl(m_algorithm, param); } - void* type() const override { return sm_fallback_conv_algo_type; } - //! select matmul to the highest preference bool is_preferred(const NCBKernSizeParam& param) const override; @@ -168,7 +166,6 @@ public: const NCBKernSizeParam& param) const override; ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override; - void* type() const override { return sm_fallback_deconv_algo_type; } }; class ConvolutionBackwardDataImpl::AlgoMatrixMul final : public AlgoBase { @@ -181,7 +178,6 @@ public: const NCBKernSizeParam& param) const override; ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override; - void* type() const override { return sm_fallback_deconv_algo_type; } }; } // namespace fallback diff --git a/dnn/src/fallback/convolution/opr_impl.cpp b/dnn/src/fallback/convolution/opr_impl.cpp index 21adb166..d20f68e6 100644 --- a/dnn/src/fallback/convolution/opr_impl.cpp +++ b/dnn/src/fallback/convolution/opr_impl.cpp @@ -37,8 +37,6 @@ class NaiveConvolutionBackwardData final const char* name() const override { return "NCBD"; } }; NaiveConvolutionBackwardData naive_conv_backward_data; -uint8_t fallback_deconv_algo_type_storage; -uint8_t fallback_conv_algo_type_storage; template void incr_ptr(T*& dst, ptrdiff_t delta) { @@ -69,9 +67,6 @@ public: SmallVector all_algos; }; -void* const ConvolutionImpl::sm_fallback_conv_algo_type = - &fallback_conv_algo_type_storage; - SmallVector ConvolutionImpl::algo_pack() { static AlgoPack sl_algo_pack; return sl_algo_pack.all_algos; @@ -412,9 +407,6 @@ ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const { /* ===================== ConvolutionBackwardData ===================== */ -void* const ConvolutionBackwardDataImpl::sm_fallback_deconv_algo_type = - &fallback_deconv_algo_type_storage; - struct ConvolutionBackwardDataImpl::AlgoPack { AlgoDirect direct; AlgoMatrixMul matmul; @@ -630,7 +622,7 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb( size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( Algorithm* algo, const NCBKernSizeParam& param) { megdnn_assert(param.filter_meta.group == 1); - if (algo->type() == sm_fallback_deconv_algo_type) { + if (algo->handle_type() == Handle::HandleType::FALLBACK) { return static_cast(algo)->get_workspace(this, param); } megdnn_assert(algo == &naive_conv_backward_data); @@ -642,7 +634,7 @@ ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern( Algorithm* algo, const NCBKernSizeParam& param) { megdnn_assert(param.filter_meta.group == 1); - if (algo->type() == sm_fallback_deconv_algo_type) { + if (algo->handle_type() == Handle::HandleType::FALLBACK) { return static_cast(algo)->dispatch_kern(this, param); } diff --git a/dnn/src/fallback/convolution/opr_impl.h b/dnn/src/fallback/convolution/opr_impl.h index 62843d5a..1df426c7 100644 --- a/dnn/src/fallback/convolution/opr_impl.h +++ b/dnn/src/fallback/convolution/opr_impl.h @@ -177,8 +177,6 @@ public: } }; - static void* const sm_fallback_conv_algo_type; - /** * \brief Kernel run time id, This information is used for getting the * work data @@ -197,6 +195,9 @@ public: class AlgoBase : public Algorithm { public: + AlgoBase() : Algorithm() { + m_handle_type = Handle::HandleType::FALLBACK; + } virtual ~AlgoBase() = default; virtual bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy) const = 0; @@ -407,13 +408,14 @@ protected: const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, bool reproducible = false); - static void* const sm_fallback_deconv_algo_type; - class AlgoBase : public Algorithm { protected: ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { + m_handle_type = Handle::HandleType::FALLBACK; + } virtual bool usable(ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param) const = 0; virtual size_t get_workspace(ConvolutionBackwardDataImpl* opr, diff --git a/dnn/src/fallback/matrix_mul/opr_impl.h b/dnn/src/fallback/matrix_mul/opr_impl.h index fd4089dc..65e87a3a 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.h +++ b/dnn/src/fallback/matrix_mul/opr_impl.h @@ -103,6 +103,7 @@ public: } public: + AlgoBase() { m_handle_type = Handle::HandleType::FALLBACK; } enum class AlgoSet : uint32_t { ALGO_TYPE_GEMM = 0, ALGO_TYPE_GEMV = 1, diff --git a/dnn/src/rocm/batched_matrix_mul/opr_impl.cpp b/dnn/src/rocm/batched_matrix_mul/opr_impl.cpp index 3237e8ea..c53764dd 100644 --- a/dnn/src/rocm/batched_matrix_mul/opr_impl.cpp +++ b/dnn/src/rocm/batched_matrix_mul/opr_impl.cpp @@ -6,10 +6,11 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ -#include "hcc_detail/hcc_defs_prologue.h" #include "./opr_impl.h" +#include "hcc_detail/hcc_defs_prologue.h" #include "src/common/utils.cuh" #include "src/rocm/handle.h" @@ -92,8 +93,8 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, static_cast(A.raw_ptr), A.layout.stride[1], A.layout.stride[0], reinterpret_cast(zero_half), - static_cast(C.raw_ptr), - C.layout.stride[1], C.layout.stride[0], batch)); + static_cast(C.raw_ptr), C.layout.stride[1], + C.layout.stride[0], batch)); }; #endif diff --git a/dnn/src/rocm/convolution/backward_data/algo.h b/dnn/src/rocm/convolution/backward_data/algo.h index 7efd76d0..e67c4e99 100644 --- a/dnn/src/rocm/convolution/backward_data/algo.h +++ b/dnn/src/rocm/convolution/backward_data/algo.h @@ -25,6 +25,7 @@ protected: ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } struct SizeArgs { HandleImpl* handle; CanonizedFilterMeta filter_meta; diff --git a/dnn/src/rocm/convolution/backward_filter/algo.h b/dnn/src/rocm/convolution/backward_filter/algo.h index dfd4a788..30074e2a 100644 --- a/dnn/src/rocm/convolution/backward_filter/algo.h +++ b/dnn/src/rocm/convolution/backward_filter/algo.h @@ -26,6 +26,7 @@ protected: ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } struct SizeArgs { HandleImpl* handle; const TensorLayout *src_layout, *diff_layout; diff --git a/dnn/src/rocm/convolution/forward/algo.h b/dnn/src/rocm/convolution/forward/algo.h index f38cf8d3..b5906ba6 100644 --- a/dnn/src/rocm/convolution/forward/algo.h +++ b/dnn/src/rocm/convolution/forward/algo.h @@ -32,6 +32,7 @@ protected: ~AlgoBase() = default; public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } struct SizeArgs : public convolution::ForwardSizeArgs { ConvolutionForwardImpl* opr; diff --git a/dnn/src/x86/conv_bias/f32/algos.h b/dnn/src/x86/conv_bias/f32/algos.h index 7a4ed9c2..94a4b141 100644 --- a/dnn/src/x86/conv_bias/f32/algos.h +++ b/dnn/src/x86/conv_bias/f32/algos.h @@ -47,8 +47,6 @@ public: return get_kimpls(param); } - void* type() const override; - ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } @@ -84,8 +82,6 @@ public: return get_kimpls(param); } - void* type() const override; - ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } @@ -103,7 +99,6 @@ public: } return m_name.c_str(); } - void* type() const override; MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); }; @@ -119,7 +114,6 @@ public: } return m_name.c_str(); } - void* type() const override; MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); }; @@ -161,7 +155,6 @@ public: }; return {{kern, {1_z, 1_z, 1_z}}}; } - void* type() const override; ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; diff --git a/dnn/src/x86/conv_bias/int8/algos.h b/dnn/src/x86/conv_bias/int8/algos.h index e62dd8ff..65d015ef 100644 --- a/dnn/src/x86/conv_bias/int8/algos.h +++ b/dnn/src/x86/conv_bias/int8/algos.h @@ -32,7 +32,6 @@ public: const NCBKernSizeParam& param) const override { return get_kimpls(param); } - void* type() const override; bool is_preferred(const NCBKernSizeParam& param) const override; ConvAlgoTypePack get_algo_type() const override { @@ -57,7 +56,6 @@ public: const NCBKernSizeParam& param) const override { return get_kimpls(param); } - void* type() const override; bool is_preferred(const NCBKernSizeParam& param) const override; ConvAlgoTypePack get_algo_type() const override { @@ -82,7 +80,6 @@ public: const NCBKernSizeParam& param) const override { return get_kimpls(param); } - void* type() const override; bool is_preferred(const NCBKernSizeParam& param) const override; ConvAlgoTypePack get_algo_type() const override { @@ -107,7 +104,6 @@ public: const NCBKernSizeParam& param) const override { return get_kimpls(param); } - void* type() const override; bool is_preferred(const NCBKernSizeParam& param) const override; ConvAlgoTypePack get_algo_type() const override { @@ -148,7 +144,6 @@ public: }; return {{kern, {group, n, 1_z}}}; } - void* type() const override; bool is_preferred(const NCBKernSizeParam& param) const override; ConvAlgoTypePack get_algo_type() const override { @@ -179,8 +174,6 @@ public: //! select matmul to the highest preference bool is_preferred(const NCBKernSizeParam& param) const override; - void* type() const override; - ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; } diff --git a/dnn/src/x86/conv_bias/opr_impl.cpp b/dnn/src/x86/conv_bias/opr_impl.cpp index 51566605..936a9738 100644 --- a/dnn/src/x86/conv_bias/opr_impl.cpp +++ b/dnn/src/x86/conv_bias/opr_impl.cpp @@ -22,54 +22,14 @@ using namespace megdnn; using namespace x86; - namespace { -uint8_t x86_algo_type_storage; -void* x86_algo_type = &x86_algo_type_storage; -} // anonymous namespace -#if MEGDNN_X86_WITH_MKL_DNN -void* ConvBiasImpl::AlgoMkldnnQint8::type() const { - return x86_algo_type; -} -void* ConvBiasImpl::AlgoMkldnnMatmulQint8::type() const { - return x86_algo_type; -} -void* ConvBiasImpl::AlgoMkldnnConv::type() const { - return x86_algo_type; -} -#endif - -void* ConvBiasImpl::AlgoDirect::type() const { - return x86_algo_type; -} - -void* ConvBiasImpl::AlgoDirectStride2::type() const { - return x86_algo_type; -} -void* ConvBiasImpl::AlgoDirectAvx2Stride1Int8::type() const { - return x86_algo_type; +bool is_fallback_or_naive(const detail::Algorithm* algo) { + return algo->handle_type() == Handle::HandleType::NAIVE || + algo->handle_type() == Handle::HandleType::FALLBACK; } -void* ConvBiasImpl::AlgoFP32WinogradF63_8x8::type() const { - return x86_algo_type; -} - -void* ConvBiasImpl::AlgoFP32WinogradF23_8x8::type() const { - return x86_algo_type; -} - -void* ConvBiasImpl::AlgoAVX2DirectConvStride2::type() const { - return x86_algo_type; -} - -void* ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::type() const { - return x86_algo_type; -} - -void* ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::type() const { - return x86_algo_type; -} +} // anonymous namespace class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoDirect stride1_direct; @@ -88,8 +48,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { public: AlgoPack() { - //! FIXME: preference to use mkldnn algo on VNNI devices - //! But now mkldnn algo preference issue with NCHW->NHWC->NCHW + //! FIXME: preference to use mkldnn algo on VNNI devices + //! But now mkldnn algo preference issue with NCHW->NHWC->NCHW #if MEGDNN_X86_WITH_MKL_DNN //! Create the mkldnn algo all_algos.emplace_back(&mkldnn_conv_fp32); @@ -108,7 +68,7 @@ public: auto&& matmul_algos = static_cast(matmul_opr)->algo_pack(); for (auto&& algo : matmul_algos) { - if (algo->type() == nullptr) + if (is_fallback_or_naive(algo)) continue; for (uint32_t tile_size : {8, 16, 24}) { refhold.emplace_back(new AlgoFP32WinogradF63_8x8( @@ -126,7 +86,7 @@ public: SmallVector winograd_algos; }; -SmallVector ConvBiasImpl::algo_pack() { +SmallVector ConvBiasImpl::algo_pack() { static AlgoPack sl_algo_pack; auto&& algos = fallback::ConvBiasImpl::algo_pack(); algos.insert(algos.begin(), sl_algo_pack.all_algos.begin(), @@ -176,8 +136,8 @@ bool ConvBiasImpl::is_matmul_quantized_prefer( !chanwise_avx2_stride2_qint8_usable_preferred(param)); } -SmallVector -ConvBiasImpl::suggest_algo_category_order(const NCBKernSizeParam& param) const { +SmallVector ConvBiasImpl::suggest_algo_category_order( + const NCBKernSizeParam& param) const { auto IC = param.filter_meta.icpg; auto OC = param.filter_meta.ocpg; auto FH = param.filter_meta.spatial[0]; diff --git a/dnn/src/x86/conv_bias/opr_impl.h b/dnn/src/x86/conv_bias/opr_impl.h index e98aae5d..49f9c731 100644 --- a/dnn/src/x86/conv_bias/opr_impl.h +++ b/dnn/src/x86/conv_bias/opr_impl.h @@ -20,10 +20,15 @@ namespace x86 { class ConvBiasImpl : public fallback::ConvBiasImpl { public: using fallback::ConvBiasImpl::ConvBiasImpl; - using FallbackConvBiasImpl = fallback::ConvBiasImpl; + class AlgoBase : public fallback::ConvBiasImpl::AlgoBase { + public: + AlgoBase() : fallback::ConvBiasImpl::AlgoBase() { + m_handle_type = Handle::HandleType::X86; + } + }; bool is_thread_safe() const override { return true; } - SmallVector algo_pack() override; + SmallVector algo_pack() override; SmallVector suggest_algo_category_order( const NCBKernSizeParam& param) const override; diff --git a/dnn/src/x86/matrix_mul/algos.h b/dnn/src/x86/matrix_mul/algos.h index d93c9b5e..79b26004 100644 --- a/dnn/src/x86/matrix_mul/algos.h +++ b/dnn/src/x86/matrix_mul/algos.h @@ -25,7 +25,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override { return 0; } kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_x86_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) }; @@ -38,7 +37,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override { return 0; } kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_x86_algo_type; } PackMode packmode() const override { return PackMode::ONLY_PACKA; } kern_naked_t get_kern_naked(const KernSizeParam&) const override; void pack_A(const KernParam& kern_param, void* out, size_t index, @@ -60,7 +58,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_x86_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -71,7 +68,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_x86_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -86,7 +82,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_x86_algo_type; } bool preferred(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -102,7 +97,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_x86_algo_type; } bool preferred(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -114,7 +108,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_x86_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -125,7 +118,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_x86_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4, AlgoDataType::FLOAT32, MK8) }; @@ -138,7 +130,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_x86_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; #endif @@ -151,7 +142,6 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override { return 0; } kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_x86_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) }; diff --git a/dnn/src/x86/matrix_mul/opr_impl.cpp b/dnn/src/x86/matrix_mul/opr_impl.cpp index a9d7c312..8d6f1ee6 100644 --- a/dnn/src/x86/matrix_mul/opr_impl.cpp +++ b/dnn/src/x86/matrix_mul/opr_impl.cpp @@ -16,12 +16,6 @@ using namespace megdnn; using namespace x86; -namespace { -uint8_t x86_algo_type_storage; -} // anonymous namespace - -void* const MatrixMulImpl::sm_x86_algo_type = &x86_algo_type_storage; - class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoF32Blas f32blas; @@ -62,10 +56,10 @@ public: all_algos.emplace_back(&f32mkl_packa); #endif } - SmallVector all_algos; + SmallVector all_algos; }; -SmallVector MatrixMulImpl::algo_pack() { +SmallVector MatrixMulImpl::algo_pack() { static AlgoPack s_algo_pack; auto&& algos = fallback::MatrixMulImpl::algo_pack(); algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), diff --git a/dnn/src/x86/matrix_mul/opr_impl.h b/dnn/src/x86/matrix_mul/opr_impl.h index be76c26c..3c8a0f90 100644 --- a/dnn/src/x86/matrix_mul/opr_impl.h +++ b/dnn/src/x86/matrix_mul/opr_impl.h @@ -33,13 +33,18 @@ namespace x86 { class MatrixMulImpl : public fallback::MatrixMulImpl { public: using fallback::MatrixMulImpl::MatrixMulImpl; + class AlgoBase : public fallback::MatrixMulImpl::AlgoBase { + public: + AlgoBase() : fallback::MatrixMulImpl::AlgoBase() { + m_handle_type = Handle::HandleType::X86; + } + }; bool is_thread_safe() const override { return true; } - SmallVector algo_pack() override; + SmallVector algo_pack() override; protected: - static void* const sm_x86_algo_type; class AlgoF32Blas; #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM class AlgoF32MKLPackA;