You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_utils.cc 29 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/node_utils.h"
  17. #include "graph/utils/graph_utils.h"
  18. #include "debug/ge_op_types.h"
  19. #include "debug/ge_util.h"
  20. #include "framework/common/debug/ge_log.h"
  21. #include "graph/anchor.h"
  22. #include "graph/debug/ge_attr_define.h"
  23. #include "graph/types.h"
  24. #include "utils/tensor_utils.h"
  25. #include "utils/type_utils.h"
  26. namespace ge {
  27. std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{};
  28. std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{};
  29. const std::set<std::string> kConstOpTypes = {"Const", "Constant"};
  30. const std::set<std::string> kIfOpTypes = {"If", "_If", "StatelessIf"};
  31. const std::set<std::string> kWhileOpTypes = {"While", "_While", "StatelessWhile"};
  32. const std::set<std::string> kCaseOpTypes = {"Case"};
  33. const std::set<std::string> kForOpTypes = {"For"};
  34. bool OpShapeIsUnknown(const OpDescPtr &desc) {
  35. for (const auto &ptr : desc->GetAllInputsDescPtr()) {
  36. auto ge_shape = ptr->GetShape();
  37. for (const auto &dim : ge_shape.GetDims()) {
  38. if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
  39. return true;
  40. }
  41. }
  42. }
  43. for (const auto &ptr : desc->GetAllOutputsDescPtr()) {
  44. auto ge_shape = ptr->GetShape();
  45. for (const auto &dim : ge_shape.GetDims()) {
  46. if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
  47. return true;
  48. }
  49. }
  50. }
  51. return false;
  52. }
  53. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node,
  54. const uint32_t &event_id) {
  55. GE_CHECK_NOTNULL(node);
  56. map_send_info_[node].push_back(event_id);
  57. return GRAPH_SUCCESS;
  58. }
  59. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddRecvEventId(const NodePtr &node,
  60. const uint32_t &event_id) {
  61. GE_CHECK_NOTNULL(node);
  62. map_recv_info_[node].push_back(event_id);
  63. return GRAPH_SUCCESS;
  64. }
  65. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  66. NodeUtils::GetSendEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_send) {
  67. GE_CHECK_NOTNULL(node);
  68. auto find = map_send_info_.find(node);
  69. if (find == map_send_info_.end()) {
  70. return GRAPH_FAILED;
  71. } else {
  72. vec_send = find->second;
  73. return GRAPH_SUCCESS;
  74. }
  75. }
  76. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  77. NodeUtils::GetRecvEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_recv) {
  78. GE_CHECK_NOTNULL(node);
  79. auto find = map_recv_info_.find(node);
  80. if (find == map_recv_info_.end()) {
  81. return GRAPH_FAILED;
  82. } else {
  83. vec_recv = find->second;
  84. return GRAPH_SUCCESS;
  85. }
  86. }
  87. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearSendInfo() {
  88. map_send_info_.clear();
  89. return GRAPH_SUCCESS;
  90. }
  91. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearRecvInfo() {
  92. map_recv_info_.clear();
  93. return GRAPH_SUCCESS;
  94. }
  95. graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst) {
  96. GE_CHECK_NOTNULL(src);
  97. NodePtr cur_ptr;
  98. if (depth < 1) {
  99. return GRAPH_FAILED;
  100. }
  101. for (int i = 0; i < depth; i++) {
  102. if (src->GetOutDataNodes().size() != 1) {
  103. return GRAPH_FAILED;
  104. }
  105. cur_ptr = src->GetOutDataNodes().at(0);
  106. GE_CHECK_NOTNULL(cur_ptr);
  107. }
  108. dst = cur_ptr;
  109. return GRAPH_SUCCESS;
  110. }
  111. graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data,
  112. InControlAnchorPtr &in_control) {
  113. GE_CHECK_NOTNULL(node_ptr);
  114. for (const auto &p : node_ptr->GetAllOutDataAnchors()) {
  115. GE_CHK_BOOL_EXEC((p != nullptr), continue, "GetAllOutDataAnchors is nullptr");
  116. for (const auto &p_in : p->GetPeerInControlAnchors()) {
  117. GE_CHK_BOOL_EXEC((p_in != nullptr), continue, "GetPeerInDataAnchors is nullptr");
  118. out_data = p;
  119. in_control = p_in;
  120. return GRAPH_SUCCESS;
  121. }
  122. }
  123. return GRAPH_FAILED;
  124. }
  125. graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) {
  126. GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED,
  127. "node or in_data_anchor is nullptr");
  128. bool find_flag = false;
  129. uint32_t index = 0;
  130. vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end();
  131. for (const auto &tmp : node_ptr->in_data_anchors_) {
  132. if (tmp == in_data_anchor) {
  133. find_flag = true;
  134. auto iter = node_ptr->in_data_anchors_.begin() + index;
  135. if (iter != node_ptr->in_data_anchors_.end()) {
  136. it = node_ptr->in_data_anchors_.erase(iter);
  137. }
  138. break;
  139. }
  140. index++;
  141. }
  142. for (; it != node_ptr->in_data_anchors_.end(); ++it) {
  143. (*it)->SetIdx(index);
  144. index++;
  145. }
  146. if (!find_flag) {
  147. return GRAPH_FAILED;
  148. }
  149. return GRAPH_SUCCESS;
  150. }
  151. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) {
  152. GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "node is nullptr");
  153. GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, return GRAPH_FAILED, "set all anchor status failed");
  154. return GRAPH_SUCCESS;
  155. }
  156. graphStatus NodeUtils::SetAllAnchorStatus(Node &node) {
  157. node.anchor_status_updated_ = true;
  158. return GRAPH_SUCCESS;
  159. }
  160. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) {
  161. GE_CHK_BOOL_EXEC(node_ptr != nullptr, return false, "node is nullptr");
  162. return IsAnchorStatusSet(*node_ptr);
  163. }
  164. bool NodeUtils::IsAnchorStatusSet(const Node &node) { return node.anchor_status_updated_; }
  165. graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) {
  166. if ((origin_node == nullptr) || (new_node == nullptr)) {
  167. return GRAPH_FAILED;
  168. }
  169. auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors();
  170. auto new_out_data_anchors = new_node->GetAllOutDataAnchors();
  171. if (origin_out_data_anchors.size() != new_out_data_anchors.size()) {
  172. return GRAPH_FAILED;
  173. }
  174. for (size_t i = 0; i < origin_out_data_anchors.size(); ++i) {
  175. for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) {
  176. GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
  177. "unlink peer_anchor failed");
  178. GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  179. "linkto peer_anchor failed");
  180. }
  181. for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) {
  182. GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
  183. "unlink peer_anchor failed");
  184. GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  185. "linkto peer_anchor failed");
  186. }
  187. }
  188. auto origin_out_control_anchor = origin_node->GetOutControlAnchor();
  189. GE_CHECK_NOTNULL(origin_out_control_anchor);
  190. auto new_out_control_anchor = new_node->GetOutControlAnchor();
  191. GE_CHECK_NOTNULL(new_out_control_anchor);
  192. for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) {
  193. GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  194. "linkto peer_anchor failed");
  195. }
  196. for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) {
  197. GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  198. "linkto peer_anchor failed");
  199. }
  200. origin_out_control_anchor->UnlinkAll();
  201. return GRAPH_SUCCESS;
  202. }
  203. bool NodeUtils::IsConst(const Node &node) {
  204. auto src_node_type = node.GetType();
  205. bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP));
  206. return is_const;
  207. }
  208. void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) {
  209. if (node_ptr == nullptr) {
  210. GELOGE(GRAPH_FAILED, "node is null");
  211. return;
  212. }
  213. UpdateIsInputConst(*node_ptr);
  214. }
  215. ///
  216. /// update is_input_const
  217. /// @param node
  218. /// @return void
  219. ///
  220. void NodeUtils::UpdateIsInputConst(Node &node) {
  221. std::vector<bool> is_input_const;
  222. size_t anchor_num = node.GetAllInDataAnchors().size();
  223. for (size_t i = 0; i < anchor_num; i++) {
  224. auto in_anchor = node.GetInDataAnchor(static_cast<int>(i));
  225. if (in_anchor == nullptr) {
  226. is_input_const.push_back(false);
  227. continue;
  228. }
  229. auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
  230. if (peer_out_anchor == nullptr) {
  231. is_input_const.push_back(false);
  232. continue;
  233. }
  234. auto src_node = peer_out_anchor->GetOwnerNode();
  235. if (src_node == nullptr) {
  236. is_input_const.push_back(false);
  237. continue;
  238. }
  239. if (IsConst(*(src_node))) {
  240. is_input_const.push_back(true);
  241. } else {
  242. is_input_const.push_back(false);
  243. }
  244. }
  245. if (node.GetOpDesc() == nullptr) {
  246. GELOGE(GRAPH_FAILED, "Node get opdesc is nullptr");
  247. return;
  248. }
  249. node.GetOpDesc()->SetIsInputConst(is_input_const);
  250. }
  251. void NodeUtils::UnlinkAll(const Node &node) {
  252. for (const auto &anchor : node.GetAllOutAnchors()) {
  253. anchor->UnlinkAll();
  254. }
  255. for (const auto &anchor : node.GetAllInAnchors()) {
  256. anchor->UnlinkAll();
  257. }
  258. }
  259. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeerNodeInputDesc(const NodePtr &node_ptr) {
  260. if (node_ptr == nullptr) {
  261. GELOGE(GRAPH_FAILED, "Nodeptr is nullptr");
  262. return GRAPH_FAILED;
  263. }
  264. auto op_desc = node_ptr->GetOpDesc();
  265. if (op_desc == nullptr) {
  266. return GRAPH_FAILED;
  267. }
  268. bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag();
  269. if (is_unknown_graph) {
  270. return GRAPH_SUCCESS;
  271. }
  272. for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
  273. auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
  274. ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
  275. output_tensor->SetOriginShape(output_tensor->GetShape());
  276. output_tensor->SetOriginDataType(output_tensor->GetDataType());
  277. GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
  278. node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
  279. TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
  280. TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
  281. for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
  282. if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
  283. GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null");
  284. continue;
  285. }
  286. auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx());
  287. if (peer_input_desc == nullptr) {
  288. GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr");
  289. continue;
  290. }
  291. GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d",
  292. peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(),
  293. output_tensor->GetDataType(), output_tensor->GetOriginDataType());
  294. peer_input_desc->SetOriginShape(output_tensor->GetOriginShape());
  295. peer_input_desc->SetShape(output_tensor->GetShape());
  296. peer_input_desc->SetDataType(output_tensor->GetDataType());
  297. peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType());
  298. std::vector<std::pair<int64_t, int64_t>> shape_range;
  299. (void)output_tensor->GetShapeRange(shape_range);
  300. peer_input_desc->SetShapeRange(shape_range);
  301. ge::TensorUtils::SetRealDimCnt(*peer_input_desc,
  302. static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
  303. GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d",
  304. peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(),
  305. peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType());
  306. }
  307. }
  308. return GRAPH_SUCCESS;
  309. }
  310. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node,
  311. uint32_t index) {
  312. if (node == nullptr) {
  313. GELOGE(GRAPH_FAILED, "Nodeptr is nullptr");
  314. return GRAPH_FAILED;
  315. }
  316. GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT);
  317. OpDescPtr op_desc = node->op_;
  318. for (size_t i = op_desc->GetInputsSize(); i < index; ++i) {
  319. if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) {
  320. GELOGE(GRAPH_FAILED, "Add input desc failed");
  321. return GRAPH_FAILED;
  322. }
  323. auto anchor = ComGraphMakeShared<InDataAnchor>(node, i);
  324. if (anchor == nullptr) {
  325. GELOGE(GRAPH_FAILED, "Current in_data_anchor is null, malloc shared_ptr failed.");
  326. return GRAPH_FAILED;
  327. }
  328. node->in_data_anchors_.push_back(anchor);
  329. }
  330. return GRAPH_SUCCESS;
  331. }
  332. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node,
  333. uint32_t index) {
  334. if (node == nullptr) {
  335. GELOGE(GRAPH_FAILED, "Nodeptr is nullptr");
  336. return GRAPH_FAILED;
  337. }
  338. OpDescPtr op_desc = node->op_;
  339. op_desc->RemoveInputDesc(index);
  340. while (node->in_data_anchors_.size() > index) {
  341. node->in_data_anchors_.pop_back();
  342. }
  343. return GRAPH_SUCCESS;
  344. }
  345. bool NodeUtils::IsInNodesEmpty(const Node &node) {
  346. for (const auto &in_anchor : node.in_data_anchors_) {
  347. if (in_anchor != nullptr) {
  348. auto out_anchor = in_anchor->GetPeerOutAnchor();
  349. if (out_anchor != nullptr) {
  350. if (out_anchor->GetOwnerNode() != nullptr) {
  351. return false;
  352. }
  353. }
  354. }
  355. }
  356. if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) {
  357. auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors();
  358. for (const auto &out_control_anchor : peer_out_control_anchors) {
  359. if (out_control_anchor != nullptr) {
  360. if (out_control_anchor->GetOwnerNode() != nullptr) {
  361. return false;
  362. }
  363. }
  364. }
  365. }
  366. return true;
  367. }
  368. GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, uint32_t index) {
  369. auto desc = node.GetOpDesc();
  370. if (desc == nullptr) {
  371. return GeTensorDesc();
  372. }
  373. return desc->GetOutputDesc(index);
  374. }
  375. GeTensorDesc NodeUtils::GetInputDesc(const Node &node, uint32_t index) {
  376. auto desc = node.GetOpDesc();
  377. if (desc == nullptr) {
  378. return GeTensorDesc();
  379. }
  380. return desc->GetInputDesc(index);
  381. }
  382. graphStatus NodeUtils::UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape) {
  383. auto desc = node.GetOpDesc();
  384. if (desc == nullptr) {
  385. return GRAPH_PARAM_INVALID;
  386. }
  387. auto output_desc = desc->MutableOutputDesc(index);
  388. if (output_desc == nullptr) {
  389. return GRAPH_PARAM_INVALID;
  390. }
  391. output_desc->SetShape(shape);
  392. return GRAPH_SUCCESS;
  393. }
  394. graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape) {
  395. auto desc = node.GetOpDesc();
  396. if (desc == nullptr) {
  397. return GRAPH_PARAM_INVALID;
  398. }
  399. auto input_desc = desc->MutableInputDesc(index);
  400. if (input_desc == nullptr) {
  401. return GRAPH_PARAM_INVALID;
  402. }
  403. input_desc->SetShape(shape);
  404. return GRAPH_SUCCESS;
  405. }
  406. graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) {
  407. auto desc = node.GetOpDesc();
  408. GE_CHECK_NOTNULL(desc);
  409. // check self
  410. is_unknow = OpShapeIsUnknown(desc);
  411. if (is_unknow) {
  412. return GRAPH_SUCCESS;
  413. }
  414. auto sub_graph_names = desc->GetSubgraphInstanceNames();
  415. if (sub_graph_names.empty()) {
  416. return GRAPH_SUCCESS;
  417. } else {
  418. auto owner_graph = node.GetOwnerComputeGraph();
  419. GE_CHECK_NOTNULL(owner_graph);
  420. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  421. if (root_graph == nullptr) {
  422. GE_LOGE("Node %s gets null root graph", node.GetName().c_str());
  423. return GRAPH_PARAM_INVALID;
  424. }
  425. for (auto &sub_graph_name : sub_graph_names) {
  426. auto sub_graph = root_graph->GetSubgraph(sub_graph_name);
  427. GE_CHECK_NOTNULL(sub_graph);
  428. for (const auto &node_ptr : sub_graph->GetDirectNode()) {
  429. auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow);
  430. if (status != GRAPH_SUCCESS) {
  431. GE_LOGE("get node unknown shape status failed!");
  432. return status;
  433. }
  434. if (is_unknow) {
  435. return GRAPH_SUCCESS;
  436. }
  437. }
  438. }
  439. }
  440. return GRAPH_SUCCESS;
  441. }
  442. std::string NodeUtils::GetNodeType(const Node &node) {
  443. if (node.GetType() != FRAMEWORKOP) {
  444. return node.GetType();
  445. }
  446. std::string type;
  447. (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
  448. return type;
  449. }
  450. ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) {
  451. auto op_desc = node.GetOpDesc();
  452. if (op_desc == nullptr) {
  453. return nullptr;
  454. }
  455. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  456. if (root_graph == nullptr) {
  457. return nullptr;
  458. }
  459. return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index));
  460. }
  461. graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) {
  462. if (subgraph == nullptr) {
  463. GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index);
  464. return GRAPH_PARAM_INVALID;
  465. }
  466. auto op_desc = node.GetOpDesc();
  467. if (op_desc == nullptr) {
  468. return GRAPH_PARAM_INVALID;
  469. }
  470. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  471. if (root_graph == nullptr) {
  472. GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str());
  473. return GRAPH_PARAM_INVALID;
  474. }
  475. auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName());
  476. if (ret != GRAPH_SUCCESS) {
  477. GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index);
  478. return ret;
  479. }
  480. subgraph->SetParentNode(node.shared_from_this());
  481. subgraph->SetParentGraph(node.GetOwnerComputeGraph());
  482. return root_graph->AddSubgraph(subgraph);
  483. }
  484. ///
  485. /// Check if node is input of subgraph
  486. /// @param [in] node
  487. /// @return bool
  488. ///
  489. bool NodeUtils::IsSubgraphInput(const NodePtr &node) {
  490. if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
  491. (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) {
  492. return false;
  493. }
  494. auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
  495. if (parent_op_desc == nullptr) {
  496. return false;
  497. }
  498. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
  499. bool is_unknown_shape = false;
  500. (void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape);
  501. if (is_unknown_shape) return false;
  502. }
  503. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE) &&
  504. kCaseOpTypes.count(parent_op_desc->GetType()) == 0 && kWhileOpTypes.count(parent_op_desc->GetType()) == 0 &&
  505. kForOpTypes.count(parent_op_desc->GetType()) == 0 && kIfOpTypes.count(parent_op_desc->GetType()) == 0) {
  506. return false;
  507. }
  508. return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX);
  509. }
  510. ///
  511. /// Check if node is output of subgraph
  512. /// @param [in] node
  513. /// @return bool
  514. ///
  515. bool NodeUtils::IsSubgraphOutput(const NodePtr &node) {
  516. if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
  517. (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) {
  518. return false;
  519. }
  520. auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
  521. if (parent_op_desc == nullptr) {
  522. return false;
  523. }
  524. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
  525. bool is_unknown_shape = false;
  526. (void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape);
  527. if (is_unknown_shape) return false;
  528. }
  529. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE) &&
  530. kCaseOpTypes.count(parent_op_desc->GetType()) == 0 && kWhileOpTypes.count(parent_op_desc->GetType()) == 0 &&
  531. kForOpTypes.count(parent_op_desc->GetType()) == 0 && kIfOpTypes.count(parent_op_desc->GetType()) == 0) {
  532. return false;
  533. }
  534. for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) {
  535. if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) {
  536. return true;
  537. }
  538. }
  539. return false;
  540. }
  541. ///
  542. /// @brief Get subgraph original input node.
  543. /// @param [in] node
  544. /// @return Node
  545. ///
  546. NodePtr NodeUtils::GetParentInput(const NodePtr &node) {
  547. GE_CHECK_NOTNULL_EXEC(node, return nullptr);
  548. uint32_t parent_index = 0;
  549. if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  550. return nullptr;
  551. }
  552. // Subgraph Data Node, check for constant input.
  553. const ComputeGraphPtr &graph = node->GetOwnerComputeGraph();
  554. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  555. const NodePtr &parent_node = graph->GetParentNode();
  556. GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr);
  557. const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index);
  558. GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr);
  559. const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor();
  560. GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr);
  561. return peer_out_anchor->GetOwnerNode();
  562. }
  563. ///
  564. /// @brief Check is varying_input for while node
  565. /// @param [in] node: Data node for subgraph
  566. /// @return bool
  567. ///
  568. bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) {
  569. if (node == nullptr) {
  570. return false;
  571. }
  572. if (node->GetType() != DATA) {
  573. return false; // not input_node for subgraph
  574. }
  575. const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode();
  576. if (parent_node == nullptr) {
  577. return false; // root graph
  578. }
  579. if (kWhileOpTypes.count(parent_node->GetType()) == 0) {
  580. return false; // not input_node for while subgraph
  581. }
  582. uint32_t index_i = 0;
  583. if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) {
  584. GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str());
  585. return false;
  586. }
  587. bool varying_flag = true;
  588. for (const auto &item : node->GetOutDataNodesAndAnchors()) {
  589. if (item.first->GetType() != NETOUTPUT) {
  590. continue;
  591. }
  592. OpDescPtr op_desc = item.first->GetOpDesc();
  593. uint32_t index_o = 0;
  594. if ((op_desc == nullptr) ||
  595. !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) {
  596. continue; // input for while-cond subgraph
  597. }
  598. if (index_i != index_o) {
  599. continue; // varying input for while-body subgraph
  600. }
  601. varying_flag = false;
  602. break;
  603. }
  604. return varying_flag;
  605. }
  606. ///
  607. /// @brief Get subgraph input is constant.
  608. /// @param [in] node
  609. /// @param [out] string
  610. /// @return bool
  611. ///
  612. bool NodeUtils::GetConstOpType(const NodePtr &in_node, std::string &op_type) {
  613. GE_CHECK_NOTNULL_EXEC(in_node, return false);
  614. if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
  615. op_type = in_node->GetType();
  616. return true;
  617. }
  618. if (in_node->GetType() == DATA) {
  619. std::string const_type;
  620. if (!AttrUtils::GetStr(in_node->GetOpDesc(), ATTR_NAME_PARENT_CONST_TYPE, const_type)) {
  621. return false;
  622. }
  623. if ((const_type == CONSTANT) || (const_type == CONSTANTOP)) {
  624. op_type = const_type;
  625. return true;
  626. }
  627. }
  628. return false;
  629. }
  630. ///
  631. /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph.
  632. /// @param [in] node
  633. /// @return return GRAPH_SUCCESS if remove successfully, other for failed.
  634. ///
  635. Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) {
  636. GE_CHECK_NOTNULL(node);
  637. auto op_desc = node->GetOpDesc();
  638. GE_CHECK_NOTNULL(op_desc);
  639. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  640. if (subgraph_names.empty()) {
  641. return GRAPH_SUCCESS;
  642. } else {
  643. auto owner_graph = node->GetOwnerComputeGraph();
  644. GE_CHECK_NOTNULL(owner_graph);
  645. auto root_graph = GraphUtils::FindRootGraph(owner_graph);
  646. GE_CHECK_NOTNULL(root_graph);
  647. std::unordered_set<std::string> subgraph_to_remove;
  648. for (auto &subgraph_name : subgraph_names) {
  649. std::deque<std::string> queue;
  650. queue.push_back(subgraph_name);
  651. subgraph_to_remove.insert(subgraph_name);
  652. op_desc->RemoveSubgraphInstanceName(subgraph_name);
  653. while (!queue.empty()) {
  654. auto graph_name = queue.front();
  655. queue.pop_front();
  656. auto subgraph = root_graph->GetSubgraph(graph_name);
  657. GE_CHECK_NOTNULL(subgraph);
  658. for (const auto &sub_node : subgraph->GetDirectNode()) {
  659. auto sub_op_desc = sub_node->GetOpDesc();
  660. GE_CHECK_NOTNULL(sub_op_desc);
  661. auto sub_names = sub_op_desc->GetSubgraphInstanceNames();
  662. // Subgraph and all nodes in it will be removed later,
  663. // no need to remove 'SubgraphInstanceName' in op desc here.
  664. for (auto &name : sub_names) {
  665. if (subgraph_to_remove.insert(name).second) {
  666. queue.push_back(name);
  667. }
  668. }
  669. }
  670. }
  671. }
  672. // Remove subgraph from root_graph
  673. for (const auto &name : subgraph_to_remove) {
  674. GELOGI("Remove subgraph:%s.", name.c_str());
  675. root_graph->RemoveSubgraph(name);
  676. }
  677. }
  678. return GRAPH_SUCCESS;
  679. }
  680. ///
  681. /// @brief Get subgraph input data node by index.
  682. /// @param [in] node
  683. /// @return Node
  684. ///
  685. vector<NodePtr> NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) {
  686. vector<NodePtr> in_data_node_vec;
  687. auto op_desc = node.GetOpDesc();
  688. GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec);
  689. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  690. if (subgraph_names.empty()) {
  691. GELOGW("Node %s is single node without sub graph.", node.GetName().c_str());
  692. return in_data_node_vec;
  693. }
  694. auto compute_graph = node.GetOwnerComputeGraph();
  695. for (const std::string &instance_name : subgraph_names) {
  696. auto subgraph = compute_graph->GetSubgraph(instance_name);
  697. for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
  698. int parent_index = -1;
  699. if (NodeUtils::IsSubgraphInput(node_in_subgraph)) {
  700. (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index);
  701. if (parent_index == index) {
  702. in_data_node_vec.emplace_back(node_in_subgraph);
  703. }
  704. }
  705. }
  706. }
  707. return in_data_node_vec;
  708. }
  709. ///
  710. /// @brief Get subgraph input data node by index.
  711. /// @param [in] node
  712. /// @return Node
  713. ///
  714. vector<NodePtr> NodeUtils::GetSubgraphOutputNodes(const Node &node) {
  715. vector<NodePtr> out_data_node_vec;
  716. auto op_desc = node.GetOpDesc();
  717. GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec);
  718. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  719. if (subgraph_names.empty()) {
  720. GELOGI("Node %s is single node without sub graph.", node.GetName().c_str());
  721. return out_data_node_vec;
  722. }
  723. auto compute_graph = node.GetOwnerComputeGraph();
  724. for (const std::string &instance_name : subgraph_names) {
  725. auto subgraph = compute_graph->GetSubgraph(instance_name);
  726. for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
  727. if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) {
  728. out_data_node_vec.emplace_back(node_in_subgraph);
  729. }
  730. }
  731. }
  732. return out_data_node_vec;
  733. }
  734. NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, int index) {
  735. if (node.GetInDataAnchor(index) == nullptr) {
  736. return nullptr;
  737. }
  738. if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) {
  739. return nullptr;
  740. }
  741. return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode();
  742. }
  743. vector<NodePtr> NodeUtils::GetOutDataNodesByIndex(const Node &node, int index) {
  744. vector<NodePtr> out_data_nodes;
  745. auto out_data_anchor = node.GetOutDataAnchor(index);
  746. if (out_data_anchor == nullptr) {
  747. return out_data_nodes;
  748. }
  749. for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  750. if (peer_in_anchor == nullptr) {
  751. continue;
  752. }
  753. if (peer_in_anchor->GetOwnerNode() == nullptr) {
  754. continue;
  755. }
  756. out_data_nodes.emplace_back(peer_in_anchor->GetOwnerNode());
  757. }
  758. return out_data_nodes;
  759. }
  760. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示