You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fixture.cpp 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. #include "test/rocm/fixture.h"
  2. #include "hcc_detail/hcc_defs_prologue.h"
  3. #include "src/rocm/handle.h"
  4. #include "src/rocm/utils.h"
  5. #include "test/common/memory_manager.h"
  6. #include "test/common/random_state.h"
  7. #include "test/common/utils.h"
  8. #include <cstdlib>
  9. #include "hip_header.h"
  10. using namespace megdnn;
  11. using namespace test;
  12. namespace {
  13. void setup_device() {
  14. #if !defined(WIN32)
  15. auto device_id_env = std::getenv("MEGDNN_DEVICE_ID");
  16. int device_id = -1;
  17. if (device_id_env) {
  18. device_id = std::atoi(device_id_env);
  19. std::cout << "Select device " << device_id
  20. << " because MEGDNN_DEVICE_ID is set." << std::endl;
  21. }
  22. auto pci_bus_id_env = std::getenv("MEGDNN_PCI_BUS_ID");
  23. if (pci_bus_id_env) {
  24. megdnn_assert(hipSuccess == hipDeviceGetByPCIBusId(&device_id, pci_bus_id_env));
  25. std::cout << "Select device " << pci_bus_id_env << " (" << device_id
  26. << ") because MEGDNN_PCI_BUS_ID is set." << std::endl;
  27. }
  28. if (device_id_env && pci_bus_id_env) {
  29. std::cout << "MEGDNN_DEVICE_ID and MEGDNN_PCI_BUS_ID should not "
  30. "be set simultaneously."
  31. << std::endl;
  32. exit(1);
  33. }
  34. if (device_id_env || pci_bus_id_env) {
  35. megdnn_assert(hipSuccess == hipSetDevice(device_id));
  36. }
  37. #endif
  38. }
  39. } // anonymous namespace
  40. void ROCM::SetUp() {
  41. RandomState::reset();
  42. setup_device();
  43. megcoreDeviceHandle_t dev_handle;
  44. megcore_check(megcoreCreateDeviceHandle(&dev_handle, megcorePlatformROCM));
  45. megcoreComputingHandle_t comp_handle;
  46. megcore_check(megcoreCreateComputingHandle(&comp_handle, dev_handle));
  47. m_handle_rocm = Handle::make(comp_handle);
  48. megdnn_assert(m_handle_rocm);
  49. }
  50. Handle* ROCM::handle_naive(bool check_dispatch) {
  51. if (!m_handle_naive)
  52. m_handle_naive = create_cpu_handle(2, check_dispatch);
  53. return m_handle_naive.get();
  54. }
  55. void ROCM::TearDown() {
  56. m_handle_naive.reset();
  57. m_handle_rocm.reset();
  58. MemoryManagerHolder::instance()->clear();
  59. }
  60. void ROCM_ERROR_INFO::SetUp() {
  61. setup_device();
  62. megcoreDeviceHandle_t dev_handle;
  63. megcore_check(megcoreCreateDeviceHandle(&dev_handle, megcorePlatformROCM));
  64. m_error_info_dev = nullptr;
  65. void* ptr;
  66. hip_check(hipMalloc(&ptr, sizeof(megcore::AsyncErrorInfo)));
  67. hip_check(hipMemset(ptr, 0, sizeof(megcore::AsyncErrorInfo)));
  68. hip_check(hipDeviceSynchronize());
  69. m_error_info_dev = static_cast<megcore::AsyncErrorInfo*>(ptr);
  70. // create handle bind with error_info
  71. megcoreComputingHandle_t comp_handle;
  72. megcore_check(megcore::createComputingHandleWithROCMContext(
  73. &comp_handle, dev_handle, 0, {nullptr, m_error_info_dev}));
  74. m_handle_rocm = Handle::make(comp_handle);
  75. megdnn_assert(static_cast<bool>(m_handle_rocm));
  76. }
  77. void ROCM_ERROR_INFO::TearDown() {
  78. if (m_error_info_dev) {
  79. hip_check(hipFree(m_error_info_dev));
  80. }
  81. m_handle_rocm.reset();
  82. MemoryManagerHolder::instance()->clear();
  83. }
  84. megcore::AsyncErrorInfo ROCM_ERROR_INFO::get_error_info() {
  85. megcore::AsyncErrorInfo ret;
  86. auto stream = rocm::hip_stream(m_handle_rocm.get());
  87. hip_check(hipMemcpyAsync(
  88. &ret, m_error_info_dev, sizeof(ret), hipMemcpyDeviceToHost, stream));
  89. hip_check(hipStreamSynchronize(stream));
  90. return ret;
  91. }
  92. // vim: syntax=cpp.doxygen