You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bases.h 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. /**
  2. * \file src/core/include/megbrain/graph/bases.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megbrain/utils/json.h"
  13. #include "megbrain/utils/metahelper.h"
  14. #include "megbrain/exception.h"
  15. #include "megbrain/comp_node.h"
  16. #include <string>
  17. #ifndef MGB_ENABLE_DTR
  18. #define MGB_ENABLE_DTR ((!MGB_BUILD_SLIM_SERVING) && (!!MGB_HAVE_THREAD))
  19. #endif // MGB_ENABLE_DTR
  20. #ifndef MGB_ENABLE_SUBLINEAR
  21. #define MGB_ENABLE_SUBLINEAR ((!MGB_BUILD_SLIM_SERVING) && (!!MGB_HAVE_THREAD))
  22. #endif // MGB_ENABLE_SUBLINEAR
  23. // FIXME: reopen when rewriting memory swap or existing tests are passed
  24. #define MGB_ENABLE_MEMORY_SWAP 0
  25. #ifndef MGB_ENABLE_MEMORY_SWAP
  26. #define MGB_ENABLE_MEMORY_SWAP \
  27. ((!MGB_BUILD_SLIM_SERVING) && (!!MGB_HAVE_THREAD) && (MGB_CUDA))
  28. #endif // MGB_ENABLE_MEMORY_SWAP
  29. #ifndef MGB_ENABLE_PARTIAL_EXECUTION
  30. #define MGB_ENABLE_PARTIAL_EXECUTION (!MGB_BUILD_SLIM_SERVING)
  31. #endif // MGB_ENABLE_PARTIAL_EXECUTION
  32. #ifndef MGB_ENABLE_COND_EXEC
  33. #define MGB_ENABLE_COND_EXEC !MGB_BUILD_SLIM_SERVING
  34. #endif
  35. #if MGB_ENABLE_COND_EXEC
  36. #define MGB_IF_COND_EXEC(x...) x
  37. #else
  38. #define MGB_IF_COND_EXEC(x...)
  39. #endif
  40. #if MGB_CUDA && MGB_ENABLE_EXCEPTION
  41. #define MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER 1
  42. #else
  43. #define MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER 0
  44. #endif // whether enable memory defragment
  45. namespace mgb {
  46. class GraphError : public MegBrainError {
  47. public:
  48. using MegBrainError::MegBrainError;
  49. };
  50. } // namespace mgb
  51. namespace mgb {
  52. //! computing graph
  53. namespace cg {
  54. namespace static_infer {
  55. struct DepElement;
  56. };
  57. using GraphError = mgb::GraphError;
  58. class VarNode;
  59. class OperatorNodeBase;
  60. class ComputingGraph;
  61. using VarNodeArray = mgb::SmallVector<VarNode*>;
  62. /*!
  63. * \brief Base class for a node in the graph.
  64. *
  65. * Each node must have a name for debugging and graph dump, and each node is
  66. * uniquely identified by its memory address. Every node in a computing graph
  67. * has its unique numerical ID.
  68. */
  69. class GraphNodeBase: public json::Serializable, public NonCopyableObj {
  70. ComputingGraph* const m_owner_graph;
  71. size_t m_id;
  72. protected:
  73. ~GraphNodeBase() = default;
  74. public:
  75. GraphNodeBase(ComputingGraph *owner_graph);
  76. ComputingGraph* owner_graph() const {
  77. return m_owner_graph;
  78. }
  79. //! get node ID as string
  80. std::string id_str() const {
  81. return std::to_string(m_id);
  82. }
  83. //! get node ID as number
  84. size_t id() const {
  85. return m_id;
  86. }
  87. };
  88. class OutputVarsUserData final : public mgb::UserDataContainer::UserData {
  89. MGB_TYPEINFO_OBJ_DECL;
  90. private:
  91. VarNodeArray m_output_vars;
  92. public:
  93. void set_output_vars(VarNodeArray vars) { m_output_vars = std::move(vars); }
  94. const VarNodeArray& get_output_vars() const { return m_output_vars; }
  95. };
  96. /*!
  97. * \brief an object that executes asynchronously
  98. */
  99. class AsyncExecutable : public json::Serializable,
  100. public CompNodeDepedentObject {
  101. UserDataContainer m_user_data;
  102. public:
  103. virtual ~AsyncExecutable() noexcept;
  104. virtual AsyncExecutable& execute() = 0;
  105. /*!
  106. * \brief wait for current task to finish
  107. */
  108. virtual AsyncExecutable& wait() = 0;
  109. /*!
  110. * \brief previous execution time in seconds
  111. */
  112. virtual double get_prev_exec_time() const = 0;
  113. /*!
  114. * \brief iterate over operator sequence
  115. * \param cb callback function, return false to stop iterating
  116. */
  117. virtual AsyncExecutable& iter_opr_seq(
  118. thin_function<bool(OperatorNodeBase*)> cb) = 0;
  119. /*!
  120. * \brief get RT_STATIC deps needed for static infer in this func
  121. */
  122. virtual const SmallVector<static_infer::DepElement>&
  123. get_rt_static_source_deps() = 0;
  124. /*!
  125. * \brief number of calls to execute()
  126. */
  127. virtual size_t get_run_id() const = 0;
  128. /*!
  129. * \brief update static memory allocation plan and allocation size
  130. *
  131. * Note: as a side effect, static shape inference would be executed and
  132. * var shapes are updated.
  133. *
  134. * \return static allocation size for each comp node
  135. */
  136. virtual const CompNode::UnorderedMap<size_t>&
  137. update_static_alloc_plan_and_get_size() = 0;
  138. /*!
  139. * \brief clear device memory; memory would be allocated in the next run
  140. */
  141. virtual void clear_device_memory() = 0;
  142. //! get the graph that owns this executable; nullptr if no owner graph
  143. virtual ComputingGraph* owner_graph() const = 0;
  144. //! user data associated with a compiled executable
  145. UserDataContainer& user_data() {
  146. return m_user_data;
  147. }
  148. void set_output_vars(const VarNodeArray& vars) {
  149. std::shared_ptr<OutputVarsUserData> ud =
  150. std::make_shared<OutputVarsUserData>();
  151. ud->set_output_vars(vars);
  152. m_user_data.add_user_data(ud);
  153. }
  154. const VarNodeArray& get_output_vars() const {
  155. auto output_vars_pair =
  156. m_user_data.get_user_data<OutputVarsUserData>();
  157. return (*(output_vars_pair.first))->get_output_vars();
  158. }
  159. #ifndef __IN_TEE_ENV__
  160. virtual void get_static_memory_alloc_info(
  161. const std::string& log_dir) const {
  162. mgb_assert(log_dir.length() < 0,
  163. "can't call this function directly\n");
  164. }
  165. #endif
  166. };
  167. } // namespace cg
  168. } // namespace mgb
  169. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台