You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

compute_graph.cc 51 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/compute_graph.h"
  17. #include <deque>
  18. #include "./format_refiner.h"
  19. #include "./ge_context.h"
  20. #include "debug/ge_attr_define.h"
  21. #include "debug/ge_log.h"
  22. #include "debug/ge_op_types.h"
  23. #include "debug/ge_util.h"
  24. #include "framework/common/debug/ge_log.h"
  25. #include "common/util/error_manager/error_manager.h"
  26. #include "ge/ge_api_types.h"
  27. #include "graph/shape_refiner.h"
  28. #include "proto/ge_ir.pb.h"
  29. #include "utils/ge_ir_utils.h"
  30. #include "utils/graph_utils.h"
  31. #include "utils/node_utils.h"
  32. #include "utils/op_desc_utils.h"
  33. #include "utils/string_utils.h"
  34. #include "utils/tensor_utils.h"
  35. namespace ge {
  36. namespace {
  37. const size_t OUTPUT_PARAM_SIZE = 2;
  38. const std::string alias_name_attr = "_aliasName";
  39. bool IsUseBFS() {
  40. string run_mode;
  41. const int base = 10;
  42. if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) {
  43. if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) >= TRAIN) {
  44. return true;
  45. }
  46. } else {
  47. GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default.");
  48. }
  49. return false;
  50. }
  51. } // namespace
  52. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name)
  53. : name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) {
  54. attrs_.InitDefault();
  55. }
  56. ComputeGraph::~ComputeGraph() {}
  57. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string ComputeGraph::GetName() const { return name_; }
  58. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const string &name) { name_ = name; }
  59. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const {
  60. return GetAllNodes().size();
  61. }
  62. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetAllNodes() const {
  63. std::vector<std::shared_ptr<ComputeGraph>> subgraphs;
  64. return AllGraphNodes(subgraphs);
  65. }
  66. ComputeGraph::Vistor<NodePtr> ComputeGraph::AllGraphNodes(std::vector<std::shared_ptr<ComputeGraph>> &subgraphs) const {
  67. std::vector<NodePtr> all_nodes;
  68. std::deque<NodePtr> candidates;
  69. candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end());
  70. while (!candidates.empty()) {
  71. NodePtr node = candidates.front();
  72. all_nodes.emplace_back(node);
  73. candidates.pop_front();
  74. OpDescPtr op_desc = node->GetOpDesc();
  75. if (op_desc == nullptr) {
  76. continue;
  77. }
  78. const auto &subgraph_names = op_desc->GetSubgraphInstanceNames();
  79. for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) {
  80. auto subgraph = GetSubgraph(*name_iter);
  81. if (subgraph != nullptr) {
  82. subgraphs.emplace_back(subgraph);
  83. candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end());
  84. }
  85. }
  86. }
  87. return Vistor<NodePtr>(shared_from_this(), all_nodes);
  88. }
  89. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  90. ComputeGraph::Vistor<NodePtr> ComputeGraph::GetNodes(bool is_unknown_shape) const {
  91. if (is_unknown_shape) {
  92. return GetDirectNode();
  93. } else {
  94. return GetAllNodes();
  95. }
  96. }
  97. size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); }
  98. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetDirectNode() const {
  99. return Vistor<NodePtr>(shared_from_this(), nodes_);
  100. }
  101. ComputeGraph::Vistor<NodePtr> ComputeGraph::GetInputNodes() const {
  102. return Vistor<NodePtr>(shared_from_this(), input_nodes_);
  103. }
  104. ComputeGraph::Vistor<NodePtr> ComputeGraph::GetOutputNodes() const {
  105. std::vector<NodePtr> result;
  106. for (auto iter = output_nodes_info_.begin(); iter != output_nodes_info_.end(); ++iter) {
  107. result.push_back(iter->first);
  108. }
  109. return Vistor<NodePtr>(shared_from_this(), result);
  110. }
  111. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::FindNode(const std::string &name) const {
  112. for (const auto &node : nodes_) {
  113. if (node == nullptr) {
  114. continue;
  115. }
  116. if (node->GetName() == name) {
  117. return node;
  118. }
  119. std::vector<string> out_alias_name;
  120. if (AttrUtils::GetListStr(node->GetOpDesc(), alias_name_attr, out_alias_name)) {
  121. for (const auto &alias_name : out_alias_name) {
  122. if (alias_name == name) {
  123. return node;
  124. }
  125. }
  126. }
  127. }
  128. return nullptr;
  129. }
  130. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  131. NodePtr ComputeGraph::FindFirstNodeMatchType(const std::string &name) const {
  132. for (const auto &node : nodes_) {
  133. if (node == nullptr) {
  134. continue;
  135. }
  136. if (node->GetType() == name) {
  137. return node;
  138. }
  139. }
  140. return nullptr;
  141. }
  142. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphAttrsAreEqual(
  143. const ComputeGraph &r_graph) const {
  144. // ProtoMsgOwner <::google::protobuf::Message> is temporarily ignored
  145. if ((this->attrs_.protoMsg_ != nullptr) && (r_graph.attrs_.protoMsg_ != nullptr)) {
  146. const auto &proto_attr_map = *(this->attrs_.protoMsg_);
  147. const auto &r_proto_attr_map = *(r_graph.attrs_.protoMsg_);
  148. // 1.Verify graph's ProtoAttrMap size
  149. if (proto_attr_map.size() != r_proto_attr_map.size()) {
  150. GELOGE(GRAPH_FAILED, "Size of compute graph's ProtoAttrMap verify failed, graph name: %s.",
  151. this->GetName().c_str());
  152. return false;
  153. }
  154. // 2.Verify graph's ProtoAttrMap key, verify values is temporarily not implemented
  155. for (const auto &it : proto_attr_map) {
  156. if (r_proto_attr_map.count(it.first) == 0) {
  157. GELOGE(GRAPH_FAILED, "Key of compute graph's ProtoAttrMap verify failed, graph name: %s key name: %s.",
  158. this->GetName().c_str(), it.first.c_str());
  159. return false;
  160. }
  161. }
  162. return true;
  163. }
  164. return ((this->attrs_.protoMsg_ == nullptr) && (r_graph.attrs_.protoMsg_ == nullptr));
  165. }
  166. /// Since there may be different input nodes
  167. /// chosen by user in the same graph, special judgment is needed
  168. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNodePtrIsEqual(
  169. const std::vector<NodePtr> &left_nodes, const std::vector<NodePtr> &right_nodes) const {
  170. const auto left_nodes_size = left_nodes.size();
  171. const auto right_nodes_size = right_nodes.size();
  172. if (left_nodes_size != right_nodes_size) {
  173. GELOGE(GRAPH_FAILED,
  174. "Check failed with graph input_nodes_: "
  175. "left inputNodes size %zu is different with right inputNodes size %zu .",
  176. left_nodes_size, right_nodes_size);
  177. return false;
  178. }
  179. for (size_t j = 0; j < left_nodes_size; j++) {
  180. if (left_nodes.at(j) == nullptr || right_nodes.at(j) == nullptr) {
  181. GELOGE(GRAPH_FAILED, "left_nodes.at(%zu) or right_nodes.at(%zu) is nullptr", j, j);
  182. return false;
  183. }
  184. const auto &left_input_name = left_nodes.at(j)->GetName();
  185. const auto &right_input_name = right_nodes.at(j)->GetName();
  186. if (left_input_name != right_input_name) {
  187. GELOGE(GRAPH_FAILED,
  188. "Check failed with graph input_nodes_: "
  189. "left inputNode name %s is different with right inputNode name %s at inputNodes index %zu.",
  190. left_input_name.c_str(), right_input_name.c_str(), j);
  191. return false;
  192. }
  193. }
  194. return true;
  195. }
  196. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual(
  197. const ComputeGraph &r_graph) const {
  198. return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.subgraphs_.size()") &&
  199. IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") &&
  200. VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) &&
  201. IsEqual(this->name_, r_graph.name_, "graph.name_") &&
  202. IsEqual(this->is_valid_flag_, r_graph.is_valid_flag_, "graph.is_valid_flag_") &&
  203. IsEqual(this->need_iteration_, r_graph.need_iteration_, "graph.need_iteration_") &&
  204. IsEqual(this->params_share_map_, r_graph.params_share_map_, "graph.params_share_map_") &&
  205. IsEqual(this->out_nodes_map_, r_graph.out_nodes_map_, "graph.out_nodes_map_") &&
  206. IsEqual(this->inputs_order_, r_graph.inputs_order_, "graph.inputs_order_") &&
  207. IsEqual(this->output_size_, r_graph.output_size_, "graph.output_size_") &&
  208. IsEqual(this->input_size_, r_graph.input_size_, "graph.input_size_") &&
  209. IsEqual(this->output_nodes_info_, r_graph.output_nodes_info_, "graph.output_nodes_info_"));
  210. }
  211. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::operator==(const ComputeGraph &r_graph) const {
  212. // Firstly: Graph's members equal
  213. if ((!GraphMembersAreEqual(r_graph)) || (!GraphAttrsAreEqual(r_graph))) {
  214. return false;
  215. }
  216. // Secondly: Node equal means the link relationship between node and node itself equal
  217. for (const auto &left_node : nodes_) {
  218. if (left_node == nullptr) {
  219. GELOGE(GRAPH_FAILED, "left_node is nullptr");
  220. return false;
  221. }
  222. const auto &node_name = left_node->GetName();
  223. // After TopologicalSorting, node order can change, so find node by name
  224. const auto &right_node = r_graph.FindNode(node_name);
  225. GE_IF_BOOL_EXEC(right_node == nullptr, GELOGE(GRAPH_FAILED, "right_node is NULL!!!"); return false);
  226. if (!(*right_node == *left_node)) {
  227. GELOGE(GRAPH_FAILED, "Compare graph failed, node name: %s.", node_name.c_str());
  228. return false;
  229. }
  230. }
  231. // Thirdly: Recursively determine whether the sub graphs are equal
  232. for (size_t i = 0; i < this->sub_graph_.size(); i++) {
  233. if (!(*((this->sub_graph_)[i]) == *((r_graph.sub_graph_)[i]))) {
  234. return false;
  235. }
  236. }
  237. return true;
  238. }
  239. NodePtr ComputeGraph::AddNodeFront(NodePtr node) {
  240. if (node == nullptr || node->GetOpDesc() == nullptr) {
  241. GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null.");
  242. return nullptr;
  243. }
  244. node->SetHostNode(is_valid_flag_);
  245. node->GetOpDesc()->SetId(nodes_.size());
  246. if (nodes_.size() > 0 && nodes_[0]->GetType() == DATA) {
  247. (void)nodes_.insert(nodes_.begin() + 1, node);
  248. } else {
  249. (void)nodes_.insert(nodes_.begin(), node);
  250. }
  251. return node;
  252. }
  253. NodePtr ComputeGraph::AddNodeFront(const OpDescPtr &op) {
  254. if (op == nullptr) {
  255. GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null.");
  256. return nullptr;
  257. }
  258. op->SetId(nodes_.size());
  259. NodePtr node_ptr = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this()));
  260. GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr);
  261. GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr);
  262. return AddNodeFront(node_ptr);
  263. }
  264. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(NodePtr node) {
  265. if (node == nullptr || node->GetOpDesc() == nullptr) {
  266. GELOGE(GRAPH_FAILED, "The node ptr should not be null.");
  267. return nullptr;
  268. }
  269. node->SetHostNode(is_valid_flag_);
  270. node->GetOpDesc()->SetId((int64_t)GetDirectNodesSize());
  271. nodes_.push_back(node);
  272. return node;
  273. }
  274. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpDescPtr op) {
  275. if (op == nullptr) {
  276. GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null.");
  277. return nullptr;
  278. }
  279. op->SetId(GetDirectNodesSize());
  280. NodePtr node_ptr = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this()));
  281. GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr);
  282. GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr);
  283. return AddNode(node_ptr);
  284. }
  285. NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize.
  286. if (op == nullptr) {
  287. GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null.");
  288. return nullptr;
  289. }
  290. op->SetId(id);
  291. NodePtr node = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this()));
  292. GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr);
  293. GE_IF_BOOL_EXEC(node->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr);
  294. node->SetHostNode(is_valid_flag_);
  295. nodes_.push_back(node);
  296. return node;
  297. }
  298. NodePtr ComputeGraph::AddInputNode(NodePtr node) {
  299. if (node == nullptr) {
  300. GELOGE(GRAPH_FAILED, "The node ptr should not be null.");
  301. return nullptr;
  302. }
  303. input_nodes_.push_back(node);
  304. if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) {
  305. GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "add node failed");
  306. }
  307. return node;
  308. }
  309. NodePtr ComputeGraph::AddOutputNode(NodePtr node) {
  310. return AddOutputNodeByIndex(node, 0);
  311. }
  312. NodePtr ComputeGraph::AddOutputNodeByIndex(NodePtr node, int32_t index) {
  313. if (node == nullptr || node->GetOpDesc() == nullptr) {
  314. GELOGE(GRAPH_FAILED, "The node ptr or opdesc should not be null.");
  315. return nullptr;
  316. }
  317. bool already_have = false;
  318. NodePtr result = node;
  319. // [output_nodes_info_ : should not be null]
  320. for (const auto &item : output_nodes_info_) {
  321. if (item.first->GetName() == node->GetName() && item.second == index) {
  322. already_have = true;
  323. result = item.first;
  324. break;
  325. }
  326. }
  327. if (!already_have) {
  328. output_nodes_info_.emplace_back(std::make_pair(node, index));
  329. GELOGI("Push back node name:%s, index:%ld, into output_nodes_info_.", node->GetName().c_str(), index);
  330. }
  331. if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) {
  332. GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "add node failed");
  333. }
  334. return result;
  335. }
  336. graphStatus ComputeGraph::RemoveConstInput(const NodePtr &node) {
  337. GE_CHECK_NOTNULL(node);
  338. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  339. auto out_anchor = in_anchor->GetPeerOutAnchor();
  340. if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) {
  341. continue;
  342. }
  343. if (out_anchor->GetOwnerNode()->GetType() == CONSTANT || out_anchor->GetOwnerNode()->GetType() == CONSTANTOP) {
  344. GE_CHK_BOOL_RET_STATUS(GraphUtils::RemoveEdge(out_anchor, in_anchor) == GRAPH_SUCCESS, GRAPH_FAILED,
  345. "Remove edge from const op failed.");
  346. if (out_anchor->GetOwnerNode()->GetOutNodes().size() == 0) {
  347. GELOGI("Remove const op %s.", out_anchor->GetOwnerNode()->GetName().c_str());
  348. auto iter = find(nodes_.begin(), nodes_.end(), out_anchor->GetOwnerNode());
  349. if (iter != nodes_.end()) {
  350. (void)nodes_.erase(iter);
  351. }
  352. }
  353. }
  354. }
  355. return GRAPH_SUCCESS;
  356. }
  357. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveNode(const NodePtr &node) {
  358. if (node == nullptr) {
  359. GELOGE(GRAPH_FAILED, "The node ptr should not be null.");
  360. return GRAPH_FAILED;
  361. }
  362. // delete const op for this node
  363. (void)RemoveConstInput(node);
  364. // if the node save as input node, delete it
  365. (void)RemoveInputNode(node);
  366. // if the node save as input node, delete it
  367. (void)RemoveOutputNode(node);
  368. if (GRAPH_SUCCESS != IsolateNode(node)) {
  369. GELOGE(GRAPH_FAILED, "Isolate node failed, node name: %s.", node->GetName().c_str());
  370. return GRAPH_FAILED;
  371. }
  372. auto iter = find(nodes_.begin(), nodes_.end(), node);
  373. if (iter != nodes_.end()) {
  374. (void)nodes_.erase(iter);
  375. return GRAPH_SUCCESS;
  376. }
  377. return GRAPH_FAILED;
  378. }
  379. // Used in sub_graph scenes
  380. graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) {
  381. if (node == nullptr) {
  382. GELOGE(GRAPH_FAILED, "The node ptr should not be null.");
  383. return GRAPH_FAILED;
  384. }
  385. auto iter = find(input_nodes_.begin(), input_nodes_.end(), node);
  386. if (iter != input_nodes_.end()) {
  387. (void)input_nodes_.erase(iter);
  388. return GRAPH_SUCCESS;
  389. }
  390. return GRAPH_FAILED;
  391. }
  392. // Used in sub_graph scenes
  393. graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) {
  394. if (node == nullptr) {
  395. GELOGE(GRAPH_FAILED, "The node ptr should not be null.");
  396. return GRAPH_FAILED;
  397. }
  398. auto iter = output_nodes_info_.begin();
  399. bool find_node = false;
  400. // [output_nodes_info_ : should not be null]
  401. while (iter != output_nodes_info_.end()) {
  402. if (node->GetName() == iter->first->GetName()) {
  403. iter = output_nodes_info_.erase(iter);
  404. find_node = true;
  405. } else {
  406. ++iter;
  407. }
  408. }
  409. GE_IF_BOOL_EXEC(find_node == false, return GRAPH_FAILED);
  410. return GRAPH_SUCCESS;
  411. }
  412. std::shared_ptr<ComputeGraph> ComputeGraph::AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph) {
  413. if (sub_graph == nullptr) {
  414. GELOGE(GRAPH_FAILED, "The graph ptr should not be null.");
  415. return nullptr;
  416. }
  417. sub_graph_.push_back(sub_graph);
  418. names_to_subgraph_[sub_graph->GetName()] = sub_graph;
  419. return sub_graph;
  420. }
  421. graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph) {
  422. if (sub_graph == nullptr) {
  423. GELOGE(GRAPH_FAILED, "The graph ptr should not be null.");
  424. return GRAPH_FAILED;
  425. }
  426. names_to_subgraph_.erase(sub_graph->GetName());
  427. auto iter = find(sub_graph_.begin(), sub_graph_.end(), sub_graph);
  428. if (iter != sub_graph_.end()) {
  429. (void)sub_graph_.erase(iter);
  430. return GRAPH_SUCCESS;
  431. } else {
  432. GELOGW("find sub_graph failed");
  433. return GRAPH_SUCCESS;
  434. }
  435. }
  436. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  437. ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph) {
  438. if (subgraph == nullptr) {
  439. GE_LOGE("Try to add a null subgraph, name %s", name.c_str());
  440. return GRAPH_PARAM_INVALID;
  441. }
  442. auto parent_graph = subgraph->GetParentGraph();
  443. if (parent_graph == nullptr) {
  444. GE_LOGE("Try to add subgraph without parent graph, name %s", name.c_str());
  445. return GRAPH_PARAM_INVALID;
  446. }
  447. auto parent_node = subgraph->GetParentNode();
  448. if (parent_node == nullptr) {
  449. GE_LOGE("Try to add a subgraph without parent node, name %s", name.c_str());
  450. return GRAPH_PARAM_INVALID;
  451. }
  452. if (parent_node->GetOwnerComputeGraph() != parent_graph) {
  453. GE_LOGE(
  454. "Try to add a subgraph which parent node's parent graph is not equal to "
  455. "the subgraph's parent graph, subgraph name %s, parent node name %s",
  456. subgraph->GetName().c_str(), parent_graph->GetName().c_str());
  457. return GRAPH_PARAM_INVALID;
  458. }
  459. if (!this->parent_graph_.expired()) {
  460. GELOGW("The subgraphs should only be added to the root graph");
  461. }
  462. if (name != subgraph->GetName()) {
  463. GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str());
  464. }
  465. if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) {
  466. GE_LOGE("The subgraph %s existed", name.c_str());
  467. return GRAPH_PARAM_INVALID;
  468. }
  469. sub_graph_.push_back(subgraph);
  470. names_to_subgraph_[name] = subgraph;
  471. return GRAPH_SUCCESS;
  472. }
  473. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  474. ComputeGraph::AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph) {
  475. if (subgraph == nullptr) {
  476. return GRAPH_PARAM_INVALID;
  477. }
  478. return AddSubgraph(subgraph->GetName(), subgraph);
  479. }
  480. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(const std::string &name) {
  481. auto iter = names_to_subgraph_.find(name);
  482. if (iter == names_to_subgraph_.end()) {
  483. return;
  484. }
  485. for (auto vec_iter = sub_graph_.begin(); vec_iter != sub_graph_.end(); ++vec_iter) {
  486. if (*vec_iter == iter->second) {
  487. sub_graph_.erase(vec_iter);
  488. break;
  489. }
  490. }
  491. names_to_subgraph_.erase(iter);
  492. }
  493. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(
  494. const std::shared_ptr<ComputeGraph> &subgraph) {
  495. if (subgraph != nullptr) {
  496. RemoveSubgraph(subgraph->GetName());
  497. }
  498. }
  499. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr<ComputeGraph> ComputeGraph::GetSubgraph(
  500. const std::string &name) const {
  501. std::shared_ptr<ComputeGraph> parent = parent_graph_.lock();
  502. if (parent == nullptr) {
  503. auto iter = names_to_subgraph_.find(name);
  504. return iter == names_to_subgraph_.end() ? nullptr : iter->second;
  505. } else {
  506. return parent->GetSubgraph(name);
  507. }
  508. }
  509. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector<std::shared_ptr<ComputeGraph>>
  510. ComputeGraph::GetAllSubgraphs() const {
  511. return sub_graph_;
  512. }
  513. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<ComputeGraph> ComputeGraph::GetParentGraph() {
  514. return parent_graph_.lock();
  515. }
  516. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentGraph(
  517. const shared_ptr<ComputeGraph> &parent) {
  518. parent_graph_ = parent;
  519. }
  520. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<Node> ComputeGraph::GetParentNode() {
  521. return parent_node_.lock();
  522. }
  523. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode(const shared_ptr<Node> &parent) {
  524. parent_node_ = parent;
  525. }
  526. ///
  527. /// @brief Update input-mapping
  528. /// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input
  529. /// @return graphStatus
  530. ///
  531. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  532. ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping) {
  533. for (auto &input : nodes_) {
  534. if (input->GetType() == DATA) {
  535. uint32_t cur_index = 0;
  536. if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) {
  537. continue;
  538. }
  539. auto iter = input_mapping.find(cur_index);
  540. if (iter == input_mapping.end()) {
  541. continue;
  542. }
  543. if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) {
  544. GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed.");
  545. return GRAPH_FAILED;
  546. }
  547. }
  548. }
  549. return GRAPH_SUCCESS;
  550. }
  551. ///
  552. /// @brief Update output-mapping
  553. /// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output
  554. /// @return graphStatus
  555. ///
  556. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  557. ComputeGraph::UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping) {
  558. NodePtr net_output = FindFirstNodeMatchType(NETOUTPUT);
  559. if (net_output == nullptr) {
  560. GE_LOGE("UpdateOutputMapping failed: node type %s not exist in graph.", NETOUTPUT);
  561. return GRAPH_FAILED;
  562. }
  563. OpDescPtr op_desc = net_output->GetOpDesc();
  564. if (op_desc == nullptr) {
  565. GE_LOGE("UpdateOutputMapping failed: op_desc is NULL.");
  566. return GRAPH_FAILED;
  567. }
  568. size_t num = op_desc->GetAllInputsSize();
  569. for (size_t i = 0; i < num; i++) {
  570. GeTensorDesc tensor = op_desc->GetInputDesc(i);
  571. uint32_t cur_index = 0;
  572. if (!ge::AttrUtils::GetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, cur_index)) {
  573. continue;
  574. }
  575. auto iter = output_mapping.find(cur_index);
  576. if (iter == output_mapping.end()) {
  577. continue;
  578. }
  579. if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) {
  580. GE_LOGE("UpdateOutputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed.");
  581. return GRAPH_FAILED;
  582. }
  583. if (op_desc->UpdateInputDesc(i, tensor) != GRAPH_SUCCESS) {
  584. GE_LOGE("UpdateOutputMapping failed: update %u input_tensor failed.", i);
  585. return GRAPH_FAILED;
  586. }
  587. }
  588. return GRAPH_SUCCESS;
  589. }
  590. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() {
  591. std::vector<NodePtr> node_vec = nodes_;
  592. for (const auto &node : GetDirectNode()) {
  593. if (node == nullptr || node->GetOpDesc() == nullptr) {
  594. GELOGW("node or OpDescPtr is nullptr.");
  595. continue;
  596. }
  597. GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should not be null."); return GRAPH_FAILED);
  598. if (node->GetOpDesc()->GetType() == RECV) {
  599. auto iter = find(node_vec.begin(), node_vec.end(), node);
  600. if (iter == node_vec.end()) {
  601. GELOGW("no node found.");
  602. } else {
  603. (void)node_vec.erase(iter);
  604. }
  605. auto dst_iter = find(node_vec.begin(), node_vec.end(), node->GetOutControlNodes().at(0));
  606. (void)node_vec.insert(dst_iter, node);
  607. }
  608. if (node->GetOpDesc()->GetType() == SEND) {
  609. auto iter = find(node_vec.begin(), node_vec.end(), node);
  610. if (iter == node_vec.end()) {
  611. GELOGW("no node found.");
  612. } else {
  613. (void)node_vec.erase(iter);
  614. }
  615. auto src_iter = find(node_vec.begin(), node_vec.end(), node->GetInControlNodes().at(0));
  616. (void)node_vec.insert(src_iter + 1, node);
  617. }
  618. }
  619. nodes_.clear();
  620. for (size_t i = 0; i < node_vec.size(); ++i) {
  621. NodePtr node = node_vec[i];
  622. if (node == nullptr || node->GetOpDesc() == nullptr) {
  623. GELOGW("node or OpDescPtr is nullptr.");
  624. } else {
  625. node->GetOpDesc()->SetId((int64_t)i);
  626. nodes_.push_back(node);
  627. }
  628. }
  629. return GRAPH_SUCCESS;
  630. }
  631. graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec,
  632. std::map<NodePtr, uint32_t> &map_in_edge_num,
  633. std::vector<NodePtr> &stack, bool reverse) {
  634. GELOGD("Runing_Dfs_Sort: %s", name_.c_str());
  635. // Record the number of non data nodes but no input nodes
  636. GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed");
  637. std::vector<NodePtr> out_nodes;
  638. auto stack_push = [&reverse, &stack](std::vector<NodePtr>& out_nodes) {
  639. if (reverse) {
  640. std::reverse(out_nodes.begin(), out_nodes.end());
  641. }
  642. stack.insert(stack.end(), out_nodes.begin(), out_nodes.end());
  643. out_nodes.clear();
  644. };
  645. // Only data nodes here
  646. while (!stack.empty()) {
  647. NodePtr node = stack.back();
  648. stack.pop_back();
  649. node_vec.push_back(node);
  650. GE_CHECK_NOTNULL(node->GetOpDesc());
  651. GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str());
  652. for (const auto &anchor : node->GetAllOutDataAnchors()) {
  653. GE_CHECK_NOTNULL(anchor);
  654. for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) {
  655. GE_CHECK_NOTNULL(peer_in_anchor);
  656. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  657. if (iter != map_in_edge_num.end() && --iter->second == 0) {
  658. out_nodes.push_back(peer_in_anchor->GetOwnerNode());
  659. }
  660. }
  661. stack_push(out_nodes);
  662. for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) {
  663. GE_CHECK_NOTNULL(peer_in_anchor);
  664. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  665. if (iter != map_in_edge_num.end() && --iter->second == 0) {
  666. out_nodes.push_back(peer_in_anchor->GetOwnerNode());
  667. }
  668. }
  669. stack_push(out_nodes);
  670. }
  671. GE_IF_BOOL_EXEC(
  672. node->GetOutControlAnchor() != nullptr, for (AnchorPtr peer_in_anchor
  673. : node->GetOutControlAnchor()->GetPeerAnchors()) {
  674. GE_CHECK_NOTNULL(peer_in_anchor);
  675. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  676. if (iter != map_in_edge_num.end() && --iter->second == 0) {
  677. out_nodes.push_back(peer_in_anchor->GetOwnerNode());
  678. }
  679. }
  680. stack_push(out_nodes);)
  681. }
  682. return GRAPH_SUCCESS;
  683. }
  684. graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec,
  685. std::map<NodePtr, uint32_t> &map_in_edge_num,
  686. std::deque<NodePtr> &stack) {
  687. GELOGI("Runing_Bfs_Sort: %s", name_.c_str());
  688. std::vector<NodePtr> stack_input;
  689. std::map<string, NodePtr> breadth_node_map;
  690. // Record the number of non data nodes but no input nodes
  691. GE_CHK_BOOL_EXEC(SortNodes(stack_input, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed");
  692. // Only data nodes here
  693. while (!stack_input.empty() || !stack.empty()) {
  694. NodePtr node = nullptr;
  695. if (!stack.empty()) {
  696. node = stack.back();
  697. stack.pop_back();
  698. } else {
  699. node = stack_input.back();
  700. stack_input.pop_back();
  701. }
  702. node_vec.push_back(node);
  703. GE_CHECK_NOTNULL(node->GetOpDesc());
  704. GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str());
  705. CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map);
  706. for (const auto &name_node : breadth_node_map) {
  707. (void)stack.push_front(name_node.second);
  708. }
  709. breadth_node_map.clear();
  710. }
  711. return GRAPH_SUCCESS;
  712. }
  713. graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num,
  714. std::map<string, NodePtr> &breadth_node_map) {
  715. for (const auto &anchor : node->GetAllOutDataAnchors()) {
  716. for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) {
  717. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  718. if (iter != map_in_edge_num.end() && --iter->second == 0) {
  719. (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode());
  720. }
  721. }
  722. for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) {
  723. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  724. if (iter != map_in_edge_num.end() && --iter->second == 0) {
  725. (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode());
  726. }
  727. }
  728. }
  729. if (node->GetOutControlAnchor() != nullptr) {
  730. for (AnchorPtr peer_in_anchor : node->GetOutControlAnchor()->GetPeerAnchors()) {
  731. auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode());
  732. if (iter != map_in_edge_num.end() && --iter->second == 0) {
  733. (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode());
  734. }
  735. }
  736. }
  737. return GRAPH_SUCCESS;
  738. }
  739. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() {
  740. auto ret = TopologicalSortingGraph();
  741. if (ret != SUCCESS) {
  742. GraphUtils::DumpGEGraphToOnnx(*this, "black_box");
  743. GELOGE(ret, "Graph [%s] topological sort failed, saved to file black_box", name_.c_str());
  744. return ret;
  745. }
  746. if (sub_graph_.empty()) {
  747. return SUCCESS;
  748. }
  749. // partition sub graph
  750. for (const auto &sub_graph : sub_graph_) {
  751. ret = sub_graph->TopologicalSortingGraph();
  752. if (ret != SUCCESS) {
  753. GELOGE(ret, "Sub graph topological sort Failed");
  754. return ret;
  755. }
  756. }
  757. std::vector<std::shared_ptr<ComputeGraph>> subgraphs;
  758. auto nodes = AllGraphNodes(subgraphs);
  759. for (size_t i = 0; i < nodes.size(); i++) {
  760. NodePtr node = nodes.at(i); // [node: should not be null]
  761. node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null]
  762. }
  763. if (sub_graph_.size() != subgraphs.size()) { // Graph Partition use subgraph, Keep original
  764. GELOGW("Keep original subgraph for graph size %zu not equal %zu.", sub_graph_.size(), subgraphs.size());
  765. return SUCCESS;
  766. }
  767. sub_graph_.swap(subgraphs);
  768. return SUCCESS;
  769. }
  770. graphStatus ComputeGraph::TopologicalSortingGraph(bool dfs_reverse) {
  771. std::vector<NodePtr> node_vec;
  772. std::map<NodePtr, uint32_t> map_in_edge_num;
  773. bool use_BFS = IsUseBFS();
  774. if (use_BFS) {
  775. std::deque<NodePtr> stack;
  776. if (BFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) {
  777. return GRAPH_FAILED;
  778. }
  779. } else {
  780. std::vector<NodePtr> stack;
  781. if (DFSTopologicalSorting(node_vec, map_in_edge_num, stack, dfs_reverse) != GRAPH_SUCCESS) {
  782. return GRAPH_FAILED;
  783. }
  784. }
  785. // If they are not equal, there is a closed loop
  786. if (node_vec.size() != nodes_.size()) {
  787. std::set<Node *> itered_nodes_set;
  788. for (auto &node : node_vec) {
  789. itered_nodes_set.insert(node.get());
  790. }
  791. ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
  792. {"TopologicalSortingGraph", "exist closed loop in graph"});
  793. GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", nodes_.size(),
  794. node_vec.size());
  795. for (auto &node : nodes_) {
  796. if (itered_nodes_set.count(node.get()) == 0) {
  797. ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
  798. {"TopologicalSortingGraph", "op[" + node->GetName() + "] does not itered when topological sorting"});
  799. GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str());
  800. }
  801. }
  802. return GRAPH_FAILED;
  803. }
  804. nodes_.clear();
  805. for (size_t i = 0; i < node_vec.size(); i++) {
  806. NodePtr node = node_vec[i]; // [node: should not be null]
  807. node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null]
  808. nodes_.push_back(node);
  809. }
  810. is_valid_flag_ = true;
  811. return GRAPH_SUCCESS;
  812. }
  813. graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &map_in_edge_num) {
  814. // Record the number of non data nodes but no input nodes
  815. uint32_t spec_node_size = 0;
  816. bool verify_isolated = false;
  817. string run_mode;
  818. const int base = 10;
  819. // Need verify isolated point in PREDICTION mode.
  820. if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) {
  821. if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) < TRAIN) {
  822. verify_isolated = true;
  823. }
  824. }
  825. for (const auto &node : GetDirectNode()) {
  826. GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue);
  827. map_in_edge_num[node] = static_cast<uint32_t>(GetInEdgeSize(node));
  828. if (map_in_edge_num[node] == 0) {
  829. if ((node->GetOpDesc()->GetType() != DATA) && (node->GetOpDesc()->GetType() != AIPPDATA) &&
  830. (node->GetOpDesc()->GetType() != INPUT_TYPE) && (node->GetOpDesc()->GetType() != ANN_DATA)) {
  831. // At present, can only judge the isolated point without input and output.
  832. // It is impossible to judge the situation with multiple output nodes.
  833. if (verify_isolated && GetOutEdgeSize(node) == 0) {
  834. ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
  835. {"SortNodes", "may has isolated node[" + node->GetName() + "] in graph"});
  836. GELOGE(GRAPH_FAILED, "May has isolated node in graph, node name: %s.", node->GetName().c_str());
  837. return GRAPH_FAILED;
  838. }
  839. (void)stack.insert(stack.begin(), node);
  840. spec_node_size++;
  841. continue;
  842. }
  843. // Need to insert the data nodes in reverse order
  844. (void)stack.insert(stack.begin() + spec_node_size, node);
  845. }
  846. }
  847. /// Make sure the inputs order matches with user-designated
  848. /// 1. Get the index of two input nodes in the user-inputs-order(inputs_order_)
  849. /// 2. Compare two indices, if not match, swap the positions of two inputs
  850. /// *: Remind: stack is reverse-order
  851. for (size_t i = 0; i < stack.size(); ++i) {
  852. // If not found in 'inputs_order_', skip it
  853. auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName());
  854. GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue);
  855. auto inx_i = it_i - inputs_order_.begin();
  856. for (size_t j = i + 1; j < stack.size(); ++j) {
  857. // If not found in 'inputs_order_', skip it
  858. auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName());
  859. GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue);
  860. // Compare index, swap them if it should be
  861. auto inx_j = it_j - inputs_order_.begin();
  862. GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j]));
  863. }
  864. }
  865. return GRAPH_SUCCESS;
  866. }
  867. size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) {
  868. size_t in_edge_size = 0;
  869. if (node == nullptr) {
  870. return in_edge_size;
  871. }
  872. for (const auto &anchor : node->GetAllInDataAnchors()) {
  873. in_edge_size = in_edge_size + anchor->GetPeerAnchorsSize();
  874. // Break flow control data loop.
  875. OutDataAnchorPtr out_anchor = anchor->GetPeerOutAnchor();
  876. if ((out_anchor != nullptr) && (out_anchor->GetOwnerNode() != nullptr)) {
  877. NodePtr out_node = out_anchor->GetOwnerNode();
  878. if (out_node == nullptr) {
  879. GELOGW("out node is nullptr");
  880. continue;
  881. }
  882. if ((out_node->GetType() == NEXTITERATION) || (out_node->GetType() == REFNEXTITERATION)) {
  883. GE_IF_BOOL_EXEC(in_edge_size == 0, GELOGE(GRAPH_FAILED, "If [in_edge_size = 0], the result will be reversed");
  884. return in_edge_size);
  885. in_edge_size -= 1;
  886. }
  887. }
  888. }
  889. if (node->GetInControlAnchor() != nullptr) {
  890. in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchorsSize();
  891. }
  892. return in_edge_size;
  893. }
  894. size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) {
  895. size_t out_edge_size = 0;
  896. if (node == nullptr) {
  897. return out_edge_size;
  898. }
  899. // Break flow control data loop.
  900. if ((node->GetType() != NEXTITERATION) && (node->GetType() != REFNEXTITERATION)) {
  901. for (const auto &anchor : node->GetAllOutDataAnchors()) {
  902. if (anchor != nullptr) {
  903. out_edge_size = out_edge_size + anchor->GetPeerAnchors().size();
  904. }
  905. }
  906. }
  907. if (node->GetOutControlAnchor() != nullptr) {
  908. if (out_edge_size > (UINT64_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) {
  909. return 0;
  910. }
  911. out_edge_size = out_edge_size + node->GetOutControlAnchor()->GetPeerAnchors().size();
  912. }
  913. return out_edge_size;
  914. }
  915. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::IsValid() const { return is_valid_flag_; }
  916. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const {
  917. GELOGI("graph name = %s.", GetName().c_str());
  918. for (const auto &node : GetAllNodes()) {
  919. GELOGD("node name = %s.", node->GetName().c_str());
  920. for (const auto &anchor : node->GetAllOutDataAnchors()) {
  921. for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) {
  922. GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
  923. GELOGI("node name = %s, out data node name = %s.", node->GetName().c_str(),
  924. peer_in_anchor->GetOwnerNode()->GetName().c_str()));
  925. }
  926. for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) {
  927. GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
  928. GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(),
  929. peer_in_anchor->GetOwnerNode()->GetName().c_str()));
  930. }
  931. }
  932. auto out_control_anchor = node->GetOutControlAnchor();
  933. if (out_control_anchor != nullptr) {
  934. for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) {
  935. GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
  936. GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(),
  937. peer_in_anchor->GetOwnerNode()->GetName().c_str()));
  938. }
  939. for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) {
  940. GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
  941. GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(),
  942. peer_in_anchor->GetOwnerNode()->GetName().c_str()));
  943. }
  944. }
  945. }
  946. }
  947. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Swap(ComputeGraph &graph) {
  948. this->AttrHolder::Swap(graph);
  949. origGraph_.swap(graph.origGraph_);
  950. name_.swap(graph.name_);
  951. std::swap(graph_id_, graph.graph_id_);
  952. attrs_.Swap(graph.attrs_);
  953. nodes_.swap(graph.nodes_);
  954. all_nodes_infos_.swap(graph.all_nodes_infos_);
  955. target_nodes_info_.swap(graph.target_nodes_info_);
  956. input_nodes_.swap(graph.input_nodes_);
  957. inputs_order_.swap(graph.inputs_order_);
  958. std::swap(input_size_, graph.input_size_);
  959. out_nodes_map_.swap(graph.out_nodes_map_);
  960. std::swap(output_size_, graph.output_size_);
  961. output_nodes_info_.swap(graph.output_nodes_info_);
  962. sub_graph_.swap(graph.sub_graph_);
  963. names_to_subgraph_.swap(graph.names_to_subgraph_);
  964. parent_graph_.swap(graph.parent_graph_);
  965. parent_node_.swap(graph.parent_node_);
  966. // the members followed should not in the ComputeGraph class
  967. std::swap(is_valid_flag_, graph.is_valid_flag_);
  968. std::swap(is_summary_graph_, graph.is_summary_graph_);
  969. std::swap(need_iteration_, graph.need_iteration_);
  970. params_share_map_.swap(graph.params_share_map_);
  971. op_name_map_.swap(graph.op_name_map_);
  972. std::swap(session_id_, graph.session_id_);
  973. std::swap(data_format_, graph.data_format_);
  974. std::swap(is_unknown_shape_graph_, graph.is_unknown_shape_graph_);
  975. // Update Node owner.
  976. SetNodesOwner();
  977. graph.SetNodesOwner();
  978. }
  979. void ComputeGraph::SetNodesOwner() {
  980. for (const auto &node : nodes_) {
  981. if (node == nullptr) {
  982. continue;
  983. }
  984. node->SetOwnerComputeGraph(shared_from_this());
  985. }
  986. }
  987. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::IsolateNode(const NodePtr &node) {
  988. GE_CHECK_NOTNULL(node);
  989. auto next_nodes = node->GetOutAllNodes();
  990. // If there is input data side
  991. for (size_t i = 0; i < node->GetAllInDataAnchors().size(); i++) {
  992. auto in_data_anchor = node->GetInDataAnchor(static_cast<int>(i));
  993. auto pre_out_data_anchor = in_data_anchor->GetPeerOutAnchor();
  994. if (pre_out_data_anchor != nullptr) {
  995. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_data_anchor, in_data_anchor) == GRAPH_SUCCESS,
  996. return GRAPH_FAILED, "remove edge failed");
  997. GE_IF_BOOL_EXEC(pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANT ||
  998. pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANTOP,
  999. continue);
  1000. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  1001. for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  1002. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS,
  1003. return GRAPH_FAILED, "remove edge failed");
  1004. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS,
  1005. return GRAPH_FAILED, "add edge failed");
  1006. }
  1007. for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) {
  1008. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1009. return GRAPH_FAILED, "remove edge failed");
  1010. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1011. return GRAPH_FAILED, "add edge failed");
  1012. }
  1013. }
  1014. auto out_ctrl_anchor = node->GetOutControlAnchor();
  1015. GE_CHECK_NOTNULL(out_ctrl_anchor);
  1016. auto pre_out_ctrl_anchor = pre_out_data_anchor->GetOwnerNode()->GetOutControlAnchor();
  1017. GE_CHECK_NOTNULL(pre_out_ctrl_anchor);
  1018. for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
  1019. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1020. return GRAPH_FAILED, "remove edge failed");
  1021. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1022. return GRAPH_FAILED, "add edge failed");
  1023. }
  1024. }
  1025. }
  1026. // If there is an input control side
  1027. auto in_ctrl_anchor = node->GetInControlAnchor();
  1028. GE_CHECK_NOTNULL(in_ctrl_anchor);
  1029. for (const auto &pre_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) {
  1030. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_ctrl_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, return GRAPH_FAILED,
  1031. "remove edge failed");
  1032. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  1033. for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) {
  1034. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1035. return GRAPH_FAILED, "remove edge failed");
  1036. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1037. return GRAPH_FAILED, "add edge failed");
  1038. }
  1039. }
  1040. auto out_ctrl_anchor = node->GetOutControlAnchor();
  1041. if (out_ctrl_anchor != nullptr) {
  1042. for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
  1043. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1044. return GRAPH_FAILED, "remove edge failed");
  1045. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1046. return GRAPH_FAILED, "add edge failed");
  1047. }
  1048. }
  1049. }
  1050. for (const auto &out_peer_data_anchor : in_ctrl_anchor->GetPeerOutDataAnchors()) {
  1051. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_peer_data_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, return GRAPH_FAILED,
  1052. "remove edge failed");
  1053. for (const auto &next_node : next_nodes) {
  1054. auto next_in_control_anchor = next_node->GetInControlAnchor();
  1055. GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(out_peer_data_anchor, next_in_control_anchor) == GRAPH_SUCCESS,
  1056. return GRAPH_FAILED, "add edge failed");
  1057. }
  1058. }
  1059. return RemoveExtraOutEdge(node);
  1060. }
  1061. graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) {
  1062. GE_CHECK_NOTNULL(node);
  1063. // Remove redundant output edges
  1064. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  1065. for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  1066. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS,
  1067. return GRAPH_FAILED, "remove edge failed");
  1068. }
  1069. for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) {
  1070. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1071. return GRAPH_FAILED, "remove edge failed");
  1072. }
  1073. }
  1074. auto out_ctrl_anchor = node->GetOutControlAnchor();
  1075. if (out_ctrl_anchor != nullptr) {
  1076. for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
  1077. GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS,
  1078. return GRAPH_FAILED, "remove edge failed");
  1079. }
  1080. }
  1081. return GRAPH_SUCCESS;
  1082. }
  1083. graphStatus ComputeGraph::Verify() {
  1084. bool is_unknown_graph = GetGraphUnknownFlag();
  1085. for (const auto &node_ptr : GetAllNodes()) {
  1086. GE_CHECK_NOTNULL(node_ptr);
  1087. GE_CHECK_NOTNULL(node_ptr->GetOpDesc());
  1088. GE_IF_BOOL_EXEC(is_unknown_graph, continue);
  1089. GE_CHK_BOOL_EXEC(node_ptr->GetOpDesc()->CommonVerify() == GRAPH_SUCCESS, return GRAPH_FAILED,
  1090. "Verifying %s failed.", node_ptr->GetName().c_str());
  1091. }
  1092. return GRAPH_SUCCESS;
  1093. }
  1094. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferOriginFormat() {
  1095. return ge::FormatRefiner::InferOrigineFormat(shared_from_this());
  1096. }
  1097. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferShapeInNeed() {
  1098. GE_CHK_BOOL_ONLY_LOG(TopologicalSorting() == GRAPH_SUCCESS, "Verifying failed.");
  1099. for (const auto &node_ptr : GetAllNodes()) {
  1100. GE_CHECK_NOTNULL(node_ptr);
  1101. auto op_desc = node_ptr->GetOpDesc();
  1102. bool is_need_infer = false;
  1103. (void)ge::AttrUtils::GetBool(op_desc, NEED_INFER, is_need_infer);
  1104. if (is_need_infer) {
  1105. GE_CHK_BOOL_EXEC(node_ptr->Verify() == GRAPH_SUCCESS, return GRAPH_FAILED, "Verifying %s failed.",
  1106. node_ptr->GetName().c_str());
  1107. graphStatus status = node_ptr->InferShapeAndType();
  1108. GE_CHK_BOOL_EXEC_INFO(node_ptr->GetType() == DATA || GRAPH_PARAM_INVALID != status, break,
  1109. "Op %s does not have the IMPLEMT_INFERFUNC definition,"
  1110. " and subsequent operators no longer perform shape inference.",
  1111. node_ptr->GetName().c_str());
  1112. GE_CHK_BOOL_EXEC(status == GRAPH_SUCCESS, return GRAPH_FAILED, "Inferring %s failed.",
  1113. node_ptr->GetName().c_str());
  1114. for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
  1115. GE_CHECK_NOTNULL(out_anchor->GetOwnerNode()->GetOpDesc());
  1116. auto output_tensor = out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx());
  1117. ge::TensorUtils::SetRealDimCnt(output_tensor, output_tensor.GetShape().GetDims().size());
  1118. (void)out_anchor->GetOwnerNode()->GetOpDesc()->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor);
  1119. for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
  1120. (void)peer_anchor->GetOwnerNode()->GetOpDesc()->UpdateInputDesc(peer_anchor->GetIdx(), output_tensor);
  1121. }
  1122. }
  1123. }
  1124. }
  1125. return GRAPH_SUCCESS;
  1126. }
  1127. ProtoAttrMapHelper ComputeGraph::MutableAttrMap() { return attrs_; }
  1128. ConstProtoAttrMapHelper ComputeGraph::GetAttrMap() const {
  1129. return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg());
  1130. }
  1131. const std::map<OperatorImplPtr, NodePtr> &ComputeGraph::GetAllNodesInfo() const { return all_nodes_infos_; }
  1132. void ComputeGraph::SetUserDefOutput(const std::string &output_name) {
  1133. if (output_name.empty()) {
  1134. return;
  1135. }
  1136. vector<string> nodes = StringUtils::Split(output_name, ';');
  1137. for (string node : nodes) {
  1138. vector<string> item = StringUtils::Split(node, ':');
  1139. if (item.size() != OUTPUT_PARAM_SIZE) {
  1140. GELOGW("invalid output param!input:%s", output_name.c_str());
  1141. continue;
  1142. }
  1143. int32_t index;
  1144. try {
  1145. index = stoi(StringUtils::Trim(item[1]));
  1146. } catch (const std::out_of_range &) {
  1147. GELOGW("outputname cause out of range execption!output_name:%s", output_name.c_str());
  1148. continue;
  1149. } catch (const std::invalid_argument &) {
  1150. GELOGW("outputname cause invalid argument!output_name:%s", output_name.c_str());
  1151. continue;
  1152. } catch (...) {
  1153. GELOGW("stoi fail! output_name:%s", output_name.c_str());
  1154. continue;
  1155. }
  1156. auto iter = out_nodes_map_.find(item[0]);
  1157. if (iter == out_nodes_map_.end()) {
  1158. out_nodes_map_[item[0]] = std::vector<int32_t>(1, index);
  1159. } else {
  1160. auto idx_iter = std::find(iter->second.begin(), iter->second.end(), index);
  1161. if (idx_iter == iter->second.end()) {
  1162. iter->second.push_back(index);
  1163. }
  1164. }
  1165. }
  1166. }
  1167. const std::string ComputeGraph::GetOutput() {
  1168. static const int resultDefaultSize = 2048;
  1169. string result;
  1170. result.reserve(resultDefaultSize);
  1171. auto iter = out_nodes_map_.begin();
  1172. while (iter != out_nodes_map_.end()) {
  1173. auto idxes = iter->second;
  1174. for (auto idx : idxes) {
  1175. (void)result.append(iter->first).append(":").append(std::to_string(idx)).append(";");
  1176. }
  1177. ++iter;
  1178. }
  1179. return result.substr(0, result.length() - 1);
  1180. }
  1181. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示