You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

handle.h 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. /**
  2. * \file dnn/include/megdnn/handle.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megcore.h"
  13. #include "megdnn/config/config.h"
  14. #include "megdnn/basic_types.h"
  15. #include <functional>
  16. #include <memory>
  17. #include "megdnn/internal/visibility_prologue.h"
  18. namespace megdnn {
  19. class OperatorBase;
  20. class Handle {
  21. public:
  22. enum class HandleType {
  23. NAIVE = 0,
  24. FALLBACK = 1,
  25. X86 = 2,
  26. ARM_COMMON = 3,
  27. ARMV7 = 4,
  28. AARCH64 = 5,
  29. CUDA = 6,
  30. ROCM = 11,
  31. ATLAS = 13,
  32. CAMBRICON = 12,
  33. };
  34. //! Device vendor
  35. enum class HandleVendorType : uint32_t {
  36. NOT_SPEC = 0,
  37. MALI = 1,
  38. ADRENO = 2,
  39. CUDA = 3,
  40. INTEL = 4,
  41. POWERVR = 5,
  42. AMD = 6,
  43. };
  44. protected:
  45. Handle(megcoreComputingHandle_t computing_handle, HandleType type);
  46. public:
  47. /**
  48. * \brief Create a MegDNN handle from a MegCore Computing handle.
  49. *
  50. * \param[in] computing_handle MegCore computing handle. Please note
  51. * that computing_handle would not be released when this Handle is
  52. * destructed
  53. * \param[in] debug_level
  54. * Applicable for CPU computing handle.
  55. * 0 means taking the fastest possible code path; it may contains
  56. * platform-specific instructions such as SSE for x86_64 or NEON for
  57. * armv7v7.
  58. * 1 means taking the fastest possible code path without
  59. * platform-specific instructions in C++ code. Note that the compiled
  60. * binary file still contains platform-specific codes.
  61. * 2 means taking the naive code path. Performance is severely
  62. * hampered, but it is less error-prone since the internal
  63. * implementation is rather straightforward.
  64. *
  65. * **Debug level 1 and 2 should not be used in productions.**
  66. */
  67. static std::unique_ptr<Handle> make(
  68. megcoreComputingHandle_t computing_handle,
  69. int debug_level = 0);
  70. #if MEGDNN_WITH_CUDA
  71. static std::unique_ptr<Handle> make_cuda_handle(
  72. megcoreComputingHandle_t computing_handle);
  73. template <typename opr>
  74. std::unique_ptr<opr> create_cuda_operator();
  75. #endif
  76. #if MEGDNN_WITH_ROCM
  77. static std::unique_ptr<Handle> make_rocm_handle(
  78. megcoreComputingHandle_t computing_handle);
  79. template <typename opr>
  80. std::unique_ptr<opr> create_rocm_operator();
  81. #endif
  82. virtual ~Handle();
  83. /*!
  84. * \brief Get the underlying megcore computing handle.
  85. */
  86. megcoreComputingHandle_t megcore_computing_handle() const {
  87. return m_computing_handle;
  88. }
  89. /*!
  90. * \brief set a callback function to be invoked when this handle is
  91. * destructed, so associated resources can be released (e.g.
  92. * computing handle)
  93. *
  94. * This function can be called at most once.
  95. */
  96. void set_destructor(const thin_function<void()> &d);
  97. /*!
  98. * \brief set a callback to be invoked when an operator is destructed
  99. * \param[in,out] cb the callback function; it would be set to the
  100. * previous callback function
  101. */
  102. void set_opr_destruct_callback(thin_function<void(OperatorBase*)> &cb) {
  103. cb.swap(m_on_opr_destructed);
  104. }
  105. void on_opr_destructed(OperatorBase* opr);
  106. /**
  107. * \brief Create operator of Opr type.
  108. */
  109. template <typename Opr>
  110. std::unique_ptr<Opr> create_operator();
  111. /*
  112. * =============================================================
  113. * Users should call functions below to query memory requirement.
  114. * =============================================================
  115. */
  116. /**
  117. * \brief The internal data pointer of TensorND should be aligned to
  118. * alignment_requirement() in bytes.
  119. */
  120. virtual size_t alignment_requirement() const;
  121. //! get alignment in bytes for rows of image 2D tensor format
  122. virtual size_t image2d_pitch_alignment() const;
  123. //! get vendor type
  124. virtual HandleVendorType vendor_type() const;
  125. HandleType type() const {
  126. return m_handle_type;
  127. }
  128. /**
  129. * \brief Check is the layout satisfy cross device copy constraint.
  130. * 1. The handle of the src and the dst is the same kind
  131. * 2. The dst is continguous.
  132. */
  133. virtual bool check_cross_dev_copy_constraint(const TensorLayout &src);
  134. private:
  135. static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u;
  136. volatile uint32_t m_alive_magic = ALIVE_MAGIC;
  137. megcoreComputingHandle_t m_computing_handle;
  138. const HandleType m_handle_type;
  139. thin_function<void()> m_destructor;
  140. thin_function<void(OperatorBase*)> m_on_opr_destructed;
  141. Handle() = delete;
  142. Handle(const Handle &rhs) = delete;
  143. Handle &operator=(const Handle &rhs) = delete;
  144. };
  145. } // namespace megdnn
  146. #include "megdnn/internal/visibility_epilogue.h"
  147. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台