You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

next_iteration_pass.cc 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/next_iteration_pass.h"
  17. #include "common/ge/ge_util.h"
  18. #include "graph/common/omg_util.h"
  19. namespace ge {
  20. Status NextIterationPass::Run(ComputeGraphPtr graph) {
  21. GELOGD("NextIterationPass Enter");
  22. /// Enter-----------+
  23. /// +-> Merge -> Switch <- LoopCond <- Cond
  24. /// NextIteration---+
  25. for (auto &node : graph->GetDirectNode()) {
  26. const std::string type = node->GetType();
  27. if ((type != ENTER) && (type != REFENTER)) {
  28. continue;
  29. }
  30. if (GroupEnterNode(node) != SUCCESS) {
  31. GELOGE(INTERNAL_ERROR, "Group enter_node %s failed.", node->GetName().c_str());
  32. return INTERNAL_ERROR;
  33. }
  34. }
  35. if (GroupWithNoBatch(graph) != SUCCESS) {
  36. GELOGE(INTERNAL_ERROR, "Group enter_nodes failed without batch_label attr.");
  37. return INTERNAL_ERROR;
  38. }
  39. if (FindWhileGroups() != SUCCESS) {
  40. GELOGE(INTERNAL_ERROR, "Find while groups failed.");
  41. return INTERNAL_ERROR;
  42. }
  43. if (!VerifyWhileGroup()) {
  44. GELOGE(INTERNAL_ERROR, "Verify while groups failed.");
  45. return INTERNAL_ERROR;
  46. }
  47. if (HandleWhileGroup(graph) != SUCCESS) {
  48. GELOGE(FAILED, "Handle while groups failed.");
  49. return FAILED;
  50. }
  51. GELOGD("NextIterationPass Leave");
  52. return SUCCESS;
  53. }
  54. ///
  55. /// @brief Group Enter node
  56. /// @param [in] enter_node
  57. /// @return Status
  58. ///
  59. Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) {
  60. OpDescPtr enter_desc = enter_node->GetOpDesc();
  61. GE_CHECK_NOTNULL(enter_desc);
  62. std::string frame_name;
  63. if (!ge::AttrUtils::GetStr(enter_desc, ENTER_ATTR_FRAME_NAME, frame_name) || frame_name.empty()) {
  64. GELOGE(FAILED, "Get attr ENTER_ATTR_FRAME_NAME failed, node: %s", enter_desc->GetName().c_str());
  65. return FAILED;
  66. }
  67. std::string batch_label;
  68. (void)ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label);
  69. if (batch_label.empty()) {
  70. auto frame_iter = frame_enter_map_.find(frame_name);
  71. if (frame_iter == frame_enter_map_.end()) {
  72. std::vector<NodePtr> enter_nodes;
  73. enter_nodes.emplace_back(enter_node);
  74. frame_enter_map_[frame_name] = enter_nodes;
  75. } else {
  76. frame_iter->second.emplace_back(enter_node);
  77. }
  78. return SUCCESS;
  79. }
  80. auto group_iter = loop_group_map_.find(frame_name);
  81. if (group_iter == loop_group_map_.end()) {
  82. LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>();
  83. if (loop_group == nullptr) {
  84. GELOGE(FAILED, "MakeShared for LoopCondGroup failed.");
  85. return FAILED;
  86. }
  87. loop_group->enter_nodes.emplace_back(enter_node);
  88. loop_group_map_[frame_name][batch_label] = loop_group;
  89. } else {
  90. auto batch_iter = group_iter->second.find(batch_label);
  91. if (batch_iter == group_iter->second.end()) {
  92. LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>();
  93. if (loop_group == nullptr) {
  94. GELOGE(FAILED, "MakeShared for LoopCondGroup failed.");
  95. return FAILED;
  96. }
  97. loop_group->enter_nodes.emplace_back(enter_node);
  98. group_iter->second[batch_label] = loop_group;
  99. } else {
  100. batch_iter->second->enter_nodes.emplace_back(enter_node);
  101. }
  102. }
  103. return SUCCESS;
  104. }
  105. ///
  106. /// @brief Group Enter nodes without batch_label attr
  107. /// @param [in] compute_graph
  108. /// @return Status
  109. ///
  110. Status NextIterationPass::GroupWithNoBatch(const ComputeGraphPtr &graph) {
  111. if (frame_enter_map_.empty()) {
  112. GELOGI("All enter nodes in graph %s has batch_label attr.", graph->GetName().c_str());
  113. return SUCCESS;
  114. }
  115. for (const auto &item : frame_enter_map_) {
  116. const std::string &frame_name = item.first;
  117. auto iter = loop_group_map_.find(frame_name);
  118. if (iter == loop_group_map_.end()) {
  119. LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>();
  120. if (loop_group == nullptr) {
  121. GELOGE(FAILED, "MakeShared for LoopCondGroup failed.");
  122. return FAILED;
  123. }
  124. loop_group->enter_nodes = item.second;
  125. loop_group_map_[frame_name][""] = loop_group;
  126. } else {
  127. for (auto &batch_item : iter->second) {
  128. for (const auto &enter_node : item.second) {
  129. batch_item.second->enter_nodes.emplace_back(enter_node);
  130. }
  131. }
  132. }
  133. }
  134. return SUCCESS;
  135. }
  136. ///
  137. /// @brief Find while groups
  138. /// @return Status
  139. ///
  140. Status NextIterationPass::FindWhileGroups() {
  141. for (const auto &loop_group_iter : loop_group_map_) {
  142. const std::string &frame_name = loop_group_iter.first;
  143. for (const auto &batch_iter : loop_group_iter.second) {
  144. const std::string &batch_label = batch_iter.first;
  145. for (const auto &enter_node : batch_iter.second->enter_nodes) {
  146. for (const auto &out_node : enter_node->GetOutAllNodes()) {
  147. GELOGI("Find while_group for enter_node %s, frame_name:%s, batch_label:%s.", enter_node->GetName().c_str(),
  148. frame_name.c_str(), batch_label.c_str());
  149. if ((out_node->GetType() != MERGE) && (out_node->GetType() != REFMERGE)) {
  150. continue;
  151. }
  152. std::string tmp_label;
  153. GE_CHECK_NOTNULL(out_node->GetOpDesc());
  154. (void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label);
  155. bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label));
  156. if (need_skip) {
  157. continue;
  158. }
  159. NodePtr next_node = nullptr;
  160. if (FindTargetNode(out_node, NEXTITERATION, true, batch_label, next_node) != SUCCESS) {
  161. GELOGE(INTERNAL_ERROR,
  162. "Get NextIteration node failed: inputs of Merge should be Enter/NextIteration, current_Merge=%s",
  163. out_node->GetName().c_str());
  164. return INTERNAL_ERROR;
  165. }
  166. batch_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node));
  167. NodePtr switch_node = nullptr;
  168. if (FindTargetNode(out_node, SWITCH, false, batch_label, switch_node) != SUCCESS) {
  169. GELOGE(INTERNAL_ERROR, "Get Switch node failed: output of Merge should be Switch, current_Merge=%s",
  170. out_node->GetName().c_str());
  171. return INTERNAL_ERROR;
  172. }
  173. if (switch_node == nullptr) {
  174. continue;
  175. }
  176. NodePtr loop_cond = nullptr;
  177. if (FindTargetNode(switch_node, LOOPCOND, true, batch_label, loop_cond) != SUCCESS) {
  178. GELOGE(INTERNAL_ERROR,
  179. "Get LoopCond node failed: pred input of Switch should be LoopCond, current_Switch=%s",
  180. switch_node->GetName().c_str());
  181. return INTERNAL_ERROR;
  182. }
  183. if (batch_iter.second->loop_cond == nullptr) {
  184. batch_iter.second->loop_cond = loop_cond;
  185. } else if (batch_iter.second->loop_cond != loop_cond) {
  186. GELOGE(FAILED, "Multi LoopCond nodes exist.");
  187. return FAILED;
  188. }
  189. }
  190. }
  191. }
  192. }
  193. return SUCCESS;
  194. }
  195. ///
  196. /// @brief Verify if valid
  197. /// @return bool
  198. ///
  199. bool NextIterationPass::VerifyWhileGroup() {
  200. // map<frame_name, LoopCondGroup>
  201. for (const auto &loop_group_iter : loop_group_map_) {
  202. const std::string &frame_name = loop_group_iter.first;
  203. if (frame_name.empty()) {
  204. GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty.");
  205. return false;
  206. }
  207. for (const auto &batch_iter : loop_group_iter.second) {
  208. if (batch_iter.second->loop_cond == nullptr) {
  209. GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str());
  210. return false;
  211. }
  212. for (const auto &pair_iter : batch_iter.second->merge_next_pairs) {
  213. if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) {
  214. GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.",
  215. frame_name.c_str());
  216. return false;
  217. }
  218. }
  219. }
  220. }
  221. return true;
  222. }
  223. ///
  224. /// @brief Handle while group
  225. /// @param [in] graph
  226. /// @return Status
  227. ///
  228. Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
  229. for (const auto &loop_cond_iter : loop_group_map_) {
  230. for (const auto &batch_iter : loop_cond_iter.second) {
  231. const std::string &cond_name = batch_iter.second->loop_cond->GetName();
  232. GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str());
  233. // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge
  234. NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE);
  235. NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE);
  236. if ((enter_active == nullptr) || (next_active == nullptr)) {
  237. GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str());
  238. return INTERNAL_ERROR;
  239. }
  240. for (const auto &enter_node : batch_iter.second->enter_nodes) {
  241. // Enter --> Active
  242. if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) !=
  243. GRAPH_SUCCESS) {
  244. GELOGE(INTERNAL_ERROR, "Add control edge failed.");
  245. return INTERNAL_ERROR;
  246. }
  247. }
  248. for (const auto &pair : batch_iter.second->merge_next_pairs) {
  249. NodePtr merge_node = pair.first;
  250. NodePtr next_node = pair.second;
  251. // Active --> Merge
  252. if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) !=
  253. GRAPH_SUCCESS) {
  254. GELOGE(INTERNAL_ERROR, "Add control edge failed.");
  255. return INTERNAL_ERROR;
  256. }
  257. // NextIteration --> Active
  258. if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
  259. GELOGE(INTERNAL_ERROR, "Add control edge failed.");
  260. return INTERNAL_ERROR;
  261. }
  262. // break link between NextIteration and Merge
  263. if (BreakNextIteration(next_node, merge_node) != SUCCESS) {
  264. GELOGE(INTERNAL_ERROR, "Break NextIteration failed");
  265. return INTERNAL_ERROR;
  266. }
  267. }
  268. if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) ||
  269. (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) {
  270. GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed.");
  271. return INTERNAL_ERROR;
  272. }
  273. }
  274. }
  275. return SUCCESS;
  276. }
  277. ///
  278. /// @brief Create Active Node
  279. /// @param [in] graph
  280. /// @param [in] name
  281. /// @return ge::NodePtr
  282. ///
  283. NodePtr NextIterationPass::CreateActiveNode(ComputeGraphPtr &graph, const std::string &name) {
  284. OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMACTIVE);
  285. if (op_desc == nullptr) {
  286. return nullptr;
  287. }
  288. GELOGI("Create StreamActive op:%s.", op_desc->GetName().c_str());
  289. NodePtr active_node = graph->AddNode(op_desc);
  290. if (active_node == nullptr) {
  291. GELOGE(INTERNAL_ERROR, "Create node[%s] failed.", name.c_str());
  292. return nullptr;
  293. }
  294. if (SetSwitchBranchNodeLabel(active_node, name) != SUCCESS) {
  295. GELOGE(INTERNAL_ERROR, "Set attr SWITCH_BRANCH_NODE_LABEL for node: %s failed.", active_node->GetName().c_str());
  296. return nullptr;
  297. }
  298. return active_node;
  299. }
  300. ///
  301. /// @brief Break NextIteration Link & add name to merge attr
  302. /// @param [in] next_node
  303. /// @param [in] merge_node
  304. /// @return Status
  305. ///
  306. Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr &merge_node) {
  307. if ((merge_node == nullptr) || (next_node == nullptr)) {
  308. GELOGE(PARAM_INVALID, "merge node or next node is null.");
  309. return PARAM_INVALID;
  310. }
  311. for (const auto &in_anchor : merge_node->GetAllInDataAnchors()) {
  312. OutDataAnchorPtr out_anchor = in_anchor->GetPeerOutAnchor();
  313. if ((out_anchor == nullptr) || (out_anchor->GetOwnerNode() != next_node)) {
  314. continue;
  315. }
  316. if (GraphUtils::RemoveEdge(out_anchor, in_anchor) != SUCCESS) {
  317. GELOGE(INTERNAL_ERROR, "Remove data edge failed, %s->%s.", next_node->GetName().c_str(),
  318. merge_node->GetName().c_str());
  319. return INTERNAL_ERROR;
  320. }
  321. if (SetNextIteration(merge_node, next_node->GetName()) != SUCCESS) {
  322. GELOGE(INTERNAL_ERROR, "Set attr NEXT_ITERATION for node %s failed.", merge_node->GetName().c_str());
  323. return INTERNAL_ERROR;
  324. }
  325. }
  326. return SUCCESS;
  327. }
  328. ///
  329. /// @brief find target node
  330. /// @param [in] node
  331. /// @param [in] target_type
  332. /// @param [in] is_input
  333. /// @param [in] batch_label
  334. /// @param [out] target_node
  335. /// @return Status
  336. ///
  337. Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input,
  338. const std::string &batch_label, NodePtr &target_node) {
  339. if (node == nullptr) {
  340. GELOGE(PARAM_INVALID, "node is null.");
  341. return PARAM_INVALID;
  342. }
  343. std::vector<NodePtr> nodes;
  344. if (is_input) {
  345. for (const auto &tmp_node : node->GetInDataNodes()) {
  346. nodes.emplace_back(tmp_node);
  347. }
  348. } else {
  349. for (const auto &tmp_node : node->GetOutDataNodes()) {
  350. nodes.emplace_back(tmp_node);
  351. }
  352. }
  353. for (const auto &tmp_node : nodes) {
  354. std::string tmp_label;
  355. (void)AttrUtils::GetStr(tmp_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label);
  356. bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label));
  357. if (need_skip) {
  358. continue;
  359. }
  360. const std::string type = tmp_node->GetType();
  361. if ((target_type == LOOPCOND) && (type == target_type)) {
  362. target_node = tmp_node;
  363. break;
  364. } else if ((type == target_type) || (type == "Ref" + target_type)) {
  365. target_node = tmp_node;
  366. break;
  367. }
  368. }
  369. if ((target_type != SWITCH) && (target_node == nullptr)) {
  370. GELOGE(INTERNAL_ERROR, "Find node %s failed.", target_type.c_str());
  371. return INTERNAL_ERROR;
  372. }
  373. return SUCCESS;
  374. }
  375. ///
  376. /// @brief Clear Status, used for subgraph pass
  377. /// @return SUCCESS
  378. ///
  379. Status NextIterationPass::ClearStatus() {
  380. frame_enter_map_.clear();
  381. loop_group_map_.clear();
  382. return SUCCESS;
  383. }
  384. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示