You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph.cc 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "external/graph/graph.h"
  17. #include <cstring>
  18. #include "debug/ge_util.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "graph/debug/ge_attr_define.h"
  21. #include "graph/debug/ge_op_types.h"
  22. #include "graph/model.h"
  23. #include "graph/utils/graph_utils.h"
  24. #include "graph/utils/op_desc_utils.h"
  25. #include "graph/utils/node_adapter.h"
  26. #include "graph/utils/node_utils.h"
  27. using std::map;
  28. using std::pair;
  29. using std::string;
  30. using std::vector;
  31. namespace ge {
  32. class GraphImpl {
  33. public:
  34. friend class GraphUtils;
  35. GraphImpl(const GraphImpl &) = delete;
  36. GraphImpl &operator=(const GraphImpl &) = delete;
  37. explicit GraphImpl(const std::string &name) : name_(name) {}
  38. ~GraphImpl() {
  39. if (IsValid()) {
  40. if (compute_graph_ != nullptr) {
  41. GraphUtils::BreakConnect(compute_graph_->GetAllNodesInfo());
  42. }
  43. }
  44. for (const auto &it : op_list_) {
  45. Operator op = it.second;
  46. op.BreakConnect();
  47. }
  48. }
  49. graphStatus SetInputs(const std::vector<Operator> &inputs) {
  50. compute_graph_ = GraphUtils::CreateGraphFromOperator(name_, inputs);
  51. GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "Build Graph failed.");
  52. GE_CHK_BOOL_RET_STATUS(inputs.size() != 0, GRAPH_FAILED, "set input NULL.");
  53. compute_graph_->SetInputSize(static_cast<uint32_t>(inputs.size()));
  54. return GRAPH_SUCCESS;
  55. }
  56. graphStatus SetOutputs(const std::vector<Operator> &outputs) {
  57. if (compute_graph_ == nullptr) {
  58. GELOGE(GRAPH_FAILED, "set ComputeGraph failed.");
  59. return GRAPH_FAILED;
  60. }
  61. if (outputs.empty()) {
  62. GELOGW("set outputs size is 0.");
  63. return GRAPH_SUCCESS;
  64. }
  65. // Construct special output node
  66. std::vector<std::pair<Operator, std::vector<size_t>>> output_indexs;
  67. for (size_t i = 0; i < outputs.size(); ++i) {
  68. output_indexs.emplace_back(outputs[i], std::vector<size_t>{});
  69. }
  70. graphStatus ret = SetOutputs(output_indexs);
  71. return ret;
  72. }
  73. graphStatus SetOutputs(const std::vector<std::pair<Operator, std::vector<size_t>>> &output_indexs) {
  74. if (compute_graph_ == nullptr) {
  75. GELOGE(GRAPH_FAILED, "set ComputeGraph failed.");
  76. return GRAPH_FAILED;
  77. }
  78. if (output_indexs.empty()) {
  79. GELOGW("set outputs size is 0.");
  80. return GRAPH_SUCCESS;
  81. }
  82. // Construct special output node
  83. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes;
  84. for (const auto &item : output_indexs) {
  85. const Operator &output = item.first;
  86. const vector<size_t> &indexs = item.second;
  87. ge::NodePtr node = compute_graph_->FindNode(output.GetName());
  88. if (node == nullptr) {
  89. GELOGW("user designated out_node [%s] not exist in graph, will ignored!", output.GetName().c_str());
  90. continue;
  91. }
  92. ge::OpDescPtr tmp_op_ptr = node->GetOpDesc();
  93. GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue);
  94. size_t out_size = tmp_op_ptr->GetOutputsSize();
  95. if (indexs.empty()) {
  96. for (size_t i = 0; i < out_size; ++i) {
  97. output_name_ += output.GetName() + ":" + std::to_string(i) + ";";
  98. output_nodes.emplace_back(node, i);
  99. }
  100. } else {
  101. for (size_t i = 0; i < indexs.size(); ++i) {
  102. if (indexs[i] >= out_size) {
  103. GELOGW("index[%zu] is not belong to out_node[%s]", indexs[i], output.GetName().c_str());
  104. } else {
  105. output_name_ += output.GetName() + ":" + std::to_string(i) + ";";
  106. output_nodes.emplace_back(node, indexs[i]);
  107. }
  108. }
  109. }
  110. }
  111. // Del last ";"
  112. if (!output_name_.empty()) {
  113. output_name_ = output_name_.substr(0, output_name_.length() - 1);
  114. }
  115. compute_graph_->SetUserDefOutput(output_name_);
  116. compute_graph_->SetOutputSize(static_cast<uint32_t>(output_indexs.size()));
  117. compute_graph_->SetGraphOutNodesInfo(output_nodes);
  118. return GRAPH_SUCCESS;
  119. }
  120. graphStatus SetOutputs(const std::vector<pair<Operator, string>> &outputs) {
  121. GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild.");
  122. GE_CHK_BOOL_EXEC_INFO(outputs.size() != 0, return GRAPH_SUCCESS, "set outputs size is 0.");
  123. // Construct specified output
  124. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes;
  125. for (auto item : outputs) {
  126. ge::NodePtr node = compute_graph_->FindNode(item.first.GetName());
  127. if (node == nullptr) {
  128. GELOGE(GRAPH_FAILED, " Warning, user designated out_node (%s) not exist in graph, this out_node ignored!",
  129. item.first.GetName().c_str());
  130. return GRAPH_FAILED;
  131. }
  132. ge::OpDescPtr tmp_op_ptr = node->GetOpDesc();
  133. GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue);
  134. size_t out_size = tmp_op_ptr->GetOutputsSize();
  135. if (item.second.empty()) {
  136. for (size_t i = 0; i < out_size; ++i) {
  137. output_name_ += item.first.GetName() + ":" + std::to_string(i) + ";";
  138. output_nodes.push_back(std::make_pair(node, i));
  139. }
  140. } else {
  141. int32_t index = tmp_op_ptr->GetOutputIndexByName(item.second);
  142. if (index < 0) {
  143. GELOGE(GRAPH_FAILED,
  144. " Warning, user designated out_node (%s):(%s) not exist in graph, this out_node ignored!",
  145. item.first.GetName().c_str(), item.second.c_str());
  146. return GRAPH_FAILED;
  147. }
  148. output_name_ += item.first.GetName() + ":" + std::to_string(index) + ";";
  149. output_nodes.push_back(std::make_pair(node, index));
  150. }
  151. }
  152. // Del last ";"
  153. if (!output_name_.empty()) {
  154. output_name_ = output_name_.substr(0, output_name_.length() - 1);
  155. }
  156. compute_graph_->SetOutputSize(static_cast<uint32_t>(outputs.size()));
  157. compute_graph_->SetGraphOutNodesInfo(output_nodes);
  158. GELOGI("********************SetOutputs Success***********************");
  159. GE_IF_BOOL_EXEC(!output_name_.empty(), GELOGI(" NetOutputs: (%s)", output_name_.c_str()));
  160. return GRAPH_SUCCESS;
  161. }
  162. graphStatus SetTargets(const std::vector<Operator> &targets) {
  163. GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild.");
  164. GE_CHK_BOOL_EXEC_INFO(targets.size() != 0, return GRAPH_SUCCESS, "set targets size is 0.");
  165. std::vector<ge::NodePtr> target_nodes;
  166. for (auto item : targets) {
  167. ge::NodePtr node = compute_graph_->FindNode(item.GetName());
  168. if (node == nullptr) {
  169. GELOGW(" Warning, user designated target_node (%s) not exist in graph, this target_node ignored!",
  170. item.GetName().c_str());
  171. continue;
  172. }
  173. target_nodes.push_back(node);
  174. }
  175. compute_graph_->SetGraphTargetNodesInfo(target_nodes);
  176. return GRAPH_SUCCESS;
  177. }
  178. bool IsValid() const { return (compute_graph_ != nullptr); }
  179. graphStatus AddOp(const ge::Operator &op) {
  180. std::pair<std::map<string, ge::Operator>::iterator, bool> ret;
  181. ret = op_list_.emplace(std::pair<string, ge::Operator>(op.GetName(), op));
  182. GE_CHK_BOOL_RET_STATUS(ret.second != false, GRAPH_FAILED, "the op have added before, op name:%s.",
  183. op.GetName().c_str());
  184. return GRAPH_SUCCESS;
  185. }
  186. graphStatus GetAllOpName(std::vector<string> &op_name) const {
  187. for (const auto &it : op_list_) {
  188. op_name.push_back(it.second.GetName());
  189. }
  190. return GRAPH_SUCCESS;
  191. }
  192. graphStatus FindOpByName(const string &name, ge::Operator &op) const {
  193. auto it = op_list_.find(name);
  194. GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str());
  195. op = it->second;
  196. return GRAPH_SUCCESS;
  197. }
  198. graphStatus FindOpByType(const string &type, std::vector<ge::Operator> &ops) const {
  199. for (auto &op : op_list_) {
  200. auto op_type = op.second.GetOpType();
  201. if (op_type == type) {
  202. ops.push_back(op.second);
  203. continue;
  204. }
  205. if (op_type == ge::FRAMEWORKOP) {
  206. op.second.GetAttr(ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, op_type);
  207. if (op_type == type) {
  208. ops.push_back(op.second);
  209. }
  210. }
  211. }
  212. return GRAPH_SUCCESS;
  213. }
  214. void SetNeedIteration(bool need_iteration) {
  215. if (compute_graph_ == nullptr) {
  216. GELOGE(GRAPH_FAILED, "Set need iteration failed, as compute graph is null.");
  217. return;
  218. }
  219. compute_graph_->SetNeedIteration(need_iteration);
  220. }
  221. const std::string &GetName() const {
  222. return name_;
  223. }
  224. ComputeGraphPtr GetComputeGraph() const {
  225. return compute_graph_;
  226. }
  227. graphStatus RemoveEdge(NodePtr &src_node_ptr, const int32_t src_port_index,
  228. NodePtr &dst_node_ptr, const int32_t dst_port_index) {
  229. GE_CHECK_NOTNULL(src_node_ptr);
  230. GE_CHECK_NOTNULL(dst_node_ptr);
  231. graphStatus res = GRAPH_FAILED;
  232. if ((src_port_index == -1) && (dst_port_index == -1)) {
  233. if (src_node_ptr->GetOutControlAnchor() == nullptr) {
  234. GELOGE(GRAPH_FAILED, "RemoveEdge: src node[%s] out control anchor is null.", src_node_ptr->GetName().c_str());
  235. return GRAPH_FAILED;
  236. }
  237. res = GraphUtils::RemoveEdge(src_node_ptr->GetOutControlAnchor(), dst_node_ptr->GetInControlAnchor());
  238. if (res != GRAPH_SUCCESS) {
  239. GELOGE(GRAPH_FAILED, "RemoveEdge: remove control edge between [%s] and [%s]failed.",
  240. src_node_ptr->GetName().c_str(), dst_node_ptr->GetName().c_str());
  241. return GRAPH_FAILED;
  242. }
  243. return GRAPH_SUCCESS;
  244. }
  245. if (src_node_ptr->GetOutDataAnchor(src_port_index) == nullptr) {
  246. GELOGE(GRAPH_FAILED, "RemoveEdge: src node[%s] out data anchor[%d] is null.",
  247. src_node_ptr->GetName().c_str(), src_port_index);
  248. return GRAPH_FAILED;
  249. }
  250. if (src_port_index != -1 && dst_port_index == -1) {
  251. res = GraphUtils::RemoveEdge(src_node_ptr->GetOutDataAnchor(src_port_index), dst_node_ptr->GetInControlAnchor());
  252. if (res != GRAPH_SUCCESS) {
  253. GELOGE(GRAPH_FAILED, "RemoveEdge: remove data-control edge between [%s] and [%s]failed.",
  254. src_node_ptr->GetName().c_str(), dst_node_ptr->GetName().c_str());
  255. return GRAPH_FAILED;
  256. }
  257. return GRAPH_SUCCESS;
  258. }
  259. res = GraphUtils::RemoveEdge(src_node_ptr->GetOutDataAnchor(src_port_index),
  260. dst_node_ptr->GetInDataAnchor(dst_port_index));
  261. if (res != GRAPH_SUCCESS) {
  262. GELOGE(GRAPH_FAILED, "RemoveEdge: remove data edge between [%s] and [%s] failed.",
  263. src_node_ptr->GetName().c_str(), dst_node_ptr->GetName().c_str());
  264. return GRAPH_FAILED;
  265. }
  266. return GRAPH_SUCCESS;
  267. }
  268. private:
  269. std::string name_;
  270. std::string output_name_;
  271. std::map<string, ge::Operator> op_list_;
  272. ComputeGraphPtr compute_graph_{nullptr};
  273. };
  274. Graph::Graph(const std::string &name) {
  275. impl_ = ComGraphMakeShared<GraphImpl>(name);
  276. if (impl_ == nullptr) {
  277. GELOGW("GraphImpl make shared failed, impl_ is nullptr");
  278. }
  279. }
  280. Graph::Graph(const char *name) {
  281. if (name != nullptr) {
  282. std::string graph_name = name;
  283. impl_ = ComGraphMakeShared<GraphImpl>(graph_name);
  284. if (impl_ == nullptr) {
  285. GELOGW("GraphImpl make shared failed, impl_ is nullptr.");
  286. }
  287. } else {
  288. GELOGW("Graph name is nullptr.");
  289. }
  290. }
  291. graphStatus Graph::AddOp(const ge::Operator &op) {
  292. GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, "AddOp failed: graph can not be used, impl is nullptr.");
  293. return impl_->AddOp(op);
  294. }
  295. graphStatus Graph::GetAllOpName(std::vector<std::string> &op_name) const {
  296. GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED,
  297. "GetAllOpName failed: graph can not be used, impl is nullptr.");
  298. return impl_->GetAllOpName(op_name);
  299. }
  300. graphStatus Graph::GetAllOpName(std::vector<AscendString> &names) const {
  301. GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED,
  302. "GetAllOpName failed: graph can not be used, impl is nullptr.");
  303. std::vector<std::string> op_names;
  304. if (impl_->GetAllOpName(op_names) != GRAPH_SUCCESS) {
  305. GELOGE(GRAPH_FAILED, "Get all op name failed.");
  306. return GRAPH_FAILED;
  307. }
  308. for (auto &op_name : op_names) {
  309. names.emplace_back(op_name.c_str());
  310. }
  311. return GRAPH_SUCCESS;
  312. }
  313. graphStatus Graph::FindOpByName(const std::string &name, Operator &op) const {
  314. Operator op_find_op_def("NULL");
  315. op = op_find_op_def;
  316. GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED,
  317. "FindOpByName failed: graph can not be used, impl is nullptr.");
  318. return impl_->FindOpByName(name, op);
  319. }
  320. graphStatus Graph::FindOpByName(const char *name, Operator &op) const {
  321. if (name == nullptr) {
  322. GELOGE(GRAPH_FAILED, "FindOpByName: name is nullptr.");
  323. return GRAPH_FAILED;
  324. }
  325. Operator op_find_op_def("NULL");
  326. op = op_find_op_def;
  327. GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED,
  328. "FindOpByName failed: graph can not be used, impl is nullptr.");
  329. std::string op_name = name;
  330. return impl_->FindOpByName(op_name, op);
  331. }
  332. graphStatus Graph::FindOpByType(const string &type, std::vector<ge::Operator> &ops) const {
  333. GE_CHECK_NOTNULL(impl_);
  334. return impl_->FindOpByType(type, ops);
  335. }
  336. graphStatus Graph::FindOpByType(const char *type, std::vector<ge::Operator> &ops) const {
  337. if (type == nullptr) {
  338. GELOGE(GRAPH_FAILED, "FindOpByType: name is nullptr.");
  339. return GRAPH_FAILED;
  340. }
  341. GE_CHECK_NOTNULL(impl_);
  342. std::string op_type = type;
  343. return impl_->FindOpByType(op_type, ops);
  344. }
  345. Graph &Graph::SetInputs(const vector<ge::Operator> &inputs) {
  346. GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetInputs failed: graph can not be used, impl is nullptr.")
  347. GE_CHK_BOOL_EXEC(inputs.size() > 0, return *this, "SetInputs failed: input operator size can not be 0.");
  348. (void)impl_->SetInputs(inputs);
  349. return *this;
  350. }
  351. Graph &Graph::SetOutputs(const vector<ge::Operator> &outputs) {
  352. if (impl_ == nullptr) {
  353. GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr.");
  354. return *this;
  355. }
  356. (void)impl_->SetOutputs(outputs);
  357. return *this;
  358. }
  359. Graph &Graph::SetOutputs(const std::vector<std::pair<Operator, std::vector<size_t>>> &output_indexs) {
  360. if (impl_ == nullptr) {
  361. GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr.");
  362. return *this;
  363. }
  364. (void)impl_->SetOutputs(output_indexs);
  365. return *this;
  366. }
  367. Graph &Graph::SetOutputs(const std::vector<pair<Operator, string>> &outputs) {
  368. GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetOutputs failed: graph can not be used, impl is nullptr.")
  369. (void)impl_->SetOutputs(outputs);
  370. return *this;
  371. }
  372. Graph &Graph::SetOutputs(const std::vector<std::pair<ge::Operator, AscendString>> &outputs) {
  373. GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetOutputs failed: graph can not be used, impl is nullptr.")
  374. vector<std::pair<ge::Operator, std::string>> graph_outputs;
  375. for (auto &item : outputs) {
  376. const char *name = item.second.GetString();
  377. if (name != nullptr) {
  378. string output_name = name;
  379. graph_outputs.emplace_back((std::pair<ge::Operator, std::string>(item.first, name)));
  380. } else {
  381. GELOGW("Output name is nullptr.");
  382. }
  383. }
  384. (void)impl_->SetOutputs(graph_outputs);
  385. return *this;
  386. }
  387. Graph &Graph::SetTargets(const vector<ge::Operator> &targets) {
  388. if (impl_ == nullptr) {
  389. GELOGE(GRAPH_FAILED, "SetTargets failed: graph can not be used, impl is nullptr.");
  390. return *this;
  391. }
  392. (void)impl_->SetTargets(targets);
  393. return *this;
  394. }
  395. bool Graph::IsValid() const {
  396. if (impl_ == nullptr) {
  397. return false;
  398. }
  399. return impl_->IsValid();
  400. }
  401. void Graph::SetNeedIteration(bool need_iteration) {
  402. if (impl_ == nullptr) {
  403. GELOGE(GRAPH_FAILED, "Set need iteration failed, as impl is null.");
  404. return;
  405. }
  406. impl_->SetNeedIteration(need_iteration);
  407. }
  408. std::vector<GNode> Graph::GetAllNodes() const {
  409. std::vector<GNode> graph_nodes;
  410. if (impl_ == nullptr) {
  411. GELOGE(GRAPH_FAILED, "GetAllNodes: graph can not be used, impl is nullptr.");
  412. return graph_nodes;
  413. }
  414. ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph();
  415. if (compute_graph_ptr == nullptr) {
  416. GELOGE(GRAPH_FAILED, "GetAllNodes: compute graph ptr is nullptr.");
  417. return graph_nodes;
  418. }
  419. for (auto &node : compute_graph_ptr->GetAllNodes()) {
  420. GNode gnode = NodeAdapter::Node2GNode(node);
  421. graph_nodes.emplace_back(gnode);
  422. }
  423. return graph_nodes;
  424. }
  425. std::vector<GNode> Graph::GetDirectNode() const {
  426. std::vector<GNode> graph_nodes;
  427. if (impl_ == nullptr) {
  428. GELOGE(GRAPH_FAILED, "GetDirectNode: graph can not be used, impl is nullptr.");
  429. return graph_nodes;
  430. }
  431. ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph();
  432. if (compute_graph_ptr == nullptr) {
  433. GELOGE(GRAPH_FAILED, "GetDirectNode: compute graph ptr is nullptr.");
  434. return graph_nodes;
  435. }
  436. for (auto &node : compute_graph_ptr->GetDirectNode()) {
  437. GNode gnode = NodeAdapter::Node2GNode(node);
  438. graph_nodes.emplace_back(gnode);
  439. }
  440. return graph_nodes;
  441. }
  442. graphStatus Graph::RemoveNode(GNode &node) {
  443. if (impl_ == nullptr) {
  444. GELOGE(GRAPH_FAILED, "RemoveNode: graph can not be used, impl is nullptr.");
  445. return GRAPH_FAILED;
  446. }
  447. NodePtr node_ptr = NodeAdapter::GNode2Node(node);
  448. if (node_ptr == nullptr) {
  449. GELOGE(GRAPH_FAILED, "RemoveNode: gnode to node failed.");
  450. return GRAPH_FAILED;
  451. }
  452. if (node_ptr->GetOwnerComputeGraph() == nullptr) {
  453. GELOGE(GRAPH_FAILED, "RemoveNode: node[%s] is invalid.", node_ptr->GetName().c_str());
  454. return GRAPH_FAILED;
  455. }
  456. ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph();
  457. if (compute_graph_ptr == nullptr) {
  458. GELOGE(GRAPH_FAILED, "RemoveNde: compute graph ptr is nullptr.");
  459. return GRAPH_FAILED;
  460. }
  461. ge::NodeUtils::UnlinkAll(*node_ptr);
  462. if (GraphUtils::RemoveNodeWithoutRelink(compute_graph_ptr, node_ptr) != GRAPH_SUCCESS) {
  463. GELOGE(GRAPH_FAILED, "RemoveNode: remove node[%s] failed.", node_ptr->GetName().c_str());
  464. return GRAPH_FAILED;
  465. }
  466. node_ptr->SetAnyOwnerComputeGraph(nullptr);
  467. return GRAPH_SUCCESS;
  468. }
  469. graphStatus Graph::RemoveEdge(GNode &src_node, const int32_t src_port_index,
  470. GNode &dst_node, const int32_t dst_port_index) {
  471. if (impl_ == nullptr) {
  472. GELOGE(GRAPH_FAILED, "RemoveEdge: graph can not be used, impl is nullptr.");
  473. return GRAPH_FAILED;
  474. }
  475. if ((src_port_index == -1) && (dst_port_index != -1)) {
  476. GELOGE(GRAPH_FAILED, "RemoveEdge:src control anchor link to dst data anchor not exists.");
  477. return GRAPH_FAILED;
  478. }
  479. NodePtr src_node_ptr = NodeAdapter::GNode2Node(src_node);
  480. if (src_node_ptr == nullptr) {
  481. GELOGE(GRAPH_FAILED, "RemoveEdge: src gnode to node failed.");
  482. return GRAPH_FAILED;
  483. }
  484. NodePtr dst_node_ptr = NodeAdapter::GNode2Node(dst_node);
  485. if (dst_node_ptr == nullptr) {
  486. GELOGE(GRAPH_FAILED, "RemoveEdge: dst gnode to node failed.");
  487. return GRAPH_FAILED;
  488. }
  489. if (src_node_ptr->GetOwnerComputeGraph() == nullptr) {
  490. GELOGE(GRAPH_FAILED, "RemoveEdge: src node[%s] is invalid.", src_node_ptr->GetName().c_str());
  491. return GRAPH_FAILED;
  492. }
  493. if (dst_node_ptr->GetOwnerComputeGraph() == nullptr) {
  494. GELOGE(GRAPH_FAILED, "RemoveEdge: dst node[%s] is invalid.", dst_node_ptr->GetName().c_str());
  495. return GRAPH_FAILED;
  496. }
  497. if (impl_->RemoveEdge(src_node_ptr, src_port_index, dst_node_ptr, dst_port_index) != GRAPH_SUCCESS) {
  498. GELOGE(GRAPH_FAILED, "RemoveEdge: remove edge failed.");
  499. return GRAPH_FAILED;
  500. }
  501. return GRAPH_SUCCESS;
  502. }
  503. GNode Graph::AddNodeByOp(const Operator &op) {
  504. if (impl_ == nullptr) {
  505. GELOGE(GRAPH_FAILED, "AddNodeByOp: graph can not be used, impl is nullptr.");
  506. return GNode();
  507. }
  508. std::shared_ptr<ge::OpDesc> op_desc = ge::OpDescUtils::GetOpDescFromOperator(op);
  509. if (op_desc == nullptr) {
  510. GELOGE(GRAPH_FAILED, "AddNodeByOp: get op desc from op[%s] failed.", op.GetName().c_str());
  511. return GNode();
  512. }
  513. ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph();
  514. if (compute_graph_ptr == nullptr) {
  515. GELOGE(GRAPH_FAILED, "AddNodeByOp: compute graph ptr is nullptr.");
  516. return GNode();
  517. }
  518. NodePtr node_ptr = compute_graph_ptr->AddNode(op_desc);
  519. GNode gnode = NodeAdapter::Node2GNode(node_ptr);
  520. return gnode;
  521. }
  522. graphStatus Graph::AddDataEdge(GNode &src_node, const int32_t src_port_index,
  523. GNode &dst_node, const int32_t dst_port_index) {
  524. if (impl_ == nullptr) {
  525. GELOGE(GRAPH_FAILED, "AddDataEdge: graph can not be used, impl is nullptr.");
  526. return GRAPH_FAILED;
  527. }
  528. NodePtr src_node_ptr = NodeAdapter::GNode2Node(src_node);
  529. if (src_node_ptr == nullptr) {
  530. GELOGE(GRAPH_FAILED, "AddDataEdge: src gnode to node failed.");
  531. return GRAPH_FAILED;
  532. }
  533. NodePtr dst_node_ptr = NodeAdapter::GNode2Node(dst_node);
  534. if (dst_node_ptr == nullptr) {
  535. GELOGE(GRAPH_FAILED, "AddDataEdge: dst gnode to node failed.");
  536. return GRAPH_FAILED;
  537. }
  538. if (src_node_ptr->GetOwnerComputeGraph() == nullptr) {
  539. GELOGE(GRAPH_FAILED, "AddDataEdge: src node[%s] is invalid.", src_node_ptr->GetName().c_str());
  540. return GRAPH_FAILED;
  541. }
  542. if (dst_node_ptr->GetOwnerComputeGraph() == nullptr) {
  543. GELOGE(GRAPH_FAILED, "AddDataEdge: dst node[%s] is invalid.", dst_node_ptr->GetName().c_str());
  544. return GRAPH_FAILED;
  545. }
  546. graphStatus res = GraphUtils::AddEdge(src_node_ptr->GetOutDataAnchor(src_port_index),
  547. dst_node_ptr->GetInDataAnchor(dst_port_index));
  548. if (res != GRAPH_SUCCESS) {
  549. GELOGE(GRAPH_FAILED, "AddDataEdge: Add data edge failed.");
  550. return GRAPH_FAILED;
  551. }
  552. return GRAPH_SUCCESS;
  553. }
  554. graphStatus Graph::AddControlEdge (GNode &src_node, GNode &dst_node) {
  555. if (impl_ == nullptr) {
  556. GELOGE(GRAPH_FAILED, "AddControlEdge: graph can not be used, impl is nullptr.");
  557. return GRAPH_FAILED;
  558. }
  559. NodePtr src_node_ptr = NodeAdapter::GNode2Node(src_node);
  560. if (src_node_ptr == nullptr) {
  561. GELOGE(GRAPH_FAILED, "AddControlEdge: src gnode to node failed.");
  562. return GRAPH_FAILED;
  563. }
  564. NodePtr dst_node_ptr = NodeAdapter::GNode2Node(dst_node);
  565. if (dst_node_ptr == nullptr) {
  566. GELOGE(GRAPH_FAILED, "AddControlEdge: dst gnode to node failed.");
  567. return GRAPH_FAILED;
  568. }
  569. if (src_node_ptr->GetOwnerComputeGraph() == nullptr) {
  570. GELOGE(GRAPH_FAILED, "AddControlEdge: src node[%s] is invalid.", src_node_ptr->GetName().c_str());
  571. return GRAPH_FAILED;
  572. }
  573. if (dst_node_ptr->GetOwnerComputeGraph() == nullptr) {
  574. GELOGE(GRAPH_FAILED, "AddControlEdge: dst node[%s] is invalid.", dst_node_ptr->GetName().c_str());
  575. return GRAPH_FAILED;
  576. }
  577. graphStatus res = GraphUtils::AddEdge(src_node_ptr->GetOutControlAnchor(), dst_node_ptr->GetInControlAnchor());
  578. if (res != GRAPH_SUCCESS) {
  579. GELOGE(GRAPH_FAILED, "AddControlEdge: Add control edge failed.");
  580. return GRAPH_FAILED;
  581. }
  582. return SUCCESS;
  583. }
  584. GraphPtr Graph::ConstructFromInputs(const std::vector<Operator> &inputs, const AscendString &name) {
  585. const char* ascend_name = name.GetString();
  586. if (ascend_name == nullptr) {
  587. GELOGE(GRAPH_PARAM_INVALID, "ConstructFromInputs: ascend string error.");
  588. return nullptr;
  589. }
  590. if (inputs.empty()) {
  591. GELOGE(GRAPH_FAILED, "ConstructFromInputs: inputs size can not be 0.");
  592. return nullptr;
  593. }
  594. std::string graph_name = ascend_name;
  595. ComputeGraphPtr compute_graph = GraphUtils::CreateGraphFromOperator(graph_name, inputs);
  596. if (compute_graph == nullptr) {
  597. GELOGE(GRAPH_FAILED, "ConstructFromInputs: create compute graph failed.");
  598. return nullptr;
  599. }
  600. compute_graph->SetInputSize(static_cast<uint32_t>(inputs.size()));
  601. GraphPtr graph_ptr = GraphUtils::CreateGraphPtrFromComputeGraph(compute_graph);
  602. if (graph_ptr == nullptr) {
  603. GELOGE(GRAPH_FAILED, "ConstructFromInputs: create graph from compute graph failed.");
  604. return nullptr;
  605. }
  606. return graph_ptr;
  607. }
  608. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::GetComputeGraph(const ge::Graph &graph) {
  609. GE_CHK_BOOL_EXEC_NOLOG(graph.IsValid(), return nullptr);
  610. return graph.impl_->compute_graph_;
  611. }
  612. graphStatus Graph::SaveToFile(const string &file_name) const {
  613. Model model = Model();
  614. model.SetGraph(*this);
  615. return model.SaveToFile(file_name);
  616. }
  617. graphStatus Graph::SaveToFile(const char *file_name) const {
  618. if (file_name == nullptr) {
  619. GELOGE(GRAPH_FAILED, "SaveToFile: file name is nullptr.");
  620. return GRAPH_FAILED;
  621. }
  622. Model model = Model();
  623. model.SetGraph(*this);
  624. std::string file = file_name;
  625. return model.SaveToFile(file);
  626. }
  627. graphStatus Graph::LoadFromFile(const string &file_name) {
  628. Model model = Model();
  629. graphStatus ret = model.LoadFromFile(file_name);
  630. if (ret != GRAPH_SUCCESS) {
  631. return ret;
  632. }
  633. *this = model.GetGraph();
  634. return GRAPH_SUCCESS;
  635. }
  636. graphStatus Graph::LoadFromFile(const char *file_name) {
  637. if (file_name == nullptr) {
  638. GELOGE(GRAPH_FAILED, "SaveToFile: file name is nullptr.");
  639. return GRAPH_FAILED;
  640. }
  641. Model model = Model();
  642. std::string file = file_name;
  643. graphStatus ret = model.LoadFromFile(file);
  644. if (ret != GRAPH_SUCCESS) {
  645. return ret;
  646. }
  647. *this = model.GetGraph();
  648. return GRAPH_SUCCESS;
  649. }
  650. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  651. const std::string &Graph::GetName() const {
  652. return impl_->GetName();
  653. }
  654. graphStatus Graph::GetName(AscendString &name) const {
  655. if (impl_ == nullptr) {
  656. GELOGE(GRAPH_FAILED, "GetName: impl is nullptr.");
  657. return GRAPH_FAILED;
  658. }
  659. std::string graph_name = impl_->GetName();
  660. name = AscendString(graph_name.c_str());
  661. return GRAPH_SUCCESS;
  662. }
  663. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph
  664. GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) {
  665. GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return Graph(""));
  666. auto name = compute_graph->GetName();
  667. auto graph = Graph(name);
  668. GE_CHK_BOOL_EXEC_NOLOG(graph.impl_ != nullptr, return graph);
  669. graph.impl_->compute_graph_ = compute_graph;
  670. return graph;
  671. }
  672. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GraphPtr
  673. GraphUtils::CreateGraphPtrFromComputeGraph(const ge::ComputeGraphPtr compute_graph) {
  674. GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return nullptr);
  675. auto name = compute_graph->GetName();
  676. auto graph = ComGraphMakeShared<Graph>(name);
  677. GE_CHK_BOOL_EXEC_NOLOG(graph != nullptr, return nullptr);
  678. GE_CHK_BOOL_EXEC_NOLOG(graph->impl_ != nullptr, return nullptr);
  679. graph->impl_->compute_graph_ = compute_graph;
  680. return graph;
  681. }
  682. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  683. graphStatus GraphUtils::RecoverGraphOperators(const Graph &graph) {
  684. GE_CHECK_NOTNULL(graph.impl_);
  685. GE_CHECK_NOTNULL(graph.impl_->compute_graph_);
  686. graph.impl_->op_list_.clear();
  687. for (const auto &node : graph.impl_->compute_graph_->GetDirectNode()) {
  688. graph.impl_->op_list_[node->GetName()] = OpDescUtils::CreateOperatorFromNode(node);
  689. }
  690. return SUCCESS;
  691. }
  692. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示