You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

switch_to_stream_switch_pass.h 9.8 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_
  17. #define GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_
  18. #include "inc/graph_pass.h"
  19. namespace ge {
  20. /* Variable Initialize Flow, take as FrameworkOp
  21. +-----------+
  22. | Merge |
  23. +-----------+
  24. / \
  25. 0/ \x
  26. / \
  27. +-----------+ +-----------+
  28. | Switch | | Switch |
  29. +-----------+ +-----------+
  30. | |F T| |
  31. 0| | | x|
  32. | | | |
  33. | +-----------------------+ |
  34. | | IsVariableInitialized | |
  35. | +-----------------------+ |
  36. | | |
  37. | | |
  38. | | |
  39. +-----------+ +-----------+
  40. | Const | | VariableV2|
  41. +-----------+ +-----------+
  42. */
  43. /* Switch branch op optimize, Switches in same case merge to one StreamSwitch, update following nodes' input
  44. +-----------+
  45. / | task2 | \
  46. T/ +-----------+ \
  47. +-----------+ +-----------+ / \ +-----------+ +-----------+
  48. | task1 | --> | Switch | | task4 | --> | noop |
  49. +-----------+ +-----------+ \ / +-----------+ +-----------+
  50. F\ +-----------+ /
  51. \ | task3 | /
  52. +-----------+
  53. cond(x < y, lambda: add(x, z), lambda: square(y))
  54. +-----------+ +-----------+
  55. | Merge | +------------|StreamMerge|----------+
  56. +-----------+ | +-----------+ |
  57. / \ | | |
  58. / \ |c | |c
  59. / \ +----------+ ----------- +----------+
  60. +-----------+ +-----------+ | Active_f | / \ | Active_t |
  61. | Square | | Add | +----------+ / \ +----------+
  62. +-----------+ +-----------+ \ / \ /
  63. / / \ \c / \ /c
  64. y/ x/ \z +-----------+ +-----------+
  65. / / \ | Square | | Add |
  66. +-----------+ +-----------+ +-----------+ +-----------+ +-----------+
  67. | Switch | | Switch | | Switch | ====> / | / | \
  68. +-----------+ +-----------+ +-----------+ / | / | \
  69. y| |F T| |x T| |z +--------+ | +--------+ | +--------+
  70. | | | | | | | y/read | | | x/read | | | z/read |
  71. | +-----------+ | | | +--------+ | +--------+ | +--------+
  72. | | Less |-------------------+ | |c |c
  73. | +-----------+ | | +----------------+ +----------------+
  74. | | | | StreamSwitch_f | | StreamSwitch_t |
  75. | | | +----------------+ +----------------+
  76. +-----------+ +-----------+ +-----------+ | |
  77. | y/read | | x/read | | z/read | | +-----------+ |
  78. +-----------+ +-----------+ +-----------+ +-----| Less |----+
  79. +-----------+
  80. */
  81. class SwitchToStreamSwitchPass : public GraphPass {
  82. public:
  83. Status Run(ComputeGraphPtr graph);
  84. ///
  85. /// @brief Clear Status, used for subgraph pass
  86. /// @return
  87. ///
  88. Status ClearStatus() override;
  89. private:
  90. ///
  91. /// @brief Check cyclic dependence
  92. /// @param [in] graph
  93. /// @return Status
  94. ///
  95. Status CheckCycleDependence(const ComputeGraphPtr &graph);
  96. ///
  97. /// @brief Mark cyclic dependence
  98. /// @param [in] graph
  99. /// @param [in] cond_switch_map
  100. /// @return void
  101. ///
  102. void MarkCycleDependence(const std::unordered_map<NodePtr, std::vector<NodePtr>> &cond_switch_map);
  103. ///
  104. /// @brief Replace Switch Op
  105. /// @param [in] graph
  106. /// @param [in] switch_node
  107. /// @return Status
  108. ///
  109. Status ReplaceSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node);
  110. ///
  111. /// @brief Bypass Switch Node
  112. /// @param [in] switch_node
  113. /// @param [out] peer_data_anchor
  114. /// @param [out] peer_cond_anchor
  115. /// @return Status
  116. ///
  117. Status BypassSwitchNode(const NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor,
  118. OutDataAnchorPtr &peer_cond_anchor);
  119. ///
  120. /// @brief Find Switch cond input
  121. /// @param [in] pass_switch_flag
  122. /// @param [out] peer_cond_anchor
  123. /// @return Status
  124. ///
  125. Status FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor);
  126. ///
  127. /// @brief Create StreamSwitch Node
  128. /// @param [in] graph
  129. /// @param [in] switch_node
  130. /// @param [in] suffix
  131. /// @param [in] peer_cond_anchor
  132. /// @return ge::NodePtr
  133. ///
  134. NodePtr CreateStreamSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix,
  135. const OutDataAnchorPtr &peer_cond_anchor);
  136. ///
  137. /// @brief Mark Switch Branch
  138. /// @param [in] peer_cond_anchor
  139. /// @param [in] stream_switch
  140. /// @param [in] true_branch_flag
  141. /// @return Status
  142. ///
  143. Status MarkBranches(const OutDataAnchorPtr &peer_cond_anchor, const NodePtr &stream_switch_node,
  144. bool true_branch_flag);
  145. ///
  146. /// @brief Get group_id for switch_node
  147. /// @param [in] node
  148. /// @return group_id
  149. ///
  150. int64_t GetGroupId(const NodePtr &node);
  151. ///
  152. /// @brief Combine switch nodes link to same cond
  153. /// @param [in] graph
  154. /// @return Status
  155. ///
  156. Status CombineSwitchNode(const ComputeGraphPtr &graph);
  157. ///
  158. /// @brief Create cast node
  159. /// @param [in] graph
  160. /// @param [in] peer_cond_anchor
  161. /// @return NodePtr
  162. ///
  163. NodePtr CreateCastOp(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_cond_anchor);
  164. ///
  165. /// @brief Create Active Op
  166. /// @param [in] graph
  167. /// @param [in] cond_node
  168. /// @return ge::NodePtr
  169. ///
  170. NodePtr CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node);
  171. ///
  172. /// @brief Add const node as switch input1
  173. /// @param [in] graph
  174. /// @param [in] stream_switch
  175. /// @return Status
  176. ///
  177. Status AddConstNode(const ComputeGraphPtr &graph, const NodePtr &stream_switch_node);
  178. ///
  179. /// @brief Modify in ctl edge for switch_node
  180. /// @param [in] switch_node
  181. /// @param [in] cast_node
  182. /// @param [in] same_cond_switch
  183. /// @return Status
  184. ///
  185. Status ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node,
  186. const std::set<NodePtr> &same_cond_switch);
  187. ///
  188. /// @brief Modify out ctl edge for switch_node
  189. /// @param [in] switch_node
  190. /// @param [in] stream_switch
  191. /// @param [in] active_node
  192. /// @return Status
  193. ///
  194. Status ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch, const NodePtr &active_node);
  195. ///
  196. /// @brief Check duplicate node_name
  197. /// @param [in] node_name
  198. /// @return std::string
  199. ///
  200. std::string CheckDuplicateName(const std::string &node_name);
  201. ///
  202. /// @brief Move Control Edges
  203. /// @param [in] old_node
  204. /// @param [in] new_node
  205. /// @return void
  206. ///
  207. void MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node);
  208. std::vector<NodePtr> switch_nodes_;
  209. std::unordered_map<NodePtr, std::set<std::string>> switch_cyclic_map_;
  210. std::set<NodePtr> bypass_nodes_;
  211. std::vector<NodePtr> stream_switch_nodes_;
  212. std::unordered_map<OutDataAnchorPtr, std::map<int64_t, std::vector<std::list<NodePtr>>>> cond_node_map_;
  213. std::unordered_map<NodePtr, std::set<std::string>> switch_node_map_;
  214. std::unordered_map<std::string, uint32_t> node_num_map_;
  215. };
  216. } // namespace ge
  217. #endif // GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示