You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cpp 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. /**
  2. * \file dnn/test/common/utils.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/utils.h"
  12. #include "megdnn/basic_types.h"
  13. #include "test/common/random_state.h"
  14. #include "test/common/memory_manager.h"
  15. #include "src/naive/handle.h"
  16. #include "megcore.h"
  17. #include <cmath>
  18. #include <random>
  19. using namespace megdnn;
  20. using namespace test;
  21. namespace {
  22. void megdnn_memcpy_internal(Handle *handle, void *dst, const void *src,
  23. size_t size_in_bytes, megcoreMemcpyKind_t kind)
  24. {
  25. auto comp_handle = handle->megcore_computing_handle();
  26. megcore_check(megcoreMemcpy(comp_handle, dst, src, size_in_bytes,
  27. kind));
  28. megcore_check(megcoreSynchronize(comp_handle));
  29. }
  30. class ErrorHandlerImpl final: public ErrorHandler {
  31. static ErrorHandlerImpl inst;
  32. void do_on_megdnn_error(const std::string &msg) override {
  33. fprintf(stderr, "megdnn error: %s\n", msg.c_str());
  34. #if MEGDNN_ENABLE_EXCEPTIONS
  35. throw MegDNNError{msg};
  36. #else
  37. megdnn_trap();
  38. #endif
  39. }
  40. void do_on_tensor_reshape_error(const std::string &msg) override {
  41. fprintf(stderr, "tensor reshape error: %s\n", msg.c_str());
  42. #if MEGDNN_ENABLE_EXCEPTIONS
  43. throw TensorReshapeError{msg};
  44. #else
  45. megdnn_trap();
  46. #endif
  47. }
  48. public:
  49. ErrorHandlerImpl() {
  50. ErrorHandler::set_handler(this);
  51. }
  52. };
  53. ErrorHandlerImpl ErrorHandlerImpl::inst;
  54. } // anonymous namespace
  55. CpuDispatchChecker::InstCounter CpuDispatchChecker::sm_inst_counter;
  56. std::unique_ptr<Handle> test::create_cpu_handle(int debug_level,
  57. bool check_dispatch,
  58. TaskExecutorConfig* config) {
  59. std::shared_ptr<MegcoreCPUDispatcher> dispatcher(nullptr);
  60. if (check_dispatch) {
  61. dispatcher = CpuDispatchChecker::make(config);
  62. }
  63. return create_cpu_handle_with_dispatcher(debug_level, dispatcher);
  64. }
  65. std::unique_ptr<Handle> test::create_cpu_handle_with_dispatcher(int debug_level,
  66. const std::shared_ptr<MegcoreCPUDispatcher> &dispatcher)
  67. {
  68. megcoreDeviceHandle_t dev_handle;
  69. megcore_check(megcoreCreateDeviceHandle(&dev_handle,
  70. megcorePlatformCPU));
  71. megcoreComputingHandle_t comp_handle;
  72. if (dispatcher) {
  73. megcore_check(megcoreCreateComputingHandleWithCPUDispatcher(
  74. &comp_handle, dev_handle, dispatcher));
  75. } else {
  76. megcore_check(megcoreCreateComputingHandle(&comp_handle, dev_handle));
  77. }
  78. auto destructor = [=]() {
  79. megcore_check(megcoreDestroyComputingHandle(comp_handle));
  80. megcore_check(megcoreDestroyDeviceHandle(dev_handle));
  81. };
  82. auto ret = Handle::make(comp_handle, debug_level);
  83. ret->set_destructor(destructor);
  84. return ret;
  85. }
  86. void test::megdnn_sync(Handle *handle)
  87. {
  88. auto comp_handle = handle->megcore_computing_handle();
  89. megcore_check(megcoreSynchronize(comp_handle));
  90. }
  91. void* test::megdnn_malloc(Handle *handle, size_t size_in_bytes)
  92. {
  93. auto mm = MemoryManagerHolder::instance()->get(handle);
  94. return mm->malloc(size_in_bytes);
  95. }
  96. void test::megdnn_free(Handle *handle, void *ptr)
  97. {
  98. auto mm = MemoryManagerHolder::instance()->get(handle);
  99. mm->free(ptr);
  100. }
  101. void test::megdnn_memcpy_D2H(Handle *handle, void *dst, const void *src,
  102. size_t size_in_bytes)
  103. {
  104. megdnn_memcpy_internal(handle, dst, src, size_in_bytes,
  105. megcoreMemcpyDeviceToHost);
  106. }
  107. void test::megdnn_memcpy_H2D(Handle *handle, void *dst, const void *src,
  108. size_t size_in_bytes)
  109. {
  110. megdnn_memcpy_internal(handle, dst, src, size_in_bytes,
  111. megcoreMemcpyHostToDevice);
  112. }
  113. void test::megdnn_memcpy_D2D(Handle *handle, void *dst, const void *src,
  114. size_t size_in_bytes)
  115. {
  116. megdnn_memcpy_internal(handle, dst, src, size_in_bytes,
  117. megcoreMemcpyDeviceToDevice);
  118. }
  119. TensorND DynOutMallocPolicyImpl::alloc_output(
  120. size_t /*id*/, DType dtype, const TensorShape &shape,
  121. void * /*user_data*/) {
  122. auto ptr = megdnn_malloc(m_handle, dtype.size() * shape.total_nr_elems());
  123. return {ptr, TensorLayout{shape, dtype}};
  124. }
  125. void* DynOutMallocPolicyImpl::alloc_workspace(size_t sz, void * /*user_data*/) {
  126. return megdnn_malloc(m_handle, sz);
  127. }
  128. void DynOutMallocPolicyImpl::free_workspace(void *ptr, void * /*user_data*/) {
  129. megdnn_free(m_handle, ptr);
  130. }
  131. std::shared_ptr<void> DynOutMallocPolicyImpl::make_output_refholder(
  132. const TensorND &out) {
  133. using namespace std::placeholders;
  134. auto deleter = std::bind(megdnn_free, m_handle, _1);
  135. return {out.raw_ptr, deleter};
  136. }
  137. NaivePitchAlignmentScope::NaivePitchAlignmentScope(size_t alignment)
  138. : m_orig_val{naive::HandleImpl::exchange_image2d_pitch_alignment(
  139. alignment)},
  140. m_new_val{alignment} {}
  141. NaivePitchAlignmentScope::~NaivePitchAlignmentScope() {
  142. auto r = naive::HandleImpl::exchange_image2d_pitch_alignment(m_orig_val);
  143. megdnn_assert(r == m_new_val);
  144. }
  145. size_t test::get_cpu_count() {
  146. return std::max<size_t>(std::thread::hardware_concurrency(), 1_z);
  147. }
  148. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台