You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

compute_graph.cc 46 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/compute_graph.h"
  17. #include <deque>
  18. #include "./format_refiner.h"
  19. #include "./ge_context.h"
  20. #include "debug/ge_attr_define.h"
  21. #include "debug/ge_log.h"
  22. #include "debug/ge_op_types.h"
  23. #include "debug/ge_util.h"
  24. #include "framework/common/debug/ge_log.h"
  25. #include "ge/ge_api_types.h"
  26. #include "graph/shape_refiner.h"
  27. #include "proto/ge_ir.pb.h"
  28. #include "utils/ge_ir_utils.h"
  29. #include "utils/graph_utils.h"
  30. #include "utils/node_utils.h"
  31. #include "utils/op_desc_utils.h"
  32. #include "utils/string_utils.h"
  33. #include "utils/tensor_utils.h"
  34. namespace ge {
  35. namespace {
  36. const size_t OUTPUT_PARAM_SIZE = 2;
  37. } // namespace
  38. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name)
  39. : name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) {
  40. attrs_.InitDefault();
  41. }
  42. ComputeGraph::~ComputeGraph() {}
  43. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string ComputeGraph::GetName() const { return name_; }
  44. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const string &name) { name_ = name; }
  45. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const {
  46. size_t s = nodes_.size();
  47. for (const auto &sub_graph : sub_graph_) {
  48. s += sub_graph->GetAllNodesSize();
  49. }
  50. return s;
  51. }
  52. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetAllNodes() const {
  53. if (sub_graph_.empty()) {
  54. return Vistor<NodePtr>(shared_from_this(), nodes_);
  55. }
  56. std::vector<NodePtr> all_nodes;
  57. std::deque<NodePtr> candidates;
  58. candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end());
  59. while (!candidates.empty()) {
  60. NodePtr node = candidates.front();
  61. all_nodes.emplace_back(node);
  62. candidates.pop_front();
  63. OpDescPtr op_desc = node->GetOpDesc();
  64. if (op_desc == nullptr) {
  65. continue;
  66. }
  67. const auto &subgraph_names = op_desc->GetSubgraphInstanceNames();
  68. for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) {
  69. auto subgraph = GetSubgraph(*name_iter);
  70. if (subgraph != nullptr) {
  71. candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end());
  72. }
  73. }
  74. }
  75. return Vistor<NodePtr>(shared_from_this(), all_nodes);
  76. }
  77. size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); }
  78. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetDirectNode() const {
  79. return Vistor<NodePtr>(shared_from_this(), nodes_);
  80. }
  81. ComputeGraph::Vistor<NodePtr> ComputeGraph::GetInputNodes() const {
  82. return Vistor<NodePtr>(shared_from_this(), input_nodes_);
  83. }
  84. ComputeGraph::Vistor<NodePtr> ComputeGraph::GetOutputNodes() const {
  85. std::vector<NodePtr> result;
  86. for (auto iter = output_nodes_info_.begin(); iter != output_nodes_info_.end(); ++iter) {
  87. result.push_back(iter->first);
  88. }
  89. return Vistor<NodePtr>(shared_from_this(), result);
  90. }
  91. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::FindNode(const std::string &name) const {
  92. for (const auto &node : nodes_) {
  93. if (node == nullptr) {
  94. continue;
  95. }
  96. if (node->GetName() == name) {
  97. return node;
  98. }
  99. }
  100. return nullptr;
  101. }
  102. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphAttrsAreEqual(
  103. const ComputeGraph &r_graph) const {
  104. // ProtoMsgOwner <::google::protobuf::Message> is temporarily ignored
  105. if ((this->attrs_.protoMsg_ != nullptr) && (r_graph.attrs_.protoMsg_ != nullptr)) {
  106. const auto &proto_attr_map = *(this->attrs_.protoMsg_);
  107. const auto &r_proto_attr_map = *(r_graph.attrs_.protoMsg_);
  108. // 1.Verify graph's ProtoAttrMap size
  109. if (proto_attr_map.size() != r_proto_attr_map.size()) {
  110. GELOGE(GRAPH_FAILED, "Size of compute graph's ProtoAttrMap verify failed, graph name: %s.",
  111. this->GetName().c_str());
  112. return false;
  113. }
  114. // 2.Verify graph's ProtoAttrMap key, verify values is temporarily not implemented
  115. for (const auto &it : proto_attr_map) {
  116. if (r_proto_attr_map.count(it.first) == 0) {
  117. GELOGE(GRAPH_FAILED, "Key of compute graph's ProtoAttrMap verify failed, graph name: %s key name: %s.",
  118. this->GetName().c_str(), it.first.c_str());
  119. return false;
  120. }
  121. }
  122. return true;
  123. }
  124. return ((this->attrs_.protoMsg_ == nullptr) && (r_graph.attrs_.protoMsg_ == nullptr));
  125. }
  126. /// Since there may be different input nodes
  127. /// chosen by user in the same graph, special judgment is needed
  128. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNodePtrIsEqual(
  129. const std::vector<NodePtr> &left_nodes, const std::vector<NodePtr> &right_nodes) const {
  130. const auto left_nodes_size = left_nodes.size();
  131. const auto right_nodes_size = right_nodes.size();
  132. if (left_nodes_size != right_nodes_size) {
  133. GELOGE(GRAPH_FAILED,
  134. "Check failed with graph input_nodes_: "
  135. "left inputNodes size %zu is different with right inputNodes size %zu .",
  136. left_nodes_size, right_nodes_size);
  137. return false;
  138. }
  139. for (size_t j = 0; j < left_nodes_size; j++) {
  140. if (left_nodes.at(j) == nullptr || right_nodes.at(j) == nullptr) {
  141. GELOGE(GRAPH_FAILED, "left_nodes.at(%zu) or right_nodes.at(%zu) is nullptr", j, j);
  142. return false;
  143. }
  144. const auto &left_input_name = left_nodes.at(j)->GetName();
  145. const auto &right_input_name = right_nodes.at(j)->GetName();
  146. if (left_input_name != right_input_name) {
  147. GELOGE(GRAPH_FAILED,
  148. "Check failed with graph input_nodes_: "
  149. "left inputNode name %s is different with right inputNode name %s at inputNodes index %zu.",
  150. left_input_name.c_str(), right_input_name.c_str(), j);
  151. return false;
  152. }
  153. }
  154. return true;
  155. }
  156. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual(
  157. const ComputeGraph &r_graph) const {
  158. return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.subgraphs_.size()") &&
  159. IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") &&
  160. VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) &&
  161. IsEqual(this->name_, r_graph.name_, "graph.name_") &&
  162. IsEqual(this->is_valid_flag_, r_graph.is_valid_flag_, "graph.is_valid_flag_") &&
  163. IsEqual(this->need_iteration_, r_graph.need_iteration_, "graph.need_iteration_") &&
  164. IsEqual(this->params_share_map_, r_graph.params_share_map_, "graph.params_share_map_") &&
  165. IsEqual(this->out_nodes_map_, r_graph.out_nodes_map_, "graph.out_nodes_map_") &&
  166. IsEqual(this->inputs_order_, r_graph.inputs_order_, "graph.inputs_order_") &&
  167. IsEqual(this->output_size_, r_graph.output_size_, "graph.output_size_") &&
  168. IsEqual(this->input_size_, r_graph.input_size_, "graph.input_size_") &&
  169. IsEqual(this->output_nodes_info_, r_graph.output_nodes_info_, "graph.output_nodes_info_"));
  170. }
  171. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::operator==(const ComputeGraph &r_graph) const {
  172. // Firstly: Graph's members equal
  173. if ((!GraphMembersAreEqual(r_graph)) || (!GraphAttrsAreEqual(r_graph))) {
  174. return false;
  175. }
  176. // Secondly: Node equal means the link relationship between node and node itself equal
  177. for (const auto &left_node : nodes_) {
  178. if (left_node == nullptr) {
  179. GELOGE(GRAPH_FAILED, "left_node is nullptr");
  180. return false;
  181. }
  182. const auto &node_name = left_node->GetName();
  183. // After TopologicalSorting, node order can change, so find node by name
  184. const auto &right_node = r_graph.FindNode(node_name);
  185. GE_IF_BOOL_EXEC(right_node == nullptr, GELOGE(GRAPH_FAILED, "right_node is NULL!!!"); return false);
  186. if (!(*right_node == *left_node)) {
  187. GELOGE(GRAPH_FAILED, "Compare graph failed, node name: %s.", node_name.c_str());
  188. return false;
  189. }
  190. }
  191. // Thirdly: Recursively determine whether the sub graphs are equal
  192. for (size_t i = 0; i < this->sub_graph_.size(); i++) {
  193. if (!(*((this->sub_graph_)[i]) == *((r_graph.sub_graph_)[i]))) {
  194. return false;
  195. }
  196. }
  197. return true;
  198. }
  199. NodePtr ComputeGraph::AddNodeFront(NodePtr node) {
  200. if (node == nullptr || node->GetOpDesc() == nullptr) {
  201. GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null.");
  202. return nullptr;
  203. }
  204. node->GetOpDesc()->SetId(nodes_.size());
  205. if (nodes_[0] == nullptr) {
  206. GELOGE(GRAPH_FAILED, "nodes_ size or nodes_[0] is nullptr");
  207. return nullptr;
  208. }
  209. if (nodes_.size() > 0 && nodes_[0]->GetType() == DATA) {
  210. (void)nodes_.insert(nodes_.begin() + 1, node);
  211. } else {
  212. (void)nodes_.insert(nodes_.begin(), node);
  213. }
  214. return node;
  215. }
  216. NodePtr ComputeGraph::AddNodeFront(const OpDescPtr &op) {
  217. if (op == nullptr) {
  218. GELOGE(GRAPH_FAILED, "The OpDesc ptr should be not null.");
  219. return nullptr;
  220. }
  221. op->SetId(nodes_.size());
  222. NodePtr node_ptr = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this()));
  223. GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr);
  224. GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr);
  225. return AddNodeFront(node_ptr);
  226. }
  227. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(NodePtr node) {
  228. if (node == nullptr || node->GetOpDesc() == nullptr) {
  229. GELOGE(GRAPH_FAILED, "The node ptr should be not null.");
  230. return nullptr;
  231. }
  232. node->GetOpDesc()->SetId((int64_t)GetDirectNodesSize());
  233. nodes_.push_back(node);
  234. return node;
  235. }
  236. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpDescPtr op) {
  237. if (op == nullptr) {
  238. GELOGE(GRAPH_FAILED, "The OpDesc ptr should be not null.");
  239. return nullptr;
  240. }
  241. op->SetId(GetDirectNodesSize());
  242. NodePtr node_ptr = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this()));
  243. GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr);
  244. GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr);
  245. return AddNode(node_ptr);
  246. }
  247. NodePtr ComputeGraph::AddInputNode(NodePtr node) {
  248. if (node == nullptr) {
  249. GELOGE(GRAPH_FAILED, "The node ptr should be not null.");
  250. return nullptr;
  251. }
  252. input_nodes_.push_back(node);
  253. if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) {
  254. GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "add node failed");
  255. }
  256. return node;
  257. }
  258. NodePtr ComputeGraph::AddOutputNode(NodePtr node) {
  259. if (node == nullptr || node->GetOpDesc() == nullptr) {
  260. GELOGE(GRAPH_FAILED, "The node ptr or opdesc should be not null.");
  261. return nullptr;
  262. }
  263. bool already_have = false;
  264. NodePtr result = node;
  265. // [output_nodes_info_ : should not be null]
  266. for (const auto &item : output_nodes_info_) {
  267. if (item.first->GetName() == node->GetName()) {
  268. already_have = true;
  269. result = item.first;
  270. break;
  271. }
  272. }
  273. if (!already_have) {
  274. output_nodes_info_.emplace_back(std::make_pair(node, 0));
  275. }
  276. if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) {
  277. GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "add node failed");
  278. }
  279. return result;
  280. }
  281. graphStatus ComputeGraph::RemoveConstInput(const NodePtr &node) {
  282. GE_CHECK_NOTNULL(node);
  283. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  284. auto out_anchor = in_anchor->GetPeerOutAnchor();
  285. if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) {
  286. continue;
  287. }
  288. if (out_anchor->GetOwnerNode()->GetType() == CONSTANT || out_anchor->GetOwnerNode()->GetType() == CONSTANTOP) {
  289. GE_CHK_BOOL_RET_STATUS(GraphUtils::RemoveEdge(out_anchor, in_anchor) == GRAPH_SUCCESS, GRAPH_FAILED,
  290. "Remove edge from const op failed.");
  291. if (out_anchor->GetOwnerNode()->GetOutDataNodes().size() == 0) {
  292. GELOGI("Remove const op %s.", out_anchor->GetOwnerNode()->GetName().c_str());
  293. auto iter = find(nodes_.begin(), nodes_.end(), out_anchor->GetOwnerNode());
  294. if (iter != nodes_.end()) {
  295. (void)nodes_.erase(iter);
  296. }
  297. }
  298. }
  299. }
  300. return GRAPH_SUCCESS;
  301. }
  302. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveNode(const NodePtr &node) {
  303. if (node == nullptr) {
  304. GELOGE(GRAPH_FAILED, "The node ptr should be not null.");
  305. return GRAPH_FAILED;
  306. }
  307. // delete const op for this node
  308. (void)RemoveConstInput(node);
  309. // if the node save as input node, delete it
  310. (void)RemoveInputNode(node);
  311. // if the node save as input node, delete it
  312. (void)RemoveOutputNode(node);
  313. if (GRAPH_SUCCESS != IsolateNode(node)) {
  314. GELOGE(GRAPH_FAILED, "Isolate node failed, node name: %s.", node->GetName().c_str());
  315. return GRAPH_FAILED;
  316. }
  317. auto iter = find(nodes_.begin(), nodes_.end(), node);
  318. if (iter != nodes_.end()) {
  319. (void)nodes_.erase(iter);
  320. return GRAPH_SUCCESS;
  321. }
  322. return GRAPH_FAILED;
  323. }
  324. // Used in sub_graph scenes
  325. graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) {
  326. if (node == nullptr) {
  327. GELOGE(GRAPH_FAILED, "The node ptr should be not null.");
  328. return GRAPH_FAILED;
  329. }
  330. auto iter = find(input_nodes_.begin(), input_nodes_.end(), node);
  331. if (iter != input_nodes_.end()) {
  332. (void)input_nodes_.erase(iter);
  333. return GRAPH_SUCCESS;
  334. }
  335. return GRAPH_FAILED;
  336. }
  337. // Used in sub_graph scenes
  338. graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) {
  339. if (node == nullptr) {
  340. GELOGE(GRAPH_FAILED, "The node ptr should be not null.");
  341. return GRAPH_FAILED;
  342. }
  343. auto iter = output_nodes_info_.begin();
  344. bool find_node = false;
  345. // [output_nodes_info_ : should not be null]
  346. while (iter != output_nodes_info_.end()) {
  347. if (node->GetName() == iter->first->GetName()) {
  348. iter = output_nodes_info_.erase(iter);
  349. find_node = true;
  350. } else {
  351. ++iter;
  352. }
  353. }
  354. GE_IF_BOOL_EXEC(find_node == false, return GRAPH_FAILED);
  355. return GRAPH_SUCCESS;
  356. }
  357. std::shared_ptr<ComputeGraph> ComputeGraph::AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph) {
  358. if (sub_graph == nullptr) {
  359. GELOGE(GRAPH_FAILED, "The graph ptr should be not null.");
  360. return nullptr;
  361. }
  362. sub_graph_.push_back(sub_graph);
  363. return sub_graph;
  364. }
  365. graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph) {
  366. if (sub_graph == nullptr) {
  367. GELOGE(GRAPH_FAILED, "The graph ptr should be not null.");
  368. return GRAPH_FAILED;
  369. }
  370. auto iter = find(sub_graph_.begin(), sub_graph_.end(), sub_graph);
  371. if (iter != sub_graph_.end()) {
  372. (void)sub_graph_.erase(iter);
  373. return GRAPH_SUCCESS;
  374. } else {
  375. GELOGW("find sub_graph failed");
  376. return GRAPH_SUCCESS;
  377. }
  378. }
  379. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  380. ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph) {
  381. if (subgraph == nullptr) {
  382. GE_LOGE("Try to add a null subgraph, name %s", name.c_str());
  383. return GRAPH_PARAM_INVALID;
  384. }
  385. auto parent_graph = subgraph->GetParentGraph();
  386. if (parent_graph == nullptr) {
  387. GE_LOGE("Try to add subgraph without parent graph, name %s", name.c_str());
  388. return GRAPH_PARAM_INVALID;
  389. }
  390. auto parent_node = subgraph->GetParentNode();
  391. if (parent_node == nullptr) {
  392. GE_LOGE("Try to add a subgraph without parent node, name %s", name.c_str());
  393. return GRAPH_PARAM_INVALID;
  394. }
  395. if (parent_node->GetOwnerComputeGraph() != parent_graph) {
  396. GE_LOGE(
  397. "Try to add a subgraph which parent node's parent graph is not equal to "
  398. "the subgraph's parent graph, subgraph name %s, parent node name %s",
  399. subgraph->GetName().c_str(), parent_graph->GetName().c_str());
  400. return GRAPH_PARAM_INVALID;
  401. }
  402. if (!this->parent_graph_.expired()) {
  403. GE_LOGE("The subgraphs can only be added to the root graph");
  404. return GRAPH_PARAM_INVALID;
  405. }
  406. if (name != subgraph->GetName()) {
  407. GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str());
  408. }
  409. sub_graph_.push_back(subgraph);
  410. names_to_subgraph_[name] = subgraph;
  411. return GRAPH_SUCCESS;
  412. }
  413. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  414. ComputeGraph::AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph) {
  415. if (subgraph == nullptr) {
  416. return GRAPH_PARAM_INVALID;
  417. }
  418. return AddSubgraph(subgraph->GetName(), subgraph);
  419. }
  420. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(const std::string &name) {
  421. auto iter = names_to_subgraph_.find(name);
  422. if (iter == names_to_subgraph_.end()) {
  423. return;
  424. }
  425. for (auto vec_iter = sub_graph_.begin(); vec_iter != sub_graph_.end(); ++vec_iter) {
  426. if (*vec_iter == iter->second) {
  427. sub_graph_.erase(vec_iter);
  428. break;
  429. }
  430. }
  431. names_to_subgraph_.erase(iter);
  432. }
  433. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(
  434. const std::shared_ptr<ComputeGraph> &subgraph) {
  435. if (subgraph != nullptr) {
  436. RemoveSubgraph(subgraph->GetName());
  437. }
  438. }
  439. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr<ComputeGraph> ComputeGraph::GetSubgraph(
  440. const std::string &name) const {
  441. auto iter = names_to_subgraph_.find(name);
  442. return iter == names_to_subgraph_.end() ? nullptr : iter->second;
  443. }
  444. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector<std::shared_ptr<ComputeGraph>>
  445. ComputeGraph::GetAllSubgraphs() const {
  446. return sub_graph_;
  447. }
  448. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<ComputeGraph> ComputeGraph::GetParentGraph() {
  449. return parent_graph_.lock();
  450. }
  451. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentGraph(
  452. const shared_ptr<ComputeGraph> &parent) {
  453. parent_graph_ = parent;
  454. }
  455. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<Node> ComputeGraph::GetParentNode() {
  456. return parent_node_.lock();
  457. }
  458. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode(const shared_ptr<Node> &parent) {
  459. parent_node_ = parent;
  460. }
  461. ///
  462. /// @brief Update input-mapping
  463. /// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input
  464. /// @return graphStatus
  465. ///
  466. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  467. ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping) {
  468. for (auto &input : input_nodes_) {
  469. uint32_t cur_index = 0;
  470. if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) {
  471. continue;
  472. }
  473. auto iter = input_mapping.find(cur_index);
  474. if (iter == input_mapping.end()) {
  475. continue;
  476. }
  477. if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) {
  478. GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed.");
  479. return GRAPH_FAILED;
  480. }
  481. }
  482. return GRAPH_SUCCESS;
  483. }
  484. ///
  485. /// @brief Update output-mapping
  486. /// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output
  487. /// @return graphStatus
  488. ///
  489. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  490. ComputeGraph::UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping) {
  491. NodePtr net_output = FindNode(kNodeNameNetOutput);
  492. if (net_output == nullptr) {
  493. GE_LOGE("UpdateOutputMapping failed: node %s not exist in graph.", kNodeNameNetOutput);
  494. return GRAPH_FAILED;
  495. }
  496. OpDescPtr op_desc = net_output->GetOpDesc();
  497. if (op_desc == nullptr) {
  498. GE_LOGE("UpdateOutputMapping failed: op_desc is NULL.");
  499. return GRAPH_FAILED;
  500. }
  501. size_t num = op_desc->GetInputsSize();
  502. for (size_t i = 0; i < num; i++) {
  503. GeTensorDesc tensor = op_desc->GetInputDesc(i);
  504. uint32_t cur_index = 0;
  505. if (!ge::AttrUtils::GetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, cur_index)) {
  506. continue;
  507. }
  508. auto iter = output_mapping.find(cur_index);
  509. if (iter == output_mapping.end()) {
  510. continue;
  511. }
  512. if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) {
  513. GE_LOGE("UpdateOutputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed.");
  514. return GRAPH_FAILED;
  515. }
  516. if (op_desc->UpdateInputDesc(i, tensor) != GRAPH_SUCCESS) {
  517. GE_LOGE("UpdateOutputMapping failed: update %u input_tensor failed.", i);
  518. return GRAPH_FAILED;
  519. }
  520. }
  521. return GRAPH_SUCCESS;
  522. }
  523. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() {
  524. std::vector<NodePtr> node_vec = nodes_;
  525. for (const auto &node : GetAllNodes()) {
  526. if (node == nullptr || node->GetOpDesc() == nullptr) {
  527. GELOGW("node or OpDescPtr is nullptr.");
  528. continue;
  529. }
  530. GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should be not null."); return GRAPH_FAILED);
  531. if (node->GetOpDesc()->GetType() == kRecvType) {
  532. auto iter = find(node_vec.begin(), node_vec.end(), node);
  533. if (iter == node_vec.end()) {
  534. GELOGW("no node found.");
  535. } else {
  536. (void)node_vec.erase(iter);
  537. }
  538. auto dst_iter = find(node_vec.begin(), node_vec.end(), node->GetOutControlNodes().at(0));
  539. (void)node_vec.insert(dst_iter, node);
  540. }
  541. if (node->GetOpDesc()->GetType() == kSendType) {
  542. auto iter = find(node_vec.begin(), node_vec.end(), node);
  543. if (iter == node_vec.end()) {
  544. GELOGW("no node found.");
  545. } else {
  546. (void)node_vec.erase(iter);
  547. }
  548. auto src_iter = find(node_vec.begin(), node_vec.end(), node->GetInControlNodes().at(0));
  549. (void)node_vec.insert(src_iter + 1, node);
  550. }
  551. }
  552. nodes_.clear();
  553. for (size_t i = 0; i < node_vec.size(); ++i) {
  554. NodePtr node = node_vec[i];
  555. if (node == nullptr || node->GetOpDesc() == nullptr) {
  556. GELOGW("node or OpDescPtr is nullptr.");
  557. } else {
  558. node->GetOpDesc()->SetId((int64_t)i);
  559. nodes_.push_back(node);
  560. }
  561. }
  562. return GRAPH_SUCCESS;
  563. }
  564. graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec,
  565. std::map<NodePtr, uint32_t> &map_in_edge_num,
  566. std::vector<NodePtr> &stack) {
  567. GELOGI("Runing_Dfs_Sort: %s", name_.c_str());
  568. // Record the number of non data nodes but no input nodes
  569. GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed");
  570. // Only data nodes here
  571. while (!stack.empty()) {
  572. NodePtr node = stack.back();
  573. stack.pop_back();
  574. node_vec.push_back(node);
  575. GE_CHECK_NOTNULL(node->GetOpDesc());
  576. GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str());
  577. for (const auto &anchor : node->GetAllOutDataAnchors()) {
  578. GE_CHECK_NOTNULL(anchor);
  579. for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) {
  580. GE_CHECK_NOTNULL(peer_in_anchor);
  581. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  582. if (iter != map_in_edge_num.end() && --iter->second == 0) {
  583. stack.push_back(peer_in_anchor->GetOwnerNode());
  584. }
  585. }
  586. for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) {
  587. GE_CHECK_NOTNULL(peer_in_anchor);
  588. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  589. if (iter != map_in_edge_num.end() && --iter->second == 0) {
  590. stack.push_back(peer_in_anchor->GetOwnerNode());
  591. }
  592. }
  593. }
  594. GE_IF_BOOL_EXEC(
  595. node->GetOutControlAnchor() != nullptr, for (AnchorPtr peer_in_anchor
  596. : node->GetOutControlAnchor()->GetPeerAnchors()) {
  597. GE_CHECK_NOTNULL(peer_in_anchor);
  598. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  599. if (iter != map_in_edge_num.end() && --iter->second == 0) {
  600. stack.push_back(peer_in_anchor->GetOwnerNode());
  601. }
  602. })
  603. }
  604. return GRAPH_SUCCESS;
  605. }
  606. graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec,
  607. std::map<NodePtr, uint32_t> &map_in_edge_num,
  608. std::deque<NodePtr> &stack) {
  609. GELOGI("Runing_Bfs_Sort: %s", name_.c_str());
  610. std::vector<NodePtr> stack_input;
  611. std::map<string, NodePtr> breadth_node_map;
  612. // Record the number of non data nodes but no input nodes
  613. GE_CHK_BOOL_EXEC(SortNodes(stack_input, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed");
  614. // Only data nodes here
  615. while (!stack_input.empty() || !stack.empty()) {
  616. NodePtr node = nullptr;
  617. if (!stack.empty()) {
  618. node = stack.back();
  619. stack.pop_back();
  620. } else {
  621. node = stack_input.back();
  622. stack_input.pop_back();
  623. }
  624. node_vec.push_back(node);
  625. GE_CHECK_NOTNULL(node->GetOpDesc());
  626. GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str());
  627. CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map);
  628. for (const auto &name_node : breadth_node_map) {
  629. (void)stack.push_front(name_node.second);
  630. }
  631. breadth_node_map.clear();
  632. }
  633. return GRAPH_SUCCESS;
  634. }
  635. graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num,
  636. std::map<string, NodePtr> &breadth_node_map) {
  637. for (const auto &anchor : node->GetAllOutDataAnchors()) {
  638. for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) {
  639. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  640. if (iter != map_in_edge_num.end() && 0 == --iter->second) {
  641. (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode());
  642. }
  643. }
  644. for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) {
  645. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  646. if (iter != map_in_edge_num.end() && 0 == --iter->second) {
  647. (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode());
  648. }
  649. }
  650. }
  651. GE_IF_BOOL_EXEC(
  652. node->GetOutControlAnchor() != nullptr, for (AnchorPtr peer_in_anchor
  653. : node->GetOutControlAnchor()->GetPeerAnchors()) {
  654. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  655. if (iter != map_in_edge_num.end() && 0 == --iter->second) {
  656. (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode());
  657. }
  658. })
  659. return GRAPH_SUCCESS;
  660. }
  661. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() {
  662. auto ret = TopologicalSortingSubgraph();
  663. if (ret != SUCCESS) {
  664. GELOGE(ret, "Sub graph partition Failed");
  665. return ret;
  666. }
  667. // partition sub graph
  668. for (const auto &sub_graph : GetAllSubgraphs()) {
  669. ret = sub_graph->TopologicalSortingSubgraph();
  670. if (ret != SUCCESS) {
  671. GELOGE(ret, "Sub graph topological sort Failed");
  672. return ret;
  673. }
  674. }
  675. return SUCCESS;
  676. }
  677. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSortingSubgraph() {
  678. std::vector<NodePtr> node_vec;
  679. std::map<NodePtr, uint32_t> map_in_edge_num;
  680. bool use_BFS = false;
  681. string run_mode;
  682. const int base = 10;
  683. if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) {
  684. if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) >= TRAIN) {
  685. use_BFS = true;
  686. }
  687. } else {
  688. GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default.");
  689. }
  690. if (use_BFS) {
  691. std::deque<NodePtr> stack;
  692. if (BFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) {
  693. return GRAPH_FAILED;
  694. }
  695. } else {
  696. std::vector<NodePtr> stack;
  697. if (DFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) {
  698. return GRAPH_FAILED;
  699. }
  700. }
  701. // If they are not equal, there is a closed loop
  702. if (node_vec.size() != nodes_.size()) {
  703. std::set<Node *> itered_nodes_set;
  704. for (auto &node : node_vec) {
  705. itered_nodes_set.insert(node.get());
  706. }
  707. GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", nodes_.size(),
  708. node_vec.size());
  709. for (auto &node : nodes_) {
  710. if (itered_nodes_set.count(node.get()) == 0) {
  711. GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str());
  712. }
  713. }
  714. return GRAPH_FAILED;
  715. }
  716. nodes_.clear();
  717. for (size_t i = 0; i < node_vec.size(); i++) {
  718. NodePtr node = node_vec[i]; // [node: should not be null]
  719. node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null]
  720. nodes_.push_back(node);
  721. }
  722. is_valid_flag_ = true;
  723. return GRAPH_SUCCESS;
  724. }
  725. graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &map_in_edge_num) {
  726. // Record the number of non data nodes but no input nodes
  727. uint32_t spec_node_size = 0;
  728. bool verify_isolated = false;
  729. string run_mode;
  730. const int base = 10;
  731. // Need verify isolated point in PREDICTION mode.
  732. if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) {
  733. if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) < TRAIN) {
  734. verify_isolated = true;
  735. }
  736. }
  737. for (const auto &node : GetDirectNode()) {
  738. GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue);
  739. map_in_edge_num[node] = static_cast<uint32_t>(GetInEdgeSize(node));
  740. if (map_in_edge_num[node] == 0) {
  741. if ((node->GetOpDesc()->GetType() != kDataType) && (node->GetOpDesc()->GetType() != kAippDataType) &&
  742. (node->GetOpDesc()->GetType() != kInputType) && (node->GetOpDesc()->GetType() != kAnnDataType)) {
  743. // At present, can only judge the isolated point without input and output.
  744. // It is impossible to judge the situation with multiple output nodes.
  745. if (verify_isolated && GetOutEdgeSize(node) == 0) {
  746. GELOGE(GRAPH_FAILED, "May has isolated nodes in graph, node name: %s.", node->GetName().c_str());
  747. return GRAPH_FAILED;
  748. }
  749. (void)stack.insert(stack.begin(), node);
  750. spec_node_size++;
  751. continue;
  752. }
  753. // Need to insert the data nodes in reverse order
  754. (void)stack.insert(stack.begin() + spec_node_size, node);
  755. }
  756. }
  757. /// Make sure the inputs order matches with user-designated
  758. /// 1. Get the index of two input nodes in the user-inputs-order(inputs_order_)
  759. /// 2. Compare two indices, if not match, swap the positions of two inputs
  760. /// *: Remind: stack is reverse-order
  761. for (size_t i = 0; i < stack.size(); ++i) {
  762. // If not found in 'inputs_order_', skip it
  763. auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName());
  764. GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue);
  765. auto inx_i = it_i - inputs_order_.begin();
  766. for (size_t j = i + 1; j < stack.size(); ++j) {
  767. // If not found in 'inputs_order_', skip it
  768. auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName());
  769. GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue);
  770. // Compare index, swap them if it should be
  771. auto inx_j = it_j - inputs_order_.begin();
  772. GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j]));
  773. }
  774. }
  775. return GRAPH_SUCCESS;
  776. }
  777. size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) {
  778. size_t in_edge_size = 0;
  779. if (node == nullptr) {
  780. return in_edge_size;
  781. }
  782. for (const auto &anchor : node->GetAllInDataAnchors()) {
  783. in_edge_size = in_edge_size + anchor->GetPeerAnchorsSize();
  784. // Break flow control data loop.
  785. OutDataAnchorPtr out_anchor = anchor->GetPeerOutAnchor();
  786. if ((out_anchor != nullptr) && (out_anchor->GetOwnerNode() != nullptr)) {
  787. NodePtr out_node = out_anchor->GetOwnerNode();
  788. if (out_node == nullptr) {
  789. GELOGW("out node is nullptr");
  790. continue;
  791. }
  792. if ((out_node->GetType() == NEXTITERATION) || (out_node->GetType() == REFNEXTITERATION)) {
  793. GE_IF_BOOL_EXEC(in_edge_size == 0, GELOGE(GRAPH_FAILED, "If [in_edge_size = 0], the result will be reversed");
  794. return in_edge_size);
  795. in_edge_size -= 1;
  796. }
  797. }
  798. }
  799. if (node->GetInControlAnchor() != nullptr) {
  800. in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchorsSize();
  801. }
  802. return in_edge_size;
  803. }
  804. size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) {
  805. size_t out_edge_size = 0;
  806. if (node == nullptr) {
  807. return out_edge_size;
  808. }
  809. // Break flow control data loop.
  810. if ((node->GetType() != NEXTITERATION) && (node->GetType() != REFNEXTITERATION)) {
  811. for (const auto &anchor : node->GetAllOutDataAnchors()) {
  812. if (anchor != nullptr) {
  813. out_edge_size = out_edge_size + anchor->GetPeerAnchors().size();
  814. }
  815. }
  816. }
  817. if (node->GetOutControlAnchor() != nullptr) {
  818. if (out_edge_size > (UINT64_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) {
  819. return 0;
  820. }
  821. out_edge_size = out_edge_size + node->GetOutControlAnchor()->GetPeerAnchors().size();
  822. }
  823. return out_edge_size;
  824. }
  825. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::IsValid() const { return is_valid_flag_; }
  826. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const {
  827. GELOGI("graph name = %s.", GetName().c_str());
  828. for (const auto &node : GetAllNodes()) {
  829. GELOGI("node name = %s.", node->GetName().c_str());
  830. for (const auto &anchor : node->GetAllOutDataAnchors()) {
  831. for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) {
  832. GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
  833. GELOGI("node name = %s, out data node name = %s.", node->GetName().c_str(),
  834. peer_in_anchor->GetOwnerNode()->GetName().c_str()));
  835. }
  836. for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) {
  837. GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
  838. GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(),
  839. peer_in_anchor->GetOwnerNode()->GetName().c_str()));
  840. }
  841. }
  842. auto out_control_anchor = node->GetOutControlAnchor();
  843. if (out_control_anchor != nullptr) {
  844. for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) {
  845. GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
  846. GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(),
  847. peer_in_anchor->GetOwnerNode()->GetName().c_str()));
  848. }
  849. for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) {
  850. GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
  851. GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(),
  852. peer_in_anchor->GetOwnerNode()->GetName().c_str()));
  853. }
  854. }
  855. }
  856. }
  857. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::IsolateNode(const NodePtr &node) {
  858. GE_CHECK_NOTNULL(node);
  859. auto next_nodes = node->GetOutAllNodes();
  860. // If there is input data side
  861. for (size_t i = 0; i < node->GetAllInDataAnchors().size(); i++) {
  862. auto in_data_anchor = node->GetInDataAnchor(static_cast<int>(i));
  863. auto pre_out_data_anchor = in_data_anchor->GetPeerOutAnchor();
  864. if (pre_out_data_anchor != nullptr) {
  865. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_data_anchor, in_data_anchor) == GRAPH_SUCCESS,
  866. return GRAPH_FAILED, "remove edge failed");
  867. GE_IF_BOOL_EXEC(pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANT ||
  868. pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANTOP,
  869. continue);
  870. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  871. for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  872. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS,
  873. return GRAPH_FAILED, "remove edge failed");
  874. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS,
  875. return GRAPH_FAILED, "add edge failed");
  876. }
  877. for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) {
  878. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  879. return GRAPH_FAILED, "remove edge failed");
  880. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  881. return GRAPH_FAILED, "add edge failed");
  882. }
  883. }
  884. auto out_ctrl_anchor = node->GetOutControlAnchor();
  885. GE_CHECK_NOTNULL(out_ctrl_anchor);
  886. auto pre_out_ctrl_anchor = pre_out_data_anchor->GetOwnerNode()->GetOutControlAnchor();
  887. GE_CHECK_NOTNULL(pre_out_ctrl_anchor);
  888. for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
  889. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  890. return GRAPH_FAILED, "remove edge failed");
  891. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  892. return GRAPH_FAILED, "add edge failed");
  893. }
  894. }
  895. }
  896. // If there is an input control side
  897. auto in_ctrl_anchor = node->GetInControlAnchor();
  898. GE_CHECK_NOTNULL(in_ctrl_anchor);
  899. for (const auto &pre_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) {
  900. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_ctrl_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, return GRAPH_FAILED,
  901. "remove edge failed");
  902. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  903. for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) {
  904. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  905. return GRAPH_FAILED, "remove edge failed");
  906. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  907. return GRAPH_FAILED, "add edge failed");
  908. }
  909. }
  910. auto out_ctrl_anchor = node->GetOutControlAnchor();
  911. if (out_ctrl_anchor != nullptr) {
  912. for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
  913. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  914. return GRAPH_FAILED, "remove edge failed");
  915. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  916. return GRAPH_FAILED, "add edge failed");
  917. }
  918. }
  919. }
  920. for (const auto &out_peer_data_anchor : in_ctrl_anchor->GetPeerOutDataAnchors()) {
  921. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_peer_data_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, return GRAPH_FAILED,
  922. "remove edge failed");
  923. for (const auto &next_node : next_nodes) {
  924. auto next_in_control_anchor = next_node->GetInControlAnchor();
  925. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(out_peer_data_anchor, next_in_control_anchor) == GRAPH_SUCCESS,
  926. return GRAPH_FAILED, "add edge failed");
  927. }
  928. }
  929. return RemoveExtraOutEdge(node);
  930. }
  931. graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) {
  932. GE_CHECK_NOTNULL(node);
  933. // Remove redundant output edges
  934. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  935. for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  936. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS,
  937. return GRAPH_FAILED, "remove edge failed");
  938. }
  939. for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) {
  940. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  941. return GRAPH_FAILED, "remove edge failed");
  942. }
  943. }
  944. auto out_ctrl_anchor = node->GetOutControlAnchor();
  945. if (out_ctrl_anchor != nullptr) {
  946. for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
  947. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  948. return GRAPH_FAILED, "remove edge failed");
  949. }
  950. }
  951. return GRAPH_SUCCESS;
  952. }
  953. graphStatus ComputeGraph::Verify() {
  954. for (const auto &node_ptr : GetAllNodes()) {
  955. GE_CHECK_NOTNULL(node_ptr);
  956. GE_CHECK_NOTNULL(node_ptr->GetOpDesc());
  957. GE_CHK_BOOL_EXEC(node_ptr->GetOpDesc()->CommonVerify() == GRAPH_SUCCESS, return GRAPH_FAILED,
  958. "Verifying %s failed.", node_ptr->GetName().c_str());
  959. }
  960. return GRAPH_SUCCESS;
  961. }
  962. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferOriginFormat() {
  963. return ge::FormatRefiner::InferOrigineFormat(shared_from_this());
  964. }
  965. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferShapeInNeed() {
  966. GE_CHK_BOOL_ONLY_LOG(TopologicalSorting() == GRAPH_SUCCESS, "Verifying failed.");
  967. for (const auto &node_ptr : GetAllNodes()) {
  968. GE_CHECK_NOTNULL(node_ptr);
  969. auto op_desc = node_ptr->GetOpDesc();
  970. bool is_need_infer = false;
  971. (void)ge::AttrUtils::GetBool(op_desc, NEED_INFER, is_need_infer);
  972. if (is_need_infer) {
  973. GE_CHK_BOOL_EXEC(node_ptr->Verify() == GRAPH_SUCCESS, return GRAPH_FAILED, "Verifying %s failed.",
  974. node_ptr->GetName().c_str());
  975. graphStatus status = node_ptr->InferShapeAndType();
  976. GE_CHK_BOOL_EXEC_INFO(node_ptr->GetType() == kDataType || GRAPH_PARAM_INVALID != status, break,
  977. "Op %s does not have the IMPLEMT_INFERFUNC definition,"
  978. " and subsequent operators no longer perform shape inference.",
  979. node_ptr->GetName().c_str());
  980. GE_CHK_BOOL_EXEC(status == GRAPH_SUCCESS, return GRAPH_FAILED, "Inferring %s failed.",
  981. node_ptr->GetName().c_str());
  982. for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
  983. GE_CHECK_NOTNULL(out_anchor->GetOwnerNode()->GetOpDesc());
  984. auto output_tensor = out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx());
  985. ge::TensorUtils::SetRealDimCnt(output_tensor, output_tensor.GetShape().GetDims().size());
  986. (void)out_anchor->GetOwnerNode()->GetOpDesc()->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor);
  987. for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
  988. (void)peer_anchor->GetOwnerNode()->GetOpDesc()->UpdateInputDesc(peer_anchor->GetIdx(), output_tensor);
  989. }
  990. }
  991. }
  992. }
  993. return GRAPH_SUCCESS;
  994. }
  995. ProtoAttrMapHelper ComputeGraph::MutableAttrMap() { return attrs_; }
  996. ConstProtoAttrMapHelper ComputeGraph::GetAttrMap() const {
  997. return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg());
  998. }
  999. const std::map<OperatorImplPtr, NodePtr> &ComputeGraph::GetAllNodesInfo() const { return all_nodes_infos_; }
  1000. void ComputeGraph::SetUserDefOutput(const std::string &output_name) {
  1001. if (output_name.empty()) {
  1002. return;
  1003. }
  1004. vector<string> nodes = StringUtils::Split(output_name, ';');
  1005. for (string node : nodes) {
  1006. vector<string> item = StringUtils::Split(node, ':');
  1007. if (item.size() != OUTPUT_PARAM_SIZE) {
  1008. GELOGW("invalid output param!input:%s", output_name.c_str());
  1009. continue;
  1010. }
  1011. int32_t index;
  1012. try {
  1013. index = stoi(StringUtils::Trim(item[1]));
  1014. } catch (const std::out_of_range &) {
  1015. GELOGW("outputname cause out of range execption!output_name:%s", output_name.c_str());
  1016. continue;
  1017. } catch (const std::invalid_argument &) {
  1018. GELOGW("outputname cause invalid argument!output_name:%s", output_name.c_str());
  1019. continue;
  1020. } catch (...) {
  1021. GELOGW("stoi fail! output_name:%s", output_name.c_str());
  1022. continue;
  1023. }
  1024. auto iter = out_nodes_map_.find(item[0]);
  1025. if (iter == out_nodes_map_.end()) {
  1026. out_nodes_map_[item[0]] = std::vector<int32_t>(1, index);
  1027. } else {
  1028. auto idx_iter = std::find(iter->second.begin(), iter->second.end(), index);
  1029. if (idx_iter == iter->second.end()) {
  1030. iter->second.push_back(index);
  1031. }
  1032. }
  1033. }
  1034. }
  1035. const std::string ComputeGraph::GetOutput() {
  1036. static const int resultDefaultSize = 2048;
  1037. string result;
  1038. result.reserve(resultDefaultSize);
  1039. auto iter = out_nodes_map_.begin();
  1040. while (iter != out_nodes_map_.end()) {
  1041. auto idxes = iter->second;
  1042. for (auto idx : idxes) {
  1043. (void)result.append(iter->first).append(":").append(std::to_string(idx)).append(";");
  1044. }
  1045. ++iter;
  1046. }
  1047. return result.substr(0, result.length() - 1);
  1048. }
  1049. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示