You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

handle.cpp 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. #include "src/common/handle_impl.h"
  2. #include "src/common/version_symbol.h"
  3. #include "megdnn/common.h"
  4. #include "src/cuda/handle.h"
  5. #include "src/cuda/utils.h"
  6. #include <cuda.h>
  7. #include <cstring>
  8. #define STR_HELPER(x) #x
  9. #define STR(x) STR_HELPER(x)
  10. #define CUDNN_VERSION_STR \
  11. STR(CUDNN_MAJOR) "." STR(CUDNN_MINOR) "." STR(CUDNN_PATCHLEVEL)
  12. #pragma message "compile with cuDNN " CUDNN_VERSION_STR " "
  13. static_assert(
  14. !(CUDNN_MAJOR == 5 && CUDNN_MINOR == 1),
  15. "cuDNN 5.1.x series has bugs. Use 5.0.x instead.");
  16. #undef STR
  17. #undef STR_HELPER
  18. namespace megdnn {
  19. namespace cuda {
  20. HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle)
  21. : HandleImplHelper(comp_handle, HandleType::CUDA) {
  22. // Get megcore device handle
  23. megcoreDeviceHandle_t dev_handle;
  24. megcoreGetDeviceHandle(comp_handle, &dev_handle);
  25. int dev_id;
  26. megcoreGetDeviceID(dev_handle, &dev_id);
  27. if (dev_id < 0) {
  28. cuda_check(cudaGetDevice(&dev_id));
  29. }
  30. m_device_id = dev_id;
  31. m_device_prop = get_device_prop(dev_id);
  32. // Get stream from MegCore computing handle.
  33. megdnn_assert(
  34. CUDNN_VERSION == cudnnGetVersion(),
  35. "cudnn version mismatch: compiled with %d; detected %zu at runtime, may "
  36. "caused by customized environment, for example LD_LIBRARY_PATH on LINUX "
  37. "and PATH on Windows!!",
  38. CUDNN_VERSION, cudnnGetVersion());
  39. #if CUDA_VERSION >= 10010
  40. megdnn_assert(
  41. cublasLtGetVersion() >= 10010,
  42. "cuda library version is too low to run cublasLt");
  43. #endif
  44. #if CUDNN_VERSION >= 8000
  45. if (!MGB_GETENV("CUDA_CACHE_PATH")) {
  46. megdnn_log_warn(R"(
  47. Cudnn8 will jit ptx code with cache. You can set
  48. CUDA_CACHE_MAXSIZE and CUDA_CACHE_PATH environment var to avoid repeat jit(very slow).
  49. For example `export CUDA_CACHE_MAXSIZE=2147483647` and `export CUDA_CACHE_PATH=/data/.cuda_cache`)");
  50. }
  51. #endif
  52. size_t free, tot;
  53. cudaMemGetInfo(&free, &tot);
  54. printf("before cudnn create, free: %.2f MB, tot: %.2f MB, allocated: %.2f MB\n",
  55. free / 1024.0 / 1024.0, tot / 1024.0 / 1024.0,
  56. (tot - free) / 1024.0 / 1024.0);
  57. cudnn_check(cudnnCreate(&m_cudnn_handle));
  58. cublas_check(cublasCreate(&m_cublas_handle));
  59. #if CUDA_VERSION >= 10010
  60. cublas_check(cublasLtCreate(&m_cublasLt_handle));
  61. #endif
  62. megcore::getCUDAContext(comp_handle, &m_megcore_context);
  63. // Set stream for cuDNN and cublas handles.
  64. cudnn_check(cudnnSetStream(m_cudnn_handle, stream()));
  65. cublas_check(cublasSetStream(m_cublas_handle, stream()));
  66. #if CUDNN_VERSION >= 8004
  67. // cudnn_check(cudnnOpsInferVersionCheck());
  68. // cudnn_check(cudnnCnnInferVersionCheck());
  69. #endif
  70. // Note that all cublas scalars (alpha, beta) and scalar results such as dot
  71. // output resides at device side.
  72. cublas_check(cublasSetPointerMode(m_cublas_handle, CUBLAS_POINTER_MODE_DEVICE));
  73. // init const scalars
  74. cuda_check(cudaMalloc(&m_const_scalars, sizeof(ConstScalars)));
  75. ConstScalars const_scalars_val;
  76. const_scalars_val.init();
  77. cuda_check(cudaMemcpyAsync(
  78. m_const_scalars, &const_scalars_val, sizeof(ConstScalars),
  79. cudaMemcpyHostToDevice, stream()));
  80. cuda_check(cudaStreamSynchronize(stream()));
  81. cudaMemGetInfo(&free, &tot);
  82. printf("after cudnn create, free: %.2f MB, tot: %.2f MB, allocated: %.2f MB\n",
  83. free / 1024.0 / 1024.0, tot / 1024.0 / 1024.0,
  84. (tot - free) / 1024.0 / 1024.0);
  85. // check tk1
  86. m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0);
  87. m_cusolver_handle = nullptr;
  88. }
  89. HandleImpl::~HandleImpl() noexcept {
  90. cudnn_check(cudnnDestroy(m_cudnn_handle));
  91. cublas_check(cublasDestroy(m_cublas_handle));
  92. #if CUDA_VERSION >= 10010
  93. cublas_check(cublasLtDestroy(m_cublasLt_handle));
  94. #endif
  95. if (m_cusolver_handle) {
  96. cusolver_check(cusolverDnDestroy(m_cusolver_handle));
  97. }
  98. cuda_check(cudaFree(m_const_scalars));
  99. }
  100. void HandleImpl::ConstScalars::init() {
  101. f16[0].megdnn_x = 0;
  102. f16[1].megdnn_x = 1;
  103. f32[0] = 0;
  104. f32[1] = 1;
  105. i32[0] = 0;
  106. i32[1] = 1;
  107. }
  108. size_t HandleImpl::alignment_requirement() const {
  109. auto&& prop = m_device_prop;
  110. return std::max(prop->textureAlignment, prop->texturePitchAlignment);
  111. }
  112. bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) {
  113. // is contiguous or can be hold by
  114. // relayout::param::try_copy_2d/try_copy_last_contig
  115. return src.is_contiguous() || src.stride[src.ndim - 1] == 1;
  116. }
  117. void HandleImpl::initialize_cusolver() {
  118. cusolver_check(cusolverDnCreate(&m_cusolver_handle));
  119. cusolver_check(cusolverDnSetStream(m_cusolver_handle, stream()));
  120. }
  121. size_t HandleImpl::image2d_pitch_alignment() const {
  122. size_t align = device_prop().texturePitchAlignment;
  123. return align;
  124. }
  125. HandleImpl::HandleVendorType HandleImpl::vendor_type() const {
  126. return HandleVendorType::CUDA;
  127. }
  128. } // namespace cuda
  129. } // namespace megdnn
  130. MEGDNN_VERSION_SYMBOL(CUDA, CUDA_VERSION);
  131. MEGDNN_VERSION_SYMBOL3(CUDNN, CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
  132. // vim: syntax=cpp.doxygen