Browse Source

refactor(mgb/cuda): use single implementation of get_device_prop from utils

GitOrigin-RevId: 5cc95472b9
release-1.2
Megvii Engine Team 4 years ago
parent
commit
f214e14695
4 changed files with 21 additions and 11 deletions
  1. +3
    -3
      dnn/src/cuda/handle.cpp
  2. +2
    -2
      dnn/src/cuda/handle.h
  3. +12
    -5
      dnn/src/cuda/utils.cpp
  4. +4
    -1
      dnn/src/cuda/utils.h

+ 3
- 3
dnn/src/cuda/handle.cpp View File

@@ -46,7 +46,7 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle):
cuda_check(cudaGetDevice(&dev_id)); cuda_check(cudaGetDevice(&dev_id));
} }
m_device_id = dev_id; m_device_id = dev_id;
cuda_check(cudaGetDeviceProperties(&m_device_prop, dev_id));
m_device_prop = get_device_prop(dev_id);
// Get stream from MegCore computing handle. // Get stream from MegCore computing handle.
megdnn_assert(CUDNN_VERSION == cudnnGetVersion(), megdnn_assert(CUDNN_VERSION == cudnnGetVersion(),
"cudnn version mismatch: compiled with %d; detected %zu at runtime", "cudnn version mismatch: compiled with %d; detected %zu at runtime",
@@ -80,7 +80,7 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle):
cuda_check(cudaStreamSynchronize(stream())); cuda_check(cudaStreamSynchronize(stream()));


// check tk1 // check tk1
m_is_tegra_k1 = (strcmp(m_device_prop.name, "GK20A") == 0);
m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0);
m_cusolver_handle = nullptr; m_cusolver_handle = nullptr;
} }


@@ -104,7 +104,7 @@ void HandleImpl::ConstScalars::init() {


size_t HandleImpl::alignment_requirement() const { size_t HandleImpl::alignment_requirement() const {
auto &&prop = m_device_prop; auto &&prop = m_device_prop;
return std::max(prop.textureAlignment, prop.texturePitchAlignment);
return std::max(prop->textureAlignment, prop->texturePitchAlignment);
} }


bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) { bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) {


+ 2
- 2
dnn/src/cuda/handle.h View File

@@ -42,7 +42,7 @@ class HandleImpl: public HandleImplHelper {
bool check_cross_dev_copy_constraint(const TensorLayout &src) override; bool check_cross_dev_copy_constraint(const TensorLayout &src) override;


const cudaDeviceProp& device_prop() const { const cudaDeviceProp& device_prop() const {
return m_device_prop;
return *m_device_prop;
} }


template <typename Opr> template <typename Opr>
@@ -137,7 +137,7 @@ class HandleImpl: public HandleImplHelper {
cusolverDnHandle_t m_cusolver_handle; cusolverDnHandle_t m_cusolver_handle;
std::once_flag m_cusolver_initialized; std::once_flag m_cusolver_initialized;


cudaDeviceProp m_device_prop;
const cudaDeviceProp* m_device_prop;


struct ConstScalars { struct ConstScalars {
union FP16 { union FP16 {


+ 12
- 5
dnn/src/cuda/utils.cpp View File

@@ -107,19 +107,26 @@ uint32_t cuda::safe_size_in_kern(size_t size) {
return size; return size;
} }


cudaDeviceProp cuda::current_device_prop() {
const cudaDeviceProp& cuda::current_device_prop() {
int dev; int dev;
cuda_check(cudaGetDevice(&dev)); cuda_check(cudaGetDevice(&dev));
megdnn_assert(dev < MAX_NR_DEVICE, "device number too large: %d", dev);
auto&& rec = device_prop_rec[dev];
return *(cuda::get_device_prop(dev));
}

const cudaDeviceProp* cuda::get_device_prop(int device) {
megdnn_assert(device < MAX_NR_DEVICE, "device number too large: %d",
device);
megdnn_assert(device >= 0, "device number must not be negative, got %d",
device);
auto&& rec = device_prop_rec[device];
if (!rec.init) { if (!rec.init) {
std::lock_guard<std::mutex> lock(rec.mtx); std::lock_guard<std::mutex> lock(rec.mtx);
if (!rec.init) { if (!rec.init) {
cuda_check(cudaGetDeviceProperties(&rec.prop, dev));
cuda_check(cudaGetDeviceProperties(&rec.prop, device));
rec.init = true; rec.init = true;
} }
} }
return rec.prop;
return &(rec.prop);
} }


bool cuda::is_compute_capability_required(int major, int minor) { bool cuda::is_compute_capability_required(int major, int minor) {


+ 4
- 1
dnn/src/cuda/utils.h View File

@@ -52,7 +52,10 @@ static inline void CUDART_CB callback_free(cudaStream_t /* stream */,
} }


//! get property of currently active device //! get property of currently active device
cudaDeviceProp current_device_prop();
const cudaDeviceProp& current_device_prop();

//! get property of device specified by device
const cudaDeviceProp* get_device_prop(int device);


//! check compute capability satisfied with given sm version //! check compute capability satisfied with given sm version
bool is_compute_capability_required(int major, int minor); bool is_compute_capability_required(int major, int minor);


Loading…
Cancel
Save