You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

next_iteration_pass.cc 11 kB

5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/next_iteration_pass.h"
  17. #include "common/ge/ge_util.h"
  18. #include "graph/common/omg_util.h"
  19. using std::string;
  20. namespace ge {
  21. Status NextIterationPass::Run(ComputeGraphPtr graph) {
  22. GELOGD("NextIterationPass Enter");
  23. /// Enter-----------+
  24. /// +-> Merge -> Switch <- LoopCond <- Cond
  25. /// NextIteration---+
  26. for (auto &node : graph->GetDirectNode()) {
  27. const std::string type = node->GetType();
  28. if ((type != ENTER) && (type != REFENTER)) {
  29. continue;
  30. }
  31. if (GroupEnterNode(node) != SUCCESS) {
  32. GELOGE(INTERNAL_ERROR, "Group enter_node %s failed.", node->GetName().c_str());
  33. return INTERNAL_ERROR;
  34. }
  35. }
  36. if (FindWhileGroups() != SUCCESS) {
  37. GELOGE(INTERNAL_ERROR, "Find while groups failed.");
  38. return INTERNAL_ERROR;
  39. }
  40. if (!VerifyWhileGroup()) {
  41. GELOGE(INTERNAL_ERROR, "Verify while groups failed.");
  42. return INTERNAL_ERROR;
  43. }
  44. if (HandleWhileGroup(graph) != SUCCESS) {
  45. GELOGE(FAILED, "Handle while groups failed.");
  46. return FAILED;
  47. }
  48. GELOGD("NextIterationPass Leave");
  49. return SUCCESS;
  50. }
  51. ///
  52. /// @brief Group Enter node
  53. /// @param [in] enter_node
  54. /// @return Status
  55. ///
  56. Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) {
  57. OpDescPtr enter_desc = enter_node->GetOpDesc();
  58. GE_CHECK_NOTNULL(enter_desc);
  59. std::string frame_name;
  60. if (!ge::AttrUtils::GetStr(enter_desc, ENTER_ATTR_FRAME_NAME, frame_name) || frame_name.empty()) {
  61. GELOGE(FAILED, "Get attr ENTER_ATTR_FRAME_NAME failed, node: %s", enter_desc->GetName().c_str());
  62. return FAILED;
  63. }
  64. string batch_label;
  65. if (ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  66. frame_name += batch_label;
  67. }
  68. auto iter = loop_group_map_.find(frame_name);
  69. if (iter == loop_group_map_.end()) {
  70. LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>();
  71. if (loop_group == nullptr) {
  72. GELOGE(FAILED, "MakeShared for LoopCondGroup failed.");
  73. return FAILED;
  74. }
  75. loop_group->enter_nodes.emplace_back(enter_node);
  76. loop_group_map_[frame_name] = loop_group;
  77. } else {
  78. iter->second->enter_nodes.emplace_back(enter_node);
  79. }
  80. return SUCCESS;
  81. }
  82. ///
  83. /// @brief Find while groups
  84. /// @return Status
  85. ///
  86. Status NextIterationPass::FindWhileGroups() {
  87. for (const auto &loop_group_iter : loop_group_map_) {
  88. const std::string &frame_name = loop_group_iter.first;
  89. for (const auto &enter_node : loop_group_iter.second->enter_nodes) {
  90. for (const auto &out_node : enter_node->GetOutAllNodes()) {
  91. std::string type;
  92. GE_CHK_STATUS_RET(GetOriginalType(out_node, type), "Get node type failed.");
  93. if ((type != MERGE) && (type != REFMERGE)) {
  94. continue;
  95. }
  96. NodePtr next_node = nullptr;
  97. if (FindTargetNode(out_node, NEXTITERATION, true, next_node) != SUCCESS) {
  98. GELOGE(INTERNAL_ERROR, "Get NextIteration node failed, frame_name: %s", frame_name.c_str());
  99. return INTERNAL_ERROR;
  100. }
  101. loop_group_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node));
  102. NodePtr switch_node = nullptr;
  103. if (FindTargetNode(out_node, SWITCH, false, switch_node) != SUCCESS) {
  104. GELOGE(INTERNAL_ERROR, "Get Switch node failed, frame_name: %s.", frame_name.c_str());
  105. return INTERNAL_ERROR;
  106. }
  107. if (switch_node == nullptr) {
  108. continue;
  109. }
  110. NodePtr loop_cond = nullptr;
  111. if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) {
  112. GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str());
  113. return INTERNAL_ERROR;
  114. }
  115. if (loop_group_iter.second->loop_cond == nullptr) {
  116. loop_group_iter.second->loop_cond = loop_cond;
  117. } else if (loop_group_iter.second->loop_cond != loop_cond) {
  118. GELOGE(FAILED, "Multi LoopCond nodes exist, frame_name: %s.", frame_name.c_str());
  119. return FAILED;
  120. }
  121. }
  122. }
  123. }
  124. return SUCCESS;
  125. }
  126. ///
  127. /// @brief Verify if valid
  128. /// @return bool
  129. ///
  130. bool NextIterationPass::VerifyWhileGroup() {
  131. // map<frame_name, LoopCondGroup>
  132. for (const auto &loop_group_iter : loop_group_map_) {
  133. const std::string &frame_name = loop_group_iter.first;
  134. if (frame_name.empty()) {
  135. GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty.");
  136. return false;
  137. }
  138. if (loop_group_iter.second->loop_cond == nullptr) {
  139. GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str());
  140. return false;
  141. }
  142. for (const auto &pair_iter : loop_group_iter.second->merge_next_pairs) {
  143. if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) {
  144. GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.",
  145. frame_name.c_str());
  146. return false;
  147. }
  148. }
  149. }
  150. return true;
  151. }
  152. ///
  153. /// @brief Handle while group
  154. /// @param [in] graph
  155. /// @return Status
  156. ///
  157. Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
  158. for (const auto &loop_cond_iter : loop_group_map_) {
  159. const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName();
  160. GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str());
  161. // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge
  162. NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE);
  163. NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE);
  164. if ((enter_active == nullptr) || (next_active == nullptr)) {
  165. GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str());
  166. return INTERNAL_ERROR;
  167. }
  168. for (const auto &enter_node : loop_cond_iter.second->enter_nodes) {
  169. // Enter --> Active
  170. if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
  171. GELOGE(INTERNAL_ERROR, "Add control edge from %s to %s failed.", enter_node->GetName().c_str(),
  172. enter_active->GetName().c_str());
  173. return INTERNAL_ERROR;
  174. }
  175. }
  176. for (const auto &pair : loop_cond_iter.second->merge_next_pairs) {
  177. NodePtr merge_node = pair.first;
  178. NodePtr next_node = pair.second;
  179. // Active --> Merge
  180. if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
  181. GELOGE(INTERNAL_ERROR, "Add control edge failed.");
  182. return INTERNAL_ERROR;
  183. }
  184. // NextIteration --> Active
  185. if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
  186. GELOGE(INTERNAL_ERROR, "Add control edge failed.");
  187. return INTERNAL_ERROR;
  188. }
  189. // break link between NextIteration and Merge
  190. if (BreakNextIteration(next_node, merge_node) != SUCCESS) {
  191. GELOGE(INTERNAL_ERROR, "Break NextIteration failed");
  192. return INTERNAL_ERROR;
  193. }
  194. }
  195. if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) ||
  196. (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) {
  197. GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed.");
  198. return INTERNAL_ERROR;
  199. }
  200. }
  201. return SUCCESS;
  202. }
  203. ///
  204. /// @brief Create Active Node
  205. /// @param [in] graph
  206. /// @param [in] name
  207. /// @return ge::NodePtr
  208. ///
  209. NodePtr NextIterationPass::CreateActiveNode(ComputeGraphPtr &graph, const std::string &name) {
  210. OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMACTIVE);
  211. if (op_desc == nullptr) {
  212. return nullptr;
  213. }
  214. GELOGI("Create StreamActive op:%s.", op_desc->GetName().c_str());
  215. NodePtr active_node = graph->AddNode(op_desc);
  216. if (active_node == nullptr) {
  217. GELOGE(INTERNAL_ERROR, "Create node[%s] failed.", name.c_str());
  218. return nullptr;
  219. }
  220. if (SetSwitchBranchNodeLabel(active_node, name) != SUCCESS) {
  221. GELOGE(INTERNAL_ERROR, "Set attr SWITCH_BRANCH_NODE_LABEL for node: %s failed.", active_node->GetName().c_str());
  222. return nullptr;
  223. }
  224. return active_node;
  225. }
  226. ///
  227. /// @brief Break NextIteration Link & add name to merge attr
  228. /// @param [in] next_node
  229. /// @param [in] merge_node
  230. /// @return Status
  231. ///
  232. Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr &merge_node) {
  233. if ((merge_node == nullptr) || (next_node == nullptr)) {
  234. GELOGE(PARAM_INVALID, "merge node or next node is null.");
  235. return PARAM_INVALID;
  236. }
  237. for (const auto &in_anchor : merge_node->GetAllInDataAnchors()) {
  238. OutDataAnchorPtr out_anchor = in_anchor->GetPeerOutAnchor();
  239. if ((out_anchor == nullptr) || (out_anchor->GetOwnerNode() != next_node)) {
  240. continue;
  241. }
  242. if (GraphUtils::RemoveEdge(out_anchor, in_anchor) != SUCCESS) {
  243. GELOGE(INTERNAL_ERROR, "Remove data edge failed, %s->%s.", next_node->GetName().c_str(),
  244. merge_node->GetName().c_str());
  245. return INTERNAL_ERROR;
  246. }
  247. if (SetNextIteration(merge_node, next_node->GetName()) != SUCCESS) {
  248. GELOGE(INTERNAL_ERROR, "Set attr NEXT_ITERATION for node %s failed.", merge_node->GetName().c_str());
  249. return INTERNAL_ERROR;
  250. }
  251. }
  252. return SUCCESS;
  253. }
  254. ///
  255. /// @brief find target node
  256. /// @param [in] node
  257. /// @param [in] target_type
  258. /// @param [in] is_input
  259. /// @param [out] target_node
  260. /// @return Status
  261. ///
  262. Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input,
  263. NodePtr &target_node) {
  264. if (node == nullptr) {
  265. GELOGE(PARAM_INVALID, "node is null.");
  266. return PARAM_INVALID;
  267. }
  268. std::vector<NodePtr> nodes;
  269. if (is_input) {
  270. for (const auto &tmp_node : node->GetInDataNodes()) {
  271. nodes.emplace_back(tmp_node);
  272. }
  273. } else {
  274. for (const auto &tmp_node : node->GetOutDataNodes()) {
  275. nodes.emplace_back(tmp_node);
  276. }
  277. }
  278. for (const auto &tmp_node : nodes) {
  279. std::string type;
  280. GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed.");
  281. if ((target_type == LOOPCOND) && (type == target_type)) {
  282. target_node = tmp_node;
  283. break;
  284. } else if ((type == target_type) || (type == "Ref" + target_type)) {
  285. target_node = tmp_node;
  286. break;
  287. }
  288. }
  289. if ((target_type != SWITCH) && (target_node == nullptr)) {
  290. GELOGE(INTERNAL_ERROR, "Find node %s failed.", target_type.c_str());
  291. return INTERNAL_ERROR;
  292. }
  293. return SUCCESS;
  294. }
  295. ///
  296. /// @brief Clear Status, used for subgraph pass
  297. /// @return SUCCESS
  298. ///
  299. Status NextIterationPass::ClearStatus() {
  300. loop_group_map_.clear();
  301. return SUCCESS;
  302. }
  303. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示