You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fixture.cpp 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. /**
  2. * \file dnn/test/rocm/fixture.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/rocm/fixture.h"
  12. #include "hcc_detail/hcc_defs_prologue.h"
  13. #include "src/rocm/handle.h"
  14. #include "src/rocm/utils.h"
  15. #include "test/common/memory_manager.h"
  16. #include "test/common/random_state.h"
  17. #include "test/common/utils.h"
  18. #include <cstdlib>
  19. #include "hip_header.h"
  20. using namespace megdnn;
  21. using namespace test;
  22. namespace {
  23. void setup_device() {
  24. #if !defined(WIN32)
  25. auto device_id_env = std::getenv("MEGDNN_DEVICE_ID");
  26. int device_id = -1;
  27. if (device_id_env) {
  28. device_id = std::atoi(device_id_env);
  29. std::cout << "Select device " << device_id
  30. << " because MEGDNN_DEVICE_ID is set." << std::endl;
  31. }
  32. auto pci_bus_id_env = std::getenv("MEGDNN_PCI_BUS_ID");
  33. if (pci_bus_id_env) {
  34. megdnn_assert(hipSuccess == hipDeviceGetByPCIBusId(&device_id, pci_bus_id_env));
  35. std::cout << "Select device " << pci_bus_id_env << " (" << device_id
  36. << ") because MEGDNN_PCI_BUS_ID is set." << std::endl;
  37. }
  38. if (device_id_env && pci_bus_id_env) {
  39. std::cout << "MEGDNN_DEVICE_ID and MEGDNN_PCI_BUS_ID should not "
  40. "be set simultaneously."
  41. << std::endl;
  42. exit(1);
  43. }
  44. if (device_id_env || pci_bus_id_env) {
  45. megdnn_assert(hipSuccess == hipSetDevice(device_id));
  46. }
  47. #endif
  48. }
  49. } // anonymous namespace
  50. void ROCM::SetUp() {
  51. RandomState::reset();
  52. setup_device();
  53. megcoreDeviceHandle_t dev_handle;
  54. megcore_check(megcoreCreateDeviceHandle(&dev_handle, megcorePlatformROCM));
  55. megcoreComputingHandle_t comp_handle;
  56. megcore_check(megcoreCreateComputingHandle(&comp_handle, dev_handle));
  57. m_handle_rocm = Handle::make(comp_handle);
  58. megdnn_assert(m_handle_rocm);
  59. }
  60. Handle* ROCM::handle_naive(bool check_dispatch) {
  61. if (!m_handle_naive)
  62. m_handle_naive = create_cpu_handle(2, check_dispatch);
  63. return m_handle_naive.get();
  64. }
  65. void ROCM::TearDown() {
  66. m_handle_naive.reset();
  67. m_handle_rocm.reset();
  68. MemoryManagerHolder::instance()->clear();
  69. }
  70. void ROCM_ERROR_INFO::SetUp() {
  71. setup_device();
  72. megcoreDeviceHandle_t dev_handle;
  73. megcore_check(megcoreCreateDeviceHandle(&dev_handle, megcorePlatformROCM));
  74. m_error_info_dev = nullptr;
  75. void* ptr;
  76. hip_check(hipMalloc(&ptr, sizeof(megcore::AsyncErrorInfo)));
  77. hip_check(hipMemset(ptr, 0, sizeof(megcore::AsyncErrorInfo)));
  78. hip_check(hipDeviceSynchronize());
  79. m_error_info_dev = static_cast<megcore::AsyncErrorInfo*>(ptr);
  80. // create handle bind with error_info
  81. megcoreComputingHandle_t comp_handle;
  82. megcore_check(megcore::createComputingHandleWithROCMContext(
  83. &comp_handle, dev_handle, 0, {nullptr, m_error_info_dev}));
  84. m_handle_rocm = Handle::make(comp_handle);
  85. megdnn_assert(static_cast<bool>(m_handle_rocm));
  86. }
  87. void ROCM_ERROR_INFO::TearDown() {
  88. if (m_error_info_dev) {
  89. hip_check(hipFree(m_error_info_dev));
  90. }
  91. m_handle_rocm.reset();
  92. MemoryManagerHolder::instance()->clear();
  93. }
  94. megcore::AsyncErrorInfo ROCM_ERROR_INFO::get_error_info() {
  95. megcore::AsyncErrorInfo ret;
  96. auto stream = rocm::hip_stream(m_handle_rocm.get());
  97. hip_check(hipMemcpyAsync(
  98. &ret, m_error_info_dev, sizeof(ret), hipMemcpyDeviceToHost, stream));
  99. hip_check(hipStreamSynchronize(stream));
  100. return ret;
  101. }
  102. // vim: syntax=cpp.doxygen