GitOrigin-RevId: 5cc95472b9
release-1.2
@@ -46,7 +46,7 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle): | |||||
cuda_check(cudaGetDevice(&dev_id)); | cuda_check(cudaGetDevice(&dev_id)); | ||||
} | } | ||||
m_device_id = dev_id; | m_device_id = dev_id; | ||||
cuda_check(cudaGetDeviceProperties(&m_device_prop, dev_id)); | |||||
m_device_prop = get_device_prop(dev_id); | |||||
// Get stream from MegCore computing handle. | // Get stream from MegCore computing handle. | ||||
megdnn_assert(CUDNN_VERSION == cudnnGetVersion(), | megdnn_assert(CUDNN_VERSION == cudnnGetVersion(), | ||||
"cudnn version mismatch: compiled with %d; detected %zu at runtime", | "cudnn version mismatch: compiled with %d; detected %zu at runtime", | ||||
@@ -80,7 +80,7 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle): | |||||
cuda_check(cudaStreamSynchronize(stream())); | cuda_check(cudaStreamSynchronize(stream())); | ||||
// check tk1 | // check tk1 | ||||
m_is_tegra_k1 = (strcmp(m_device_prop.name, "GK20A") == 0); | |||||
m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0); | |||||
m_cusolver_handle = nullptr; | m_cusolver_handle = nullptr; | ||||
} | } | ||||
@@ -104,7 +104,7 @@ void HandleImpl::ConstScalars::init() { | |||||
size_t HandleImpl::alignment_requirement() const { | size_t HandleImpl::alignment_requirement() const { | ||||
auto &&prop = m_device_prop; | auto &&prop = m_device_prop; | ||||
return std::max(prop.textureAlignment, prop.texturePitchAlignment); | |||||
return std::max(prop->textureAlignment, prop->texturePitchAlignment); | |||||
} | } | ||||
bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) { | bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) { | ||||
@@ -42,7 +42,7 @@ class HandleImpl: public HandleImplHelper { | |||||
bool check_cross_dev_copy_constraint(const TensorLayout &src) override; | bool check_cross_dev_copy_constraint(const TensorLayout &src) override; | ||||
const cudaDeviceProp& device_prop() const { | const cudaDeviceProp& device_prop() const { | ||||
return m_device_prop; | |||||
return *m_device_prop; | |||||
} | } | ||||
template <typename Opr> | template <typename Opr> | ||||
@@ -137,7 +137,7 @@ class HandleImpl: public HandleImplHelper { | |||||
cusolverDnHandle_t m_cusolver_handle; | cusolverDnHandle_t m_cusolver_handle; | ||||
std::once_flag m_cusolver_initialized; | std::once_flag m_cusolver_initialized; | ||||
cudaDeviceProp m_device_prop; | |||||
const cudaDeviceProp* m_device_prop; | |||||
struct ConstScalars { | struct ConstScalars { | ||||
union FP16 { | union FP16 { | ||||
@@ -107,19 +107,26 @@ uint32_t cuda::safe_size_in_kern(size_t size) { | |||||
return size; | return size; | ||||
} | } | ||||
cudaDeviceProp cuda::current_device_prop() { | |||||
const cudaDeviceProp& cuda::current_device_prop() { | |||||
int dev; | int dev; | ||||
cuda_check(cudaGetDevice(&dev)); | cuda_check(cudaGetDevice(&dev)); | ||||
megdnn_assert(dev < MAX_NR_DEVICE, "device number too large: %d", dev); | |||||
auto&& rec = device_prop_rec[dev]; | |||||
return *(cuda::get_device_prop(dev)); | |||||
} | |||||
const cudaDeviceProp* cuda::get_device_prop(int device) { | |||||
megdnn_assert(device < MAX_NR_DEVICE, "device number too large: %d", | |||||
device); | |||||
megdnn_assert(device >= 0, "device number must not be negative, got %d", | |||||
device); | |||||
auto&& rec = device_prop_rec[device]; | |||||
if (!rec.init) { | if (!rec.init) { | ||||
std::lock_guard<std::mutex> lock(rec.mtx); | std::lock_guard<std::mutex> lock(rec.mtx); | ||||
if (!rec.init) { | if (!rec.init) { | ||||
cuda_check(cudaGetDeviceProperties(&rec.prop, dev)); | |||||
cuda_check(cudaGetDeviceProperties(&rec.prop, device)); | |||||
rec.init = true; | rec.init = true; | ||||
} | } | ||||
} | } | ||||
return rec.prop; | |||||
return &(rec.prop); | |||||
} | } | ||||
bool cuda::is_compute_capability_required(int major, int minor) { | bool cuda::is_compute_capability_required(int major, int minor) { | ||||
@@ -52,7 +52,10 @@ static inline void CUDART_CB callback_free(cudaStream_t /* stream */, | |||||
} | } | ||||
//! get property of currently active device | //! get property of currently active device | ||||
cudaDeviceProp current_device_prop(); | |||||
const cudaDeviceProp& current_device_prop(); | |||||
//! get property of device specified by device | |||||
const cudaDeviceProp* get_device_prop(int device); | |||||
//! check compute capability satisfied with given sm version | //! check compute capability satisfied with given sm version | ||||
bool is_compute_capability_required(int major, int minor); | bool is_compute_capability_required(int major, int minor); | ||||