diff --git a/dnn/include/megdnn/oprs/base.h b/dnn/include/megdnn/oprs/base.h index 97233238..aef6de42 100644 --- a/dnn/include/megdnn/oprs/base.h +++ b/dnn/include/megdnn/oprs/base.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -92,24 +93,72 @@ enum class AlgoDataType : uint32_t { /*! * \brief Abstract representation of an algorithm for implementing * the operator - * - * All pointers to Algorithm should be allocated globally and usable - * across multiple megdnn handles, and they should not be freed by - * the caller. */ class Algorithm { public: + static constexpr uint32_t INVALID_ALGO_TYPE = static_cast(-1); + /** + * \brief Algorithm information, we can get real algo from + * AlgorithmInfo::Info::Desc + */ + struct Info { + struct Desc { + //! backend of the algo belonging to + Handle::HandleType handle_type; + //! indicate the real algo implementation + uint32_t type = INVALID_ALGO_TYPE; + //! serialized param of the algo type + std::string param; + bool valid() const { return type != INVALID_ALGO_TYPE; } + void reset() { type = INVALID_ALGO_TYPE; } + + bool operator==(const Desc& rhs) const { + return handle_type == rhs.handle_type && type == rhs.type && + param == rhs.param; + } + } desc; + //! algorithm name + std::string name; + bool is_reproducible; + bool valid() const { return desc.valid(); } + void reset() { desc.reset(); } + //! desc donate the algo + bool operator==(const Info& rhs) const { return desc == rhs.desc; } + }; + + virtual ~Algorithm() = default; + /** * \brief whether the execution result is * reproducible across multiple runs. */ virtual bool is_reproducible() const = 0; virtual const char* name() const = 0; + //! serialized param + virtual std::string param() const { return {}; } + virtual uint32_t type() const = 0; Handle::HandleType handle_type() const { return m_handle_type; } + Info info() const { + return {{handle_type(), type(), param()}, name(), is_reproducible()}; + } + + template + static void serialize_write_pod(const T& val, std::string& result) { + result.append(reinterpret_cast(&val), sizeof(T)); + } + + static void serialize_write_pod(const char* val, std::string& result) { + result.append(val, strlen(val)); + } + + template + static T deserialize_read_pod(const std::string& data, size_t offset = 0) { + T ret = *reinterpret_cast(&data[offset]); + return ret; + } protected: - ~Algorithm() = default; Handle::HandleType m_handle_type = Handle::HandleType::NAIVE; }; @@ -127,6 +176,8 @@ class MultiAlgoOpr; template class MultiAlgoOpr { public: + using AlgorithmInfo = detail::Algorithm::Info; + using AlgorithmDesc = detail::Algorithm::Info::Desc; using Algorithm = detail::Algorithm; /*! * \brief get a string representation for current algorithm set; @@ -139,8 +190,8 @@ public: //! policy for executing the operator struct ExecutionPolicy { - //! nullptr means using heuristic - Algorithm* algorithm = nullptr; + //! INVALID_ALGO_TYPE algo_type means using heuristic + AlgorithmInfo algo; }; ExecutionPolicy& execution_policy() { return m_execution_policy; } @@ -161,6 +212,39 @@ template class MultiAlgoOpr : public MultiAlgoOpr { public: using Algorithm = detail::Algorithm; + using AlgorithmInfo = detail::Algorithm::Info; + + //! get all possible algorithm decriptions for the specified layouts + std::vector get_all_algorithms_info(const TensorLayout& p0, + const TensorLayout& p1, + const TensorLayout& p2) { + std::vector ret; + for (auto&& algo : get_all_algorithms(p0, p1, p2)) { + ret.emplace_back(algo->info()); + } + return ret; + } + + /** + * \brief Returns the best algorithm information which indicate the + * algorithm by heuristic. + * + * The selected algorithm should not use workspace more than + * \p workspace_limit_in_bytes. + */ + AlgorithmInfo get_algorithm_info_heuristic( + const TensorLayout& p0, const TensorLayout& p1, + const TensorLayout& p2, + size_t workspace_limit_in_bytes = + std::numeric_limits::max(), + bool reproducible = false) { + return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes, + reproducible) + ->info(); + } + +protected: + ~MultiAlgoOpr() = default; //! get all possible algorithms for the specified layouts virtual std::vector get_all_algorithms( @@ -179,9 +263,6 @@ public: size_t workspace_limit_in_bytes = std::numeric_limits::max(), bool reproducible = false) = 0; - -protected: - ~MultiAlgoOpr() = default; }; //! specializae for nargs == 4 @@ -189,6 +270,40 @@ template class MultiAlgoOpr : public MultiAlgoOpr { public: using Algorithm = detail::Algorithm; + using AlgorithmInfo = detail::Algorithm::Info; + + //! get all possible algorithm decriptions for the specified layouts + std::vector get_all_algorithms_info(const TensorLayout& p0, + const TensorLayout& p1, + const TensorLayout& p2, + const TensorLayout& p3) { + std::vector ret; + for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) { + ret.emplace_back(algo->info()); + } + return ret; + } + + /** + * \brief Returns the best algorithm information which indicate the + * algorithm by heuristic. + * + * The selected algorithm should not use workspace more than + * \p workspace_limit_in_bytes. + */ + AlgorithmInfo get_algorithm_info_heuristic( + const TensorLayout& p0, const TensorLayout& p1, + const TensorLayout& p2, const TensorLayout& p3, + size_t workspace_limit_in_bytes = + std::numeric_limits::max(), + bool reproducible = false) { + return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes, + reproducible) + ->info(); + } + +protected: + ~MultiAlgoOpr() = default; //! get all possible algorithms for the specified layouts virtual std::vector get_all_algorithms( @@ -207,9 +322,6 @@ public: size_t workspace_limit_in_bytes = std::numeric_limits::max(), bool reproducible = false) = 0; - -protected: - ~MultiAlgoOpr() = default; }; //! specializae for nargs == 5 @@ -217,6 +329,42 @@ template class MultiAlgoOpr : public MultiAlgoOpr { public: using Algorithm = detail::Algorithm; + using AlgorithmInfo = detail::Algorithm::Info; + + //! get all possible algorithm decriptions for the specified layouts + std::vector get_all_algorithms_info(const TensorLayout& p0, + const TensorLayout& p1, + const TensorLayout& p2, + const TensorLayout& p3, + const TensorLayout& p4) { + std::vector ret; + for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) { + ret.emplace_back(algo->info()); + } + return ret; + } + + /** + * \brief Returns the best algorithm information which indicate the + * algorithm by heuristic. + * + * The selected algorithm should not use workspace more than + * \p workspace_limit_in_bytes. + */ + AlgorithmInfo get_algorithm_info_heuristic( + const TensorLayout& p0, const TensorLayout& p1, + const TensorLayout& p2, const TensorLayout& p3, + const TensorLayout& p4, + size_t workspace_limit_in_bytes = + std::numeric_limits::max(), + bool reproducible = false) { + return get_algorithm_heuristic(p0, p1, p2, p3, p4, + workspace_limit_in_bytes, reproducible) + ->info(); + } + +protected: + ~MultiAlgoOpr() = default; //! get all possible algorithms for the specified layouts virtual std::vector get_all_algorithms( @@ -237,9 +385,6 @@ public: size_t workspace_limit_in_bytes = std::numeric_limits::max(), bool reproducible = false) = 0; - -protected: - ~MultiAlgoOpr() = default; }; //! specializae for nargs == 8 @@ -247,6 +392,42 @@ template class MultiAlgoOpr : public MultiAlgoOpr { public: using Algorithm = detail::Algorithm; + using AlgorithmInfo = detail::Algorithm::Info; + + //! get all possible algorithm decriptions for the specified layouts + std::vector get_all_algorithms_info( + const TensorLayout& p0, const TensorLayout& p1, + const TensorLayout& p2, const TensorLayout& p3, + const TensorLayout& p4, const TensorLayout& p5, + const TensorLayout& p6, const TensorLayout& p7) { + std::vector ret; + for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) { + ret.emplace_back(algo->info()); + } + return ret; + } + + /** + * \brief Returns the best algorithm information which indicate the + * algorithm by heuristic. + * + * The selected algorithm should not use workspace more than + */ + AlgorithmInfo get_algorithm_info_heuristic( + const TensorLayout& p0, const TensorLayout& p1, + const TensorLayout& p2, const TensorLayout& p3, + const TensorLayout& p4, const TensorLayout& p5, + const TensorLayout& p6, const TensorLayout& p7, + size_t workspace_limit_in_bytes = + std::numeric_limits::max(), + bool reproducible = false) { + return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7, + workspace_limit_in_bytes, reproducible) + ->info(); + } + +protected: + ~MultiAlgoOpr() = default; //! get all possible algorithms for the specified layouts virtual std::vector get_all_algorithms( @@ -269,9 +450,6 @@ public: size_t workspace_limit_in_bytes = std::numeric_limits::max(), bool reproducible = false) = 0; - -protected: - ~MultiAlgoOpr() = default; }; } // namespace detail } // namespace megdnn diff --git a/dnn/src/aarch64/conv_bias/fp16/algos.h b/dnn/src/aarch64/conv_bias/fp16/algos.h index 3b36bfd3..c6ddc95b 100644 --- a/dnn/src/aarch64/conv_bias/fp16/algos.h +++ b/dnn/src/aarch64/conv_bias/fp16/algos.h @@ -31,6 +31,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(AARCH64_DIRECT_STRD2_FP16) }; } // namespace aarch64 } // namespace megdnn diff --git a/dnn/src/aarch64/conv_bias/fp32/algos.h b/dnn/src/aarch64/conv_bias/fp32/algos.h index 3340c726..53a33968 100644 --- a/dnn/src/aarch64/conv_bias/fp32/algos.h +++ b/dnn/src/aarch64/conv_bias/fp32/algos.h @@ -36,6 +36,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(AARCH64_DIRECT_STRD2_FP32) }; } // namespace aarch64 diff --git a/dnn/src/aarch64/conv_bias/int8/algos.h b/dnn/src/aarch64/conv_bias/int8/algos.h index afac5922..4a5bf6b1 100644 --- a/dnn/src/aarch64/conv_bias/int8/algos.h +++ b/dnn/src/aarch64/conv_bias/int8/algos.h @@ -48,6 +48,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; } + MEGDNN_DECL_ALGO_TYPE(AARCH64_MATMUL_S8) }; } // namespace aarch64 diff --git a/dnn/src/aarch64/conv_bias/opr_impl.cpp b/dnn/src/aarch64/conv_bias/opr_impl.cpp index 0d997f73..9bca6877 100644 --- a/dnn/src/aarch64/conv_bias/opr_impl.cpp +++ b/dnn/src/aarch64/conv_bias/opr_impl.cpp @@ -32,28 +32,54 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoF16DirectStride2 f16_direct_stride2; #endif + fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map; + SmallVector m_direct_algos; + SmallVector m_matmul_algos; + public: AlgoPack() { - matmul_algos.emplace_back(&qu8_matrix_mul); - matmul_algos.emplace_back(&s8_matrix_mul); + m_matmul_algos.emplace_back(&qu8_matrix_mul); + m_matmul_algos.emplace_back(&s8_matrix_mul); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - direct_algos.emplace_back(&f16_direct_stride2); + m_direct_algos.emplace_back(&f16_direct_stride2); #endif - direct_algos.emplace_back(&f32_direct_stride2); + m_direct_algos.emplace_back(&f32_direct_stride2); + + for (auto&& algo : m_direct_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } + for (auto&& algo : m_matmul_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } + } + + const SmallVector& direct_algos() const { + return m_direct_algos; + } + const SmallVector& matmul_algos() + const { + return m_matmul_algos; } - SmallVector direct_algos; - SmallVector matmul_algos; + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } + }; -SmallVector ConvBiasImpl::algo_pack() { - static AlgoPack sl_algo_pack; - auto&& algos = arm_common::ConvBiasImpl::algo_pack(); - algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), - sl_algo_pack.direct_algos.end()); +const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { + static AlgoPack algo_pack; + return algo_pack; +} + +MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) + +SmallVector +ConvBiasImpl::get_all_packed_algo() { + auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); + algos.insert(algos.begin(), algo_pack().direct_algos().begin(), + algo_pack().direct_algos().end()); //! We put matmul algos at the begin. Because matmul will get privilege when //! prefer return true. See - algos.insert(algos.begin(), sl_algo_pack.matmul_algos.begin(), - sl_algo_pack.matmul_algos.end()); + algos.insert(algos.begin(), algo_pack().matmul_algos().begin(), + algo_pack().matmul_algos().end()); return std::move(algos); } diff --git a/dnn/src/aarch64/conv_bias/opr_impl.h b/dnn/src/aarch64/conv_bias/opr_impl.h index 7666ab15..f1867cf9 100644 --- a/dnn/src/aarch64/conv_bias/opr_impl.h +++ b/dnn/src/aarch64/conv_bias/opr_impl.h @@ -25,7 +25,9 @@ public: } }; - SmallVector algo_pack() override; + SmallVector get_all_packed_algo() override; + + MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl); protected: const char* get_algorithm_set_name() const override; @@ -38,6 +40,7 @@ private: class AlgoF16DirectStride2; #endif class AlgoPack; + static const AlgoPack& algo_pack(); }; } // namespace aarch64 diff --git a/dnn/src/aarch64/conv_bias/quint8/algos.h b/dnn/src/aarch64/conv_bias/quint8/algos.h index a55ee568..ba3ab203 100644 --- a/dnn/src/aarch64/conv_bias/quint8/algos.h +++ b/dnn/src/aarch64/conv_bias/quint8/algos.h @@ -48,6 +48,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; } + MEGDNN_DECL_ALGO_TYPE(AARCH64_MATMUL_QU8) }; } // namespace aarch64 } // namespace megdnn diff --git a/dnn/src/aarch64/matrix_mul/algos.h b/dnn/src/aarch64/matrix_mul/algos.h index a20247c4..aa9e7e0e 100644 --- a/dnn/src/aarch64/matrix_mul/algos.h +++ b/dnn/src/aarch64/matrix_mul/algos.h @@ -27,6 +27,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_K8X12X1) }; class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { @@ -37,6 +38,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_K8X12X1) }; class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { @@ -47,6 +49,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_K4X16X1) }; class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { @@ -58,10 +61,17 @@ public: kern_t get_kern(const KernSizeParam&) const override; PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4) + MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16) }; class MatrixMulImpl::AlgoF32Gemv final - : public arm_common::MatrixMulImpl::AlgoF32Gemv {}; + : public arm_common::MatrixMulImpl::AlgoF32Gemv { +public: + AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { + m_handle_type = Handle::HandleType::AARCH64; + } + MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_GEMV) +}; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { @@ -72,6 +82,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_K8X24X1) }; class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { @@ -83,6 +94,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8) + MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_8X8) }; #endif @@ -98,6 +110,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X12X4_DOTPROD) }; class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { @@ -110,6 +123,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD) }; #else @@ -124,6 +138,7 @@ public: PackMode packmode() const override { return PackMode::DEFAULT; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_4X4X16) }; class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { @@ -136,6 +151,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K4X4X16) }; class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { @@ -147,6 +163,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X8X8) }; #endif @@ -160,6 +177,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K8X8X8) }; class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { @@ -171,6 +189,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K4X4X16) }; class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { @@ -186,6 +205,7 @@ public: PackMode packmode() const override { return PackMode::DEFAULT; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_16X12X4) }; class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { @@ -201,6 +221,7 @@ public: PackMode packmode() const override { return PackMode::DEFAULT; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_K8X8X8) }; class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { @@ -214,6 +235,7 @@ public: PackMode packmode() const override { return PackMode::DEFAULT; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_4X4X8) }; class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { @@ -225,6 +247,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_K12X8X1) }; class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { @@ -236,6 +259,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) + MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_MK8_8X8) }; #if __ARM_FEATURE_DOTPROD @@ -249,6 +273,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X4_DOTPROD) }; class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { @@ -262,6 +287,7 @@ public: AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) + MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_GEMV_DOTPROD) }; #else @@ -273,6 +299,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X8) }; #endif diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.cpp b/dnn/src/aarch64/matrix_mul/opr_impl.cpp index 2910e582..f4ee77e9 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.cpp +++ b/dnn/src/aarch64/matrix_mul/opr_impl.cpp @@ -51,49 +51,66 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoQuint8K8x8x8 quint8_k8x8x8; #endif + SmallVector m_all_algos; + fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; public: - SmallVector all_algos; AlgoPack() { - all_algos.emplace_back(&f32_gemv); - all_algos.emplace_back(&f32K8x12x1); - all_algos.emplace_back(&f32_mk4_8x12x1); - all_algos.emplace_back(&f32k4x16x1); - all_algos.emplace_back(&f32mk4_4x16); + m_all_algos.emplace_back(&f32_gemv); + m_all_algos.emplace_back(&f32K8x12x1); + m_all_algos.emplace_back(&f32_mk4_8x12x1); + m_all_algos.emplace_back(&f32k4x16x1); + m_all_algos.emplace_back(&f32mk4_4x16); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - all_algos.emplace_back(&f16_k8x24x1); - all_algos.emplace_back(&f16_mk8_8x8); + m_all_algos.emplace_back(&f16_k8x24x1); + m_all_algos.emplace_back(&f16_mk8_8x8); #endif #if __ARM_FEATURE_DOTPROD - all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); - all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); + m_all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); + m_all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); #else - all_algos.emplace_back(&int8x8x32_k4x4x16); - all_algos.emplace_back(&int8x8x32_k8x8x8); - all_algos.emplace_back(&int8x8x32_mk4_4x4x16); + m_all_algos.emplace_back(&int8x8x32_k4x4x16); + m_all_algos.emplace_back(&int8x8x32_k8x8x8); + m_all_algos.emplace_back(&int8x8x32_mk4_4x4x16); #endif - all_algos.emplace_back(&int8x8x16_k4x4x16); - all_algos.emplace_back(&int8x8x16_k8x8x8); - all_algos.emplace_back(&int8x8x16_mk4_k8x8x8); - all_algos.emplace_back(&int8x8x16_mk4_4x4x8); - all_algos.emplace_back(&int8x8x16_mk4_16x12x4); + m_all_algos.emplace_back(&int8x8x16_k4x4x16); + m_all_algos.emplace_back(&int8x8x16_k8x8x8); + m_all_algos.emplace_back(&int8x8x16_mk4_k8x8x8); + m_all_algos.emplace_back(&int8x8x16_mk4_4x4x8); + m_all_algos.emplace_back(&int8x8x16_mk4_16x12x4); - all_algos.emplace_back(&int16x16x32_k12x8x1); - all_algos.emplace_back(&int16x16x32_mk8_8x8); + m_all_algos.emplace_back(&int16x16x32_k12x8x1); + m_all_algos.emplace_back(&int16x16x32_mk8_8x8); #if __ARM_FEATURE_DOTPROD - all_algos.emplace_back(&quint8_gemv_dotprod); - all_algos.emplace_back(&quint8_k8x8x4_dotprod); + m_all_algos.emplace_back(&quint8_gemv_dotprod); + m_all_algos.emplace_back(&quint8_k8x8x4_dotprod); #else - all_algos.emplace_back(&quint8_k8x8x8); + m_all_algos.emplace_back(&quint8_k8x8x8); #endif + + for (auto&& algo : m_all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } + } + + const SmallVector& all_algos() const { + return m_all_algos; } + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -SmallVector MatrixMulImpl::algo_pack() { - static AlgoPack s_algo_pack; - auto&& algos = arm_common::MatrixMulImpl::algo_pack(); - algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), - s_algo_pack.all_algos.end()); +const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { + static AlgoPack algo_pack; + return algo_pack; +} + +MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl) + +SmallVector +MatrixMulImpl::get_all_packed_algo() { + auto&& algos = arm_common::MatrixMulImpl::get_all_packed_algo(); + algos.insert(algos.begin(), algo_pack().all_algos().begin(), + algo_pack().all_algos().end()); return std::move(algos); } diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.h b/dnn/src/aarch64/matrix_mul/opr_impl.h index 31c8ef3b..0e4a5fa9 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.h +++ b/dnn/src/aarch64/matrix_mul/opr_impl.h @@ -25,7 +25,10 @@ public: } }; - SmallVector algo_pack() override; + SmallVector get_all_packed_algo() + override; + + MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl); private: class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 @@ -66,6 +69,8 @@ private: class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 class AlgoPack; +public: + static const AlgoPack& algo_pack(); }; } // namespace aarch64 diff --git a/dnn/src/arm_common/conv_bias/f16/algos.h b/dnn/src/arm_common/conv_bias/f16/algos.h index 0f985651..47c0ead5 100644 --- a/dnn/src/arm_common/conv_bias/f16/algos.h +++ b/dnn/src/arm_common/conv_bias/f16/algos.h @@ -30,6 +30,7 @@ public: return m_name.c_str(); } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_FP16) }; class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase { @@ -45,7 +46,7 @@ public: return m_name.c_str(); } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); - + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP16) }; class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase { public: @@ -61,6 +62,7 @@ public: } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP16) }; class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase { public: @@ -75,6 +77,7 @@ public: return m_name.c_str(); } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_FP16) }; class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { @@ -94,6 +97,7 @@ public: ConvAlgoTypePack get_algo_type() const override{ return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP16) }; class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { @@ -110,6 +114,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP16) }; } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.h b/dnn/src/arm_common/conv_bias/fp32/algos.h index e65fe8bd..15fad6ed 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.h +++ b/dnn/src/arm_common/conv_bias/fp32/algos.h @@ -30,6 +30,7 @@ public: return m_name.c_str(); } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_FP32) }; class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { @@ -45,6 +46,7 @@ public: return m_name.c_str(); } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP32) }; class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { @@ -60,6 +62,7 @@ public: return m_name.c_str(); } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_FP32) }; class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { @@ -75,6 +78,7 @@ public: return m_name.c_str(); } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F54_FP32) }; class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { @@ -90,6 +94,7 @@ public: return m_name.c_str(); } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP32) }; //===================== NCHW44 Winograd Support =====================// @@ -107,6 +112,7 @@ public: return m_name.c_str(); } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32) }; class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { @@ -123,6 +129,7 @@ public: return m_name.c_str(); } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32) }; class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { @@ -139,6 +146,7 @@ public: return m_name.c_str(); } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32) }; // ================================================================= // @@ -157,6 +165,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP32) }; class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { @@ -174,6 +183,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP32) }; class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { @@ -191,6 +201,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_FP32) }; class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { @@ -209,6 +220,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_FP32) }; class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { @@ -227,6 +239,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_FP32) }; class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { @@ -244,6 +257,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW44_F32) }; } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/int8/algos.h b/dnn/src/arm_common/conv_bias/int8/algos.h index f9611372..e509e43d 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.h +++ b/dnn/src/arm_common/conv_bias/int8/algos.h @@ -33,6 +33,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_S8) }; class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { @@ -49,6 +50,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_S8) }; class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { @@ -65,6 +67,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44) }; class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { @@ -81,6 +84,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_S8) }; class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase { @@ -95,6 +99,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD1_NCHW44_S8) }; class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase { @@ -109,6 +114,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD2_NCHW44_S8) }; #if __ARM_FEATURE_DOTPROD @@ -126,6 +132,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8) }; class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { @@ -142,6 +149,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_S8) }; class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { @@ -159,6 +167,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_S8) }; class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase { @@ -180,6 +189,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_DOT_S8) }; #endif @@ -196,6 +206,7 @@ public: return m_name.c_str(); } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_S8) }; //=======================input int8 compute fp32 output int8============ @@ -213,6 +224,7 @@ public: return m_name.c_str(); } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32) }; //=======================input int8 compute int16 output int8============ @@ -231,6 +243,7 @@ public: } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8) }; } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/algos.h b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h index 4591a278..9b83ac02 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/algos.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h @@ -39,6 +39,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_INT8X8X16) }; class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { @@ -54,6 +55,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_INT8X8X16) }; class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { @@ -80,6 +82,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_INT8X8X16) }; class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase { @@ -96,12 +99,16 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_F2_INT8X8X16) }; -class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase { +class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final + : public AlgoBase { public: bool is_reproducible() const override { return true; } - const char* name() const override { return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; } + const char* name() const override { + return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; + } bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace( @@ -111,6 +118,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_STRD1_STRD2_NCHW44_INT8X8X16) }; class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase { @@ -129,6 +137,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_INT8X8X16) }; } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index 2a50104d..99a13d0a 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -88,46 +88,50 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { #endif SmallVector> refhold; + fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map; + SmallVector m_direct_algos; + SmallVector m_winograd_algos; public: AlgoPack() { #if __ARM_FEATURE_DOTPROD - direct_algos.emplace_back(&ds8_direct_stride1); - direct_algos.emplace_back(&ds8_direct_stride2); - direct_algos.emplace_back(&du8_direct_stride1); - direct_algos.emplace_back(&du8_direct_stride2); + m_direct_algos.emplace_back(&ds8_direct_stride1); + m_direct_algos.emplace_back(&ds8_direct_stride2); + m_direct_algos.emplace_back(&du8_direct_stride1); + m_direct_algos.emplace_back(&du8_direct_stride2); - direct_algos.emplace_back(&ds8_direct_nchw44); - direct_algos.emplace_back(&ds8_direct_nchw_nchw44); + m_direct_algos.emplace_back(&ds8_direct_nchw44); + m_direct_algos.emplace_back(&ds8_direct_nchw_nchw44); #endif - direct_algos.emplace_back(&qu8_direct_stride2); - direct_algos.emplace_back(&qu8_direct_stride1); - direct_algos.emplace_back(&s8_direct_stride2); - direct_algos.emplace_back(&s8_direct_nchw44); - direct_algos.emplace_back(&s8x8x16_direct_nchw44); - direct_algos.emplace_back(&s8_direct_nchw_nchw44); - direct_algos.emplace_back(&s8_direct_stride1); - - direct_algos.emplace_back(&s8x8x16_channel_wise_stride1_stride2_nchw44); - direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); - direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); + m_direct_algos.emplace_back(&qu8_direct_stride2); + m_direct_algos.emplace_back(&qu8_direct_stride1); + m_direct_algos.emplace_back(&s8_direct_stride2); + m_direct_algos.emplace_back(&s8_direct_nchw44); + m_direct_algos.emplace_back(&s8x8x16_direct_nchw44); + m_direct_algos.emplace_back(&s8_direct_nchw_nchw44); + m_direct_algos.emplace_back(&s8_direct_stride1); + + m_direct_algos.emplace_back( + &s8x8x16_channel_wise_stride1_stride2_nchw44); + m_direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); + m_direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - direct_algos.emplace_back(&f16_direct_stride1); - direct_algos.emplace_back(&f16_direct); + m_direct_algos.emplace_back(&f16_direct_stride1); + m_direct_algos.emplace_back(&f16_direct); #endif - direct_algos.emplace_back(&i8x8x16_direct); - direct_algos.emplace_back(&i8x8x16_stride2_filter2); - direct_algos.emplace_back(&i8x8x16_stride2); - direct_algos.emplace_back(&i8x8x16_nchw_nchw44); + m_direct_algos.emplace_back(&i8x8x16_direct); + m_direct_algos.emplace_back(&i8x8x16_stride2_filter2); + m_direct_algos.emplace_back(&i8x8x16_stride2); + m_direct_algos.emplace_back(&i8x8x16_nchw_nchw44); - direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); - direct_algos.emplace_back(&f32_chanel_wise_nchw44); - direct_algos.emplace_back(&f32_direct_nchw44); + m_direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); + m_direct_algos.emplace_back(&f32_chanel_wise_nchw44); + m_direct_algos.emplace_back(&f32_direct_nchw44); - direct_algos.emplace_back(&f32_direct_stride1); - direct_algos.emplace_back(&f32_direct_stride2); - direct_algos.emplace_back(&f32_direct); + m_direct_algos.emplace_back(&f32_direct_stride1); + m_direct_algos.emplace_back(&f32_direct_stride2); + m_direct_algos.emplace_back(&f32_direct); static CpuOprDelegationStorage<2> storage; auto matmul_opr = storage.get(); @@ -143,31 +147,31 @@ public: refhold.emplace_back(new AlgoFP32WinogradF23_4x4( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoFP32WinogradF63_4x4( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); //! uncomment this when low precision mode is done #if 0 refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); #endif //! Qint8x8x32 winograd compute with fp32 refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); } } matmul_algos = static_cast(matmul_opr) @@ -180,15 +184,15 @@ public: refhold.emplace_back(new AlgoFP32WinogradF63( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoFP32WinogradF54( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoFP32WinogradF45( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); } } @@ -203,15 +207,15 @@ public: refhold.emplace_back(new AlgoFP16WinogradF23( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoFP16WinogradF45( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoFP16WinogradF63( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); } } matmul_algos = static_cast(matmul_opr) @@ -224,7 +228,7 @@ public: refhold.emplace_back(new AlgoFP16WinogradF23_8x8( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); } } #endif @@ -238,25 +242,48 @@ public: refhold.emplace_back(new AlgoS8WinogradF23_8x8( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); } } + + + for (auto&& algo : m_direct_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } + for (auto&& algo : m_winograd_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } - SmallVector direct_algos; - SmallVector winograd_algos; + + const SmallVector& direct_algos() + const { + return m_direct_algos; + } + const SmallVector& winograd_algos() + const { + return m_winograd_algos; + } + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -SmallVector ConvBiasImpl::algo_pack() { - static AlgoPack sl_algo_pack; - auto&& algos = fallback::ConvBiasImpl::algo_pack(); - algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), - sl_algo_pack.direct_algos.end()); - algos.insert(algos.end(), sl_algo_pack.winograd_algos.begin(), - sl_algo_pack.winograd_algos.end()); +const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { + static AlgoPack algo_pack; + return algo_pack; +} + +MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) + +SmallVector +ConvBiasImpl::get_all_packed_algo() { + auto&& algos = fallback::ConvBiasImpl::get_all_packed_algo(); + algos.insert(algos.begin(), algo_pack().direct_algos().begin(), + algo_pack().direct_algos().end()); + algos.insert(algos.end(), algo_pack().winograd_algos().begin(), + algo_pack().winograd_algos().end()); return std::move(algos); } diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 61f622e9..6e9d94d5 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -12,6 +12,7 @@ #pragma once #include "src/common/utils.h" #include "src/fallback/conv_bias/opr_impl.h" +#include "src/common/algo_base.h" namespace megdnn { namespace arm_common { @@ -27,7 +28,7 @@ public: } }; - SmallVector algo_pack() override; + SmallVector get_all_packed_algo() override; bool is_matmul_quantized_prefer( const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param) @@ -35,7 +36,8 @@ public: SmallVector suggest_algo_category_order( const NCBKernSizeParam& param) const override; - class AlgoPack; + + MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl); protected: const char* get_algorithm_set_name() const override; @@ -95,6 +97,9 @@ private: class AlgoF16Direct; class AlgoF16DirectStride1; #endif + + class AlgoPack; + static const AlgoPack& algo_pack(); }; } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/quint8/algos.h b/dnn/src/arm_common/conv_bias/quint8/algos.h index bda2412f..df67c166 100644 --- a/dnn/src/arm_common/conv_bias/quint8/algos.h +++ b/dnn/src/arm_common/conv_bias/quint8/algos.h @@ -32,6 +32,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_QU8) }; class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { @@ -48,6 +49,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_QU8) }; #if __ARM_FEATURE_DOTPROD class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { @@ -65,6 +67,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_QU8) }; class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { @@ -81,6 +84,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_QU8) }; #endif } // namespace arm_common diff --git a/dnn/src/arm_common/convolution/int8x8x32/algos.h b/dnn/src/arm_common/convolution/int8x8x32/algos.h index e69794dc..44fc6e84 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/algos.h +++ b/dnn/src/arm_common/convolution/int8x8x32/algos.h @@ -36,6 +36,7 @@ public: ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_INT8X8X32) }; class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final @@ -54,6 +55,7 @@ public: ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_INT8X8X32) }; #endif diff --git a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp index 5ffed44d..6e45a10e 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp +++ b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp @@ -1086,6 +1086,10 @@ bool deconv::can_stride1_int8x8x32_dot(const NCBKernSizeParam& param) { (FH == 2 || FH == 3 || FH == 5 || FH == 7) && FH >= PH + 1 && FW >= PW + 1; + avaiable &= (param.filter_type.enumv() == DTypeEnum::QuantizedS8 || + param.filter_type.enumv() == DTypeEnum::Int8) && + (param.grad_type.enumv() == DTypeEnum::QuantizedS32 || + param.grad_type.enumv() == DTypeEnum::Int32); return avaiable && ((FH == 2 && OC <= 8) || ((FH == 3 || FH == 5 || FH == 7) && (IC < 32 && OC <= 16))); diff --git a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp index 9a8e9774..8b17812d 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp +++ b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp @@ -1180,6 +1180,10 @@ bool deconv::can_stride2_int8x8x32_dot(const NCBKernSizeParam& param) { (FH == 2 || FH == 3 || FH == 5 || FH == 7) && FH >= PH + 1 && FW >= PW + 1; + avaiable &= (param.filter_type.enumv() == DTypeEnum::QuantizedS8 || + param.filter_type.enumv() == DTypeEnum::Int8) && + (param.grad_type.enumv() == DTypeEnum::QuantizedS32 || + param.grad_type.enumv() == DTypeEnum::Int32); return avaiable && ((FH == 2 && OC <= 4) || (FH == 3 && OC <= 8) || (FH == 5 && OC <= 16) || (FH == 7 && OC < 32)); } diff --git a/dnn/src/arm_common/convolution/opr_impl.cpp b/dnn/src/arm_common/convolution/opr_impl.cpp index 9d7af4de..822ad29b 100644 --- a/dnn/src/arm_common/convolution/opr_impl.cpp +++ b/dnn/src/arm_common/convolution/opr_impl.cpp @@ -23,15 +23,54 @@ using namespace arm_common; /* ===================== ConvolutionBackwardData ===================== */ -struct ConvolutionBackwardDataImpl::AlgoPack { +class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { #if __ARM_FEATURE_DOTPROD AlgoSdot8DirectStride1 i8x8x32_direct_stride1_sdot; AlgoSdot8DirectStride2 i8x8x32_direct_stride2_sdot; AlgoUdot8DirectStride1 quint8_direct_stride1_udot; AlgoUdot8DirectStride2 quint8_direct_stride2_udot; #endif + + fallback::ConvolutionBackwardDataImpl::AlgoBase::Mapper m_all_algos_map; + SmallVector + m_all_algos; + +public: + AlgoPack() { +#if __ARM_FEATURE_DOTPROD + m_all_algos.emplace_back(&i8x8x32_direct_stride1_sdot); + m_all_algos.emplace_back(&i8x8x32_direct_stride2_sdot); + m_all_algos.emplace_back(&quint8_direct_stride1_udot); + m_all_algos.emplace_back(&quint8_direct_stride2_udot); +#endif + + for (auto&& algo : m_all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } + } + + const SmallVector& + all_algos() const { + return m_all_algos; + } + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack; + +const ConvolutionBackwardDataImpl::AlgoPack& +ConvolutionBackwardDataImpl::algo_pack() { + static AlgoPack algo_pack; + return algo_pack; +} + +MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl) + +SmallVector +ConvolutionBackwardDataImpl::get_all_packed_algo() { + auto&& algos = fallback::ConvolutionBackwardDataImpl::get_all_packed_algo(); + algos.insert(algos.begin(), algo_pack().all_algos().begin(), + algo_pack().all_algos().end()); + return std::move(algos); +} ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern( @@ -52,35 +91,6 @@ size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( param); } -std::vector -ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( - const NCBKernSizeParam& param) { - auto ret = fallback::ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( - param); - -#if __ARM_FEATURE_DOTPROD - if ((param.filter_type.enumv() == DTypeEnum::QuantizedS8 || - param.filter_type.enumv() == DTypeEnum::Int8) && - (param.grad_type.enumv() == DTypeEnum::QuantizedS32 || - param.grad_type.enumv() == DTypeEnum::Int32)) { - if (sm_algo_pack.i8x8x32_direct_stride1_sdot.usable(this, param)) { - ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride1_sdot); - } - if (sm_algo_pack.i8x8x32_direct_stride2_sdot.usable(this, param)) { - ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride2_sdot); - } - } else if (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && - param.grad_type.enumv() == DTypeEnum::QuantizedS32) { - if (sm_algo_pack.quint8_direct_stride1_udot.usable(this, param)) { - ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride1_udot); - } - if (sm_algo_pack.quint8_direct_stride2_udot.usable(this, param)) { - ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride2_udot); - } - } -#endif - return ret; -} const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const { // arm common version 0 return "DeconvAC0"; diff --git a/dnn/src/arm_common/convolution/opr_impl.h b/dnn/src/arm_common/convolution/opr_impl.h index d8d124ad..cc87b698 100644 --- a/dnn/src/arm_common/convolution/opr_impl.h +++ b/dnn/src/arm_common/convolution/opr_impl.h @@ -47,11 +47,14 @@ protected: size_t ncb_1g_get_workspace(Algorithm* algo, const NCBKernSizeParam& param) override; - std::vector ncb_1g_get_all_algorithms( - const NCBKernSizeParam& param) override; - const char* get_algorithm_set_name() const override; + SmallVector + get_all_packed_algo() override; + +public: + MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl); + private: #if __ARM_FEATURE_DOTPROD class AlgoSdot8DirectStride1; @@ -59,8 +62,8 @@ private: class AlgoUdot8DirectStride1; class AlgoUdot8DirectStride2; #endif - struct AlgoPack; - static AlgoPack sm_algo_pack; + class AlgoPack; + static const AlgoPack& algo_pack(); }; } // namespace arm_common diff --git a/dnn/src/arm_common/convolution/quint8/algos.h b/dnn/src/arm_common/convolution/quint8/algos.h index a4815380..74a12340 100644 --- a/dnn/src/arm_common/convolution/quint8/algos.h +++ b/dnn/src/arm_common/convolution/quint8/algos.h @@ -36,6 +36,7 @@ public: ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_QU8) }; class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final @@ -55,6 +56,7 @@ public: ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_QU8) }; #endif } // namespace arm_common diff --git a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp index bd2f2176..8b0d1f6c 100644 --- a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp +++ b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp @@ -1236,6 +1236,9 @@ bool deconv::can_stride1_quint8_dot(const NCBKernSizeParam& param) { (FH == 2 || FH == 3 || FH == 5 || FH == 7) && FH >= PH + 1 && FW >= PW + 1; + avaiable &= (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm || + param.grad_type.enumv() == DTypeEnum::Int32); + /** * \note In the kernel, we use int32_t to calc the value, in order * not generate negative number, we first initialize SHIFT and sub diff --git a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp index 5db7ea06..e07c9ff5 100644 --- a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp +++ b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp @@ -1337,6 +1337,9 @@ bool deconv::can_stride2_quint8_dot(const NCBKernSizeParam& param) { (FH == 2 || FH == 3 || FH == 5 || FH == 7) && FH >= PH + 1 && FW >= PW + 1; + avaiable &= (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm || + param.grad_type.enumv() == DTypeEnum::Int32); + /** * \note In the kernel, we use uint32_t to calc the value, in order * not generate negative number, we first initialize SHIFT and sub diff --git a/dnn/src/arm_common/elemwise/opr_impl.h b/dnn/src/arm_common/elemwise/opr_impl.h index eab11604..8a2d6abb 100644 --- a/dnn/src/arm_common/elemwise/opr_impl.h +++ b/dnn/src/arm_common/elemwise/opr_impl.h @@ -59,6 +59,7 @@ public: virtual bool is_available(const KernParam&) const = 0; virtual void exec(const KernParam&) const = 0; virtual ~AlgoBase() = default; + uint32_t type() const override { return INVALID_ALGO_TYPE; }; }; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/dnn/src/arm_common/matrix_mul/algos.h b/dnn/src/arm_common/matrix_mul/algos.h index e728a9df..0cbb7289 100644 --- a/dnn/src/arm_common/matrix_mul/algos.h +++ b/dnn/src/arm_common/matrix_mul/algos.h @@ -26,6 +26,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT) + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X16) }; class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { @@ -39,6 +40,7 @@ public: AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV) }; class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { @@ -52,6 +54,7 @@ public: AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4) }; #if __ARM_FEATURE_DOTPROD @@ -66,6 +69,7 @@ public: AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT) + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4_DOT) }; #endif @@ -96,6 +100,7 @@ public: AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F32_GEMV_MK4) }; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -110,6 +115,7 @@ public: AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT) + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F16_GEMV) }; #endif @@ -130,6 +136,7 @@ public: static_cast(AlgoDataType::FLOAT32) | static_cast(AlgoDataType::QINT8X8X32)), DEFAULT) + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_GEVM) }; } // namespace arm_common diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.cpp b/dnn/src/arm_common/matrix_mul/opr_impl.cpp index f1527374..a2dcd28c 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.cpp +++ b/dnn/src/arm_common/matrix_mul/opr_impl.cpp @@ -28,28 +28,47 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoGevm gevm; AlgoF32GemvMK4 f32_gemv_mk4; + SmallVector m_all_algos; + fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; + public: AlgoPack() { - all_algos.emplace_back(&int8x8x16); + m_all_algos.emplace_back(&int8x8x16); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - all_algos.emplace_back(&f16gemv); + m_all_algos.emplace_back(&f16gemv); #endif #if __ARM_FEATURE_DOTPROD - all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); + m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); #endif - all_algos.emplace_back(&int8x8x32_gemv); - all_algos.emplace_back(&int8x8x32_gemv_mk4); - all_algos.emplace_back(&f32_gemv_mk4); - all_algos.emplace_back(&gevm); + m_all_algos.emplace_back(&int8x8x32_gemv); + m_all_algos.emplace_back(&int8x8x32_gemv_mk4); + m_all_algos.emplace_back(&f32_gemv_mk4); + m_all_algos.emplace_back(&gevm); + + for (auto&& algo : m_all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } + } + + const SmallVector& all_algos() const { + return m_all_algos; } - SmallVector all_algos; + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -SmallVector MatrixMulImpl::algo_pack() { +const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { + static AlgoPack algo_pack; + return algo_pack; +} + +MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl) + +SmallVector +MatrixMulImpl::get_all_packed_algo() { static AlgoPack s_algo_pack; - auto&& algos = fallback::MatrixMulImpl::algo_pack(); - algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), - s_algo_pack.all_algos.end()); + auto&& algos = fallback::MatrixMulImpl::get_all_packed_algo(); + algos.insert(algos.begin(), algo_pack().all_algos().begin(), + algo_pack().all_algos().end()); return std::move(algos); } diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.h b/dnn/src/arm_common/matrix_mul/opr_impl.h index 0014b0cf..2ea5e7d6 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.h +++ b/dnn/src/arm_common/matrix_mul/opr_impl.h @@ -11,6 +11,7 @@ #pragma once #include "src/common/utils.h" #include "src/fallback/matrix_mul/opr_impl.h" +#include "src/common/algo_base.h" namespace megdnn { namespace arm_common { @@ -27,7 +28,10 @@ public: } }; - SmallVector algo_pack() override; + SmallVector get_all_packed_algo() + override; + + MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl); protected: class AlgoF32Gemv; // Arm_common F32 Gemv @@ -43,6 +47,9 @@ protected: #endif class AlgoInt8x8x16; // Arm_common Int 8x8x16 class AlgoPack; + +public: + static const AlgoPack& algo_pack(); }; } // namespace arm_common diff --git a/dnn/src/arm_common/pooling/opr_impl.h b/dnn/src/arm_common/pooling/opr_impl.h index c3a5335e..b198f68f 100644 --- a/dnn/src/arm_common/pooling/opr_impl.h +++ b/dnn/src/arm_common/pooling/opr_impl.h @@ -10,6 +10,7 @@ * implied. */ #pragma once +#include "megdnn/oprs/base.h" #include "src/fallback/pooling/opr_impl.h" namespace megdnn { @@ -72,6 +73,8 @@ public: virtual ~AlgoBase() = default; virtual bool usable(const PoolingKernSizeParam& param) const = 0; virtual void exec(const PoolingKernParam& param) const = 0; + + uint32_t type() const override { return INVALID_ALGO_TYPE; }; }; private: diff --git a/dnn/src/armv7/conv_bias/int8/algos.h b/dnn/src/armv7/conv_bias/int8/algos.h index 748e92d1..b0cb6f93 100644 --- a/dnn/src/armv7/conv_bias/int8/algos.h +++ b/dnn/src/armv7/conv_bias/int8/algos.h @@ -40,6 +40,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; } + MEGDNN_DECL_ALGO_TYPE(ARMV7_MATMUL_S8) }; } // namespace armv7 diff --git a/dnn/src/armv7/conv_bias/opr_impl.cpp b/dnn/src/armv7/conv_bias/opr_impl.cpp index db6f09c3..53421925 100644 --- a/dnn/src/armv7/conv_bias/opr_impl.cpp +++ b/dnn/src/armv7/conv_bias/opr_impl.cpp @@ -24,22 +24,40 @@ using namespace armv7; class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoS8MatrixMul s8_matrix_mul; AlgoQU8MatrixMul qu8_matrix_mul; + fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map; + SmallVector m_all_algos; public: AlgoPack() { - all_algos.emplace_back(&qu8_matrix_mul); - all_algos.emplace_back(&s8_matrix_mul); + m_all_algos.emplace_back(&qu8_matrix_mul); + m_all_algos.emplace_back(&s8_matrix_mul); + + for (auto&& algo : m_all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } + } + + const SmallVector& all_algos() + const { + return m_all_algos; } - SmallVector all_algos; + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -SmallVector ConvBiasImpl::algo_pack() { - static AlgoPack sl_algo_pack; - auto&& algos = arm_common::ConvBiasImpl::algo_pack(); +const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { + static AlgoPack algo_pack; + return algo_pack; +} + +MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) + +SmallVector +ConvBiasImpl::get_all_packed_algo() { + auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); //! TODO fused matmul bias is slower than matmul + elemwise in armv7 now, //! and nearly equal in aarch64, because of the waste of register in //! postprocess - algos.insert(algos.end(), sl_algo_pack.all_algos.begin(), - sl_algo_pack.all_algos.end()); + algos.insert(algos.end(), algo_pack().all_algos().begin(), + algo_pack().all_algos().end()); return std::move(algos); } diff --git a/dnn/src/armv7/conv_bias/opr_impl.h b/dnn/src/armv7/conv_bias/opr_impl.h index 32d97439..744dc37e 100644 --- a/dnn/src/armv7/conv_bias/opr_impl.h +++ b/dnn/src/armv7/conv_bias/opr_impl.h @@ -25,7 +25,9 @@ public: } }; - SmallVector algo_pack() override; + SmallVector get_all_packed_algo() override; + + MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl); protected: const char* get_algorithm_set_name() const override; @@ -34,6 +36,7 @@ private: class AlgoS8MatrixMul; class AlgoQU8MatrixMul; class AlgoPack; + static const AlgoPack& algo_pack(); }; } // namespace armv7 diff --git a/dnn/src/armv7/conv_bias/quint8/algos.h b/dnn/src/armv7/conv_bias/quint8/algos.h index cd6ed708..6dd0b8dc 100644 --- a/dnn/src/armv7/conv_bias/quint8/algos.h +++ b/dnn/src/armv7/conv_bias/quint8/algos.h @@ -42,6 +42,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; } + MEGDNN_DECL_ALGO_TYPE(ARMV7_MATMUL_QU8) }; } // namespace armv7 diff --git a/dnn/src/armv7/matrix_mul/algos.h b/dnn/src/armv7/matrix_mul/algos.h index 10c8f0ca..2be450fa 100644 --- a/dnn/src/armv7/matrix_mul/algos.h +++ b/dnn/src/armv7/matrix_mul/algos.h @@ -27,6 +27,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(ARMV7_F32) }; class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { @@ -37,6 +38,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_MK4_PACK_4X12) }; class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase { @@ -48,6 +50,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) + MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_MK4_4x8) }; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -59,6 +62,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_K4X16X1) }; class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase { public: @@ -69,6 +73,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8) + MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_MK8_4X8) }; #endif #if __ARM_FEATURE_DOTPROD @@ -80,6 +85,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8_K6X8X4) }; class MatrixMulImpl::AlgoQuint8DotK4x8x4 final : public AlgoBase { @@ -90,6 +96,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(ARMV7_QUINT8_K4X8X4) }; class MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd final : public AlgoBase { @@ -102,11 +109,18 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8_MK4_8X4X4_DOTPROD) }; #endif class MatrixMulImpl::AlgoF32Gemv final - : public arm_common::MatrixMulImpl::AlgoF32Gemv {}; + : public arm_common::MatrixMulImpl::AlgoF32Gemv { +public: + AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { + m_handle_type = Handle::HandleType::ARMV7; + } + MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_GEMV) +}; class MatrixMulImpl::AlgoInt8x8x32K4x2x16 final : public AlgoBase { public: @@ -117,6 +131,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_K4X2X16) }; class MatrixMulImpl::AlgoInt8x8x32K4x8x8 final : public AlgoBase { @@ -128,6 +143,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_K4X8X8) }; class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { @@ -138,6 +154,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(ARMV7_QUINT8_K4X8X8) }; class MatrixMulImpl::AlgoInt8x8x16K4x2x16 final : public AlgoBase { @@ -149,6 +166,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K4X2X16) }; class MatrixMulImpl::AlgoInt8x8x16K4x8x8 final : public AlgoBase { @@ -160,6 +178,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K4X8X8) }; class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { @@ -171,6 +190,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_MK4_K8X8X4) }; class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase { @@ -182,6 +202,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(ARMV7_INT16X16X32_K12X4X1) }; class MatrixMulImpl::AlgoInt16x16x32MK8_4x8 final : public AlgoBase { @@ -193,6 +214,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) + MEGDNN_DECL_ALGO_TYPE(ARMV7_INT16X16X32_MK8_4X8) }; class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { @@ -204,6 +226,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_MK4_4X2X16) }; } // namespace armv7 diff --git a/dnn/src/armv7/matrix_mul/opr_impl.cpp b/dnn/src/armv7/matrix_mul/opr_impl.cpp index 6887cfe3..a2807a56 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.cpp +++ b/dnn/src/armv7/matrix_mul/opr_impl.cpp @@ -43,42 +43,60 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1; AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; + SmallVector m_all_algos; + fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; + public: - SmallVector all_algos; AlgoPack() { - all_algos.emplace_back(&f32_gemv); - all_algos.emplace_back(&f32); - all_algos.emplace_back(&f32_mk4_pack_4x12); - all_algos.emplace_back(&f32_mk4_4x8); + m_all_algos.emplace_back(&f32_gemv); + m_all_algos.emplace_back(&f32); + m_all_algos.emplace_back(&f32_mk4_pack_4x12); + m_all_algos.emplace_back(&f32_mk4_4x8); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - all_algos.emplace_back(&f16_k4x16x1); - all_algos.emplace_back(&f16_mk8_4x8); + m_all_algos.emplace_back(&f16_k4x16x1); + m_all_algos.emplace_back(&f16_mk8_4x8); #endif #if __ARM_FEATURE_DOTPROD - all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); - all_algos.emplace_back(&int8_k6x8x4); - all_algos.emplace_back(&quint8_k4x8x4); + m_all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); + m_all_algos.emplace_back(&int8_k6x8x4); + m_all_algos.emplace_back(&quint8_k4x8x4); #endif - all_algos.emplace_back(&int8x8x32_mk4_4x2x16); - all_algos.emplace_back(&int8x8x32_k4x2x16); - all_algos.emplace_back(&int8x8x32_k4x8x8); - all_algos.emplace_back(&quint8_k4x8x8); - all_algos.emplace_back(&int8x8x16_mk4_8x8x4); - all_algos.emplace_back(&int8x8x16_k4x2x16); - all_algos.emplace_back(&int8x8x16_k4x8x8); + m_all_algos.emplace_back(&int8x8x32_mk4_4x2x16); + m_all_algos.emplace_back(&int8x8x32_k4x2x16); + m_all_algos.emplace_back(&int8x8x32_k4x8x8); + m_all_algos.emplace_back(&quint8_k4x8x8); + m_all_algos.emplace_back(&int8x8x16_mk4_8x8x4); + m_all_algos.emplace_back(&int8x8x16_k4x2x16); + m_all_algos.emplace_back(&int8x8x16_k4x8x8); + + m_all_algos.emplace_back(&int16x16x32_k12x4x1); + m_all_algos.emplace_back(&int16x16x32_mk8_4x8); + + for (auto&& algo : m_all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } + } - all_algos.emplace_back(&int16x16x32_k12x4x1); - all_algos.emplace_back(&int16x16x32_mk8_4x8); + const SmallVector& all_algos() const { + return m_all_algos; } + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -SmallVector MatrixMulImpl::algo_pack() { - static AlgoPack s_algo_pack; - auto algos = arm_common::MatrixMulImpl::algo_pack(); - algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), - s_algo_pack.all_algos.end()); +const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { + static AlgoPack algo_pack; + return algo_pack; +} + +SmallVector +MatrixMulImpl::get_all_packed_algo() { + auto algos = arm_common::MatrixMulImpl::get_all_packed_algo(); + algos.insert(algos.begin(), algo_pack().all_algos().begin(), + algo_pack().all_algos().end()); return algos; } +MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl) + // vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/opr_impl.h b/dnn/src/armv7/matrix_mul/opr_impl.h index 744099a8..8ff401f7 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.h +++ b/dnn/src/armv7/matrix_mul/opr_impl.h @@ -25,7 +25,10 @@ public: } }; - SmallVector algo_pack() override; + SmallVector get_all_packed_algo() + override; + + MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl); private: class AlgoF32; // Armv7 F32 @@ -52,6 +55,9 @@ private: // DotProduct #endif class AlgoPack; + +public: + static const AlgoPack& algo_pack(); }; } // namespace armv7 diff --git a/dnn/src/common/algo_base.h b/dnn/src/common/algo_base.h new file mode 100644 index 00000000..39854557 --- /dev/null +++ b/dnn/src/common/algo_base.h @@ -0,0 +1,101 @@ +/** + * \file dnn/src/common/algo_base.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include +#include + +#include "megdnn/oprs/base.h" +#include "src/common/utils.h" + +namespace megdnn { + +#define MEGDNN_DECL_ALGO_TYPE(_type) \ + uint32_t type() const override { \ + return static_cast::type>( \ + AlgoType::_type); \ + } + +#define MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(_opr) \ + static fallback::_opr::AlgoBase* get_algo_from_desc( \ + const AlgorithmDesc& desc) + +#define MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(_opr) \ + fallback::_opr::AlgoBase* _opr::get_algo_from_desc( \ + const AlgorithmDesc& desc) { \ + megdnn_assert(algo_pack().all_algos_map().find(desc) != \ + algo_pack().all_algos_map().end()); \ + return algo_pack().all_algos_map().at(desc); \ + } + +#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \ + _opr::AlgoBase* _opr::get_algo_from_desc(const AlgorithmDesc& desc) { \ + megdnn_assert(algo_pack().all_algos_map().find(desc) != \ + algo_pack().all_algos_map().end()); \ + return algo_pack().all_algos_map().at(desc); \ + } + +/** + * \brief construct algo from AlgorithmDesc + */ +template +class AlgoConstructMixin { +private: + std::vector> m_refhold; +protected: + typename AlgoBase::Mapper m_all_algos_map; + +public: + + //! construct the algo which described by desc, and return the instance + AlgoBase* construct_and_get_algo( + const detail::Algorithm::Info::Desc& desc) { + auto iter = m_all_algos_map.find(desc); + if (iter != m_all_algos_map.end()) { + return m_all_algos_map.at(desc); + } + std::string serialized_bin; + AlgoBase::serialize_write_pod(desc.type, serialized_bin); + serialized_bin += desc.param; + m_refhold.emplace_back(AlgoBase::deserialize(serialized_bin)); + m_all_algos_map.emplace(desc, m_refhold.back().get()); + return m_refhold.back().get(); + } + + void clear() { + m_all_algos_map.clear(); + m_refhold.clear(); + } + + const typename AlgoBase::Mapper& all_algos_map() const { + return m_all_algos_map; + } +}; + +} // namespace megdnn + +namespace std { +template <> +struct hash { + std::size_t operator()( + const megdnn::detail::Algorithm::Info::Desc& desc) const { + return megdnn::hash_combine( + megdnn::hash_combine( + std::hash()(desc.param), + std::hash()(desc.type)), + std::hash()(static_cast(desc.handle_type))); + } +}; +} // namespace std + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/common/algo_chooser.h b/dnn/src/common/algo_chooser.h index 49d449f0..e597ca3f 100644 --- a/dnn/src/common/algo_chooser.h +++ b/dnn/src/common/algo_chooser.h @@ -25,15 +25,34 @@ namespace megdnn { */ template typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { - typename Opr::Algorithm* ret; - if (auto set = opr->execution_policy().algorithm) { + typename Opr::AlgorithmInfo ret; + auto set = opr->execution_policy().algo; + if (set.valid()) { ret = set; } else { - ret = opr->get_algorithm_heuristic(std::forward(args)..., - std::numeric_limits::max(), - false); + ret = opr->get_algorithm_info_heuristic( + std::forward(args)..., std::numeric_limits::max(), + false); + } + return opr->get_algo_from_desc(ret.desc); +} + +/*! + * \brief get user-configured algorithm, or heuristic algorithm. used in opencl + * whose algo need to be constructed each time. + */ +template +typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { + typename Opr::AlgorithmInfo ret; + auto set = opr->execution_policy().algo; + if (set.valid()) { + return opr->algo_pack().construct_and_get_algo(set.desc); + } else { + ret = opr->get_algorithm_info_heuristic( + std::forward(args)..., std::numeric_limits::max(), + false); + return opr->get_algo_from_desc(ret.desc); } - return static_cast(ret); } /*! diff --git a/dnn/src/common/utils.h b/dnn/src/common/utils.h index ec04802f..7da91606 100644 --- a/dnn/src/common/utils.h +++ b/dnn/src/common/utils.h @@ -9,6 +9,32 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/** + * Boost Software License - Version 1.0 - August 17th, 2003 + * + * Permission is hereby granted, free of charge, to any person or organization + * obtaining a copy of the software and accompanying documentation covered by + * this license (the "Software") to use, reproduce, display, distribute, + * execute, and transmit the Software, and to prepare derivative works of the + * Software, and to permit third-parties to whom the Software is furnished to + * do so, all subject to the following: + * + * The copyright notices in the Software and this entire statement, including + * the above license grant, this restriction and the following disclaimer, + * must be included in all copies of the Software, in whole or in part, and + * all derivative works of the Software, unless such copies or derivative + * works are solely in the form of machine-executable object code generated by + * a source language processor. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT + * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE + * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, + * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + #pragma once #include "megdnn/arch.h" @@ -263,6 +289,13 @@ constexpr uint32_t operator"" _hash(char const* str, size_t count) { return XXHash64CT::hash(str, count, 20160701); } +// refer to https://www.boost.org/doc/libs/1_64_0/boost/functional/hash/hash.hpp +template +inline T hash_combine(T seed, T value) { + seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; +} + template std::string vec2str(Vec&& vec) { std::string res; diff --git a/dnn/src/cuda/batch_conv_bias/algo.cpp b/dnn/src/cuda/batch_conv_bias/algo.cpp index 705ff270..c829278b 100644 --- a/dnn/src/cuda/batch_conv_bias/algo.cpp +++ b/dnn/src/cuda/batch_conv_bias/algo.cpp @@ -18,8 +18,14 @@ using namespace cuda; BatchConvBiasForwardImpl::AlgoPack::AlgoPack() { all_algos.push_back(&int8_nchw4_gemm_dotprod); all_algos.push_back(&int8_nchw4_implicit_gemm_dotprod); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } +MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchConvBiasForwardImpl) + BatchConvBiasForwardImpl::AlgoPack BatchConvBiasForwardImpl::sm_algo_pack; BatchConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( diff --git a/dnn/src/cuda/batch_conv_bias/algo.h b/dnn/src/cuda/batch_conv_bias/algo.h index e7748cab..1dade319 100644 --- a/dnn/src/cuda/batch_conv_bias/algo.h +++ b/dnn/src/cuda/batch_conv_bias/algo.h @@ -11,13 +11,16 @@ #pragma once -#include +#include #include "megdnn/oprs.h" #include "src/common/utils.h" #include "src/cuda/batch_conv_bias/opr_impl.h" #include "src/cuda/handle.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" + namespace megdnn { namespace cuda { @@ -26,6 +29,12 @@ protected: ~AlgoBase() = default; public: + enum class AlgoType : uint32_t { + CUDA_GEMM_NCHW4_DOTPROD_INT8, + CUDA_IMPLICIT_GEMM_PRECOMP_NCHW4_DOTPROD_INT8, + }; + using Mapper = std::unordered_map; + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { BatchConvBiasForwardImpl* opr; @@ -85,6 +94,7 @@ public: const char* name() const override { return "BATCH_CONV_BIAS_INT8_NCHW4_GEMM_DOTPROD"; } + MEGDNN_DECL_ALGO_TYPE(CUDA_GEMM_NCHW4_DOTPROD_INT8) }; class BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp final @@ -99,15 +109,16 @@ public: const char* name() const override { return "BATCH_CONV_BIAS_INT8_NCHW4_IMPLICIT_GEMM_PRECOMP_DOTPROD"; } + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_PRECOMP_NCHW4_DOTPROD_INT8) private: WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; }; -class BatchConvBiasForwardImpl::AlgoPack { - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator=(const AlgoPack&) = delete; +class BatchConvBiasForwardImpl::AlgoPack : NonCopyableObj { +private: + AlgoBase::Mapper m_all_algos_map; public: AlgoPack(); @@ -116,6 +127,8 @@ public: AlgoInt8NCHW4DotProdImplicitGemmPrecomp int8_nchw4_implicit_gemm_dotprod; std::vector all_algos; + + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; } // namespace cuda diff --git a/dnn/src/cuda/batch_conv_bias/opr_impl.h b/dnn/src/cuda/batch_conv_bias/opr_impl.h index 4ad3faaa..276de278 100644 --- a/dnn/src/cuda/batch_conv_bias/opr_impl.h +++ b/dnn/src/cuda/batch_conv_bias/opr_impl.h @@ -26,6 +26,18 @@ public: const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst) override; + const char* get_algorithm_set_name() const override; + + class AlgoBase; + class AlgoInt8NCHW4DotProdGemm; + class AlgoInt8NCHW4DotProdImplicitGemmPrecomp; + + class AlgoPack; + + static const AlgoPack& algo_pack() { return sm_algo_pack; } + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + +protected: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z, @@ -37,15 +49,6 @@ public: const TensorLayout& dst, size_t workspace_limit_in_bytes, bool reproducible) override; - const char* get_algorithm_set_name() const override; - - class AlgoBase; - class AlgoInt8NCHW4DotProdGemm; - class AlgoInt8NCHW4DotProdImplicitGemmPrecomp; - - class AlgoPack; - - static const AlgoPack& algo_pack() { return sm_algo_pack; } private: static AlgoPack sm_algo_pack; diff --git a/dnn/src/cuda/batched_matrix_mul/algo.cpp b/dnn/src/cuda/batched_matrix_mul/algo.cpp index da8d396b..01b7001d 100644 --- a/dnn/src/cuda/batched_matrix_mul/algo.cpp +++ b/dnn/src/cuda/batched_matrix_mul/algo.cpp @@ -60,4 +60,12 @@ BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() { for (auto& algo : brute_force_algos) { all_algos.push_back(&algo); } + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } + +MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl) + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/batched_matrix_mul/algo.h b/dnn/src/cuda/batched_matrix_mul/algo.h index b0b3bd8a..1329a546 100644 --- a/dnn/src/cuda/batched_matrix_mul/algo.h +++ b/dnn/src/cuda/batched_matrix_mul/algo.h @@ -16,6 +16,8 @@ #include "src/common/utils.h" #include "src/cuda/batched_matrix_mul/opr_impl.h" #include "src/cuda/matrix_mul/cublasLt_wrapper.h" +#include "src/common/metahelper.h" + #if CUDA_VERSION >= 10010 #include #endif @@ -28,6 +30,14 @@ protected: ~AlgoBase() = default; public: + enum class AlgoType : uint32_t { + CUDA_BRUTE_FORCE, + CUDA_CUBLAS, + CUDA_CUBLASLT, + CUDA_INT8X8X32, + }; + using Mapper = std::unordered_map; + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { BatchedMatrixMulForwardImpl* opr; @@ -90,6 +100,13 @@ public: void exec(const ExecArgs& args) const final; bool is_reproducible() const override { return true; } const char* name() const override { return m_name.c_str(); } + MEGDNN_DECL_ALGO_TYPE(CUDA_BRUTE_FORCE) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_algorithm, ret); + return ret; + } }; class BatchedMatrixMulForwardImpl::AlgoCublas final : public BatchedMatrixMulForwardImpl::AlgoBase { @@ -100,6 +117,7 @@ public: void exec(const ExecArgs& args) const final; bool is_reproducible() const override { return true; } const char* name() const override { return "CUBLAS"; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) }; #if CUDA_VERSION >= 10010 class BatchedMatrixMulForwardImpl::AlgoCublasLt final : public AlgoBase { @@ -110,6 +128,7 @@ public: void exec(const ExecArgs& args) const final; bool is_reproducible() const override { return true; } const char* name() const override { return "CUBLAS_LT"; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) }; #endif class BatchedMatrixMulForwardImpl::AlgoInt8x8x32 final @@ -121,11 +140,13 @@ public: void exec(const ExecArgs& args) const final; bool is_reproducible() const override { return true; } const char* name() const override { return "INT8x8x32"; } + MEGDNN_DECL_ALGO_TYPE(CUDA_INT8X8X32) }; -class BatchedMatrixMulForwardImpl::AlgoPack { + +class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj { +private: + AlgoBase::Mapper m_all_algos_map; MatrixMulForwardImpl::AlgoPack mm_pack; - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator=(const AlgoPack&) = delete; public: AlgoPack(); @@ -137,6 +158,8 @@ public: AlgoInt8x8x32 int8x8x32; std::vector all_algos; std::vector brute_force_algos; + + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/batched_matrix_mul/brute_force.cpp b/dnn/src/cuda/batched_matrix_mul/brute_force.cpp index 0da6aa14..3c12ff73 100644 --- a/dnn/src/cuda/batched_matrix_mul/brute_force.cpp +++ b/dnn/src/cuda/batched_matrix_mul/brute_force.cpp @@ -24,7 +24,7 @@ bool BatchedMatrixMulForwardImpl::AlgoBruteForce::is_available( const SizeArgs& args) const { MatrixMulForwardImpl mm{args.opr->handle()}; mm.param() = {args.opr->param().transposeA, args.opr->param().transposeB}; - mm.execution_policy() = {m_algorithm}; + mm.execution_policy() = {m_algorithm->info()}; auto mm_layout_a = args.layout_a.remove_axis(0); auto mm_layout_b = args.layout_b.remove_axis(0); @@ -39,7 +39,7 @@ size_t BatchedMatrixMulForwardImpl::AlgoBruteForce::get_workspace_in_bytes( auto mm_opr = args.opr->handle()->create_operator(); mm_opr->param() = {args.opr->param().transposeA, args.opr->param().transposeB}; - mm_opr->execution_policy() = {m_algorithm}; + mm_opr->execution_policy() = {m_algorithm->info()}; return mm_opr->get_workspace_in_bytes(args.layout_a, args.layout_b, args.layout_c); @@ -50,7 +50,7 @@ void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec( auto&& mm_opr = args.opr->handle()->create_operator(); mm_opr->param() = {args.opr->param().transposeA, args.opr->param().transposeB}; - mm_opr->execution_policy() = {m_algorithm}; + mm_opr->execution_policy() = {m_algorithm->info()}; rep(n, N) { TensorND A_, B_, C_; auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) { diff --git a/dnn/src/cuda/batched_matrix_mul/opr_impl.h b/dnn/src/cuda/batched_matrix_mul/opr_impl.h index c38da62b..eafd0fc9 100644 --- a/dnn/src/cuda/batched_matrix_mul/opr_impl.h +++ b/dnn/src/cuda/batched_matrix_mul/opr_impl.h @@ -32,6 +32,16 @@ public: _megdnn_workspace workspace) override; size_t get_workspace_in_bytes(const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) override; + + const char* get_algorithm_set_name() const override { + return "BATCHED_MATMUL"; + } + + bool is_thread_safe() const override { return true; } + static const AlgoPack& algo_pack() { return sm_algo_pack; } + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + +protected: std::vector get_all_algorithms(const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) override; @@ -40,12 +50,6 @@ public: const TensorLayout& C, size_t workspace_limit_in_bytes, bool reproducible) override; - const char* get_algorithm_set_name() const override { - return "BATCHED_MATMUL"; - } - - bool is_thread_safe() const override { return true; } - static const AlgoPack& algo_pack() { return sm_algo_pack; } private: static AlgoPack sm_algo_pack; diff --git a/dnn/src/cuda/conv_bias/algo.cpp b/dnn/src/cuda/conv_bias/algo.cpp index 014bbde9..8614b790 100644 --- a/dnn/src/cuda/conv_bias/algo.cpp +++ b/dnn/src/cuda/conv_bias/algo.cpp @@ -100,10 +100,16 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { for (size_t i = all_algo_size; i < all_algos.size(); ++i) { non_cudnn_algos.push_back(all_algos[i]); } + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } ConvBiasForwardImpl::AlgoPack ConvBiasForwardImpl::sm_algo_pack; +MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvBiasForwardImpl) + ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( ConvBiasForwardImpl* o, const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, @@ -172,43 +178,10 @@ std::string ConvBiasForwardImpl::AlgoBase::SizeArgs::to_string() const { } void ConvBiasForwardImpl::AlgoPack::fill_cudnn_algos() { -#define V1(v) #v -#define V(v) V1(v) - -#define DEF_ALGO(NAME, REPROD) \ - cudnn_conv_bias_activations.push_back( \ - {REPROD, \ - "CUDNN:ConvBiasActivation:" #NAME \ - "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL), \ - NAME}); \ - cudnn_convs.push_back( \ - {REPROD, \ - "CUDNN:Convolution:" #NAME \ - "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL), \ - NAME}) - - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true); - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true); - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true); - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true); - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true); - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true); - -#if CUDNN_MAJOR >= 5 - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true); -#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true); -#endif -#endif - -#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) -#pragma message "not latest cudnn" -#endif - -#undef DEF_ALGO - -#undef V -#undef V1 + for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) { + cudnn_conv_bias_activations.push_back(algo.first); + cudnn_convs.push_back(algo.first); + } } #if CUDA_VERSION >= 10000 diff --git a/dnn/src/cuda/conv_bias/algo.h b/dnn/src/cuda/conv_bias/algo.h index 21b3c6e4..95580231 100644 --- a/dnn/src/cuda/conv_bias/algo.h +++ b/dnn/src/cuda/conv_bias/algo.h @@ -6,19 +6,23 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/oprs.h" +#include "src/common/algo_base.h" #include "src/common/utils.h" +#include "src/common/metahelper.h" #include "src/cuda/conv_bias/conv_bias_int8.cuh" #include "src/cuda/conv_bias/helper.h" #include "src/cuda/conv_bias/opr_impl.h" #include "src/cuda/convolution_helper/parameter.cuh" #include "src/cuda/handle.h" +#include "src/cuda/cudnn_wrapper.h" #include #include @@ -38,11 +42,39 @@ protected: ~AlgoBase() = default; public: + enum class AlgoType : uint32_t { + CUDA_CUDNN_CONVBIAS, + CUDA_CHANWISE, + CUDA_CHANWISE_SMALL, + CUDA_CHANWISE_INT8X8X32, + CUDA_CUDNN_CONV, + CUDA_INPLACE_MATMUL, + CUDA_MATMUL, + CUDA_MATMUL_INT8X8X32, + CUDA_1X1, + CUDA_BATCHED_MATMUL, + CUDA_GROUP_CONV_GENERAL, + CUDA_WMMA_UINT4X4X32, + CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8, + CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8, + CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8, + CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8, + CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8, + CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8, + CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8, + CUDA_BFLOAT16, + CUDA_IMPLICIT_GEMM_SASS_NCHW4_DOTPROD_INT8, + CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8, + CUDA_IMPLICIT_GEMM_SASS_NCHW32_IMMA_INT8, + CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8, + }; + using Mapper = std::unordered_map; + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs : public conv_bias::BiasForwardSizeArgs { ConvBiasForwardImpl* opr; const PreprocessedFilter* preprocessed_filter; - + std::string to_string() const; SizeArgs(ConvBiasForwardImpl* opr, const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, @@ -80,13 +112,17 @@ public: virtual void exec(const ExecArgs& args) const = 0; virtual size_t get_preprocess_workspace_in_bytes( const SizeArgs& args) const { + MEGDNN_MARK_USED_VAR(args); return 0; } virtual SmallVector deduce_preprocessed_filter_layout( const SizeArgs& args) const { + MEGDNN_MARK_USED_VAR(args); return {}; } - virtual void exec_preprocess(const ExecArgs& args) const {} + virtual void exec_preprocess(const ExecArgs& args) const { + MEGDNN_MARK_USED_VAR(args); + } bool is_available_wk(const SizeArgs& args, size_t limit) { return is_available(args) && get_workspace_in_bytes(args) <= limit; @@ -114,11 +150,14 @@ public: class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final : public AlgoBase { public: - AlgoCUDNNConvBiasActivation(bool is_reproducible, const char* name, - cudnnConvolutionFwdAlgo_t cudnn_enum) - : m_is_reproducible(is_reproducible), - m_name(ConvBiasForward::algo_name(name, {})), - m_cudnn_enum(cudnn_enum) {} + AlgoCUDNNConvBiasActivation(cudnnConvolutionFwdAlgo_t cudnn_enum) + : m_cudnn_enum(cudnn_enum) { + megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != + CudnnAlgoPack::conv_fwd_algos().end()); + m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum); + m_name = ConvBiasForward::algo_name( + "CUDNN:ConvBiasActivation:" + m_attr.name, {}); + } size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; @@ -127,16 +166,24 @@ public: const char* name() const override { return m_name.c_str(); } - bool is_reproducible() const override { return m_is_reproducible; } + bool is_reproducible() const override { return m_attr.is_reproducible; } cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; } bool is_cudnn() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONVBIAS) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_cudnn_enum, ret); + return ret; + } + private: - bool m_is_reproducible; std::string m_name; cudnnConvolutionFwdAlgo_t m_cudnn_enum; + CudnnAlgoPack::Attr m_attr; }; class ConvBiasForwardImpl::AlgoChanwise final : public AlgoBase { @@ -154,6 +201,8 @@ public: } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) + private: mutable std::string m_name; }; @@ -172,6 +221,7 @@ public: return m_name.c_str(); } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) private: mutable std::string m_name; @@ -190,6 +240,7 @@ public: return m_name.c_str(); } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_INT8X8X32) private: mutable std::string m_name; @@ -197,27 +248,39 @@ private: class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoBase { public: - AlgoCUDNNConv(bool is_reproducible, const char* name, - cudnnConvolutionFwdAlgo_t cudnn_enum) - : m_is_reproducible(is_reproducible), - m_name(ConvBiasForward::algo_name(name, {})), - m_cudnn_enum(cudnn_enum) {} + AlgoCUDNNConv(cudnnConvolutionFwdAlgo_t cudnn_enum) + : m_cudnn_enum(cudnn_enum) { + megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != + CudnnAlgoPack::conv_fwd_algos().end()); + m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum); + m_name = ConvBiasForward::algo_name( + "CUDNN:Convolution:" + m_attr.name, {}); + } bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; - bool is_reproducible() const override { return m_is_reproducible; } + bool is_reproducible() const override { return m_attr.is_reproducible; } const char* name() const override { return m_name.c_str(); } cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; } bool is_cudnn() const override { return true; } + + MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONV) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_cudnn_enum, ret); + return ret; + } + private: - bool m_is_reproducible; std::string m_name; cudnnConvolutionFwdAlgo_t m_cudnn_enum; + CudnnAlgoPack::Attr m_attr; WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; }; @@ -237,6 +300,7 @@ public: return m_name.c_str(); } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) private: mutable std::string m_name; @@ -261,6 +325,7 @@ public: return m_name.c_str(); } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) private: WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; @@ -281,6 +346,7 @@ public: return m_name.c_str(); } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL_INT8X8X32) private: bool need_src_unroll(const SizeArgs& args) const; @@ -310,6 +376,7 @@ public: return m_name.c_str(); } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_1X1) private: WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; @@ -333,6 +400,7 @@ public: return m_name.c_str(); } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) private: WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; @@ -354,6 +422,13 @@ public: static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, TensorLayout& dst_pg, TensorLayout& bias_pg); + MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_impl, ret); + return ret; + } private: WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; @@ -370,10 +445,13 @@ public: void exec(const ExecArgs& args) const override; const char* name() const override { return "QUINT4x4x32_WMMA"; } bool is_reproducible() const override { return true; } + private: - WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, + const SizeArgs& args) const; bool use_kernel_fhxfw(const SizeArgs& args) const; size_t get_workspace_in_bytes_do_conv(const SizeArgs& args) const; + MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32) }; #endif @@ -395,6 +473,7 @@ public: const convolution::ConvParam& param, float alpha, float beta, float gamma, float scale, cudaStream_t stream, param::ConvBias::NonlineMode nonlinear_mode); + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8) }; class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final @@ -415,8 +494,9 @@ public: warp_k == 32 && stage == 2) { return ""; } - return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n, - threadblock_k, warp_m, warp_n, warp_k, stage); + return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, + threadblock_n, threadblock_k, warp_m, warp_n, + warp_k, stage); } }; AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) @@ -433,6 +513,13 @@ public: SmallVector deduce_preprocessed_filter_layout( const SizeArgs& args) const override; void exec_preprocess(const ExecArgs& args) const override; + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_algo_param, ret); + return ret; + } private: WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, @@ -457,9 +544,7 @@ public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; - const char* name() const override { - return m_name.c_str(); - } + const char* name() const override { return m_name.c_str(); } bool is_reproducible() const override { return true; } template static void dispatch_nonlinear_mode( @@ -471,6 +556,14 @@ public: MMATileSize mma_tile_size); static std::string to_string(MMATileSize mma_tile_size); + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_mma_tile_size, ret); + return ret; + } + private: MMATileSize m_mma_tile_size; std::string m_name; @@ -488,10 +581,16 @@ public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; - const char* name() const override { - return m_name.c_str(); - } + const char* name() const override { return m_name.c_str(); } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_mma_tile_size, ret); + return ret; + } + private: WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; @@ -513,6 +612,13 @@ public: void exec(const ExecArgs& args) const override; const char* name() const override { return m_name.c_str(); } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_mma_tile_size, ret); + return ret; + } private: MMATileSize m_mma_tile_size; @@ -533,6 +639,13 @@ public: void exec(const ExecArgs& args) const override; const char* name() const override { return m_name.c_str(); } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_mma_tile_size, ret); + return ret; + } private: MMATileSize m_mma_tile_size; @@ -570,6 +683,13 @@ public: SmallVector deduce_preprocessed_filter_layout( const SizeArgs& args) const override; void exec_preprocess(const ExecArgs& args) const override; + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_algo_param, ret); + return ret; + } private: WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, @@ -592,6 +712,14 @@ public: bool is_reproducible() const override { return m_impl->is_reproducible(); } + MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_impl, ret); + return ret; + } + private: SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr, TensorLayout& fsrc, TensorLayout& ffilter, @@ -603,17 +731,16 @@ private: }; -class ConvBiasForwardImpl::AlgoPack { - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator=(const AlgoPack&) = delete; +class ConvBiasForwardImpl::AlgoPack : NonCopyableObj { +private: + AlgoBase::Mapper m_all_algos_map; public: AlgoPack(); std::vector all_algos, //! non-cudnn algos, used for heuristic if cudnn is not supported - non_cudnn_algos, - bfloat16_algos; + non_cudnn_algos, bfloat16_algos; std::vector cudnn_conv_bias_activations; std::vector cudnn_convs; AlgoChanwise chanwise; @@ -646,6 +773,8 @@ public: AlgoBase* cudnn_conv_from_enum(cudnnConvolutionFwdAlgo_t algo); + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } + private: #if CUDA_VERSION >= 10000 void fill_imma_algos(); diff --git a/dnn/src/cuda/conv_bias/bfloat16.cpp b/dnn/src/cuda/conv_bias/bfloat16.cpp index 283ffd8f..68f93659 100644 --- a/dnn/src/cuda/conv_bias/bfloat16.cpp +++ b/dnn/src/cuda/conv_bias/bfloat16.cpp @@ -47,7 +47,7 @@ ConvBiasForwardImpl::AlgoBFloat16::float_args( change_dtype(fdst); opr->param() = args.opr->param(); opr->param().compute_mode = Param::ComputeMode::DEFAULT; - opr->execution_policy() = {m_impl}; + opr->execution_policy() = {m_impl->info()}; return SizeArgs(opr, fsrc, ffilter, fbias, fz, fdst); } @@ -110,7 +110,7 @@ void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { auto convbias_opr = args.handle->create_operator(); convbias_opr->param() = args.opr->param(); convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT; - convbias_opr->execution_policy() = {m_impl}; + convbias_opr->execution_policy() = {m_impl->info()}; convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor, fdst_tensor, nullptr, cvter.workspace()); } diff --git a/dnn/src/cuda/conv_bias/opr_impl.cpp b/dnn/src/cuda/conv_bias/opr_impl.cpp index 6bccb117..ad793a75 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.cpp +++ b/dnn/src/cuda/conv_bias/opr_impl.cpp @@ -63,12 +63,12 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( auto conv_args = args; auto cudnn_conv_bias_act_from_enum_wrapper = - [this](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { + [](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { return sm_algo_pack.cudnn_conv_bias_act_from_enum(algo); }; auto cudnn_conv_from_enum_wrapper = - [this](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { + [](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { return sm_algo_pack.cudnn_conv_from_enum(algo); }; diff --git a/dnn/src/cuda/conv_bias/opr_impl.h b/dnn/src/cuda/conv_bias/opr_impl.h index 09c87973..4ba78f9c 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.h +++ b/dnn/src/cuda/conv_bias/opr_impl.h @@ -24,17 +24,6 @@ public: _megdnn_tensor_out dst, const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) override; - std::vector get_all_algorithms( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& bias, - const TensorLayout& z, - const TensorLayout& dst, - size_t workspace_limit_in_bytes, - bool reproducible) override; size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, @@ -80,6 +69,20 @@ public: static const AlgoPack& algo_pack() { return sm_algo_pack; } + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + + std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& bias, + const TensorLayout& z, + const TensorLayout& dst, + size_t workspace_limit_in_bytes, + bool reproducible) override; + private: static AlgoPack sm_algo_pack; }; diff --git a/dnn/src/cuda/convolution/backward_data/algo.cpp b/dnn/src/cuda/convolution/backward_data/algo.cpp index b888e947..cba028a4 100644 --- a/dnn/src/cuda/convolution/backward_data/algo.cpp +++ b/dnn/src/cuda/convolution/backward_data/algo.cpp @@ -52,8 +52,14 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { all_algos.push_back(bfloat16_refhold.back().get()); bfloat16_algos.push_back(bfloat16_refhold.back().get()); } + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } +MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl) + ConvolutionBackwardDataImpl::AlgoCUDNN* ConvolutionBackwardDataImpl::AlgoPack::cudnn_from_enum( cudnnConvolutionBwdDataAlgo_t algo) { diff --git a/dnn/src/cuda/convolution/backward_data/algo.h b/dnn/src/cuda/convolution/backward_data/algo.h index eaa6038d..380fb783 100644 --- a/dnn/src/cuda/convolution/backward_data/algo.h +++ b/dnn/src/cuda/convolution/backward_data/algo.h @@ -11,8 +11,11 @@ #pragma once -#include "src/cuda/convolution/helper.h" #include +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" +#include "src/cuda/convolution/helper.h" +#include "src/cuda/cudnn_wrapper.h" namespace megdnn { namespace cuda { @@ -23,154 +26,146 @@ namespace cuda { * All the algo impls should try to support non-contiguous batch dim, for group * conv execution. */ -class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm { - protected: - ~AlgoBase() = default; - - public: - AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } - struct SizeArgs { - HandleImpl *handle; - CanonizedFilterMeta filter_meta; - const TensorLayout *diff_layout, *grad_layout, *filter_layout; - ConvolutionBackwardDataImpl *opr; - - std::string to_string() const; - void init_desc(convolution::CUDNNBwdDataDescs &desc) const { - desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); - } - SizeArgs(ConvolutionBackwardDataImpl* opr, - const TensorLayout& filter, const TensorLayout& diff, - const TensorLayout& grad); - SizeArgs(ConvolutionBackwardDataImpl* opr, - const TensorLayout& filter, - const CanonizedFilterMeta& filter_meta, - const TensorLayout& diff, const TensorLayout& grad); - - convolution::ForwardSizeArgs as_fwd_args() const { - return {handle, grad_layout, filter_layout, filter_meta, - diff_layout}; - } - }; - struct ExecArgs: public SizeArgs { - const TensorND *filter_tensor, *diff_tensor, *grad_tensor; - Workspace workspace; - - ExecArgs(ConvolutionBackwardDataImpl *opr, - _megdnn_tensor_in filter, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace); - }; - virtual bool is_available(const SizeArgs &args) const = 0; - virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; - virtual void exec(const ExecArgs &args) const = 0; - - bool is_available_wk(const SizeArgs &args, size_t limit) { - return is_available(args) && get_workspace_in_bytes(args) <= limit; - } +class ConvolutionBackwardDataImpl::AlgoBase : public Algorithm { +protected: + ~AlgoBase() = default; - bool is_available_reproducible( - const SizeArgs& args, bool reproducible = true, - size_t limit = std::numeric_limits::max()) { - return (!reproducible || is_reproducible()) && - is_available_wk(args, limit); - } - - AlgoBase& check_workspace( - const SizeArgs &args, const Workspace &workspace) { - auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "conv bwd data algo %s: " - "required workspace %zu bytes, got %zu", - name(), req, workspace.size); - return *this; +public: + enum class AlgoType : uint32_t { + CUDA_CUDNN, + CUDA_MATMUL, + CUDA_CHANWISE, + CUDA_CHANWISE_SMALL, + CUDA_BFLOAT16, + CUDA_GROUP_CONV_GENERAL, + }; + using Mapper = std::unordered_map; + + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } + struct SizeArgs { + HandleImpl* handle; + CanonizedFilterMeta filter_meta; + const TensorLayout *diff_layout, *grad_layout, *filter_layout; + ConvolutionBackwardDataImpl* opr; + + std::string to_string() const; + void init_desc(convolution::CUDNNBwdDataDescs& desc) const { + desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); } - - virtual bool is_cudnn() const { - return false; + SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter, + const TensorLayout& diff, const TensorLayout& grad); + SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter, + const CanonizedFilterMeta& filter_meta, + const TensorLayout& diff, const TensorLayout& grad); + + convolution::ForwardSizeArgs as_fwd_args() const { + return {handle, grad_layout, filter_layout, filter_meta, + diff_layout}; } + }; + struct ExecArgs : public SizeArgs { + const TensorND *filter_tensor, *diff_tensor, *grad_tensor; + Workspace workspace; + + ExecArgs(ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace); + }; + virtual bool is_available(const SizeArgs& args) const = 0; + virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; + virtual void exec(const ExecArgs& args) const = 0; + + bool is_available_wk(const SizeArgs& args, size_t limit) { + return is_available(args) && get_workspace_in_bytes(args) <= limit; + } + + bool is_available_reproducible( + const SizeArgs& args, bool reproducible = true, + size_t limit = std::numeric_limits::max()) { + return (!reproducible || is_reproducible()) && + is_available_wk(args, limit); + } + + AlgoBase& check_workspace(const SizeArgs& args, + const Workspace& workspace) { + auto req = get_workspace_in_bytes(args); + megdnn_assert(req <= workspace.size, + "conv bwd data algo %s: " + "required workspace %zu bytes, got %zu", + name(), req, workspace.size); + return *this; + } + + virtual bool is_cudnn() const { return false; } }; class ConvolutionBackwardDataImpl::AlgoCUDNN final : public AlgoBase { - bool m_is_reproducible; - const char *m_name; cudnnConvolutionBwdDataAlgo_t m_cudnn_enum; + CudnnAlgoPack::Attr m_attr; - public: +public: + AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum) + : m_cudnn_enum(cudnn_enum) { + megdnn_assert(CudnnAlgoPack::conv_bwd_data_algos().find(cudnn_enum) != + CudnnAlgoPack::conv_bwd_data_algos().end()); + m_attr = CudnnAlgoPack::conv_bwd_data_algos().at(cudnn_enum); + } - AlgoCUDNN(bool is_reproducible, const char *name, - cudnnConvolutionBwdDataAlgo_t cudnn_enum): - m_is_reproducible(is_reproducible), - m_name(name), - m_cudnn_enum(cudnn_enum) - {} + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; + bool is_reproducible() const override { return m_attr.is_reproducible; } - bool is_reproducible() const override { - return m_is_reproducible; - } + const char* name() const override { return m_attr.name.c_str(); } - const char* name() const override { - return m_name; - } + cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } - cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { - return m_cudnn_enum; - } + bool is_cudnn() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) - bool is_cudnn() const override { - return true; - } + std::string param() const override { + std::string ret; + serialize_write_pod(m_cudnn_enum, ret); + return ret; + } }; //! im2col and matmul, with dilation -class ConvolutionBackwardDataImpl::AlgoMatmul final: public AlgoBase { - template - static void exec_internal(const ExecArgs &args); +class ConvolutionBackwardDataImpl::AlgoMatmul final : public AlgoBase { + template + static void exec_internal(const ExecArgs& args); - public: - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - const char* name() const override { - return "MATMUL"; - } - bool is_reproducible() const override { - return true; - } + const char* name() const override { return "MATMUL"; } + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) }; -class ConvolutionBackwardDataImpl::AlgoChanwise final: public AlgoBase { - public: - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; +class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase { +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - const char* name() const override { - return "CHANNEL_WISE"; - } - bool is_reproducible() const override { - return true; - } + const char* name() const override { return "CHANNEL_WISE"; } + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) }; -class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final: public AlgoBase { - public: - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; +class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final : public AlgoBase { +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - const char* name() const override { - return "CHANNEL_WISE_SMALL"; - } - bool is_reproducible() const override { - return true; - } + const char* name() const override { return "CHANNEL_WISE_SMALL"; } + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) }; class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase { @@ -190,61 +185,72 @@ private: TensorLayout& fsrc, TensorLayout& ffilter, TensorLayout& fdst) const; WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; + MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_algorithm, ret); + return ret; + } }; //! implement group conv by another algo -class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase { - AlgoBase *m_impl; +class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final + : public AlgoBase { + AlgoBase* m_impl; std::string m_name; - public: - AlgoGroupConvGeneral(AlgoBase *impl); +public: + AlgoGroupConvGeneral(AlgoBase* impl); - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - const char* name() const override { - return m_name.c_str(); - } + const char* name() const override { return m_name.c_str(); } - bool is_reproducible() const override { - return m_impl->is_reproducible(); - } + bool is_reproducible() const override { return m_impl->is_reproducible(); } + + static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg, + TensorLayout& grad_pg); + MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) - static void modify_size_args(SizeArgs &args, - TensorLayout &diff_pg, TensorLayout &grad_pg); + std::string param() const override { + std::string ret; + serialize_write_pod(m_impl, ret); + return ret; + } }; -class ConvolutionBackwardDataImpl::AlgoPack { +class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { // defined in cudnn.cpp void fill_cudnn_algos(); - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator = (const AlgoPack &) = delete; + AlgoBase::Mapper m_all_algos_map; - public: - AlgoPack(); +public: + AlgoPack(); - std::vector cudnn; - AlgoMatmul matmul; - AlgoChanwise chanwise; - AlgoChanwiseSmall chanwise_small; - std::vector gconv; - std::unordered_map algo2gconv; - std::vector> bfloat16_refhold; + std::vector cudnn; + AlgoMatmul matmul; + AlgoChanwise chanwise; + AlgoChanwiseSmall chanwise_small; + std::vector gconv; + std::unordered_map algo2gconv; + std::vector> bfloat16_refhold; - std::vector + std::vector //! all algorithms all_algos, //! non-cudnn algos, used for heuristic if cudnn is not supported - non_cudnn_algos, - bfloat16_algos; + non_cudnn_algos, bfloat16_algos; + + AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); - AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/backward_data/bfloat16.cpp b/dnn/src/cuda/convolution/backward_data/bfloat16.cpp index 2c33ae83..d1388ac3 100644 --- a/dnn/src/cuda/convolution/backward_data/bfloat16.cpp +++ b/dnn/src/cuda/convolution/backward_data/bfloat16.cpp @@ -42,7 +42,7 @@ ConvolutionBackwardDataImpl::AlgoBFloat16::float_args( change_dtype(fgrad); opr->param() = args.opr->param(); opr->param().compute_mode = Param::ComputeMode::DEFAULT; - opr->execution_policy() = {m_algorithm}; + opr->execution_policy() = {m_algorithm->info()}; return SizeArgs(opr, ffilter, fdiff, fgrad); } @@ -105,7 +105,7 @@ void ConvolutionBackwardDataImpl::AlgoBFloat16::exec( args.handle->create_operator(); conv_back_data_opr->param() = args.opr->param(); conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT; - conv_back_data_opr->execution_policy() = {m_algorithm}; + conv_back_data_opr->execution_policy() = {m_algorithm->info()}; conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor, cvter.workspace()); } diff --git a/dnn/src/cuda/convolution/backward_data/cudnn.cpp b/dnn/src/cuda/convolution/backward_data/cudnn.cpp index c70c1ca3..3b046972 100644 --- a/dnn/src/cuda/convolution/backward_data/cudnn.cpp +++ b/dnn/src/cuda/convolution/backward_data/cudnn.cpp @@ -98,35 +98,9 @@ void ConvolutionBackwardDataImpl::AlgoCUDNN::exec( } void ConvolutionBackwardDataImpl::AlgoPack::fill_cudnn_algos() { -#define V1(v) #v -#define V(v) V1(v) - -#define DEF_ALGO(NAME, REPROD) \ - cudnn.push_back({ \ - REPROD, #NAME \ - "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ - "." V(CUDNN_PATCHLEVEL), \ - NAME}) - - DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false); - DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true); - DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true); - DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true); -#if CUDNN_MAJOR >= 5 - DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true); -#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 - DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, true); -#endif -#endif - -#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) -#pragma message "not latest cudnn" -#endif - -#undef DEF_ALGO - -#undef V -#undef V1 + for (auto&& algo : CudnnAlgoPack::conv_bwd_data_algos()) { + cudnn.push_back(algo.first); + } } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/backward_filter/algo.cpp b/dnn/src/cuda/convolution/backward_filter/algo.cpp index 601663af..d7c9c4bc 100644 --- a/dnn/src/cuda/convolution/backward_filter/algo.cpp +++ b/dnn/src/cuda/convolution/backward_filter/algo.cpp @@ -49,8 +49,14 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() { all_algos.push_back(bfloat16_refhold.back().get()); bfloat16_algos.push_back(bfloat16_refhold.back().get()); } + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } +MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardFilterImpl) + ConvolutionBackwardFilterImpl::AlgoCUDNN* ConvolutionBackwardFilterImpl::AlgoPack::cudnn_from_enum( cudnnConvolutionBwdFilterAlgo_t algo) { diff --git a/dnn/src/cuda/convolution/backward_filter/algo.h b/dnn/src/cuda/convolution/backward_filter/algo.h index d54f3121..2c8b3563 100644 --- a/dnn/src/cuda/convolution/backward_filter/algo.h +++ b/dnn/src/cuda/convolution/backward_filter/algo.h @@ -6,13 +6,16 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once -#include "src/cuda/convolution/helper.h" #include +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" +#include "src/cuda/convolution/helper.h" namespace megdnn { namespace cuda { @@ -23,141 +26,134 @@ namespace cuda { * All the algo impls should try to support non-contiguous batch dim, for group * conv execution. */ -class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm { - protected: - ~AlgoBase() = default; - - public: - AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } - struct SizeArgs { - HandleImpl *handle; - const TensorLayout *src_layout, *diff_layout, *grad_layout; - CanonizedFilterMeta grad_filter_meta; - ConvolutionBackwardFilterImpl *opr; - - std::string to_string() const; - void init_desc(convolution::CUDNNBwdFilterDescs &desc) const { - desc.set(*src_layout, *diff_layout, grad_filter_meta, - opr->param()); - } - SizeArgs(ConvolutionBackwardFilterImpl *opr, - const TensorLayout &src, const TensorLayout &diff, - const TensorLayout &grad); - SizeArgs(ConvolutionBackwardFilterImpl* opr, - const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad, - const CanonizedFilterMeta& grad_meta); - - convolution::ForwardSizeArgs as_fwd_args() const { - return {handle, src_layout, grad_layout, grad_filter_meta, - diff_layout}; - } - }; - struct ExecArgs: public SizeArgs { - const TensorND *src_tensor, *diff_tensor, *grad_tensor; - Workspace workspace; - - ExecArgs(ConvolutionBackwardFilterImpl *opr, - _megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace); - }; - virtual bool is_available(const SizeArgs &args) const = 0; - virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; - virtual void exec(const ExecArgs &args) const = 0; - - bool is_available_wk(const SizeArgs &args, size_t limit) { - return is_available(args) && get_workspace_in_bytes(args) <= limit; - } +class ConvolutionBackwardFilterImpl::AlgoBase : public Algorithm { +protected: + ~AlgoBase() = default; - bool is_available_reproducible( - const SizeArgs& args, bool reproducible = true, - size_t limit = std::numeric_limits::max()) { - return (!reproducible || is_reproducible()) && - is_available_wk(args, limit); +public: + enum class AlgoType : uint32_t { + CUDA_CUDNN, + CUDA_MATMUL, + CUDA_CHANWISE, + CUDA_BFLOAT16, + CUDA_GROUP_CONV_GENERAL, + }; + using Mapper = std::unordered_map; + + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } + struct SizeArgs { + HandleImpl* handle; + const TensorLayout *src_layout, *diff_layout, *grad_layout; + CanonizedFilterMeta grad_filter_meta; + ConvolutionBackwardFilterImpl* opr; + + std::string to_string() const; + void init_desc(convolution::CUDNNBwdFilterDescs& desc) const { + desc.set(*src_layout, *diff_layout, grad_filter_meta, opr->param()); } - - AlgoBase& check_workspace( - const SizeArgs &args, const Workspace &workspace) { - auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "conv bwd filter algo %s: " - "required workspace %zu bytes, got %zu", - name(), req, workspace.size); - return *this; - } - - virtual bool is_cudnn() const { - return false; + SizeArgs(ConvolutionBackwardFilterImpl* opr, const TensorLayout& src, + const TensorLayout& diff, const TensorLayout& grad); + SizeArgs(ConvolutionBackwardFilterImpl* opr, const TensorLayout& src, + const TensorLayout& diff, const TensorLayout& grad, + const CanonizedFilterMeta& grad_meta); + + convolution::ForwardSizeArgs as_fwd_args() const { + return {handle, src_layout, grad_layout, grad_filter_meta, + diff_layout}; } + }; + struct ExecArgs : public SizeArgs { + const TensorND *src_tensor, *diff_tensor, *grad_tensor; + Workspace workspace; + + ExecArgs(ConvolutionBackwardFilterImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace); + }; + virtual bool is_available(const SizeArgs& args) const = 0; + virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; + virtual void exec(const ExecArgs& args) const = 0; + + bool is_available_wk(const SizeArgs& args, size_t limit) { + return is_available(args) && get_workspace_in_bytes(args) <= limit; + } + + bool is_available_reproducible( + const SizeArgs& args, bool reproducible = true, + size_t limit = std::numeric_limits::max()) { + return (!reproducible || is_reproducible()) && + is_available_wk(args, limit); + } + + AlgoBase& check_workspace(const SizeArgs& args, + const Workspace& workspace) { + auto req = get_workspace_in_bytes(args); + megdnn_assert(req <= workspace.size, + "conv bwd filter algo %s: " + "required workspace %zu bytes, got %zu", + name(), req, workspace.size); + return *this; + } + + virtual bool is_cudnn() const { return false; } }; class ConvolutionBackwardFilterImpl::AlgoCUDNN final : public AlgoBase { - bool m_is_reproducible; - const char *m_name; cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum; + CudnnAlgoPack::Attr m_attr; - public: +public: + AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum) + : m_cudnn_enum(cudnn_enum) { + megdnn_assert(CudnnAlgoPack::conv_bwd_flt_algos().find(cudnn_enum) != + CudnnAlgoPack::conv_bwd_flt_algos().end()); + m_attr = CudnnAlgoPack::conv_bwd_flt_algos().at(cudnn_enum); + } - AlgoCUDNN(bool is_reproducible, const char *name, - cudnnConvolutionBwdFilterAlgo_t cudnn_enum): - m_is_reproducible(is_reproducible), - m_name(name), - m_cudnn_enum(cudnn_enum) - {} + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; + bool is_reproducible() const override { return m_attr.is_reproducible; } - bool is_reproducible() const override { - return m_is_reproducible; - } + const char* name() const override { return m_attr.name.c_str(); } - const char* name() const override { - return m_name; - } + cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { return m_cudnn_enum; } - cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { - return m_cudnn_enum; - } + bool is_cudnn() const override { return true; } - bool is_cudnn() const override { - return true; - } + MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) + std::string param() const override { + std::string ret; + serialize_write_pod(m_cudnn_enum, ret); + return ret; + } }; //! im2col and matmul, with dilation -class ConvolutionBackwardFilterImpl::AlgoMatmul final: public AlgoBase { - template - static void exec_internal(const ExecArgs &args); +class ConvolutionBackwardFilterImpl::AlgoMatmul final : public AlgoBase { + template + static void exec_internal(const ExecArgs& args); - public: - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - const char* name() const override { - return "MATMUL"; - } - bool is_reproducible() const override { - return true; - } + const char* name() const override { return "MATMUL"; } + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) }; -class ConvolutionBackwardFilterImpl::AlgoChanwise final: public AlgoBase { - public: - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; +class ConvolutionBackwardFilterImpl::AlgoChanwise final : public AlgoBase { +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - const char* name() const override { - return "CHANNEL_WISE"; - } - bool is_reproducible() const override { - return true; - } + const char* name() const override { return "CHANNEL_WISE"; } + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) }; class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase { @@ -169,6 +165,13 @@ public: const char* name() const override { return m_name.c_str(); } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_algorithm, ret); + return ret; + } private: std::string m_name; @@ -180,57 +183,62 @@ private: }; //! implement group conv by another algo -class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase { - AlgoBase *m_impl; +class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final + : public AlgoBase { + AlgoBase* m_impl; std::string m_name; - public: - AlgoGroupConvGeneral(AlgoBase *impl); +public: + AlgoGroupConvGeneral(AlgoBase* impl); + + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return m_name.c_str(); } - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; + bool is_reproducible() const override { return m_impl->is_reproducible(); } - const char* name() const override { - return m_name.c_str(); - } + static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, + TensorLayout& diff_pg); - bool is_reproducible() const override { - return m_impl->is_reproducible(); - } + MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) - static void modify_size_args(SizeArgs &args, - TensorLayout &src_pg, TensorLayout &diff_pg); + std::string param() const override { + std::string ret; + serialize_write_pod(m_impl, ret); + return ret; + } }; -class ConvolutionBackwardFilterImpl::AlgoPack { +class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj { // defined in cudnn.cpp void fill_cudnn_algos(); - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator = (const AlgoPack &) = delete; + AlgoBase::Mapper m_all_algos_map; - public: - AlgoPack(); +public: + AlgoPack(); - std::vector cudnn; - AlgoMatmul matmul; - AlgoChanwise chanwise; - std::vector gconv; - std::unordered_map algo2gconv; - std::vector> bfloat16_refhold; + std::vector cudnn; + AlgoMatmul matmul; + AlgoChanwise chanwise; + std::vector gconv; + std::unordered_map algo2gconv; + std::vector> bfloat16_refhold; - std::vector + std::vector //! all algorithms all_algos, //! non-cudnn algos, used for heuristic if cudnn is not supported - non_cudnn_algos, - bfloat16_algos; + non_cudnn_algos, bfloat16_algos; + + AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); - AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/backward_filter/bfloat16.cpp b/dnn/src/cuda/convolution/backward_filter/bfloat16.cpp index 21c98745..efe98b25 100644 --- a/dnn/src/cuda/convolution/backward_filter/bfloat16.cpp +++ b/dnn/src/cuda/convolution/backward_filter/bfloat16.cpp @@ -42,7 +42,7 @@ ConvolutionBackwardFilterImpl::AlgoBFloat16::float_args( change_dtype(fgrad); opr->param() = args.opr->param(); opr->param().compute_mode = Param::ComputeMode::DEFAULT; - opr->execution_policy() = {m_algorithm}; + opr->execution_policy() = {m_algorithm->info()}; return SizeArgs(opr, fsrc, fdiff, fgrad); } @@ -107,7 +107,7 @@ void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec( conv_back_filter_opr->param() = args.opr->param(); conv_back_filter_opr->param().compute_mode = Param::ComputeMode::DEFAULT; - conv_back_filter_opr->execution_policy() = {m_algorithm}; + conv_back_filter_opr->execution_policy() = {m_algorithm->info()}; conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor, cvter.workspace()); } diff --git a/dnn/src/cuda/convolution/backward_filter/cudnn.cpp b/dnn/src/cuda/convolution/backward_filter/cudnn.cpp index 17b31934..30d38dd5 100644 --- a/dnn/src/cuda/convolution/backward_filter/cudnn.cpp +++ b/dnn/src/cuda/convolution/backward_filter/cudnn.cpp @@ -80,35 +80,9 @@ void ConvolutionBackwardFilterImpl::AlgoCUDNN::exec( } void ConvolutionBackwardFilterImpl::AlgoPack::fill_cudnn_algos() { -#define V1(v) #v -#define V(v) V1(v) - -#define DEF_ALGO(NAME, REPROD) \ - cudnn.push_back({ \ - REPROD, #NAME \ - "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ - "." V(CUDNN_PATCHLEVEL), \ - NAME}) - - DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false); - DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true); - DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true); - DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false); -#if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1) - DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, true); -#if CUDNN_MAJOR >= 6 - DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true); -#endif -#endif - -#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) -#pragma message "not latest cudnn" -#endif - -#undef DEF_ALGO - -#undef V -#undef V1 + for(auto&& algo : CudnnAlgoPack::conv_bwd_flt_algos()) { + cudnn.push_back(algo.first); + } } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/opr_impl.cpp b/dnn/src/cuda/convolution/opr_impl.cpp index 7deda097..66ea1c27 100644 --- a/dnn/src/cuda/convolution/opr_impl.cpp +++ b/dnn/src/cuda/convolution/opr_impl.cpp @@ -70,7 +70,7 @@ ConvolutionForwardImpl::conv_bias_extra_data(const TensorLayout& src, conv_param.dilate_w, 0, conv_param.compute_mode}; - ret.convbias_opr->execution_policy() = {this->execution_policy().algorithm}; + ret.convbias_opr->execution_policy() = {this->execution_policy().algo}; return ret; } @@ -183,15 +183,6 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, CUDNNBwdDataDescs desc; args.init_desc(desc); - //disable, segfault in megbrain, need further investigate. -#if 0 - bool is_heuristic_success= convolution:: - PerformanceModelBackwardData::get_algo_backward_data_success( - args, desc, workspace_limit_in_bytes, &algo); - if (is_heuristic_success) { - return sm_algo_pack.cudnn_from_enum(algo); - } -#endif #if CUDNN_MAJOR >= 7 int max_count = 0; cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( diff --git a/dnn/src/cuda/convolution/opr_impl.h b/dnn/src/cuda/convolution/opr_impl.h index 42afb90e..f4693849 100644 --- a/dnn/src/cuda/convolution/opr_impl.h +++ b/dnn/src/cuda/convolution/opr_impl.h @@ -24,14 +24,6 @@ class ConvolutionForwardImpl: public ConvolutionForward { const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) override; - std::vector get_all_algorithms(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst, - size_t workspace_limit_in_bytes, - bool reproducible) override; size_t get_workspace_in_bytes( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, @@ -60,99 +52,129 @@ class ConvolutionForwardImpl: public ConvolutionForward { TensorLayout bias_layout; TensorLayout z_layout; }; - private: - ConvBiasExtraData conv_bias_extra_data(const TensorLayout&, - const TensorLayout&, - const TensorLayout&); -}; -class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { - public: - using ConvolutionBackwardData::ConvolutionBackwardData; - void exec(_megdnn_tensor_in filter, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - std::vector get_all_algorithms(const TensorLayout &filter, - const TensorLayout &diff, - const TensorLayout &grad) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad, + std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst, size_t workspace_limit_in_bytes, bool reproducible) override; - Algorithm* get_algorithm_heuristic( - const TensorLayout& filter, - const CanonizedFilterMeta& filter_meta, - const TensorLayout& diff, const TensorLayout& grad, - size_t workspace_limit_in_bytes, bool reproducible); - size_t get_workspace_in_bytes(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) override; - const char* get_algorithm_set_name() const override; - - class AlgoBase; - class AlgoCUDNN; - class AlgoMatmul; - class AlgoChanwise; - class AlgoChanwiseSmall; - class AlgoGroupConvGeneral; - class AlgoBFloat16; - - class AlgoPack; - - static const AlgoPack& algo_pack() { - return sm_algo_pack; - } private: - static AlgoPack sm_algo_pack; + ConvBiasExtraData conv_bias_extra_data(const TensorLayout&, + const TensorLayout&, + const TensorLayout&); }; -class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { - public: - using ConvolutionBackwardFilter::ConvolutionBackwardFilter; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - std::vector get_all_algorithms(const TensorLayout &src, - const TensorLayout &diff, - const TensorLayout &grad) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_limit_in_bytes, - bool reproducible) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& gradk, - const CanonizedFilterMeta& grad_meta, - size_t workspace_limit_in_bytes, - bool reproducible); - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) override; - const char* get_algorithm_set_name() const override; - - class AlgoBase; - class AlgoCUDNN; - class AlgoMatmul; - class AlgoChanwise; - class AlgoGroupConvGeneral; - class AlgoBFloat16; - - class AlgoPack; - - static const AlgoPack& algo_pack() { - return sm_algo_pack; - } +class ConvolutionBackwardDataImpl : public ConvolutionBackwardData { +public: + using ConvolutionBackwardData::ConvolutionBackwardData; + void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + AlgorithmInfo get_algorithm_info_heuristic( + const TensorLayout& filter, const CanonizedFilterMeta& filter_meta, + const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_limit_in_bytes, bool reproducible) { + return get_algorithm_heuristic(filter, filter_meta, diff, grad, + workspace_limit_in_bytes, reproducible) + ->info(); + } + size_t get_workspace_in_bytes(const TensorLayout& filter, + const TensorLayout& diff, + const TensorLayout& grad) override; + const char* get_algorithm_set_name() const override; + + class AlgoBase; + class AlgoCUDNN; + class AlgoMatmul; + class AlgoChanwise; + class AlgoChanwiseSmall; + class AlgoGroupConvGeneral; + class AlgoBFloat16; + + class AlgoPack; + + static const AlgoPack& algo_pack() { return sm_algo_pack; } + + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + +protected: + std::vector get_all_algorithms( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& filter, + const TensorLayout& diff, + const TensorLayout& grad, + size_t workspace_limit_in_bytes, + bool reproducible) override; +private: + Algorithm* get_algorithm_heuristic(const TensorLayout& filter, + const CanonizedFilterMeta& filter_meta, + const TensorLayout& diff, + const TensorLayout& grad, + size_t workspace_limit_in_bytes, + bool reproducible); + + static AlgoPack sm_algo_pack; +}; - private: - static AlgoPack sm_algo_pack; +class ConvolutionBackwardFilterImpl : public ConvolutionBackwardFilter { +public: + using ConvolutionBackwardFilter::ConvolutionBackwardFilter; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& diff, + const TensorLayout& grad) override; + AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src, + const TensorLayout& diff, + const TensorLayout& grad, + const CanonizedFilterMeta& grad_meta, + size_t workspace_limit_in_bytes, + bool reproducible) { + return get_algorithm_heuristic(src, diff, grad, grad_meta, + workspace_limit_in_bytes, reproducible) + ->info(); + } + + const char* get_algorithm_set_name() const override; + + class AlgoBase; + class AlgoCUDNN; + class AlgoMatmul; + class AlgoChanwise; + class AlgoGroupConvGeneral; + class AlgoBFloat16; + + class AlgoPack; + + static const AlgoPack& algo_pack() { return sm_algo_pack; } + + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + +protected: + std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& diff, + const TensorLayout& grad, + size_t workspace_limit_in_bytes, + bool reproducible) override; +private: + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& diff, + const TensorLayout& grad, + const CanonizedFilterMeta& grad_meta, + size_t workspace_limit_in_bytes, + bool reproducible); + + static AlgoPack sm_algo_pack; }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution3d/backward_data/algo.cpp b/dnn/src/cuda/convolution3d/backward_data/algo.cpp index 9c243c42..3c2b5dab 100644 --- a/dnn/src/cuda/convolution3d/backward_data/algo.cpp +++ b/dnn/src/cuda/convolution3d/backward_data/algo.cpp @@ -39,8 +39,14 @@ Convolution3DBackwardDataImpl::AlgoPack::AlgoPack() { all_algos.push_back(&i); } megdnn_assert(all_algos_data == all_algos.data()); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } +MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DBackwardDataImpl) + Convolution3DBackwardDataImpl::AlgoCUDNN* Convolution3DBackwardDataImpl::AlgoPack::cudnn_from_enum( cudnnConvolutionBwdDataAlgo_t algo) { @@ -96,7 +102,7 @@ std::string Convolution3DBackwardDataImpl::AlgoBase::SizeArgs::to_string() const fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], fm.spatial[2], diff_layout->to_string().c_str(), grad_layout->to_string().c_str(), - fm.padding[0], fm.padding[1], fm.padding[2], + fm.padding[0], fm.padding[1], fm.padding[2], fm.stride[0], fm.stride[1], fm.stride[2], fm.dilation[0], fm.dilation[1] ,fm.dilation[2], !fm.should_flip, diff --git a/dnn/src/cuda/convolution3d/backward_data/algo.h b/dnn/src/cuda/convolution3d/backward_data/algo.h index 2d0baed9..cf134da8 100644 --- a/dnn/src/cuda/convolution3d/backward_data/algo.h +++ b/dnn/src/cuda/convolution3d/backward_data/algo.h @@ -6,13 +6,16 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once -#include "src/cuda/convolution3d/helper.h" #include +#include "src/cuda/convolution3d/helper.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" namespace megdnn { namespace cuda { @@ -23,170 +26,174 @@ namespace cuda { * All the algo impls should try to support non-contiguous batch dim, for group * conv execution. */ -class Convolution3DBackwardDataImpl::AlgoBase: public Algorithm { - protected: - ~AlgoBase() = default; - - public: - AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } - struct SizeArgs { - HandleImpl *handle; - CanonizedFilterMeta filter_meta; - const TensorLayout *diff_layout, *grad_layout; - Convolution3DBackwardDataImpl *opr; - - std::string to_string() const; - void init_desc(convolution3d::CUDNNBwdDataDescs &desc) const { - desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); - } - SizeArgs(Convolution3DBackwardDataImpl *opr, - const TensorLayout &filter, const TensorLayout &diff, - const TensorLayout &grad); - SizeArgs(Convolution3DBackwardDataImpl *opr, - const CanonizedFilterMeta &filter, const TensorLayout &diff, - const TensorLayout &grad); - - convolution3d::ForwardSizeArgs as_fwd_args() const { - return {handle, grad_layout, filter_meta, diff_layout, - opr->param().data_type}; - } - }; - struct ExecArgs: public SizeArgs { - const TensorND *filter_tensor, *diff_tensor, *grad_tensor; - Workspace workspace; - - ExecArgs(Convolution3DBackwardDataImpl *opr, - _megdnn_tensor_in filter, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace); - }; - virtual bool is_available(const SizeArgs &args) const = 0; - virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; - virtual void exec(const ExecArgs &args) const = 0; - - bool is_available_wk(const SizeArgs &args, size_t limit) { - return is_available(args) && get_workspace_in_bytes(args) <= limit; - } - bool is_available_reproducible( - const SizeArgs& args, bool reproducible = true, - size_t limit = std::numeric_limits::max()) { - return (!reproducible || is_reproducible()) && - is_available_wk(args, limit); - } - AlgoBase& check_workspace( - const SizeArgs &args, const Workspace &workspace) { - auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "conv bwd data algo %s: " - "required workspace %zu bytes, got %zu", - name(), req, workspace.size); - return *this; +class Convolution3DBackwardDataImpl::AlgoBase : public Algorithm { +protected: + ~AlgoBase() = default; + +public: + enum class AlgoType : uint32_t { + CUDA_GROUP_CONV_GENERAL, + CUDA_CUDNN, + CUDA_CHANWISE, + }; + using Mapper = std::unordered_map; + + + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } + struct SizeArgs { + HandleImpl* handle; + CanonizedFilterMeta filter_meta; + const TensorLayout *diff_layout, *grad_layout; + Convolution3DBackwardDataImpl* opr; + + std::string to_string() const; + void init_desc(convolution3d::CUDNNBwdDataDescs& desc) const { + desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); } - - virtual bool is_cudnn() const { - return false; + SizeArgs(Convolution3DBackwardDataImpl* opr, const TensorLayout& filter, + const TensorLayout& diff, const TensorLayout& grad); + SizeArgs(Convolution3DBackwardDataImpl* opr, + const CanonizedFilterMeta& filter, const TensorLayout& diff, + const TensorLayout& grad); + + convolution3d::ForwardSizeArgs as_fwd_args() const { + return {handle, grad_layout, filter_meta, diff_layout, + opr->param().data_type}; } + }; + struct ExecArgs : public SizeArgs { + const TensorND *filter_tensor, *diff_tensor, *grad_tensor; + Workspace workspace; + + ExecArgs(Convolution3DBackwardDataImpl* opr, _megdnn_tensor_in filter, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace); + }; + virtual bool is_available(const SizeArgs& args) const = 0; + virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; + virtual void exec(const ExecArgs& args) const = 0; + + bool is_available_wk(const SizeArgs& args, size_t limit) { + return is_available(args) && get_workspace_in_bytes(args) <= limit; + } + bool is_available_reproducible( + const SizeArgs& args, bool reproducible = true, + size_t limit = std::numeric_limits::max()) { + return (!reproducible || is_reproducible()) && + is_available_wk(args, limit); + } + AlgoBase& check_workspace(const SizeArgs& args, + const Workspace& workspace) { + auto req = get_workspace_in_bytes(args); + megdnn_assert(req <= workspace.size, + "conv bwd data algo %s: " + "required workspace %zu bytes, got %zu", + name(), req, workspace.size); + return *this; + } + + virtual bool is_cudnn() const { return false; } }; class Convolution3DBackwardDataImpl::AlgoCUDNN final : public AlgoBase { - bool m_is_reproducible; - const char *m_name; cudnnConvolutionBwdDataAlgo_t m_cudnn_enum; + CudnnAlgoPack::Attr m_attr; - public: +public: + AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum) + : m_cudnn_enum(cudnn_enum) { + megdnn_assert(CudnnAlgoPack::conv3d_bwd_data_algos().find(cudnn_enum) != + CudnnAlgoPack::conv3d_bwd_data_algos().end()); + m_attr = CudnnAlgoPack::conv3d_bwd_data_algos().at(cudnn_enum); + } - AlgoCUDNN(bool is_reproducible, const char *name, - cudnnConvolutionBwdDataAlgo_t cudnn_enum): - m_is_reproducible(is_reproducible), - m_name(name), - m_cudnn_enum(cudnn_enum) - {} + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; + bool is_reproducible() const override { return m_attr.is_reproducible; } - bool is_reproducible() const override { - return m_is_reproducible; - } + const char* name() const override { return m_attr.name.c_str(); } - const char* name() const override { - return m_name; - } + cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } - cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { - return m_cudnn_enum; - } + bool is_cudnn() const override { return true; } - bool is_cudnn() const override { - return true; - } + MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_cudnn_enum, ret); + return ret; + } }; -class Convolution3DBackwardDataImpl::AlgoChanwise final: public AlgoBase { - public: - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; +class Convolution3DBackwardDataImpl::AlgoChanwise final : public AlgoBase { +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - const char* name() const override { - return "CHANNEL_WISE"; - } - bool is_reproducible() const override { - return true; - } + const char* name() const override { return "CHANNEL_WISE"; } + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) }; //! implement group conv by another algo -class Convolution3DBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase { - AlgoBase *m_impl; +class Convolution3DBackwardDataImpl::AlgoGroupConvGeneral final + : public AlgoBase { + AlgoBase* m_impl; std::string m_name; - public: - AlgoGroupConvGeneral(AlgoBase *impl); +public: + AlgoGroupConvGeneral(AlgoBase* impl); - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - const char* name() const override { - return m_name.c_str(); - } + const char* name() const override { return m_name.c_str(); } - bool is_reproducible() const override { - return m_impl->is_reproducible(); - } + bool is_reproducible() const override { return m_impl->is_reproducible(); } + + static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg, + TensorLayout& grad_pg); - static void modify_size_args(SizeArgs &args, - TensorLayout &diff_pg, TensorLayout &grad_pg); + MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) + std::string param() const override { + std::string ret; + serialize_write_pod(m_impl, ret); + return ret; + } }; -class Convolution3DBackwardDataImpl::AlgoPack { + +class Convolution3DBackwardDataImpl::AlgoPack : NonCopyableObj { // defined in cudnn.cpp void fill_cudnn_algos(); - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator = (const AlgoPack &) = delete; + AlgoBase::Mapper m_all_algos_map; - public: - AlgoPack(); +public: + AlgoPack(); - std::vector cudnn; - AlgoChanwise chanwise; - std::vector gconv; - std::unordered_map algo2gconv; + std::vector cudnn; + AlgoChanwise chanwise; + std::vector gconv; + std::unordered_map algo2gconv; - std::vector + std::vector //! all algorithms all_algos, //! non-cudnn algos, used for heuristic if cudnn is not supported non_cudnn_algos; - AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); + AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); + + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution3d/backward_data/cudnn.cpp b/dnn/src/cuda/convolution3d/backward_data/cudnn.cpp index 01caa236..271b07cd 100644 --- a/dnn/src/cuda/convolution3d/backward_data/cudnn.cpp +++ b/dnn/src/cuda/convolution3d/backward_data/cudnn.cpp @@ -80,27 +80,9 @@ void Convolution3DBackwardDataImpl::AlgoCUDNN::exec( } void Convolution3DBackwardDataImpl::AlgoPack::fill_cudnn_algos() { -#define V1(v) #v -#define V(v) V1(v) - -#define DEF_ALGO(NAME, REPROD) \ - cudnn.push_back({ \ - REPROD, #NAME \ - "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ - "." V(CUDNN_PATCHLEVEL), \ - NAME}) - -DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false); -DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true); -DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true); -#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) -#pragma message "not latest cudnn" -#endif - -#undef DEF_ALGO - -#undef V -#undef V1 + for (auto&& algo : CudnnAlgoPack::conv3d_bwd_data_algos()) { + cudnn.push_back(algo.first); + } } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution3d/backward_filter/algo.cpp b/dnn/src/cuda/convolution3d/backward_filter/algo.cpp index 0af54db1..5533bbe3 100644 --- a/dnn/src/cuda/convolution3d/backward_filter/algo.cpp +++ b/dnn/src/cuda/convolution3d/backward_filter/algo.cpp @@ -17,7 +17,7 @@ using namespace cuda; Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() { non_cudnn_algos.push_back(&chanwise); - non_cudnn_algos.push_back(&inplace_matmul); + non_cudnn_algos.push_back(&inplace_matmul); all_algos.push_back(&chanwise); // prefer chanwise fill_cudnn_algos(); @@ -41,8 +41,14 @@ Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() { } megdnn_assert(all_algos_data == all_algos.data()); non_cudnn_algos.push_back(all_algos.rbegin()[0]); //group inplace_matmul + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } +MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DBackwardFilterImpl) + Convolution3DBackwardFilterImpl::AlgoCUDNN* Convolution3DBackwardFilterImpl::AlgoPack::cudnn_from_enum( cudnnConvolutionBwdFilterAlgo_t algo) { @@ -99,9 +105,9 @@ Convolution3DBackwardFilterImpl::AlgoBase::SizeArgs::to_string() const { "pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s", src_layout->to_string().c_str(), diff_layout->to_string().c_str(), - fm.group, fm.ocpg, fm.icpg, + fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], fm.spatial[2], - fm.padding[0], fm.padding[1], fm.padding[2], + fm.padding[0], fm.padding[1], fm.padding[2], fm.stride[0], fm.stride[1], fm.stride[2], fm.dilation[0], fm.dilation[1], fm.dilation[2], !fm.should_flip, diff --git a/dnn/src/cuda/convolution3d/backward_filter/algo.h b/dnn/src/cuda/convolution3d/backward_filter/algo.h index 9ce504ec..8c71eed9 100644 --- a/dnn/src/cuda/convolution3d/backward_filter/algo.h +++ b/dnn/src/cuda/convolution3d/backward_filter/algo.h @@ -6,198 +6,198 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once -#include "src/cuda/convolution3d/helper.h" #include +#include "src/cuda/convolution3d/helper.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" namespace megdnn { namespace cuda { -class Convolution3DBackwardFilterImpl::AlgoBase: public Algorithm { - protected: - ~AlgoBase() = default; - - public: - AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } - struct SizeArgs { - HandleImpl *handle; - const TensorLayout *src_layout, *diff_layout; - CanonizedFilterMeta grad_filter_meta; - Convolution3DBackwardFilterImpl *opr; - - std::string to_string() const; - void init_desc(convolution3d::CUDNNBwdFilterDescs &desc) const { - desc.set(*src_layout, *diff_layout, grad_filter_meta, - opr->param()); - } - SizeArgs(Convolution3DBackwardFilterImpl *opr, - const TensorLayout &src, const TensorLayout &diff, - const TensorLayout &grad); - SizeArgs(Convolution3DBackwardFilterImpl *opr, - const TensorLayout &src, const TensorLayout &diff, - const CanonizedFilterMeta &grad); - - convolution3d::ForwardSizeArgs as_fwd_args() const { - return {handle, src_layout, grad_filter_meta, diff_layout, - opr->param().data_type}; - } - }; - struct ExecArgs: public SizeArgs { - const TensorND *src_tensor, *diff_tensor, *grad_tensor; - Workspace workspace; - - ExecArgs(Convolution3DBackwardFilterImpl *opr, - _megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace); - }; - virtual bool is_available(const SizeArgs &args) const = 0; - virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; - virtual void exec(const ExecArgs &args) const = 0; - - bool is_available_wk(const SizeArgs &args, size_t limit) { - return is_available(args) && get_workspace_in_bytes(args) <= limit; - } - bool is_available_reproducible( - const SizeArgs& args, bool reproducible = true, - size_t limit = std::numeric_limits::max()) { - return (!reproducible || is_reproducible()) && - is_available_wk(args, limit); - } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { - auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "conv bwd filter algo %s: " - "required workspace %zu bytes, got %zu", - name(), req, workspace.size); - return *this; +class Convolution3DBackwardFilterImpl::AlgoBase : public Algorithm { +protected: + ~AlgoBase() = default; + +public: + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } + enum class AlgoType : uint32_t { + CUDA_GROUP_CONV_GENERAL, + CUDA_CUDNN, + CUDA_INPLACE_MATMUL, + CUDA_CHANWISE, + }; + using Mapper = std::unordered_map; + + struct SizeArgs { + HandleImpl* handle; + const TensorLayout *src_layout, *diff_layout; + CanonizedFilterMeta grad_filter_meta; + Convolution3DBackwardFilterImpl* opr; + + std::string to_string() const; + void init_desc(convolution3d::CUDNNBwdFilterDescs& desc) const { + desc.set(*src_layout, *diff_layout, grad_filter_meta, opr->param()); } + SizeArgs(Convolution3DBackwardFilterImpl* opr, const TensorLayout& src, + const TensorLayout& diff, const TensorLayout& grad); + SizeArgs(Convolution3DBackwardFilterImpl* opr, const TensorLayout& src, + const TensorLayout& diff, const CanonizedFilterMeta& grad); - virtual bool is_cudnn() const { - return false; + convolution3d::ForwardSizeArgs as_fwd_args() const { + return {handle, src_layout, grad_filter_meta, diff_layout, + opr->param().data_type}; } + }; + struct ExecArgs : public SizeArgs { + const TensorND *src_tensor, *diff_tensor, *grad_tensor; + Workspace workspace; + + ExecArgs(Convolution3DBackwardFilterImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace); + }; + virtual bool is_available(const SizeArgs& args) const = 0; + virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; + virtual void exec(const ExecArgs& args) const = 0; + + bool is_available_wk(const SizeArgs& args, size_t limit) { + return is_available(args) && get_workspace_in_bytes(args) <= limit; + } + bool is_available_reproducible( + const SizeArgs& args, bool reproducible = true, + size_t limit = std::numeric_limits::max()) { + return (!reproducible || is_reproducible()) && + is_available_wk(args, limit); + } + AlgoBase& check_workspace(const SizeArgs& args, + const Workspace& workspace) { + auto req = get_workspace_in_bytes(args); + megdnn_assert(req <= workspace.size, + "conv bwd filter algo %s: " + "required workspace %zu bytes, got %zu", + name(), req, workspace.size); + return *this; + } + + virtual bool is_cudnn() const { return false; } }; class Convolution3DBackwardFilterImpl::AlgoCUDNN final : public AlgoBase { - bool m_is_reproducible; - const char *m_name; cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum; + CudnnAlgoPack::Attr m_attr; - public: +public: + AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum) + : m_cudnn_enum(cudnn_enum) { + megdnn_assert(CudnnAlgoPack::conv3d_bwd_flt_algos().find(cudnn_enum) != + CudnnAlgoPack::conv3d_bwd_flt_algos().end()); + m_attr = CudnnAlgoPack::conv3d_bwd_flt_algos().at(cudnn_enum); + } - AlgoCUDNN(bool is_reproducible, const char *name, - cudnnConvolutionBwdFilterAlgo_t cudnn_enum): - m_is_reproducible(is_reproducible), - m_name(name), - m_cudnn_enum(cudnn_enum) - {} + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; + bool is_reproducible() const override { return m_attr.is_reproducible; } - bool is_reproducible() const override { - return m_is_reproducible; - } + const char* name() const override { return m_attr.name.c_str(); } - const char* name() const override { - return m_name; - } + cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { return m_cudnn_enum; } - cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { - return m_cudnn_enum; - } + bool is_cudnn() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) - bool is_cudnn() const override { - return true; - } + std::string param() const override { + std::string ret; + serialize_write_pod(m_cudnn_enum, ret); + return ret; + } }; +class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final + : public AlgoBase { +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; -class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final: public AlgoBase { - public: - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; - - const char* name() const override { - return "INPLACE_MATMUL"; - } - bool is_reproducible() const override { - return false; - } + const char* name() const override { return "INPLACE_MATMUL"; } + bool is_reproducible() const override { return false; } + MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) }; -class Convolution3DBackwardFilterImpl::AlgoChanwise final: public AlgoBase { - public: - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; +class Convolution3DBackwardFilterImpl::AlgoChanwise final : public AlgoBase { +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - const char* name() const override { - return "CHANNEL_WISE"; - } - bool is_reproducible() const override { - return true; - } + const char* name() const override { return "CHANNEL_WISE"; } + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) }; //! implement group conv by another algo -class Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase { - AlgoBase *m_impl; +class Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral final + : public AlgoBase { + AlgoBase* m_impl; std::string m_name; - public: - AlgoGroupConvGeneral(AlgoBase *impl); +public: + AlgoGroupConvGeneral(AlgoBase* impl); - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - const char* name() const override { - return m_name.c_str(); - } + const char* name() const override { return m_name.c_str(); } - bool is_reproducible() const override { - return m_impl->is_reproducible(); - } + bool is_reproducible() const override { return m_impl->is_reproducible(); } - static void modify_size_args(SizeArgs &args, - TensorLayout &src_pg, TensorLayout &diff_pg); + static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, + TensorLayout& diff_pg); + + MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) + std::string param() const override { + std::string ret; + serialize_write_pod(m_impl, ret); + return ret; + } }; -class Convolution3DBackwardFilterImpl::AlgoPack { +class Convolution3DBackwardFilterImpl::AlgoPack : NonCopyableObj { // defined in cudnn.cpp void fill_cudnn_algos(); - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator = (const AlgoPack &) = delete; + AlgoBase::Mapper m_all_algos_map; - public: - AlgoPack(); +public: + AlgoPack(); - std::vector cudnn; - AlgoInplaceMatmul inplace_matmul; - AlgoChanwise chanwise; - std::vector gconv; - std::unordered_map algo2gconv; + std::vector cudnn; + AlgoInplaceMatmul inplace_matmul; + AlgoChanwise chanwise; + std::vector gconv; + std::unordered_map algo2gconv; - std::vector + std::vector //! all algorithms all_algos, //! non-cudnn algos, used for heuristic if cudnn is not supported non_cudnn_algos; - AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); + AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); + + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp b/dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp index 1ff883db..5662b9ec 100644 --- a/dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp +++ b/dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp @@ -66,29 +66,9 @@ void Convolution3DBackwardFilterImpl::AlgoCUDNN::exec( } void Convolution3DBackwardFilterImpl::AlgoPack::fill_cudnn_algos() { -#define V1(v) #v -#define V(v) V1(v) - -#define DEF_ALGO(NAME, REPROD) \ - cudnn.push_back({REPROD, \ - #NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V( \ - CUDNN_PATCHLEVEL), \ - NAME}) - - DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false); -#pragma message \ - "fp16 dilated conv with odd size filter, only algo_1 works, need focus on doc" - DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true); - DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false); - -#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) -#pragma message "not latest cudnn" -#endif - -#undef DEF_ALGO - -#undef V -#undef V1 + for (auto&& algo : CudnnAlgoPack::conv3d_bwd_flt_algos()) { + cudnn.push_back(algo.first); + } } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution3d/forward/algo.cpp b/dnn/src/cuda/convolution3d/forward/algo.cpp index 231c0f33..1b52a0e8 100644 --- a/dnn/src/cuda/convolution3d/forward/algo.cpp +++ b/dnn/src/cuda/convolution3d/forward/algo.cpp @@ -21,13 +21,13 @@ Convolution3DForwardImpl::AlgoPack::AlgoPack() { non_cudnn_algos.push_back(&a1x1x1); all_algos.push_back(&chanwise); - + fill_cudnn_algos(); for (auto &&i: cudnn) { - all_algos.push_back(&i); + all_algos.push_back(&i); } all_algos.push_back(&inplace_matmul); - all_algos.push_back(&a1x1x1); + all_algos.push_back(&a1x1x1); all_algos.reserve(all_algos.size() * 2); // add gconv algos by AlgoGroupConvGeneral @@ -42,10 +42,16 @@ Convolution3DForwardImpl::AlgoPack::AlgoPack() { all_algos.push_back(&i); } megdnn_assert(all_algos_data == all_algos.data()); - non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group inplace_matmul + non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group inplace_matmul non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1x1 + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } +MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DForwardImpl) + Convolution3DForwardImpl::AlgoCUDNN* Convolution3DForwardImpl::AlgoPack::cudnn_from_enum( cudnnConvolutionFwdAlgo_t algo) { @@ -99,7 +105,7 @@ std::string Convolution3DForwardImpl::AlgoBase::SizeArgs::to_string() const { "src=%s, filter=%u{%u,%u,%u,%u,%u}, dst=%s, " "pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s", src_layout->to_string().c_str(), - fm.group, fm.ocpg, fm.icpg, + fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], fm.spatial[2], dst_layout->to_string().c_str(), fm.padding[0], fm.padding[1], fm.padding[2], diff --git a/dnn/src/cuda/convolution3d/forward/algo.h b/dnn/src/cuda/convolution3d/forward/algo.h index 726dcbaf..e6073dfa 100644 --- a/dnn/src/cuda/convolution3d/forward/algo.h +++ b/dnn/src/cuda/convolution3d/forward/algo.h @@ -6,17 +6,20 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/oprs.h" +#include "src/common/utils.h" #include "src/cuda/convolution3d/helper.h" -#include "src/cuda/handle.h" #include "src/cuda/convolution3d/opr_impl.h" -#include "src/common/utils.h" +#include "src/cuda/handle.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" #include @@ -29,195 +32,189 @@ namespace cuda { * All the algo impls should try to support non-contiguous batch dim, for group * conv execution. */ -class Convolution3DForwardImpl::AlgoBase: public Algorithm { - protected: - ~AlgoBase() = default; - - public: - AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } - struct SizeArgs: public convolution3d::ForwardSizeArgs { - Convolution3DForwardImpl *opr; - - std::string to_string() const; - void init_desc(convolution3d::CUDNNForwardDescs &desc) const { - desc.set(*src_layout, filter_meta, *dst_layout, opr->param()); - } - SizeArgs(Convolution3DForwardImpl *opr, - const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst); - SizeArgs(Convolution3DForwardImpl *opr, - const TensorLayout &src, - const CanonizedFilterMeta &filter, - const TensorLayout &dst); - }; - struct ExecArgs : public SizeArgs { - const TensorND *src_tensor, *filter_tensor, *dst_tensor; - Workspace workspace; - - ExecArgs(Convolution3DForwardImpl *opr, - _megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - _megdnn_workspace workspace); - }; - virtual bool is_available(const SizeArgs &args) const = 0; - virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; - virtual void exec(const ExecArgs &args) const = 0; - - bool is_available_wk(const SizeArgs &args, size_t limit) { - return is_available(args) && get_workspace_in_bytes(args) <= limit; - } - bool is_available_reproducible( - const SizeArgs& args, bool reproducible = true, - size_t limit = std::numeric_limits::max()) { - return (!reproducible || is_reproducible()) && - is_available_wk(args, limit); - } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { - auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "conv3d fwd algo %s: required workspace %zu bytes, got %zu", - name(), req, workspace.size); - return *this; - } - - virtual bool is_cudnn() const { - return false; - } +class Convolution3DForwardImpl::AlgoBase : public Algorithm { +protected: + ~AlgoBase() = default; + +public: + enum class AlgoType : uint32_t { + CUDA_1X1X1, + CUDA_GROUP_CONV_GENERAL, + CUDA_CUDNN, + CUDA_INPLACE_MATMUL, + CUDA_CHANWISE, + }; + using Mapper = std::unordered_map; + + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } + struct SizeArgs : public convolution3d::ForwardSizeArgs { + Convolution3DForwardImpl* opr; + + std::string to_string() const; + void init_desc(convolution3d::CUDNNForwardDescs& desc) const { + desc.set(*src_layout, filter_meta, *dst_layout, opr->param()); + } + SizeArgs(Convolution3DForwardImpl* opr, const TensorLayout& src, + const TensorLayout& filter, const TensorLayout& dst); + SizeArgs(Convolution3DForwardImpl* opr, const TensorLayout& src, + const CanonizedFilterMeta& filter, const TensorLayout& dst); + }; + struct ExecArgs : public SizeArgs { + const TensorND *src_tensor, *filter_tensor, *dst_tensor; + Workspace workspace; + + ExecArgs(Convolution3DForwardImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace); + }; + virtual bool is_available(const SizeArgs& args) const = 0; + virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; + virtual void exec(const ExecArgs& args) const = 0; + + bool is_available_wk(const SizeArgs& args, size_t limit) { + return is_available(args) && get_workspace_in_bytes(args) <= limit; + } + bool is_available_reproducible( + const SizeArgs& args, bool reproducible = true, + size_t limit = std::numeric_limits::max()) { + return (!reproducible || is_reproducible()) && + is_available_wk(args, limit); + } + AlgoBase& check_workspace(const SizeArgs& args, + const Workspace& workspace) { + auto req = get_workspace_in_bytes(args); + megdnn_assert( + req <= workspace.size, + "conv3d fwd algo %s: required workspace %zu bytes, got %zu", + name(), req, workspace.size); + return *this; + } + + virtual bool is_cudnn() const { return false; } }; -class Convolution3DForwardImpl::Algo1x1x1 final: public AlgoBase { - static void extract_matmul_layouts(const SizeArgs &args, - TensorLayout &A, TensorLayout &B, TensorLayout &C); - public: - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; - - const char* name() const override { - return "1x1x1"; - } - bool is_reproducible() const override { - return true; - } +class Convolution3DForwardImpl::Algo1x1x1 final : public AlgoBase { + static void extract_matmul_layouts(const SizeArgs& args, TensorLayout& A, + TensorLayout& B, TensorLayout& C); + +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return "1x1x1"; } + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_1X1X1) }; //! implement group conv by another algo -class Convolution3DForwardImpl::AlgoGroupConvGeneral final: public AlgoBase { - AlgoBase *m_impl; +class Convolution3DForwardImpl::AlgoGroupConvGeneral final : public AlgoBase { + AlgoBase* m_impl; std::string m_name; - public: - AlgoGroupConvGeneral(AlgoBase *impl); +public: + AlgoGroupConvGeneral(AlgoBase* impl); - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - const char* name() const override { - return m_name.c_str(); - } + const char* name() const override { return m_name.c_str(); } - bool is_reproducible() const override { - return m_impl->is_reproducible(); - } + bool is_reproducible() const override { return m_impl->is_reproducible(); } - static void modify_size_args(SizeArgs &args, - TensorLayout &src_pg, TensorLayout &dst_pg); + static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, + TensorLayout& dst_pg); + MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) + std::string param() const override { + std::string ret; + serialize_write_pod(m_impl, ret); + return ret; + } }; class Convolution3DForwardImpl::AlgoCUDNN final : public AlgoBase { - bool m_is_reproducible; - const char *m_name; cudnnConvolutionFwdAlgo_t m_cudnn_enum; + CudnnAlgoPack::Attr m_attr; - public: +public: + AlgoCUDNN(cudnnConvolutionFwdAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) { + megdnn_assert(CudnnAlgoPack::conv3d_fwd_algos().find(cudnn_enum) != + CudnnAlgoPack::conv3d_fwd_algos().end()); + m_attr = CudnnAlgoPack::conv3d_fwd_algos().at(cudnn_enum); + } - AlgoCUDNN(bool is_reproducible, const char *name, - cudnnConvolutionFwdAlgo_t cudnn_enum): - m_is_reproducible(is_reproducible), - m_name(name), - m_cudnn_enum(cudnn_enum) - {} + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; + bool is_reproducible() const override { return m_attr.is_reproducible; } - bool is_reproducible() const override { - return m_is_reproducible; - } + const char* name() const override { return m_attr.name.c_str(); } - const char* name() const override { - return m_name; - } + cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; } - cudnnConvolutionFwdAlgo_t cudnn_enum() const { - return m_cudnn_enum; - } + bool is_cudnn() const override { return true; } - bool is_cudnn() const override { - return true; - } -}; + MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) -class Convolution3DForwardImpl::AlgoInplaceMatmul final: public AlgoBase { - public: - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; + std::string param() const override { + std::string ret; + serialize_write_pod(m_cudnn_enum, ret); + return ret; + } - const char* name() const override { - return "INPLACE_MATMUL"; - } - bool is_reproducible() const override { - return true; - } }; +class Convolution3DForwardImpl::AlgoInplaceMatmul final : public AlgoBase { +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; -class Convolution3DForwardImpl::AlgoChanwise final: public AlgoBase { - public: - bool is_available(const SizeArgs &args) const override; - size_t get_workspace_in_bytes(const SizeArgs &args) const override; - void exec(const ExecArgs &args) const override; + const char* name() const override { return "INPLACE_MATMUL"; } + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) +}; - const char* name() const override { - return "CHANNEL_WISE"; - } - bool is_reproducible() const override { - return true; - } +class Convolution3DForwardImpl::AlgoChanwise final : public AlgoBase { +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return "CHANNEL_WISE"; } + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) }; -class Convolution3DForwardImpl::AlgoPack { +class Convolution3DForwardImpl::AlgoPack : NonCopyableObj { // defined in cudnn.cpp void fill_cudnn_algos(); - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator = (const AlgoPack &) = delete; + AlgoBase::Mapper m_all_algos_map; - public: - AlgoPack(); +public: + AlgoPack(); - std::vector cudnn; - Algo1x1x1 a1x1x1; - AlgoInplaceMatmul inplace_matmul; - AlgoChanwise chanwise; - std::vector gconv; - std::unordered_map algo2gconv; + std::vector cudnn; + Algo1x1x1 a1x1x1; + AlgoInplaceMatmul inplace_matmul; + AlgoChanwise chanwise; + std::vector gconv; + std::unordered_map algo2gconv; - std::vector + std::vector //! all algorithms all_algos, //! non-cudnn algos, used for heuristic if cudnn is not supported non_cudnn_algos; - AlgoCUDNN* cudnn_from_enum(cudnnConvolutionFwdAlgo_t algo); + AlgoCUDNN* cudnn_from_enum(cudnnConvolutionFwdAlgo_t algo); + + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution3d/forward/cudnn.cpp b/dnn/src/cuda/convolution3d/forward/cudnn.cpp index 178d373a..6dfcbc59 100644 --- a/dnn/src/cuda/convolution3d/forward/cudnn.cpp +++ b/dnn/src/cuda/convolution3d/forward/cudnn.cpp @@ -78,30 +78,10 @@ void Convolution3DForwardImpl::AlgoCUDNN::exec( cudnnGetErrorString(status), args.to_string().c_str()); } - void Convolution3DForwardImpl::AlgoPack::fill_cudnn_algos() { -#define V1(v) #v -#define V(v) V1(v) - -#define DEF_ALGO(NAME, REPROD) \ - cudnn.push_back({ \ - REPROD, #NAME \ - "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ - "." V(CUDNN_PATCHLEVEL), \ - NAME}) - -DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true); -DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true); -DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true); - -#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) -#pragma message "not latest cudnn" -#endif - -#undef DEF_ALGO - -#undef V -#undef V1 + for (auto&& algo : CudnnAlgoPack::conv3d_fwd_algos()) { + cudnn.push_back(algo.first); + } } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution3d/opr_impl.h b/dnn/src/cuda/convolution3d/opr_impl.h index 120b1fa2..a20249bb 100644 --- a/dnn/src/cuda/convolution3d/opr_impl.h +++ b/dnn/src/cuda/convolution3d/opr_impl.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -15,126 +16,155 @@ namespace megdnn { namespace cuda { -class Convolution3DForwardImpl: public Convolution3DForward { - public: - using Convolution3DForward::Convolution3DForward; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - std::vector get_all_algorithms(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst, - size_t workspace_limit_in_bytes, - bool reproducible) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& src, - const CanonizedFilterMeta& filter, - const TensorLayout& dst, - size_t workspace_limit_in_bytes, - bool reproducible); - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) override; - const char* get_algorithm_set_name() const override; - class AlgoBase; - class AlgoCUDNN; - class Algo1x1x1; - class AlgoInplaceMatmul; - class AlgoChanwise; - class AlgoGroupConvGeneral; - class AlgoPack; - static const AlgoPack& algo_pack() { - return sm_algo_pack; - } - private: - static AlgoPack sm_algo_pack; +class Convolution3DForwardImpl : public Convolution3DForward { +public: + using Convolution3DForward::Convolution3DForward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; + AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src, + const CanonizedFilterMeta& filter, + const TensorLayout& dst, + size_t workspace_limit_in_bytes, + bool reproducible) { + return get_algorithm_heuristic(src, filter, dst, + workspace_limit_in_bytes, reproducible) + ->info(); + } + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst) override; + const char* get_algorithm_set_name() const override; + class AlgoBase; + class AlgoCUDNN; + class Algo1x1x1; + class AlgoInplaceMatmul; + class AlgoChanwise; + class AlgoGroupConvGeneral; + class AlgoPack; + static const AlgoPack& algo_pack() { return sm_algo_pack; } + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + +protected: + std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst, + size_t workspace_limit_in_bytes, + bool reproducible) override; + +private: + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const CanonizedFilterMeta& filter, + const TensorLayout& dst, + size_t workspace_limit_in_bytes, + bool reproducible); + + + static AlgoPack sm_algo_pack; }; -class Convolution3DBackwardDataImpl: public Convolution3DBackwardData { - public: - using Convolution3DBackwardData::Convolution3DBackwardData; - void exec(_megdnn_tensor_in filter, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - std::vector get_all_algorithms(const TensorLayout &filter, - const TensorLayout &diff, - const TensorLayout &grad) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_limit_in_bytes, - bool reproducible) override; - Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_limit_in_bytes, - bool reproducible); - size_t get_workspace_in_bytes(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) override; - const char* get_algorithm_set_name() const override; - - class AlgoBase; - class AlgoCUDNN; - class AlgoInplaceMatmul; - class AlgoChanwise; - class AlgoGroupConvGeneral; - - class AlgoPack; - - static const AlgoPack& algo_pack() { - return sm_algo_pack; - } - - private: - static AlgoPack sm_algo_pack; +class Convolution3DBackwardDataImpl : public Convolution3DBackwardData { +public: + using Convolution3DBackwardData::Convolution3DBackwardData; + void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + AlgorithmInfo get_algorithm_info_heuristic( + const CanonizedFilterMeta& filter, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_limit_in_bytes, + bool reproducible) { + return get_algorithm_heuristic(filter, diff, grad, + workspace_limit_in_bytes, reproducible) + ->info(); + } + size_t get_workspace_in_bytes(const TensorLayout& filter, + const TensorLayout& diff, + const TensorLayout& grad) override; + const char* get_algorithm_set_name() const override; + + class AlgoBase; + class AlgoCUDNN; + class AlgoInplaceMatmul; + class AlgoChanwise; + class AlgoGroupConvGeneral; + + class AlgoPack; + + static const AlgoPack& algo_pack() { return sm_algo_pack; } + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + +protected: + std::vector get_all_algorithms( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& filter, + const TensorLayout& diff, + const TensorLayout& grad, + size_t workspace_limit_in_bytes, + bool reproducible) override; + +private: + Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, + const TensorLayout& diff, + const TensorLayout& grad, + size_t workspace_limit_in_bytes, + bool reproducible); + + static AlgoPack sm_algo_pack; }; -class Convolution3DBackwardFilterImpl: public Convolution3DBackwardFilter { - public: - using Convolution3DBackwardFilter::Convolution3DBackwardFilter; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - std::vector get_all_algorithms(const TensorLayout &src, - const TensorLayout &diff, - const TensorLayout &grad) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_limit_in_bytes, - bool reproducible) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& src, - const TensorLayout& diff, - const CanonizedFilterMeta& grad, - size_t workspace_limit_in_bytes, - bool reproducible); - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) override; - const char* get_algorithm_set_name() const override; - - class AlgoBase; - class AlgoCUDNN; - class AlgoInplaceMatmul; - class AlgoChanwise; - class AlgoGroupConvGeneral; - - class AlgoPack; - - static const AlgoPack& algo_pack() { - return sm_algo_pack; - } - - private: - static AlgoPack sm_algo_pack; +class Convolution3DBackwardFilterImpl : public Convolution3DBackwardFilter { +public: + using Convolution3DBackwardFilter::Convolution3DBackwardFilter; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& diff, + const TensorLayout& grad) override; + AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src, + const TensorLayout& diff, + const CanonizedFilterMeta& grad, + size_t workspace_limit_in_bytes, + bool reproducible) { + return get_algorithm_heuristic(src, diff, grad, + workspace_limit_in_bytes, reproducible) + ->info(); + } + + const char* get_algorithm_set_name() const override; + + class AlgoBase; + class AlgoCUDNN; + class AlgoInplaceMatmul; + class AlgoChanwise; + class AlgoGroupConvGeneral; + + class AlgoPack; + + static const AlgoPack& algo_pack() { return sm_algo_pack; } + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + +protected: + std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& diff, + const TensorLayout& grad, + size_t workspace_limit_in_bytes, + bool reproducible) override; + +private: + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& diff, + const CanonizedFilterMeta& grad, + size_t workspace_limit_in_bytes, + bool reproducible); + + static AlgoPack sm_algo_pack; }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/cudnn_wrapper.cpp b/dnn/src/cuda/cudnn_wrapper.cpp index 62217f69..64ecf4ac 100644 --- a/dnn/src/cuda/cudnn_wrapper.cpp +++ b/dnn/src/cuda/cudnn_wrapper.cpp @@ -433,6 +433,137 @@ void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) { desc, 3, padA, filterStrideA, dilationA, mode, CUDNN_DATA_FLOAT)); } +////////////////////////// CudnnAlgoPack ////////////////////////// + +#define V1(v) #v +#define V(v) V1(v) +#define DEF_NAME(NAME) \ + #NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) +#define DEF_ALGO(NAME, PROD) \ + { \ + NAME, { DEF_NAME(NAME), PROD } \ + } + +#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) +#pragma message "not latest cudnn" +#endif + +const std::unordered_map +CudnnAlgoPack::conv_bwd_data_algos() { + static const std::unordered_map + algos = { + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true), +#if CUDNN_MAJOR >= 5 + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true), +#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, + true), +#endif +#endif + }; + + return algos; +} + +const std::unordered_map +CudnnAlgoPack::conv_bwd_flt_algos() { + static const std::unordered_map + algos = { + DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false), +#if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1) + DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, + true), +#if CUDNN_MAJOR >= 6 + DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true), +#endif +#endif + + }; + + return algos; +} + + +const std::unordered_map +CudnnAlgoPack::conv_fwd_algos() { + static const std::unordered_map + algos = { + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + true), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true), + +#if CUDNN_MAJOR >= 5 + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true), +#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true), +#endif +#endif + + }; + + return algos; +} + +const std::unordered_map +CudnnAlgoPack::conv3d_bwd_data_algos() { + static const std::unordered_map + algos = { + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true), + }; + + return algos; +} // namespace cuda + +const std::unordered_map +CudnnAlgoPack::conv3d_bwd_flt_algos() { +#pragma message \ + "fp16 dilated conv with odd size filter, only algo_1 works, need focus on doc" + static const std::unordered_map + algos = { + DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false), + }; + + return algos; +} + +const std::unordered_map +CudnnAlgoPack::conv3d_fwd_algos() { + static const std::unordered_map + algos = { + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + true), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true), + }; + + return algos; +} + +#undef DEF_ALGO +#undef DEF_NAME +#undef V +#undef V1 + } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/cudnn_wrapper.h b/dnn/src/cuda/cudnn_wrapper.h index c4ada5d2..8346ef0b 100644 --- a/dnn/src/cuda/cudnn_wrapper.h +++ b/dnn/src/cuda/cudnn_wrapper.h @@ -10,6 +10,7 @@ */ #pragma once +#include #include "megdnn/basic_types.h" #include "megdnn/oprs/nn.h" #include "src/cuda/cudnn_with_check.h" @@ -27,7 +28,7 @@ class TensorDesc { public: TensorDesc(); //! default layout is nchw - void set(const TensorLayout& layout, const param::Convolution::Format = + void set(const TensorLayout& layout, const param::Convolution::Format = param::Convolution::Format::NCHW); ~TensorDesc(); cudnnTensorDescriptor_t desc; @@ -103,9 +104,52 @@ class Conv3DDesc { cudnnConvolutionDescriptor_t desc; }; +class CudnnAlgoPack { +public: + //! algorithm attr + struct Attr { + std::string name; + bool is_reproducible; + }; + static const std::unordered_map + conv_bwd_data_algos(); -} // namespace cuda -} // namespace megdnn + static const std::unordered_map + conv_bwd_flt_algos(); + + static const std::unordered_map + conv_fwd_algos(); + + static const std::unordered_map + conv3d_bwd_data_algos(); + + static const std::unordered_map + conv3d_bwd_flt_algos(); + + static const std::unordered_map + conv3d_fwd_algos(); + +}; + +} // namespace cuda +} // namespace megdnn + +namespace std { + +#define DEF_HASH(_type) \ + template <> \ + struct hash<_type> { \ + std::size_t operator()(const _type& algo) const { \ + return std::hash()(static_cast(algo)); \ + } \ + } + +DEF_HASH(cudnnConvolutionBwdDataAlgo_t); +DEF_HASH(cudnnConvolutionBwdFilterAlgo_t); +DEF_HASH(cudnnConvolutionFwdAlgo_t); + +#undef DEF_HASH +} // namespace std // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/deformable_conv/bwd_data/algo.cpp b/dnn/src/cuda/deformable_conv/bwd_data/algo.cpp index ce3fefed..df7ec722 100644 --- a/dnn/src/cuda/deformable_conv/bwd_data/algo.cpp +++ b/dnn/src/cuda/deformable_conv/bwd_data/algo.cpp @@ -19,7 +19,12 @@ using OprImpl = DeformableConvBackwardDataImpl; OprImpl::AlgoPack::AlgoPack() { all_algos.push_back(&algo_matmul); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } +MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvBackwardDataImpl) OprImpl::AlgoPack OprImpl::sm_algo_pack; diff --git a/dnn/src/cuda/deformable_conv/bwd_data/algo.h b/dnn/src/cuda/deformable_conv/bwd_data/algo.h index 5f83bc2f..af8fcc0d 100644 --- a/dnn/src/cuda/deformable_conv/bwd_data/algo.h +++ b/dnn/src/cuda/deformable_conv/bwd_data/algo.h @@ -13,11 +13,15 @@ #include "megdnn/oprs.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" #include "src/common/utils.h" #include "src/cuda/handle.h" #include "src/cuda/deformable_conv/opr_impl.h" +#include + namespace megdnn { namespace cuda { @@ -26,6 +30,10 @@ protected: ~AlgoBase() = default; public: + enum class AlgoType : uint32_t { + CUDA_MATMUL, + }; + using Mapper = std::unordered_map; AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { DeformableConvBackwardDataImpl* opr; @@ -107,17 +115,18 @@ public: bool is_reproducible() const override { return true; } const char* name() const override { return "AlgoMatmul"; } + MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) }; -class DeformableConvBackwardDataImpl::AlgoPack { - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator=(const AlgoPack&) = delete; - +class DeformableConvBackwardDataImpl::AlgoPack : NonCopyableObj { + AlgoBase::Mapper m_all_algos_map; public: AlgoPack(); AlgoMatmul algo_matmul; //! all algorithms std::vector all_algos; + + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; } // namespace cuda diff --git a/dnn/src/cuda/deformable_conv/bwd_flt/algo.cpp b/dnn/src/cuda/deformable_conv/bwd_flt/algo.cpp index a7c37236..a726e026 100644 --- a/dnn/src/cuda/deformable_conv/bwd_flt/algo.cpp +++ b/dnn/src/cuda/deformable_conv/bwd_flt/algo.cpp @@ -20,7 +20,11 @@ using OprImpl = DeformableConvBackwardFilterImpl; OprImpl::AlgoPack::AlgoPack() { all_algos.push_back(&algo_matmul); + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } +MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvBackwardFilterImpl) OprImpl::AlgoPack OprImpl::sm_algo_pack; diff --git a/dnn/src/cuda/deformable_conv/bwd_flt/algo.h b/dnn/src/cuda/deformable_conv/bwd_flt/algo.h index a2bb713e..390cd2ba 100644 --- a/dnn/src/cuda/deformable_conv/bwd_flt/algo.h +++ b/dnn/src/cuda/deformable_conv/bwd_flt/algo.h @@ -13,11 +13,15 @@ #include "megdnn/oprs.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" #include "src/common/utils.h" #include "src/cuda/handle.h" #include "src/cuda/deformable_conv/opr_impl.h" +#include + namespace megdnn { namespace cuda { @@ -26,6 +30,11 @@ protected: ~AlgoBase() = default; public: + enum class AlgoType : uint32_t { + CUDA_MATMUL, + }; + using Mapper = std::unordered_map; + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { DeformableConvBackwardFilterImpl* opr; @@ -97,18 +106,18 @@ public: bool is_reproducible() const override { return true; } const char* name() const override { return "AlgoMatmul"; } + MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) }; -class DeformableConvBackwardFilterImpl::AlgoPack { - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator=(const AlgoPack&) = delete; - +class DeformableConvBackwardFilterImpl::AlgoPack : NonCopyableObj { + AlgoBase::Mapper m_all_algos_map; public: AlgoPack(); AlgoMatmul algo_matmul; //! all algorithms std::vector all_algos; + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; } // namespace cuda diff --git a/dnn/src/cuda/deformable_conv/fwd/algo.cpp b/dnn/src/cuda/deformable_conv/fwd/algo.cpp index f26c80e3..f321011d 100644 --- a/dnn/src/cuda/deformable_conv/fwd/algo.cpp +++ b/dnn/src/cuda/deformable_conv/fwd/algo.cpp @@ -22,8 +22,14 @@ using OprImpl = DeformableConvForwardImpl; OprImpl::AlgoPack::AlgoPack() { all_algos.push_back(&algo_matmul); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } +MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvForwardImpl) + OprImpl::AlgoPack OprImpl::sm_algo_pack; OprImpl::AlgoBase::SizeArgs::SizeArgs(OprImpl* o, const TensorLayout& im, diff --git a/dnn/src/cuda/deformable_conv/fwd/algo.h b/dnn/src/cuda/deformable_conv/fwd/algo.h index f2d28ecb..2de2fca9 100644 --- a/dnn/src/cuda/deformable_conv/fwd/algo.h +++ b/dnn/src/cuda/deformable_conv/fwd/algo.h @@ -13,9 +13,13 @@ #include "megdnn/oprs.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" #include "src/cuda/deformable_conv/opr_impl.h" #include "src/cuda/utils.h" +#include + namespace megdnn { namespace cuda { @@ -24,6 +28,11 @@ protected: ~AlgoBase() = default; public: + enum class AlgoType : uint32_t { + CUDA_MATMUL, + }; + using Mapper = std::unordered_map; + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { DeformableConvForwardImpl* opr; @@ -92,17 +101,17 @@ public: bool is_reproducible() const override { return true; } const char* name() const override { return "AlgoMatmul"; } + MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) }; -class DeformableConvForwardImpl::AlgoPack { - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator=(const AlgoPack&) = delete; - +class DeformableConvForwardImpl::AlgoPack : NonCopyableObj { + AlgoBase::Mapper m_all_algos_map; public: AlgoPack(); AlgoMatmul algo_matmul; //! all algorithms std::vector all_algos; + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; } // namespace cuda diff --git a/dnn/src/cuda/deformable_conv/opr_impl.h b/dnn/src/cuda/deformable_conv/opr_impl.h index 3a6ec138..1740aedd 100644 --- a/dnn/src/cuda/deformable_conv/opr_impl.h +++ b/dnn/src/cuda/deformable_conv/opr_impl.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -29,19 +30,6 @@ public: const TensorLayout& mask, const TensorLayout& dst) override; - std::vector get_all_algorithms( - const TensorLayout& im, const TensorLayout& filter, - const TensorLayout& offset, const TensorLayout& mask, - const TensorLayout& dst) override; - - Algorithm* get_algorithm_heuristic(const TensorLayout& im, - const TensorLayout& filter, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& dst, - size_t workspace_limit_in_bytes, - bool reproducible) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& im, const CanonizedFilterMeta& filter, const TensorLayout& offset, @@ -58,31 +46,35 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + +protected: + std::vector get_all_algorithms( + const TensorLayout& im, const TensorLayout& filter, + const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& dst) override; + + Algorithm* get_algorithm_heuristic(const TensorLayout& im, + const TensorLayout& filter, + const TensorLayout& offset, + const TensorLayout& mask, + const TensorLayout& dst, + size_t workspace_limit_in_bytes, + bool reproducible) override; private: static AlgoPack sm_algo_pack; }; -class DeformableConvBackwardFilterImpl: public DeformableConvBackwardFilter { +class DeformableConvBackwardFilterImpl : public DeformableConvBackwardFilter { public: using DeformableConvBackwardFilter::DeformableConvBackwardFilter; - void exec(_megdnn_tensor_in im,_megdnn_tensor_in offset, _megdnn_tensor_in mask, - _megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad, + void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset, + _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, + _megdnn_tensor_out filter_grad, _megdnn_workspace workspace) override; - std::vector get_all_algorithms( - const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, - const TensorLayout& out_grad, const TensorLayout& filter_grad) override; - - Algorithm* get_algorithm_heuristic(const TensorLayout& im, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& out_grad, - const TensorLayout& filter_grad, - size_t workspace_limit_in_bytes, - bool reproducible) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, @@ -91,9 +83,11 @@ public: size_t workspace_limit_in_bytes, bool reproducible); - size_t get_workspace_in_bytes( - const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, - const TensorLayout& out_grad, const TensorLayout& filter_grad) override; + size_t get_workspace_in_bytes(const TensorLayout& im, + const TensorLayout& offset, + const TensorLayout& mask, + const TensorLayout& out_grad, + const TensorLayout& filter_grad) override; const char* get_algorithm_set_name() const override; @@ -103,6 +97,21 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + +protected: + std::vector get_all_algorithms( + const TensorLayout& im, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& out_grad, + const TensorLayout& filter_grad) override; + + Algorithm* get_algorithm_heuristic(const TensorLayout& im, + const TensorLayout& offset, + const TensorLayout& mask, + const TensorLayout& out_grad, + const TensorLayout& filter_grad, + size_t workspace_limit_in_bytes, + bool reproducible) override; private: static AlgoPack sm_algo_pack; @@ -118,19 +127,6 @@ public: _megdnn_tensor_out offset_grad, _megdnn_tensor_out mask_grad, _megdnn_workspace workspace) override; - std::vector get_all_algorithms( - const TensorLayout& im, const TensorLayout& filter, - const TensorLayout& offset, const TensorLayout& mask, - const TensorLayout& out_grad, const TensorLayout& im_grad, - const TensorLayout& offset_grad, const TensorLayout& mask_grad) override; - - Algorithm* get_algorithm_heuristic( - const TensorLayout& im, const TensorLayout& filter, - const TensorLayout& offset, const TensorLayout& mask, - const TensorLayout& out_grad, const TensorLayout& im_grad, - const TensorLayout& offset_grad, const TensorLayout& mask_grad, - size_t workspace_limit_in_bytes, bool reproducible) override; - Algorithm* get_algorithm_heuristic( const TensorLayout& im, const CanonizedFilterMeta& filter, const TensorLayout& offset, const TensorLayout& mask, @@ -138,11 +134,14 @@ public: const TensorLayout& offset_grad, const TensorLayout& mask_grad, size_t workspace_limit_in_bytes, bool reproducible); - size_t get_workspace_in_bytes( - const TensorLayout& im, const TensorLayout& filter, - const TensorLayout& offset, const TensorLayout& mask, - const TensorLayout& out_grad, const TensorLayout& im_grad, - const TensorLayout& offset_grad, const TensorLayout& mask_grad) override; + size_t get_workspace_in_bytes(const TensorLayout& im, + const TensorLayout& filter, + const TensorLayout& offset, + const TensorLayout& mask, + const TensorLayout& out_grad, + const TensorLayout& im_grad, + const TensorLayout& offset_grad, + const TensorLayout& mask_grad) override; const char* get_algorithm_set_name() const override; @@ -152,6 +151,22 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + +protected: + std::vector get_all_algorithms( + const TensorLayout& im, const TensorLayout& filter, + const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& out_grad, const TensorLayout& im_grad, + const TensorLayout& offset_grad, + const TensorLayout& mask_grad) override; + + Algorithm* get_algorithm_heuristic( + const TensorLayout& im, const TensorLayout& filter, + const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& out_grad, const TensorLayout& im_grad, + const TensorLayout& offset_grad, const TensorLayout& mask_grad, + size_t workspace_limit_in_bytes, bool reproducible) override; private: static AlgoPack sm_algo_pack; diff --git a/dnn/src/cuda/local_share/backward_data/algo.cpp b/dnn/src/cuda/local_share/backward_data/algo.cpp index 0e3f26b8..ec4afd49 100644 --- a/dnn/src/cuda/local_share/backward_data/algo.cpp +++ b/dnn/src/cuda/local_share/backward_data/algo.cpp @@ -18,8 +18,14 @@ using namespace cuda; LocalShareBackwardDataImpl::AlgoPack::AlgoPack() { all_algos.push_back(&implicit_gemm); all_algos.push_back(&batched_matmul); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } +MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareBackwardDataImpl) + LocalShareBackwardDataImpl::AlgoPack LocalShareBackwardDataImpl::sm_algo_pack; LocalShareBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( diff --git a/dnn/src/cuda/local_share/backward_data/algo.h b/dnn/src/cuda/local_share/backward_data/algo.h index 66a954ce..e4a62b38 100644 --- a/dnn/src/cuda/local_share/backward_data/algo.h +++ b/dnn/src/cuda/local_share/backward_data/algo.h @@ -13,10 +13,14 @@ #include "megdnn/oprs.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" #include "src/common/utils.h" #include "src/cuda/handle.h" #include "src/cuda/local_share/opr_impl.h" +#include + namespace megdnn { namespace cuda { @@ -25,6 +29,13 @@ protected: ~AlgoBase() = default; public: + enum class AlgoType : uint32_t { + CUDA_IMPLICIT_GEMM, + CUDA_BATCHED_MATMUL, + }; + using Mapper = std::unordered_map; + + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { LocalShareBackwardDataImpl* opr; @@ -77,6 +88,7 @@ public: const char* name() const override { return "LOCAL_SHARE_IMPLICIT_GEMM"; } + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM) }; class LocalShareBackwardDataImpl::AlgoBatchedMatMul final @@ -93,11 +105,11 @@ public: const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } + MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) }; -class LocalShareBackwardDataImpl::AlgoPack { - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator=(const AlgoPack&) = delete; +class LocalShareBackwardDataImpl::AlgoPack : NonCopyableObj { + AlgoBase::Mapper m_all_algos_map; public: AlgoPack(); @@ -106,6 +118,7 @@ public: AlgoBatchedMatMul batched_matmul; std::vector all_algos; + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; } // namespace cuda diff --git a/dnn/src/cuda/local_share/backward_filter/algo.cpp b/dnn/src/cuda/local_share/backward_filter/algo.cpp index 0513aeee..f5350ba0 100644 --- a/dnn/src/cuda/local_share/backward_filter/algo.cpp +++ b/dnn/src/cuda/local_share/backward_filter/algo.cpp @@ -18,8 +18,14 @@ using namespace cuda; LocalShareBackwardFilterImpl::AlgoPack::AlgoPack() { all_algos.push_back(&implicit_gemm); all_algos.push_back(&batched_matmul); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } +MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareBackwardFilterImpl) + LocalShareBackwardFilterImpl::AlgoPack LocalShareBackwardFilterImpl::sm_algo_pack; LocalShareBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( diff --git a/dnn/src/cuda/local_share/backward_filter/algo.h b/dnn/src/cuda/local_share/backward_filter/algo.h index cf916e78..83190674 100644 --- a/dnn/src/cuda/local_share/backward_filter/algo.h +++ b/dnn/src/cuda/local_share/backward_filter/algo.h @@ -13,10 +13,14 @@ #include "megdnn/oprs.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" #include "src/common/utils.h" #include "src/cuda/handle.h" #include "src/cuda/local_share/opr_impl.h" +#include + namespace megdnn { namespace cuda { @@ -25,6 +29,12 @@ protected: ~AlgoBase() = default; public: + enum class AlgoType : uint32_t { + CUDA_IMPLICIT_GEMM, + CUDA_BATCHED_MATMUL, + }; + using Mapper = std::unordered_map; + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { LocalShareBackwardFilterImpl* opr; @@ -75,6 +85,7 @@ public: bool is_reproducible() const override { return true; } const char* name() const override { return "LOCAL_SHARE_IMPLICIT_GEMM"; } + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM) }; class LocalShareBackwardFilterImpl::AlgoBatchedMatMul final : public AlgoBase { @@ -88,11 +99,11 @@ public: bool is_reproducible() const override { return true; } const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } + MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) }; -class LocalShareBackwardFilterImpl::AlgoPack { - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator=(const AlgoPack&) = delete; +class LocalShareBackwardFilterImpl::AlgoPack : NonCopyableObj { + AlgoBase::Mapper m_all_algos_map; public: AlgoPack(); @@ -101,6 +112,8 @@ public: AlgoBatchedMatMul batched_matmul; std::vector all_algos; + + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; } // namespace cuda diff --git a/dnn/src/cuda/local_share/forward/algo.cpp b/dnn/src/cuda/local_share/forward/algo.cpp index 67c13eb7..3f8b0180 100644 --- a/dnn/src/cuda/local_share/forward/algo.cpp +++ b/dnn/src/cuda/local_share/forward/algo.cpp @@ -19,8 +19,14 @@ LocalShareForwardImpl::AlgoPack::AlgoPack() { all_algos.push_back(&batch_size_aware_chwn_small_image); all_algos.push_back(&batch_size_aware_chwn); all_algos.push_back(&batched_matmul); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } +MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareForwardImpl) + LocalShareForwardImpl::AlgoPack LocalShareForwardImpl::sm_algo_pack; LocalShareForwardImpl::AlgoBase::SizeArgs::SizeArgs(LocalShareForwardImpl* o, diff --git a/dnn/src/cuda/local_share/forward/algo.h b/dnn/src/cuda/local_share/forward/algo.h index a82be4f4..0063f2d5 100644 --- a/dnn/src/cuda/local_share/forward/algo.h +++ b/dnn/src/cuda/local_share/forward/algo.h @@ -14,9 +14,13 @@ #include "megdnn/oprs.h" #include "src/common/utils.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" #include "src/cuda/handle.h" #include "src/cuda/local_share/opr_impl.h" +#include + namespace megdnn { namespace cuda { @@ -25,6 +29,13 @@ protected: ~AlgoBase() = default; public: + enum class AlgoType : uint32_t { + CUDA_CHWN_BATCH_SIZE_AWARE, + CUDA_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE, + CUDA_BATCHED_MATMUL + }; + using Mapper = std::unordered_map; + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { LocalShareForwardImpl* opr; @@ -79,6 +90,7 @@ public: const char* name() const override { return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE"; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CHWN_BATCH_SIZE_AWARE) }; class LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage final @@ -95,6 +107,7 @@ public: const char* name() const override { return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE"; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE) }; class LocalShareForwardImpl::AlgoBatchedMatMul final : public AlgoBase { @@ -108,11 +121,11 @@ public: bool is_reproducible() const override { return true; } const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } + MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) }; -class LocalShareForwardImpl::AlgoPack { - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator=(const AlgoPack&) = delete; +class LocalShareForwardImpl::AlgoPack : NonCopyableObj { + AlgoBase::Mapper m_all_algos_map; public: AlgoPack(); @@ -122,6 +135,7 @@ public: AlgoBatchedMatMul batched_matmul; std::vector all_algos; + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; } // namespace cuda diff --git a/dnn/src/cuda/local_share/opr_impl.h b/dnn/src/cuda/local_share/opr_impl.h index 76aba387..4d404155 100644 --- a/dnn/src/cuda/local_share/opr_impl.h +++ b/dnn/src/cuda/local_share/opr_impl.h @@ -23,14 +23,6 @@ public: size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) override; - std::vector get_all_algorithms( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst, - size_t workspace_limit_in_bytes, - bool reproducible) override; const char* get_algorithm_set_name() const override; class AlgoBase; @@ -41,7 +33,17 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); +protected: + std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst, + size_t workspace_limit_in_bytes, + bool reproducible) override; private: static AlgoPack sm_algo_pack; }; @@ -54,14 +56,6 @@ public: size_t get_workspace_in_bytes(const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad) override; - std::vector get_all_algorithms( - const TensorLayout& filter, const TensorLayout& diff, - const TensorLayout& grad) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_limit_in_bytes, - bool reproducible) override; const char* get_algorithm_set_name() const override; class AlgoBase; @@ -71,6 +65,17 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + +protected: + std::vector get_all_algorithms( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& filter, + const TensorLayout& diff, + const TensorLayout& grad, + size_t workspace_limit_in_bytes, + bool reproducible) override; private: static AlgoPack sm_algo_pack; @@ -84,14 +89,6 @@ public: size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) override; - std::vector get_all_algorithms( - const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_limit_in_bytes, - bool reproducible) override; const char* get_algorithm_set_name() const override; class AlgoBase; @@ -101,6 +98,17 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + +protected: + std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& diff, + const TensorLayout& grad, + size_t workspace_limit_in_bytes, + bool reproducible) override; private: static AlgoPack sm_algo_pack; diff --git a/dnn/src/cuda/matrix_mul/algos.cpp b/dnn/src/cuda/matrix_mul/algos.cpp index 38598335..85e5d47f 100644 --- a/dnn/src/cuda/matrix_mul/algos.cpp +++ b/dnn/src/cuda/matrix_mul/algos.cpp @@ -11,6 +11,7 @@ #include "./algos.h" #include "src/cuda/utils.h" +#include "src/common/algo_base.h" #include #if CUDA_VERSION >= 10010 @@ -33,10 +34,16 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { cublas_bfloat16 = std::make_unique(&cublas); all_algos.push_back(cublas_bfloat16.get()); #endif + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; +MEGDNN_DEF_GET_ALGO_FROM_DESC(MatrixMulForwardImpl) + MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(MatrixMulForwardImpl* o, const TensorLayout& A, const TensorLayout& B, @@ -67,4 +74,5 @@ std::string MatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { m, k, k, n, m, n, param.transposeA, param.transposeB, layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); } + // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/matrix_mul/algos.h b/dnn/src/cuda/matrix_mul/algos.h index e52f628c..5e5def36 100644 --- a/dnn/src/cuda/matrix_mul/algos.h +++ b/dnn/src/cuda/matrix_mul/algos.h @@ -6,14 +6,18 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/oprs.h" #include "src/common/utils.h" #include "src/cuda/matrix_mul/opr_impl.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" +#include #include #include #if CUDA_VERSION >= 10010 @@ -32,6 +36,15 @@ protected: ~AlgoBase() = default; public: + enum class AlgoType : uint32_t { + CUDA_CUBLAS, + CUDA_WMMA_UINT4X4X32, + CUDA_CUBLASLT, + CUDA_NAIVE, + CUDA_BFLOAT16 + }; + using Mapper = std::unordered_map; + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { MatrixMulForwardImpl* opr; @@ -62,12 +75,12 @@ public: virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; virtual void exec(const ExecArgs& args) const = 0; - bool is_available_wk(const SizeArgs& args, size_t limit) { + bool is_available_wk(const SizeArgs& args, size_t limit) const { return is_available(args) && get_workspace_in_bytes(args) <= limit; } bool is_available_reproducible( const SizeArgs& args, bool reproducible = true, - size_t limit = std::numeric_limits::max()) { + size_t limit = std::numeric_limits::max()) const { return (!reproducible || is_reproducible()) && is_available_wk(args, limit); } @@ -80,8 +93,6 @@ public: name(), req, workspace.size); return *this; } - - }; class MatrixMulForwardImpl::AlgoCuBlas final : public AlgoBase { @@ -91,13 +102,10 @@ public: size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { return 0_z; } - const char* name() const override { - return "CUBLAS"; - } + const char* name() const override { return "CUBLAS"; } void exec(const ExecArgs& args) const override; - bool is_reproducible() const override { - return true; - } + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) }; #if CUDA_VERSION >= 10000 @@ -106,13 +114,10 @@ public: AlgoUInt4x4x32WMMA() = default; bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; - const char* name() const override { - return "UINT4x4x32_WMMA"; - } + const char* name() const override { return "UINT4x4x32_WMMA"; } void exec(const ExecArgs& args) const override; - bool is_reproducible() const override { - return true; - } + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32) }; #endif #if CUDA_VERSION >= 10010 @@ -120,13 +125,10 @@ class MatrixMulForwardImpl::AlgoCuBlasLt final : public AlgoBase { public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; - const char* name() const override { - return "CUBLAS_LT"; - } + const char* name() const override { return "CUBLAS_LT"; } void exec(const ExecArgs& args) const override; - bool is_reproducible() const override { - return true; - } + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) }; #endif @@ -140,6 +142,7 @@ public: const char* name() const override { return "NAIVE"; } void exec(const ExecArgs& args) const override; bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE) }; #if !MEGDNN_DISABLE_FLOAT16 @@ -151,6 +154,13 @@ public: const char* name() const override { return m_name.c_str(); } void exec(const ExecArgs& args) const override; bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_algorithm, ret); + return ret; + } private: MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr; @@ -160,9 +170,9 @@ private: }; #endif -class MatrixMulForwardImpl::AlgoPack { - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator=(const AlgoPack&) = delete; +class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { +private: + AlgoBase::Mapper m_all_algos_map; public: AlgoPack(); @@ -178,6 +188,8 @@ public: std::unique_ptr cublas_bfloat16; #endif std::vector all_algos; + + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; } // namespace cuda diff --git a/dnn/src/cuda/matrix_mul/bfloat16.cpp b/dnn/src/cuda/matrix_mul/bfloat16.cpp index 7d97c21e..f635dc6a 100644 --- a/dnn/src/cuda/matrix_mul/bfloat16.cpp +++ b/dnn/src/cuda/matrix_mul/bfloat16.cpp @@ -82,7 +82,7 @@ void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { args.opr->handle()->create_operator(); matmul_opr->param() = args.opr->param(); matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT; - matmul_opr->execution_policy() = {m_algorithm}; + matmul_opr->execution_policy() = {m_algorithm->info()}; matmul_opr->exec(a, b, c, ctypecvt.workspace()); } ctypecvt.comp_to_dst_type(c, args.tensor_c); diff --git a/dnn/src/cuda/matrix_mul/opr_impl.h b/dnn/src/cuda/matrix_mul/opr_impl.h index cc75bd9d..4c8f3d2b 100644 --- a/dnn/src/cuda/matrix_mul/opr_impl.h +++ b/dnn/src/cuda/matrix_mul/opr_impl.h @@ -25,15 +25,6 @@ public: bool is_thread_safe() const override { return true; } - std::vector get_all_algorithms(const TensorLayout& A, - const TensorLayout& B, - const TensorLayout& C) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& A, - const TensorLayout& B, - const TensorLayout& C, - size_t workspace_limit_in_bytes, - bool reproducible) override; - const char* get_algorithm_set_name() const override { return "CUDA MATMUL"; } @@ -55,6 +46,17 @@ public: static const AlgoPack& algo_pack() { return sm_algo_pack; } + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + +protected: + std::vector get_all_algorithms(const TensorLayout& A, + const TensorLayout& B, + const TensorLayout& C) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& A, + const TensorLayout& B, + const TensorLayout& C, + size_t workspace_limit_in_bytes, + bool reproducible) override; private: static AlgoPack sm_algo_pack; diff --git a/dnn/src/fallback/conv_bias/algos.cpp b/dnn/src/fallback/conv_bias/algos.cpp index 80100c6a..af26286f 100644 --- a/dnn/src/fallback/conv_bias/algos.cpp +++ b/dnn/src/fallback/conv_bias/algos.cpp @@ -10,10 +10,14 @@ */ #include "src/fallback/conv_bias/algos.h" +#include "src/fallback/conv_bias/conv1x1/algos.h" +#include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h" +#include "src/fallback/conv_bias/im2col/algos.h" #include "megdnn/opr_param_defs.h" #include "src/common/opr_delegate.h" #include "src/fallback/conv_bias/winograd/strategy.h" #include "src/naive/convolution/helper.h" +#include "src/common/algo_base.h" #include "midout.h" @@ -176,6 +180,7 @@ void kern_default(const ConvBiasImpl::NCBKernParam& p) { } // namespace MIDOUT_DECL(megdnn_fallback_naive) + /* ======================= AlgoNaive ======================== */ bool ConvBiasImpl::AlgoNaive::usable( diff --git a/dnn/src/fallback/conv_bias/algos.h b/dnn/src/fallback/conv_bias/algos.h index e70959bb..b92fded7 100644 --- a/dnn/src/fallback/conv_bias/algos.h +++ b/dnn/src/fallback/conv_bias/algos.h @@ -36,6 +36,7 @@ public: static_cast(AlgoDataType::QUINT8X8X32)); return {support_data_type, AlgoCategory::NAIVE}; } + MEGDNN_DECL_ALGO_TYPE(FB_NAIVE) }; class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase { @@ -59,6 +60,12 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; } + MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_F32) + std::string param() const override { + std::string ret; + serialize_write_pod(m_matmul_algo, ret); + return ret; + } private: MatrixMulImpl::AlgoBase* m_matmul_algo; @@ -87,6 +94,12 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; } + MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_4X4_F32) + std::string param() const override { + std::string ret; + serialize_write_pod(m_matmul_algo, ret); + return ret; + } private: MatrixMulImpl::AlgoBase* m_matmul_algo; @@ -115,6 +128,12 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; } + MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_QS8) + std::string param() const override { + std::string ret; + serialize_write_pod(m_matmul_algo, ret); + return ret; + } private: MatrixMulImpl::AlgoBase* m_matmul_algo; @@ -143,6 +162,12 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; } + MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8) + std::string param() const override { + std::string ret; + serialize_write_pod(m_matmul_algo, ret); + return ret; + } private: MatrixMulImpl::AlgoBase* m_matmul_algo; diff --git a/dnn/src/fallback/conv_bias/common.h b/dnn/src/fallback/conv_bias/common.h index fb9caf2e..7b33c6ba 100644 --- a/dnn/src/fallback/conv_bias/common.h +++ b/dnn/src/fallback/conv_bias/common.h @@ -156,6 +156,12 @@ using BiasMode = ConvBiasForward::BiasMode; ConvAlgoTypePack get_algo_type() const override { \ return {_algo_data_type, AlgoCategory::WINOGRAD}; \ } \ + std::string param() const override { \ + std::string ret; \ + serialize_write_pod(m_matmul_algo, ret); \ + serialize_write_pod(m_tile_size, ret); \ + return ret; \ + } \ \ private: \ fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; \ diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos.h b/dnn/src/fallback/conv_bias/conv1x1/algos.h index 6c3bc4ef..662812bc 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos.h +++ b/dnn/src/fallback/conv_bias/conv1x1/algos.h @@ -60,6 +60,13 @@ public: return {m_matmul_algo->matmul_description().algo_type.data_type, AlgoCategory::IM2COL}; } + MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8) + std::string param() const override { + std::string ret; + serialize_write_pod(m_matmul_algo, ret); + serialize_write_pod(m_oc_block_size, ret); + return ret; + } protected: size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h b/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h index b56bb138..f7250f7c 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h +++ b/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h @@ -43,6 +43,7 @@ public: static_cast(AlgoDataType::QUINT8X8X32)); return {support_data_type, AlgoCategory::IM2COL}; } + MEGDNN_DECL_ALGO_TYPE(FB_CONV1x1_GEMV) protected: size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; diff --git a/dnn/src/fallback/conv_bias/im2col/algos.h b/dnn/src/fallback/conv_bias/im2col/algos.h index 919ae250..c7d0d4c9 100644 --- a/dnn/src/fallback/conv_bias/im2col/algos.h +++ b/dnn/src/fallback/conv_bias/im2col/algos.h @@ -68,6 +68,14 @@ public: return {m_matmul_algo->matmul_description().algo_type.data_type, AlgoCategory::IM2COL}; } + MEGDNN_DECL_ALGO_TYPE(FB_IM2COL) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_matmul_algo, ret); + serialize_write_pod(m_ohw_tile_size, ret); + return ret; + } private: MatrixMulImpl::AlgoBase* m_matmul_algo; diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index efefbfea..b0b5e2d9 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -22,6 +22,14 @@ #include "src/naive/convolution/algorithms.h" #include "src/naive/handle.h" +#if MEGDNN_X86 +#include "src/x86/conv_bias/opr_impl.h" +#elif MEGDNN_AARCH64 +#include "src/aarch64/conv_bias/opr_impl.h" +#elif MEGDNN_ARMV7 +#include "src/armv7/conv_bias/opr_impl.h" +#endif + #include using namespace megdnn; @@ -65,17 +73,19 @@ void incr_ptr(T*& dst, ptrdiff_t delta) { class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoNaive algo_naive; SmallVector> refhold; + SmallVector m_all_algos; + AlgoBase::Mapper m_all_algos_map; public: AlgoPack() { refhold.emplace_back(new AlgoConv1x1Gemv()); - all_algos.emplace_back(refhold.back().get()); + m_all_algos.emplace_back(refhold.back().get()); static CpuOprDelegationStorage<> storage; auto matmul_opr = storage.get(); - auto&& matmul_algos = - static_cast(matmul_opr)->algo_pack(); + auto&& matmul_algos = static_cast(matmul_opr) + ->get_all_packed_algo(); for (auto&& algo : matmul_algos) { #if MEGDNN_X86 //! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may @@ -97,13 +107,13 @@ public: refhold.emplace_back(new AlgoIm2col( static_cast(algo), ohw_tile_size)); - all_algos.emplace_back(refhold.back().get()); + m_all_algos.emplace_back(refhold.back().get()); } for (size_t oc_tile_size : {48, 24}) { refhold.emplace_back(new AlgoConv1x1( static_cast(algo), oc_tile_size)); - all_algos.emplace_back(refhold.back().get()); + m_all_algos.emplace_back(refhold.back().get()); } #endif @@ -113,26 +123,35 @@ public: //! FIXME: I do not know a better way to do it. refhold.emplace_back(new AlgoWinogradF32( static_cast(algo))); - all_algos.emplace_back(refhold.back().get()); + m_all_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoWinogradF32_4x4( static_cast(algo))); - all_algos.emplace_back(refhold.back().get()); + m_all_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoWinogradQS8( static_cast(algo))); - all_algos.emplace_back(refhold.back().get()); + m_all_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoWinogradQS8_8x8( static_cast(algo))); - all_algos.emplace_back(refhold.back().get()); + m_all_algos.emplace_back(refhold.back().get()); #endif } - all_algos.emplace_back(&algo_naive); + m_all_algos.emplace_back(&algo_naive); + + for (auto&& algo : m_all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } - SmallVector all_algos; + const SmallVector& all_algos() const { return m_all_algos; } + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -SmallVector ConvBiasImpl::algo_pack() { - static AlgoPack sl_algo_pack; - return sl_algo_pack.all_algos; +const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { + static AlgoPack algo_pack; + return algo_pack; +} + +SmallVector ConvBiasImpl::get_all_packed_algo() { + return algo_pack().all_algos(); } SmallVector ConvBiasImpl::select_algo_type( @@ -140,7 +159,7 @@ SmallVector ConvBiasImpl::select_algo_type( megdnn_assert(nr_type_contain(target_type.data_type), "ConvBias algo selection only support one type"); SmallVector algos; - for (auto&& algo : algo_pack()) { + for (auto&& algo : get_all_packed_algo()) { auto algo_type = algo->get_algo_type(); if (contain_data_type(algo_type.data_type, target_type.data_type) && algo_type.algo_category == target_type.algo_category) { @@ -166,7 +185,7 @@ void ConvBiasImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, workspace.size, preprocessed_filter); auto fparam = make_ncb_kern_param(src, filter, bias, dst, workspace, preprocessed_filter); - ConvBiasImpl::Algorithm* algo = get_algorithm(fparam, workspace.size); + auto&& algo = get_algorithm(fparam, workspace.size); if (!is_naive_algo(algo) && NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) { exec_with_ncb_kern(fparam, algo); @@ -189,9 +208,10 @@ void ConvBiasImpl::exec_preprocess(const TensorLayout& src_layout, auto fparam = make_ncb_kern_param(src, filter, bias, dst, workspace, preprocessed_filter); //! should not pass workspace_size limit otherwise can not find match algo - ConvBiasImpl::Algorithm* algo = get_algorithm(fparam); - if (!is_naive_algo(algo) && NCB_ALGO_FUNC(get_preprocess_workspace, algo, - fparam) <= workspace.size) { + auto&& algo = get_algorithm(fparam); + if (!is_naive_algo(algo) && + NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam) <= + workspace.size) { exec_preprocess_with_ncb_kern(fparam, algo); } else { naive::ConvBiasForwardImpl::exec_preprocess( @@ -207,7 +227,7 @@ size_t ConvBiasImpl::get_workspace_in_bytes( const PreprocessedFilter* preprocessed_filter) { auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, preprocessed_filter); - ConvBiasImpl::Algorithm* algo = get_algorithm(fparam); + auto&& algo = get_algorithm(fparam); if (is_naive_algo(algo)) { return naive::ConvBiasForwardImpl::get_workspace_in_bytes( src, filter, bias, z, dst, preprocessed_filter); @@ -221,7 +241,7 @@ size_t ConvBiasImpl::get_preprocess_workspace_in_bytes( const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst) { auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); - Algorithm* algo = get_algorithm(fparam); + auto&& algo = get_algorithm(fparam); if (is_naive_algo(algo)) { return naive::ConvBiasForwardImpl::get_preprocess_workspace_in_bytes( src, filter, bias, z, dst); @@ -235,7 +255,7 @@ SmallVector ConvBiasImpl::deduce_preprocessed_filter_layout( const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst) { auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); - Algorithm* algo = get_algorithm(fparam); + auto&& algo = get_algorithm(fparam); if (is_naive_algo(algo)) { return naive::ConvBiasForwardImpl::deduce_preprocessed_filter_layout( src, filter, bias, z, dst); @@ -443,7 +463,7 @@ std::vector ConvBiasImpl::get_all_algorithms_with_ncb( MEGDNN_MARK_USED_VAR(param); std::vector algos; std::vector prefer_algos; - for (auto&& algo : algo_pack()) { + for (auto&& algo : get_all_packed_algo()) { if (algo->usable(param, AlgoSelectionStrategy::FULL_RUN)) { if (algo->is_preferred(param)) { prefer_algos.push_back(algo); @@ -457,10 +477,49 @@ std::vector ConvBiasImpl::get_all_algorithms_with_ncb( return algos; } +ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc( + const AlgorithmDesc& desc) const { + if (!desc.valid()) { + return nullptr; + } else { + switch (desc.handle_type) { + case Handle::HandleType::FALLBACK: { + const auto& map = algo_pack().all_algos_map(); + megdnn_assert(map.find(desc) != map.end()); + return map.at(desc); + }; + +#if MEGDNN_X86 + case Handle::HandleType::X86: + return x86::ConvBiasImpl::get_algo_from_desc(desc); +#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 + case Handle::HandleType::ARM_COMMON: + return arm_common::ConvBiasImpl::get_algo_from_desc(desc); +#if MEGDNN_AARCH64 + case Handle::HandleType::AARCH64: + return aarch64::ConvBiasImpl::get_algo_from_desc(desc); +#else + case Handle::HandleType::ARMV7: + return armv7::ConvBiasImpl::get_algo_from_desc(desc); +#endif +#endif + case Handle::HandleType::NAIVE: { + auto algo = static_cast(handle()) + ->default_conv_bias_fwd_algo(); + megdnn_assert(algo->info().desc == desc); + return algo; + } + default: + megdnn_throw("Unknown handle type"); + return nullptr; + } + } +} + ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( const NCBKernSizeParam& param, size_t workspace_size) { - if (auto set = execution_policy().algorithm) { - return set; + if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { + return algo; } if (!m_prev_selected_algo || memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index 34318cb9..7ac49b0b 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -216,6 +216,86 @@ public: AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::FALLBACK; } + + enum class AlgoType : uint32_t { + //! fallback + FB_NAIVE = 1 << 0, + FB_WINOGRAD_F32, + FB_WINOGRAD_4X4_F32, + FB_WINOGRAD_QS8, + FB_WINOGRAD_8X8_QS8, + FB_CONV1x1, + FB_CONV1x1_GEMV, + FB_IM2COL, + +#if MEGDNN_X86 + X86_DIRECT = 1 << 8, + X86_DIRECT_STRD2, + X86_WINOGRAD_F63_8x8_F32, + X86_WINOGRAD_F23_8x8_F32, + X86_MKLDNN, + X86_CHANWISE_AVX2_STRD1_QINT8, + X86_CHANWISE_AVX2_STRD2_QINT8, + X86_DIRECT_AVX2_STRD1_INT8, + X86_DIRECT_AVX2_STRD2_INT8, + X86_MKLDNN_QINT8, + X86_MKLDNN_MATMUL_QINT8, +#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 + ARM_COMMON_WINOGRAD_F23_FP16 = 1 << 8, + ARM_COMMON_WINOGRAD_F45_FP16, + ARM_COMMON_WINOGRAD_F63_FP16, + ARM_COMMON_WINOGRAD_F23_8X8_FP16, + ARM_COMMON_DIRECT_FP16, + ARM_COMMON_DIRECT_STRD1_FP16, + ARM_COMMON_WINOGRAD_F23_4X4_FP32, + ARM_COMMON_WINOGRAD_F63_FP32, + ARM_COMMON_WINOGRAD_F63_4X4_FP32, + ARM_COMMON_WINOGRAD_F54_FP32, + ARM_COMMON_WINOGRAD_F45_FP32, + ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32, + ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32, + ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32, + ARM_COMMON_DIRECT_FP32, + ARM_COMMON_DIRECT_STRD1_FP32, + ARM_COMMON_DIRECT_STRD2_FP32, + ARM_COMMON_DIRECT_NCHW44_FP32, + ARM_COMMON_DIRECT_NCHW_NCHW44_FP32, + ARM_COMMON_CHWNWISE_NCHW44_F32, + ARM_COMMON_DIRECT_STRD1_S8, + ARM_COMMON_DIRECT_STRD2_S8, + ARM_COMMON_DIRECT_NCHW44, + ARM_COMMON_DIRECT_NCHW_NCHW44_S8, + ARM_COMMON_CHANWISE_STRD1_NCHW44_S8, + ARM_COMMON_CHANWISE_STRD2_NCHW44_S8, + ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8, + ARM_COMMON_DIRECT_STRD1_DOT_S8, + ARM_COMMON_DIRECT_STRD2_DOT_S8, + ARM_COMMON_DIRECT_NCHW44_DOT_S8, + ARM_COMMON_WINOGRAD_F23_8X8_S8, + ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32, + ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8, + ARM_COMMON_DIRECT_INT8X8X16, + ARM_COMMON_DIRECT_NCHW44_INT8X8X16, + ARM_COMMON_DIRECT_STRD2_INT8X8X16, + ARM_COMMON_DIRECT_STRD2_F2_INT8X8X16, + ARM_COMMON_CHWNWISE_STRD1_STRD2_NCHW44_INT8X8X16, + ARM_COMMON_DIRECT_NCHW_NCHW44_INT8X8X16, + ARM_COMMON_DIRECT_STRD1_QU8, + ARM_COMMON_DIRECT_STRD2_QU8, + ARM_COMMON_DIRECT_STRD1_DOT_QU8, + ARM_COMMON_DIRECT_STRD2_DOT_QU8, +#if MEGDNN_AARCH64 + AARCH64_DIRECT_STRD2_FP16, + AARCH64_DIRECT_STRD2_FP32, + AARCH64_MATMUL_S8, + AARCH64_MATMUL_QU8, +#else + ARMV7_MATMUL_S8, + ARMV7_MATMUL_QU8, +#endif // MEGDNN_AARCH64 +#endif + }; + virtual ~AlgoBase() = default; virtual bool usable( const NCBKernSizeParam& param, @@ -255,12 +335,14 @@ public: //! get the type of the algo virtual ConvAlgoTypePack get_algo_type() const = 0; + using Mapper = std::unordered_map; }; + using AlgoMapper = AlgoBase::Mapper; /** * \brief get all the algorithm for the opr. */ - virtual SmallVector algo_pack(); + virtual SmallVector get_all_packed_algo(); /** * \brief select algo according to input algo type @@ -305,6 +387,8 @@ private: bool is_naive_algo(ConvBiasImpl::Algorithm* algo); + Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; + //! get algorithm set by user or by heuristic Algorithm* get_algorithm( const NCBKernSizeParam& param, @@ -320,6 +404,8 @@ private: _megdnn_tensor_in bias, _megdnn_tensor_out dst, _megdnn_workspace workspace, const PreprocessedFilter* preprocessed_filter); + + static const AlgoPack& algo_pack(); }; inline bool is_enable_filter_preprocess( diff --git a/dnn/src/fallback/convolution/algos.cpp b/dnn/src/fallback/convolution/algos.cpp index 9f39f62c..36cb709b 100644 --- a/dnn/src/fallback/convolution/algos.cpp +++ b/dnn/src/fallback/convolution/algos.cpp @@ -162,6 +162,7 @@ void kern_direct(const NCBKernParam& param) { } // namespace + /* ===================== fallback algo ===================== */ bool ConvolutionImpl::AlgoFallback::usable( @@ -461,7 +462,6 @@ SmallVector ConvolutionImpl::AlgoDefault::get_kimpl( } /////////////////////////// ConvolutionBackwardData ///////////////////// - /* ===================== naive algo ===================== */ bool ConvolutionBackwardDataImpl::AlgoNaive::usable( diff --git a/dnn/src/fallback/convolution/algos.h b/dnn/src/fallback/convolution/algos.h index b28ccf5d..216ec104 100644 --- a/dnn/src/fallback/convolution/algos.h +++ b/dnn/src/fallback/convolution/algos.h @@ -15,6 +15,7 @@ #include "src/fallback/conv_bias/algos.h" #include "src/fallback/convolution/opr_impl.h" #include "src/naive/convolution/helper.h" +#include "src/common/algo_chooser.h" namespace megdnn { namespace fallback { @@ -87,6 +88,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::NAIVE}; } + MEGDNN_DECL_ALGO_TYPE(FB_ALGO) }; class ConvolutionImpl::AlgoNaive final : public AlgoBase { @@ -108,6 +110,7 @@ public: static_cast(AlgoDataType::QUINT8X8X32)); return {support_data_type, AlgoCategory::NAIVE}; } + MEGDNN_DECL_ALGO_TYPE(FB_NAIVE) }; class ConvolutionImpl::AlgoDefault final : public AlgoBase { @@ -144,12 +147,19 @@ public: //! select matmul to the highest preference bool is_preferred(const NCBKernSizeParam& param) const override; + std::string param() const override { + std::string ret; + serialize_write_pod(m_algorithm, ret); + return ret; + } + static ConvBiasImpl::NCBKernSizeParam init_conv_bias_param( const NCBKernSizeParam& param); ConvAlgoTypePack get_algo_type() const override { return m_algorithm->get_algo_type(); } + MEGDNN_DECL_ALGO_TYPE(FB_DEFAULT) private: std::string m_name; @@ -168,6 +178,7 @@ public: ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override; bool is_naive() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(FB_NAIVE) }; class ConvolutionBackwardDataImpl::AlgoDirect final : public AlgoBase { @@ -180,6 +191,7 @@ public: const NCBKernSizeParam& param) const override; ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override; + MEGDNN_DECL_ALGO_TYPE(FB_DIRECT) }; class ConvolutionBackwardDataImpl::AlgoMatrixMul final : public AlgoBase { @@ -193,6 +205,7 @@ public: ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override; bool is_preferred(const NCBKernSizeParam& param) const override; + MEGDNN_DECL_ALGO_TYPE(FB_MATMUL) }; } // namespace fallback diff --git a/dnn/src/fallback/convolution/opr_impl.cpp b/dnn/src/fallback/convolution/opr_impl.cpp index d7b18914..ec7352ef 100644 --- a/dnn/src/fallback/convolution/opr_impl.cpp +++ b/dnn/src/fallback/convolution/opr_impl.cpp @@ -22,6 +22,10 @@ #include "midout.h" +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 +#include "src/arm_common/convolution/opr_impl.h" +#endif + #include #include @@ -31,39 +35,50 @@ using namespace megdnn; using namespace fallback; namespace { - template void incr_ptr(T*& dst, ptrdiff_t delta) { dst = reinterpret_cast(reinterpret_cast(dst) + delta); } + } // namespace class ConvolutionImpl::AlgoPack : NonCopyableObj { AlgoFallback algo_fallback; AlgoNaive algo_naive; SmallVector> refhold; - + SmallVector m_all_algos; + AlgoBase::Mapper m_all_algos_map; public: AlgoPack() { static CpuOprDelegationStorage<1> storage; auto conv_bias_opr = storage.get(); auto&& conv_bias_algo = - static_cast(conv_bias_opr)->algo_pack(); + static_cast(conv_bias_opr)->get_all_packed_algo(); for (auto&& algorithm : conv_bias_algo) { // fallback algo refhold.emplace_back(new AlgoDefault(algorithm)); - all_algos.emplace_back(refhold.back().get()); + m_all_algos.emplace_back(refhold.back().get()); } - all_algos.emplace_back(&algo_fallback); - all_algos.emplace_back(&algo_naive); + m_all_algos.emplace_back(&algo_fallback); + m_all_algos.emplace_back(&algo_naive); + + for (auto&& algo : m_all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } - SmallVector all_algos; + + const SmallVector& all_algos() const { return m_all_algos; } + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -SmallVector ConvolutionImpl::algo_pack() { - static AlgoPack sl_algo_pack; - return sl_algo_pack.all_algos; +const ConvolutionImpl::AlgoPack& ConvolutionImpl::algo_pack() { + static AlgoPack algo_pack; + return algo_pack; +} + +SmallVector ConvolutionImpl::get_all_packed_algo() { + return algo_pack().all_algos(); } SmallVector ConvolutionImpl::select_algo_type( @@ -71,7 +86,7 @@ SmallVector ConvolutionImpl::select_algo_type( megdnn_assert(nr_type_contain(target_type.data_type), "ConvBias algo selection only support one type"); SmallVector algos; - for (auto&& algo : algo_pack()) { + for (auto&& algo : get_all_packed_algo()) { auto algo_type = algo->get_algo_type(); if (contain_data_type(algo_type.data_type, target_type.data_type) && algo_type.algo_category == target_type.algo_category) { @@ -94,7 +109,7 @@ void ConvolutionImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_workspace workspace) { auto fparam = make_ncb_kern_param(src, filter, dst, preprocessed_filter, workspace); - ConvolutionImpl::Algorithm* algo = get_algorithm(fparam, workspace.size); + auto&& algo = get_algorithm(fparam, workspace.size); if (!is_naive_algo(algo) && NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) { exec_with_ncb_kern(fparam, algo); @@ -116,9 +131,10 @@ void ConvolutionImpl::exec_preprocess(const TensorLayout& src_layout, workspace); //! should not pass workspace_size limit otherwise can not find match algo - ConvolutionImpl::Algorithm* algo = get_algorithm(fparam); - if (!is_naive_algo(algo) && NCB_ALGO_FUNC(get_preprocess_workspace, algo, - fparam) <= workspace.size) { + auto&& algo = get_algorithm(fparam); + if (!is_naive_algo(algo) && + NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam) <= + workspace.size) { exec_preprocess_with_ncb_kern(fparam, algo); } else { naive::ConvolutionForwardImpl::exec_preprocess( @@ -132,7 +148,7 @@ size_t ConvolutionImpl::get_workspace_in_bytes( const PreprocessedFilter* preprocessed_filter) { auto fparam = make_ncb_kern_size_param(src, filter, dst, preprocessed_filter); - Algorithm* algo = get_algorithm(fparam); + auto&& algo = get_algorithm(fparam); if (is_naive_algo(algo)) { return naive::ConvolutionForwardImpl::get_workspace_in_bytes( src, filter, dst, preprocessed_filter); @@ -145,7 +161,7 @@ size_t ConvolutionImpl::get_preprocess_workspace_in_bytes( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) { auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); - Algorithm* algo = get_algorithm(fparam); + auto&& algo = get_algorithm(fparam); if (is_naive_algo(algo)) { return naive::ConvolutionForwardImpl::get_preprocess_workspace_in_bytes( src, filter, dst); @@ -158,7 +174,7 @@ SmallVector ConvolutionImpl::deduce_preprocessed_filter_layout( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) { auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); - Algorithm* algo = get_algorithm(fparam); + auto&& algo = get_algorithm(fparam); if (is_naive_algo(algo)) { return naive::ConvolutionForwardImpl::deduce_preprocessed_filter_layout( src, filter, dst); @@ -333,7 +349,7 @@ std::vector ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) { std::vector ret; std::vector prefer_algos; - for (auto&& i : algo_pack()) { + for (auto&& i : get_all_packed_algo()) { if (i->usable(param, AlgoSelectionStrategy::FULL_RUN)) { if (i->is_preferred(param)) { prefer_algos.push_back(i); @@ -346,10 +362,34 @@ ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) { return ret; } +ConvolutionImpl::Algorithm* ConvolutionImpl::get_algo_from_desc( + const AlgorithmDesc& desc) const { + if (!desc.valid()) { + return nullptr; + } else { + switch (desc.handle_type) { + case Handle::HandleType::FALLBACK: { + const auto& map = algo_pack().all_algos_map(); + megdnn_assert(map.find(desc) != map.end()); + return map.at(desc); + } + case Handle::HandleType::NAIVE: { + auto algo = static_cast(handle()) + ->default_conv_fwd_algo(); + megdnn_assert(algo->info().desc == desc); + return algo; + } + default: + megdnn_throw("Unknown handle type"); + return nullptr; + } + } +} + ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm( const NCBKernSizeParam& param, size_t workspace_size) { - if (auto set = execution_policy().algorithm) { - return set; + if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { + return algo; } if (!m_prev_selected_algo || memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { @@ -405,20 +445,31 @@ class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { AlgoNaive algo_naive; AlgoDirect algo_direct; AlgoMatrixMul algo_matmul; + SmallVector m_all_algos; + AlgoBase::Mapper m_all_algos_map; public: AlgoPack() { - all_algos.emplace_back(&algo_matmul); - all_algos.emplace_back(&algo_direct); - all_algos.emplace_back(&algo_naive); + m_all_algos.emplace_back(&algo_matmul); + m_all_algos.emplace_back(&algo_direct); + m_all_algos.emplace_back(&algo_naive); + + for (auto&& algo : m_all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } - SmallVector all_algos; + const SmallVector& all_algos() const { return m_all_algos; } + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; +const ConvolutionBackwardDataImpl::AlgoPack& +ConvolutionBackwardDataImpl::algo_pack() { + static AlgoPack algo_pack; + return algo_pack; +} SmallVector -ConvolutionBackwardDataImpl::algo_pack() { - static AlgoPack sl_algo_pack; - return sl_algo_pack.all_algos; +ConvolutionBackwardDataImpl::get_all_packed_algo() { + return algo_pack().all_algos(); } void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter, @@ -545,7 +596,7 @@ void ConvolutionBackwardDataImpl::exec_with_ncb_kern( auto p1g = param; auto group = p1g.filter_meta.group; p1g.filter_meta.group = 1; - auto algo = get_algorithm(p1g); + auto&& algo = get_algorithm(p1g); auto kptr = ncb_1g_dispatch_kern(algo, p1g); if (group == 1 || static_cast(algo)->is_naive()) { auto run = [kptr, param]() { kptr(param); }; @@ -597,9 +648,11 @@ size_t ConvolutionBackwardDataImpl::get_workspace_with_ncb( if (param.filter_meta.group != 1) { auto p1g = param; p1g.filter_meta.group = 1; - return ncb_1g_get_workspace(get_algorithm(p1g), p1g); + auto algo = get_algorithm(p1g); + return ncb_1g_get_workspace(algo, p1g); } - return ncb_1g_get_workspace(get_algorithm(param), param); + auto algo = get_algorithm(param); + return ncb_1g_get_workspace(algo, param); } std::vector @@ -664,7 +717,7 @@ ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( const NCBKernSizeParam& param) { std::vector ret; std::vector prefer_algos; - for (auto&& i : algo_pack()) { + for (auto&& i : get_all_packed_algo()) { if (i->usable(this, param)) { if (i->is_preferred(param)) { prefer_algos.push_back(i); @@ -697,9 +750,42 @@ ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic( } ConvolutionBackwardDataImpl::Algorithm* +ConvolutionBackwardDataImpl::get_algo_from_desc( + const AlgorithmDesc& desc) const { + if (!desc.valid()) { + return nullptr; + } else { + switch (desc.handle_type) { + case Handle::HandleType::FALLBACK: { + const auto& map = algo_pack().all_algos_map(); + megdnn_assert(map.find(desc) != map.end()); + return map.at(desc); + } +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 + case Handle::HandleType::ARM_COMMON: + case Handle::HandleType::AARCH64: + case Handle::HandleType::ARMV7: + return arm_common::ConvolutionBackwardDataImpl:: + get_algo_from_desc(desc); +#endif + case Handle::HandleType::NAIVE: { + auto algo = static_cast(handle()) + ->default_conv_bwd_data_algo(); + megdnn_assert(algo->info().desc == desc); + return algo; + } + default: + megdnn_throw("Unknown handle type"); + return nullptr; + } + } +} + + +ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) { - if (auto set = execution_policy().algorithm) { - return set; + if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { + return algo; } if (!m_prev_selected_algo || memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { diff --git a/dnn/src/fallback/convolution/opr_impl.h b/dnn/src/fallback/convolution/opr_impl.h index 7ad66242..59fd40ae 100644 --- a/dnn/src/fallback/convolution/opr_impl.h +++ b/dnn/src/fallback/convolution/opr_impl.h @@ -10,8 +10,11 @@ */ #pragma once +#include +#include #include "megdnn/oprs/base.h" #include "src/common/utils.h" +#include "src/common/algo_base.h" #include "src/fallback/handle.h" #include "src/naive/convolution/opr_impl.h" @@ -198,6 +201,14 @@ public: AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::FALLBACK; } + + enum class AlgoType : uint32_t { + //! fallback + FB_ALGO = 1 << 0, + FB_NAIVE, + FB_DEFAULT, + }; + virtual ~AlgoBase() = default; virtual bool usable(const NCBKernSizeParam& param, AlgoSelectionStrategy) const = 0; @@ -235,12 +246,13 @@ public: //! get the type of the algo virtual ConvAlgoTypePack get_algo_type() const = 0; + using Mapper = std::unordered_map; }; /** * \brief get all the algorithm for the opr. */ - virtual SmallVector algo_pack(); + virtual SmallVector get_all_packed_algo(); /** * \brief select algo according to input algo type @@ -268,11 +280,12 @@ protected: class AlgoPack; private: + NCBKernSizeParam m_prev_selected_algo_sizep; Algorithm* m_prev_selected_algo = nullptr; + Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; bool is_naive_algo(ConvolutionImpl::Algorithm* algo); - //! get algorithm set by user or by heuristic Algorithm* get_algorithm( const NCBKernSizeParam& param, size_t workspace_size = std::numeric_limits::max()); @@ -290,6 +303,9 @@ private: SmallVector suggest_algo_category_order( const NCBKernSizeParam& param) const; + +public: + static const AlgoPack& algo_pack(); }; class ConvolutionBackwardDataImpl : public naive::ConvolutionBackwardDataImpl { @@ -374,6 +390,49 @@ public: protected: using ncb_kern_t = thin_function; + class AlgoBase : public Algorithm { + protected: + ~AlgoBase() = default; + + public: + AlgoBase() : Algorithm() { + m_handle_type = Handle::HandleType::FALLBACK; + } + enum class AlgoType : uint32_t { + //! fallback + FB_NAIVE = 1 << 0, + FB_DIRECT, + FB_MATMUL, + +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 + ARM_COMMON_DIRECT_STRD1_DOT_INT8X8X32 = 1 << 8, + ARM_COMMON_DIRECT_STRD2_DOT_INT8X8X32, + ARM_COMMON_DIRECT_STRD1_DOT_QU8, + ARM_COMMON_DIRECT_STRD2_DOT_QU8 +#endif + }; + + virtual bool usable(ConvolutionBackwardDataImpl* opr, + const NCBKernSizeParam& param) const = 0; + virtual size_t get_workspace(ConvolutionBackwardDataImpl* opr, + const NCBKernSizeParam& param) const = 0; + virtual ncb_kern_t dispatch_kern( + ConvolutionBackwardDataImpl* opr, + const NCBKernSizeParam& param) const = 0; + bool usable_reproducible(ConvolutionBackwardDataImpl* opr, + const NCBKernSizeParam& param, + bool reproducible = true) const { + return (!reproducible || is_reproducible()) && usable(opr, param); + } + virtual bool is_preferred(const NCBKernSizeParam&) const { + return false; + } + //! if the algo is naive, it will not split by group + virtual bool is_naive() const { return false; } + using Mapper = std::unordered_map; + }; + +protected: //! default impl calls ncb_1g_dispatch_kern() virtual void exec_with_ncb_kern(const NCBKernParam& param); @@ -408,38 +467,11 @@ protected: const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, bool reproducible = false); - class AlgoBase : public Algorithm { - protected: - ~AlgoBase() = default; - - public: - AlgoBase() : Algorithm() { - m_handle_type = Handle::HandleType::FALLBACK; - } - virtual bool usable(ConvolutionBackwardDataImpl* opr, - const NCBKernSizeParam& param) const = 0; - virtual size_t get_workspace(ConvolutionBackwardDataImpl* opr, - const NCBKernSizeParam& param) const = 0; - virtual ncb_kern_t dispatch_kern( - ConvolutionBackwardDataImpl* opr, - const NCBKernSizeParam& param) const = 0; - bool usable_reproducible(ConvolutionBackwardDataImpl* opr, - const NCBKernSizeParam& param, - bool reproducible = true) const { - return (!reproducible || is_reproducible()) && usable(opr, param); - } - virtual bool is_preferred(const NCBKernSizeParam&) const { - return false; - } - //! if the algo is naive, it will not split by group - virtual bool is_naive() const { return false; } - }; - static bool is_matrix_mul_preferred(const NCBKernSizeParam& param); /** * \brief get all the algorithm for the opr. */ - virtual SmallVector algo_pack(); + virtual SmallVector get_all_packed_algo(); private: NCBKernSizeParam m_prev_selected_algo_sizep; @@ -461,6 +493,11 @@ private: class AlgoDirect; class AlgoMatrixMul; class AlgoPack; + Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; + +public: + //! maintain all the algos of in the opr of fallback + static const AlgoPack& algo_pack(); }; } // namespace fallback diff --git a/dnn/src/fallback/matrix_mul/algos.cpp b/dnn/src/fallback/matrix_mul/algos.cpp index d3b78292..708e7fe5 100644 --- a/dnn/src/fallback/matrix_mul/algos.cpp +++ b/dnn/src/fallback/matrix_mul/algos.cpp @@ -41,6 +41,8 @@ void f32_8x12x1_kern(const MatrixMulImpl::KernParam& kern_param) { } } // anonymous namespace +////////////////////// AlgoF32K8x12x1 /////////////////////////// + bool MatrixMulImpl::AlgoF32K8x12x1::usable( const KernSizeParam& kern_size_param) const { return kern_size_param.compute_mode == diff --git a/dnn/src/fallback/matrix_mul/algos.h b/dnn/src/fallback/matrix_mul/algos.h index 8abb64bc..f0cd51be 100644 --- a/dnn/src/fallback/matrix_mul/algos.h +++ b/dnn/src/fallback/matrix_mul/algos.h @@ -11,8 +11,10 @@ #pragma once +#include #include "src/fallback/matrix_mul/opr_impl.h" #include "src/fallback/matrix_mul/gemm_common.h" +#include "src/common/algo_base.h" namespace megdnn { namespace fallback { @@ -24,6 +26,7 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; + MEGDNN_DECL_ALGO_TYPE(FB_F32K8x12x1) MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; @@ -37,6 +40,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_DECL_ALGO_TYPE(FB_GEMV) MEGDNN_OVERRIDE_MATMUL_DESC( 8, 16, 1, 4, static_cast( diff --git a/dnn/src/fallback/matrix_mul/opr_impl.cpp b/dnn/src/fallback/matrix_mul/opr_impl.cpp index 6ba8c0b6..c82c65ba 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.cpp +++ b/dnn/src/fallback/matrix_mul/opr_impl.cpp @@ -9,7 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/fallback/matrix_mul/opr_impl.h" +#include +#include "megdnn/oprs/base.h" #include "src/common/metahelper.h" #include "src/common/utils.h" #include "src/fallback/matrix_mul/algos.h" @@ -19,23 +21,43 @@ #include "src/naive/matrix_mul/opr_impl.h" #include "src/common/algo_chooser.h" +#if MEGDNN_X86 +#include "src/x86/matrix_mul/opr_impl.h" +#elif MEGDNN_AARCH64 +#include "src/aarch64/matrix_mul/opr_impl.h" +#elif MEGDNN_ARMV7 +#include "src/armv7/matrix_mul/opr_impl.h" +#endif + using namespace megdnn; using namespace fallback; class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoF32K8x12x1 f32_k8x12x1; -public: AlgoGemv gemv; + SmallVector m_all_algos; + AlgoBase::Mapper m_all_algos_map; + +public: AlgoPack() { - all_algos.emplace_back(&gemv); - all_algos.emplace_back(&f32_k8x12x1); + m_all_algos.emplace_back(&gemv); + m_all_algos.emplace_back(&f32_k8x12x1); + for (auto&& algo : m_all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } - SmallVector all_algos; + + const SmallVector& all_algos() const { return m_all_algos; } + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -SmallVector MatrixMulImpl::algo_pack() { - static AlgoPack s_algo_pack; - return s_algo_pack.all_algos; +const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { + static AlgoPack algo_pack; + return algo_pack; +} + +SmallVector MatrixMulImpl::get_all_packed_algo() { + return algo_pack().all_algos(); } SmallVector MatrixMulImpl::select_algo_type( @@ -43,7 +65,7 @@ SmallVector MatrixMulImpl::select_algo_type( megdnn_assert(nr_type_contain(index.data_type), "Matmul algo selection only support one type"); SmallVector algos; - for (auto&& algo : algo_pack()) { + for (auto&& algo : get_all_packed_algo()){ auto algo_desc = algo->matmul_description(); if (contain_data_type(algo_desc.algo_type.data_type, index.data_type) && @@ -58,7 +80,7 @@ std::vector MatrixMulImpl::get_all_algorithms( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { std::vector gemm_algos, gemv_algos; auto kern_size_param = make_kern_size_param(A, B, C); - for (auto&& algo : algo_pack()) { + for (auto&& algo : get_all_packed_algo()) { if (algo->usable(kern_size_param)) { if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { // simple gemv @@ -72,15 +94,48 @@ std::vector MatrixMulImpl::get_all_algorithms( return gemv_algos; } +MatrixMulImpl::AlgoBase* MatrixMulImpl::get_algo_from_desc( + const AlgorithmDesc& desc) { + if (!desc.valid()) { + return nullptr; + } else { + switch (desc.handle_type) { + case Handle::HandleType::FALLBACK: { + const auto& map = algo_pack().all_algos_map(); + megdnn_assert(map.find(desc) != map.end()); + return map.at(desc); + }; + +#if MEGDNN_X86 + case Handle::HandleType::X86: + return x86::MatrixMulImpl::get_algo_from_desc(desc); +#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 + case Handle::HandleType::ARM_COMMON: + return arm_common::MatrixMulImpl::get_algo_from_desc(desc); +#if MEGDNN_AARCH64 + case Handle::HandleType::AARCH64: + return aarch64::MatrixMulImpl::get_algo_from_desc(desc); +#else + case Handle::HandleType::ARMV7: + return armv7::MatrixMulImpl::get_algo_from_desc(desc); +#endif +#endif + default: + megdnn_throw("Unknown handle type"); + return {}; + } + } +} + MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, size_t workspace_limit_in_bytes, bool reproducible) { auto kern_size_param = make_kern_size_param(A, B, C); - if (auto algo = execution_policy().algorithm) { - megdnn_assert(static_cast(algo)->get_workspace( - kern_size_param) < workspace_limit_in_bytes); - auto cur = megdnn::get_reproducible_algo( - static_cast(algo), reproducible); + if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { + megdnn_assert(algo->get_workspace(kern_size_param) < + workspace_limit_in_bytes); + auto cur = megdnn::get_reproducible_algo(algo, + reproducible); if (cur) return cur; megdnn_throw( diff --git a/dnn/src/fallback/matrix_mul/opr_impl.h b/dnn/src/fallback/matrix_mul/opr_impl.h index 65e87a3a..f740b435 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.h +++ b/dnn/src/fallback/matrix_mul/opr_impl.h @@ -10,10 +10,12 @@ * implied. */ #pragma once +#include #include "megdnn/opr_param_defs.h" +#include "megdnn/oprs/base.h" +#include "src/common/algo_base.h" #include "src/common/utils.h" #include "src/naive/matrix_mul/opr_impl.h" -#include namespace megdnn { @@ -104,6 +106,77 @@ public: public: AlgoBase() { m_handle_type = Handle::HandleType::FALLBACK; } + enum class AlgoType : uint32_t { + //! fallback + FB_F32K8x12x1 = 1 << 0, + FB_GEMV, + +#if MEGDNN_X86 + //! x86 + X86_F32_BLAS = 1 << 8, + X86_F32_MKL_PACKA, + X86_INT8X8X32_AVX2_2X4X16, + X86_INT8X8X32_AVX2_4X16X2, + X86_INT8X8X16_AVX2, + X86_INT8X8X16_SSE, + X86_INT8X8X32_SSE_4X8X2, + X86_F32_MK8_8X8, + X86_INT8X8X32_VNNI, + X86_INT8X8X32_MKLDNN, +#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 + ARM_COMMON_INT8X8X16 = 1 << 8, + ARM_COMMON_INT8X8X32_GEMV, + ARM_COMMON_INT8X8X32_GEMV_MK4, + ARM_COMMON_INT8X8X32_GEMV_MK4_DOT, + ARM_COMMON_F32_GEMV_MK4, + ARM_COMMON_F16_GEMV, + ARM_COMMON_GEVM, +#if MEGDNN_AARCH64 + AARCH64_F32_K8X12X1 = 1 << 16, + AARCH64_F32_MK4_K8X12X1, + AARCH64_F32_K4X16X1, + AARCH64_F32_MK4_4x16, + AARCH64_F32_GEMV, + AARCH64_F16_K8X24X1, + AARCH64_F16_MK8_8X8, + AARCH64_INT8X8X32_K8X12X4_DOTPROD, + AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD, + AARCH64_INT8X8X32_MK4_4X4X16, + AARCH64_INT8X8X32_K4X4X16, + AARCH64_INT8X8X32_K8X8X8, + AARCH64_INT8X8X16_K8X8X8, + AARCH64_INT8X8X16_K4X4X16, + AARCH64_INT8X8X16_MK4_16X12X4, + AARCH64_INT8X8X16_MK4_K8X8X8, + AARCH64_INT8X8X16_MK4_4X4X8, + AARCH64_INT16X16X32_K12X8X1, + AARCH64_INT16X16X32_MK8_8X8, + AARCH64_QUINT8_K8X8X4_DOTPROD, + AARCH64_QUINT8_GEMV_DOTPROD, + AARCH64_QUINT8_K8X8X8, +#else + ARMV7_F32 = 1 << 16, + ARMV7_F32_MK4_PACK_4X12, + ARMV7_F32_MK4_4x8, + ARMV7_F16_K4X16X1, + ARMV7_F16_MK8_4X8, + ARMV7_INT8_K6X8X4, + ARMV7_QUINT8_K4X8X4, + ARMV7_INT8_MK4_8X4X4_DOTPROD, + ARMV7_F32_GEMV, + ARMV7_INT8X8X32_K4X2X16, + ARMV7_INT8X8X32_K4X8X8, + ARMV7_QUINT8_K4X8X8, + ARMV7_INT8X8X16_K4X2X16, + ARMV7_INT8X8X16_K4X8X8, + ARMV7_INT8X8X16_MK4_K8X8X4, + ARMV7_INT16X16X32_K12X4X1, + ARMV7_INT16X16X32_MK8_4X8, + ARMV7_INT8X8X32_MK4_4X2X16 +#endif +#endif + }; + enum class AlgoSet : uint32_t { ALGO_TYPE_GEMM = 0, ALGO_TYPE_GEMV = 1, @@ -152,12 +225,23 @@ public: return (!reproducible || is_reproducible()) && preferred(param); }; virtual MatmulDescription matmul_description() const = 0; + + using Mapper = std::unordered_map; }; +private: + class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 + class AlgoGemv; + class AlgoPack; + //! maintain all the algos of in the opr of fallback + static const AlgoPack& algo_pack(); + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); +public: + /** * \brief get all the algorithm for the opr. */ - virtual SmallVector algo_pack(); + virtual SmallVector get_all_packed_algo(); /** * \brief select algo according to input algo type @@ -183,10 +267,6 @@ protected: size_t workspace_limit_in_bytes, bool reproducible) override; -private: - class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 - class AlgoGemv; - class AlgoPack; }; } // namespace fallback diff --git a/dnn/src/naive/convolution/algorithms.h b/dnn/src/naive/convolution/algorithms.h index 3ae94239..6ed1ba52 100644 --- a/dnn/src/naive/convolution/algorithms.h +++ b/dnn/src/naive/convolution/algorithms.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/oprs.h" @@ -14,43 +15,38 @@ namespace megdnn { namespace naive { -class DefaultConvolutionForwardAlgorithm final: - public megdnn::ConvolutionForward::Algorithm { - bool is_reproducible() const override - { return true; } - const char* name() const override - { return "DEFAULT"; } +class DefaultConvolutionForwardAlgorithm final + : public megdnn::ConvolutionForward::Algorithm { + bool is_reproducible() const override { return true; } + const char* name() const override { return "DEFAULT"; } + uint32_t type() const override { return 0; } }; -class DefaultConvolutionBackwardDataAlgorithm final: - public megdnn::ConvolutionBackwardData::Algorithm { - bool is_reproducible() const override - { return true; } - const char* name() const override - { return "DEFAULT"; } +class DefaultConvolutionBackwardDataAlgorithm final + : public megdnn::ConvolutionBackwardData::Algorithm { + bool is_reproducible() const override { return true; } + const char* name() const override { return "DEFAULT"; } + uint32_t type() const override { return 0; } }; -class DefaultConvolutionBackwardFilterAlgorithm final: - public megdnn::ConvolutionBackwardFilter::Algorithm { - bool is_reproducible() const override - { return true; } - const char* name() const override - { return "DEFAULT"; } +class DefaultConvolutionBackwardFilterAlgorithm final + : public megdnn::ConvolutionBackwardFilter::Algorithm { + bool is_reproducible() const override { return true; } + const char* name() const override { return "DEFAULT"; } + uint32_t type() const override { return 0; } }; -class DefaultConvBiasForwardAlgorithm final: - public megdnn::ConvBiasForward::Algorithm { - bool is_reproducible() const override - { return true; } - const char* name() const override - { return "DEFAULT"; } +class DefaultConvBiasForwardAlgorithm final + : public megdnn::ConvBiasForward::Algorithm { + bool is_reproducible() const override { return true; } + const char* name() const override { return "DEFAULT"; } + uint32_t type() const override { return 0; } }; class DefaultBatchConvBiasForwardAlgorithm final : public megdnn::BatchConvBiasForward::Algorithm { - bool is_reproducible() const override - { return true; } - const char* name() const override - { return "DEFAULT"; } + bool is_reproducible() const override { return true; } + const char* name() const override { return "DEFAULT"; } + uint32_t type() const override { return 0; } }; -} // namespace naive -} // namespace megdnn +} // namespace naive +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/convolution3d/algorithms.h b/dnn/src/naive/convolution3d/algorithms.h index a5c5218a..4d7652cb 100644 --- a/dnn/src/naive/convolution3d/algorithms.h +++ b/dnn/src/naive/convolution3d/algorithms.h @@ -18,16 +18,19 @@ class DefaultConvolution3DForwardAlgorithm final : public megdnn::Convolution3DForward::Algorithm { bool is_reproducible() const override { return true; } const char* name() const override { return "DEFAULT"; } + uint32_t type() const override { return 0; } }; class DefaultConvolution3DBackwardDataAlgorithm final : public megdnn::Convolution3DBackwardData::Algorithm { bool is_reproducible() const override { return true; } const char* name() const override { return "DEFAULT"; } + uint32_t type() const override { return 0; } }; class DefaultConvolution3DBackwardFilterAlgorithm final : public megdnn::Convolution3DBackwardFilter::Algorithm { bool is_reproducible() const override { return true; } const char* name() const override { return "DEFAULT"; } + uint32_t type() const override { return 0; } }; } // namespace naive diff --git a/dnn/src/naive/local_share/algorithms.h b/dnn/src/naive/local_share/algorithms.h index 5b5c8865..fb52c21d 100644 --- a/dnn/src/naive/local_share/algorithms.h +++ b/dnn/src/naive/local_share/algorithms.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/oprs.h" @@ -14,28 +15,25 @@ namespace megdnn { namespace naive { -class DefaultLocalShareForwardAlgorithm final: - public megdnn::LocalShareForward::Algorithm { - bool is_reproducible() const override - { return true; } - const char* name() const override - { return "DEFAULT"; } +class DefaultLocalShareForwardAlgorithm final + : public megdnn::LocalShareForward::Algorithm { + bool is_reproducible() const override { return true; } + const char* name() const override { return "DEFAULT"; } + uint32_t type() const override { return 0; } }; -class DefaultLocalShareBackwardDataAlgorithm final: - public megdnn::LocalShareBackwardData::Algorithm { - bool is_reproducible() const override - { return true; } - const char* name() const override - { return "DEFAULT"; } +class DefaultLocalShareBackwardDataAlgorithm final + : public megdnn::LocalShareBackwardData::Algorithm { + bool is_reproducible() const override { return true; } + const char* name() const override { return "DEFAULT"; } + uint32_t type() const override { return 0; } }; -class DefaultLocalShareBackwardFilterAlgorithm final: - public megdnn::LocalShareBackwardFilter::Algorithm { - bool is_reproducible() const override - { return true; } - const char* name() const override - { return "DEFAULT"; } +class DefaultLocalShareBackwardFilterAlgorithm final + : public megdnn::LocalShareBackwardFilter::Algorithm { + bool is_reproducible() const override { return true; } + const char* name() const override { return "DEFAULT"; } + uint32_t type() const override { return 0; } }; -} // namespace naive -} // namespace megdnn +} // namespace naive +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/backward_data/algo.cpp b/dnn/src/rocm/convolution/backward_data/algo.cpp index 8a14527e..30df20d4 100644 --- a/dnn/src/rocm/convolution/backward_data/algo.cpp +++ b/dnn/src/rocm/convolution/backward_data/algo.cpp @@ -23,8 +23,13 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { non_miopen_algos.push_back(&matmul); non_miopen_algos.push_back(&chanwise); miopen_algos.push_back(&miopen); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } +MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl) ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack; ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( diff --git a/dnn/src/rocm/convolution/backward_data/algo.h b/dnn/src/rocm/convolution/backward_data/algo.h index e67c4e99..81b21843 100644 --- a/dnn/src/rocm/convolution/backward_data/algo.h +++ b/dnn/src/rocm/convolution/backward_data/algo.h @@ -12,6 +12,11 @@ #pragma once #include "src/rocm/convolution/helper.h" +#include "src/common/utils.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" + +#include namespace megdnn { namespace rocm { @@ -25,6 +30,13 @@ protected: ~AlgoBase() = default; public: + enum class AlgoType : uint32_t { + ROCM_MIOPEN, + ROCM_MATMUL, + ROCM_CHANWISE + }; + using Mapper = std::unordered_map; + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } struct SizeArgs { HandleImpl* handle; @@ -103,6 +115,13 @@ public: } bool is_miopen() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) + std::string param() const override { + std::string ret; + serialize_write_pod(m_is_reproducible, ret); + return ret; + } + static convolution::MIOpenCache sm_miopen_algo_cache; static convolution::MIOpenCache sm_miopen_ws_cache; @@ -119,6 +138,7 @@ public: const char* name() const override { return "MATMUL"; } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(ROCM_MATMUL) }; class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase { @@ -129,15 +149,14 @@ public: const char* name() const override { return "CHANNEL_WISE"; } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(ROCM_CHANWISE) }; -class ConvolutionBackwardDataImpl::AlgoPack { +class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { // defined in miopen.cpp void fill_miopen_algos(); - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator=(const AlgoPack&) = delete; - + AlgoBase::Mapper m_all_algos_map; public: AlgoPack(); @@ -148,6 +167,7 @@ public: std::vector //! all algorithms all_algos, miopen_algos, non_miopen_algos; + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; } // namespace rocm diff --git a/dnn/src/rocm/convolution/backward_filter/algo.cpp b/dnn/src/rocm/convolution/backward_filter/algo.cpp index 8b01d13d..35b31ca7 100644 --- a/dnn/src/rocm/convolution/backward_filter/algo.cpp +++ b/dnn/src/rocm/convolution/backward_filter/algo.cpp @@ -24,8 +24,13 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() { non_miopen_algos.push_back(&chanwise); non_miopen_algos.push_back(all_algos.back()); miopen_algos.push_back(&miopen); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } +MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardFilterImpl) ConvolutionBackwardFilterImpl::AlgoPack ConvolutionBackwardFilterImpl::sm_algo_pack; diff --git a/dnn/src/rocm/convolution/backward_filter/algo.h b/dnn/src/rocm/convolution/backward_filter/algo.h index 30074e2a..96f46dce 100644 --- a/dnn/src/rocm/convolution/backward_filter/algo.h +++ b/dnn/src/rocm/convolution/backward_filter/algo.h @@ -13,6 +13,8 @@ #include #include "src/rocm/convolution/helper.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" namespace megdnn { namespace rocm { @@ -26,6 +28,12 @@ protected: ~AlgoBase() = default; public: + enum class AlgoType : uint32_t { + ROCM_MIOPEN, + ROCM_MATMUL, + ROCM_CHANWISE + }; + using Mapper = std::unordered_map; AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } struct SizeArgs { HandleImpl* handle; @@ -103,6 +111,13 @@ public: } bool is_miopen() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) + std::string param() const override { + std::string ret; + serialize_write_pod(m_is_reproducible, ret); + return ret; + } + static convolution::MIOpenCache sm_miopen_algo_cache; static convolution::MIOpenCache sm_miopen_ws_cache; @@ -119,6 +134,7 @@ public: const char* name() const override { return "MATMUL"; } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(ROCM_MATMUL) }; class ConvolutionBackwardFilterImpl::AlgoChanwise final : public AlgoBase { @@ -129,14 +145,13 @@ public: const char* name() const override { return "CHANNEL_WISE"; } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(ROCM_CHANWISE) }; -class ConvolutionBackwardFilterImpl::AlgoPack { +class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj { void fill_miopen_algos(); - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator=(const AlgoPack&) = delete; - + AlgoBase::Mapper m_all_algos_map; public: AlgoPack(); @@ -147,6 +162,7 @@ public: std::vector //! all algorithms all_algos, miopen_algos, non_miopen_algos; + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; } // namespace rocm diff --git a/dnn/src/rocm/convolution/forward/algo.cpp b/dnn/src/rocm/convolution/forward/algo.cpp index df4db044..b1a5382f 100644 --- a/dnn/src/rocm/convolution/forward/algo.cpp +++ b/dnn/src/rocm/convolution/forward/algo.cpp @@ -30,8 +30,14 @@ ConvolutionForwardImpl::AlgoPack::AlgoPack() { all_algos.push_back(&batched_matrix_mul); all_algos.push_back(&chanwise); all_algos.push_back(&miopen); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } +MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionForwardImpl) + ConvolutionForwardImpl::AlgoPack ConvolutionForwardImpl::sm_algo_pack; ConvolutionForwardImpl::AlgoBase::SizeArgs::SizeArgs(ConvolutionForwardImpl* o, diff --git a/dnn/src/rocm/convolution/forward/algo.h b/dnn/src/rocm/convolution/forward/algo.h index b5906ba6..baf95579 100644 --- a/dnn/src/rocm/convolution/forward/algo.h +++ b/dnn/src/rocm/convolution/forward/algo.h @@ -6,13 +6,16 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/oprs.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" #include "src/common/utils.h" #include "src/rocm/convolution/helper.h" #include "src/rocm/convolution/opr_impl.h" @@ -32,6 +35,16 @@ protected: ~AlgoBase() = default; public: + enum class AlgoType : uint32_t { + ROCM_MIOPEN, + ROCM_MATMUL, + ROCM_INPLACE_MATMUL, + ROCM_1X1, + ROCM_1X1_LARGE_BATCH, + ROCM_CHANWISE + }; + using Mapper = std::unordered_map; + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } struct SizeArgs : public convolution::ForwardSizeArgs { ConvolutionForwardImpl* opr; @@ -99,6 +112,12 @@ public: const char* name() const override { return "MIOpenConvolutionForward"; } bool is_miopen() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) + std::string param() const override { + std::string ret; + serialize_write_pod(m_is_reproducible, ret); + return ret; + } static convolution::MIOpenCache sm_miopen_algo_cache; @@ -116,6 +135,7 @@ public: const char* name() const override { return "MATMUL"; } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(ROCM_MATMUL) }; //! compute small matmul in the kernel @@ -127,6 +147,7 @@ public: const char* name() const override { return "INPLACE_MATMUL"; } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(ROCM_INPLACE_MATMUL) }; //! optimized 1x1 conv @@ -141,6 +162,7 @@ public: const char* name() const override { return "1x1"; } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(ROCM_1X1) }; //! optimized 1x1 conv when input data batchsize is larger than 32 @@ -155,6 +177,7 @@ public: const char* name() const override { return "LARGE_BATCH_1x1"; } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(ROCM_1X1_LARGE_BATCH) }; class ConvolutionForwardImpl::AlgoChanwise final : public AlgoBase { @@ -165,15 +188,14 @@ public: const char* name() const override { return "CHANNEL_WISE"; } bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(ROCM_CHANWISE) }; -class ConvolutionForwardImpl::AlgoPack { +class ConvolutionForwardImpl::AlgoPack : NonCopyableObj { // defined in miopen.cpp void fill_miopen_algos(); - AlgoPack(const AlgoPack&) = delete; - AlgoPack& operator=(const AlgoPack&) = delete; - + AlgoBase::Mapper m_all_algos_map; public: AlgoPack(); @@ -187,9 +209,11 @@ public: std::vector //! all algorithms all_algos, miopen_algos, non_miopen_algos; + + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -} // namespace rocm -} // namespace megdnn +} // namespace rocm +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/opr_impl.h b/dnn/src/rocm/convolution/opr_impl.h index a19fbc89..abb8c08a 100644 --- a/dnn/src/rocm/convolution/opr_impl.h +++ b/dnn/src/rocm/convolution/opr_impl.h @@ -23,19 +23,14 @@ public: _megdnn_tensor_out dst, const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) override; - std::vector get_all_algorithms( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst, - size_t workspace_limit_in_bytes, - bool reproducible) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& src, - const CanonizedFilterMeta& filter, - const TensorLayout& dst, - size_t workspace_limit_in_bytes, - bool reproducible); + AlgorithmInfo get_algorithm_info_heuristic( + const TensorLayout& src, const CanonizedFilterMeta& filter, + const TensorLayout& dst, size_t workspace_limit_in_bytes, + bool reproducible) { + return get_algorithm_heuristic(src, filter, dst, + workspace_limit_in_bytes, reproducible) + ->info(); + } size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, @@ -71,8 +66,23 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); private: + std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst, + size_t workspace_limit_in_bytes, + bool reproducible) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const CanonizedFilterMeta& filter, + const TensorLayout& dst, + size_t workspace_limit_in_bytes, + bool reproducible); + static AlgoPack sm_algo_pack; }; @@ -81,6 +91,30 @@ public: using ConvolutionBackwardData::ConvolutionBackwardData; void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + AlgorithmInfo get_algorithm_info_heuristic( + const CanonizedFilterMeta& filter, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_limit_in_bytes, + bool reproducible) { + return get_algorithm_heuristic(filter, diff, grad, + workspace_limit_in_bytes, reproducible) + ->info(); + } + size_t get_workspace_in_bytes(const TensorLayout& filter, + const TensorLayout& diff, + const TensorLayout& grad) override; + const char* get_algorithm_set_name() const override; + + class AlgoBase; + class AlgoMIOpen; + class AlgoMatmul; + class AlgoChanwise; + + class AlgoPack; + + static const AlgoPack& algo_pack() { return sm_algo_pack; } + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + +private: std::vector get_all_algorithms( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad) override; @@ -94,7 +128,25 @@ public: const TensorLayout& grad, size_t workspace_limit_in_bytes, bool reproducible); - size_t get_workspace_in_bytes(const TensorLayout& filter, + + static AlgoPack sm_algo_pack; +}; + +class ConvolutionBackwardFilterImpl : public ConvolutionBackwardFilter { +public: + using ConvolutionBackwardFilter::ConvolutionBackwardFilter; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src, + const TensorLayout& diff, + const CanonizedFilterMeta& grad, + size_t workspace_limit_in_bytes, + bool reproducible) { + return get_algorithm_heuristic(src, diff, grad, + workspace_limit_in_bytes, reproducible) + ->info(); + } + size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) override; const char* get_algorithm_set_name() const override; @@ -106,17 +158,10 @@ public: class AlgoPack; + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); static const AlgoPack& algo_pack() { return sm_algo_pack; } private: - static AlgoPack sm_algo_pack; -}; - -class ConvolutionBackwardFilterImpl : public ConvolutionBackwardFilter { -public: - using ConvolutionBackwardFilter::ConvolutionBackwardFilter; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) override; std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) override; @@ -130,25 +175,11 @@ public: const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, bool reproducible); - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) override; - const char* get_algorithm_set_name() const override; - class AlgoBase; - class AlgoMIOpen; - class AlgoMatmul; - class AlgoChanwise; - - class AlgoPack; - - static const AlgoPack& algo_pack() { return sm_algo_pack; } - -private: static AlgoPack sm_algo_pack; }; -} // namespace rocm -} // namespace megdnn +} // namespace rocm +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/x86/conv_bias/f32/algos.h b/dnn/src/x86/conv_bias/f32/algos.h index 94a4b141..c15e56c5 100644 --- a/dnn/src/x86/conv_bias/f32/algos.h +++ b/dnn/src/x86/conv_bias/f32/algos.h @@ -50,6 +50,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(X86_DIRECT) }; /* ===================== direct-stride2 algo ===================== */ @@ -85,6 +86,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(X86_DIRECT_STRD2) }; /* =========================== winograd ======================== */ class ConvBiasImpl::AlgoFP32WinogradF63_8x8 final : public AlgoBase { @@ -100,6 +102,7 @@ public: return m_name.c_str(); } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); + MEGDNN_DECL_ALGO_TYPE(X86_WINOGRAD_F63_8x8_F32) }; class ConvBiasImpl::AlgoFP32WinogradF23_8x8 final : public AlgoBase { @@ -115,6 +118,7 @@ public: return m_name.c_str(); } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); + MEGDNN_DECL_ALGO_TYPE(X86_WINOGRAD_F23_8x8_F32) }; #if MEGDNN_X86_WITH_MKL_DNN @@ -159,6 +163,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(X86_MKLDNN) }; #endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/x86/conv_bias/int8/algos.h b/dnn/src/x86/conv_bias/int8/algos.h index 65d015ef..4830d36b 100644 --- a/dnn/src/x86/conv_bias/int8/algos.h +++ b/dnn/src/x86/conv_bias/int8/algos.h @@ -37,6 +37,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(X86_CHANWISE_AVX2_STRD1_QINT8) }; /* ===================== avx2 stride2 chanwise algo ===================== */ @@ -61,6 +62,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(X86_CHANWISE_AVX2_STRD2_QINT8) }; /* ===================== avx2 stride1 direct algo ===================== */ @@ -85,6 +87,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(X86_DIRECT_AVX2_STRD1_INT8) }; /* ================== avx2 int8 direct conv stride2 algo ================== */ @@ -109,6 +112,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(X86_DIRECT_AVX2_STRD2_INT8) }; #if MEGDNN_X86_WITH_MKL_DNN @@ -149,6 +153,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; } + MEGDNN_DECL_ALGO_TYPE(X86_MKLDNN_QINT8) }; /* ===================== mkldnn qint8 matmul algo ===================== */ class ConvBiasImpl::AlgoMkldnnMatmulQint8 final : public AlgoBase { @@ -177,6 +182,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; } + MEGDNN_DECL_ALGO_TYPE(X86_MKLDNN_MATMUL_QINT8) }; #endif diff --git a/dnn/src/x86/conv_bias/opr_impl.cpp b/dnn/src/x86/conv_bias/opr_impl.cpp index 936a9738..36fa51f1 100644 --- a/dnn/src/x86/conv_bias/opr_impl.cpp +++ b/dnn/src/x86/conv_bias/opr_impl.cpp @@ -45,6 +45,9 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoMkldnnConv mkldnn_conv_fp32; #endif SmallVector> refhold; + SmallVector m_all_no_winograd_algo; + SmallVector m_winograd_algos; + fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map; public: AlgoPack() { @@ -52,21 +55,21 @@ public: //! But now mkldnn algo preference issue with NCHW->NHWC->NCHW #if MEGDNN_X86_WITH_MKL_DNN //! Create the mkldnn algo - all_algos.emplace_back(&mkldnn_conv_fp32); - all_algos.emplace_back(&mkldnn_matmul_qint8); - all_algos.emplace_back(&mkldnn_qint8); + m_all_no_winograd_algo.emplace_back(&mkldnn_conv_fp32); + m_all_no_winograd_algo.emplace_back(&mkldnn_matmul_qint8); + m_all_no_winograd_algo.emplace_back(&mkldnn_qint8); #endif - all_algos.emplace_back(&stride1_direct); - all_algos.emplace_back(&stride2_direct); - all_algos.emplace_back(&avx2_stride1_chanwsie_qint8); - all_algos.emplace_back(&avx2_stride2_chanwsie_qint8); - all_algos.emplace_back(&avx2_stride1_direct_int8); - all_algos.emplace_back(&avx2_stride2_direct); + m_all_no_winograd_algo.emplace_back(&stride1_direct); + m_all_no_winograd_algo.emplace_back(&stride2_direct); + m_all_no_winograd_algo.emplace_back(&avx2_stride1_chanwsie_qint8); + m_all_no_winograd_algo.emplace_back(&avx2_stride2_chanwsie_qint8); + m_all_no_winograd_algo.emplace_back(&avx2_stride1_direct_int8); + m_all_no_winograd_algo.emplace_back(&avx2_stride2_direct); static CpuOprDelegationStorage<> storage; auto matmul_opr = storage.get(); auto&& matmul_algos = - static_cast(matmul_opr)->algo_pack(); + static_cast(matmul_opr)->get_all_packed_algo(); for (auto&& algo : matmul_algos) { if (is_fallback_or_naive(algo)) continue; @@ -74,25 +77,52 @@ public: refhold.emplace_back(new AlgoFP32WinogradF63_8x8( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoFP32WinogradF23_8x8( static_cast(algo), tile_size)); - winograd_algos.emplace_back(refhold.back().get()); + m_winograd_algos.emplace_back(refhold.back().get()); } } + + for (auto&& algo : m_all_no_winograd_algo) { + m_all_algos_map.emplace(algo->info().desc, algo); + } + for (auto&& algo : m_winograd_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } + } + const SmallVector& all_no_winograd_algo() + const { + return m_all_no_winograd_algo; + } + const SmallVector& winograd_algos() + const { + return m_winograd_algos; } - SmallVector all_algos; - SmallVector winograd_algos; + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -SmallVector ConvBiasImpl::algo_pack() { - static AlgoPack sl_algo_pack; - auto&& algos = fallback::ConvBiasImpl::algo_pack(); - algos.insert(algos.begin(), sl_algo_pack.all_algos.begin(), - sl_algo_pack.all_algos.end()); - algos.insert(algos.end(), sl_algo_pack.winograd_algos.begin(), - sl_algo_pack.winograd_algos.end()); +const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { + static AlgoPack algo_pack; + return algo_pack; +} + +fallback::ConvBiasImpl::AlgoBase* ConvBiasImpl::get_algo_from_desc( + const AlgorithmDesc& desc) { + megdnn_assert(algo_pack().all_algos_map().find(desc) != + algo_pack().all_algos_map().end()); + return algo_pack().all_algos_map().at(desc); +} + +SmallVector +ConvBiasImpl::get_all_packed_algo() { + auto&& algos = fallback::ConvBiasImpl::get_all_packed_algo(); + algos.insert(algos.begin(), algo_pack().all_no_winograd_algo().begin(), + algo_pack().all_no_winograd_algo().end()); + algos.insert(algos.end(), algo_pack().winograd_algos().begin(), + algo_pack().winograd_algos().end()); + return std::move(algos); } diff --git a/dnn/src/x86/conv_bias/opr_impl.h b/dnn/src/x86/conv_bias/opr_impl.h index 49f9c731..6a977919 100644 --- a/dnn/src/x86/conv_bias/opr_impl.h +++ b/dnn/src/x86/conv_bias/opr_impl.h @@ -28,25 +28,10 @@ public: }; bool is_thread_safe() const override { return true; } - SmallVector algo_pack() override; + SmallVector get_all_packed_algo() override; SmallVector suggest_algo_category_order( const NCBKernSizeParam& param) const override; - class AlgoDirect; - class AlgoDirectStride2; - class AlgoFP32WinogradF63_8x8; - class AlgoFP32WinogradF23_8x8; - class AlgoDirectAvx2Stride1Int8; - class AlgoAVX2DirectConvStride2; - class AlgoChanWiseAvx2Stride1Qint8; - class AlgoChanWiseAvx2Stride2Qint8; -#if MEGDNN_X86_WITH_MKL_DNN - class AlgoMkldnnConv; - class AlgoMkldnnQint8; - class AlgoMkldnnMatmulQint8; -#endif - class AlgoPack; - /** * \brief Adjust tensor layouts to fulfill alignment requirements. * OW2 would be 8-byte aligned. @@ -62,6 +47,26 @@ public: bool is_matmul_quantized_prefer( const ConvBiasImpl::NCBKernSizeParam& ncb_param) const override; + static fallback::ConvBiasImpl::AlgoBase* get_algo_from_desc( + const AlgorithmDesc& desc); + +private: + class AlgoDirect; + class AlgoDirectStride2; + class AlgoFP32WinogradF63_8x8; + class AlgoFP32WinogradF23_8x8; + class AlgoDirectAvx2Stride1Int8; + class AlgoAVX2DirectConvStride2; + class AlgoChanWiseAvx2Stride1Qint8; + class AlgoChanWiseAvx2Stride2Qint8; +#if MEGDNN_X86_WITH_MKL_DNN + class AlgoMkldnnConv; + class AlgoMkldnnQint8; + class AlgoMkldnnMatmulQint8; +#endif + class AlgoPack; + + static const AlgoPack& algo_pack(); }; } // namespace x86 diff --git a/dnn/src/x86/matrix_mul/algos.h b/dnn/src/x86/matrix_mul/algos.h index 79b26004..b17c197d 100644 --- a/dnn/src/x86/matrix_mul/algos.h +++ b/dnn/src/x86/matrix_mul/algos.h @@ -27,6 +27,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) + MEGDNN_DECL_ALGO_TYPE(X86_F32_BLAS) }; #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM @@ -48,6 +49,7 @@ public: InnerBlockSize get_inner_block_size() const override{ return {8, 16, 1}; }; MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) + MEGDNN_DECL_ALGO_TYPE(X86_F32_MKL_PACKA) }; #endif @@ -59,6 +61,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(X86_INT8X8X32_AVX2_2X4X16) }; class MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2 : public AlgoBase { @@ -69,6 +72,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(X86_INT8X8X32_AVX2_4X16X2) }; class MatrixMulImpl::AlgoInt8x8x16AVX2 : public AlgoBase { @@ -84,6 +88,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(X86_INT8X8X16_AVX2) }; class MatrixMulImpl::AlgoInt8x8x16SSE : public AlgoBase { @@ -99,6 +104,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(X86_INT8X8X16_SSE) }; class MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2 : public AlgoBase { @@ -109,6 +115,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(X86_INT8X8X32_SSE_4X8X2) }; class MatrixMulImpl::AlgoF32MK8_8x8 : public AlgoBase { @@ -120,6 +127,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4, AlgoDataType::FLOAT32, MK8) + MEGDNN_DECL_ALGO_TYPE(X86_F32_MK8_8X8) }; #if MEGDNN_X86_WITH_VNNI @@ -131,6 +139,7 @@ public: size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(X86_INT8X8X32_VNNI) }; #endif @@ -144,6 +153,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; PackMode packmode() const override { return PackMode::NO_PACK; } MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) + MEGDNN_DECL_ALGO_TYPE(X86_INT8X8X32_MKLDNN) }; #endif } // namespace x86 diff --git a/dnn/src/x86/matrix_mul/opr_impl.cpp b/dnn/src/x86/matrix_mul/opr_impl.cpp index 8d6f1ee6..fb19c1ff 100644 --- a/dnn/src/x86/matrix_mul/opr_impl.cpp +++ b/dnn/src/x86/matrix_mul/opr_impl.cpp @@ -35,35 +35,58 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoInt8x8x16SSE algoint8x8x16sse_m4n8k2; AlgoF32MK8_8x8 algof32mk8_8x8; + SmallVector m_all_algos; + fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; + public: AlgoPack() { if (is_supported(SIMDType::VNNI)) { #if MEGDNN_X86_WITH_VNNI - all_algos.emplace_back(&algoint8x8x32vnni); + m_all_algos.emplace_back(&algoint8x8x32vnni); #endif } - all_algos.emplace_back(&algoint8x8x32avx2_m4n16k2); - all_algos.emplace_back(&algoint8x8x16avx2_m4n16k2); - all_algos.emplace_back(&algoint8x8x32avx2_m2n4k16); - all_algos.emplace_back(&algoint8x8x32sse_m4n8k2); - all_algos.emplace_back(&algoint8x8x16sse_m4n8k2); - all_algos.emplace_back(&algof32mk8_8x8); + m_all_algos.emplace_back(&algoint8x8x32avx2_m4n16k2); + m_all_algos.emplace_back(&algoint8x8x16avx2_m4n16k2); + m_all_algos.emplace_back(&algoint8x8x32avx2_m2n4k16); + m_all_algos.emplace_back(&algoint8x8x32sse_m4n8k2); + m_all_algos.emplace_back(&algoint8x8x16sse_m4n8k2); + m_all_algos.emplace_back(&algof32mk8_8x8); #if MEGDNN_X86_WITH_MKL_DNN - all_algos.emplace_back(&algoint8x8x32mkldnn); + m_all_algos.emplace_back(&algoint8x8x32mkldnn); #endif - all_algos.emplace_back(&f32blas); + m_all_algos.emplace_back(&f32blas); #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM - all_algos.emplace_back(&f32mkl_packa); + m_all_algos.emplace_back(&f32mkl_packa); #endif + + for (auto&& algo : m_all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } + } + + const SmallVector& all_algos() const { + return m_all_algos; } - SmallVector all_algos; + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -SmallVector MatrixMulImpl::algo_pack() { - static AlgoPack s_algo_pack; - auto&& algos = fallback::MatrixMulImpl::algo_pack(); - algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), - s_algo_pack.all_algos.end()); +const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { + static AlgoPack algo_pack; + return algo_pack; +} + +fallback::MatrixMulImpl::AlgoBase* MatrixMulImpl::get_algo_from_desc( + const AlgorithmDesc& desc) { + megdnn_assert(algo_pack().all_algos_map().find(desc) != + algo_pack().all_algos_map().end()); + return algo_pack().all_algos_map().at(desc); +} + +SmallVector +MatrixMulImpl::get_all_packed_algo() { + auto&& algos = fallback::MatrixMulImpl::get_all_packed_algo(); + algos.insert(algos.begin(), algo_pack().all_algos().begin(), + algo_pack().all_algos().end()); return std::move(algos); } diff --git a/dnn/src/x86/matrix_mul/opr_impl.h b/dnn/src/x86/matrix_mul/opr_impl.h index 3c8a0f90..cc474773 100644 --- a/dnn/src/x86/matrix_mul/opr_impl.h +++ b/dnn/src/x86/matrix_mul/opr_impl.h @@ -42,9 +42,13 @@ public: bool is_thread_safe() const override { return true; } - SmallVector algo_pack() override; + SmallVector get_all_packed_algo() + override; -protected: + static fallback::MatrixMulImpl::AlgoBase* get_algo_from_desc( + const AlgorithmDesc& desc); + +private: class AlgoF32Blas; #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM class AlgoF32MKLPackA; @@ -64,6 +68,9 @@ protected: class AlgoInt8x8x16SSE; class AlgoPack; class AlgoF32MK8_8x8; + +public: + static const AlgoPack& algo_pack(); }; } // namespace x86 diff --git a/dnn/test/common/benchmarker.h b/dnn/test/common/benchmarker.h index da7809d5..51903994 100644 --- a/dnn/test/common/benchmarker.h +++ b/dnn/test/common/benchmarker.h @@ -376,16 +376,16 @@ float algo_benchmark(Benchmarker& benchmark, TensorLayoutArray layouts, auto opr = benchmark.opr(); opr->param() = benchmark.param(); proxy.deduce_layout(opr, layouts); - auto algos = OprAlgoProxy::get_all_algorithms(opr, layouts); + auto algos = OprAlgoProxy::get_all_algorithms_info(opr, layouts); float min_used = std::numeric_limits::max(); bool execed = false; for (auto i : algos) { - if (std::regex_match(i->name(), + if (std::regex_match(i.name, std::regex("(" + algo_base + ")(.*)"))) { - opr->execution_policy().algorithm = i; + opr->execution_policy().algo = i; auto used = benchmark.exec(layouts); min_used = std::min(min_used, used); - printf("run algo: %s used: %f ms min_used: %f ms\n", i->name(), + printf("run algo: %s used: %f ms min_used: %f ms\n", i.name.c_str(), used, min_used); execed = true; } diff --git a/dnn/test/common/checker.h b/dnn/test/common/checker.h index 2b4c4717..a16e89fb 100644 --- a/dnn/test/common/checker.h +++ b/dnn/test/common/checker.h @@ -24,15 +24,15 @@ #include // clang-format off -#if defined(__has_feature) +#if defined(__has_feature) #if __has_feature(address_sanitizer) #define MEGDNN_TEST_ASAN 1 #else #define MEGDNN_TEST_ASAN 0 #endif -#elif defined(__SANITIZE_ADDRESS__) +#elif defined(__SANITIZE_ADDRESS__) #define MEGDNN_TEST_ASAN 1 -#else +#else #define MEGDNN_TEST_ASAN 0 #endif // clang-format on @@ -392,8 +392,8 @@ TensorND TensorValue(const TensorShape& shape, T dtype, tensor.layout = {shape, dtype}; tensor.raw_ptr = static_cast(malloc(tensor.layout.span().dist_byte())); - megdnn_assert(values.size() == tensor.layout.total_nr_elems(), "%zu == %zu", values.size(), - tensor.layout.total_nr_elems()); + megdnn_assert(values.size() == tensor.layout.total_nr_elems(), "%zu == %zu", + values.size(), tensor.layout.total_nr_elems()); auto ptr = tensor.ptr::ctype>(); for (const auto& v : values) { *ptr++ = typename DTypeTrait::ctype(v); @@ -456,28 +456,29 @@ public: : m_algo{algo}, m_require_algo{require_algo} {} void operator()(Opr* opr, const CheckerHelper::TensorValueArray& arr) { - opr->execution_policy().algorithm = nullptr; TensorLayoutArray layouts; for (auto&& val : arr) { layouts.push_back(val.layout); } if (m_require_algo && *m_require_algo) { - auto algo = OprAlgoProxy::get_algorithm_heuristic(opr, layouts); + auto algo = + OprAlgoProxy::get_algorithm_info_heuristic(opr, layouts); if (m_name.empty()) { - ASSERT_EQ(m_algo->name(), algo->name()); + ASSERT_EQ(m_algo->name(), algo.name.c_str()); } else { ASSERT_TRUE(std::regex_match( - algo->name(), std::regex("(" + m_name + ")(.*)"))); + algo.name.c_str(), std::regex("(" + m_name + ")(.*)"))); } } else { if (m_name.empty()) { - opr->execution_policy().algorithm = m_algo; + opr->execution_policy().algo = m_algo->info(); return; } else { - for (auto i : OprAlgoProxy::get_all_algorithms(opr, layouts)) { - if (std::regex_match(i->name(), + for (auto i : + OprAlgoProxy::get_all_algorithms_info(opr, layouts)) { + if (std::regex_match(i.name, std::regex("(" + m_name + ")(.*)"))) { - opr->execution_policy().algorithm = i; + opr->execution_policy().algo = i; return; } } diff --git a/dnn/test/common/convolution.cpp b/dnn/test/common/convolution.cpp index 67a0db3f..416e2f0f 100644 --- a/dnn/test/common/convolution.cpp +++ b/dnn/test/common/convolution.cpp @@ -11,6 +11,7 @@ #include "test/common/checker.h" #include "test/common/convolution.h" +#include "src/common/algo_base.h" #include #include @@ -52,7 +53,7 @@ std::vector convolution::get_args_common() { TensorShape{5, 2, i, i+1}, TensorShape{3, 2, 3, 4}); } - + return args; } @@ -73,7 +74,7 @@ std::vector convolution::get_args_padding() { TensorShape{5, 2, i, i+1}, TensorShape{3, 2, 3, 4}); } - + return args; } @@ -107,7 +108,7 @@ std::vector convolution::get_args_large_channel() { TensorShape{2, 20, i, i+1}, TensorShape{30, 20, 3, 4}); } - + return args; } @@ -126,7 +127,7 @@ std::vector convolution::get_args_1x1() { TensorShape{2, 20, i, i+1}, TensorShape{30, 20, 1, 1}); } - + return args; } @@ -145,7 +146,7 @@ std::vector convolution::get_args_large_filter() { TensorShape{2, 2, i, i+1}, TensorShape{3, 2, 7, 8}); } - + return args; } @@ -184,7 +185,7 @@ std::vector convolution::get_args_4x4() { TensorShape{4, 3, oh+3, oh+4}, TensorShape{2, 3, 4, 4}); } - + return args; } @@ -309,7 +310,7 @@ std::vector convolution::get_args_x86_winograd_algorithm() { TensorShape{2, ic_size, 102, 102}, TensorShape{8, ic_size, 3, 3}); } - + return args; } @@ -330,7 +331,7 @@ std::vector convolution::get_args_BRAIN_481() { TensorShape{3, 4, 16-margin, 15-margin}); } } - + return args; } @@ -470,9 +471,10 @@ void convolution::test_conv_config_combinations(int k_size, #define CONF_BOOL(var) for (int var: {0, 1}) - std::unordered_set used_algos; - std::unordered_set used_algos_bwd_data; - std::unordered_set + std::unordered_set used_algos; + std::unordered_set + used_algos_bwd_data; + std::unordered_set used_algos_bwd_flt; using Param = Convolution::Param; @@ -576,14 +578,14 @@ void convolution::test_conv_config_combinations(int k_size, float scale = 1.0f / sqrt(fshp[channel_start] * FH * FW); UniformFloatRNG rng(scale, 2 * scale); checker.set_rng(0, &rng).set_rng(1, &rng); - for (auto algo : opr->get_all_algorithms(ily, fly, oly)) { - used_algos.insert(algo); - opr->execution_policy().algorithm = algo; + for (auto algo : opr->get_all_algorithms_info(ily, fly, oly)) { + used_algos.insert(algo.desc); + opr->execution_policy().algo = algo; checker - .set_epsilon(eps_getter(dtype == 1, 0, algo->name())) + .set_epsilon(eps_getter(dtype == 1, 0, algo.name.c_str())) .execs({ishp, fshp, {}}); - opr->execution_policy().algorithm = nullptr; - ASSERT_TRUE(checker.prev_succ()) << errmsg(algo->name()); + opr->execution_policy().algo.reset(); + ASSERT_TRUE(checker.prev_succ()) << errmsg(algo.name.c_str()); } if (test_backward) { @@ -595,15 +597,15 @@ void convolution::test_conv_config_combinations(int k_size, auto opr = checker_bwd_data.opr(); opr->param() = param; - for (auto algo: opr->get_all_algorithms(fly, oly, ily)) { - used_algos_bwd_data.insert(algo); - opr->execution_policy().algorithm = algo; + for (auto algo: opr->get_all_algorithms_info(fly, oly, ily)) { + used_algos_bwd_data.insert(algo.desc); + opr->execution_policy().algo = algo; checker_bwd_data - .set_epsilon(eps_getter(dtype == 1, 1, algo->name())) + .set_epsilon(eps_getter(dtype == 1, 1, algo.name.c_str())) .execl({fly, oly, ily}); - opr->execution_policy().algorithm = nullptr; + opr->execution_policy().algo.reset(); ASSERT_TRUE(checker_bwd_data.prev_succ()) << - errmsg(algo->name()); + errmsg(algo.name.c_str()); } } if (test_backward) { @@ -616,38 +618,19 @@ void convolution::test_conv_config_combinations(int k_size, auto opr = checker_bwd_filter.opr(); opr->param() = param; - for (auto algo: opr->get_all_algorithms(ily, oly, fly)) { - used_algos_bwd_flt.insert(algo); - opr->execution_policy().algorithm = algo; + for (auto algo: opr->get_all_algorithms_info(ily, oly, fly)) { + used_algos_bwd_flt.insert(algo.desc); + opr->execution_policy().algo = algo; checker_bwd_filter - .set_epsilon(eps_getter(dtype == 1, 2, algo->name())) + .set_epsilon(eps_getter(dtype == 1, 2, algo.name.c_str())) .execl({ily, oly, fly}); - opr->execution_policy().algorithm = nullptr; + opr->execution_policy().algo.reset(); ASSERT_TRUE(checker_bwd_filter.prev_succ()) << - errmsg(algo->name()); + errmsg(algo.name.c_str()); } } - - //printf("%s\r", config2str().c_str()); - //fflush(stdout); } - //printf("tested algos: fwd:{"); - //for (auto i: used_algos) { - // printf(" %s", i->name()); - //} - //if (test_backward) { - // printf("} bwd_data:{"); - // for (auto i: used_algos_bwd_data) { - // printf(" %s", i->name()); - // } - // printf("} bwd_filter:{"); - // for (auto i: used_algos_bwd_flt) { - // printf(" %s", i->name()); - // } - //} - //printf("} \n"); } // vim: syntax=cpp.doxygen - diff --git a/dnn/test/common/opr_algo_proxy.h b/dnn/test/common/opr_algo_proxy.h index 5f4e0854..fa362756 100644 --- a/dnn/test/common/opr_algo_proxy.h +++ b/dnn/test/common/opr_algo_proxy.h @@ -22,31 +22,32 @@ struct AlgoProxy; template struct AlgoProxy { - static std::vector get_all_algorithms( + static std::vector get_all_algorithms_info( Opr* opr, TensorLayoutArray& layouts) { megdnn_assert(layouts.size() == 3); - return opr->get_all_algorithms(layouts[0], layouts[1], layouts[2]); + return opr->get_all_algorithms_info(layouts[0], layouts[1], layouts[2]); } - static typename Opr::Algorithm* get_algorithm_heuristic( + static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( Opr* opr, TensorLayoutArray& layouts) { megdnn_assert(layouts.size() == 3); - return opr->get_algorithm_heuristic(layouts[0], layouts[1], layouts[2]); + return opr->get_algorithm_info_heuristic(layouts[0], layouts[1], + layouts[2]); } }; template struct AlgoProxy { - static std::vector get_all_algorithms( + static std::vector get_all_algorithms_info( Opr* opr, TensorLayoutArray& layouts) { megdnn_assert(layouts.size() == 5); - return opr->get_all_algorithms(layouts[0], layouts[1], layouts[2], - layouts[3], layouts[4]); + return opr->get_all_algorithms_info(layouts[0], layouts[1], layouts[2], + layouts[3], layouts[4]); } - static typename Opr::Algorithm* get_algorithm_heuristic( + static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( Opr* opr, TensorLayoutArray& layouts) { megdnn_assert(layouts.size() == 5); - return opr->get_algorithm_heuristic(layouts[0], layouts[1], layouts[2], - layouts[3], layouts[4]); + return opr->get_algorithm_info_heuristic( + layouts[0], layouts[1], layouts[2], layouts[3], layouts[4]); } }; diff --git a/dnn/test/common/opr_proxy.h b/dnn/test/common/opr_proxy.h index 2a9e21f1..5d587415 100644 --- a/dnn/test/common/opr_proxy.h +++ b/dnn/test/common/opr_proxy.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -20,8 +21,6 @@ #include #include - - namespace megdnn { namespace test { @@ -142,7 +141,7 @@ struct OprProxyProfilingBase //! target algo setup by profiler; it can also be directly specified by the //! caller - typename Opr::Algorithm* target_algo = nullptr; + typename Opr::AlgorithmInfo target_algo_info; OprProxyProfilingBase(bool profile = false) { m_profiling = profile; } @@ -178,12 +177,12 @@ struct OprProxyProfilingTernary : public OprProxyProfilingBase { if (!Base::W.valid()) { Base::W = WorkspaceWrapper(opr->handle(), 0); } - if (Base::m_profiling && !Base::target_algo) { + if (Base::m_profiling && !Base::target_algo_info.valid()) { size_t min_time = std::numeric_limits::max(); - for (auto algo : - opr->get_all_algorithms(tensors[0].layout, tensors[1].layout, - tensors[2].layout)) { - opr->execution_policy().algorithm = algo; + for (auto algo : opr->get_all_algorithms_info(tensors[0].layout, + tensors[1].layout, + tensors[2].layout)) { + opr->execution_policy().algo = algo; auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout); @@ -202,18 +201,18 @@ struct OprProxyProfilingTernary : public OprProxyProfilingBase { megcoreSynchronize(opr->handle()->megcore_computing_handle()); timer.stop(); printf("%.3fms %s\n", timer.get_time_in_us() / 1e3, - algo->name()); + algo.name.c_str()); if (min_time > timer.get_time_in_us()) { min_time = timer.get_time_in_us(); - Base::target_algo = algo; + Base::target_algo_info = algo; } } - opr->execution_policy().algorithm = Base::target_algo; + opr->execution_policy().algo = Base::target_algo_info; auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout); Base::W.update(workspace_size); } - if (!Base::target_algo) { + if (!Base::target_algo_info.valid()) { auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout); Base::W.update(workspace_size); @@ -238,18 +237,19 @@ DEF_PROF3(LocalShareBackwardFilter); template <> struct OprProxy : public OprProxyProfilingTernary { - using OprProxyProfilingTernary::OprProxyProfilingTernary; + using OprProxyProfilingTernary< + ConvolutionForward>::OprProxyProfilingTernary; void exec(ConvolutionForward* opr, const TensorNDArray& tensors) { megdnn_assert(tensors.size() == 3); if (!Base::W.valid()) { Base::W = WorkspaceWrapper(opr->handle(), 0); } - if (Base::m_profiling && !Base::target_algo) { + if (Base::m_profiling && !Base::target_algo_info.desc.valid()) { size_t min_time = std::numeric_limits::max(); - for (auto algo : - opr->get_all_algorithms(tensors[0].layout, tensors[1].layout, - tensors[2].layout)) { - opr->execution_policy().algorithm = algo; + for (auto algo : opr->get_all_algorithms_info(tensors[0].layout, + tensors[1].layout, + tensors[2].layout)) { + opr->execution_policy().algo = algo; auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, nullptr); @@ -268,18 +268,19 @@ struct OprProxy megcoreSynchronize(opr->handle()->megcore_computing_handle()); timer.stop(); printf("%.3fms %s\n", timer.get_time_in_us() / 1e3, - algo->name()); + algo.name.c_str()); if (min_time > timer.get_time_in_us()) { min_time = timer.get_time_in_us(); - Base::target_algo = algo; + Base::target_algo_info = algo; } } - opr->execution_policy().algorithm = Base::target_algo; + opr->execution_policy().algo = Base::target_algo_info; auto workspace_size = opr->get_workspace_in_bytes( - tensors[0].layout, tensors[1].layout, tensors[2].layout, nullptr); + tensors[0].layout, tensors[1].layout, tensors[2].layout, + nullptr); Base::W.update(workspace_size); } - if (!Base::target_algo) { + if (!Base::target_algo_info.desc.valid()) { auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, nullptr); @@ -293,23 +294,25 @@ struct OprProxy template <> struct OprWeightPreprocessProxy : public OprProxyProfilingTernary { - using OprProxyProfilingTernary::OprProxyProfilingTernary; + using OprProxyProfilingTernary< + ConvolutionForward>::OprProxyProfilingTernary; void exec(ConvolutionForward* opr, const TensorNDArray& tensors) { megdnn_assert(tensors.size() == 3); if (!Base::W.valid()) { Base::W = WorkspaceWrapper(opr->handle(), 0); } - if (Base::m_profiling && !Base::target_algo) { + if (Base::m_profiling && !Base::target_algo_info.desc.valid()) { size_t min_time = std::numeric_limits::max(); - for (auto algo : - opr->get_all_algorithms(tensors[0].layout, tensors[1].layout, - tensors[2].layout)) { - opr->execution_policy().algorithm = algo; + for (auto algo : opr->get_all_algorithms_info(tensors[0].layout, + tensors[1].layout, + tensors[2].layout)) { + opr->execution_policy().algo = algo; - auto preprocess_tensors = weight_prerocess(opr, tensors, algo); + auto preprocess_tensors = + weight_prerocess(opr, tensors, algo.desc); megcoreSynchronize(opr->handle()->megcore_computing_handle()); ConvolutionForward::PreprocessedFilter preprocessed_filter{ - algo, *preprocess_tensors}; + nullptr, *preprocess_tensors}; auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, @@ -329,29 +332,29 @@ struct OprWeightPreprocessProxy megcoreSynchronize(opr->handle()->megcore_computing_handle()); timer.stop(); printf("%.3fms %s\n", timer.get_time_in_us() / 1e3, - algo->name()); + algo.name.c_str()); if (min_time > timer.get_time_in_us()) { min_time = timer.get_time_in_us(); - Base::target_algo = algo; + Base::target_algo_info = algo; } } - opr->execution_policy().algorithm = Base::target_algo; + opr->execution_policy().algo = Base::target_algo_info; auto preprocess_tensors = - weight_prerocess(opr, tensors, Base::target_algo); + weight_prerocess(opr, tensors, Base::target_algo_info.desc); megcoreSynchronize(opr->handle()->megcore_computing_handle()); ConvolutionForward::PreprocessedFilter preprocessed_filter{ - Base::target_algo, *preprocess_tensors}; + nullptr, *preprocess_tensors}; auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, &preprocessed_filter); Base::W.update(workspace_size); } auto preprocess_tensors = - weight_prerocess(opr, tensors, Base::target_algo); + weight_prerocess(opr, tensors, Base::target_algo_info.desc); megcoreSynchronize(opr->handle()->megcore_computing_handle()); ConvolutionForward::PreprocessedFilter preprocessed_filter{ - Base::target_algo, *preprocess_tensors}; - if (!Base::target_algo) { + nullptr, *preprocess_tensors}; + if (!Base::target_algo_info.valid()) { auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, &preprocessed_filter); @@ -364,13 +367,13 @@ struct OprWeightPreprocessProxy //! handle weight preprocess std::shared_ptr weight_prerocess( ConvolutionForward* opr, const TensorNDArray& tensors, - ConvolutionForward::Algorithm* algo) { + const ConvolutionForward::AlgorithmDesc&) { auto weight_perprocess_layouts = opr->deduce_preprocessed_filter_layout( tensors[0].layout, tensors[1].layout, tensors[2].layout); auto preprocessed_filter_tensors_ptr = alloc_tensors(opr->handle(), weight_perprocess_layouts); ConvolutionForward::PreprocessedFilter preprocessed_filter{ - algo, *preprocessed_filter_tensors_ptr}; + nullptr, *preprocessed_filter_tensors_ptr}; size_t preprocess_workspace_size = opr->get_preprocess_workspace_in_bytes(tensors[0].layout, tensors[1].layout, @@ -384,7 +387,6 @@ struct OprWeightPreprocessProxy } }; - template struct OprProxyProfiling5 : public OprProxyProfilingBase { using Base = OprProxyProfilingBase; @@ -394,13 +396,13 @@ struct OprProxyProfiling5 : public OprProxyProfilingBase { if (!Base::W.valid()) { Base::W = WorkspaceWrapper(opr->handle(), 0); } - if (Base::m_profiling && !Base::target_algo) { + if (Base::m_profiling && !Base::target_algo_info.valid()) { size_t min_time = std::numeric_limits::max(); - for (auto algo : - opr->get_all_algorithms(tensors[0].layout, tensors[1].layout, - tensors[2].layout, tensors[3].layout, - tensors[4].layout)) { - opr->execution_policy().algorithm = algo; + for (auto algo : opr->get_all_algorithms_info( + tensors[0].layout, tensors[1].layout, + tensors[2].layout, tensors[3].layout, + tensors[4].layout)) { + opr->execution_policy().algo = algo; auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, tensors[3].layout, tensors[4].layout); @@ -419,19 +421,19 @@ struct OprProxyProfiling5 : public OprProxyProfilingBase { megcoreSynchronize(opr->handle()->megcore_computing_handle()); timer.stop(); printf("%.3fms %s\n", timer.get_time_in_us() / 1e3, - algo->name()); + algo.name.c_str()); if (min_time > timer.get_time_in_us()) { min_time = timer.get_time_in_us(); - Base::target_algo = algo; + Base::target_algo_info = algo; } } - opr->execution_policy().algorithm = Base::target_algo; + opr->execution_policy().algo = Base::target_algo_info; auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, tensors[3].layout, tensors[4].layout); Base::W.update(workspace_size); } - if (!Base::target_algo) { + if (!Base::target_algo_info.valid()) { auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, tensors[3].layout, tensors[4].layout); @@ -461,13 +463,13 @@ struct OprProxy : public OprProxyProfiling5 { if (!Base::W.valid()) { Base::W = WorkspaceWrapper(opr->handle(), 0); } - if (Base::m_profiling && !Base::target_algo) { + if (Base::m_profiling && !Base::target_algo_info.desc.valid()) { size_t min_time = std::numeric_limits::max(); - for (auto algo : - opr->get_all_algorithms(tensors[0].layout, tensors[1].layout, - tensors[2].layout, tensors[3].layout, - tensors[4].layout)) { - opr->execution_policy().algorithm = algo; + for (auto algo : opr->get_all_algorithms_info( + tensors[0].layout, tensors[1].layout, + tensors[2].layout, tensors[3].layout, + tensors[4].layout)) { + opr->execution_policy().algo = algo; auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, tensors[3].layout, tensors[4].layout, nullptr); @@ -486,19 +488,19 @@ struct OprProxy : public OprProxyProfiling5 { megcoreSynchronize(opr->handle()->megcore_computing_handle()); timer.stop(); printf("%.3fms %s\n", timer.get_time_in_us() / 1e3, - algo->name()); + algo.name.c_str()); if (min_time > timer.get_time_in_us()) { min_time = timer.get_time_in_us(); - Base::target_algo = algo; + Base::target_algo_info = algo; } } - opr->execution_policy().algorithm = Base::target_algo; + opr->execution_policy().algo = Base::target_algo_info; auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, tensors[3].layout, tensors[4].layout, nullptr); Base::W.update(workspace_size); } - if (!Base::target_algo) { + if (!Base::target_algo_info.valid()) { auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, tensors[3].layout, tensors[4].layout, nullptr); @@ -518,18 +520,19 @@ struct OprWeightPreprocessProxy if (!Base::W.valid()) { Base::W = WorkspaceWrapper(opr->handle(), 0); } - if (Base::m_profiling && !Base::target_algo) { + if (Base::m_profiling && !Base::target_algo_info.valid()) { size_t min_time = std::numeric_limits::max(); - for (auto algo : - opr->get_all_algorithms(tensors[0].layout, tensors[1].layout, - tensors[2].layout, tensors[3].layout, - tensors[4].layout)) { - opr->execution_policy().algorithm = algo; + for (auto algo : opr->get_all_algorithms_info( + tensors[0].layout, tensors[1].layout, + tensors[2].layout, tensors[3].layout, + tensors[4].layout)) { + opr->execution_policy().algo = algo; - auto preprocess_tensors = weight_prerocess(opr, tensors, algo); + auto preprocess_tensors = + weight_prerocess(opr, tensors, algo.desc); megcoreSynchronize(opr->handle()->megcore_computing_handle()); ConvBiasForward::PreprocessedFilter preprocessed_filter{ - algo, *preprocess_tensors}; + nullptr, *preprocess_tensors}; auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, @@ -552,29 +555,29 @@ struct OprWeightPreprocessProxy megcoreSynchronize(opr->handle()->megcore_computing_handle()); timer.stop(); printf("%.3fms %s\n", timer.get_time_in_us() / 1e3, - algo->name()); + algo.name.c_str()); if (min_time > timer.get_time_in_us()) { min_time = timer.get_time_in_us(); - Base::target_algo = algo; + Base::target_algo_info = algo; } } - opr->execution_policy().algorithm = Base::target_algo; + opr->execution_policy().algo = Base::target_algo_info; auto preprocess_tensors = - weight_prerocess(opr, tensors, Base::target_algo); + weight_prerocess(opr, tensors, Base::target_algo_info.desc); megcoreSynchronize(opr->handle()->megcore_computing_handle()); ConvBiasForward::PreprocessedFilter preprocessed_filter{ - Base::target_algo, *preprocess_tensors}; + nullptr, *preprocess_tensors}; auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, tensors[3].layout, tensors[4].layout, &preprocessed_filter); Base::W.update(workspace_size); } auto preprocess_tensors = - weight_prerocess(opr, tensors, Base::target_algo); + weight_prerocess(opr, tensors, Base::target_algo_info.desc); megcoreSynchronize(opr->handle()->megcore_computing_handle()); ConvBiasForward::PreprocessedFilter preprocessed_filter{ - Base::target_algo, *preprocess_tensors}; - if (!Base::target_algo) { + nullptr, *preprocess_tensors}; + if (!Base::target_algo_info.valid()) { auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, tensors[3].layout, tensors[4].layout, &preprocessed_filter); @@ -587,14 +590,14 @@ struct OprWeightPreprocessProxy //! handle weight preprocess std::shared_ptr weight_prerocess( ConvBiasForward* opr, const TensorNDArray& tensors, - ConvBiasForward::Algorithm* algo) { + const ConvBiasForward::AlgorithmDesc&) { auto weight_perprocess_layouts = opr->deduce_preprocessed_filter_layout( tensors[0].layout, tensors[1].layout, tensors[2].layout, tensors[3].layout, tensors[4].layout); auto preprocessed_filter_tensors_ptr = alloc_tensors(opr->handle(), weight_perprocess_layouts); ConvBiasForward::PreprocessedFilter preprocessed_filter{ - algo, *preprocessed_filter_tensors_ptr}; + nullptr, *preprocessed_filter_tensors_ptr}; size_t preprocess_workspace_size = opr->get_preprocess_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, @@ -618,14 +621,14 @@ struct OprProxyProfiling8 : public OprProxyProfilingBase { if (!Base::W.valid()) { Base::W = WorkspaceWrapper(opr->handle(), 0); } - if (Base::m_profiling && !Base::target_algo) { + if (Base::m_profiling && !Base::target_algo_info.valid()) { size_t min_time = std::numeric_limits::max(); - for (auto algo : opr->get_all_algorithms( + for (auto algo : opr->get_all_algorithms_info( tensors[0].layout, tensors[1].layout, tensors[2].layout, tensors[3].layout, tensors[4].layout, tensors[5].layout, tensors[6].layout, tensors[7].layout)) { - opr->execution_policy().algorithm = algo; + opr->execution_policy().algo = algo; auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, tensors[3].layout, tensors[4].layout, tensors[5].layout, @@ -647,20 +650,20 @@ struct OprProxyProfiling8 : public OprProxyProfilingBase { megcoreSynchronize(opr->handle()->megcore_computing_handle()); timer.stop(); printf("%.3fms %s\n", timer.get_time_in_us() / 1e3, - algo->name()); + algo.name.c_str()); if (min_time > timer.get_time_in_us()) { min_time = timer.get_time_in_us(); - Base::target_algo = algo; + Base::target_algo_info = algo; } } - opr->execution_policy().algorithm = Base::target_algo; + opr->execution_policy().algo = Base::target_algo_info; auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, tensors[3].layout, tensors[4].layout, tensors[5].layout, tensors[6].layout, tensors[7].layout); Base::W.update(workspace_size); } - if (!Base::target_algo) { + if (!Base::target_algo_info.valid()) { auto workspace_size = opr->get_workspace_in_bytes( tensors[0].layout, tensors[1].layout, tensors[2].layout, tensors[3].layout, tensors[4].layout, tensors[5].layout, diff --git a/dnn/test/cuda/batch_conv_bias.cpp b/dnn/test/cuda/batch_conv_bias.cpp index 9d1a9704..66e34f8c 100644 --- a/dnn/test/cuda/batch_conv_bias.cpp +++ b/dnn/test/cuda/batch_conv_bias.cpp @@ -279,7 +279,7 @@ void benchmark_target_algo(Handle* handle, const std::vector& args, benchmarker.set_param(bparam); if (!algo) { - benchmarker.proxy()->target_algo = nullptr; + benchmarker.proxy()->target_algo_info.reset(); } auto time_in_ms = benchmarker.execs( diff --git a/dnn/test/cuda/chanwise_convolution.cpp b/dnn/test/cuda/chanwise_convolution.cpp index c29db534..626d5a30 100644 --- a/dnn/test/cuda/chanwise_convolution.cpp +++ b/dnn/test/cuda/chanwise_convolution.cpp @@ -514,7 +514,7 @@ TEST_F(CUDA, CHANWISE_CONVOLUTION_BENCH_ALL_ALGO_FWD) { auto run = [&](size_t N, size_t C, size_t IH, size_t IW, size_t FH, size_t FW) { - checker.proxy()->target_algo = nullptr; + checker.proxy()->target_algo_info.reset(); checker.execs({{N, C, IH, IW}, {C, 1, 1, FH, FW}, {}}); }; @@ -538,7 +538,7 @@ TEST_F(CUDA, CHANWISE_CONVOLUTION_BENCH_ALL_ALGO_BWD_DATA) { auto run = [&](size_t N, size_t C, size_t IH, size_t IW, size_t FH, size_t FW) { - checker.proxy()->target_algo = nullptr; + checker.proxy()->target_algo_info.reset(); checker.execs({{C, 1, 1, FH, FW}, {N, C, IH - FH + 1, IW - FW + 1}, {N, C, IH, IW}}); @@ -564,7 +564,7 @@ TEST_F(CUDA, CHANWISE_CONVOLUTION_BENCH_ALL_ALGO_BWD_FILTER) { auto run = [&](size_t N, size_t C, size_t IH, size_t IW, size_t FH, size_t FW) { - checker.proxy()->target_algo = nullptr; + checker.proxy()->target_algo_info.reset(); checker.execs({{N, C, IH, IW}, {N, C, IH - FH + 1, IW - FW + 1}, {C, 1, 1, FH, FW}}); @@ -614,7 +614,7 @@ TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_ALL_ALGO_FORWARD) { .set_dtype(2, dtype::Float32()) .set_rng(0, &rng) .set_rng(1, &rng); - bencher.proxy()->target_algo = nullptr; + bencher.proxy()->target_algo_info.reset(); auto time_in_ms_fp32 = bencher.execs({src, filter, {}}) / RUNS; bencher.set_param(param) @@ -623,10 +623,10 @@ TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_ALL_ALGO_FORWARD) { .set_dtype(2, dtype::Float16()) .set_rng(0, &rng) .set_rng(1, &rng); - bencher.proxy()->target_algo = nullptr; + bencher.proxy()->target_algo_info.reset(); auto time_in_ms_fp16 = bencher.execs({src, filter, {}}) / RUNS; - bencher.proxy()->target_algo = nullptr; + bencher.proxy()->target_algo_info.reset(); param.compute_mode = param::Convolution::ComputeMode::FLOAT32; bencher.set_param(param); auto time_in_ms_pseudo_fp16 = bencher.execs({src, filter, {}}) / RUNS; @@ -1022,7 +1022,7 @@ TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BWD_FILTER) { opr->param() = param; float bandwith = static_cast(flt_grad.total_nr_elems() + dst_grad.total_nr_elems() + - src.total_nr_elems()) / + src.total_nr_elems()) / (1024 * 1024 * 1024) * 1e3; bencher.set_param(param) .set_dtype(0, dtype::Float32()) diff --git a/dnn/test/cuda/conv_bias_int8.cpp b/dnn/test/cuda/conv_bias_int8.cpp index 5a3560c7..693da6ca 100644 --- a/dnn/test/cuda/conv_bias_int8.cpp +++ b/dnn/test/cuda/conv_bias_int8.cpp @@ -168,7 +168,7 @@ void benchmark_target_algo( benchmarker.set_param(param); if (!algo) { - benchmarker.proxy()->target_algo = nullptr; + benchmarker.proxy()->target_algo_info.reset(); } TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, filter{arg.co, arg.ci, arg.f, arg.f}, bias{1, arg.co, 1, 1}, @@ -327,7 +327,7 @@ void benchmark_target_algo_with_cudnn_tsc( benchmarker.set_param(param); if (!algo) { - benchmarker.proxy()->target_algo = nullptr; + benchmarker.proxy()->target_algo_info.reset(); } TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, filter{arg.co, arg.ci, arg.f, arg.f}, bias{1, arg.co, 1, 1}, diff --git a/dnn/test/cuda/convolution.cpp b/dnn/test/cuda/convolution.cpp index c8f47857..226809f5 100644 --- a/dnn/test/cuda/convolution.cpp +++ b/dnn/test/cuda/convolution.cpp @@ -429,7 +429,7 @@ TEST_F(CUDA, CONVOLUTION_FWD_BENCHMARK) { param.pad_h = param.pad_w = PH; param.compute_mode = param::Convolution::ComputeMode::DEFAULT; bench.set_param(param); - bench.proxy()->target_algo = nullptr; + bench.proxy()->target_algo_info.reset(); TensorLayout src{{N, IC, IH, IW}, dtype::Float32()}, filter{{OC, IC, FH, FH}, dtype::Float32()}; TensorLayout dst; @@ -440,13 +440,13 @@ TEST_F(CUDA, CONVOLUTION_FWD_BENCHMARK) { } auto time_ms_fp32 = bench.execl({src, filter, dst}) / RUNS; src.dtype = filter.dtype = dst.dtype = dtype::Float16(); - bench.proxy()->target_algo = nullptr; + bench.proxy()->target_algo_info.reset(); bench.set_dtype(0, dtype::Float16()) .set_dtype(1, dtype::Float16()) .set_dtype(2, dtype::Float16()); auto time_ms_true_fp16 = bench.execl({src, filter, dst}) / RUNS; param.compute_mode = param::Convolution::ComputeMode::FLOAT32; - bench.proxy()->target_algo = nullptr; + bench.proxy()->target_algo_info.reset(); bench.set_param(param); auto time_ms_pseudo_fp16 = bench.execl({src, filter, dst}) / RUNS; float flo = 2.0 * N * OC * IC * dst[2] * dst[3] * FH * FH; @@ -500,7 +500,7 @@ TEST_F(CUDA, CONVOLUTION_BWD_DATA_BENCHMARK) { param.pad_h = param.pad_w = PH; param.compute_mode = param::Convolution::ComputeMode::DEFAULT; bench.set_param(param); - bench.proxy()->target_algo = nullptr; + bench.proxy()->target_algo_info.reset(); TensorLayout src{{N, IC, IH, IW}, dtype::Float32()}, filter{{OC, IC, FH, FH}, dtype::Float32()}; TensorLayout dst; @@ -511,13 +511,13 @@ TEST_F(CUDA, CONVOLUTION_BWD_DATA_BENCHMARK) { } auto time_ms_fp32 = bench.execl({filter, dst, src}) / RUNS; src.dtype = filter.dtype = dst.dtype = dtype::Float16(); - bench.proxy()->target_algo = nullptr; + bench.proxy()->target_algo_info.reset(); bench.set_dtype(0, dtype::Float16()) .set_dtype(1, dtype::Float16()) .set_dtype(2, dtype::Float16()); auto time_ms_true_fp16 = bench.execl({filter, dst, src}) / RUNS; param.compute_mode = param::Convolution::ComputeMode::FLOAT32; - bench.proxy()->target_algo = nullptr; + bench.proxy()->target_algo_info.reset(); bench.set_param(param); auto time_ms_pseudo_fp16 = bench.execl({filter, dst, src}) / RUNS; float flo = 2.0 * N * OC * IC * dst[2] * dst[3] * FH * FH; @@ -571,7 +571,7 @@ TEST_F(CUDA, CONVOLUTION_BWD_FILTER_BENCHMARK) { param.pad_h = param.pad_w = PH; param.compute_mode = param::Convolution::ComputeMode::DEFAULT; bench.set_param(param); - bench.proxy()->target_algo = nullptr; + bench.proxy()->target_algo_info.reset(); TensorLayout src{{N, IC, IH, IW}, dtype::Float32()}, filter{{OC, IC, FH, FH}, dtype::Float32()}; TensorLayout dst; @@ -582,13 +582,13 @@ TEST_F(CUDA, CONVOLUTION_BWD_FILTER_BENCHMARK) { } auto time_ms_fp32 = bench.execl({src, dst, filter}) / RUNS; src.dtype = filter.dtype = dst.dtype = dtype::Float16(); - bench.proxy()->target_algo = nullptr; + bench.proxy()->target_algo_info.reset(); bench.set_dtype(0, dtype::Float16()) .set_dtype(1, dtype::Float16()) .set_dtype(2, dtype::Float16()); auto time_ms_true_fp16 = bench.execl({src, dst, filter}) / RUNS; param.compute_mode = param::Convolution::ComputeMode::FLOAT32; - bench.proxy()->target_algo = nullptr; + bench.proxy()->target_algo_info.reset(); bench.set_param(param); auto time_ms_pseudo_fp16 = bench.execl({src, dst, filter}) / RUNS; float flo = 2.0 * N * OC * IC * dst[2] * dst[3] * FH * FH; diff --git a/dnn/test/cuda/local_share.cpp b/dnn/test/cuda/local_share.cpp index 333fd8d5..1fc9983b 100644 --- a/dnn/test/cuda/local_share.cpp +++ b/dnn/test/cuda/local_share.cpp @@ -778,7 +778,7 @@ TEST_F(CUDA, BENCHMARK_LOCAL_SHARE_BWD_FILTER) { .set_dtype(2, dtype::Float32()) .set_rng(0, &rng) .set_rng(1, &rng); - bencher.proxy()->target_algo = nullptr; + bencher.proxy()->target_algo_info.reset(); auto time_in_ms = bencher.execs({src, diff, grad}) / RUNS; printf("src=%s, diff=%s, grad=%s, float32: %.2fms " @@ -856,7 +856,7 @@ TEST_F(CUDA, BENCHMARK_GROUP_LOCAL_SHARE_FORWARD) { .set_dtype(2, dtype::Float32()) .set_rng(0, &rng) .set_rng(1, &rng); - bencher.proxy()->target_algo = nullptr; + bencher.proxy()->target_algo_info.reset(); auto time_in_ms = bencher.execs({src, filter, {}}) / RUNS; ; @@ -915,7 +915,7 @@ TEST_F(CUDA, BENCHMARK_LOCAL_SHARE_BWD_DATA) { .set_dtype(2, dtype::Float32()) .set_rng(0, &rng) .set_rng(1, &rng); - bencher.proxy()->target_algo = nullptr; + bencher.proxy()->target_algo_info.reset(); auto time_in_ms = bencher.execs({filter, diff, grad}) / RUNS; printf("filter=%s, diff=%s, grad=%s, float32: %.2fms " @@ -1002,11 +1002,11 @@ TEST_F(CUDA, BENCHMARK_LOCAL_SHARE_FORWARD_BOTTLENECK) { .set_dtype(2, dtype::Float32()) .set_rng(0, &rng) .set_rng(1, &rng); - bencher.proxy()->target_algo = nullptr; + bencher.proxy()->target_algo_info.reset(); auto time_in_ms = bencher.execs({src, filter, {}}) / RUNS; bencher_conv.set_param(conv_param); - bencher_conv.proxy()->target_algo = nullptr; + bencher_conv.proxy()->target_algo_info.reset(); auto time_in_ms_conv = bencher_conv.execs({src, {oc, ic, f, f}, {}}) / RUNS; @@ -1094,11 +1094,11 @@ TEST_F(CUDA, BENCHMARK_LOCAL_SHARE_FORWARD_FROM_RESEARCH) { .set_dtype(2, dtype::Float32()) .set_rng(0, &rng) .set_rng(1, &rng); - bencher.proxy()->target_algo = nullptr; + bencher.proxy()->target_algo_info.reset(); auto time_in_ms = bencher.execs({src, filter, {}}) / RUNS; bencher_conv.set_param(conv_param); - bencher_conv.proxy()->target_algo = nullptr; + bencher_conv.proxy()->target_algo_info.reset(); auto time_in_ms_conv = bencher_conv.execs({src, {oc, ic, f, f}, {}}) / RUNS; @@ -1177,11 +1177,11 @@ TEST_F(CUDA, BENCHMARK_LOCAL_SHARE_FORWARD) { .set_dtype(2, dtype::Float32()) .set_rng(0, &rng) .set_rng(1, &rng); - bencher.proxy()->target_algo = nullptr; + bencher.proxy()->target_algo_info.reset(); auto time_in_ms = bencher.execs({src, filter, {}}) / RUNS; bencher_conv.set_param(conv_param); - bencher_conv.proxy()->target_algo = nullptr; + bencher_conv.proxy()->target_algo_info.reset(); auto time_in_ms_conv = bencher_conv.execs({src, {oc, ic, f, f}, {}}) / RUNS; diff --git a/src/core/test/graph/misc.cpp b/src/core/test/graph/misc.cpp index 5d7797a2..830ff21a 100644 --- a/src/core/test/graph/misc.cpp +++ b/src/core/test/graph/misc.cpp @@ -1922,19 +1922,19 @@ TEST(TestGraph, NaiveRecord2NCHW44) { namespace { template -typename DnnOp::Algorithm* try_find_any_weight_preprocess_algo( +typename DnnOp::AlgorithmInfo try_find_any_weight_preprocess_algo( DnnOp* dnn_op, const char* mgb_info, Maybe& found, Args&& ...args) { if (found.valid()) { if (found.val()) { - return dnn_op->execution_policy().algorithm; + return dnn_op->execution_policy().algo; } else { - return nullptr; + return {}; } } - for (auto&& algo : dnn_op->get_all_algorithms( + for (auto&& algo : dnn_op->get_all_algorithms_info( std::forward(args)...)) { - dnn_op->execution_policy().algorithm = algo; + dnn_op->execution_policy().algo = algo; auto layouts = dnn_op->deduce_preprocessed_filter_layout( std::forward(args)...); if (layouts.empty()) continue; @@ -1952,23 +1952,23 @@ typename DnnOp::Algorithm* try_find_any_weight_preprocess_algo( } found.emplace(false); mgb_log_warn("Can't find weight preprocess algo for op %s", mgb_info); - return nullptr; + return {}; } template -typename DnnOp::Algorithm* try_find_any_bias_preprocess_algo( +typename DnnOp::AlgorithmInfo try_find_any_bias_preprocess_algo( DnnOp* dnn_op, const char* mgb_info, Maybe& found, Args&& ...args) { if (found.valid()) { if (found.val()) { - return dnn_op->execution_policy().algorithm; + return dnn_op->execution_policy().algo; } else { - return nullptr; + return {}; } } - for (auto&& algo : dnn_op->get_all_algorithms( + for (auto&& algo : dnn_op->get_all_algorithms_info( std::forward(args)...)) { - dnn_op->execution_policy().algorithm = algo; + dnn_op->execution_policy().algo = algo; auto layouts = dnn_op->deduce_preprocessed_filter_layout( std::forward(args)...); if (layouts.size() <= 1) @@ -1984,7 +1984,7 @@ typename DnnOp::Algorithm* try_find_any_bias_preprocess_algo( } found.emplace(false); mgb_log_warn("Can't find bias preprocess algo for op %s", mgb_info); - return nullptr; + return {}; } void test_free_memory_in_weight_preprocess(int record_level, CompNode cn) { diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index 293f1540..4816d159 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -69,7 +69,7 @@ AlgoChooserProfileCache::Result AlgoChooser::get_profile_result( Maybe cur_rst; std::string msg = ssprintf("profiling %s algorithm %s %s", ctx.mgb_opr()->dyn_typeinfo()->name, - algo->name(), str_on_inp_shape.c_str()); + algo.name.c_str(), str_on_inp_shape.c_str()); timer.reset(); MGB_TRY { cur_rst = ctx.profile_single_algo(algo, cur_timeout); } MGB_CATCH(std::exception & exc, { @@ -122,20 +122,20 @@ typename AlgoChooser::ImplAlgo AlgoChooser::choose_by_profile( MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) auto opr = ctx.mgb_opr(); if (opr->owner_graph()->options().no_profiling_on_shape_change) { - auto algo = ctx.megdnn_opr()->execution_policy().algorithm; - if (algo) + auto algo = ctx.megdnn_opr()->execution_policy().algo; + if (algo.valid()) return algo; } std::unordered_map algo_map; for (auto i : ctx.get_all_candidates()) { - auto ins = algo_map.emplace(i->name(), i); - mgb_assert(ins.second, "duplicated algo name: %s", i->name()); + auto ins = algo_map.emplace(i.name.c_str(), i); + mgb_assert(ins.second, "duplicated algo name: %s", i.name.c_str()); } auto&& prof = get_profile_result(ctx, enable_update); if (prof.empty()) - return nullptr; + return {}; for (auto&& i : prof) { if ((!require_reproducible || i.reproducible)) { auto iter = algo_map.find(i.algo); @@ -173,13 +173,13 @@ size_t AlgoChooser::setup_algo(const ConvTensorLayouts& layouts, return 0; } - ImplAlgo algo = nullptr; + ImplAlgo algo = {}; ExeContext ctx(layouts, megdnn_opr, mgb_opr, allow_weight_preprocess); if (auto algo_choose_hook = mgb_opr->algo_chooser()) { algo = algo_choose_hook(mgb_opr); } - if (!algo) { + if (!algo.valid()) { algo = get_algo(ctx); } size_t workspace = ctx.get_workspace_size_bytes(algo); @@ -190,8 +190,8 @@ size_t AlgoChooser::setup_algo(const ConvTensorLayouts& layouts, layouts[0].dtype.name(), layouts[1].to_string().c_str(), layouts[1].dtype.name(), layouts[layouts.size() - 1].to_string().c_str(), - layouts[layouts.size() - 1].dtype.name(), algo->name(), - workspace / (1024 * 1024.0), algo->is_reproducible()); + layouts[layouts.size() - 1].dtype.name(), algo.name.c_str(), + workspace / (1024 * 1024.0), algo.is_reproducible); megdnn_opr->execution_policy() = {algo}; return workspace; } @@ -208,7 +208,7 @@ typename AlgoChooser::ImplAlgo AlgoChooser::get_algo( return ctx.choose_by_heuristic(true); case S::PROFILE_HEURISTIC: { ImplAlgo algo = choose_by_profile(ctx, false, false); - if (algo == nullptr) + if (!algo.valid()) algo = ctx.choose_by_heuristic(); return algo; } @@ -249,8 +249,8 @@ AlgoChooser::ExeContext::choose_by_heuristic(bool reproducible) const { auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( opr->owner_graph(), opr->comp_node(), opr->execution_policy().workspace_limit); - return APPLY(m_megdnn_opr->get_algorithm_heuristic(args..., workspace_limit, - reproducible), + return APPLY(m_megdnn_opr->get_algorithm_info_heuristic( + args..., workspace_limit, reproducible), m_layouts); } @@ -258,7 +258,8 @@ template std::vector::ImplAlgo> AlgoChooser::ExeContext::get_all_candidates() const { auto heu = choose_by_heuristic(); - auto&& ret = APPLY(m_megdnn_opr->get_all_algorithms(args...), m_layouts); + auto&& ret = + APPLY(m_megdnn_opr->get_all_algorithms_info(args...), m_layouts); bool found = false; for (size_t i = 0; i < ret.size(); ++i) { if (ret[i] == heu) { @@ -269,7 +270,8 @@ AlgoChooser::ExeContext::get_all_candidates() const { } mgb_assert(found, "algo %s got by heuristic not found in " - "candidate list", heu->name()); + "candidate list", + heu.name.c_str()); return std::move(ret); } @@ -320,7 +322,7 @@ Maybe AlgoChooser::ExeContext::profile_single_algo(ImplAlgo algo, double& timeout) const { typename TimedProfiler::Param param; - auto name = algo->name(); + auto name = algo.name.c_str(); // force check copy size <= dest len-1 from gcc8 for safe auto len = sizeof(param.algo_name); strncpy(param.algo_name, name, len - 1); @@ -354,7 +356,7 @@ AlgoChooser::ExeContext::profile_single_algo(ImplAlgo algo, if (!rst.valid()) return None; return AlgoChooserProfileCache::ResultEntry{ - algo->name(), algo->is_reproducible(), rst.val().time, + algo.name.c_str(), algo.is_reproducible, rst.val().time, param.workspace}; } diff --git a/src/opr/impl/search_policy/profiler.cpp b/src/opr/impl/search_policy/profiler.cpp index 506a8330..76712386 100644 --- a/src/opr/impl/search_policy/profiler.cpp +++ b/src/opr/impl/search_policy/profiler.cpp @@ -99,14 +99,15 @@ typename TimedProfiler::TResult TimedProfiler::prof_impl( megdnn_opr->param() = param.opr_param; { - typename Opr::Algorithm* algo = nullptr; - for (auto i : APPLY(megdnn_opr->get_all_algorithms(args...), layouts)) { - if (!strcmp(i->name(), param.algo_name)) { + typename Opr::AlgorithmInfo algo; + for (auto i : + APPLY(megdnn_opr->get_all_algorithms_info(args...), layouts)) { + if (!strcmp(i.name.c_str(), param.algo_name)) { algo = i; break; } } - mgb_assert(algo, "algorithm %s not found", param.algo_name); + mgb_assert(algo.valid(), "algorithm %s not found", param.algo_name); megdnn_opr->execution_policy() = {algo}; } diff --git a/src/opr/include/megbrain/opr/dnn/convolution.h b/src/opr/include/megbrain/opr/dnn/convolution.h index b3071817..c1fed6d4 100644 --- a/src/opr/include/megbrain/opr/dnn/convolution.h +++ b/src/opr/include/megbrain/opr/dnn/convolution.h @@ -25,9 +25,9 @@ namespace mixin { class Convolution { public: using ExecutionPolicy = megdnn::param::ExecutionPolicy; - using Algorithm = megdnn::detail::Algorithm; + using AlgorithmInfo = megdnn::detail::Algorithm::Info; using AlgoChooserHook = - std::function; + std::function; const ExecutionPolicy& execution_policy() const { if (!m_policy_accessed) { @@ -618,9 +618,9 @@ private: const override final; }; -MGB_DEFINE_OPR_CLASS(BatchConvBiasForward, intl::BatchConvBiasForwardBase, +MGB_DEFINE_OPR_CLASS(BatchConvBiasForward, intl::BatchConvBiasForwardBase, public mixin::Convolution) // { - + void init_output_dtype() override; size_t get_workspace_size_bytes( const TensorShapeArray& input_shapes, diff --git a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h index 332664ab..7219fe5c 100644 --- a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h +++ b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h @@ -46,7 +46,7 @@ class AlgoChooser { static constexpr int arity_out = OprArityTrait::arity_out; static constexpr int arity = OprArityTrait::arity; - using ImplAlgo = typename Opr::Algorithm*; + using ImplAlgo = typename Opr::AlgorithmInfo; using MGBOpr = typename MegDNNOpr2MGBOpr::MGBOpr; using ConvTensorLayouts = std::array; diff --git a/src/opr/test/dnn/convolution.cpp b/src/opr/test/dnn/convolution.cpp index 500efd92..461dd978 100644 --- a/src/opr/test/dnn/convolution.cpp +++ b/src/opr/test/dnn/convolution.cpp @@ -20,6 +20,7 @@ #include "megbrain/opr/basic_arith.h" #include "megbrain/gopt/inference.h" #include "megbrain/opr/tensor_manip.h" +#include "megdnn/oprs/base.h" #include @@ -2008,11 +2009,11 @@ TEST(TestOprDNN, HeuristicReproducible) { bwd_flt->owner_opr()) ->megdnn_opr()) ->execution_policy() - .algorithm; + .algo; if (strategy == S::HEURISTIC_REPRODUCIBLE) { - EXPECT_TRUE(algo->is_reproducible()); + EXPECT_TRUE(algo.is_reproducible); } - algo_name0 = algo->name(); + algo_name0 = algo.name.c_str(); } { Checker checker(make_graph, fwd); @@ -2024,8 +2025,8 @@ TEST(TestOprDNN, HeuristicReproducible) { bwd_flt->owner_opr()) ->megdnn_opr()) ->execution_policy() - .algorithm; - algo_name1 = algo->name(); + .algo; + algo_name1 = algo.name.c_str(); } EXPECT_TRUE(algo_name0 == algo_name1); } @@ -2183,6 +2184,17 @@ public: MOCK_METHOD3(get_preprocess_workspace_in_bytes, size_t(const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst)); + + MOCK_METHOD3(get_all_algorithms_info, + std::vector(const TensorLayout& p0, + const TensorLayout& p1, + const TensorLayout& p2)); + MOCK_METHOD5(get_algorithm_info_heuristic, + AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1, + const TensorLayout& p2, + size_t workspace_limit_in_bytes, + bool reproducible)); + MOCK_METHOD3(get_all_algorithms, std::vector(const TensorLayout& p0, const TensorLayout& p1, @@ -2192,6 +2204,7 @@ public: const TensorLayout& p2, size_t workspace_limit_in_bytes, bool reproducible)); +protected: const char* get_algorithm_set_name() const override { return m_algorithm_set_name; } @@ -2204,6 +2217,9 @@ public: MockAlgorithm(const char* name = "NotImportant") : m_name(name) {} bool is_reproducible() const override { return true; } const char* name() const override { return m_name; } + uint32_t type() const override { + return megdnn::detail::Algorithm::INVALID_ALGO_TYPE; + } virtual ~MockAlgorithm() = default; };