You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parallel_group_pass.cc 15 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/parallel_group_pass.h"
  17. #include <queue>
  18. #include "framework/common/debug/ge_log.h"
  19. #include "common/ge/ge_util.h"
  20. #include "framework/common/ge_inner_error_codes.h"
  21. #include "graph/debug/ge_attr_define.h"
  22. #include "graph/utils/graph_utils.h"
  23. #include "graph/utils/node_utils.h"
  24. namespace ge {
  25. namespace {
  26. const int32_t kMaxRecursionDepth = 10;
  27. const int64_t kLoopType = 1;
  28. }
  29. Status ParallelGroupPass::Run(ComputeGraphPtr graph) {
  30. GELOGD("ParallelGroupPass running");
  31. if (graph == nullptr) {
  32. GELOGE(PARAM_INVALID, "[Check][Graph]Input param graph is null, skip ParallelGroupPass.");
  33. REPORT_INNER_ERROR("E19999", "Input param graph is null, skip ParallelGroupPass.");
  34. return PARAM_INVALID;
  35. }
  36. if (graph->GetParentGraph() != nullptr) {
  37. GELOGD("Current graph %s is a subgraph, this pass only support root graph.",
  38. graph->GetName().c_str());
  39. return SUCCESS;
  40. }
  41. if (graph->TopologicalSorting() != GRAPH_SUCCESS) {
  42. GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str());
  43. REPORT_CALL_ERROR("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.",
  44. graph->GetName().c_str());
  45. return FAILED;
  46. }
  47. std::unordered_set<std::string> parallel_groups;
  48. int depth = 0;
  49. if (ProcessGraphGroupNodes(graph, depth, parallel_groups) != SUCCESS) {
  50. GELOGE(INTERNAL_ERROR, "[Process][Graph]Process group nodes of graph %s failed.", graph->GetName().c_str());
  51. return INTERNAL_ERROR;
  52. }
  53. if (graph->TopologicalSorting() != GRAPH_SUCCESS) {
  54. GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str());
  55. REPORT_CALL_ERROR("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.",
  56. graph->GetName().c_str());
  57. return FAILED;
  58. }
  59. return SUCCESS;
  60. }
  61. Status ParallelGroupPass::ProcessGraphGroupNodes(ComputeGraphPtr graph, int32_t depth,
  62. std::unordered_set<std::string> &parallel_groups) {
  63. if (depth >= kMaxRecursionDepth) {
  64. GELOGE(FAILED, "[Process][SubGraph]There are too much subgraphs:%d > %d(max subgraphs)", depth, kMaxRecursionDepth);
  65. REPORT_INNER_ERROR("E19999", "There are too much subgraphs:%d > %d(max subgraphs)", depth, kMaxRecursionDepth);
  66. return FAILED;
  67. }
  68. std::map<std::string, vector<NodePtr>> group_nodes;
  69. auto candidates = graph->GetDirectNode();
  70. auto root_graph = GraphUtils::FindRootGraph(graph);
  71. for (const auto &node : candidates) {
  72. OpDescPtr op_desc = node->GetOpDesc();
  73. if (op_desc == nullptr) {
  74. continue;
  75. }
  76. std::string group_name;
  77. if (AttrUtils::GetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, group_name)) {
  78. group_nodes[group_name].push_back(node);
  79. parallel_groups.insert(group_name);
  80. GELOGD("Find group node:%s, group_name:%s", node->GetName().c_str(), group_name.c_str());
  81. }
  82. const auto &subgraph_name = op_desc->GetSubgraphInstanceNames();
  83. GE_CHECK_NOTNULL(root_graph);
  84. for (auto name_iter = subgraph_name.rbegin(); name_iter != subgraph_name.rend(); ++name_iter) {
  85. const auto &sub_graph = root_graph->GetSubgraph(*name_iter);
  86. GE_CHECK_NOTNULL(sub_graph);
  87. // if the pass add control edge for known and unknown graph, then the known graph will become unknown graph
  88. // the order between known and unknown graph is guaranteed by dynamic shape executor
  89. // so the parallel group pass do nothing for unknown graph
  90. if (sub_graph->GetGraphUnknownFlag()) {
  91. continue;
  92. }
  93. std::unordered_set<std::string> sub_parallel_groups;
  94. auto ret = ProcessGraphGroupNodes(sub_graph, depth + 1, sub_parallel_groups);
  95. if (ret != SUCCESS) {
  96. GELOGE(FAILED, "[Process][SubGraph]Process sub graph %s failed.", sub_graph->GetName().c_str());
  97. return FAILED;
  98. }
  99. for (const auto &sub_parallel_group : sub_parallel_groups) {
  100. parallel_groups.insert(sub_parallel_group);
  101. group_nodes[sub_parallel_group].emplace_back(node);
  102. }
  103. }
  104. }
  105. std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> node_2_switch_merge;
  106. if (ProcessGroupNodeInSwitch(graph, node_2_switch_merge) != SUCCESS) {
  107. GELOGE(FAILED, "[Process][Node]Process group node in switch failed, graph:%s.", graph->GetName().c_str());
  108. return FAILED;
  109. }
  110. for (const auto &itr : group_nodes) {
  111. const auto &nodes = itr.second;
  112. if (nodes.empty()) {
  113. continue;
  114. }
  115. NodePtr pre_node = nodes[0];
  116. NodePtr cur_node = nullptr;
  117. for (std::size_t i = 1; i < nodes.size(); i++) {
  118. cur_node = nodes[i];
  119. GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
  120. if (ReplaceWithSwitchAndMerge(pre_node, cur_node, node_2_switch_merge) != SUCCESS) {
  121. GELOGE(FAILED, "[Replace][Node]Replace switch and merges for nodes: %s and %s failed.",
  122. pre_node->GetName().c_str(), cur_node->GetName().c_str());
  123. return FAILED;
  124. }
  125. pre_node = cur_node;
  126. }
  127. }
  128. return SUCCESS;
  129. }
  130. Status ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) {
  131. if (pre_node == cur_node) {
  132. GELOGD("Pre_node and cur_node are same, no need add anchor");
  133. return SUCCESS;
  134. }
  135. auto in_nodes = cur_node->GetInAllNodes();
  136. for (const auto &node : in_nodes) {
  137. if (pre_node == node) {
  138. GELOGD("Node:%s and %s already linked", pre_node->GetName().c_str(),
  139. cur_node->GetName().c_str());
  140. return SUCCESS;
  141. }
  142. }
  143. GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
  144. return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), cur_node->GetInControlAnchor());
  145. }
  146. Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph,
  147. std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) {
  148. std::string type;
  149. auto direct_nodes = graph->GetDirectNode();
  150. for (const auto &node : direct_nodes) {
  151. type = node->GetType();
  152. if (type != STREAMSWITCH) {
  153. continue;
  154. }
  155. if (IsBigSmallLoopStreamSwitch(node->GetOpDesc()) ||
  156. IsWhileStreamSwitch(node->GetOpDesc())) {
  157. continue;
  158. }
  159. std::vector<NodePtr> merge_nodes;
  160. std::set<NodePtr> group_nodes;
  161. std::set<std::string> stream_labels;
  162. FindGroupNodeAndMerge(node, group_nodes, merge_nodes, stream_labels);
  163. if (merge_nodes.empty() || (!group_nodes.empty() && stream_labels.size() > 1)) {
  164. GELOGE(FAILED, "[Process][Node]Cannot find merge node or exist switch nestification, switch node:%s,"
  165. "merge_vec size:%zu, stream_labels size:%zu, graph:%s.", node->GetName().c_str(),
  166. merge_nodes.size(), stream_labels.size(), graph->GetName().c_str());
  167. REPORT_INNER_ERROR("E19999", "Cannot find merge node or exist switch nest, switch node:%s,"
  168. "merge_vec size: %zu, stream_labels size: %zu, graph:%s.", node->GetName().c_str(),
  169. merge_nodes.size(), stream_labels.size(), graph->GetName().c_str());
  170. return FAILED;
  171. }
  172. std::sort(merge_nodes.begin(), merge_nodes.end(),
  173. [] (NodePtr a, NodePtr b) -> bool {
  174. return (a->GetOpDesc()->GetId() < b->GetOpDesc()->GetId());
  175. });
  176. NodePtr cast_node = NodeUtils::GetInDataNodeByIndex(*node, 0);
  177. GE_CHECK_NOTNULL(cast_node);
  178. if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes, cast_node, node, node_2_switch_merge) != SUCCESS) {
  179. GELOGE(FAILED, "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", graph->GetName().c_str());
  180. REPORT_CALL_ERROR("E19999", "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.",
  181. graph->GetName().c_str());
  182. return FAILED;
  183. }
  184. }
  185. return SUCCESS;
  186. }
  187. void ParallelGroupPass::FindGroupNodeAndMerge(NodePtr stream_switch_node, std::set<NodePtr> &group_nodes,
  188. std::vector<NodePtr> &merge_nodes, std::set<std::string> &stream_labels) {
  189. std::string type;
  190. std::deque<NodePtr> candidates;
  191. std::set<NodePtr> visited;
  192. candidates.push_back(stream_switch_node);
  193. while (!candidates.empty()) {
  194. NodePtr tmp_node = candidates.front();
  195. candidates.pop_front();
  196. for (const auto &out_node : tmp_node->GetOutAllNodes()) {
  197. type = out_node->GetType();
  198. if (type == STREAMMERGE) {
  199. merge_nodes.emplace_back(out_node);
  200. continue;
  201. }
  202. const auto &op = out_node->GetOpDesc();
  203. if (op != nullptr && op->HasAttr(ATTR_NAME_PARALLEL_GROUP)) {
  204. group_nodes.emplace(out_node);
  205. }
  206. if (visited.count(out_node) > 0) {
  207. continue;
  208. }
  209. candidates.push_back(out_node);
  210. visited.insert(out_node);
  211. std::string stream_label;
  212. if (ge::AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) {
  213. stream_labels.insert(stream_label);
  214. }
  215. }
  216. }
  217. }
  218. Status ParallelGroupPass::MappingNodeToSwitchAndMerge(const std::set<NodePtr> &group_nodes,
  219. const std::vector<NodePtr> &merge_nodes, const NodePtr &cast_node, const NodePtr &switch_node,
  220. std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) {
  221. for (const auto &group_node : group_nodes) {
  222. auto itr = node_2_switch_merge.find(group_node);
  223. if (itr != node_2_switch_merge.end()) {
  224. auto &tmp = itr->second;
  225. auto &switch_set = tmp.first;
  226. const auto &merge_node = tmp.second;
  227. GELOGD("Find group node: %s in switch %s and merge %s.",
  228. group_node->GetName().c_str(), switch_node->GetName().c_str(), merge_node->GetName().c_str());
  229. if (merge_node != merge_nodes.back()) {
  230. GELOGE(FAILED, "[Mapping][Node]Has two different merge nodes: %s and %s, graph's structure is invalid",
  231. merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str());
  232. REPORT_INNER_ERROR("E19999", "Has two different merge nodes: %s and %s,"
  233. "graph's structure is invalid",
  234. merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str());
  235. return FAILED;
  236. }
  237. switch_set.insert(cast_node);
  238. } else {
  239. node_2_switch_merge.emplace(group_node,
  240. std::make_pair(std::set<NodePtr>{cast_node}, merge_nodes.back()));
  241. }
  242. }
  243. return SUCCESS;
  244. }
  245. Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cur_node,
  246. const std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) {
  247. auto pre_itr = node_2_switch_merge.find(pre_node);
  248. auto cur_itr = node_2_switch_merge.find(cur_node);
  249. if (pre_itr != node_2_switch_merge.end()) {
  250. if (cur_itr != node_2_switch_merge.end()) {
  251. const auto &pre_set = pre_itr->second.first;
  252. const auto &cur_set = cur_itr->second.first;
  253. if (!HasSameSwitch(pre_set, cur_set)) {
  254. pre_node = pre_itr->second.second;
  255. for (const auto &switch_node : cur_itr->second.first) {
  256. if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) {
  257. GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  258. pre_node->GetName().c_str(), switch_node->GetName().c_str());
  259. REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  260. pre_node->GetName().c_str(), switch_node->GetName().c_str());
  261. return FAILED;
  262. }
  263. }
  264. }
  265. return SUCCESS;
  266. } else {
  267. pre_node = pre_itr->second.second;
  268. return AddCtrlEdge(pre_node, cur_node);
  269. }
  270. } else {
  271. if (cur_itr != node_2_switch_merge.end()) {
  272. for (const auto &switch_node : cur_itr->second.first) {
  273. int64_t pre_id = pre_node->GetOpDesc()->GetId();
  274. int64_t switch_id = switch_node->GetOpDesc()->GetId();
  275. NodePtr first_node = pre_node;
  276. NodePtr second_node = switch_node;
  277. if (pre_id > switch_id && IsIndirectConnect(switch_node, pre_node)) {
  278. // avoid ring, merge->pre_node
  279. first_node = cur_itr->second.second;
  280. second_node = pre_node;
  281. }
  282. if (AddCtrlEdge(first_node, second_node) != SUCCESS) {
  283. GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  284. first_node->GetName().c_str(), second_node->GetName().c_str());
  285. REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  286. first_node->GetName().c_str(), second_node->GetName().c_str());
  287. return FAILED;
  288. }
  289. }
  290. } else {
  291. return AddCtrlEdge(pre_node, cur_node);
  292. }
  293. }
  294. return SUCCESS;
  295. }
  296. bool ParallelGroupPass::HasSameSwitch(const std::set<NodePtr> &switch_set1, const std::set<NodePtr> &switch_set2) {
  297. for (const auto &node1 : switch_set1) {
  298. auto itr = switch_set2.find(node1);
  299. if (itr != switch_set2.end()) {
  300. return true;
  301. }
  302. }
  303. return false;
  304. }
  305. bool ParallelGroupPass::IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc) {
  306. return !AttrUtils::HasAttr(switch_op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG);
  307. }
  308. bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) {
  309. int64_t stream_switch_type = -1;
  310. return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) &&
  311. stream_switch_type == kLoopType);
  312. }
  313. bool ParallelGroupPass::IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b) {
  314. if (node_a == nullptr || node_b == nullptr) {
  315. GELOGW("node_a or node_b is nullptr.");
  316. return false;
  317. }
  318. int64_t end_id = node_b->GetOpDesc()->GetId();
  319. std::queue<NodePtr> nodes;
  320. nodes.push(node_a);
  321. while (!nodes.empty()) {
  322. NodePtr tmp_node = nodes.front();
  323. nodes.pop();
  324. if (tmp_node == nullptr || tmp_node->GetOpDesc() == nullptr ||
  325. tmp_node->GetOpDesc()->GetId() > end_id) {
  326. continue;
  327. }
  328. if (tmp_node == node_b) {
  329. return true;
  330. }
  331. for (const auto &out_node : tmp_node->GetOutAllNodes()) {
  332. nodes.push(out_node);
  333. }
  334. }
  335. return false;
  336. }
  337. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示