You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gnode.cc 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/gnode.h"
  17. #include <utility>
  18. #include "debug/ge_util.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "graph/anchor.h"
  21. #include "graph/node.h"
  22. #include "graph/utils/node_adapter.h"
  23. #include "graph/utils/tensor_adapter.h"
  24. #include <graph/utils/graph_utils.h>
  25. #include "graph/debug/ge_attr_define.h"
  26. #include "graph/debug/ge_op_types.h"
  27. #include "utils/node_utils.h"
  28. #include "utils/op_desc_utils.h"
  29. namespace ge {
  30. class NodeImpl {
  31. public:
  32. NodeImpl() = default;
  33. ~NodeImpl() = default;
  34. NodeImpl(NodeImpl &) = delete;
  35. NodeImpl &operator=(const NodeImpl &) = delete;
  36. std::weak_ptr<Node> node_ptr_;
  37. };
  38. NodePtr NodeAdapter::GNode2Node(const ge::GNode &graph_node) {
  39. if (graph_node.impl_ == nullptr) {
  40. GELOGE(GRAPH_FAILED, "GNode2Node: gnode impl is nullptr.");
  41. return nullptr;
  42. }
  43. return graph_node.impl_->node_ptr_.lock();
  44. }
  45. GNode NodeAdapter::Node2GNode(const ge::NodePtr &node) {
  46. if (node == nullptr) {
  47. GELOGE(GRAPH_FAILED, "Node2GNode: node is nullptr");
  48. return GNode();
  49. }
  50. GNode graph_node;
  51. if (graph_node.impl_ == nullptr) {
  52. GELOGW("Node2GNode: gnode impl is nullptr, node[%s].", node->GetName().c_str());
  53. return graph_node;
  54. }
  55. graph_node.impl_->node_ptr_ = node;
  56. return graph_node;
  57. }
  58. GNodePtr NodeAdapter::Node2GNodePtr(const ge::NodePtr &node) {
  59. if (node == nullptr) {
  60. GELOGE(GRAPH_FAILED, "Node2GNodePtr: node is nullptr");
  61. return nullptr;
  62. }
  63. GNodePtr gnode = std::shared_ptr<GNode>(new (std::nothrow) GNode());
  64. if (gnode == nullptr) {
  65. GELOGE(GRAPH_FAILED, "Node2GNodePtr: gnode is nullptr, node[%s].", node->GetName().c_str());
  66. return nullptr;
  67. }
  68. if (gnode->impl_ == nullptr) {
  69. GELOGW("Node2GNode: gnode impl is nullptr, node[%s].", node->GetName().c_str());
  70. return nullptr;
  71. }
  72. gnode->impl_->node_ptr_ = node;
  73. return gnode;
  74. }
  75. GNode::GNode() { impl_ = ComGraphMakeShared<NodeImpl>(); }
  76. graphStatus GNode::GetType(AscendString &type) const {
  77. if (impl_ == nullptr) {
  78. GELOGE(GRAPH_FAILED, "GetType: node impl is nullptr.");
  79. return GRAPH_FAILED;
  80. }
  81. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  82. if (node_ptr == nullptr) {
  83. GELOGE(GRAPH_FAILED, "GetType: the shared ptr is not valid.");
  84. return GRAPH_FAILED;
  85. }
  86. std::string node_type = node_ptr->GetType();
  87. AscendString ascend_type(node_type.c_str());
  88. type = ascend_type;
  89. return GRAPH_SUCCESS;
  90. }
  91. graphStatus GNode::GetName(AscendString &name) const {
  92. if (impl_ == nullptr) {
  93. GELOGE(GRAPH_FAILED, "GetName: node impl is nullptr.");
  94. return GRAPH_FAILED;
  95. }
  96. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  97. if (node_ptr == nullptr) {
  98. GELOGE(GRAPH_FAILED, "GetName: the shared ptr is not valid.");
  99. return GRAPH_FAILED;
  100. }
  101. std::string node_name = node_ptr->GetName();
  102. AscendString ascend_name(node_name.c_str());
  103. name = ascend_name;
  104. return GRAPH_SUCCESS;
  105. }
  106. std::pair<GNodePtr, int32_t> GNode::GetInDataNodesAndPortIndexs(const int32_t index) const {
  107. pair<GNodePtr, int32_t> gnode_idx = {nullptr, 0xFF};
  108. if (impl_ == nullptr) {
  109. GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr.");
  110. return gnode_idx;
  111. }
  112. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  113. if (node_ptr == nullptr) {
  114. GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid.");
  115. return gnode_idx;
  116. }
  117. auto in_anchor = node_ptr->GetInDataAnchor(index);
  118. if (in_anchor == nullptr) {
  119. GELOGE(GRAPH_FAILED, "Failed to get in data node of index[%d] from node[%s], the anchor does not exist", index,
  120. node_ptr->GetName().c_str());
  121. return gnode_idx;
  122. }
  123. auto out_anchor = in_anchor->GetPeerOutAnchor();
  124. if (out_anchor == nullptr) {
  125. GELOGE(GRAPH_FAILED, "Failed to get in data node of index[%d] from node [%s], the data input does not exist", index,
  126. node_ptr->GetName().c_str());
  127. return gnode_idx;
  128. }
  129. NodePtr peer_node_ptr = out_anchor->GetOwnerNode();
  130. GNodePtr gnode = NodeAdapter::Node2GNodePtr(peer_node_ptr);
  131. if (gnode == nullptr) {
  132. GELOGE(GRAPH_FAILED, "Peer node of node[%s] to gnode faild.", node_ptr->GetName().c_str());
  133. return gnode_idx;
  134. }
  135. return {gnode, out_anchor->GetIdx()};
  136. }
  137. std::vector<GNodePtr> GNode::GetInControlNodes() const {
  138. if (impl_ == nullptr) {
  139. GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr.");
  140. return {};
  141. }
  142. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  143. if (node_ptr == nullptr) {
  144. GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid.");
  145. return {};
  146. }
  147. std::vector<GNodePtr> gnodes;
  148. auto in_control_nodes = node_ptr->GetInControlNodes();
  149. for (auto &in_control_node : in_control_nodes) {
  150. GNodePtr gnode = NodeAdapter::Node2GNodePtr(in_control_node);
  151. if (gnode == nullptr) {
  152. GELOGE(GRAPH_FAILED, "In control_node of node[%s] to gnode faild.", node_ptr->GetName().c_str());
  153. return {};
  154. }
  155. gnodes.emplace_back(gnode);
  156. }
  157. return gnodes;
  158. }
  159. std::vector<std::pair<GNodePtr, int32_t>> GNode::GetOutDataNodesAndPortIndexs(const int32_t index) const {
  160. if (impl_ == nullptr) {
  161. GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr.");
  162. return {};
  163. }
  164. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  165. if (node_ptr == nullptr) {
  166. GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid.");
  167. return {};
  168. }
  169. auto out_anchor = node_ptr->GetOutDataAnchor(index);
  170. if (out_anchor == nullptr) {
  171. GELOGE(GRAPH_FAILED, "Failed to get out data node of index %d from node %s, the anchor does not exists", index,
  172. node_ptr->GetName().c_str());
  173. return {};
  174. }
  175. vector<std::pair<GNodePtr, int32_t>> gnode_index;
  176. auto in_data_anchors = out_anchor->GetPeerInDataAnchors();
  177. for (auto &in_data_anchor : in_data_anchors) {
  178. if (in_data_anchor == nullptr) {
  179. GELOGE(GRAPH_FAILED, "In data anchor of node[%s] is nullptr.", node_ptr->GetName().c_str());
  180. return {};
  181. }
  182. NodePtr peer_node_ptr = in_data_anchor->GetOwnerNode();
  183. GNodePtr gnode = NodeAdapter::Node2GNodePtr(peer_node_ptr);
  184. if (gnode == nullptr) {
  185. GELOGE(GRAPH_FAILED, "Peer node of node[%s] to gnode faild.", node_ptr->GetName().c_str());
  186. return {};
  187. }
  188. gnode_index.emplace_back(std::pair<GNodePtr, int32_t>(gnode, in_data_anchor->GetIdx()));
  189. }
  190. return gnode_index;
  191. }
  192. std::vector<GNodePtr> GNode::GetOutControlNodes() const {
  193. if (impl_ == nullptr) {
  194. GELOGE(GRAPH_FAILED, "GetOutControlNodes: node impl is nullptr.");
  195. return {};
  196. }
  197. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  198. if (node_ptr == nullptr) {
  199. GELOGE(GRAPH_FAILED, "GetOutControlNodes: the node shared ptr is not valid.");
  200. return {};
  201. }
  202. std::vector<GNodePtr> gnodes;
  203. auto out_control_nodes = node_ptr->GetOutControlNodes();
  204. for (auto &out_control_node : out_control_nodes) {
  205. GNodePtr gnode = NodeAdapter::Node2GNodePtr(out_control_node);
  206. if (gnode == nullptr) {
  207. GELOGE(GRAPH_FAILED, "In control_node of node[%s] to gnode faild.", node_ptr->GetName().c_str());
  208. return {};
  209. }
  210. gnodes.emplace_back(gnode);
  211. }
  212. return gnodes;
  213. }
  214. graphStatus GNode::GetInputConstData(const int32_t index, Tensor &data) const {
  215. if (impl_ == nullptr) {
  216. GELOGE(GRAPH_FAILED, "GetInputConstData: node impl is nullptr.");
  217. return GRAPH_FAILED;
  218. }
  219. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  220. if (node_ptr == nullptr) {
  221. GELOGE(GRAPH_FAILED, "GetInputConstData: the node shared ptr is not valid.");
  222. return GRAPH_FAILED;
  223. }
  224. NodePtr input_data_node = NodeUtils::GetInDataNodeByIndex(*node_ptr, index);
  225. GE_CHECK_NOTNULL(input_data_node);
  226. string op_type = input_data_node->GetType();
  227. if (op_type == CONSTANT || op_type == CONSTANTOP) {
  228. Operator const_op = OpDescUtils::CreateOperatorFromNode(input_data_node);
  229. if (const_op.GetAttr(ATTR_NAME_WEIGHTS, data) != GRAPH_SUCCESS) {
  230. GELOGE(GRAPH_FAILED, "Input data node[%s] of node[%s] get data failed.", input_data_node->GetName().c_str(),
  231. node_ptr->GetName().c_str());
  232. return GRAPH_FAILED;
  233. }
  234. return SUCCESS;
  235. } else if (op_type == DATA) {
  236. auto parent_node = NodeUtils::GetParentInput(input_data_node);
  237. while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) {
  238. parent_node = NodeUtils::GetParentInput(parent_node);
  239. }
  240. if ((parent_node != nullptr) && ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) {
  241. Operator const_op = OpDescUtils::CreateOperatorFromNode(parent_node);
  242. if (const_op.GetAttr(ATTR_NAME_WEIGHTS, data) != GRAPH_SUCCESS) {
  243. GELOGE(GRAPH_FAILED, "Input data node[%s] of node[%s] get data failed.", parent_node->GetName().c_str(),
  244. node_ptr->GetName().c_str());
  245. return GRAPH_FAILED;
  246. }
  247. return GRAPH_SUCCESS;
  248. }
  249. }
  250. GELOGE(GRAPH_NODE_WITHOUT_CONST_INPUT, "Node[%s] has no const input.", node_ptr->GetName().c_str());
  251. return GRAPH_NODE_WITHOUT_CONST_INPUT;
  252. }
  253. graphStatus GNode::GetInputIndexByName(const AscendString &name, int32_t &index) {
  254. const char *ascend_name = name.GetString();
  255. if (ascend_name == nullptr) {
  256. GELOGE(GRAPH_PARAM_INVALID, "GetInputIndexByName: ascend string error.");
  257. return GRAPH_PARAM_INVALID;
  258. }
  259. if (impl_ == nullptr) {
  260. GELOGE(GRAPH_FAILED, "GetInputIndexByName: node impl is nullptr.");
  261. return GRAPH_FAILED;
  262. }
  263. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  264. if (node_ptr == nullptr) {
  265. GELOGE(GRAPH_FAILED, "GetInputIndexByName: the node shared ptr is not valid.");
  266. return GRAPH_FAILED;
  267. }
  268. OpDescPtr op_desc = node_ptr->GetOpDesc();
  269. if (op_desc == nullptr) {
  270. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  271. return GRAPH_FAILED;
  272. }
  273. std::string node_name = ascend_name;
  274. index = op_desc->GetInputIndexByName(node_name);
  275. return GRAPH_SUCCESS;
  276. }
  277. graphStatus GNode::GetOutputIndexByName(const AscendString &name, int32_t &index) {
  278. const char *ascend_name = name.GetString();
  279. if (ascend_name == nullptr) {
  280. GELOGE(GRAPH_PARAM_INVALID, "GetOutputIndexByName: ascend string error.");
  281. return GRAPH_PARAM_INVALID;
  282. }
  283. if (impl_ == nullptr) {
  284. GELOGE(GRAPH_FAILED, "GetOutputIndexByName: node impl is nullptr.");
  285. return GRAPH_FAILED;
  286. }
  287. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  288. if (node_ptr == nullptr) {
  289. GELOGE(GRAPH_FAILED, "GetOutputIndexByName: the node shared ptr is not valid.");
  290. return GRAPH_FAILED;
  291. }
  292. OpDescPtr op_desc = node_ptr->GetOpDesc();
  293. if (op_desc == nullptr) {
  294. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  295. return GRAPH_FAILED;
  296. }
  297. std::string node_name = ascend_name;
  298. index = op_desc->GetOutputIndexByName(node_name);
  299. return GRAPH_SUCCESS;
  300. }
  301. size_t GNode::GetInputsSize() const {
  302. if (impl_ == nullptr) {
  303. GELOGE(GRAPH_FAILED, "GetInputsSize: node impl is nullptr.");
  304. return GRAPH_FAILED;
  305. }
  306. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  307. if (node_ptr == nullptr) {
  308. GELOGE(GRAPH_FAILED, "GetInputsSize: the node shared ptr is not valid.");
  309. return GRAPH_FAILED;
  310. }
  311. OpDescPtr op_desc = node_ptr->GetOpDesc();
  312. if (op_desc == nullptr) {
  313. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  314. return GRAPH_FAILED;
  315. }
  316. return op_desc->GetInputsSize();
  317. }
  318. size_t GNode::GetOutputsSize() const {
  319. if (impl_ == nullptr) {
  320. GELOGE(GRAPH_FAILED, "GetOutputsSize: node impl is nullptr.");
  321. return GRAPH_FAILED;
  322. }
  323. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  324. if (node_ptr == nullptr) {
  325. GELOGE(GRAPH_FAILED, "GetOutputsSize: the shared ptr is not valid.");
  326. return GRAPH_FAILED;
  327. }
  328. OpDescPtr op_desc = node_ptr->GetOpDesc();
  329. if (op_desc == nullptr) {
  330. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  331. return GRAPH_FAILED;
  332. }
  333. return op_desc->GetOutputsSize();
  334. }
  335. graphStatus GNode::GetInputDesc(const int32_t index, TensorDesc &tensor_desc) const {
  336. if (index < 0) {
  337. GELOGE(GRAPH_PARAM_INVALID, "GetInputDesc: index[%d] cannot be less than zero.", index);
  338. return GRAPH_PARAM_INVALID;
  339. }
  340. if (impl_ == nullptr) {
  341. GELOGE(GRAPH_FAILED, "GetInputDesc: node impl is nullptr.");
  342. return GRAPH_FAILED;
  343. }
  344. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  345. if (node_ptr == nullptr) {
  346. GELOGE(GRAPH_FAILED, "GetInputDesc: the node shared ptr is not valid.");
  347. return GRAPH_FAILED;
  348. }
  349. OpDescPtr op_desc = node_ptr->GetOpDesc();
  350. if (op_desc == nullptr) {
  351. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  352. return GRAPH_FAILED;
  353. }
  354. ConstGeTensorDescPtr ge_tensor_desc = op_desc->GetInputDescPtr(static_cast<uint32_t>(index));
  355. if (ge_tensor_desc == nullptr) {
  356. GELOGE(GRAPH_FAILED, "Get tensor desc of node[%s] failed.", node_ptr->GetName().c_str());
  357. return GRAPH_FAILED;
  358. }
  359. tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc);
  360. return GRAPH_SUCCESS;
  361. }
  362. graphStatus GNode::UpdateInputDesc(const int32_t index, const TensorDesc &tensor_desc) {
  363. if (index < 0) {
  364. GELOGE(GRAPH_PARAM_INVALID, "UpdateInputDesc: index[%d] cannot be less than zero.", index);
  365. return GRAPH_PARAM_INVALID;
  366. }
  367. if (impl_ == nullptr) {
  368. GELOGE(GRAPH_FAILED, "UpdateInputDesc: node impl is nullptr.");
  369. return GRAPH_FAILED;
  370. }
  371. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  372. if (node_ptr == nullptr) {
  373. GELOGE(GRAPH_FAILED, "UpdateInputDesc: the node shared ptr is not valid.");
  374. return GRAPH_FAILED;
  375. }
  376. OpDescPtr op_desc = node_ptr->GetOpDesc();
  377. if (op_desc == nullptr) {
  378. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  379. return GRAPH_FAILED;
  380. }
  381. GeTensorDesc ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc);
  382. if (op_desc->UpdateInputDesc(static_cast<uint32_t>(index), ge_tensor_desc) != GRAPH_SUCCESS) {
  383. GELOGE(GRAPH_FAILED, "Update input desc of node[%s] failed.", node_ptr->GetName().c_str());
  384. return GRAPH_FAILED;
  385. }
  386. return GRAPH_SUCCESS;
  387. }
  388. graphStatus GNode::GetOutputDesc(const int32_t index, TensorDesc &tensor_desc) const {
  389. if (index < 0) {
  390. GELOGE(GRAPH_PARAM_INVALID, "GetOutputDesc: index[%d] cannot be less than zero.", index);
  391. return GRAPH_PARAM_INVALID;
  392. }
  393. if (impl_ == nullptr) {
  394. GELOGE(GRAPH_FAILED, "GetOutputDesc: node impl is nullptr.");
  395. return GRAPH_FAILED;
  396. }
  397. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  398. if (node_ptr == nullptr) {
  399. GELOGE(GRAPH_FAILED, "GetOutputDesc: the node shared ptr is not valid.");
  400. return GRAPH_FAILED;
  401. }
  402. OpDescPtr op_desc = node_ptr->GetOpDesc();
  403. if (op_desc == nullptr) {
  404. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  405. return GRAPH_FAILED;
  406. }
  407. ConstGeTensorDescPtr ge_tensor_desc = op_desc->GetOutputDescPtr(static_cast<uint32_t>(index));
  408. if (ge_tensor_desc == nullptr) {
  409. GELOGE(GRAPH_FAILED, "Get tensor desc of node[%s] failed.", node_ptr->GetName().c_str());
  410. return GRAPH_FAILED;
  411. }
  412. tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc);
  413. return GRAPH_SUCCESS;
  414. }
  415. graphStatus GNode::UpdateOutputDesc(const int32_t index, const TensorDesc &tensor_desc) {
  416. if (index < 0) {
  417. GELOGE(GRAPH_PARAM_INVALID, "Gnode: index[%d] cannot be less than zero.", index);
  418. return GRAPH_PARAM_INVALID;
  419. }
  420. if (impl_ == nullptr) {
  421. GELOGE(GRAPH_FAILED, "UpdateOutputDesc: node impl is nullptr.");
  422. return GRAPH_FAILED;
  423. }
  424. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  425. if (node_ptr == nullptr) {
  426. GELOGE(GRAPH_FAILED, "UpdateOutputDesc: the shared ptr is not valid.");
  427. return GRAPH_FAILED;
  428. }
  429. OpDescPtr op_desc = node_ptr->GetOpDesc();
  430. if (op_desc == nullptr) {
  431. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  432. return GRAPH_FAILED;
  433. }
  434. GeTensorDesc ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc);
  435. if (op_desc->UpdateOutputDesc(static_cast<uint32_t>(index), ge_tensor_desc) != GRAPH_SUCCESS) {
  436. GELOGE(GRAPH_FAILED, "Update input desc of node[%s] failed.", node_ptr->GetName().c_str());
  437. return GRAPH_FAILED;
  438. }
  439. return GRAPH_SUCCESS;
  440. }
  441. #define NODE_ATTR_GET_IMP(ArgType) \
  442. graphStatus GNode::GetAttr(const AscendString &name, ArgType &attr_value) const { \
  443. const char *ascend_name = name.GetString(); \
  444. if (ascend_name == nullptr) { \
  445. GELOGE(GRAPH_PARAM_INVALID, "GetAttr: ascend string error."); \
  446. return GRAPH_PARAM_INVALID; \
  447. } \
  448. \
  449. if (impl_ == nullptr) { \
  450. GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr."); \
  451. return GRAPH_FAILED; \
  452. } \
  453. \
  454. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); \
  455. if (node_ptr == nullptr) { \
  456. GELOGE(GRAPH_FAILED, "GetAttr: the shared ptr is not valid."); \
  457. return GRAPH_FAILED; \
  458. } \
  459. \
  460. std::string node_name = ascend_name; \
  461. Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); \
  462. if (op.GetAttr(node_name, attr_value) != GRAPH_SUCCESS) { \
  463. GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str()); \
  464. return GRAPH_FAILED; \
  465. } \
  466. \
  467. return GRAPH_SUCCESS; \
  468. }
  469. #define NODE_ATTR_SET_IMP(ArgType) \
  470. graphStatus GNode::SetAttr(const AscendString &name, ArgType &attr_value) const { \
  471. const char *ascend_name = name.GetString(); \
  472. if (ascend_name == nullptr) { \
  473. GELOGE(GRAPH_PARAM_INVALID, "SetAttr: ascend string error."); \
  474. return GRAPH_PARAM_INVALID; \
  475. } \
  476. \
  477. if (impl_ == nullptr) { \
  478. GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr."); \
  479. return GRAPH_FAILED; \
  480. } \
  481. \
  482. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); \
  483. if (node_ptr == nullptr) { \
  484. GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid."); \
  485. return GRAPH_FAILED; \
  486. } \
  487. \
  488. std::string node_name = ascend_name; \
  489. Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); \
  490. (void)op.SetAttr(node_name, attr_value); \
  491. return GRAPH_SUCCESS; \
  492. }
  493. NODE_ATTR_GET_IMP(int64_t)
  494. NODE_ATTR_GET_IMP(int32_t)
  495. NODE_ATTR_GET_IMP(uint32_t)
  496. NODE_ATTR_GET_IMP(float)
  497. NODE_ATTR_GET_IMP(bool)
  498. NODE_ATTR_GET_IMP(Tensor)
  499. NODE_ATTR_GET_IMP(std::vector<int64_t>)
  500. NODE_ATTR_GET_IMP(std::vector<int32_t>)
  501. NODE_ATTR_GET_IMP(std::vector<uint32_t>)
  502. NODE_ATTR_GET_IMP(std::vector<float>)
  503. NODE_ATTR_GET_IMP(std::vector<bool>)
  504. NODE_ATTR_GET_IMP(std::vector<Tensor>)
  505. NODE_ATTR_GET_IMP(OpBytes)
  506. NODE_ATTR_GET_IMP(std::vector<std::vector<int64_t>>)
  507. NODE_ATTR_GET_IMP(std::vector<ge::DataType>)
  508. NODE_ATTR_GET_IMP(ge::DataType)
  509. NODE_ATTR_GET_IMP(AttrValue)
  510. NODE_ATTR_SET_IMP(int64_t)
  511. NODE_ATTR_SET_IMP(int32_t)
  512. NODE_ATTR_SET_IMP(uint32_t)
  513. NODE_ATTR_SET_IMP(float)
  514. NODE_ATTR_SET_IMP(bool)
  515. NODE_ATTR_SET_IMP(Tensor)
  516. NODE_ATTR_SET_IMP(std::vector<int64_t>)
  517. NODE_ATTR_SET_IMP(std::vector<int32_t>)
  518. NODE_ATTR_SET_IMP(std::vector<uint32_t>)
  519. NODE_ATTR_SET_IMP(std::vector<float>)
  520. NODE_ATTR_SET_IMP(std::vector<bool>)
  521. NODE_ATTR_SET_IMP(std::vector<Tensor>)
  522. NODE_ATTR_SET_IMP(OpBytes)
  523. NODE_ATTR_SET_IMP(std::vector<std::vector<int64_t>>)
  524. NODE_ATTR_SET_IMP(std::vector<ge::DataType>)
  525. NODE_ATTR_SET_IMP(ge::DataType)
  526. graphStatus GNode::SetAttr(const AscendString &name, AttrValue &attr_value) const {
  527. const char *ascend_name = name.GetString();
  528. if (ascend_name == nullptr) {
  529. GELOGE(GRAPH_PARAM_INVALID, "SetAttr: ascend string error.");
  530. return GRAPH_PARAM_INVALID;
  531. }
  532. if (impl_ == nullptr) {
  533. GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr.");
  534. return GRAPH_FAILED;
  535. }
  536. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  537. if (node_ptr == nullptr) {
  538. GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid.");
  539. return GRAPH_FAILED;
  540. }
  541. std::string node_name = ascend_name;
  542. Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr);
  543. (void)op.SetAttr(node_name, std::move(attr_value));
  544. return GRAPH_SUCCESS;
  545. }
  546. graphStatus GNode::SetAttr(const AscendString &name, AscendString &attr_value) const {
  547. const char *ascend_name = name.GetString();
  548. if (ascend_name == nullptr) {
  549. GELOGE(GRAPH_PARAM_INVALID, "SetAttr: name ascend string error.");
  550. return GRAPH_PARAM_INVALID;
  551. }
  552. const char *ascend_attr_value = attr_value.GetString();
  553. if (ascend_attr_value == nullptr) {
  554. GELOGE(GRAPH_PARAM_INVALID, "SetAttr: attr value ascend string error.");
  555. return GRAPH_PARAM_INVALID;
  556. }
  557. if (impl_ == nullptr) {
  558. GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr.");
  559. return GRAPH_FAILED;
  560. }
  561. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  562. if (node_ptr == nullptr) {
  563. GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid.");
  564. return GRAPH_FAILED;
  565. }
  566. std::string node_name = ascend_name;
  567. std::string node_attr_value = ascend_attr_value;
  568. Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr);
  569. (void)op.SetAttr(node_name, node_attr_value);
  570. return GRAPH_SUCCESS;
  571. }
  572. graphStatus GNode::SetAttr(const AscendString &name, std::vector<AscendString> &attr_values) const {
  573. const char *ascend_name = name.GetString();
  574. if (ascend_name == nullptr) {
  575. GELOGE(GRAPH_PARAM_INVALID, "SetAttr: name ascend string error.");
  576. return GRAPH_PARAM_INVALID;
  577. }
  578. for (auto &attr_val : attr_values) {
  579. const char *ascend_attr_value = attr_val.GetString();
  580. if (ascend_attr_value == nullptr) {
  581. GELOGE(GRAPH_PARAM_INVALID, "SetAttr: attr val error.");
  582. return GRAPH_PARAM_INVALID;
  583. }
  584. }
  585. if (impl_ == nullptr) {
  586. GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr.");
  587. return GRAPH_FAILED;
  588. }
  589. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  590. if (node_ptr == nullptr) {
  591. GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid.");
  592. return GRAPH_FAILED;
  593. }
  594. vector<std::string> node_attr_vals;
  595. for (auto attr_val : attr_values) {
  596. if (attr_val.GetString() != nullptr) {
  597. std::string node_attr_val = attr_val.GetString();
  598. node_attr_vals.emplace_back(node_attr_val);
  599. }
  600. }
  601. std::string node_name = ascend_name;
  602. Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr);
  603. (void)op.SetAttr(node_name, node_attr_vals);
  604. return GRAPH_SUCCESS;
  605. }
  606. graphStatus GNode::GetAttr(const AscendString &name, AscendString &attr_value) const {
  607. const char *ascend_name = name.GetString();
  608. if (ascend_name == nullptr) {
  609. GELOGE(GRAPH_PARAM_INVALID, "GetAttr: name ascend string error.");
  610. return GRAPH_PARAM_INVALID;
  611. }
  612. if (impl_ == nullptr) {
  613. GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr.");
  614. return GRAPH_FAILED;
  615. }
  616. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  617. if (node_ptr == nullptr) {
  618. GELOGE(GRAPH_FAILED, "GetAttr: the node shared ptr is not valid.");
  619. return GRAPH_FAILED;
  620. }
  621. std::string node_name = ascend_name;
  622. Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr);
  623. std::string op_name;
  624. if (op.GetAttr(node_name, op_name) != GRAPH_SUCCESS) {
  625. GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str());
  626. return GRAPH_FAILED;
  627. }
  628. AscendString attr_value_get(op_name.c_str());
  629. attr_value = attr_value_get;
  630. return GRAPH_SUCCESS;
  631. }
  632. graphStatus GNode::GetAttr(const AscendString &name, std::vector<AscendString> &attr_values) const {
  633. const char *ascend_name = name.GetString();
  634. if (ascend_name == nullptr) {
  635. GELOGE(GRAPH_PARAM_INVALID, "GetAttr: name ascend string error.");
  636. return GRAPH_PARAM_INVALID;
  637. }
  638. if (impl_ == nullptr) {
  639. GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr.");
  640. return GRAPH_FAILED;
  641. }
  642. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  643. if (node_ptr == nullptr) {
  644. GELOGE(GRAPH_FAILED, "GetAttr: the node shared ptr is not valid.");
  645. return GRAPH_FAILED;
  646. }
  647. std::string node_name = ascend_name;
  648. Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr);
  649. vector<std::string> attr_names;
  650. if (op.GetAttr(node_name, attr_names) != GRAPH_SUCCESS) {
  651. GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str());
  652. return GRAPH_FAILED;
  653. }
  654. for (auto &attr_name : attr_names) {
  655. AscendString ascend_attr_name(attr_name.c_str());
  656. attr_values.push_back(ascend_attr_name);
  657. }
  658. return GRAPH_SUCCESS;
  659. }
  660. bool GNode::HasAttr(const AscendString &name) {
  661. const char *ascend_name = name.GetString();
  662. if (ascend_name == nullptr) {
  663. GELOGE(GRAPH_PARAM_INVALID, "HasAttr: ascend string error.");
  664. return false;
  665. }
  666. if (impl_ == nullptr) {
  667. GELOGE(GRAPH_FAILED, "HasAttr: node impl is nullptr.");
  668. return false;
  669. }
  670. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  671. if (node_ptr == nullptr) {
  672. GELOGE(GRAPH_FAILED, "HasAttr: the node shared ptr is not valid.");
  673. return false;
  674. }
  675. OpDescPtr op_desc = node_ptr->GetOpDesc();
  676. if (op_desc == nullptr) {
  677. GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
  678. return false;
  679. }
  680. std::string attr_name = ascend_name;
  681. if (!op_desc->HasAttr(attr_name)) {
  682. GELOGE(GRAPH_FAILED, "Node[%s] has no attr name[%s]", node_ptr->GetName().c_str(), attr_name.c_str());
  683. return false;
  684. }
  685. return true;
  686. }
  687. graphStatus GNode::GetSubgraph(uint32_t index, GraphPtr &graph) const {
  688. if (impl_ == nullptr) {
  689. GELOGE(GRAPH_FAILED, "GetSubgraph: node impl is nullptr.");
  690. return GRAPH_FAILED;
  691. }
  692. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  693. if (node_ptr == nullptr) {
  694. GELOGE(GRAPH_FAILED, "GetSubgraph: the node shared ptr is not valid.");
  695. return GRAPH_FAILED;
  696. }
  697. ComputeGraphPtr compute_graph_ptr = NodeUtils::GetSubgraph(*node_ptr, index);
  698. if (compute_graph_ptr == nullptr) {
  699. GELOGE(GRAPH_FAILED, "GetSubgraph: get subgraph[%u] failed from node[%s].", index, node_ptr->GetName().c_str());
  700. return GRAPH_FAILED;
  701. }
  702. graph = GraphUtils::CreateGraphPtrFromComputeGraph(compute_graph_ptr);
  703. if (graph == nullptr) {
  704. GELOGE(GRAPH_FAILED, "GetSubgraph: get subgraph[%u] failed from node[%s].", index, node_ptr->GetName().c_str());
  705. return GRAPH_FAILED;
  706. }
  707. return GRAPH_SUCCESS;
  708. }
  709. graphStatus GNode::GetALLSubgraphs(std::vector<GraphPtr> &graph_list) const {
  710. if (impl_ == nullptr) {
  711. GELOGE(GRAPH_FAILED, "GetALLSubgraphs: node impl is nullptr.");
  712. return GRAPH_FAILED;
  713. }
  714. std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
  715. if (node_ptr == nullptr) {
  716. GELOGE(GRAPH_FAILED, "GetALLSubgraphs: the node shared ptr is not valid.");
  717. return GRAPH_FAILED;
  718. }
  719. std::vector<ComputeGraphPtr> sub_graphs = NodeUtils::GetAllSubgraphs(*node_ptr);
  720. if (sub_graphs.empty()) {
  721. GELOGE(GRAPH_FAILED, "GetALLSubgraphs: get all subgraphs failed from node[%s].", node_ptr->GetName().c_str());
  722. return GRAPH_FAILED;
  723. }
  724. for (auto &sub_graph : sub_graphs) {
  725. if (sub_graph == nullptr) {
  726. GELOGE(GRAPH_FAILED, "Get subgraph failed from node[%s].", node_ptr->GetName().c_str());
  727. return GRAPH_FAILED;
  728. }
  729. GraphPtr graph = GraphUtils::CreateGraphPtrFromComputeGraph(sub_graph);
  730. if (graph == nullptr) {
  731. GELOGE(GRAPH_FAILED, "Subgraph create compute graph failed from node[%s].", node_ptr->GetName().c_str());
  732. return GRAPH_FAILED;
  733. }
  734. graph_list.emplace_back(graph);
  735. }
  736. if (graph_list.empty()) {
  737. GELOGW("Node[%s] has no subgraph.", node_ptr->GetName().c_str());
  738. }
  739. return GRAPH_SUCCESS;
  740. }
  741. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示