You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cpp 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. #include "test/common/utils.h"
  2. #include "megcore.h"
  3. #include "megdnn/basic_types.h"
  4. #include "src/naive/handle.h"
  5. #include "test/common/memory_manager.h"
  6. #include "test/common/random_state.h"
  7. #include <cmath>
  8. #include <random>
  9. using namespace megdnn;
  10. using namespace test;
  11. namespace {
  12. void megdnn_memcpy_internal(
  13. Handle* handle, void* dst, const void* src, size_t size_in_bytes,
  14. megcoreMemcpyKind_t kind) {
  15. auto comp_handle = handle->megcore_computing_handle();
  16. megcore_check(megcoreMemcpy(comp_handle, dst, src, size_in_bytes, kind));
  17. megcore_check(megcoreSynchronize(comp_handle));
  18. }
  19. class ErrorHandlerImpl final : public ErrorHandler {
  20. static ErrorHandlerImpl inst;
  21. void do_on_megdnn_error(const std::string& msg) override {
  22. fprintf(stderr, "megdnn error: %s\n", msg.c_str());
  23. #if MEGDNN_ENABLE_EXCEPTIONS
  24. throw MegDNNError{msg};
  25. #else
  26. megdnn_trap();
  27. #endif
  28. }
  29. void do_on_tensor_reshape_error(const std::string& msg) override {
  30. fprintf(stderr, "tensor reshape error: %s\n", msg.c_str());
  31. #if MEGDNN_ENABLE_EXCEPTIONS
  32. throw TensorReshapeError{msg};
  33. #else
  34. megdnn_trap();
  35. #endif
  36. }
  37. public:
  38. ErrorHandlerImpl() { ErrorHandler::set_handler(this); }
  39. };
  40. ErrorHandlerImpl ErrorHandlerImpl::inst;
  41. } // anonymous namespace
  42. CpuDispatchChecker::InstCounter CpuDispatchChecker::sm_inst_counter;
  43. std::unique_ptr<Handle> test::create_cpu_handle(
  44. int debug_level, bool check_dispatch, TaskExecutorConfig* config) {
  45. std::shared_ptr<MegcoreCPUDispatcher> dispatcher(nullptr);
  46. if (check_dispatch) {
  47. dispatcher = CpuDispatchChecker::make(config);
  48. }
  49. return create_cpu_handle_with_dispatcher(debug_level, dispatcher);
  50. }
  51. std::unique_ptr<Handle> test::create_cpu_handle_with_dispatcher(
  52. int debug_level, const std::shared_ptr<MegcoreCPUDispatcher>& dispatcher) {
  53. megcoreDeviceHandle_t dev_handle;
  54. megcore_check(megcoreCreateDeviceHandle(&dev_handle, megcorePlatformCPU));
  55. megcoreComputingHandle_t comp_handle;
  56. if (dispatcher) {
  57. megcore_check(megcoreCreateComputingHandleWithCPUDispatcher(
  58. &comp_handle, dev_handle, dispatcher));
  59. } else {
  60. megcore_check(megcoreCreateComputingHandle(&comp_handle, dev_handle));
  61. }
  62. auto destructor = [=]() {
  63. megcore_check(megcoreDestroyComputingHandle(comp_handle));
  64. megcore_check(megcoreDestroyDeviceHandle(dev_handle));
  65. };
  66. auto ret = Handle::make(comp_handle, debug_level);
  67. ret->set_destructor(destructor);
  68. return ret;
  69. }
  70. void test::megdnn_sync(Handle* handle) {
  71. auto comp_handle = handle->megcore_computing_handle();
  72. megcore_check(megcoreSynchronize(comp_handle));
  73. }
  74. void* test::megdnn_malloc(Handle* handle, size_t size_in_bytes) {
  75. auto mm = MemoryManagerHolder::instance()->get(handle);
  76. return mm->malloc(size_in_bytes);
  77. }
  78. void test::megdnn_free(Handle* handle, void* ptr) {
  79. auto mm = MemoryManagerHolder::instance()->get(handle);
  80. mm->free(ptr);
  81. }
  82. void test::megdnn_memcpy_D2H(
  83. Handle* handle, void* dst, const void* src, size_t size_in_bytes) {
  84. megdnn_memcpy_internal(handle, dst, src, size_in_bytes, megcoreMemcpyDeviceToHost);
  85. }
  86. void test::megdnn_memcpy_H2D(
  87. Handle* handle, void* dst, const void* src, size_t size_in_bytes) {
  88. megdnn_memcpy_internal(handle, dst, src, size_in_bytes, megcoreMemcpyHostToDevice);
  89. }
  90. void test::megdnn_memcpy_D2D(
  91. Handle* handle, void* dst, const void* src, size_t size_in_bytes) {
  92. megdnn_memcpy_internal(
  93. handle, dst, src, size_in_bytes, megcoreMemcpyDeviceToDevice);
  94. }
  95. TensorND DynOutMallocPolicyImpl::alloc_output(
  96. size_t /*id*/, DType dtype, const TensorShape& shape, void* /*user_data*/) {
  97. auto ptr = megdnn_malloc(m_handle, dtype.size() * shape.total_nr_elems());
  98. return {ptr, TensorLayout{shape, dtype}};
  99. }
  100. void* DynOutMallocPolicyImpl::alloc_workspace(size_t sz, void* /*user_data*/) {
  101. return megdnn_malloc(m_handle, sz);
  102. }
  103. void DynOutMallocPolicyImpl::free_workspace(void* ptr, void* /*user_data*/) {
  104. megdnn_free(m_handle, ptr);
  105. }
  106. std::shared_ptr<void> DynOutMallocPolicyImpl::make_output_refholder(
  107. const TensorND& out) {
  108. using namespace std::placeholders;
  109. auto deleter = std::bind(megdnn_free, m_handle, _1);
  110. return {out.raw_ptr(), deleter};
  111. }
  112. NaivePitchAlignmentScope::NaivePitchAlignmentScope(
  113. size_t alignment, megdnn::Handle::HandleVendorType vendor)
  114. : m_orig_val{naive::HandleImpl::exchange_image2d_pitch_alignment(alignment)},
  115. m_new_val{alignment},
  116. m_orig_vendor{naive::HandleImpl::exchange_image2d_vendor(vendor)},
  117. m_new_vendor{vendor} {}
  118. NaivePitchAlignmentScope::~NaivePitchAlignmentScope() {
  119. auto r = naive::HandleImpl::exchange_image2d_pitch_alignment(m_orig_val);
  120. megdnn_assert(r == m_new_val);
  121. auto v = naive::HandleImpl::exchange_image2d_vendor(m_orig_vendor);
  122. megdnn_assert(v == m_new_vendor);
  123. }
  124. size_t test::get_cpu_count() {
  125. return std::max<size_t>(std::thread::hardware_concurrency(), 1_z);
  126. }
  127. // vim: syntax=cpp.doxygen