You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

compute_graph.cc 52 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/compute_graph.h"
  17. #include <deque>
  18. #include "./format_refiner.h"
  19. #include "./ge_context.h"
  20. #include "debug/ge_attr_define.h"
  21. #include "debug/ge_log.h"
  22. #include "debug/ge_op_types.h"
  23. #include "debug/ge_util.h"
  24. #include "framework/common/debug/ge_log.h"
  25. #include "ge/ge_api_types.h"
  26. #include "graph/shape_refiner.h"
  27. #include "proto/ge_ir.pb.h"
  28. #include "utils/ge_ir_utils.h"
  29. #include "utils/graph_utils.h"
  30. #include "utils/node_utils.h"
  31. #include "utils/op_desc_utils.h"
  32. #include "utils/string_utils.h"
  33. #include "utils/tensor_utils.h"
  34. namespace ge {
  35. namespace {
  36. const size_t OUTPUT_PARAM_SIZE = 2;
  37. const std::string alias_name_attr = "_aliasName";
  38. bool IsUseBFS() {
  39. string run_mode;
  40. const int base = 10;
  41. if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) {
  42. if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) >= TRAIN) {
  43. return true;
  44. }
  45. } else {
  46. GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default.");
  47. }
  48. return false;
  49. }
  50. } // namespace
  51. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name)
  52. : name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) {
  53. attrs_.InitDefault();
  54. }
  55. ComputeGraph::~ComputeGraph() {}
  56. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string ComputeGraph::GetName() const { return name_; }
  57. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const string &name) { name_ = name; }
  58. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const {
  59. return GetAllNodes().size();
  60. }
  61. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetAllNodes() const {
  62. std::vector<std::shared_ptr<ComputeGraph>> subgraphs;
  63. return AllGraphNodes(subgraphs);
  64. }
  65. ComputeGraph::Vistor<NodePtr> ComputeGraph::AllGraphNodes(std::vector<std::shared_ptr<ComputeGraph>> &subgraphs) const {
  66. std::vector<NodePtr> all_nodes;
  67. std::deque<NodePtr> candidates;
  68. candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end());
  69. while (!candidates.empty()) {
  70. NodePtr node = candidates.front();
  71. all_nodes.emplace_back(node);
  72. candidates.pop_front();
  73. OpDescPtr op_desc = node->GetOpDesc();
  74. if (op_desc == nullptr) {
  75. continue;
  76. }
  77. const auto &subgraph_names = op_desc->GetSubgraphInstanceNames();
  78. for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) {
  79. auto subgraph = GetSubgraph(*name_iter);
  80. if (subgraph != nullptr) {
  81. subgraphs.emplace_back(subgraph);
  82. candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end());
  83. }
  84. }
  85. }
  86. return Vistor<NodePtr>(shared_from_this(), all_nodes);
  87. }
  88. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetNodes(
  89. bool is_unknown_shape) const {
  90. if (is_unknown_shape) {
  91. return GetDirectNode();
  92. } else {
  93. return GetAllNodes();
  94. }
  95. }
  96. size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); }
  97. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetDirectNode() const {
  98. return Vistor<NodePtr>(shared_from_this(), nodes_);
  99. }
  100. ComputeGraph::Vistor<NodePtr> ComputeGraph::GetInputNodes() const {
  101. return Vistor<NodePtr>(shared_from_this(), input_nodes_);
  102. }
  103. ComputeGraph::Vistor<NodePtr> ComputeGraph::GetOutputNodes() const {
  104. std::vector<NodePtr> result;
  105. for (auto iter = output_nodes_info_.begin(); iter != output_nodes_info_.end(); ++iter) {
  106. result.push_back(iter->first);
  107. }
  108. return Vistor<NodePtr>(shared_from_this(), result);
  109. }
  110. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::FindNode(const std::string &name) const {
  111. for (const auto &node : nodes_) {
  112. if (node == nullptr) {
  113. continue;
  114. }
  115. if (node->GetName() == name) {
  116. return node;
  117. }
  118. std::vector<string> out_alias_name;
  119. if (AttrUtils::GetListStr(node->GetOpDesc(), alias_name_attr, out_alias_name)) {
  120. for (const auto &alias_name : out_alias_name) {
  121. if (alias_name == name) {
  122. return node;
  123. }
  124. }
  125. }
  126. }
  127. return nullptr;
  128. }
  129. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr
  130. ComputeGraph::FindFirstNodeMatchType(const std::string &name) const {
  131. for (const auto &node : nodes_) {
  132. if (node == nullptr) {
  133. continue;
  134. }
  135. if (node->GetType() == name) {
  136. return node;
  137. }
  138. }
  139. return nullptr;
  140. }
  141. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphAttrsAreEqual(
  142. const ComputeGraph &r_graph) const {
  143. // ProtoMsgOwner <::google::protobuf::Message> is temporarily ignored
  144. if ((this->attrs_.protoMsg_ != nullptr) && (r_graph.attrs_.protoMsg_ != nullptr)) {
  145. const auto &proto_attr_map = *(this->attrs_.protoMsg_);
  146. const auto &r_proto_attr_map = *(r_graph.attrs_.protoMsg_);
  147. // 1.Verify graph's ProtoAttrMap size
  148. if (proto_attr_map.size() != r_proto_attr_map.size()) {
  149. GELOGE(GRAPH_FAILED, "Size of compute graph's ProtoAttrMap verify failed, graph name: %s.",
  150. this->GetName().c_str());
  151. return false;
  152. }
  153. // 2.Verify graph's ProtoAttrMap key, verify values is temporarily not implemented
  154. for (const auto &it : proto_attr_map) {
  155. if (r_proto_attr_map.count(it.first) == 0) {
  156. GELOGE(GRAPH_FAILED, "Key of compute graph's ProtoAttrMap verify failed, graph name: %s key name: %s.",
  157. this->GetName().c_str(), it.first.c_str());
  158. return false;
  159. }
  160. }
  161. return true;
  162. }
  163. return ((this->attrs_.protoMsg_ == nullptr) && (r_graph.attrs_.protoMsg_ == nullptr));
  164. }
  165. /// Since there may be different input nodes
  166. /// chosen by user in the same graph, special judgment is needed
  167. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNodePtrIsEqual(
  168. const std::vector<NodePtr> &left_nodes, const std::vector<NodePtr> &right_nodes) const {
  169. const auto left_nodes_size = left_nodes.size();
  170. const auto right_nodes_size = right_nodes.size();
  171. if (left_nodes_size != right_nodes_size) {
  172. GELOGE(GRAPH_FAILED,
  173. "Check failed with graph input_nodes_: "
  174. "left inputNodes size %zu is different with right inputNodes size %zu .",
  175. left_nodes_size, right_nodes_size);
  176. return false;
  177. }
  178. for (size_t j = 0; j < left_nodes_size; j++) {
  179. if (left_nodes.at(j) == nullptr || right_nodes.at(j) == nullptr) {
  180. GELOGE(GRAPH_FAILED, "left_nodes.at(%zu) or right_nodes.at(%zu) is nullptr", j, j);
  181. return false;
  182. }
  183. const auto &left_input_name = left_nodes.at(j)->GetName();
  184. const auto &right_input_name = right_nodes.at(j)->GetName();
  185. if (left_input_name != right_input_name) {
  186. GELOGE(GRAPH_FAILED,
  187. "Check failed with graph input_nodes_: "
  188. "left inputNode name %s is different with right inputNode name %s at inputNodes index %zu.",
  189. left_input_name.c_str(), right_input_name.c_str(), j);
  190. return false;
  191. }
  192. }
  193. return true;
  194. }
  195. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual(
  196. const ComputeGraph &r_graph) const {
  197. return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.subgraphs_.size()") &&
  198. IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") &&
  199. VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) &&
  200. IsEqual(this->name_, r_graph.name_, "graph.name_") &&
  201. IsEqual(this->is_valid_flag_, r_graph.is_valid_flag_, "graph.is_valid_flag_") &&
  202. IsEqual(this->need_iteration_, r_graph.need_iteration_, "graph.need_iteration_") &&
  203. IsEqual(this->params_share_map_, r_graph.params_share_map_, "graph.params_share_map_") &&
  204. IsEqual(this->out_nodes_map_, r_graph.out_nodes_map_, "graph.out_nodes_map_") &&
  205. IsEqual(this->inputs_order_, r_graph.inputs_order_, "graph.inputs_order_") &&
  206. IsEqual(this->output_size_, r_graph.output_size_, "graph.output_size_") &&
  207. IsEqual(this->input_size_, r_graph.input_size_, "graph.input_size_") &&
  208. IsEqual(this->output_nodes_info_, r_graph.output_nodes_info_, "graph.output_nodes_info_"));
  209. }
  210. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::operator==(const ComputeGraph &r_graph) const {
  211. // Firstly: Graph's members equal
  212. if ((!GraphMembersAreEqual(r_graph)) || (!GraphAttrsAreEqual(r_graph))) {
  213. return false;
  214. }
  215. // Secondly: Node equal means the link relationship between node and node itself equal
  216. for (const auto &left_node : nodes_) {
  217. if (left_node == nullptr) {
  218. GELOGE(GRAPH_FAILED, "left_node is nullptr");
  219. return false;
  220. }
  221. const auto &node_name = left_node->GetName();
  222. // After TopologicalSorting, node order can change, so find node by name
  223. const auto &right_node = r_graph.FindNode(node_name);
  224. GE_IF_BOOL_EXEC(right_node == nullptr, GELOGE(GRAPH_FAILED, "right_node is NULL!!!"); return false);
  225. if (!(*right_node == *left_node)) {
  226. GELOGE(GRAPH_FAILED, "Compare graph failed, node name: %s.", node_name.c_str());
  227. return false;
  228. }
  229. }
  230. // Thirdly: Recursively determine whether the sub graphs are equal
  231. for (size_t i = 0; i < this->sub_graph_.size(); i++) {
  232. if (!(*((this->sub_graph_)[i]) == *((r_graph.sub_graph_)[i]))) {
  233. return false;
  234. }
  235. }
  236. return true;
  237. }
  238. NodePtr ComputeGraph::AddNodeFront(NodePtr node) {
  239. if (node == nullptr || node->GetOpDesc() == nullptr) {
  240. GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null.");
  241. return nullptr;
  242. }
  243. node->SetHostNode(is_valid_flag_);
  244. node->GetOpDesc()->SetId(nodes_.size());
  245. if (nodes_.size() > 0 && nodes_[0]->GetType() == DATA) {
  246. (void)nodes_.insert(nodes_.begin() + 1, node);
  247. } else {
  248. (void)nodes_.insert(nodes_.begin(), node);
  249. }
  250. return node;
  251. }
  252. NodePtr ComputeGraph::AddNodeFront(const OpDescPtr &op) {
  253. if (op == nullptr) {
  254. GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null.");
  255. return nullptr;
  256. }
  257. op->SetId(nodes_.size());
  258. NodePtr node_ptr = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this()));
  259. GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr);
  260. GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr);
  261. return AddNodeFront(node_ptr);
  262. }
  263. NodePtr ComputeGraph::AddNodeAfter(NodePtr node, const NodePtr &pre_node) {
  264. if (node == nullptr || node->GetOpDesc() == nullptr || pre_node == nullptr) {
  265. GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null.");
  266. return nullptr;
  267. }
  268. node->SetHostNode(is_valid_flag_);
  269. node->GetOpDesc()->SetId(nodes_.size());
  270. auto node_iter = std::find(nodes_.begin(), nodes_.end(), pre_node);
  271. if (node_iter != nodes_.end()) {
  272. nodes_.insert(node_iter + 1, node);
  273. } else {
  274. GELOGE(GRAPH_FAILED, "Cannot find pre_node in nodes_.");
  275. return nullptr;
  276. }
  277. return node;
  278. }
  279. NodePtr ComputeGraph::AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node) {
  280. if (op == nullptr) {
  281. GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null.");
  282. return nullptr;
  283. }
  284. op->SetId(nodes_.size());
  285. NodePtr node_ptr = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this()));
  286. GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr);
  287. GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init failed."); return nullptr);
  288. return AddNodeAfter(node_ptr, pre_node);
  289. }
  290. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(NodePtr node) {
  291. if (node == nullptr || node->GetOpDesc() == nullptr) {
  292. GELOGE(GRAPH_FAILED, "The node ptr should not be null.");
  293. return nullptr;
  294. }
  295. node->SetHostNode(is_valid_flag_);
  296. node->GetOpDesc()->SetId((int64_t)GetDirectNodesSize());
  297. nodes_.push_back(node);
  298. return node;
  299. }
  300. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpDescPtr op) {
  301. if (op == nullptr) {
  302. GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null.");
  303. return nullptr;
  304. }
  305. op->SetId(GetDirectNodesSize());
  306. NodePtr node_ptr = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this()));
  307. GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr);
  308. GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr);
  309. return AddNode(node_ptr);
  310. }
  311. NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize.
  312. if (op == nullptr) {
  313. GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null.");
  314. return nullptr;
  315. }
  316. op->SetId(id);
  317. NodePtr node = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this()));
  318. GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr);
  319. GE_IF_BOOL_EXEC(node->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr);
  320. node->SetHostNode(is_valid_flag_);
  321. nodes_.push_back(node);
  322. return node;
  323. }
  324. NodePtr ComputeGraph::AddInputNode(NodePtr node) {
  325. if (node == nullptr) {
  326. GELOGE(GRAPH_FAILED, "The node ptr should not be null.");
  327. return nullptr;
  328. }
  329. input_nodes_.push_back(node);
  330. if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) {
  331. GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "add node failed");
  332. }
  333. return node;
  334. }
  335. NodePtr ComputeGraph::AddOutputNode(NodePtr node) { return AddOutputNodeByIndex(node, 0); }
  336. NodePtr ComputeGraph::AddOutputNodeByIndex(NodePtr node, int32_t index) {
  337. if (node == nullptr || node->GetOpDesc() == nullptr) {
  338. GELOGE(GRAPH_FAILED, "The node ptr or opdesc should not be null.");
  339. return nullptr;
  340. }
  341. bool already_have = false;
  342. NodePtr result = node;
  343. // [output_nodes_info_ : should not be null]
  344. for (const auto &item : output_nodes_info_) {
  345. if (item.first->GetName() == node->GetName() && item.second == index) {
  346. already_have = true;
  347. result = item.first;
  348. break;
  349. }
  350. }
  351. if (!already_have) {
  352. output_nodes_info_.emplace_back(std::make_pair(node, index));
  353. GELOGI("Push back node name:%s, index:%ld, into output_nodes_info_.", node->GetName().c_str(), index);
  354. }
  355. if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) {
  356. GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "add node failed");
  357. }
  358. return result;
  359. }
  360. graphStatus ComputeGraph::RemoveConstInput(const NodePtr &node) {
  361. GE_CHECK_NOTNULL(node);
  362. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  363. auto out_anchor = in_anchor->GetPeerOutAnchor();
  364. if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) {
  365. continue;
  366. }
  367. if (out_anchor->GetOwnerNode()->GetType() == CONSTANT || out_anchor->GetOwnerNode()->GetType() == CONSTANTOP) {
  368. GE_CHK_BOOL_RET_STATUS(GraphUtils::RemoveEdge(out_anchor, in_anchor) == GRAPH_SUCCESS, GRAPH_FAILED,
  369. "Remove edge from const op failed.");
  370. if (out_anchor->GetOwnerNode()->GetOutNodes().size() == 0) {
  371. GELOGI("Remove const op %s.", out_anchor->GetOwnerNode()->GetName().c_str());
  372. auto iter = find(nodes_.begin(), nodes_.end(), out_anchor->GetOwnerNode());
  373. if (iter != nodes_.end()) {
  374. (void)nodes_.erase(iter);
  375. }
  376. }
  377. }
  378. }
  379. return GRAPH_SUCCESS;
  380. }
  381. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveNode(const NodePtr &node) {
  382. if (node == nullptr) {
  383. GELOGE(GRAPH_FAILED, "The node ptr should not be null.");
  384. return GRAPH_FAILED;
  385. }
  386. // delete const op for this node
  387. (void)RemoveConstInput(node);
  388. // if the node save as input node, delete it
  389. (void)RemoveInputNode(node);
  390. // if the node save as input node, delete it
  391. (void)RemoveOutputNode(node);
  392. if (GRAPH_SUCCESS != IsolateNode(node)) {
  393. GELOGE(GRAPH_FAILED, "Isolate node failed, node name: %s.", node->GetName().c_str());
  394. return GRAPH_FAILED;
  395. }
  396. auto iter = find(nodes_.begin(), nodes_.end(), node);
  397. if (iter != nodes_.end()) {
  398. (void)nodes_.erase(iter);
  399. return GRAPH_SUCCESS;
  400. }
  401. return GRAPH_FAILED;
  402. }
  403. // Used in sub_graph scenes
  404. graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) {
  405. if (node == nullptr) {
  406. GELOGE(GRAPH_FAILED, "The node ptr should not be null.");
  407. return GRAPH_FAILED;
  408. }
  409. auto iter = find(input_nodes_.begin(), input_nodes_.end(), node);
  410. if (iter != input_nodes_.end()) {
  411. (void)input_nodes_.erase(iter);
  412. return GRAPH_SUCCESS;
  413. }
  414. return GRAPH_FAILED;
  415. }
  416. // Used in sub_graph scenes
  417. graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) {
  418. if (node == nullptr) {
  419. GELOGE(GRAPH_FAILED, "The node ptr should not be null.");
  420. return GRAPH_FAILED;
  421. }
  422. auto iter = output_nodes_info_.begin();
  423. bool find_node = false;
  424. // [output_nodes_info_ : should not be null]
  425. while (iter != output_nodes_info_.end()) {
  426. if (node->GetName() == iter->first->GetName()) {
  427. iter = output_nodes_info_.erase(iter);
  428. find_node = true;
  429. } else {
  430. ++iter;
  431. }
  432. }
  433. GE_IF_BOOL_EXEC(find_node == false, return GRAPH_FAILED);
  434. return GRAPH_SUCCESS;
  435. }
  436. std::shared_ptr<ComputeGraph> ComputeGraph::AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph) {
  437. if (sub_graph == nullptr) {
  438. GELOGE(GRAPH_FAILED, "The graph ptr should not be null.");
  439. return nullptr;
  440. }
  441. sub_graph_.push_back(sub_graph);
  442. names_to_subgraph_[sub_graph->GetName()] = sub_graph;
  443. return sub_graph;
  444. }
  445. graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph) {
  446. if (sub_graph == nullptr) {
  447. GELOGE(GRAPH_FAILED, "The graph ptr should not be null.");
  448. return GRAPH_FAILED;
  449. }
  450. names_to_subgraph_.erase(sub_graph->GetName());
  451. auto iter = find(sub_graph_.begin(), sub_graph_.end(), sub_graph);
  452. if (iter != sub_graph_.end()) {
  453. (void)sub_graph_.erase(iter);
  454. return GRAPH_SUCCESS;
  455. } else {
  456. GELOGW("find sub_graph failed");
  457. return GRAPH_SUCCESS;
  458. }
  459. }
  460. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  461. ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph) {
  462. if (subgraph == nullptr) {
  463. GE_LOGE("Try to add a null subgraph, name %s", name.c_str());
  464. return GRAPH_PARAM_INVALID;
  465. }
  466. auto parent_graph = subgraph->GetParentGraph();
  467. if (parent_graph == nullptr) {
  468. GE_LOGE("Try to add subgraph without parent graph, name %s", name.c_str());
  469. return GRAPH_PARAM_INVALID;
  470. }
  471. auto parent_node = subgraph->GetParentNode();
  472. if (parent_node == nullptr) {
  473. GE_LOGE("Try to add a subgraph without parent node, name %s", name.c_str());
  474. return GRAPH_PARAM_INVALID;
  475. }
  476. if (parent_node->GetOwnerComputeGraph() != parent_graph) {
  477. GE_LOGE(
  478. "Try to add a subgraph which parent node's parent graph is not equal to "
  479. "the subgraph's parent graph, subgraph name %s, parent node name %s",
  480. subgraph->GetName().c_str(), parent_graph->GetName().c_str());
  481. return GRAPH_PARAM_INVALID;
  482. }
  483. if (!this->parent_graph_.expired()) {
  484. GELOGW("The subgraphs should only be added to the root graph");
  485. }
  486. if (name != subgraph->GetName()) {
  487. GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str());
  488. }
  489. if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) {
  490. GE_LOGE("The subgraph %s existed", name.c_str());
  491. return GRAPH_PARAM_INVALID;
  492. }
  493. sub_graph_.push_back(subgraph);
  494. names_to_subgraph_[name] = subgraph;
  495. return GRAPH_SUCCESS;
  496. }
  497. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  498. ComputeGraph::AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph) {
  499. if (subgraph == nullptr) {
  500. return GRAPH_PARAM_INVALID;
  501. }
  502. return AddSubgraph(subgraph->GetName(), subgraph);
  503. }
  504. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(const std::string &name) {
  505. auto iter = names_to_subgraph_.find(name);
  506. if (iter == names_to_subgraph_.end()) {
  507. return;
  508. }
  509. for (auto vec_iter = sub_graph_.begin(); vec_iter != sub_graph_.end(); ++vec_iter) {
  510. if (*vec_iter == iter->second) {
  511. sub_graph_.erase(vec_iter);
  512. break;
  513. }
  514. }
  515. names_to_subgraph_.erase(iter);
  516. }
  517. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(
  518. const std::shared_ptr<ComputeGraph> &subgraph) {
  519. if (subgraph != nullptr) {
  520. RemoveSubgraph(subgraph->GetName());
  521. }
  522. }
  523. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr<ComputeGraph> ComputeGraph::GetSubgraph(
  524. const std::string &name) const {
  525. std::shared_ptr<ComputeGraph> parent = parent_graph_.lock();
  526. if (parent == nullptr) {
  527. auto iter = names_to_subgraph_.find(name);
  528. return iter == names_to_subgraph_.end() ? nullptr : iter->second;
  529. } else {
  530. return parent->GetSubgraph(name);
  531. }
  532. }
  533. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector<std::shared_ptr<ComputeGraph>>
  534. ComputeGraph::GetAllSubgraphs() const {
  535. return sub_graph_;
  536. }
  537. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<ComputeGraph> ComputeGraph::GetParentGraph() {
  538. return parent_graph_.lock();
  539. }
  540. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentGraph(
  541. const shared_ptr<ComputeGraph> &parent) {
  542. parent_graph_ = parent;
  543. }
  544. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<Node> ComputeGraph::GetParentNode() {
  545. return parent_node_.lock();
  546. }
  547. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode(const shared_ptr<Node> &parent) {
  548. parent_node_ = parent;
  549. }
  550. ///
  551. /// @brief Update input-mapping
  552. /// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input
  553. /// @return graphStatus
  554. ///
  555. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  556. ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping) {
  557. for (auto &input : nodes_) {
  558. if (input->GetType() == DATA) {
  559. uint32_t cur_index = 0;
  560. if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) {
  561. continue;
  562. }
  563. auto iter = input_mapping.find(cur_index);
  564. if (iter == input_mapping.end()) {
  565. continue;
  566. }
  567. if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) {
  568. GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed.");
  569. return GRAPH_FAILED;
  570. }
  571. }
  572. }
  573. return GRAPH_SUCCESS;
  574. }
  575. ///
  576. /// @brief Update output-mapping
  577. /// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output
  578. /// @return graphStatus
  579. ///
  580. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  581. ComputeGraph::UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping) {
  582. NodePtr net_output = FindFirstNodeMatchType(NETOUTPUT);
  583. if (net_output == nullptr) {
  584. GE_LOGE("UpdateOutputMapping failed: node type %s not exist in graph.", NETOUTPUT);
  585. return GRAPH_FAILED;
  586. }
  587. OpDescPtr op_desc = net_output->GetOpDesc();
  588. if (op_desc == nullptr) {
  589. GE_LOGE("UpdateOutputMapping failed: op_desc is NULL.");
  590. return GRAPH_FAILED;
  591. }
  592. size_t num = op_desc->GetAllInputsSize();
  593. for (size_t i = 0; i < num; i++) {
  594. GeTensorDesc tensor = op_desc->GetInputDesc(i);
  595. uint32_t cur_index = 0;
  596. if (!ge::AttrUtils::GetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, cur_index)) {
  597. continue;
  598. }
  599. auto iter = output_mapping.find(cur_index);
  600. if (iter == output_mapping.end()) {
  601. continue;
  602. }
  603. if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) {
  604. GE_LOGE("UpdateOutputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed.");
  605. return GRAPH_FAILED;
  606. }
  607. if (op_desc->UpdateInputDesc(i, tensor) != GRAPH_SUCCESS) {
  608. GE_LOGE("UpdateOutputMapping failed: update %u input_tensor failed.", i);
  609. return GRAPH_FAILED;
  610. }
  611. }
  612. return GRAPH_SUCCESS;
  613. }
  614. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() {
  615. std::vector<NodePtr> node_vec = nodes_;
  616. for (const auto &node : GetDirectNode()) {
  617. if (node == nullptr || node->GetOpDesc() == nullptr) {
  618. GELOGW("node or OpDescPtr is nullptr.");
  619. continue;
  620. }
  621. GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should not be null."); return GRAPH_FAILED);
  622. if (node->GetOpDesc()->GetType() == RECV) {
  623. auto iter = find(node_vec.begin(), node_vec.end(), node);
  624. if (iter == node_vec.end()) {
  625. GELOGW("no node found.");
  626. } else {
  627. (void)node_vec.erase(iter);
  628. }
  629. auto dst_iter = find(node_vec.begin(), node_vec.end(), node->GetOutControlNodes().at(0));
  630. (void)node_vec.insert(dst_iter, node);
  631. }
  632. if (node->GetOpDesc()->GetType() == SEND) {
  633. auto iter = find(node_vec.begin(), node_vec.end(), node);
  634. if (iter == node_vec.end()) {
  635. GELOGW("no node found.");
  636. } else {
  637. (void)node_vec.erase(iter);
  638. }
  639. auto src_iter = find(node_vec.begin(), node_vec.end(), node->GetInControlNodes().at(0));
  640. (void)node_vec.insert(src_iter + 1, node);
  641. }
  642. }
  643. nodes_.clear();
  644. for (size_t i = 0; i < node_vec.size(); ++i) {
  645. NodePtr node = node_vec[i];
  646. if (node == nullptr || node->GetOpDesc() == nullptr) {
  647. GELOGW("node or OpDescPtr is nullptr.");
  648. } else {
  649. node->GetOpDesc()->SetId((int64_t)i);
  650. nodes_.push_back(node);
  651. }
  652. }
  653. return GRAPH_SUCCESS;
  654. }
  655. graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec,
  656. std::map<NodePtr, uint32_t> &map_in_edge_num,
  657. std::vector<NodePtr> &stack) {
  658. GELOGI("Runing_Dfs_Sort: %s", name_.c_str());
  659. // Record the number of non data nodes but no input nodes
  660. GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed");
  661. // Only data nodes here
  662. while (!stack.empty()) {
  663. NodePtr node = stack.back();
  664. stack.pop_back();
  665. node_vec.push_back(node);
  666. GE_CHECK_NOTNULL(node->GetOpDesc());
  667. GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str());
  668. for (const auto &anchor : node->GetAllOutDataAnchors()) {
  669. GE_CHECK_NOTNULL(anchor);
  670. for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) {
  671. GE_CHECK_NOTNULL(peer_in_anchor);
  672. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  673. if (iter != map_in_edge_num.end() && --iter->second == 0) {
  674. stack.push_back(peer_in_anchor->GetOwnerNode());
  675. }
  676. }
  677. for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) {
  678. GE_CHECK_NOTNULL(peer_in_anchor);
  679. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  680. if (iter != map_in_edge_num.end() && --iter->second == 0) {
  681. stack.push_back(peer_in_anchor->GetOwnerNode());
  682. }
  683. }
  684. }
  685. GE_IF_BOOL_EXEC(
  686. node->GetOutControlAnchor() != nullptr, for (AnchorPtr peer_in_anchor
  687. : node->GetOutControlAnchor()->GetPeerAnchors()) {
  688. GE_CHECK_NOTNULL(peer_in_anchor);
  689. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  690. if (iter != map_in_edge_num.end() && --iter->second == 0) {
  691. stack.push_back(peer_in_anchor->GetOwnerNode());
  692. }
  693. })
  694. }
  695. return GRAPH_SUCCESS;
  696. }
  697. graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec,
  698. std::map<NodePtr, uint32_t> &map_in_edge_num,
  699. std::deque<NodePtr> &stack) {
  700. GELOGI("Runing_Bfs_Sort: %s", name_.c_str());
  701. std::vector<NodePtr> stack_input;
  702. std::map<string, NodePtr> breadth_node_map;
  703. // Record the number of non data nodes but no input nodes
  704. GE_CHK_BOOL_EXEC(SortNodes(stack_input, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed");
  705. // Only data nodes here
  706. while (!stack_input.empty() || !stack.empty()) {
  707. NodePtr node = nullptr;
  708. if (!stack.empty()) {
  709. node = stack.back();
  710. stack.pop_back();
  711. } else {
  712. node = stack_input.back();
  713. stack_input.pop_back();
  714. }
  715. node_vec.push_back(node);
  716. GE_CHECK_NOTNULL(node->GetOpDesc());
  717. GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str());
  718. CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map);
  719. for (const auto &name_node : breadth_node_map) {
  720. (void)stack.push_front(name_node.second);
  721. }
  722. breadth_node_map.clear();
  723. }
  724. return GRAPH_SUCCESS;
  725. }
  726. graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num,
  727. std::map<string, NodePtr> &breadth_node_map) {
  728. for (const auto &anchor : node->GetAllOutDataAnchors()) {
  729. for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) {
  730. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  731. if (iter != map_in_edge_num.end() && 0 == --iter->second) {
  732. (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode());
  733. }
  734. }
  735. for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) {
  736. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  737. if (iter != map_in_edge_num.end() && 0 == --iter->second) {
  738. (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode());
  739. }
  740. }
  741. }
  742. if (node->GetOutControlAnchor() != nullptr) {
  743. for (AnchorPtr peer_in_anchor : node->GetOutControlAnchor()->GetPeerAnchors()) {
  744. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  745. if (iter != map_in_edge_num.end() && 0 == --iter->second) {
  746. (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode());
  747. }
  748. }
  749. }
  750. return GRAPH_SUCCESS;
  751. }
  752. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() {
  753. auto ret = TopologicalSortingGraph();
  754. if (ret != SUCCESS) {
  755. GraphUtils::DumpGEGraphToOnnx(*this, "black_box");
  756. GELOGE(ret, "Graph [%s] topological sort failed, saved to file black_box", name_.c_str());
  757. return ret;
  758. }
  759. if (sub_graph_.empty()) {
  760. return SUCCESS;
  761. }
  762. // partition sub graph
  763. for (const auto &sub_graph : sub_graph_) {
  764. ret = sub_graph->TopologicalSortingGraph();
  765. if (ret != SUCCESS) {
  766. GELOGE(ret, "Sub graph topological sort Failed");
  767. return ret;
  768. }
  769. }
  770. std::vector<std::shared_ptr<ComputeGraph>> subgraphs;
  771. auto nodes = AllGraphNodes(subgraphs);
  772. for (size_t i = 0; i < nodes.size(); i++) {
  773. NodePtr node = nodes.at(i); // [node: should not be null]
  774. node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null]
  775. }
  776. if (sub_graph_.size() != subgraphs.size()) { // Graph Partition use subgraph, Keep original
  777. GELOGW("Keep original subgraph for graph size %zu not equal %zu.", sub_graph_.size(), subgraphs.size());
  778. return SUCCESS;
  779. }
  780. sub_graph_.swap(subgraphs);
  781. return SUCCESS;
  782. }
  783. graphStatus ComputeGraph::TopologicalSortingGraph() {
  784. std::vector<NodePtr> node_vec;
  785. std::map<NodePtr, uint32_t> map_in_edge_num;
  786. bool use_BFS = IsUseBFS();
  787. if (use_BFS) {
  788. std::deque<NodePtr> stack;
  789. if (BFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) {
  790. return GRAPH_FAILED;
  791. }
  792. } else {
  793. std::vector<NodePtr> stack;
  794. if (DFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) {
  795. return GRAPH_FAILED;
  796. }
  797. }
  798. // If they are not equal, there is a closed loop
  799. if (node_vec.size() != nodes_.size()) {
  800. std::set<Node *> itered_nodes_set;
  801. for (auto &node : node_vec) {
  802. itered_nodes_set.insert(node.get());
  803. }
  804. GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", nodes_.size(),
  805. node_vec.size());
  806. for (auto &node : nodes_) {
  807. if (itered_nodes_set.count(node.get()) == 0) {
  808. GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str());
  809. }
  810. }
  811. return GRAPH_FAILED;
  812. }
  813. nodes_.clear();
  814. for (size_t i = 0; i < node_vec.size(); i++) {
  815. NodePtr node = node_vec[i]; // [node: should not be null]
  816. node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null]
  817. nodes_.push_back(node);
  818. }
  819. is_valid_flag_ = true;
  820. return GRAPH_SUCCESS;
  821. }
  822. graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &map_in_edge_num) {
  823. // Record the number of non data nodes but no input nodes
  824. uint32_t spec_node_size = 0;
  825. bool verify_isolated = false;
  826. string run_mode;
  827. const int base = 10;
  828. // Need verify isolated point in PREDICTION mode.
  829. if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) {
  830. if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) < TRAIN) {
  831. verify_isolated = true;
  832. }
  833. }
  834. for (const auto &node : GetDirectNode()) {
  835. GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue);
  836. map_in_edge_num[node] = static_cast<uint32_t>(GetInEdgeSize(node));
  837. if (map_in_edge_num[node] == 0) {
  838. if ((node->GetOpDesc()->GetType() != DATA) && (node->GetOpDesc()->GetType() != AIPPDATA) &&
  839. (node->GetOpDesc()->GetType() != INPUT_TYPE) && (node->GetOpDesc()->GetType() != ANN_DATA)) {
  840. // At present, can only judge the isolated point without input and output.
  841. // It is impossible to judge the situation with multiple output nodes.
  842. if (verify_isolated && GetOutEdgeSize(node) == 0) {
  843. GELOGE(GRAPH_FAILED, "May has isolated nodes in graph, node name: %s.", node->GetName().c_str());
  844. return GRAPH_FAILED;
  845. }
  846. (void)stack.insert(stack.begin(), node);
  847. spec_node_size++;
  848. continue;
  849. }
  850. // Need to insert the data nodes in reverse order
  851. (void)stack.insert(stack.begin() + spec_node_size, node);
  852. }
  853. }
  854. /// Make sure the inputs order matches with user-designated
  855. /// 1. Get the index of two input nodes in the user-inputs-order(inputs_order_)
  856. /// 2. Compare two indices, if not match, swap the positions of two inputs
  857. /// *: Remind: stack is reverse-order
  858. for (size_t i = 0; i < stack.size(); ++i) {
  859. // If not found in 'inputs_order_', skip it
  860. auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName());
  861. GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue);
  862. auto inx_i = it_i - inputs_order_.begin();
  863. for (size_t j = i + 1; j < stack.size(); ++j) {
  864. // If not found in 'inputs_order_', skip it
  865. auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName());
  866. GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue);
  867. // Compare index, swap them if it should be
  868. auto inx_j = it_j - inputs_order_.begin();
  869. GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j]));
  870. }
  871. }
  872. return GRAPH_SUCCESS;
  873. }
  874. size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) {
  875. size_t in_edge_size = 0;
  876. if (node == nullptr) {
  877. return in_edge_size;
  878. }
  879. for (const auto &anchor : node->GetAllInDataAnchors()) {
  880. in_edge_size = in_edge_size + anchor->GetPeerAnchorsSize();
  881. // Break flow control data loop.
  882. OutDataAnchorPtr out_anchor = anchor->GetPeerOutAnchor();
  883. if ((out_anchor != nullptr) && (out_anchor->GetOwnerNode() != nullptr)) {
  884. NodePtr out_node = out_anchor->GetOwnerNode();
  885. if (out_node == nullptr) {
  886. GELOGW("out node is nullptr");
  887. continue;
  888. }
  889. if ((out_node->GetType() == NEXTITERATION) || (out_node->GetType() == REFNEXTITERATION)) {
  890. GE_IF_BOOL_EXEC(in_edge_size == 0, GELOGE(GRAPH_FAILED, "If [in_edge_size = 0], the result will be reversed");
  891. return in_edge_size);
  892. in_edge_size -= 1;
  893. }
  894. }
  895. }
  896. if (node->GetInControlAnchor() != nullptr) {
  897. in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchorsSize();
  898. }
  899. return in_edge_size;
  900. }
  901. size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) {
  902. size_t out_edge_size = 0;
  903. if (node == nullptr) {
  904. return out_edge_size;
  905. }
  906. // Break flow control data loop.
  907. if ((node->GetType() != NEXTITERATION) && (node->GetType() != REFNEXTITERATION)) {
  908. for (const auto &anchor : node->GetAllOutDataAnchors()) {
  909. if (anchor != nullptr) {
  910. out_edge_size = out_edge_size + anchor->GetPeerAnchors().size();
  911. }
  912. }
  913. }
  914. if (node->GetOutControlAnchor() != nullptr) {
  915. if (out_edge_size > (UINT64_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) {
  916. return 0;
  917. }
  918. out_edge_size = out_edge_size + node->GetOutControlAnchor()->GetPeerAnchors().size();
  919. }
  920. return out_edge_size;
  921. }
  922. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::IsValid() const { return is_valid_flag_; }
  923. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const {
  924. GELOGI("graph name = %s.", GetName().c_str());
  925. for (const auto &node : GetAllNodes()) {
  926. GELOGI("node name = %s.", node->GetName().c_str());
  927. for (const auto &anchor : node->GetAllOutDataAnchors()) {
  928. for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) {
  929. GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
  930. GELOGI("node name = %s, out data node name = %s.", node->GetName().c_str(),
  931. peer_in_anchor->GetOwnerNode()->GetName().c_str()));
  932. }
  933. for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) {
  934. GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
  935. GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(),
  936. peer_in_anchor->GetOwnerNode()->GetName().c_str()));
  937. }
  938. }
  939. auto out_control_anchor = node->GetOutControlAnchor();
  940. if (out_control_anchor != nullptr) {
  941. for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) {
  942. GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
  943. GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(),
  944. peer_in_anchor->GetOwnerNode()->GetName().c_str()));
  945. }
  946. for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) {
  947. GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
  948. GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(),
  949. peer_in_anchor->GetOwnerNode()->GetName().c_str()));
  950. }
  951. }
  952. }
  953. }
  954. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Swap(ComputeGraph &graph) {
  955. this->AttrHolder::Swap(graph);
  956. origGraph_.swap(graph.origGraph_);
  957. name_.swap(graph.name_);
  958. std::swap(graph_id_, graph.graph_id_);
  959. attrs_.Swap(graph.attrs_);
  960. nodes_.swap(graph.nodes_);
  961. all_nodes_infos_.swap(graph.all_nodes_infos_);
  962. target_nodes_info_.swap(graph.target_nodes_info_);
  963. input_nodes_.swap(graph.input_nodes_);
  964. inputs_order_.swap(graph.inputs_order_);
  965. std::swap(input_size_, graph.input_size_);
  966. out_nodes_map_.swap(graph.out_nodes_map_);
  967. std::swap(output_size_, graph.output_size_);
  968. output_nodes_info_.swap(graph.output_nodes_info_);
  969. sub_graph_.swap(graph.sub_graph_);
  970. names_to_subgraph_.swap(graph.names_to_subgraph_);
  971. parent_graph_.swap(graph.parent_graph_);
  972. parent_node_.swap(graph.parent_node_);
  973. // the members followed should not in the ComputeGraph class
  974. std::swap(is_valid_flag_, graph.is_valid_flag_);
  975. std::swap(is_summary_graph_, graph.is_summary_graph_);
  976. std::swap(need_iteration_, graph.need_iteration_);
  977. params_share_map_.swap(graph.params_share_map_);
  978. op_name_map_.swap(graph.op_name_map_);
  979. std::swap(session_id_, graph.session_id_);
  980. std::swap(data_format_, graph.data_format_);
  981. std::swap(is_unknown_shape_graph_, graph.is_unknown_shape_graph_);
  982. // Update Node owner.
  983. SetNodesOwner();
  984. graph.SetNodesOwner();
  985. }
  986. void ComputeGraph::SetNodesOwner() {
  987. for (const auto &node : nodes_) {
  988. if (node == nullptr) {
  989. continue;
  990. }
  991. node->SetOwnerComputeGraph(shared_from_this());
  992. }
  993. }
  994. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::IsolateNode(const NodePtr &node) {
  995. GE_CHECK_NOTNULL(node);
  996. auto next_nodes = node->GetOutAllNodes();
  997. // If there is input data side
  998. for (size_t i = 0; i < node->GetAllInDataAnchors().size(); i++) {
  999. auto in_data_anchor = node->GetInDataAnchor(static_cast<int>(i));
  1000. auto pre_out_data_anchor = in_data_anchor->GetPeerOutAnchor();
  1001. if (pre_out_data_anchor != nullptr) {
  1002. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_data_anchor, in_data_anchor) == GRAPH_SUCCESS,
  1003. return GRAPH_FAILED, "remove edge failed");
  1004. GE_IF_BOOL_EXEC(pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANT ||
  1005. pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANTOP,
  1006. continue);
  1007. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  1008. for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  1009. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS,
  1010. return GRAPH_FAILED, "remove edge failed");
  1011. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS,
  1012. return GRAPH_FAILED, "add edge failed");
  1013. }
  1014. for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) {
  1015. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1016. return GRAPH_FAILED, "remove edge failed");
  1017. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1018. return GRAPH_FAILED, "add edge failed");
  1019. }
  1020. }
  1021. auto out_ctrl_anchor = node->GetOutControlAnchor();
  1022. GE_CHECK_NOTNULL(out_ctrl_anchor);
  1023. auto pre_out_ctrl_anchor = pre_out_data_anchor->GetOwnerNode()->GetOutControlAnchor();
  1024. GE_CHECK_NOTNULL(pre_out_ctrl_anchor);
  1025. for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
  1026. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1027. return GRAPH_FAILED, "remove edge failed");
  1028. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1029. return GRAPH_FAILED, "add edge failed");
  1030. }
  1031. }
  1032. }
  1033. // If there is an input control side
  1034. auto in_ctrl_anchor = node->GetInControlAnchor();
  1035. GE_CHECK_NOTNULL(in_ctrl_anchor);
  1036. for (const auto &pre_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) {
  1037. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_ctrl_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, return GRAPH_FAILED,
  1038. "remove edge failed");
  1039. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  1040. for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) {
  1041. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1042. return GRAPH_FAILED, "remove edge failed");
  1043. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1044. return GRAPH_FAILED, "add edge failed");
  1045. }
  1046. }
  1047. auto out_ctrl_anchor = node->GetOutControlAnchor();
  1048. if (out_ctrl_anchor != nullptr) {
  1049. for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
  1050. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1051. return GRAPH_FAILED, "remove edge failed");
  1052. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1053. return GRAPH_FAILED, "add edge failed");
  1054. }
  1055. }
  1056. }
  1057. for (const auto &out_peer_data_anchor : in_ctrl_anchor->GetPeerOutDataAnchors()) {
  1058. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_peer_data_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, return GRAPH_FAILED,
  1059. "remove edge failed");
  1060. for (const auto &next_node : next_nodes) {
  1061. auto next_in_control_anchor = next_node->GetInControlAnchor();
  1062. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(out_peer_data_anchor, next_in_control_anchor) == GRAPH_SUCCESS,
  1063. return GRAPH_FAILED, "add edge failed");
  1064. }
  1065. }
  1066. return RemoveExtraOutEdge(node);
  1067. }
  1068. graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) {
  1069. GE_CHECK_NOTNULL(node);
  1070. // Remove redundant output edges
  1071. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  1072. for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  1073. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS,
  1074. return GRAPH_FAILED, "remove edge failed");
  1075. }
  1076. for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) {
  1077. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1078. return GRAPH_FAILED, "remove edge failed");
  1079. }
  1080. }
  1081. auto out_ctrl_anchor = node->GetOutControlAnchor();
  1082. if (out_ctrl_anchor != nullptr) {
  1083. for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
  1084. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1085. return GRAPH_FAILED, "remove edge failed");
  1086. }
  1087. }
  1088. return GRAPH_SUCCESS;
  1089. }
  1090. graphStatus ComputeGraph::Verify() {
  1091. bool is_unknown_graph = GetGraphUnknownFlag();
  1092. for (const auto &node_ptr : GetAllNodes()) {
  1093. GE_CHECK_NOTNULL(node_ptr);
  1094. GE_CHECK_NOTNULL(node_ptr->GetOpDesc());
  1095. GE_IF_BOOL_EXEC(is_unknown_graph, continue);
  1096. GE_CHK_BOOL_EXEC(node_ptr->GetOpDesc()->CommonVerify() == GRAPH_SUCCESS, return GRAPH_FAILED,
  1097. "Verifying %s failed.", node_ptr->GetName().c_str());
  1098. }
  1099. return GRAPH_SUCCESS;
  1100. }
  1101. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferOriginFormat() {
  1102. return ge::FormatRefiner::InferOrigineFormat(shared_from_this());
  1103. }
  1104. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferShapeInNeed() {
  1105. GE_CHK_BOOL_ONLY_LOG(TopologicalSorting() == GRAPH_SUCCESS, "Verifying failed.");
  1106. for (const auto &node_ptr : GetAllNodes()) {
  1107. GE_CHECK_NOTNULL(node_ptr);
  1108. auto op_desc = node_ptr->GetOpDesc();
  1109. bool is_need_infer = false;
  1110. (void)ge::AttrUtils::GetBool(op_desc, NEED_INFER, is_need_infer);
  1111. if (is_need_infer) {
  1112. GE_CHK_BOOL_EXEC(node_ptr->Verify() == GRAPH_SUCCESS, return GRAPH_FAILED, "Verifying %s failed.",
  1113. node_ptr->GetName().c_str());
  1114. graphStatus status = node_ptr->InferShapeAndType();
  1115. GE_CHK_BOOL_EXEC_INFO(node_ptr->GetType() == DATA || GRAPH_PARAM_INVALID != status, break,
  1116. "Op %s does not have the IMPLEMT_INFERFUNC definition,"
  1117. " and subsequent operators no longer perform shape inference.",
  1118. node_ptr->GetName().c_str());
  1119. GE_CHK_BOOL_EXEC(status == GRAPH_SUCCESS, return GRAPH_FAILED, "Inferring %s failed.",
  1120. node_ptr->GetName().c_str());
  1121. for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
  1122. GE_CHECK_NOTNULL(out_anchor->GetOwnerNode()->GetOpDesc());
  1123. auto output_tensor = out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx());
  1124. ge::TensorUtils::SetRealDimCnt(output_tensor, output_tensor.GetShape().GetDims().size());
  1125. (void)out_anchor->GetOwnerNode()->GetOpDesc()->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor);
  1126. for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
  1127. (void)peer_anchor->GetOwnerNode()->GetOpDesc()->UpdateInputDesc(peer_anchor->GetIdx(), output_tensor);
  1128. }
  1129. }
  1130. }
  1131. }
  1132. return GRAPH_SUCCESS;
  1133. }
  1134. ProtoAttrMapHelper ComputeGraph::MutableAttrMap() { return attrs_; }
  1135. ConstProtoAttrMapHelper ComputeGraph::GetAttrMap() const {
  1136. return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg());
  1137. }
  1138. const std::map<OperatorImplPtr, NodePtr> &ComputeGraph::GetAllNodesInfo() const { return all_nodes_infos_; }
  1139. void ComputeGraph::SetUserDefOutput(const std::string &output_name) {
  1140. if (output_name.empty()) {
  1141. return;
  1142. }
  1143. vector<string> nodes = StringUtils::Split(output_name, ';');
  1144. for (string node : nodes) {
  1145. vector<string> item = StringUtils::Split(node, ':');
  1146. if (item.size() != OUTPUT_PARAM_SIZE) {
  1147. GELOGW("invalid output param!input:%s", output_name.c_str());
  1148. continue;
  1149. }
  1150. int32_t index;
  1151. try {
  1152. index = stoi(StringUtils::Trim(item[1]));
  1153. } catch (const std::out_of_range &) {
  1154. GELOGW("outputname cause out of range execption!output_name:%s", output_name.c_str());
  1155. continue;
  1156. } catch (const std::invalid_argument &) {
  1157. GELOGW("outputname cause invalid argument!output_name:%s", output_name.c_str());
  1158. continue;
  1159. } catch (...) {
  1160. GELOGW("stoi fail! output_name:%s", output_name.c_str());
  1161. continue;
  1162. }
  1163. auto iter = out_nodes_map_.find(item[0]);
  1164. if (iter == out_nodes_map_.end()) {
  1165. out_nodes_map_[item[0]] = std::vector<int32_t>(1, index);
  1166. } else {
  1167. auto idx_iter = std::find(iter->second.begin(), iter->second.end(), index);
  1168. if (idx_iter == iter->second.end()) {
  1169. iter->second.push_back(index);
  1170. }
  1171. }
  1172. }
  1173. }
  1174. const std::string ComputeGraph::GetOutput() {
  1175. static const int resultDefaultSize = 2048;
  1176. string result;
  1177. result.reserve(resultDefaultSize);
  1178. auto iter = out_nodes_map_.begin();
  1179. while (iter != out_nodes_map_.end()) {
  1180. auto idxes = iter->second;
  1181. for (auto idx : idxes) {
  1182. (void)result.append(iter->first).append(":").append(std::to_string(idx)).append(";");
  1183. }
  1184. ++iter;
  1185. }
  1186. return result.substr(0, result.length() - 1);
  1187. }
  1188. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示