You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gnode.cc 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/gnode.h"
  17. #include <utility>
  18. #include "debug/ge_util.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "graph/anchor.h"
  21. #include "graph/node.h"
  22. #include "graph/utils/node_adapter.h"
  23. #include "graph/utils/tensor_adapter.h"
  24. #include <graph/utils/graph_utils.h>
  25. #include "graph/debug/ge_attr_define.h"
  26. #include "graph/debug/ge_op_types.h"
  27. #include "utils/node_utils.h"
  28. #include "utils/op_desc_utils.h"
  29. namespace ge {
  30. class NodeImpl {
  31. public:
  32. NodeImpl() = default;
  33. ~NodeImpl() = default;
  34. NodeImpl(NodeImpl &) = delete;
  35. NodeImpl &operator=(const NodeImpl &) = delete;
  36. std::weak_ptr<Node> node_ptr_;
  37. };
  38. NodePtr NodeAdapter::GNode2Node(const ge::GNode &graph_node) {
  39. if (graph_node.impl_ == nullptr) {
  40. GELOGE(GRAPH_FAILED, "GNode2Node: gnode impl is nullptr.");
  41. return nullptr;
  42. }
  43. return graph_node.impl_->node_ptr_.lock();
  44. }
  45. GNode NodeAdapter::Node2GNode(const ge::NodePtr &node) {
  46. if (node == nullptr) {
  47. GELOGE(GRAPH_FAILED, "Node2GNode: node is nullptr");
  48. return GNode();
  49. }
  50. GNode graph_node;
  51. if (graph_node.impl_ == nullptr) {
  52. GELOGW("Node2GNode: gnode impl is nullptr, node[%s].", node->GetName().c_str());
  53. return graph_node;
  54. }
  55. graph_node.impl_->node_ptr_ = node;
  56. return graph_node;
  57. }
  58. GNodePtr NodeAdapter::Node2GNodePtr(const ge::NodePtr &node) {
  59. if (node == nullptr) {
  60. GELOGE(GRAPH_FAILED, "Node2GNodePtr: node is nullptr");
  61. return nullptr;
  62. }
  63. GNodePtr gnode = std::shared_ptr<GNode>(new (std::nothrow) GNode());
  64. if (gnode == nullptr) {
  65. GELOGE(GRAPH_FAILED, "Node2GNodePtr: gnode is nullptr, node[%s].", node->GetName().c_str());
  66. return nullptr;
  67. }
  68. if (gnode->impl_ == nullptr) {
  69. GELOGW("Node2GNode: gnode impl is nullptr, node[%s].", node->GetName().c_str());
  70. return nullptr;
  71. }
  72. gnode->impl_->node_ptr_ = node;
  73. return gnode;
  74. }
  75. GNode::GNode() { impl_ = ComGraphMakeShared<NodeImpl>(); }
  76. graphStatus GNode::GetType(AscendString &type) const {
  77. if (impl_ == nullptr) {
  78. GELOGE(GRAPH_FAILED, "GetType: node impl is nullptr.");
  79. return GRAPH_FAILED;
  80. }
  81. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  82. if (node_ptr == nullptr) {
  83. GELOGE(GRAPH_FAILED, "GetType: the shared ptr is not valid.");
  84. return GRAPH_FAILED;
  85. }
  86. std::string node_type = node_ptr->GetType();
  87. AscendString ascend_type(node_type.c_str());
  88. type = ascend_type;
  89. return GRAPH_SUCCESS;
  90. }
  91. graphStatus GNode::GetName(AscendString &name) const {
  92. if (impl_ == nullptr) {
  93. GELOGE(GRAPH_FAILED, "GetName: node impl is nullptr.");
  94. return GRAPH_FAILED;
  95. }
  96. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  97. if (node_ptr == nullptr) {
  98. GELOGE(GRAPH_FAILED, "GetName: the shared ptr is not valid.");
  99. return GRAPH_FAILED;
  100. }
  101. std::string node_name = node_ptr->GetName();
  102. AscendString ascend_name(node_name.c_str());
  103. name = ascend_name;
  104. return GRAPH_SUCCESS;
  105. }
  106. std::pair<GNodePtr, int32_t> GNode::GetInDataNodesAndPortIndexs(const int32_t index) const {
  107. pair<GNodePtr, int32_t> gnode_idx = {nullptr, 0xFF};
  108. if (impl_ == nullptr) {
  109. GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr.");
  110. return gnode_idx;
  111. }
  112. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  113. if (node_ptr == nullptr) {
  114. GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid.");
  115. return gnode_idx;
  116. }
  117. auto in_anchor = node_ptr->GetInDataAnchor(index);
  118. if (in_anchor == nullptr) {
  119. GELOGE(GRAPH_FAILED, "Failed to get in data node of index[%d] from node[%s], the anchor does not exist",
  120. index, node_ptr->GetName().c_str());
  121. return gnode_idx;
  122. }
  123. auto out_anchor = in_anchor->GetPeerOutAnchor();
  124. if (out_anchor == nullptr) {
  125. GELOGE(GRAPH_FAILED, "Failed to get in data node of index[%d] from node [%s], the data input does not exist",
  126. index, node_ptr->GetName().c_str());
  127. return gnode_idx;
  128. }
  129. NodePtr peer_node_ptr = out_anchor->GetOwnerNode();
  130. GNodePtr gnode = NodeAdapter::Node2GNodePtr(peer_node_ptr);
  131. if (gnode == nullptr) {
  132. GELOGE(GRAPH_FAILED, "Peer node of node[%s] to gnode faild.", node_ptr->GetName().c_str());
  133. return gnode_idx;
  134. }
  135. return {gnode, out_anchor->GetIdx()};
  136. }
  137. std::vector<GNodePtr> GNode::GetInControlNodes() const {
  138. if (impl_ == nullptr) {
  139. GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr.");
  140. return {};
  141. }
  142. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  143. if (node_ptr == nullptr) {
  144. GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid.");
  145. return {};
  146. }
  147. std::vector<GNodePtr> gnodes;
  148. auto in_control_nodes = node_ptr->GetInControlNodes();
  149. for (auto &in_control_node : in_control_nodes) {
  150. GNodePtr gnode = NodeAdapter::Node2GNodePtr(in_control_node);
  151. if (gnode == nullptr) {
  152. GELOGE(GRAPH_FAILED, "In control_node of node[%s] to gnode faild.", node_ptr->GetName().c_str());
  153. return {};
  154. }
  155. gnodes.emplace_back(gnode);
  156. }
  157. return gnodes;
  158. }
  159. std::vector<std::pair<GNodePtr, int32_t>> GNode::GetOutDataNodesAndPortIndexs(const int32_t index) const {
  160. if (impl_ == nullptr) {
  161. GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr.");
  162. return {};
  163. }
  164. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  165. if (node_ptr == nullptr) {
  166. GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid.");
  167. return {};
  168. }
  169. auto out_anchor = node_ptr->GetOutDataAnchor(index);
  170. if (out_anchor == nullptr) {
  171. GELOGE(GRAPH_FAILED, "Failed to get out data node of index %d from node %s, the anchor does not exists",
  172. index, node_ptr->GetName().c_str());
  173. return {};
  174. }
  175. vector<std::pair<GNodePtr, int32_t>> gnode_index;
  176. auto in_data_anchors = out_anchor->GetPeerInDataAnchors();
  177. for (auto &in_data_anchor : in_data_anchors) {
  178. if (in_data_anchor == nullptr) {
  179. GELOGE(GRAPH_FAILED, "In data anchor of node[%s] is nullptr.", node_ptr->GetName().c_str());
  180. return {};
  181. }
  182. NodePtr peer_node_ptr = in_data_anchor->GetOwnerNode();
  183. GNodePtr gnode = NodeAdapter::Node2GNodePtr(peer_node_ptr);
  184. if (gnode == nullptr) {
  185. GELOGE(GRAPH_FAILED, "Peer node of node[%s] to gnode faild.", node_ptr->GetName().c_str());
  186. return {};
  187. }
  188. gnode_index.emplace_back(std::pair<GNodePtr, int32_t>(gnode, in_data_anchor->GetIdx()));
  189. }
  190. return gnode_index;
  191. }
  192. std::vector<GNodePtr> GNode::GetOutControlNodes() const {
  193. if (impl_ == nullptr) {
  194. GELOGE(GRAPH_FAILED, "GetOutControlNodes: node impl is nullptr.");
  195. return {};
  196. }
  197. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  198. if (node_ptr == nullptr) {
  199. GELOGE(GRAPH_FAILED, "GetOutControlNodes: the node shared ptr is not valid.");
  200. return {};
  201. }
  202. std::vector<GNodePtr> gnodes;
  203. auto out_control_nodes = node_ptr->GetOutControlNodes();
  204. for (auto &out_control_node : out_control_nodes) {
  205. GNodePtr gnode = NodeAdapter::Node2GNodePtr(out_control_node);
  206. if (gnode == nullptr) {
  207. GELOGE(GRAPH_FAILED, "In control_node of node[%s] to gnode faild.", node_ptr->GetName().c_str());
  208. return {};
  209. }
  210. gnodes.emplace_back(gnode);
  211. }
  212. return gnodes;
  213. }
  214. graphStatus GNode::GetInputConstData(const int32_t index, Tensor &data) const {
  215. if (impl_ == nullptr) {
  216. GELOGE(GRAPH_FAILED, "GetInputConstData: node impl is nullptr.");
  217. return GRAPH_FAILED;
  218. }
  219. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  220. if (node_ptr == nullptr) {
  221. GELOGE(GRAPH_FAILED, "GetInputConstData: the node shared ptr is not valid.");
  222. return GRAPH_FAILED;
  223. }
  224. NodePtr input_data_node = NodeUtils::GetInDataNodeByIndex(*node_ptr, index);
  225. GE_CHECK_NOTNULL(input_data_node);
  226. string op_type = input_data_node->GetType();
  227. if (op_type == CONSTANT || op_type == CONSTANTOP) {
  228. Operator const_op = OpDescUtils::CreateOperatorFromNode(input_data_node);
  229. if (const_op.GetAttr(ATTR_NAME_WEIGHTS, data) != GRAPH_SUCCESS) {
  230. GELOGE(GRAPH_FAILED, "Input data node[%s] of node[%s] get data failed.",
  231. input_data_node->GetName().c_str(), node_ptr->GetName().c_str());
  232. return GRAPH_FAILED;
  233. }
  234. return SUCCESS;
  235. } else if (op_type == DATA) {
  236. auto parent_node = NodeUtils::GetParentInput(input_data_node);
  237. while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) {
  238. parent_node = NodeUtils::GetParentInput(parent_node);
  239. }
  240. if ((parent_node != nullptr) &&
  241. ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) {
  242. Operator const_op = OpDescUtils::CreateOperatorFromNode(parent_node);
  243. if (const_op.GetAttr(ATTR_NAME_WEIGHTS, data) != GRAPH_SUCCESS) {
  244. GELOGE(GRAPH_FAILED, "Input data node[%s] of node[%s] get data failed.",
  245. parent_node->GetName().c_str(), node_ptr->GetName().c_str());
  246. return GRAPH_FAILED;
  247. }
  248. return GRAPH_SUCCESS;
  249. }
  250. }
  251. GELOGE(GRAPH_NODE_WITHOUT_CONST_INPUT, "Node[%s] has no const input.", node_ptr->GetName().c_str());
  252. return GRAPH_NODE_WITHOUT_CONST_INPUT;
  253. }
  254. graphStatus GNode::GetInputIndexByName(const AscendString &name, int32_t &index) {
  255. const char* ascend_name = name.GetString();
  256. if (ascend_name == nullptr) {
  257. GELOGE(GRAPH_PARAM_INVALID, "GetInputIndexByName: ascend string error.");
  258. return GRAPH_PARAM_INVALID;
  259. }
  260. if (impl_ == nullptr) {
  261. GELOGE(GRAPH_FAILED, "GetInputIndexByName: node impl is nullptr.");
  262. return GRAPH_FAILED;
  263. }
  264. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  265. if (node_ptr == nullptr) {
  266. GELOGE(GRAPH_FAILED, "GetInputIndexByName: the node shared ptr is not valid.");
  267. return GRAPH_FAILED;
  268. }
  269. OpDescPtr op_desc = node_ptr->GetOpDesc();
  270. if (op_desc == nullptr) {
  271. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  272. return GRAPH_FAILED;
  273. }
  274. std::string node_name = ascend_name;
  275. index = op_desc->GetInputIndexByName(node_name);
  276. return GRAPH_SUCCESS;
  277. }
  278. graphStatus GNode::GetOutputIndexByName(const AscendString &name, int32_t &index) {
  279. const char* ascend_name = name.GetString();
  280. if (ascend_name == nullptr) {
  281. GELOGE(GRAPH_PARAM_INVALID, "GetOutputIndexByName: ascend string error.");
  282. return GRAPH_PARAM_INVALID;
  283. }
  284. if (impl_ == nullptr) {
  285. GELOGE(GRAPH_FAILED, "GetOutputIndexByName: node impl is nullptr.");
  286. return GRAPH_FAILED;
  287. }
  288. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  289. if (node_ptr == nullptr) {
  290. GELOGE(GRAPH_FAILED, "GetOutputIndexByName: the node shared ptr is not valid.");
  291. return GRAPH_FAILED;
  292. }
  293. OpDescPtr op_desc = node_ptr->GetOpDesc();
  294. if (op_desc == nullptr) {
  295. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  296. return GRAPH_FAILED;
  297. }
  298. std::string node_name = ascend_name;
  299. index = op_desc->GetOutputIndexByName(node_name);
  300. return GRAPH_SUCCESS;
  301. }
  302. size_t GNode::GetInputsSize() const {
  303. if (impl_ == nullptr) {
  304. GELOGE(GRAPH_FAILED, "GetInputsSize: node impl is nullptr.");
  305. return GRAPH_FAILED;
  306. }
  307. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  308. if (node_ptr == nullptr) {
  309. GELOGE(GRAPH_FAILED, "GetInputsSize: the node shared ptr is not valid.");
  310. return GRAPH_FAILED;
  311. }
  312. OpDescPtr op_desc = node_ptr->GetOpDesc();
  313. if (op_desc == nullptr) {
  314. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  315. return GRAPH_FAILED;
  316. }
  317. return op_desc->GetInputsSize();
  318. }
  319. size_t GNode::GetOutputsSize() const {
  320. if (impl_ == nullptr) {
  321. GELOGE(GRAPH_FAILED, "GetOutputsSize: node impl is nullptr.");
  322. return GRAPH_FAILED;
  323. }
  324. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  325. if (node_ptr == nullptr) {
  326. GELOGE(GRAPH_FAILED, "GetOutputsSize: the shared ptr is not valid.");
  327. return GRAPH_FAILED;
  328. }
  329. OpDescPtr op_desc = node_ptr->GetOpDesc();
  330. if (op_desc == nullptr) {
  331. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  332. return GRAPH_FAILED;
  333. }
  334. return op_desc->GetOutputsSize();
  335. }
  336. graphStatus GNode::GetInputDesc(const int32_t index, TensorDesc &tensor_desc) const {
  337. if (index < 0) {
  338. GELOGE(GRAPH_PARAM_INVALID, "GetInputDesc: index[%d] cannot be less than zero.", index);
  339. return GRAPH_PARAM_INVALID;
  340. }
  341. if (impl_ == nullptr) {
  342. GELOGE(GRAPH_FAILED, "GetInputDesc: node impl is nullptr.");
  343. return GRAPH_FAILED;
  344. }
  345. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  346. if (node_ptr == nullptr) {
  347. GELOGE(GRAPH_FAILED, "GetInputDesc: the node shared ptr is not valid.");
  348. return GRAPH_FAILED;
  349. }
  350. OpDescPtr op_desc = node_ptr->GetOpDesc();
  351. if (op_desc == nullptr) {
  352. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  353. return GRAPH_FAILED;
  354. }
  355. ConstGeTensorDescPtr ge_tensor_desc = op_desc->GetInputDescPtr(static_cast<uint32_t>(index));
  356. if (ge_tensor_desc == nullptr) {
  357. GELOGE(GRAPH_FAILED, "Get tensor desc of node[%s] failed.", node_ptr->GetName().c_str());
  358. return GRAPH_FAILED;
  359. }
  360. tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc);
  361. return GRAPH_SUCCESS;
  362. }
  363. graphStatus GNode::UpdateInputDesc(const int32_t index, const TensorDesc &tensor_desc) {
  364. if (index < 0) {
  365. GELOGE(GRAPH_PARAM_INVALID, "UpdateInputDesc: index[%d] cannot be less than zero.", index);
  366. return GRAPH_PARAM_INVALID;
  367. }
  368. if (impl_ == nullptr) {
  369. GELOGE(GRAPH_FAILED, "UpdateInputDesc: node impl is nullptr.");
  370. return GRAPH_FAILED;
  371. }
  372. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  373. if (node_ptr == nullptr) {
  374. GELOGE(GRAPH_FAILED, "UpdateInputDesc: the node shared ptr is not valid.");
  375. return GRAPH_FAILED;
  376. }
  377. OpDescPtr op_desc = node_ptr->GetOpDesc();
  378. if (op_desc == nullptr) {
  379. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  380. return GRAPH_FAILED;
  381. }
  382. GeTensorDesc ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc);
  383. if (op_desc->UpdateInputDesc(static_cast<uint32_t>(index), ge_tensor_desc) != GRAPH_SUCCESS) {
  384. GELOGE(GRAPH_FAILED, "Update input desc of node[%s] failed.", node_ptr->GetName().c_str());
  385. return GRAPH_FAILED;
  386. }
  387. return GRAPH_SUCCESS;
  388. }
  389. graphStatus GNode::GetOutputDesc(const int32_t index, TensorDesc &tensor_desc) const {
  390. if (index < 0) {
  391. GELOGE(GRAPH_PARAM_INVALID, "GetOutputDesc: index[%d] cannot be less than zero.", index);
  392. return GRAPH_PARAM_INVALID;
  393. }
  394. if (impl_ == nullptr) {
  395. GELOGE(GRAPH_FAILED, "GetOutputDesc: node impl is nullptr.");
  396. return GRAPH_FAILED;
  397. }
  398. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  399. if (node_ptr == nullptr) {
  400. GELOGE(GRAPH_FAILED, "GetOutputDesc: the node shared ptr is not valid.");
  401. return GRAPH_FAILED;
  402. }
  403. OpDescPtr op_desc = node_ptr->GetOpDesc();
  404. if (op_desc == nullptr) {
  405. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  406. return GRAPH_FAILED;
  407. }
  408. ConstGeTensorDescPtr ge_tensor_desc = op_desc->GetOutputDescPtr(static_cast<uint32_t>(index));
  409. if (ge_tensor_desc == nullptr) {
  410. GELOGE(GRAPH_FAILED, "Get tensor desc of node[%s] failed.", node_ptr->GetName().c_str());
  411. return GRAPH_FAILED;
  412. }
  413. tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc);
  414. return GRAPH_SUCCESS;
  415. }
  416. graphStatus GNode::UpdateOutputDesc(const int32_t index, const TensorDesc &tensor_desc) {
  417. if (index < 0) {
  418. GELOGE(GRAPH_PARAM_INVALID, "Gnode: index[%d] cannot be less than zero.", index);
  419. return GRAPH_PARAM_INVALID;
  420. }
  421. if (impl_ == nullptr) {
  422. GELOGE(GRAPH_FAILED, "UpdateOutputDesc: node impl is nullptr.");
  423. return GRAPH_FAILED;
  424. }
  425. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  426. if (node_ptr == nullptr) {
  427. GELOGE(GRAPH_FAILED, "UpdateOutputDesc: the shared ptr is not valid.");
  428. return GRAPH_FAILED;
  429. }
  430. OpDescPtr op_desc = node_ptr->GetOpDesc();
  431. if (op_desc == nullptr) {
  432. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  433. return GRAPH_FAILED;
  434. }
  435. GeTensorDesc ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc);
  436. if (op_desc->UpdateOutputDesc(static_cast<uint32_t>(index), ge_tensor_desc) != GRAPH_SUCCESS) {
  437. GELOGE(GRAPH_FAILED, "Update input desc of node[%s] failed.", node_ptr->GetName().c_str());
  438. return GRAPH_FAILED;
  439. }
  440. return GRAPH_SUCCESS;
  441. }
  442. #define NODE_ATTR_GET_IMP(ArgType) \
  443. graphStatus GNode::GetAttr(const AscendString &name, ArgType &attr_value) const { \
  444. const char* ascend_name = name.GetString(); \
  445. if (ascend_name == nullptr) { \
  446. GELOGE(GRAPH_PARAM_INVALID, "GetAttr: ascend string error."); \
  447. return GRAPH_PARAM_INVALID; \
  448. } \
  449. \
  450. if (impl_ == nullptr) { \
  451. GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr."); \
  452. return GRAPH_FAILED; \
  453. } \
  454. \
  455. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); \
  456. if (node_ptr == nullptr) { \
  457. GELOGE(GRAPH_FAILED, "GetAttr: the shared ptr is not valid."); \
  458. return GRAPH_FAILED; \
  459. } \
  460. \
  461. std::string node_name = ascend_name; \
  462. Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); \
  463. if (op.GetAttr(node_name, attr_value) != GRAPH_SUCCESS) { \
  464. GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str()); \
  465. return GRAPH_FAILED; \
  466. } \
  467. \
  468. return GRAPH_SUCCESS; \
  469. }
  470. #define NODE_ATTR_SET_IMP(ArgType) \
  471. graphStatus GNode::SetAttr(const AscendString &name, ArgType &attr_value) const { \
  472. const char* ascend_name = name.GetString(); \
  473. if (ascend_name == nullptr) { \
  474. GELOGE(GRAPH_PARAM_INVALID, "SetAttr: ascend string error."); \
  475. return GRAPH_PARAM_INVALID; \
  476. } \
  477. \
  478. if (impl_ == nullptr) { \
  479. GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr."); \
  480. return GRAPH_FAILED; \
  481. } \
  482. \
  483. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); \
  484. if (node_ptr == nullptr) { \
  485. GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid."); \
  486. return GRAPH_FAILED; \
  487. } \
  488. \
  489. std::string node_name = ascend_name; \
  490. Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); \
  491. (void)op.SetAttr(node_name, attr_value); \
  492. return GRAPH_SUCCESS; \
  493. }
  494. NODE_ATTR_GET_IMP(int64_t)
  495. NODE_ATTR_GET_IMP(int32_t)
  496. NODE_ATTR_GET_IMP(uint32_t)
  497. NODE_ATTR_GET_IMP(float)
  498. NODE_ATTR_GET_IMP(bool)
  499. NODE_ATTR_GET_IMP(Tensor)
  500. NODE_ATTR_GET_IMP(std::vector<int64_t>)
  501. NODE_ATTR_GET_IMP(std::vector<int32_t>)
  502. NODE_ATTR_GET_IMP(std::vector<uint32_t>)
  503. NODE_ATTR_GET_IMP(std::vector<float>)
  504. NODE_ATTR_GET_IMP(std::vector<bool>)
  505. NODE_ATTR_GET_IMP(std::vector<Tensor>)
  506. NODE_ATTR_GET_IMP(OpBytes)
  507. NODE_ATTR_GET_IMP(std::vector<std::vector<int64_t>>)
  508. NODE_ATTR_GET_IMP(std::vector<ge::DataType>)
  509. NODE_ATTR_GET_IMP(ge::DataType)
  510. NODE_ATTR_GET_IMP(AttrValue)
  511. NODE_ATTR_SET_IMP(int64_t)
  512. NODE_ATTR_SET_IMP(int32_t)
  513. NODE_ATTR_SET_IMP(uint32_t)
  514. NODE_ATTR_SET_IMP(float)
  515. NODE_ATTR_SET_IMP(bool)
  516. NODE_ATTR_SET_IMP(Tensor)
  517. NODE_ATTR_SET_IMP(std::vector<int64_t>)
  518. NODE_ATTR_SET_IMP(std::vector<int32_t>)
  519. NODE_ATTR_SET_IMP(std::vector<uint32_t>)
  520. NODE_ATTR_SET_IMP(std::vector<float>)
  521. NODE_ATTR_SET_IMP(std::vector<bool>)
  522. NODE_ATTR_SET_IMP(std::vector<Tensor>)
  523. NODE_ATTR_SET_IMP(OpBytes)
  524. NODE_ATTR_SET_IMP(std::vector<std::vector<int64_t>>)
  525. NODE_ATTR_SET_IMP(std::vector<ge::DataType>)
  526. NODE_ATTR_SET_IMP(ge::DataType)
  527. graphStatus GNode::SetAttr(const AscendString &name, AttrValue &attr_value) const {
  528. const char* ascend_name = name.GetString();
  529. if (ascend_name == nullptr) {
  530. GELOGE(GRAPH_PARAM_INVALID, "SetAttr: ascend string error.");
  531. return GRAPH_PARAM_INVALID;
  532. }
  533. if (impl_ == nullptr) {
  534. GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr.");
  535. return GRAPH_FAILED;
  536. }
  537. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  538. if (node_ptr == nullptr) {
  539. GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid.");
  540. return GRAPH_FAILED;
  541. }
  542. std::string node_name = ascend_name;
  543. Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr);
  544. (void)op.SetAttr(node_name, std::move(attr_value));
  545. return GRAPH_SUCCESS;
  546. }
  547. graphStatus GNode::SetAttr(const AscendString &name, AscendString &attr_value) const {
  548. const char* ascend_name = name.GetString();
  549. if (ascend_name == nullptr) {
  550. GELOGE(GRAPH_PARAM_INVALID, "SetAttr: name ascend string error.");
  551. return GRAPH_PARAM_INVALID;
  552. }
  553. const char* ascend_attr_value = attr_value.GetString();
  554. if (ascend_attr_value == nullptr) {
  555. GELOGE(GRAPH_PARAM_INVALID, "SetAttr: attr value ascend string error.");
  556. return GRAPH_PARAM_INVALID;
  557. }
  558. if (impl_ == nullptr) {
  559. GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr.");
  560. return GRAPH_FAILED;
  561. }
  562. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  563. if (node_ptr == nullptr) {
  564. GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid.");
  565. return GRAPH_FAILED;
  566. }
  567. std::string node_name = ascend_name;
  568. std::string node_attr_value = ascend_attr_value;
  569. Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr);
  570. (void)op.SetAttr(node_name, node_attr_value);
  571. return GRAPH_SUCCESS;
  572. }
  573. graphStatus GNode::SetAttr(const AscendString &name, std::vector<AscendString> &attr_values) const {
  574. const char* ascend_name = name.GetString();
  575. if (ascend_name == nullptr) {
  576. GELOGE(GRAPH_PARAM_INVALID, "SetAttr: name ascend string error.");
  577. return GRAPH_PARAM_INVALID;
  578. }
  579. for (auto &attr_val : attr_values) {
  580. const char* ascend_attr_value = attr_val.GetString();
  581. if (ascend_attr_value == nullptr) {
  582. GELOGE(GRAPH_PARAM_INVALID, "SetAttr: attr val error.");
  583. return GRAPH_PARAM_INVALID;
  584. }
  585. }
  586. if (impl_ == nullptr) {
  587. GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr.");
  588. return GRAPH_FAILED;
  589. }
  590. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  591. if (node_ptr == nullptr) {
  592. GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid.");
  593. return GRAPH_FAILED;
  594. }
  595. vector<std::string> node_attr_vals;
  596. for (auto attr_val : attr_values) {
  597. if (attr_val.GetString() != nullptr) {
  598. std::string node_attr_val = attr_val.GetString();
  599. node_attr_vals.emplace_back(node_attr_val);
  600. }
  601. }
  602. std::string node_name = ascend_name;
  603. Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr);
  604. (void)op.SetAttr(node_name, node_attr_vals);
  605. return GRAPH_SUCCESS;
  606. }
  607. graphStatus GNode::GetAttr(const AscendString &name, AscendString &attr_value) const {
  608. const char* ascend_name = name.GetString();
  609. if (ascend_name == nullptr) {
  610. GELOGE(GRAPH_PARAM_INVALID, "GetAttr: name ascend string error.");
  611. return GRAPH_PARAM_INVALID;
  612. }
  613. if (impl_ == nullptr) {
  614. GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr.");
  615. return GRAPH_FAILED;
  616. }
  617. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  618. if (node_ptr == nullptr) {
  619. GELOGE(GRAPH_FAILED, "GetAttr: the node shared ptr is not valid.");
  620. return GRAPH_FAILED;
  621. }
  622. std::string node_name = ascend_name;
  623. Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr);
  624. std::string op_name;
  625. if (op.GetAttr(node_name, op_name) != GRAPH_SUCCESS) {
  626. GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str());
  627. return GRAPH_FAILED;
  628. }
  629. AscendString attr_value_get(op_name.c_str());
  630. attr_value = attr_value_get;
  631. return GRAPH_SUCCESS;
  632. }
  633. graphStatus GNode::GetAttr(const AscendString &name, std::vector<AscendString> &attr_values) const {
  634. const char* ascend_name = name.GetString();
  635. if (ascend_name == nullptr) {
  636. GELOGE(GRAPH_PARAM_INVALID, "GetAttr: name ascend string error.");
  637. return GRAPH_PARAM_INVALID;
  638. }
  639. if (impl_ == nullptr) {
  640. GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr.");
  641. return GRAPH_FAILED;
  642. }
  643. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  644. if (node_ptr == nullptr) {
  645. GELOGE(GRAPH_FAILED, "GetAttr: the node shared ptr is not valid.");
  646. return GRAPH_FAILED;
  647. }
  648. std::string node_name = ascend_name;
  649. Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr);
  650. vector<std::string> attr_names;
  651. if (op.GetAttr(node_name, attr_names) != GRAPH_SUCCESS) {
  652. GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str());
  653. return GRAPH_FAILED;
  654. }
  655. for (auto &attr_name : attr_names) {
  656. AscendString ascend_attr_name(attr_name.c_str());
  657. attr_values.push_back(ascend_attr_name);
  658. }
  659. return GRAPH_SUCCESS;
  660. }
  661. bool GNode::HasAttr(const AscendString &name) {
  662. const char* ascend_name = name.GetString();
  663. if (ascend_name == nullptr) {
  664. GELOGE(GRAPH_PARAM_INVALID, "HasAttr: ascend string error.");
  665. return false;
  666. }
  667. if (impl_ == nullptr) {
  668. GELOGE(GRAPH_FAILED, "HasAttr: node impl is nullptr.");
  669. return false;
  670. }
  671. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  672. if (node_ptr == nullptr) {
  673. GELOGE(GRAPH_FAILED, "HasAttr: the node shared ptr is not valid.");
  674. return false;
  675. }
  676. OpDescPtr op_desc = node_ptr->GetOpDesc();
  677. if (op_desc == nullptr) {
  678. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  679. return false;
  680. }
  681. std::string attr_name = ascend_name;
  682. if (!op_desc->HasAttr(attr_name)) {
  683. GELOGE(GRAPH_FAILED, "Node[%s] has no attr name[%s]", node_ptr->GetName().c_str(), attr_name.c_str());
  684. return false;
  685. }
  686. return true;
  687. }
  688. graphStatus GNode::GetSubgraph(uint32_t index, GraphPtr &graph) const {
  689. if (impl_ == nullptr) {
  690. GELOGE(GRAPH_FAILED, "GetSubgraph: node impl is nullptr.");
  691. return GRAPH_FAILED;
  692. }
  693. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  694. if (node_ptr == nullptr) {
  695. GELOGE(GRAPH_FAILED, "GetSubgraph: the node shared ptr is not valid.");
  696. return GRAPH_FAILED;
  697. }
  698. ComputeGraphPtr compute_graph_ptr = NodeUtils::GetSubgraph(*node_ptr, index);
  699. if (compute_graph_ptr == nullptr) {
  700. GELOGE(GRAPH_FAILED, "GetSubgraph: get subgraph[%u] failed from node[%s].", index, node_ptr->GetName().c_str());
  701. return GRAPH_FAILED;
  702. }
  703. graph = GraphUtils::CreateGraphPtrFromComputeGraph(compute_graph_ptr);
  704. if (graph == nullptr) {
  705. GELOGE(GRAPH_FAILED, "GetSubgraph: get subgraph[%u] failed from node[%s].", index, node_ptr->GetName().c_str());
  706. return GRAPH_FAILED;
  707. }
  708. return GRAPH_SUCCESS;
  709. }
  710. graphStatus GNode::GetALLSubgraphs(std::vector<GraphPtr> &graph_list) const {
  711. if (impl_ == nullptr) {
  712. GELOGE(GRAPH_FAILED, "GetALLSubgraphs: node impl is nullptr.");
  713. return GRAPH_FAILED;
  714. }
  715. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  716. if (node_ptr == nullptr) {
  717. GELOGE(GRAPH_FAILED, "GetALLSubgraphs: the node shared ptr is not valid.");
  718. return GRAPH_FAILED;
  719. }
  720. std::vector<ComputeGraphPtr> sub_graphs = NodeUtils::GetAllSubgraphs(*node_ptr);
  721. if (sub_graphs.empty()) {
  722. GELOGE(GRAPH_FAILED, "GetALLSubgraphs: get all subgraphs failed from node[%s].", node_ptr->GetName().c_str());
  723. return GRAPH_FAILED;
  724. }
  725. for (auto &sub_graph : sub_graphs) {
  726. if (sub_graph == nullptr) {
  727. GELOGE(GRAPH_FAILED, "Get subgraph failed from node[%s].", node_ptr->GetName().c_str());
  728. return GRAPH_FAILED;
  729. }
  730. GraphPtr graph = GraphUtils::CreateGraphPtrFromComputeGraph(sub_graph);
  731. if (graph == nullptr) {
  732. GELOGE(GRAPH_FAILED, "Subgraph create compute graph failed from node[%s].", node_ptr->GetName().c_str());
  733. return GRAPH_FAILED;
  734. }
  735. graph_list.emplace_back(graph);
  736. }
  737. if (graph_list.empty()) {
  738. GELOGW("Node[%s] has no subgraph.", node_ptr->GetName().c_str());
  739. }
  740. return GRAPH_SUCCESS;
  741. }
  742. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示