You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

folding_pass.cc 13 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/folding_pass.h"
  17. #include <memory>
  18. #include <string>
  19. #include <utility>
  20. #include <vector>
  21. #include <unordered_set>
  22. #include "framework/common/debug/ge_log.h"
  23. #include "graph/utils/graph_utils.h"
  24. #include "graph/utils/node_utils.h"
  25. #include "inc/kernel.h"
  26. #include "inc/kernel_factory.h"
  27. #include "graph/debug/ge_attr_define.h"
  28. #include "ge_local_engine/engine/host_cpu_engine.h"
  29. namespace ge {
  30. namespace folding_pass {
  31. shared_ptr<Kernel> GetKernelByType(const NodePtr &node) {
  32. if (node == nullptr) {
  33. GELOGE(FAILED, "parameter is null.");
  34. return nullptr;
  35. }
  36. KernelFactory &factory = KernelFactory::Instance();
  37. string type = node->GetType();
  38. if (type == FRAMEWORKOP) {
  39. if (!ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) {
  40. return nullptr;
  41. }
  42. }
  43. return factory.Create(type);
  44. }
  45. bool IsNoNeedConstantFolding(const NodePtr &node) {
  46. auto node_desc = node->GetOpDesc();
  47. return node_desc == nullptr || node_desc->HasAttr(ATTR_NO_NEED_CONSTANT_FOLDING);
  48. }
  49. } // namespace folding_pass
  50. namespace {
  51. IndexsToAnchors GetIndexAndPeerInDataAnchors(NodePtr &node) {
  52. IndexsToAnchors indexes_to_anchors;
  53. for (auto &out_anchor : node->GetAllOutDataAnchors()) {
  54. if (out_anchor == nullptr) {
  55. continue;
  56. }
  57. for (auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
  58. if (peer_in_anchor == nullptr) {
  59. continue;
  60. }
  61. const auto &peer_node = peer_in_anchor->GetOwnerNode();
  62. if (peer_node == nullptr) {
  63. continue;
  64. }
  65. indexes_to_anchors[out_anchor->GetIdx()].push_back(peer_in_anchor);
  66. }
  67. }
  68. return indexes_to_anchors;
  69. }
  70. NodePtr AddConstNodeToGraph(GeTensorPtr &tensor, ComputeGraphPtr &graph) {
  71. auto const_desc = OpDescUtils::CreateConstOp(tensor);
  72. if (const_desc == nullptr) {
  73. GELOGE(OUT_OF_MEMORY, "Failed to get const desc from tensor");
  74. return nullptr;
  75. }
  76. GE_IF_BOOL_EXEC(graph == nullptr, GELOGW("input param graph is null"); return nullptr);
  77. return graph->AddNodeFront(const_desc);
  78. }
  79. NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tensor, ComputeGraphPtr &graph) {
  80. if (graph == nullptr) {
  81. GELOGE(INTERNAL_ERROR, "Compute graph ptr is null in creating identity node.");
  82. return nullptr;
  83. }
  84. OpDescPtr desc = MakeShared<OpDesc>("", "");
  85. if (desc == nullptr) {
  86. GELOGE(MEMALLOC_FAILED, "Failed to create op desc.");
  87. return nullptr;
  88. }
  89. desc->SetName(name);
  90. desc->SetType(IDENTITY);
  91. auto ret = desc->AddInputDesc(tensor);
  92. auto ret2 = desc->AddOutputDesc(tensor);
  93. if ((ret != GRAPH_SUCCESS) || (ret2 != GRAPH_SUCCESS)) {
  94. GELOGE(INTERNAL_ERROR, "Failed to add input/output desc in creating Identity.");
  95. return nullptr;
  96. }
  97. return graph->AddNodeFront(desc);
  98. }
  99. } // namespace
  100. Status FoldingPass::RunOpKernel(NodePtr &node,
  101. const vector<ConstGeTensorPtr> &inputs,
  102. std::vector<GeTensorPtr> &outputs) {
  103. return HostCpuEngine::GetInstance().Run(node, inputs, outputs);
  104. }
  105. Status FoldingPass::Folding(NodePtr &node, vector<GeTensorPtr> &outputs) {
  106. GE_CHECK_NOTNULL(node);
  107. GELOGD("begin folding node:%s", node->GetName().c_str());
  108. // Before processing nodes, collect the relations between the out anchor and the peer out data nodes
  109. // to prepare for const reconnection
  110. auto indexes_to_anchors = GetIndexAndPeerInDataAnchors(node);
  111. auto ret = DealWithInNodes(node);
  112. if (ret != SUCCESS) {
  113. return ret;
  114. }
  115. if (AddConstNode(node, indexes_to_anchors, outputs) != SUCCESS) {
  116. return INTERNAL_ERROR;
  117. }
  118. auto in_data_nodes = node->GetInDataNodes();
  119. std::unordered_set<NodePtr> in_data_nodes_set(in_data_nodes.begin(), in_data_nodes.end());
  120. if (IsolateAndDeleteNode(node, {}) != SUCCESS) {
  121. GELOGE(INTERNAL_ERROR, "Failed to isolate and delete node %s, type %s.",
  122. node->GetName().c_str(), node->GetType().c_str());
  123. return INTERNAL_ERROR;
  124. }
  125. for (auto iter = in_data_nodes_set.begin(); iter != in_data_nodes_set.end(); ++iter) {
  126. auto pre_node = *iter;
  127. if (pre_node->GetOutDataNodesSize() == 0) {
  128. if ((pre_node->GetType() == DATA) || (pre_node->GetType() == ENTER)) {
  129. GELOGI("No need to remove data/enter, node name:%s.", pre_node->GetName().c_str());
  130. continue;
  131. }
  132. if (IsolateAndDeleteNode(pre_node, {}) != SUCCESS) {
  133. GELOGE(INTERNAL_ERROR, "Failed to isolate and delete in data node %s, type %s.",
  134. pre_node->GetName().c_str(), pre_node->GetType().c_str());
  135. return INTERNAL_ERROR;
  136. }
  137. }
  138. }
  139. return SUCCESS;
  140. }
  141. Status FoldingPass::DealWithInNodes(NodePtr &node) {
  142. GE_CHECK_NOTNULL(node);
  143. GE_CHECK_NOTNULL(node->GetOpDesc());
  144. auto graph = node->GetOwnerComputeGraph();
  145. auto in_data_anchors = node->GetAllInDataAnchors();
  146. for (auto &in_data_anchor : in_data_anchors) {
  147. if (in_data_anchor == nullptr) {
  148. continue;
  149. }
  150. auto in_node_anchor = in_data_anchor->GetPeerOutAnchor();
  151. if (in_node_anchor == nullptr) {
  152. continue;
  153. }
  154. auto in_node = in_node_anchor->GetOwnerNode();
  155. if (in_node == nullptr) {
  156. continue;
  157. }
  158. if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH)) {
  159. GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str());
  160. auto ret = in_node_anchor->Unlink(in_data_anchor);
  161. if (ret != SUCCESS) {
  162. GELOGE(INTERNAL_ERROR, "Failed to unlink anchor between const node %s to constant-folding-node %s, type %s.",
  163. in_node->GetName().c_str(), node->GetName().c_str(), node->GetType().c_str());
  164. return INTERNAL_ERROR;
  165. }
  166. GELOGI("Unlink anchor between in_node %s and node %s success.", in_node->GetName().c_str(),
  167. node->GetName().c_str());
  168. auto identity_name = node->GetName() + "_ctrl_identity_" + std::to_string(in_data_anchor->GetIdx());
  169. auto identity =
  170. AddIdentityNodeToGraph(identity_name, node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()), graph);
  171. if (identity == nullptr) {
  172. GELOGE(INTERNAL_ERROR, "Failed to add identity node to graph.");
  173. return INTERNAL_ERROR;
  174. }
  175. ret = GraphUtils::AddEdge(in_node_anchor, identity->GetInDataAnchor(0));
  176. if (ret != GRAPH_SUCCESS) {
  177. GELOGE(INTERNAL_ERROR, "Failed to add edge, from node %s to node %s.", in_node->GetName().c_str(),
  178. identity->GetName().c_str());
  179. return INTERNAL_ERROR;
  180. }
  181. GELOGI("Create new identity node success.");
  182. ret = GraphUtils::AddEdge(identity->GetOutControlAnchor(), node->GetInControlAnchor());
  183. if (ret != GRAPH_SUCCESS) {
  184. GELOGE(INTERNAL_ERROR, "Failed to add edge, from node %s to node %s.", in_node->GetName().c_str(),
  185. node->GetName().c_str());
  186. return INTERNAL_ERROR;
  187. }
  188. }
  189. }
  190. return SUCCESS;
  191. }
  192. Status FoldingPass::AddConstNode(NodePtr &node, IndexsToAnchors indexes_to_anchors,
  193. std::vector<GeTensorPtr> &v_weight) {
  194. if (node == nullptr) {
  195. GELOGE(PARAM_INVALID, "node is null");
  196. return FAILED;
  197. }
  198. auto graph = node->GetOwnerComputeGraph();
  199. for (auto &index_to_anchors : indexes_to_anchors) {
  200. auto index = static_cast<size_t>(index_to_anchors.first);
  201. if (index >= v_weight.size()) {
  202. GELOGE(INTERNAL_ERROR,
  203. "Failed to constant fold on node %s type %s, "
  204. "the out nodes num %lu calculated is less than the node out anchor index %zu",
  205. node->GetName().c_str(), node->GetType().c_str(), v_weight.size(), index);
  206. return INTERNAL_ERROR;
  207. }
  208. GeTensorPtr weight = v_weight[index];
  209. if (weight == nullptr) {
  210. GELOGE(INTERNAL_ERROR, "Failed to constant fold on node %s type %s, the %lust node calculated is null",
  211. node->GetName().c_str(), node->GetType().c_str(), index);
  212. return INTERNAL_ERROR;
  213. }
  214. auto const_node = AddConstNodeToGraph(weight, graph);
  215. if (const_node == nullptr) {
  216. GELOGE(INTERNAL_ERROR, "Failed to add dynamic const node, node name:%s, index:%zu.",
  217. node->GetName().c_str(), index);
  218. return INTERNAL_ERROR;
  219. }
  220. vector<string> curr_origin_op_names;
  221. (void)AttrUtils::GetListStr(node->GetOpDesc(), curr_origin_op_names);
  222. if (curr_origin_op_names.empty()) {
  223. (void)AttrUtils::SetListStr(const_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, std::move({node->GetName()}));
  224. } else {
  225. (void)AttrUtils::SetListStr(const_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, curr_origin_op_names);
  226. }
  227. GELOGI("add const_node:%s, replace node %s, type %s, index %zu.", const_node->GetName().c_str(),
  228. node->GetName().c_str(), node->GetType().c_str(), index);
  229. // add new const to re-pass node
  230. for (auto &in_anchor : index_to_anchors.second) {
  231. if (in_anchor == nullptr) {
  232. GELOGE(INTERNAL_ERROR, "In anchor is nullptr.");
  233. return INTERNAL_ERROR;
  234. }
  235. auto ret = ConnectNodeToInAnchor(in_anchor, const_node, 0);
  236. if (ret != SUCCESS) {
  237. return ret;
  238. }
  239. NodeUtils::UpdateIsInputConst(*(in_anchor->GetOwnerNode()));
  240. }
  241. Status ret = GraphUtils::AddEdge(node->GetOutControlAnchor(), const_node->GetInControlAnchor());
  242. if (ret != GRAPH_SUCCESS) {
  243. GELOGE(INTERNAL_ERROR, "Failed to add control edge, from node %s to const node %s.", node->GetName().c_str(),
  244. const_node->GetName().c_str());
  245. return INTERNAL_ERROR;
  246. }
  247. GE_CHECK_NOTNULL(node->GetOpDesc());
  248. std::string stream_label;
  249. if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) {
  250. GE_CHECK_NOTNULL(const_node->GetOpDesc());
  251. if (!AttrUtils::SetStr(const_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) {
  252. GELOGE(INTERNAL_ERROR, "Failed to set stream label on dynamic const node %s, with stream label:%s.",
  253. const_node->GetName().c_str(), stream_label.c_str());
  254. return INTERNAL_ERROR;
  255. }
  256. }
  257. GELOGD("Add control edge when insert dynamic const, from node %s to const node %s, with stream label:%s.",
  258. node->GetName().c_str(), const_node->GetName().c_str(), stream_label.c_str());
  259. }
  260. return SUCCESS;
  261. }
  262. Status FoldingPass::RemoveNodeKeepingCtrlEdges(NodePtr &node) {
  263. GE_IF_BOOL_EXEC(node == nullptr, GELOGE(PARAM_INVALID, "node is null"); return PARAM_INVALID);
  264. auto ret = GraphUtils::IsolateNode(node, {});
  265. if (ret != GRAPH_SUCCESS) {
  266. GELOGE(INTERNAL_ERROR, "Failed to isolate the folding-node %s type %s", node->GetName().c_str(),
  267. node->GetType().c_str());
  268. return INTERNAL_ERROR;
  269. }
  270. auto graph = node->GetOwnerComputeGraph();
  271. ret = GraphUtils::RemoveNodeWithoutRelink(graph, node);
  272. if (ret != GRAPH_SUCCESS) {
  273. GELOGE(INTERNAL_ERROR, "Failed to remove node %s from graph", node->GetName().c_str());
  274. return INTERNAL_ERROR;
  275. }
  276. AddNodeDeleted(node);
  277. return SUCCESS;
  278. }
  279. Status FoldingPass::ConnectNodeToInAnchor(InDataAnchorPtr &in_anchor, NodePtr &node, int node_index) {
  280. // the origin edge must be removed before add
  281. if (in_anchor == nullptr || node == nullptr) {
  282. GELOGE(PARAM_INVALID, "in anchor or node is null");
  283. return PARAM_INVALID;
  284. }
  285. auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
  286. if (peer_out_anchor != nullptr) {
  287. if (ge::GraphUtils::RemoveEdge(peer_out_anchor, in_anchor) != GRAPH_SUCCESS) {
  288. GELOGW("RemoveEdge failed.");
  289. }
  290. }
  291. auto new_out_anchor = node->GetOutDataAnchor(node_index);
  292. if (new_out_anchor == nullptr) {
  293. GELOGE(INTERNAL_ERROR,
  294. "Failed to add node to in anchor,"
  295. " the index %d for node %s, type %s is invalid",
  296. node_index, node->GetName().c_str(), node->GetType().c_str());
  297. return INTERNAL_ERROR;
  298. }
  299. if (GraphUtils::AddEdge(new_out_anchor, in_anchor) != GRAPH_SUCCESS) {
  300. GELOGE(INTERNAL_ERROR,
  301. "Failed to add edge between anchors,"
  302. " new node %s, type %s",
  303. node->GetName().c_str(), node->GetType().c_str());
  304. return INTERNAL_ERROR;
  305. }
  306. AddRePassNodesWithInOut(node);
  307. return SUCCESS;
  308. }
  309. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示