You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_utils.cc 37 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/utils/node_utils.h"
  17. #include "graph/utils/op_desc_utils.h"
  18. #include "graph/utils/graph_utils.h"
  19. #include "debug/ge_op_types.h"
  20. #include "debug/ge_util.h"
  21. #include "framework/common/debug/ge_log.h"
  22. #include "graph/anchor.h"
  23. #include "graph/debug/ge_attr_define.h"
  24. #include "graph/types.h"
  25. #include "external/graph/operator.h"
  26. #include "graph/ge_context.h"
  27. #include "graph/runtime_inference_context.h"
  28. #include "graph/utils/op_desc_utils.h"
  29. #include "graph/utils/tensor_utils.h"
  30. #include "graph/utils/tensor_adapter.h"
  31. #include "graph/utils/type_utils.h"
  32. namespace ge {
  33. std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{};
  34. std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{};
  35. const std::set<std::string> kConstOpTypes = { "Const", "Constant" };
  36. const std::set<std::string> kIfOpTypes = { "If", "_If", "StatelessIf" };
  37. const std::set<std::string> kWhileOpTypes = { "While", "_While", "StatelessWhile" };
  38. const std::set<std::string> kCaseOpTypes = { "Case" };
  39. const std::set<std::string> kForOpTypes = { "For" };
  40. bool OpShapeIsUnknown(const OpDescPtr &desc) {
  41. for (const auto &ptr : desc->GetAllInputsDescPtr()) {
  42. auto ge_shape = ptr->GetShape();
  43. for (const auto &dim : ge_shape.GetDims()) {
  44. if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
  45. return true;
  46. }
  47. }
  48. }
  49. for (const auto &ptr : desc->GetAllOutputsDescPtr()) {
  50. auto ge_shape = ptr->GetShape();
  51. for (const auto &dim : ge_shape.GetDims()) {
  52. if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
  53. return true;
  54. }
  55. }
  56. }
  57. return false;
  58. }
  59. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node,
  60. const uint32_t &event_id) {
  61. GE_CHECK_NOTNULL(node);
  62. map_send_info_[node].push_back(event_id);
  63. return GRAPH_SUCCESS;
  64. }
  65. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddRecvEventId(const NodePtr &node,
  66. const uint32_t &event_id) {
  67. GE_CHECK_NOTNULL(node);
  68. map_recv_info_[node].push_back(event_id);
  69. return GRAPH_SUCCESS;
  70. }
  71. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  72. NodeUtils::GetSendEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_send) {
  73. GE_CHECK_NOTNULL(node);
  74. auto find = map_send_info_.find(node);
  75. if (find == map_send_info_.end()) {
  76. return GRAPH_FAILED;
  77. } else {
  78. vec_send = find->second;
  79. return GRAPH_SUCCESS;
  80. }
  81. }
  82. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  83. NodeUtils::GetRecvEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_recv) {
  84. GE_CHECK_NOTNULL(node);
  85. auto find = map_recv_info_.find(node);
  86. if (find == map_recv_info_.end()) {
  87. return GRAPH_FAILED;
  88. } else {
  89. vec_recv = find->second;
  90. return GRAPH_SUCCESS;
  91. }
  92. }
  93. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearSendInfo() {
  94. map_send_info_.clear();
  95. return GRAPH_SUCCESS;
  96. }
  97. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearRecvInfo() {
  98. map_recv_info_.clear();
  99. return GRAPH_SUCCESS;
  100. }
  101. graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst) {
  102. GE_CHECK_NOTNULL(src);
  103. NodePtr cur_ptr;
  104. if (depth < 1) {
  105. return GRAPH_FAILED;
  106. }
  107. for (int i = 0; i < depth; i++) {
  108. if (src->GetOutDataNodes().size() != 1) {
  109. return GRAPH_FAILED;
  110. }
  111. cur_ptr = src->GetOutDataNodes().at(0);
  112. GE_CHECK_NOTNULL(cur_ptr);
  113. }
  114. dst = cur_ptr;
  115. return GRAPH_SUCCESS;
  116. }
  117. graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data,
  118. InControlAnchorPtr &in_control) {
  119. GE_CHECK_NOTNULL(node_ptr);
  120. for (const auto &p : node_ptr->GetAllOutDataAnchors()) {
  121. GE_CHK_BOOL_EXEC((p != nullptr), continue, "GetAllOutDataAnchors is nullptr");
  122. for (const auto &p_in : p->GetPeerInControlAnchors()) {
  123. GE_CHK_BOOL_EXEC((p_in != nullptr), continue, "GetPeerInDataAnchors is nullptr");
  124. out_data = p;
  125. in_control = p_in;
  126. return GRAPH_SUCCESS;
  127. }
  128. }
  129. return GRAPH_FAILED;
  130. }
  131. graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) {
  132. GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED,
  133. "node or in_data_anchor is nullptr");
  134. bool find_flag = false;
  135. uint32_t index = 0;
  136. vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end();
  137. for (const auto &tmp : node_ptr->in_data_anchors_) {
  138. if (tmp == in_data_anchor) {
  139. find_flag = true;
  140. auto iter = node_ptr->in_data_anchors_.begin() + index;
  141. if (iter != node_ptr->in_data_anchors_.end()) {
  142. it = node_ptr->in_data_anchors_.erase(iter);
  143. }
  144. break;
  145. }
  146. index++;
  147. }
  148. for (; it != node_ptr->in_data_anchors_.end(); ++it) {
  149. (*it)->SetIdx(index);
  150. index++;
  151. }
  152. if (!find_flag) {
  153. return GRAPH_FAILED;
  154. }
  155. return GRAPH_SUCCESS;
  156. }
  157. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) {
  158. GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "node is nullptr");
  159. GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, return GRAPH_FAILED, "set all anchor status failed");
  160. return GRAPH_SUCCESS;
  161. }
  162. graphStatus NodeUtils::SetAllAnchorStatus(Node &node) {
  163. node.anchor_status_updated_ = true;
  164. return GRAPH_SUCCESS;
  165. }
  166. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) {
  167. GE_CHK_BOOL_EXEC(node_ptr != nullptr, return false, "node is nullptr");
  168. return IsAnchorStatusSet(*node_ptr);
  169. }
  170. bool NodeUtils::IsAnchorStatusSet(const Node &node) { return node.anchor_status_updated_; }
  171. graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) {
  172. if ((origin_node == nullptr) || (new_node == nullptr)) {
  173. return GRAPH_FAILED;
  174. }
  175. auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors();
  176. auto new_out_data_anchors = new_node->GetAllOutDataAnchors();
  177. if (origin_out_data_anchors.size() != new_out_data_anchors.size()) {
  178. return GRAPH_FAILED;
  179. }
  180. for (size_t i = 0; i < origin_out_data_anchors.size(); ++i) {
  181. for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) {
  182. GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
  183. "unlink peer_anchor failed");
  184. GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  185. "linkto peer_anchor failed");
  186. }
  187. for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) {
  188. GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
  189. "unlink peer_anchor failed");
  190. GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  191. "linkto peer_anchor failed");
  192. }
  193. }
  194. auto origin_out_control_anchor = origin_node->GetOutControlAnchor();
  195. GE_CHECK_NOTNULL(origin_out_control_anchor);
  196. auto new_out_control_anchor = new_node->GetOutControlAnchor();
  197. GE_CHECK_NOTNULL(new_out_control_anchor);
  198. for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) {
  199. GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  200. "linkto peer_anchor failed");
  201. }
  202. for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) {
  203. GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  204. "linkto peer_anchor failed");
  205. }
  206. origin_out_control_anchor->UnlinkAll();
  207. return GRAPH_SUCCESS;
  208. }
  209. bool NodeUtils::IsConst(const Node &node) {
  210. auto src_node_type = node.GetType();
  211. bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP));
  212. return is_const;
  213. }
  214. void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) {
  215. if (node_ptr == nullptr) {
  216. GELOGE(GRAPH_FAILED, "node is null");
  217. return;
  218. }
  219. UpdateIsInputConst(*node_ptr);
  220. }
  221. ///
  222. /// update is_input_const
  223. /// @param node
  224. /// @return void
  225. ///
  226. void NodeUtils::UpdateIsInputConst(Node &node) {
  227. std::vector<bool> is_input_const;
  228. size_t anchor_num = node.GetAllInDataAnchors().size();
  229. for (size_t i = 0; i < anchor_num; i++) {
  230. auto in_anchor = node.GetInDataAnchor(static_cast<int>(i));
  231. if (in_anchor == nullptr) {
  232. is_input_const.push_back(false);
  233. continue;
  234. }
  235. auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
  236. if (peer_out_anchor == nullptr) {
  237. is_input_const.push_back(false);
  238. continue;
  239. }
  240. auto src_node = peer_out_anchor->GetOwnerNode();
  241. if (src_node == nullptr) {
  242. is_input_const.push_back(false);
  243. continue;
  244. }
  245. if (IsConst(*(src_node))) {
  246. is_input_const.push_back(true);
  247. } else {
  248. is_input_const.push_back(false);
  249. }
  250. }
  251. if (node.GetOpDesc() == nullptr) {
  252. GELOGE(GRAPH_FAILED, "Node get opdesc is nullptr");
  253. return;
  254. }
  255. node.GetOpDesc()->SetIsInputConst(is_input_const);
  256. }
  257. void NodeUtils::UnlinkAll(const Node &node) {
  258. for (const auto &anchor : node.GetAllOutAnchors()) {
  259. anchor->UnlinkAll();
  260. }
  261. for (const auto &anchor : node.GetAllInAnchors()) {
  262. anchor->UnlinkAll();
  263. }
  264. }
  265. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeerNodeInputDesc(const NodePtr &node_ptr) {
  266. if (node_ptr == nullptr) {
  267. GELOGE(GRAPH_FAILED, "Nodeptr is nullptr");
  268. return GRAPH_FAILED;
  269. }
  270. auto op_desc = node_ptr->GetOpDesc();
  271. if (op_desc == nullptr) {
  272. return GRAPH_FAILED;
  273. }
  274. bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag();
  275. if (is_unknown_graph) {
  276. return GRAPH_SUCCESS;
  277. }
  278. for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
  279. auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
  280. auto out_dims = output_tensor->GetShape().GetDims();
  281. auto out_dtype = output_tensor->GetDataType();
  282. ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
  283. output_tensor->SetOriginShape(output_tensor->GetShape());
  284. output_tensor->SetOriginDataType(output_tensor->GetDataType());
  285. GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
  286. node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
  287. TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
  288. TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
  289. for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
  290. auto peer_anchor_opdesc = peer_anchor->GetOwnerNode()->GetOpDesc();
  291. if (peer_anchor_opdesc == nullptr) {
  292. GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null");
  293. continue;
  294. }
  295. if (op_desc->GetId() < peer_anchor_opdesc->GetId() ||
  296. peer_anchor_opdesc->GetType() == CONSTANT ||
  297. peer_anchor_opdesc->GetType() == CONSTANTOP) {
  298. GELOGD("no need to UpdatePeerNodeInputDesc");
  299. continue;
  300. }
  301. auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx());
  302. if (peer_input_desc == nullptr) {
  303. GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr");
  304. continue;
  305. }
  306. // check shape and dtype continuity. do not stop process
  307. auto peer_input_dims = peer_input_desc->GetShape().GetDims();
  308. auto peer_input_dtype = peer_input_desc->GetDataType();
  309. if (out_dtype != peer_input_dtype) {
  310. GELOGW("current node [%s] [%d]\'th out_dtype is [%s].peer input node [%s] [%d]\'th "
  311. "input_dtype is [%s].The two dtype should be same! Please check graph and fix it",
  312. node_ptr->GetName().c_str(), out_anchor->GetIdx(), TypeUtils::DataTypeToSerialString(out_dtype).c_str(),
  313. peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(),
  314. TypeUtils::DataTypeToSerialString(peer_input_dtype).c_str());
  315. } else if ((!peer_input_dims.empty()) && (out_dims != peer_input_dims)) {
  316. GELOGW("current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th "
  317. "input_shape is [%s].The two shape should be same! Please check graph and fix it",
  318. node_ptr->GetName().c_str(), out_anchor->GetIdx(), output_tensor->GetShape().ToString().c_str(),
  319. peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(),
  320. peer_input_desc->GetShape().ToString().c_str());
  321. }
  322. GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d",
  323. peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(),
  324. output_tensor->GetShape().GetDimNum(), output_tensor->GetDataType(),
  325. output_tensor->GetOriginDataType());
  326. peer_input_desc->SetOriginShape(output_tensor->GetOriginShape());
  327. peer_input_desc->SetShape(output_tensor->GetShape());
  328. peer_input_desc->SetDataType(output_tensor->GetDataType());
  329. peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType());
  330. std::vector<std::pair<int64_t, int64_t>> shape_range;
  331. (void) output_tensor->GetShapeRange(shape_range);
  332. peer_input_desc->SetShapeRange(shape_range);
  333. ge::TensorUtils::SetRealDimCnt(*peer_input_desc,
  334. static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
  335. GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d",
  336. peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(),
  337. peer_input_desc->GetShape().GetDimNum(), peer_input_desc->GetDataType(),
  338. peer_input_desc->GetOriginDataType());
  339. }
  340. }
  341. return GRAPH_SUCCESS;
  342. }
  343. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  344. graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node, uint32_t num) {
  345. if (node == nullptr) {
  346. GELOGE(GRAPH_FAILED, "Input node is null");
  347. return GRAPH_FAILED;
  348. }
  349. GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT);
  350. const auto &op_desc = node->GetOpDesc();
  351. for (size_t i = op_desc->GetInputsSize(); i < num; ++i) {
  352. if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) {
  353. GELOGE(GRAPH_FAILED, "Add input desc failed");
  354. return GRAPH_FAILED;
  355. }
  356. }
  357. for (size_t i = node->in_data_anchors_.size(); i < num; ++i) {
  358. auto anchor = ComGraphMakeShared<InDataAnchor>(node, i);
  359. if (anchor == nullptr) {
  360. GELOGE(OUT_OF_MEMORY, "Current in data anchor is null, make shared_ptr failed.");
  361. return GRAPH_FAILED;
  362. }
  363. node->in_data_anchors_.push_back(anchor);
  364. }
  365. return GRAPH_SUCCESS;
  366. }
  367. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  368. graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node, uint32_t num) {
  369. if (node == nullptr) {
  370. GELOGE(GRAPH_FAILED, "Input node is null");
  371. return GRAPH_FAILED;
  372. }
  373. const auto &op_desc = node->GetOpDesc();
  374. while (op_desc->GetInputsSize() > num) {
  375. if (!OpDescUtils::ClearInputDesc(op_desc, num)) {
  376. return GRAPH_FAILED;
  377. }
  378. }
  379. auto input_names = op_desc->GetAllInputName();
  380. (void)op_desc->UpdateInputName(input_names);
  381. auto is_input_const = op_desc->GetIsInputConst();
  382. is_input_const.resize(num);
  383. op_desc->SetIsInputConst(is_input_const);
  384. while (node->in_data_anchors_.size() > num) {
  385. node->in_data_anchors_.pop_back();
  386. }
  387. return GRAPH_SUCCESS;
  388. }
  389. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  390. graphStatus NodeUtils::AppendOutputAnchor(const NodePtr &node, uint32_t num) {
  391. if (node == nullptr) {
  392. GELOGE(GRAPH_FAILED, "Input node is null");
  393. return GRAPH_FAILED;
  394. }
  395. GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT);
  396. const OpDescPtr &op_desc = node->GetOpDesc();
  397. for (size_t i = op_desc->GetOutputsSize(); i < num; ++i) {
  398. if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) {
  399. GELOGE(GRAPH_FAILED, "Add output desc failed");
  400. return GRAPH_FAILED;
  401. }
  402. }
  403. for (size_t i = node->out_data_anchors_.size(); i < num; ++i) {
  404. auto anchor = ComGraphMakeShared<OutDataAnchor>(node, i);
  405. if (anchor == nullptr) {
  406. GELOGE(OUT_OF_MEMORY, "Current out data anchor is null, make shared_ptr failed.");
  407. return GRAPH_FAILED;
  408. }
  409. node->out_data_anchors_.push_back(anchor);
  410. }
  411. return GRAPH_SUCCESS;
  412. }
  413. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  414. graphStatus NodeUtils::RemoveOutputAnchor(const NodePtr &node, uint32_t num) {
  415. if (node == nullptr) {
  416. GELOGE(GRAPH_FAILED, "Input node is null");
  417. return GRAPH_FAILED;
  418. }
  419. const auto &op_desc = node->GetOpDesc();
  420. auto output_names = op_desc->GetAllOutputName();
  421. while (op_desc->GetOutputsSize() > num) {
  422. if (!OpDescUtils::ClearOutputDesc(op_desc, num)) {
  423. return GRAPH_FAILED;
  424. }
  425. }
  426. (void)op_desc->UpdateOutputName(output_names);
  427. while (node->out_data_anchors_.size() > num) {
  428. node->out_data_anchors_.pop_back();
  429. }
  430. return GRAPH_SUCCESS;
  431. }
  432. bool NodeUtils::IsInNodesEmpty(const Node &node) {
  433. for (const auto &in_anchor : node.in_data_anchors_) {
  434. if (in_anchor != nullptr) {
  435. auto out_anchor = in_anchor->GetPeerOutAnchor();
  436. if (out_anchor != nullptr) {
  437. if (out_anchor->GetOwnerNode() != nullptr) {
  438. return false;
  439. }
  440. }
  441. }
  442. }
  443. if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) {
  444. auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors();
  445. for (const auto &out_control_anchor : peer_out_control_anchors) {
  446. if (out_control_anchor != nullptr) {
  447. if (out_control_anchor->GetOwnerNode() != nullptr) {
  448. return false;
  449. }
  450. }
  451. }
  452. }
  453. return true;
  454. }
  455. GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, uint32_t index) {
  456. auto desc = node.GetOpDesc();
  457. if (desc == nullptr) {
  458. return GeTensorDesc();
  459. }
  460. return desc->GetOutputDesc(index);
  461. }
  462. GeTensorDesc NodeUtils::GetInputDesc(const Node &node, uint32_t index) {
  463. auto desc = node.GetOpDesc();
  464. if (desc == nullptr) {
  465. return GeTensorDesc();
  466. }
  467. return desc->GetInputDesc(index);
  468. }
  469. graphStatus NodeUtils::UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape) {
  470. auto desc = node.GetOpDesc();
  471. if (desc == nullptr) {
  472. return GRAPH_PARAM_INVALID;
  473. }
  474. auto output_desc = desc->MutableOutputDesc(index);
  475. if (output_desc == nullptr) {
  476. return GRAPH_PARAM_INVALID;
  477. }
  478. output_desc->SetShape(shape);
  479. return GRAPH_SUCCESS;
  480. }
  481. graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape) {
  482. auto desc = node.GetOpDesc();
  483. if (desc == nullptr) {
  484. return GRAPH_PARAM_INVALID;
  485. }
  486. auto input_desc = desc->MutableInputDesc(index);
  487. if (input_desc == nullptr) {
  488. return GRAPH_PARAM_INVALID;
  489. }
  490. input_desc->SetShape(shape);
  491. return GRAPH_SUCCESS;
  492. }
  493. graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) {
  494. auto desc = node.GetOpDesc();
  495. GE_CHECK_NOTNULL(desc);
  496. // check self
  497. is_unknow = OpShapeIsUnknown(desc);
  498. if (is_unknow) {
  499. return GRAPH_SUCCESS;
  500. }
  501. auto sub_graph_names = desc->GetSubgraphInstanceNames();
  502. if (sub_graph_names.empty()) {
  503. return GRAPH_SUCCESS;
  504. } else {
  505. auto owner_graph = node.GetOwnerComputeGraph();
  506. GE_CHECK_NOTNULL(owner_graph);
  507. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  508. if (root_graph == nullptr) {
  509. GE_LOGE("Node %s gets null root graph", node.GetName().c_str());
  510. return GRAPH_PARAM_INVALID;
  511. }
  512. for (auto &sub_graph_name : sub_graph_names) {
  513. auto sub_graph = root_graph->GetSubgraph(sub_graph_name);
  514. GE_CHECK_NOTNULL(sub_graph);
  515. for (const auto &node_ptr : sub_graph->GetDirectNode()) {
  516. auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow);
  517. if (status != GRAPH_SUCCESS) {
  518. GE_LOGE("get node unknown shape status failed!");
  519. return status;
  520. }
  521. if (is_unknow) {
  522. return GRAPH_SUCCESS;
  523. }
  524. }
  525. }
  526. }
  527. return GRAPH_SUCCESS;
  528. }
  529. graphStatus NodeUtils::GetInputConstData(const ConstNodePtr& node_ptr,
  530. const string &dst_name,
  531. GeTensorPtr &ge_tensor) {
  532. GE_CHECK_NOTNULL(node_ptr);
  533. return NodeUtils::GetInputConstData(*node_ptr, dst_name, ge_tensor);
  534. }
  535. graphStatus NodeUtils::GetInputConstData(const Node &node,
  536. const string &dst_name,
  537. GeTensorPtr &ge_tensor) {
  538. // For inner compute graph
  539. auto op_desc = node.GetOpDesc();
  540. GE_CHECK_NOTNULL(op_desc);
  541. auto index = op_desc->GetInputIndexByName(dst_name);
  542. auto in_data_anchor = node.GetInDataAnchor(index);
  543. GE_CHECK_NOTNULL(in_data_anchor);
  544. auto out_data_anchor = in_data_anchor->GetPeerOutAnchor();
  545. GE_CHECK_NOTNULL(out_data_anchor);
  546. auto peer_node = out_data_anchor->GetOwnerNode();
  547. if (peer_node->GetType() == ENTER || peer_node->GetType() == REFENTER) {
  548. auto enter_in_data_anchor = peer_node->GetInDataAnchor(0);
  549. GE_CHECK_NOTNULL(enter_in_data_anchor);
  550. auto enter_peer_out_data_anchor = enter_in_data_anchor->GetPeerOutAnchor();
  551. GE_CHECK_NOTNULL(enter_peer_out_data_anchor);
  552. peer_node = enter_peer_out_data_anchor->GetOwnerNode();
  553. }
  554. auto peer_op_desc = peer_node->GetOpDesc();
  555. GE_CHECK_NOTNULL(peer_op_desc);
  556. auto peer_op_type = peer_op_desc->GetType();
  557. if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) {
  558. if (!AttrUtils::MutableTensor(peer_node->GetOpDesc(), ATTR_NAME_WEIGHTS, ge_tensor)) {
  559. GELOGW("get attr name %s failed.", ATTR_NAME_WEIGHTS.c_str());
  560. return GRAPH_FAILED;
  561. }
  562. return GRAPH_SUCCESS;
  563. } else if (peer_op_type == DATA) {
  564. auto parent_node = NodeUtils::GetParentInput(peer_node);
  565. while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) {
  566. parent_node = NodeUtils::GetParentInput(parent_node);
  567. }
  568. if ((parent_node != nullptr)
  569. && ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) {
  570. if (!AttrUtils::MutableTensor(parent_node->GetOpDesc(), ATTR_NAME_WEIGHTS, ge_tensor)) {
  571. GELOGW("get attr name %s failed.", ATTR_NAME_WEIGHTS.c_str());
  572. return GRAPH_FAILED;
  573. }
  574. return GRAPH_SUCCESS;
  575. }
  576. }
  577. // Try get from runtime inference context
  578. auto session_id = std::to_string(GetContext().SessionId());
  579. RuntimeInferenceContext *runtime_infer_ctx = nullptr;
  580. if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) {
  581. GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str());
  582. auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(),
  583. out_data_anchor->GetIdx(), ge_tensor);
  584. if (ret == GRAPH_SUCCESS) {
  585. return GRAPH_SUCCESS;
  586. }
  587. }
  588. GELOGW("node[%s]'s input[%s]'s peer node is not const", node.GetName().c_str(), dst_name.c_str());
  589. return GRAPH_FAILED;
  590. }
  591. std::string NodeUtils::GetNodeType(const Node &node) {
  592. if (node.GetType() != FRAMEWORKOP) {
  593. return node.GetType();
  594. }
  595. std::string type;
  596. (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
  597. return type;
  598. }
  599. std::string NodeUtils::GetNodeType(const NodePtr &node) {
  600. return node == nullptr ? "" : GetNodeType(*node);
  601. }
  602. std::vector<ComputeGraphPtr> NodeUtils::GetAllSubgraphs(const Node &node) {
  603. auto op_desc = node.GetOpDesc();
  604. if (op_desc == nullptr) {
  605. GELOGE(GRAPH_FAILED, "Failed to get op desc from node %s ", node.GetName().c_str());
  606. return {};
  607. }
  608. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  609. if (root_graph == nullptr) {
  610. GELOGE(GRAPH_FAILED, "Failed to find root graph from node %s ", node.GetName().c_str());
  611. return {};
  612. }
  613. return root_graph->GetAllSubgraphs();
  614. }
  615. ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) {
  616. auto op_desc = node.GetOpDesc();
  617. if (op_desc == nullptr) {
  618. return nullptr;
  619. }
  620. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  621. if (root_graph == nullptr) {
  622. return nullptr;
  623. }
  624. return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index));
  625. }
  626. graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) {
  627. if (subgraph == nullptr) {
  628. GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index);
  629. return GRAPH_PARAM_INVALID;
  630. }
  631. auto op_desc = node.GetOpDesc();
  632. if (op_desc == nullptr) {
  633. return GRAPH_PARAM_INVALID;
  634. }
  635. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  636. if (root_graph == nullptr) {
  637. GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str());
  638. return GRAPH_PARAM_INVALID;
  639. }
  640. auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName());
  641. if (ret != GRAPH_SUCCESS) {
  642. GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index);
  643. return ret;
  644. }
  645. subgraph->SetParentNode(node.shared_from_this());
  646. subgraph->SetParentGraph(node.GetOwnerComputeGraph());
  647. return root_graph->AddSubgraph(subgraph);
  648. }
  649. ///
  650. /// Check if node is input of subgraph
  651. /// @param [in] node
  652. /// @return bool
  653. ///
  654. bool NodeUtils::IsSubgraphInput(const NodePtr &node) {
  655. if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
  656. (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) {
  657. return false;
  658. }
  659. auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
  660. if (parent_op_desc == nullptr) {
  661. return false;
  662. }
  663. // dynamic shape unknown graph false
  664. // dynamic shape known graph with functional subgraph maybe true
  665. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
  666. if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) {
  667. return false;
  668. } else {
  669. if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) {
  670. return false;
  671. }
  672. }
  673. }
  674. return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX);
  675. }
  676. ///
  677. /// Check if node is output of subgraph
  678. /// @param [in] node
  679. /// @return bool
  680. ///
  681. bool NodeUtils::IsSubgraphOutput(const NodePtr &node) {
  682. if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
  683. (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) {
  684. return false;
  685. }
  686. auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
  687. if (parent_op_desc == nullptr) {
  688. return false;
  689. }
  690. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
  691. if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) {
  692. return false;
  693. } else {
  694. if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) {
  695. return false;
  696. }
  697. }
  698. }
  699. for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) {
  700. if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) {
  701. return true;
  702. }
  703. }
  704. return false;
  705. }
  706. ///
  707. /// @brief Get subgraph original input node.
  708. /// @param [in] node
  709. /// @return Node
  710. ///
  711. NodePtr NodeUtils::GetParentInput(const Node &node) {
  712. uint32_t parent_index = 0;
  713. if (!AttrUtils::GetInt(node.GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  714. return nullptr;
  715. }
  716. // Subgraph Data Node, check for constant input.
  717. const ComputeGraphPtr &graph = node.GetOwnerComputeGraph();
  718. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  719. const NodePtr &parent_node = graph->GetParentNode();
  720. GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr);
  721. const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index);
  722. GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr);
  723. const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor();
  724. GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr);
  725. return peer_out_anchor->GetOwnerNode();
  726. }
  727. NodePtr NodeUtils::GetParentInput(const NodePtr &node) {
  728. return node == nullptr ? node : GetParentInput(*node);
  729. }
  730. ///
  731. /// @brief Get is dynamic shape graph from node.
  732. /// @param [in] node
  733. /// @return bool
  734. ///
  735. bool NodeUtils::IsDynamicShape(const Node &node) {
  736. const auto graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  737. if (graph == nullptr) {
  738. return false;
  739. }
  740. bool is_dynamic_shape = false;
  741. (void)AttrUtils::GetBool(graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape);
  742. return is_dynamic_shape;
  743. }
  744. bool NodeUtils::IsDynamicShape(const NodePtr &node) {
  745. return node == nullptr ? false : IsDynamicShape(*node);
  746. }
  747. ///
  748. /// @brief Check is varying_input for while node
  749. /// @param [in] node: Data node for subgraph
  750. /// @return bool
  751. ///
  752. bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) {
  753. if (node == nullptr) {
  754. return false;
  755. }
  756. if (node->GetType() != DATA) {
  757. return false; // not input_node for subgraph
  758. }
  759. const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode();
  760. if (parent_node == nullptr) {
  761. return false; // root graph
  762. }
  763. if (kWhileOpTypes.count(parent_node->GetType()) == 0) {
  764. return false; // not input_node for while subgraph
  765. }
  766. uint32_t index_i = 0;
  767. if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) {
  768. GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str());
  769. return false;
  770. }
  771. bool varying_flag = true;
  772. for (const auto &item : node->GetOutDataNodesAndAnchors()) {
  773. if (item.first->GetType() != NETOUTPUT) {
  774. continue;
  775. }
  776. OpDescPtr op_desc = item.first->GetOpDesc();
  777. uint32_t index_o = 0;
  778. if ((op_desc == nullptr) ||
  779. !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) {
  780. continue; // input for while-cond subgraph
  781. }
  782. if (index_i != index_o) {
  783. continue; // varying input for while-body subgraph
  784. }
  785. varying_flag = false;
  786. break;
  787. }
  788. return varying_flag;
  789. }
  790. ///
  791. /// @brief Get subgraph input is constant.
  792. /// @param [in] node
  793. /// @param [out] string
  794. /// @return bool
  795. ///
  796. bool NodeUtils::GetConstOpType(const NodePtr &node, std::string &type) {
  797. if (node == nullptr) {
  798. return false;
  799. }
  800. if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) {
  801. type = node->GetType();
  802. return true;
  803. }
  804. if (node->GetType() != DATA) {
  805. return false; // not subgraph input node
  806. }
  807. const auto &parent = GetParentInput(node);
  808. return GetConstOpType(parent, type);
  809. }
  810. ///
  811. /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph.
  812. /// @param [in] node
  813. /// @return return GRAPH_SUCCESS if remove successfully, other for failed.
  814. ///
  815. Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) {
  816. GE_CHECK_NOTNULL(node);
  817. auto op_desc = node->GetOpDesc();
  818. GE_CHECK_NOTNULL(op_desc);
  819. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  820. if (subgraph_names.empty()) {
  821. return GRAPH_SUCCESS;
  822. } else {
  823. auto owner_graph = node->GetOwnerComputeGraph();
  824. GE_CHECK_NOTNULL(owner_graph);
  825. auto root_graph = GraphUtils::FindRootGraph(owner_graph);
  826. GE_CHECK_NOTNULL(root_graph);
  827. std::unordered_set<std::string> subgraph_to_remove;
  828. for (auto &subgraph_name : subgraph_names) {
  829. std::deque<std::string> queue;
  830. queue.push_back(subgraph_name);
  831. subgraph_to_remove.insert(subgraph_name);
  832. op_desc->RemoveSubgraphInstanceName(subgraph_name);
  833. while (!queue.empty()) {
  834. auto graph_name = queue.front();
  835. queue.pop_front();
  836. auto subgraph = root_graph->GetSubgraph(graph_name);
  837. GE_CHECK_NOTNULL(subgraph);
  838. for (const auto &sub_node : subgraph->GetDirectNode()) {
  839. auto sub_op_desc = sub_node->GetOpDesc();
  840. GE_CHECK_NOTNULL(sub_op_desc);
  841. auto sub_names = sub_op_desc->GetSubgraphInstanceNames();
  842. // Subgraph and all nodes in it will be removed later,
  843. // no need to remove 'SubgraphInstanceName' in op desc here.
  844. for (auto &name : sub_names) {
  845. if (subgraph_to_remove.insert(name).second) {
  846. queue.push_back(name);
  847. }
  848. }
  849. }
  850. }
  851. }
  852. // Remove subgraph from root_graph
  853. for (const auto &name : subgraph_to_remove) {
  854. GELOGI("Remove subgraph:%s.", name.c_str());
  855. root_graph->RemoveSubgraph(name);
  856. }
  857. }
  858. return GRAPH_SUCCESS;
  859. }
  860. ///
  861. /// @brief Get subgraph input data node by index.
  862. /// @param [in] node
  863. /// @return Node
  864. ///
  865. vector<NodePtr> NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) {
  866. vector<NodePtr> in_data_node_vec;
  867. auto op_desc = node.GetOpDesc();
  868. GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec);
  869. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  870. if (subgraph_names.empty()) {
  871. GELOGW("Node %s is single node without sub graph.", node.GetName().c_str());
  872. return in_data_node_vec;
  873. }
  874. auto compute_graph = node.GetOwnerComputeGraph();
  875. for (const std::string &instance_name : subgraph_names) {
  876. auto subgraph = compute_graph->GetSubgraph(instance_name);
  877. for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
  878. int parent_index = -1;
  879. if (NodeUtils::IsSubgraphInput(node_in_subgraph)) {
  880. (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index);
  881. if (parent_index == index) {
  882. in_data_node_vec.emplace_back(node_in_subgraph);
  883. }
  884. }
  885. }
  886. }
  887. return in_data_node_vec;
  888. }
  889. ///
  890. /// @brief Get subgraph input data node by index.
  891. /// @param [in] node
  892. /// @return Node
  893. ///
  894. vector<NodePtr> NodeUtils::GetSubgraphOutputNodes(const Node &node) {
  895. vector<NodePtr> out_data_node_vec;
  896. auto op_desc = node.GetOpDesc();
  897. GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec);
  898. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  899. if (subgraph_names.empty()) {
  900. GELOGI("Node %s is single node without sub graph.", node.GetName().c_str());
  901. return out_data_node_vec;
  902. }
  903. auto compute_graph = node.GetOwnerComputeGraph();
  904. for (const std::string &instance_name : subgraph_names) {
  905. auto subgraph = compute_graph->GetSubgraph(instance_name);
  906. for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
  907. if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) {
  908. out_data_node_vec.emplace_back(node_in_subgraph);
  909. }
  910. }
  911. }
  912. return out_data_node_vec;
  913. }
  914. NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, const int index) {
  915. if (node.GetInDataAnchor(index) == nullptr) {
  916. return nullptr;
  917. }
  918. if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) {
  919. return nullptr;
  920. }
  921. return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode();
  922. }
  923. vector<pair<InDataAnchorPtr, NodePtr>> NodeUtils::GetOutDataNodesWithAnchorByIndex(const Node &node, const int index) {
  924. vector<pair<InDataAnchorPtr, NodePtr>> out_data_nodes;
  925. auto out_data_anchor = node.GetOutDataAnchor(index);
  926. if (out_data_anchor == nullptr) {
  927. return out_data_nodes;
  928. }
  929. for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  930. if (peer_in_anchor == nullptr) {
  931. continue;
  932. }
  933. if (peer_in_anchor->GetOwnerNode() == nullptr) {
  934. continue;
  935. }
  936. out_data_nodes.emplace_back(std::make_pair(peer_in_anchor, peer_in_anchor->GetOwnerNode()));
  937. }
  938. return out_data_nodes;
  939. }
  940. ConstNodePtr NodeUtils::GetNodeFromOperator(const Operator &oprt) {
  941. return oprt.GetNode();
  942. }
  943. std::string NodeUtils::GetInConstNodeTypeCrossSubgraph(const NodePtr &node) {
  944. NodePtr input_node = node;
  945. while (input_node != nullptr) {
  946. if (input_node->GetType() != DATA) {
  947. return input_node->GetType();
  948. }
  949. auto owner_graph = input_node->GetOwnerComputeGraph();
  950. auto parent_node = owner_graph->GetParentNode();
  951. if ((parent_node == nullptr) || (kWhileOpTypes.count(parent_node->GetType()) > 0)) {
  952. return node->GetType(); // not in subgraph or while subgraph.
  953. }
  954. input_node = GetParentInput(input_node);
  955. }
  956. return "";
  957. }
  958. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示