You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_pass.cc 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/base_pass.h"
  17. #include <queue>
  18. #include <unordered_set>
  19. #include "common/debug/log.h"
  20. #include "graph/utils/graph_utils.h"
  21. namespace ge {
  22. namespace {
  23. constexpr int kMaxRePassTimes = 10000;
  24. constexpr size_t kMaxOneInNodes = 1000;
  25. // Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later
  26. constexpr int kMaxRecursiveDepth = 20;
  27. void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph,
  28. GEPass::GraphLevelState &g_state) {
  29. for (auto &node : graph->GetDirectNode()) {
  30. if (node == nullptr) {
  31. continue;
  32. }
  33. size_t in_nums = node->GetInNodes().size();
  34. if (in_nums == 0) {
  35. g_state.AddNodeToQueueIfNotSeen(node);
  36. } else if (in_nums > kMaxOneInNodes) {
  37. g_state.nodes_last.insert(node);
  38. }
  39. }
  40. }
  41. bool AnyNodesIn(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_set) {
  42. return std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) {
  43. return nodes_set.count(n) > 0;
  44. });
  45. }
  46. bool IsNodeReadyToQueue(const NodePtr &node, GEPass::GraphLevelState &g_state) {
  47. if (node == nullptr) {
  48. GELOGW("node is null");
  49. return false;
  50. }
  51. if (g_state.nodes_deleted.count(node) > 0) {
  52. GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str());
  53. return false;
  54. }
  55. if (g_state.nodes_last.count(node) != 0) {
  56. return false;
  57. }
  58. // all in_node seen && all in_node not suspend
  59. if (!node->IsAllInNodesSeen(g_state.nodes_seen)) {
  60. return false;
  61. }
  62. if (g_state.nodes_suspend.count(node) > 0) {
  63. GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
  64. node->GetName().c_str());
  65. return false;
  66. }
  67. if (AnyNodesIn(node->GetInAllNodes(), g_state.nodes_suspend)) {
  68. GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
  69. node->GetName().c_str());
  70. return false;
  71. }
  72. return true;
  73. }
  74. void AddNextIterNodes(const NodePtr &cur_node,
  75. std::unordered_set<NodePtr> &out_nodes_before_pass,
  76. GEPass::GraphLevelState &g_state) {
  77. for (auto &node : cur_node->GetOutNodes()) {
  78. if (node == nullptr) {
  79. continue;
  80. }
  81. if(out_nodes_before_pass.erase(node) == 0) {
  82. // after pass node, new output node come up
  83. GELOGI("New output node %s come up after pass %s.",
  84. node->GetName().c_str(), cur_node->GetName().c_str());
  85. }
  86. // all in_node seen && all in_node not suspend
  87. if (IsNodeReadyToQueue(node, g_state)) {
  88. g_state.AddNodeToQueueIfNotSeen(node);
  89. }
  90. }
  91. //
  92. for (const auto &node : out_nodes_before_pass) {
  93. // A-->B-->C if B was
  94. // unlink edge may happend, add these node to queue if needed
  95. if (node->GetInAllNodes().empty() && IsNodeReadyToQueue(node, g_state)) {
  96. GELOGI("Node %s may lost from cur node, add to queue if not seen.",
  97. node->GetName().c_str(), cur_node->GetName().c_str());
  98. g_state.AddNodeToQueueIfNotSeen(node);
  99. }
  100. }
  101. }
  102. void AddImmediateRepassNodesToQueue(NodePtr &cur_node,
  103. std::unordered_map<NodePtr, std::string> re_pass_imm_nodes_to_pass_names,
  104. GEPass::GraphLevelState &g_state) {
  105. for (const auto &node_2_pass_names : re_pass_imm_nodes_to_pass_names) {
  106. auto imme_repass_node = node_2_pass_names.first;
  107. if (imme_repass_node == nullptr) {
  108. GELOGW("Found null immediately re-pass node when executing pass %s on node %s type %s",
  109. node_2_pass_names.second.c_str(),
  110. cur_node->GetName().c_str(), cur_node->GetType().c_str());
  111. continue;
  112. }
  113. if (g_state.nodes_passed.count(imme_repass_node) > 0) {
  114. GELOGD("The node %s specified by pass %s has been passed, it will repass immediately",
  115. imme_repass_node->GetName().c_str(), node_2_pass_names.second.c_str());
  116. g_state.AddNodeToQueueFront(imme_repass_node);
  117. continue;
  118. }
  119. GELOGW("The node %s specified by pass %s has un-passed, it will not repass immediately",
  120. node_2_pass_names.first->GetName().c_str(), node_2_pass_names.second.c_str());
  121. }
  122. }
  123. void AddLastNodesToQueue(GEPass::GraphLevelState &g_state) {
  124. for (auto &node : g_state.nodes_last) {
  125. if (node->IsAllInNodesSeen(g_state.nodes_seen)) {
  126. g_state.AddNodeToQueueIfNotSeen(node);
  127. }
  128. }
  129. g_state.nodes_last.clear();
  130. }
  131. void AddResumeNodesToQueue(const std::unordered_map<NodePtr, std::string> resume_node_2_pass_names,
  132. GEPass::GraphLevelState &g_state) {
  133. // Now base pass doesnt record the order of suspend & resume, so we dont know which one come first in a node pass.
  134. // Here if one node pass suspend and resume a node ,consider it resume that node.
  135. // Better way to record the order, and here suspend or resume in order.
  136. for (const auto &node_2_pass_names : resume_node_2_pass_names) {
  137. auto node = node_2_pass_names.first;
  138. if (g_state.nodes_suspend.erase(node) > 0) {
  139. if (g_state.nodes_seen.count(node.get()) > 0 || node->IsAllInNodesSeen(g_state.nodes_seen)) {
  140. g_state.nodes.push_back(node);
  141. GELOGD("Node %s has been resumed by pass %s, and add to pass queue",
  142. node->GetName().c_str(), node_2_pass_names.second.c_str());
  143. }
  144. }
  145. }
  146. }
  147. void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass,
  148. std::unordered_set<Node *> &nodes_seen, const std::vector<NodePtr> &nodes_to_re_pass,
  149. GEPass::RepassLevelState &rp_state) {
  150. for (const auto &node_to_re_pass : nodes_to_re_pass) {
  151. if (node_to_re_pass == nullptr) {
  152. GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(),
  153. node->GetName().c_str(), node->GetType().c_str());
  154. continue;
  155. }
  156. if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) {
  157. if (rp_state.AddNodeToRepass(node_to_re_pass)) {
  158. GELOGD("The node %s will be re-pass.", node_to_re_pass->GetName().c_str());
  159. continue;
  160. }
  161. GELOGD("Node %s has been added to repass queue, no need to add again.", node_to_re_pass->GetName().c_str());
  162. } else {
  163. GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str());
  164. }
  165. }
  166. }
  167. void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) {
  168. for (auto &name_to_pass : names_to_pass) {
  169. name_to_pass.second->SetOption(option, "");
  170. }
  171. }
  172. void ClearOption(NamesToPass names_to_pass) {
  173. for (auto &name_to_pass : names_to_pass) {
  174. name_to_pass.second->ClearOptions();
  175. }
  176. }
  177. } // namespace
  178. Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map,
  179. bool is_repass_io_immediately) {
  180. if (node == nullptr) {
  181. REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid.");
  182. GELOGE(FAILED, "[Check][Param] parameter node is nullptr.");
  183. return FAILED;
  184. }
  185. GELOGI("Prepare to isolate and delete node, name:%s, type:%s.", node->GetName().c_str(),
  186. node->GetType().c_str());
  187. ComputeGraphPtr graph = node->GetOwnerComputeGraph();
  188. if (graph == nullptr) {
  189. REPORT_INNER_ERROR("E19999", "The owner graph of node:%s must not be null.", node->GetName().c_str());
  190. GELOGE(FAILED, "[Get][OwnerComputeGraph] failed, The owner graph of node:%s must not be null.",
  191. node->GetName().c_str());
  192. return FAILED;
  193. }
  194. is_repass_io_immediately ? AddImmediateRePassNodesWithInOut(node) : AddRePassNodesWithInOut(node);
  195. if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) {
  196. REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str());
  197. GELOGE(FAILED, "[Isolate][Node] %s failed.", node->GetName().c_str());
  198. return FAILED;
  199. }
  200. if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != SUCCESS) {
  201. REPORT_CALL_ERROR("E19999", "call RemoveNodeWithoutRelink for node:%s failed.", node->GetName().c_str());
  202. GELOGE(FAILED, "[Call][RemoveNodeWithoutRelink] for node:%s failed.", node->GetName().c_str());
  203. return FAILED;
  204. }
  205. AddNodeDeleted(node);
  206. return SUCCESS;
  207. }
  208. Status GEPass::Run(const NamesToPass &names_to_passes) {
  209. if (graph_ == nullptr) {
  210. REPORT_INNER_ERROR("E19999", "graph_ is nullptr, check invalid.");
  211. GELOGE(INTERNAL_ERROR, "[Check][Param] The graph is nullptr");
  212. return INTERNAL_ERROR;
  213. }
  214. if (names_to_passes.empty()) {
  215. GELOGW("No passes input, the GEPass will do nothing");
  216. return INTERNAL_ERROR;
  217. }
  218. for (const auto &name_to_pass : names_to_passes) {
  219. if (name_to_pass.second == nullptr) {
  220. GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s)", name_to_pass.first.c_str());
  221. return INTERNAL_ERROR;
  222. }
  223. }
  224. if (depth_ > kMaxRecursiveDepth) {
  225. GELOGE(PARAM_INVALID,
  226. "[Check][Param] The pass for root graph %s will be terminated because too many nesting"
  227. " levels(%d) of subgraphs, last subgraph is %s",
  228. root_graph_->GetName().c_str(), depth_, graph_->GetName().c_str());
  229. return PARAM_INVALID;
  230. }
  231. return RunPassesOneGraph(names_to_passes);
  232. // todo debug mode is on, find first node in topo order which is not passed. and give a warning
  233. }
  234. void NotifyPassGraphStart(const ComputeGraphPtr &graph, const NamesToPass &names_to_pass) {
  235. for (auto &name_to_pass : names_to_pass) {
  236. name_to_pass.second->OnStartPassGraph(graph);
  237. }
  238. }
  239. Status GEPass::HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state) {
  240. std::unordered_map<NodePtr, std::string> resume_nodes_to_pass_names;
  241. for (auto &name_to_pass : names_to_passes) {
  242. name_to_pass.second->init();
  243. auto ret = name_to_pass.second->OnSuspendNodesLeaked();
  244. if (ret != SUCCESS) {
  245. GELOGE(ret, "Internal error with OnSuspendNodesLeaked on pass %s.", name_to_pass.first.c_str());
  246. return ret;
  247. }
  248. for (const auto &resume_node : name_to_pass.second->GetNodesResume()){
  249. resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ",");
  250. }
  251. }
  252. AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state);
  253. return SUCCESS;
  254. }
  255. Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
  256. GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size());
  257. NotifyPassGraphStart(graph_, names_to_passes);
  258. GraphLevelState g_state;
  259. g_state.re_pass_times = 0;
  260. GetAllNodesNoInputEdge(graph_, g_state);
  261. GELOGD("Start points count %zu", g_state.nodes.size());
  262. do {
  263. if (!g_state.nodes_suspend.empty()) {
  264. auto ret = HandleLeakedSuspendNodes(names_to_passes, g_state);
  265. if (ret != SUCCESS) {
  266. // log inside upper function
  267. return ret;
  268. }
  269. if (g_state.nodes.empty()) {
  270. GELOGE(INTERNAL_ERROR, "There are some suspended nodes leaked and no pass resume them.");
  271. return INTERNAL_ERROR;
  272. }
  273. }
  274. auto ret = RunPassesGraphRepass(names_to_passes, g_state);
  275. if (ret != SUCCESS) {
  276. return ret;
  277. }
  278. } while (!g_state.nodes_suspend.empty());
  279. return SUCCESS;
  280. }
  281. Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state) {
  282. RepassLevelState rp_state;
  283. do {
  284. for (auto &node : rp_state.nodes_re_pass) {
  285. if (rp_state.nodes_re_pass_set.count(node) > 0) {
  286. GELOGD("Add node %s to queue for re-pass", node->GetName().c_str());
  287. g_state.AddNodeToQueue(node);
  288. }
  289. }
  290. rp_state.ClearRepass();
  291. while (!g_state.nodes.empty()) {
  292. auto node = g_state.PopFront();
  293. if (g_state.nodes_deleted.count(node) > 0) {
  294. GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str());
  295. continue;
  296. }
  297. rp_state.EraseNodeFromRepass(node);
  298. g_state.nodes_seen.insert(node.get());
  299. // collect out nodes before pass
  300. std::unordered_set<NodePtr> out_nodes_before_pass;
  301. for (const auto &out_node : node->GetOutNodes()) {
  302. out_nodes_before_pass.insert(out_node);
  303. }
  304. auto ret = RunPassesNodeOnce(node, names_to_passes, g_state, rp_state);
  305. if (ret != SUCCESS) {
  306. GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(),
  307. node->GetType().c_str(), ret);
  308. return ret;
  309. }
  310. AddNextIterNodes(node, out_nodes_before_pass, g_state);
  311. }
  312. AddLastNodesToQueue(g_state);
  313. } while ((!rp_state.nodes_re_pass.empty() || !g_state.nodes.empty()) && ++g_state.re_pass_times < kMaxRePassTimes);
  314. if (g_state.re_pass_times == kMaxRePassTimes) {
  315. GELOGW("re_pass_times should not come to %d", kMaxRePassTimes);
  316. }
  317. GELOGD("All passes runs end");
  318. return SUCCESS;
  319. }
  320. Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) {
  321. auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames();
  322. has_sub_graph = false;
  323. for (const auto &name : sub_graph_names) {
  324. auto graph = root_graph_->GetSubgraph(name);
  325. if (graph == nullptr) {
  326. GELOGW("Can not find the sub graph %s from node %s, the pass-process will skip it",
  327. name.c_str(), node->GetName().c_str());
  328. continue;
  329. }
  330. has_sub_graph = true;
  331. GELOGI("Begin to run passes on the sub graph %s of node %s", name.c_str(), node->GetName().c_str());
  332. GEPass pass(graph, root_graph_, depth_ + 1);
  333. auto ret = pass.Run(names_to_passes);
  334. if (ret != SUCCESS) {
  335. GELOGE(ret, "[Run][Passes] for sub graph:%s from node:%s failed", name.c_str(), node->GetName().c_str());
  336. return ret;
  337. }
  338. }
  339. return SUCCESS;
  340. }
  341. Status GEPass::RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes,
  342. GraphLevelState &g_state, RepassLevelState &rp_state) {
  343. auto ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state);
  344. if (ret != SUCCESS) {
  345. GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(),
  346. node->GetType().c_str(), ret);
  347. return ret;
  348. }
  349. bool has_sub_graph = false;
  350. ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph);
  351. if (ret != SUCCESS) {
  352. GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str());
  353. return ret;
  354. }
  355. if (has_sub_graph) {
  356. GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str());
  357. SetFlagOption(kOptimizeAfterSubGraph, names_to_passes);
  358. ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state);
  359. if (ret != SUCCESS) {
  360. GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", node->GetName().c_str(),
  361. node->GetType().c_str(), ret);
  362. return ret;
  363. }
  364. // There is only one option scene, so set and clear options around the `RunPasses` func.
  365. // if there are more than one scene to set options, the `ClearOption` function
  366. // should be called each time at the begin of the iteration
  367. ClearOption(names_to_passes);
  368. }
  369. return SUCCESS;
  370. }
  371. Status GEPass::RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state,
  372. RepassLevelState &rp_state) {
  373. if (node == nullptr) {
  374. REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid.");
  375. GELOGE(FAILED, "[Check][Param] parameter node is nullptr.");
  376. return FAILED;
  377. }
  378. GELOGD("Begin to run pass for node %s", node->GetName().c_str());
  379. for (const auto &name_to_pass : names_to_passes) {
  380. GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str());
  381. name_to_pass.second->init();
  382. auto result = name_to_pass.second->Run(node);
  383. if (result != SUCCESS) {
  384. REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u",
  385. name_to_pass.first.c_str(), node->GetName().c_str(), result);
  386. GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result "
  387. "%u, the passes will be terminated immediately.",
  388. name_to_pass.first.c_str(), node->GetName().c_str(), result);
  389. return result;
  390. }
  391. if (name_to_pass.second->GetNodesDeleted().count(node) > 0) {
  392. GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(),
  393. name_to_pass.first.c_str());
  394. break;
  395. }
  396. }
  397. g_state.nodes_passed.insert(node);
  398. std::unordered_map<NodePtr, std::string> re_pass_imm_nodes_to_pass_names;
  399. std::unordered_map<NodePtr, std::string> resume_nodes_to_pass_names;
  400. // if muti psss repass one same node, it will add to queue many times, so collect and duplicate
  401. for (const auto &name_to_pass : names_to_passes) {
  402. PushToRePassIfSeen(node, name_to_pass, g_state.nodes_seen,
  403. name_to_pass.second->GetNodesNeedRePass(),
  404. rp_state);
  405. // collect imm_node && resume_node among these passes
  406. for (const auto &imm_node : name_to_pass.second->GetNodesNeedRePassImmediately()){
  407. re_pass_imm_nodes_to_pass_names[imm_node].append(name_to_pass.first + ",");
  408. }
  409. for (const auto &resume_node : name_to_pass.second->GetNodesResume()){
  410. resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ",");
  411. }
  412. for (const auto &suspend_node : name_to_pass.second->GetNodesSuspend()) {
  413. GELOGD("The iteration suspend of node %s has been set by pass %s", suspend_node->GetName().c_str(),
  414. name_to_pass.first.c_str());
  415. g_state.nodes_suspend.insert(suspend_node);
  416. }
  417. const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
  418. g_state.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end());
  419. }
  420. AddImmediateRepassNodesToQueue(node, re_pass_imm_nodes_to_pass_names, g_state);
  421. AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state);
  422. return SUCCESS;
  423. }
  424. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示