From f214e14695cdaaed428cbf60ff8c6d31df30ad98 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 7 Dec 2020 17:30:05 +0800 Subject: [PATCH] refactor(mgb/cuda): use single implementation of get_device_prop from utils GitOrigin-RevId: 5cc95472b9f27339380f74a4a2828368af56c038 --- dnn/src/cuda/handle.cpp | 6 +++--- dnn/src/cuda/handle.h | 4 ++-- dnn/src/cuda/utils.cpp | 17 ++++++++++++----- dnn/src/cuda/utils.h | 5 ++++- 4 files changed, 21 insertions(+), 11 deletions(-) diff --git a/dnn/src/cuda/handle.cpp b/dnn/src/cuda/handle.cpp index bc909c95..b52e0015 100644 --- a/dnn/src/cuda/handle.cpp +++ b/dnn/src/cuda/handle.cpp @@ -46,7 +46,7 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle): cuda_check(cudaGetDevice(&dev_id)); } m_device_id = dev_id; - cuda_check(cudaGetDeviceProperties(&m_device_prop, dev_id)); + m_device_prop = get_device_prop(dev_id); // Get stream from MegCore computing handle. megdnn_assert(CUDNN_VERSION == cudnnGetVersion(), "cudnn version mismatch: compiled with %d; detected %zu at runtime", @@ -80,7 +80,7 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle): cuda_check(cudaStreamSynchronize(stream())); // check tk1 - m_is_tegra_k1 = (strcmp(m_device_prop.name, "GK20A") == 0); + m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0); m_cusolver_handle = nullptr; } @@ -104,7 +104,7 @@ void HandleImpl::ConstScalars::init() { size_t HandleImpl::alignment_requirement() const { auto &&prop = m_device_prop; - return std::max(prop.textureAlignment, prop.texturePitchAlignment); + return std::max(prop->textureAlignment, prop->texturePitchAlignment); } bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) { diff --git a/dnn/src/cuda/handle.h b/dnn/src/cuda/handle.h index 9aa6fdfb..c4ac6ac2 100644 --- a/dnn/src/cuda/handle.h +++ b/dnn/src/cuda/handle.h @@ -42,7 +42,7 @@ class HandleImpl: public HandleImplHelper { bool check_cross_dev_copy_constraint(const TensorLayout &src) override; const cudaDeviceProp& device_prop() const { - return m_device_prop; + return *m_device_prop; } template @@ -137,7 +137,7 @@ class HandleImpl: public HandleImplHelper { cusolverDnHandle_t m_cusolver_handle; std::once_flag m_cusolver_initialized; - cudaDeviceProp m_device_prop; + const cudaDeviceProp* m_device_prop; struct ConstScalars { union FP16 { diff --git a/dnn/src/cuda/utils.cpp b/dnn/src/cuda/utils.cpp index 52335d61..0060a7cd 100644 --- a/dnn/src/cuda/utils.cpp +++ b/dnn/src/cuda/utils.cpp @@ -107,19 +107,26 @@ uint32_t cuda::safe_size_in_kern(size_t size) { return size; } -cudaDeviceProp cuda::current_device_prop() { +const cudaDeviceProp& cuda::current_device_prop() { int dev; cuda_check(cudaGetDevice(&dev)); - megdnn_assert(dev < MAX_NR_DEVICE, "device number too large: %d", dev); - auto&& rec = device_prop_rec[dev]; + return *(cuda::get_device_prop(dev)); +} + +const cudaDeviceProp* cuda::get_device_prop(int device) { + megdnn_assert(device < MAX_NR_DEVICE, "device number too large: %d", + device); + megdnn_assert(device >= 0, "device number must not be negative, got %d", + device); + auto&& rec = device_prop_rec[device]; if (!rec.init) { std::lock_guard lock(rec.mtx); if (!rec.init) { - cuda_check(cudaGetDeviceProperties(&rec.prop, dev)); + cuda_check(cudaGetDeviceProperties(&rec.prop, device)); rec.init = true; } } - return rec.prop; + return &(rec.prop); } bool cuda::is_compute_capability_required(int major, int minor) { diff --git a/dnn/src/cuda/utils.h b/dnn/src/cuda/utils.h index 05fc0284..c5bb2698 100644 --- a/dnn/src/cuda/utils.h +++ b/dnn/src/cuda/utils.h @@ -52,7 +52,10 @@ static inline void CUDART_CB callback_free(cudaStream_t /* stream */, } //! get property of currently active device -cudaDeviceProp current_device_prop(); +const cudaDeviceProp& current_device_prop(); + +//! get property of device specified by device +const cudaDeviceProp* get_device_prop(int device); //! check compute capability satisfied with given sm version bool is_compute_capability_required(int major, int minor);