You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cpp 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. /**
  2. * \file dnn/test/common/utils.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/utils.h"
  12. #include "megcore.h"
  13. #include "megdnn/basic_types.h"
  14. #include "src/naive/handle.h"
  15. #include "test/common/memory_manager.h"
  16. #include "test/common/random_state.h"
  17. #include <cmath>
  18. #include <random>
  19. using namespace megdnn;
  20. using namespace test;
  21. namespace {
  22. void megdnn_memcpy_internal(
  23. Handle* handle, void* dst, const void* src, size_t size_in_bytes,
  24. megcoreMemcpyKind_t kind) {
  25. auto comp_handle = handle->megcore_computing_handle();
  26. megcore_check(megcoreMemcpy(comp_handle, dst, src, size_in_bytes, kind));
  27. megcore_check(megcoreSynchronize(comp_handle));
  28. }
  29. class ErrorHandlerImpl final : public ErrorHandler {
  30. static ErrorHandlerImpl inst;
  31. void do_on_megdnn_error(const std::string& msg) override {
  32. fprintf(stderr, "megdnn error: %s\n", msg.c_str());
  33. #if MEGDNN_ENABLE_EXCEPTIONS
  34. throw MegDNNError{msg};
  35. #else
  36. megdnn_trap();
  37. #endif
  38. }
  39. void do_on_tensor_reshape_error(const std::string& msg) override {
  40. fprintf(stderr, "tensor reshape error: %s\n", msg.c_str());
  41. #if MEGDNN_ENABLE_EXCEPTIONS
  42. throw TensorReshapeError{msg};
  43. #else
  44. megdnn_trap();
  45. #endif
  46. }
  47. public:
  48. ErrorHandlerImpl() { ErrorHandler::set_handler(this); }
  49. };
  50. ErrorHandlerImpl ErrorHandlerImpl::inst;
  51. } // anonymous namespace
  52. CpuDispatchChecker::InstCounter CpuDispatchChecker::sm_inst_counter;
  53. std::unique_ptr<Handle> test::create_cpu_handle(
  54. int debug_level, bool check_dispatch, TaskExecutorConfig* config) {
  55. std::shared_ptr<MegcoreCPUDispatcher> dispatcher(nullptr);
  56. if (check_dispatch) {
  57. dispatcher = CpuDispatchChecker::make(config);
  58. }
  59. return create_cpu_handle_with_dispatcher(debug_level, dispatcher);
  60. }
  61. std::unique_ptr<Handle> test::create_cpu_handle_with_dispatcher(
  62. int debug_level, const std::shared_ptr<MegcoreCPUDispatcher>& dispatcher) {
  63. megcoreDeviceHandle_t dev_handle;
  64. megcore_check(megcoreCreateDeviceHandle(&dev_handle, megcorePlatformCPU));
  65. megcoreComputingHandle_t comp_handle;
  66. if (dispatcher) {
  67. megcore_check(megcoreCreateComputingHandleWithCPUDispatcher(
  68. &comp_handle, dev_handle, dispatcher));
  69. } else {
  70. megcore_check(megcoreCreateComputingHandle(&comp_handle, dev_handle));
  71. }
  72. auto destructor = [=]() {
  73. megcore_check(megcoreDestroyComputingHandle(comp_handle));
  74. megcore_check(megcoreDestroyDeviceHandle(dev_handle));
  75. };
  76. auto ret = Handle::make(comp_handle, debug_level);
  77. ret->set_destructor(destructor);
  78. return ret;
  79. }
  80. void test::megdnn_sync(Handle* handle) {
  81. auto comp_handle = handle->megcore_computing_handle();
  82. megcore_check(megcoreSynchronize(comp_handle));
  83. }
  84. void* test::megdnn_malloc(Handle* handle, size_t size_in_bytes) {
  85. auto mm = MemoryManagerHolder::instance()->get(handle);
  86. return mm->malloc(size_in_bytes);
  87. }
  88. void test::megdnn_free(Handle* handle, void* ptr) {
  89. auto mm = MemoryManagerHolder::instance()->get(handle);
  90. mm->free(ptr);
  91. }
  92. void test::megdnn_memcpy_D2H(
  93. Handle* handle, void* dst, const void* src, size_t size_in_bytes) {
  94. megdnn_memcpy_internal(handle, dst, src, size_in_bytes, megcoreMemcpyDeviceToHost);
  95. }
  96. void test::megdnn_memcpy_H2D(
  97. Handle* handle, void* dst, const void* src, size_t size_in_bytes) {
  98. megdnn_memcpy_internal(handle, dst, src, size_in_bytes, megcoreMemcpyHostToDevice);
  99. }
  100. void test::megdnn_memcpy_D2D(
  101. Handle* handle, void* dst, const void* src, size_t size_in_bytes) {
  102. megdnn_memcpy_internal(
  103. handle, dst, src, size_in_bytes, megcoreMemcpyDeviceToDevice);
  104. }
  105. TensorND DynOutMallocPolicyImpl::alloc_output(
  106. size_t /*id*/, DType dtype, const TensorShape& shape, void* /*user_data*/) {
  107. auto ptr = megdnn_malloc(m_handle, dtype.size() * shape.total_nr_elems());
  108. return {ptr, TensorLayout{shape, dtype}};
  109. }
  110. void* DynOutMallocPolicyImpl::alloc_workspace(size_t sz, void* /*user_data*/) {
  111. return megdnn_malloc(m_handle, sz);
  112. }
  113. void DynOutMallocPolicyImpl::free_workspace(void* ptr, void* /*user_data*/) {
  114. megdnn_free(m_handle, ptr);
  115. }
  116. std::shared_ptr<void> DynOutMallocPolicyImpl::make_output_refholder(
  117. const TensorND& out) {
  118. using namespace std::placeholders;
  119. auto deleter = std::bind(megdnn_free, m_handle, _1);
  120. return {out.raw_ptr(), deleter};
  121. }
  122. NaivePitchAlignmentScope::NaivePitchAlignmentScope(size_t alignment)
  123. : m_orig_val{naive::HandleImpl::exchange_image2d_pitch_alignment(alignment)},
  124. m_new_val{alignment} {}
  125. NaivePitchAlignmentScope::~NaivePitchAlignmentScope() {
  126. auto r = naive::HandleImpl::exchange_image2d_pitch_alignment(m_orig_val);
  127. megdnn_assert(r == m_new_val);
  128. }
  129. size_t test::get_cpu_count() {
  130. return std::max<size_t>(std::thread::hardware_concurrency(), 1_z);
  131. }
  132. // vim: syntax=cpp.doxygen