You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

switch_to_stream_switch_pass.h 9.7 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_
  17. #define GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_
  18. #include "inc/graph_pass.h"
  19. namespace ge {
  20. /* Variable Initialize Flow, take as FrameworkOp
  21. +-----------+
  22. | Merge |
  23. +-----------+
  24. / \
  25. 0/ \x
  26. / \
  27. +-----------+ +-----------+
  28. | Switch | | Switch |
  29. +-----------+ +-----------+
  30. | |F T| |
  31. 0| | | x|
  32. | | | |
  33. | +-----------------------+ |
  34. | | IsVariableInitialized | |
  35. | +-----------------------+ |
  36. | | |
  37. | | |
  38. | | |
  39. +-----------+ +-----------+
  40. | Const | | VariableV2|
  41. +-----------+ +-----------+
  42. Switch branch op optimize, Switches in same case merge to one StreamSwitch, update following nodes' input
  43. +-----------+
  44. / | task2 | \
  45. T/ +-----------+ \
  46. +-----------+ +-----------+ / \ +-----------+ +-----------+
  47. | task1 | --> | Switch | | task4 | --> | noop |
  48. +-----------+ +-----------+ \ / +-----------+ +-----------+
  49. F\ +-----------+ /
  50. \ | task3 | /
  51. +-----------+
  52. cond(x < y, lambda: add(x, z), lambda: square(y))
  53. +-----------+ +-----------+
  54. | Merge | +------------|StreamMerge|----------+
  55. +-----------+ | +-----------+ |
  56. / \ | | |
  57. / \ |c | |c
  58. / \ +----------+ ----------- +----------+
  59. +-----------+ +-----------+ | Active_f | / \ | Active_t |
  60. | Square | | Add | +----------+ / \ +----------+
  61. +-----------+ +-----------+ \ / \ /
  62. / / \ \c / \ /c
  63. y/ x/ \z +-----------+ +-----------+
  64. / / \ | Square | | Add |
  65. +-----------+ +-----------+ +-----------+ +-----------+ +-----------+
  66. | Switch | | Switch | | Switch | ====> / | / | \
  67. +-----------+ +-----------+ +-----------+ / | / | \
  68. y| |F T| |x T| |z +--------+ | +--------+ | +--------+
  69. | | | | | | | y/read | | | x/read | | | z/read |
  70. | +-----------+ | | | +--------+ | +--------+ | +--------+
  71. | | Less |-------------------+ | |c |c
  72. | +-----------+ | | +----------------+ +----------------+
  73. | | | | StreamSwitch_f | | StreamSwitch_t |
  74. | | | +----------------+ +----------------+
  75. +-----------+ +-----------+ +-----------+ | |
  76. | y/read | | x/read | | z/read | | +-----------+ |
  77. +-----------+ +-----------+ +-----------+ +-----| Less |----+
  78. +-----------+
  79. */
  80. class SwitchToStreamSwitchPass : public GraphPass {
  81. public:
  82. Status Run(ComputeGraphPtr graph);
  83. ///
  84. /// @brief Clear Status, used for subgraph pass
  85. /// @return
  86. ///
  87. Status ClearStatus() override;
  88. private:
  89. ///
  90. /// @brief Check cyclic dependence
  91. /// @param [in] graph
  92. /// @return Status
  93. ///
  94. Status CheckCycleDependence(const ComputeGraphPtr &graph);
  95. ///
  96. /// @brief Mark cyclic dependence
  97. /// @param [in] graph
  98. /// @param [in] cond_switch_map
  99. /// @return void
  100. ///
  101. void MarkCycleDependence(const std::unordered_map<NodePtr, std::vector<NodePtr>> &cond_switch_map);
  102. ///
  103. /// @brief Replace Switch Op
  104. /// @param [in] graph
  105. /// @param [in] switch_node
  106. /// @return Status
  107. ///
  108. Status ReplaceSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node);
  109. ///
  110. /// @brief Bypass Switch Node
  111. /// @param [in] switch_node
  112. /// @param [out] peer_data_anchor
  113. /// @param [out] peer_cond_anchor
  114. /// @return Status
  115. ///
  116. Status BypassSwitchNode(const NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor,
  117. OutDataAnchorPtr &peer_cond_anchor);
  118. ///
  119. /// @brief Find Switch cond input
  120. /// @param [out] peer_cond_anchor
  121. /// @return Status
  122. ///
  123. Status FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor);
  124. ///
  125. /// @brief Create StreamSwitch Node
  126. /// @param [in] graph
  127. /// @param [in] switch_node
  128. /// @param [in] suffix
  129. /// @param [in] peer_cond_anchor
  130. /// @return ge::NodePtr
  131. ///
  132. NodePtr CreateStreamSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix,
  133. const OutDataAnchorPtr &peer_cond_anchor);
  134. ///
  135. /// @brief Mark Switch Branch
  136. /// @param [in] peer_cond_anchor
  137. /// @param [in] stream_switch
  138. /// @param [in] true_branch_flag
  139. /// @return Status
  140. ///
  141. Status MarkBranches(const OutDataAnchorPtr &peer_cond_anchor, const NodePtr &stream_switch_node,
  142. bool true_branch_flag);
  143. ///
  144. /// @brief Get group_id for switch_node
  145. /// @param [in] node
  146. /// @return group_id
  147. ///
  148. int64_t GetGroupId(const NodePtr &node);
  149. ///
  150. /// @brief Combine switch nodes link to same cond
  151. /// @param [in] graph
  152. /// @return Status
  153. ///
  154. Status CombineSwitchNode(const ComputeGraphPtr &graph);
  155. ///
  156. /// @brief Create cast node
  157. /// @param [in] graph
  158. /// @param [in] peer_cond_anchor
  159. /// @return NodePtr
  160. ///
  161. NodePtr CreateCastOp(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_cond_anchor);
  162. ///
  163. /// @brief Create Active Op
  164. /// @param [in] graph
  165. /// @param [in] cond_node
  166. /// @return ge::NodePtr
  167. ///
  168. NodePtr CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node);
  169. ///
  170. /// @brief Add const node as switch input1
  171. /// @param [in] graph
  172. /// @param [in] stream_switch
  173. /// @return Status
  174. ///
  175. Status AddConstNode(const ComputeGraphPtr &graph, const NodePtr &stream_switch_node);
  176. ///
  177. /// @brief Modify in ctl edge for switch_node
  178. /// @param [in] switch_node
  179. /// @param [in] cast_node
  180. /// @param [in] same_cond_switch
  181. /// @return Status
  182. ///
  183. Status ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node,
  184. const std::set<NodePtr> &same_cond_switch);
  185. ///
  186. /// @brief Modify out ctl edge for switch_node
  187. /// @param [in] switch_node
  188. /// @param [in] stream_switch
  189. /// @param [in] active_node
  190. /// @return Status
  191. ///
  192. Status ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch, const NodePtr &active_node);
  193. ///
  194. /// @brief Check duplicate node_name
  195. /// @param [in] node_name
  196. /// @return std::string
  197. ///
  198. std::string CheckDuplicateName(const std::string &node_name);
  199. ///
  200. /// @brief Move Control Edges
  201. /// @param [in] old_node
  202. /// @param [in] new_node
  203. /// @return void
  204. ///
  205. void MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node);
  206. std::vector<NodePtr> switch_nodes_;
  207. std::unordered_map<NodePtr, std::set<std::string>> switch_cyclic_map_;
  208. std::set<NodePtr> bypass_nodes_;
  209. std::vector<NodePtr> stream_switch_nodes_;
  210. std::unordered_map<OutDataAnchorPtr, std::map<int64_t, std::vector<std::list<NodePtr>>>> cond_node_map_;
  211. std::unordered_map<NodePtr, std::set<std::string>> switch_node_map_;
  212. std::map<std::string, uint32_t> node_num_map_;
  213. };
  214. } // namespace ge
  215. #endif // GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示