You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fixture.cpp 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. /**
  2. * \file dnn/test/cuda/fixture.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/cuda/fixture.h"
  12. #include "test/common/utils.h"
  13. #include "test/common/memory_manager.h"
  14. #include "src/cuda/utils.h"
  15. #include "src/cuda/handle.h"
  16. #include "test/common/random_state.h"
  17. #include <cstdlib>
  18. #include <cuda_runtime_api.h>
  19. using namespace megdnn;
  20. using namespace test;
  21. namespace {
  22. void setup_device() {
  23. #if !defined(_WIN32)
  24. auto device_id_env = std::getenv("MEGDNN_DEVICE_ID");
  25. int device_id = -1;
  26. if (device_id_env) {
  27. device_id = std::atoi(device_id_env);
  28. std::cout << "Select device " << device_id
  29. << " because MEGDNN_DEVICE_ID is set." << std::endl;
  30. }
  31. auto pci_bus_id_env = std::getenv("MEGDNN_PCI_BUS_ID");
  32. if (pci_bus_id_env) {
  33. megdnn_assert(cudaSuccess == cudaDeviceGetByPCIBusId(&device_id,
  34. pci_bus_id_env));
  35. std::cout << "Select device " << pci_bus_id_env << " ("
  36. << device_id << ") because MEGDNN_PCI_BUS_ID is set."
  37. << std::endl;
  38. }
  39. if (device_id_env && pci_bus_id_env) {
  40. std::cout << "MEGDNN_DEVICE_ID and MEGDNN_PCI_BUS_ID should not "
  41. "be set simultaneously." << std::endl;
  42. exit(1);
  43. }
  44. if (device_id_env || pci_bus_id_env) {
  45. megdnn_assert(cudaSuccess == cudaSetDevice(device_id));
  46. }
  47. #endif
  48. }
  49. } // anonymous namespace
  50. void CUDA::SetUp() {
  51. RandomState::reset();
  52. setup_device();
  53. megcoreDeviceHandle_t dev_handle;
  54. megcore_check(megcoreCreateDeviceHandle(&dev_handle,
  55. megcorePlatformCUDA));
  56. megcoreComputingHandle_t comp_handle;
  57. megcore_check(megcoreCreateComputingHandle(&comp_handle,
  58. dev_handle));
  59. m_handle_cuda = Handle::make(comp_handle);
  60. megdnn_assert(m_handle_cuda);
  61. }
  62. Handle* CUDA::handle_naive() {
  63. if (!m_handle_naive)
  64. m_handle_naive = create_cpu_handle(2);
  65. return m_handle_naive.get();
  66. }
  67. void CUDA::TearDown() {
  68. m_handle_naive.reset();
  69. m_handle_cuda.reset();
  70. MemoryManagerHolder::instance()->clear();
  71. }
  72. void CUDA_ERROR_INFO::SetUp() {
  73. setup_device();
  74. megcoreDeviceHandle_t dev_handle;
  75. megcore_check(megcoreCreateDeviceHandle(&dev_handle, megcorePlatformCUDA));
  76. m_error_info_dev = nullptr;
  77. void* ptr;
  78. cuda_check(cudaMalloc(&ptr, sizeof(megcore::AsyncErrorInfo)));
  79. cuda_check(cudaMemset(ptr, 0, sizeof(megcore::AsyncErrorInfo)));
  80. cuda_check(cudaDeviceSynchronize());
  81. m_error_info_dev = static_cast<megcore::AsyncErrorInfo*>(ptr);
  82. // create handle bind with error_info
  83. megcoreComputingHandle_t comp_handle;
  84. megcore_check(megcore::createComputingHandleWithCUDAContext(
  85. &comp_handle, dev_handle, 0, {nullptr, m_error_info_dev}));
  86. m_handle_cuda = Handle::make(comp_handle);
  87. megdnn_assert(static_cast<bool>(m_handle_cuda));
  88. }
  89. void CUDA_ERROR_INFO::TearDown() {
  90. if (m_error_info_dev) {
  91. cuda_check(cudaFree(m_error_info_dev));
  92. }
  93. m_handle_cuda.reset();
  94. MemoryManagerHolder::instance()->clear();
  95. }
  96. megcore::AsyncErrorInfo CUDA_ERROR_INFO::get_error_info() {
  97. megcore::AsyncErrorInfo ret;
  98. auto stream = cuda::cuda_stream(m_handle_cuda.get());
  99. cuda_check(cudaMemcpyAsync(&ret, m_error_info_dev, sizeof(ret),
  100. cudaMemcpyDeviceToHost, stream));
  101. cuda_check(cudaStreamSynchronize(stream));
  102. return ret;
  103. }
  104. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台