You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

unused_op_remove_pass.cc 4.5 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/unused_op_remove_pass.h"
  17. #include <queue>
  18. #include <set>
  19. #include <string>
  20. #include <vector>
  21. #include "common/debug/log.h"
  22. #include "common/op/ge_op_utils.h"
  23. #include "common/types.h"
  24. #include "common/util.h"
  25. #include "graph/utils/attr_utils.h"
  26. #include "graph/utils/graph_utils.h"
  27. #include "graph/utils/op_desc_utils.h"
  28. #include "inc/pass_manager.h"
  29. #include "graph/passes/isolated_op_remove_pass.h"
  30. using domi::SUCCESS;
  31. namespace ge {
  32. const std::set<std::string> kRemoveOpSet = {DROPOUT, PERMUTE, UNUSEDCONST, ASSERT};
  33. const std::set<std::string> kOtherRemoveOpSet = {DROPOUT};
  34. Status UnusedOpRemovePass::Run(ComputeGraphPtr graph) {
  35. GE_CHECK_NOTNULL(graph);
  36. std::set<std::string> remove_op_set;
  37. vector<NodePtr> nodes_to_be_deleted;
  38. if (fmktype_ == TENSORFLOW) {
  39. remove_op_set = kRemoveOpSet;
  40. } else {
  41. remove_op_set = kOtherRemoveOpSet;
  42. }
  43. for (auto &node : graph->GetDirectNode()) {
  44. GE_CHECK_NOTNULL(node->GetOpDesc());
  45. std::string op_type_str = node->GetOpDesc()->GetType();
  46. if (remove_op_set.count(op_type_str)) {
  47. if (IsExceptions(node)) {
  48. continue;
  49. }
  50. for (auto &out_anchor : node->GetAllOutDataAnchors()) {
  51. for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) {
  52. NodePtr dst_node = in_anchor->GetOwnerNode();
  53. GE_CHECK_NOTNULL(dst_node->GetOpDesc());
  54. int dst_index = in_anchor->GetIdx();
  55. std::vector<bool> list_bool;
  56. GE_CHECK_NOTNULL(dst_node->GetOpDesc());
  57. list_bool = dst_node->GetOpDesc()->GetIsInputConst();
  58. GE_IF_BOOL_EXEC(list_bool.size() == 0, continue);
  59. list_bool.erase(list_bool.begin() + dst_index);
  60. dst_node->GetOpDesc()->SetIsInputConst(list_bool);
  61. }
  62. }
  63. if (op_type_str == ASSERT) {
  64. GE_CHK_STATUS_RET(CollectParentNode(graph, node, nodes_to_be_deleted), "remove node failed");
  65. } else {
  66. GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node failed");
  67. }
  68. }
  69. }
  70. for (auto &node : nodes_to_be_deleted) {
  71. for (InDataAnchorPtr &inAnchor : node->GetAllInDataAnchors()) {
  72. inAnchor->UnlinkAll();
  73. }
  74. for (OutDataAnchorPtr &outAnchorPtr : node->GetAllOutDataAnchors()) {
  75. outAnchorPtr->UnlinkAll();
  76. }
  77. if (node->GetOutControlAnchor() != nullptr) {
  78. node->GetOutControlAnchor()->UnlinkAll();
  79. }
  80. GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node:%s failed", node->GetName().c_str());
  81. }
  82. return SUCCESS;
  83. }
  84. Status UnusedOpRemovePass::CollectParentNode(const ComputeGraphPtr &graph, const NodePtr &node,
  85. vector<NodePtr> &node_vec) {
  86. GE_CHECK_NOTNULL(graph);
  87. GE_CHECK_NOTNULL(node);
  88. node_vec.push_back(node);
  89. std::queue<NodePtr> node_queue;
  90. for (auto &src_node : node->GetInDataNodes()) {
  91. if (src_node->GetOutDataNodesSize() == 1) {
  92. node_queue.push(src_node);
  93. }
  94. }
  95. while (!node_queue.empty()) {
  96. NodePtr temp = node_queue.front();
  97. node_queue.pop();
  98. for (auto &src_node : temp->GetInDataNodes()) {
  99. if (src_node->GetOutDataNodesSize() == 1) {
  100. node_queue.push(src_node);
  101. }
  102. }
  103. node_vec.push_back(temp);
  104. }
  105. return SUCCESS;
  106. }
  107. bool UnusedOpRemovePass::IsExceptions(const NodePtr &node) {
  108. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr");
  109. auto op_def = node->GetOpDesc();
  110. GE_CHK_BOOL_EXEC(op_def != nullptr, return false, "opdesc is nullptr");
  111. // permute optimised in permute_pass.cpp
  112. if (op_def->GetType() == PERMUTE) {
  113. GE_IF_BOOL_EXEC(
  114. (node->GetInDataNodes().size() != 0 &&
  115. (node->GetInDataNodes().at(0) != nullptr && node->GetInDataNodes().at(0)->GetOpDesc() != nullptr &&
  116. node->GetInDataNodes().at(0)->GetOpDesc()->GetType() == ATTENTIONDECODER)),
  117. return false);
  118. return true;
  119. }
  120. return false;
  121. }
  122. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示