From ef9aa8007443f2af129ca72c986268b4ceb321a4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 15 Apr 2021 15:44:58 +0800 Subject: [PATCH] fix(mgb/dnn): fix cuda naive matmul algo GitOrigin-RevId: 79c9bba73b46274d4db59ed5b57c1ba0b1dacf45 --- dnn/src/cuda/matrix_mul/algos.cpp | 2 +- dnn/src/cuda/matrix_mul/algos.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dnn/src/cuda/matrix_mul/algos.cpp b/dnn/src/cuda/matrix_mul/algos.cpp index 57d8941b..c9c01692 100644 --- a/dnn/src/cuda/matrix_mul/algos.cpp +++ b/dnn/src/cuda/matrix_mul/algos.cpp @@ -29,7 +29,6 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { #if CUDA_VERSION >= 10010 all_algos.push_back(&cublas_lt); #endif - all_algos.push_back(&naive); #if !MEGDNN_DISABLE_FLOAT16 all_algos.push_back(&bfloat16); #endif @@ -45,6 +44,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { all_algos.push_back(&algo); } #endif + all_algos.push_back(&naive); for (auto&& algo : all_algos) { m_all_algos_map.emplace(algo->info().desc, algo); diff --git a/dnn/src/cuda/matrix_mul/algos.h b/dnn/src/cuda/matrix_mul/algos.h index b783cf8d..5bbb9245 100644 --- a/dnn/src/cuda/matrix_mul/algos.h +++ b/dnn/src/cuda/matrix_mul/algos.h @@ -157,7 +157,7 @@ public: void exec(const ExecArgs& args) const override; MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE) AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; } };