You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_utils.h 23 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_GRAPH_UTILS_GRAPH_UTILS_H_
  17. #define INC_GRAPH_UTILS_GRAPH_UTILS_H_
  18. #include <fstream>
  19. #include <iostream>
  20. #include <map>
  21. #include <string>
  22. #include <vector>
  23. #include "graph/anchor.h"
  24. #include "graph/node.h"
  25. #include "graph/compute_graph.h"
  26. #include "graph/utils/anchor_utils.h"
  27. #include "graph/graph.h"
  28. #include "graph/model.h"
  29. #define REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
  30. do { \
  31. DataType ret; \
  32. attr.GetValue<DataType>(ret); \
  33. } while (0)
  34. #define PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \
  35. do { \
  36. if (value_type == VT_ENUM) { \
  37. REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
  38. stream << ret; \
  39. } \
  40. } while (0)
  41. #define PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \
  42. do { \
  43. if (value_type == VT_ENUM) { \
  44. REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
  45. stream << "["; \
  46. for (int i = 0; i < ret.size(); i++) { \
  47. stream << ret[i]; \
  48. if (i + 1 != ret.size()) stream << ", "; \
  49. } \
  50. stream << "]"; \
  51. } \
  52. } while (0)
  53. #define PRINT_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \
  54. else PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream)
  55. #define PRINT_LIST_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \
  56. else PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream)
  57. #define PRINT_SHAPE(i_o, n, idx, stream) \
  58. do { \
  59. auto op = n->GetOpDesc(); \
  60. GeTensorDesc td = i_o == "input" ? op->GetInputDesc(idx) : op->GetOutputDesc(idx); \
  61. auto shape = td.GetShape().GetDims(); \
  62. stream << "["; \
  63. for (int i = 0; i < shape.size(); i++) { \
  64. stream << shape[i]; \
  65. if (i + 1 < shape.size()) stream << ", "; \
  66. } \
  67. stream << "]"; \
  68. } while (0)
  69. #define PRINT_ATTR_FUNC(stream) \
  70. [&](GeAttrValue attr) { \
  71. auto type = attr.GetValueType(); \
  72. PRINT_ATTR_VALUE_IF(type, GeAttrValue::ValueType::VT_STRING, GeAttrValue::STR, attr, stream) \
  73. PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_FLOAT, GeAttrValue::FLOAT, attr, stream) \
  74. PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_BOOL, GeAttrValue::BOOL, attr, stream) \
  75. PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_INT, GeAttrValue::INT, attr, stream) \
  76. PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_STRING, GeAttrValue::LIST_STR, attr, stream) \
  77. PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_FLOAT, GeAttrValue::LIST_FLOAT, attr, stream) \
  78. PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_BOOL, GeAttrValue::LIST_BOOL, attr, stream) \
  79. PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_INT, GeAttrValue::LIST_INT, attr, stream) \
  80. else if (type == GeAttrValue::ValueType::VT_TENSOR_DESC) stream << "TENSOR_DESC"; \
  81. else if (type == GeAttrValue::ValueType::VT_TENSOR) stream << "TENSOR"; \
  82. else if (type == GeAttrValue::ValueType::VT_BYTES) stream << "BYTES"; \
  83. else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR_DESC) stream << "LIST_TENSOR_DESC"; \
  84. else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR) stream << "LIST_TENSOR"; \
  85. else if (type == GeAttrValue::ValueType::VT_LIST_BYTES) stream << "LIST_BYTES"; \
  86. };
  87. namespace ge {
  88. class GraphUtils {
  89. public:
  90. static ComputeGraphPtr GetComputeGraph(const Graph &graph);
  91. static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph);
  92. static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector<Operator> &inputs);
  93. static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
  94. static graphStatus AddEdge(const OutDataAnchorPtr &src, const Format &src_format, const InDataAnchorPtr &dst,
  95. const Format &dst_format);
  96. static graphStatus AddEdge(const AnchorPtr &src, const AnchorPtr &dst);
  97. static graphStatus AddEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst);
  98. static graphStatus AddEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst);
  99. // check whether src is link to dst and then remove
  100. static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
  101. static graphStatus RemoveEdge(const AnchorPtr &src, const AnchorPtr &dst);
  102. static graphStatus RemoveEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst);
  103. static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst);
  104. static graphStatus ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst,
  105. const InDataAnchorPtr &new_dst);
  106. static graphStatus ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst,
  107. const InControlAnchorPtr &new_dst);
  108. static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst,
  109. const NodePtr &new_node);
  110. static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node);
  111. static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor,
  112. const std::vector<OpDescPtr> &vec_op_desc);
  113. ///
  114. /// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst
  115. /// @param [in] src
  116. /// @param [in] dsts
  117. /// @param [in] insert_node
  118. /// @param [in] input_index
  119. /// @param [in] output_index
  120. /// @return graphStatus
  121. ///
  122. static graphStatus InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
  123. const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0);
  124. static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node);
  125. static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node);
  126. static void RecordOriginalNames(std::vector<ge::NodePtr> original_nodes, const ge::NodePtr &node);
  127. static void RecordOriginalNames(std::vector<std::string> names_tmp, const ge::NodePtr &node);
  128. static bool MatchDumpStr(const std::string &suffix);
  129. static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false);
  130. static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph);
  131. static void BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos);
  132. static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix);
  133. static bool LoadGEGraphFromOnnx(const char *file, ge::ComputeGraph &compute_graph);
  134. static bool ReadProtoFromTextFile(const char *file, google::protobuf::Message *message);
  135. static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *real_path);
  136. static graphStatus AppendInputNode(const ComputeGraphPtr &graph, const NodePtr &node);
  137. ///
  138. /// Isolating `node`, relinking data links from the in-anchor peer nodes to
  139. /// the out-anchor peer nodes according to `io_map`, relinking control links
  140. /// to ensure that input nodes of `node` are before out nodes
  141. ///
  142. /// Link the `io_map[i]` input anchor peer node to `i` output anchor peer
  143. /// nodes, then unlink all links connecting with `node`. If `io_map[i]` < 0,
  144. /// unlink all links from `i` output anchor without any relinking.
  145. ///
  146. /// @param node
  147. /// @param io_map
  148. /// @return
  149. ///
  150. static graphStatus IsolateNode(const NodePtr &node, const std::initializer_list<int> &io_map);
  151. static graphStatus IsolateNode(const NodePtr &node, const std::vector<int> &io_map);
  152. ///
  153. /// Isolate `node` which must be one input one output, equivalent to
  154. /// `IsolateNode(node, {0})`
  155. /// @param node
  156. /// @return
  157. ///
  158. static graphStatus IsolateNodeOneIO(const NodePtr &node);
  159. ///
  160. /// The data anchors replacing behavior is the same with
  161. /// `ReplaceNodeDataAnchors`. In addition, replace all `old_node` control
  162. /// anchors with `new_node`'s.
  163. /// @param new_node
  164. /// @param old_node
  165. /// @param inputs_map
  166. /// @param outputs_map
  167. /// @return
  168. ///
  169. static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node,
  170. std::initializer_list<int> inputs_map, std::initializer_list<int> outputs_map);
  171. static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node,
  172. const std::vector<int> &inputs_map, const std::vector<int> &outputs_map);
  173. ///
  174. /// Replace `old_node` data anchors with `new_node`'s according to `inputs_map` and `outputs_map`.
  175. /// Replace the `i` in/out data anchor on `old_node` with
  176. /// `inputs_map[i]`/`outputs_map[i]` data anchor on `new_node`.
  177. /// If `inputs_map[i]`/`outputs_map[i]` < 0 or the index not contained in
  178. /// `inputs_map[i]`/`outputs_map[i]`, the `i` data anchor will remain
  179. /// on `old_node`.
  180. /// @param new_node
  181. /// @param old_node
  182. /// @param inputs_map
  183. /// @param outputs_map
  184. /// @return
  185. ///
  186. static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node,
  187. std::initializer_list<int> inputs_map,
  188. std::initializer_list<int> outputs_map);
  189. static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node,
  190. const std::vector<int> &inputs_map, const std::vector<int> &outputs_map);
  191. ///
  192. /// Copy all in-control edges from `src_node` to `dst_node`
  193. /// @param src_node
  194. /// @param dst_node
  195. /// @return
  196. ///
  197. static graphStatus CopyInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node);
  198. static graphStatus MoveInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node);
  199. ///
  200. /// Copy all out-control edges from `src_node` to `dst_node`
  201. /// @param src_node
  202. /// @param dst_node
  203. /// @return success: GRAPH_SUCESS
  204. ///
  205. static graphStatus CopyOutCtrlEdges(const NodePtr &src_node, NodePtr &dst_node);
  206. ///
  207. /// Move all out-control edges from `src_node` to `dst_node`
  208. /// @param src_node
  209. /// @param dst_node
  210. /// @return success: GRAPH_SUCESS
  211. ///
  212. static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node);
  213. static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph);
  214. };
  215. class ComputeGraphBuilder {
  216. public:
  217. ComputeGraphBuilder() : owner_graph_(nullptr) {}
  218. ComputeGraphBuilder(const ComputeGraphBuilder &) = delete;
  219. ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete;
  220. ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete;
  221. ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete;
  222. ~ComputeGraphBuilder() = default;
  223. ///
  224. /// @brief Add node to graph
  225. /// @param [in] op_desc
  226. /// @return ComputeGraphBuilder
  227. ///
  228. virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc);
  229. ///
  230. /// @brief Add data-link among nodes in graph
  231. /// @param [in] src_name
  232. /// @param [in] out_anchor_ind
  233. /// @param [in] dst_name
  234. /// @param [in] in_anchor_ind
  235. /// @return ComputeGraphBuilder
  236. ///
  237. virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind,
  238. const std::string &dst_name, uint32_t in_anchor_ind);
  239. ///
  240. /// @brief Add ctrl-link among nodes in graph
  241. /// @param [in] src_name
  242. /// @param [in] dst_name
  243. /// @return ComputeGraphBuilder
  244. ///
  245. virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name);
  246. ///
  247. /// @brief Build graph
  248. /// @param [out] error_code
  249. /// @param [out] error_msg
  250. /// @return ComputeGraphPtr
  251. ///
  252. virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0;
  253. /// @brief Get node with name
  254. /// @param [in] name
  255. /// @return NodePtr
  256. ///
  257. NodePtr GetNode(const std::string &name);
  258. protected:
  259. ///
  260. /// @brief Build nodes
  261. /// @param [out] error_code
  262. /// @param [out] error_msg
  263. /// @return void
  264. ///
  265. void BuildNodes(graphStatus &error_code, std::string &error_msg);
  266. ///
  267. /// @brief Build data-links
  268. /// @param [out] error_code
  269. /// @param [out] error_msg
  270. /// @return void
  271. ///
  272. void BuildDataLinks(graphStatus &error_code, std::string &error_msg);
  273. ///
  274. /// @brief Build ctrl-links
  275. /// @param [out] error_code
  276. /// @param [out] error_msg
  277. /// @return void
  278. ///
  279. void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg);
  280. ComputeGraphPtr owner_graph_;
  281. // node_name -> node
  282. std::map<std::string, NodePtr> node_names_;
  283. std::vector<OpDescPtr> nodes_;
  284. // <src_node_name, out_anchor_ind> -> <dst_node_name, in_anchor_ind>
  285. std::vector<std::pair<std::pair<std::string, uint32_t>, std::pair<std::string, uint32_t>>> data_links_;
  286. // src_node_name -> dst_node_name
  287. std::vector<std::pair<std::string, std::string>> ctrl_links_;
  288. };
  289. class CompleteGraphBuilder : public ComputeGraphBuilder {
  290. public:
  291. explicit CompleteGraphBuilder(std::string name) : name_(std::move(name)), parent_node_(nullptr) {}
  292. CompleteGraphBuilder(const CompleteGraphBuilder &) = delete;
  293. CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete;
  294. CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete;
  295. CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete;
  296. ~CompleteGraphBuilder() = default;
  297. ///
  298. /// @brief Add node to graph
  299. /// @param [in] op_desc
  300. /// @return CompleteGraphBuilder
  301. ///
  302. CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override;
  303. ///
  304. /// @brief Add data-link among nodes in graph
  305. /// @param [in] src_name
  306. /// @param [in] out_anchor_ind
  307. /// @param [in] dst_name
  308. /// @param [in] in_anchor_ind
  309. /// @return CompleteGraphBuilder
  310. ///
  311. CompleteGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name,
  312. uint32_t in_anchor_ind) override;
  313. ///
  314. /// @brief Add ctrl-link among nodes in graph
  315. /// @param [in] src_name
  316. /// @param [in] dst_name
  317. /// @return CompleteGraphBuilder
  318. ///
  319. CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override;
  320. ///
  321. /// @brief Set index_th input anchor for graph
  322. /// @param [in] index
  323. /// @param [in] node_names
  324. /// @param [in] anchor_inds
  325. /// @return CompleteGraphBuilder
  326. ///
  327. CompleteGraphBuilder &SetInput(uint32_t index, const std::vector<std::string> &node_names,
  328. const std::vector<uint32_t> &anchor_inds);
  329. ///
  330. /// @brief Set index_th input of graph as useless
  331. /// @param [in] index
  332. /// @return CompleteGraphBuilder
  333. ///
  334. CompleteGraphBuilder &SetUselessInput(uint32_t index);
  335. ///
  336. /// @brief Add output anchor for graph
  337. /// @param [in] owner_node_name
  338. /// @param [in] anchor_ind
  339. /// @return CompleteGraphBuilder
  340. ///
  341. CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind);
  342. ///
  343. /// @brief Set parent-node of graph
  344. /// @param [in] parent_node
  345. /// @return CompleteGraphBuilder
  346. ///
  347. CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node);
  348. ///
  349. /// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node
  350. /// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node
  351. /// @return CompleteGraphBuilder
  352. ///
  353. CompleteGraphBuilder &SetInputMapping(const std::map<uint32_t, uint32_t> &input_mapping);
  354. ///
  355. /// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind
  356. /// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node
  357. /// @return CompleteGraphBuilder
  358. ///
  359. CompleteGraphBuilder &SetOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping);
  360. ///
  361. /// @brief Build graph
  362. /// @param [out] error_code
  363. /// @param [out] error_msg
  364. /// @return ComputeGraphPtr
  365. ///
  366. ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override;
  367. private:
  368. ///
  369. /// @brief Build inputs
  370. /// @param [out] error_code
  371. /// @param [out] error_msg
  372. /// @return void
  373. ///
  374. void BuildInputs(graphStatus &error_code, std::string &error_msg);
  375. ///
  376. /// @brief Add data node
  377. /// @param [in] index
  378. /// @param [out] error_code
  379. /// @param [out] error_msg
  380. /// @return void
  381. ///
  382. NodePtr AddDateNode(uint32_t index, graphStatus &error_code, std::string &error_msg);
  383. ///
  384. /// @brief Build outputs
  385. /// @param [out] error_code
  386. /// @param [out] error_msg
  387. /// @return void
  388. ///
  389. void BuildOutputs(graphStatus &error_code, std::string &error_msg);
  390. ///
  391. /// @brief Add NetOutput node
  392. /// @param [out] error_code
  393. /// @param [out] error_msg
  394. /// @return NodePtr
  395. ///
  396. NodePtr AddNetOutputNode(graphStatus &error_code, std::string &error_msg);
  397. ///
  398. /// @brief Add input/output tensor for NetOutput node
  399. /// @param [in] out_nodes_info
  400. /// @param [out] net_output_desc
  401. /// @return graphStatus
  402. ///
  403. graphStatus BuildInOutForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info,
  404. OpDescPtr &net_output_desc);
  405. ///
  406. /// @brief Add edge for NetOutput node
  407. /// @param [in] out_nodes_info
  408. /// @param [out] net_output_node
  409. /// @return graphStatus
  410. ///
  411. graphStatus AddEdgeForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info,
  412. const NodePtr &net_output_node);
  413. std::string name_;
  414. NodePtr parent_node_;
  415. std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_;
  416. std::vector<std::pair<std::string, uint32_t>> graph_outputs_;
  417. // index_of_graph_input -> in_anchor_index_of_parent_node
  418. std::map<uint32_t, uint32_t> input_mapping_;
  419. // index_of_graph_output -> out_anchor_index_of_parent_node
  420. std::map<uint32_t, uint32_t> output_mapping_;
  421. };
  422. class PartialGraphBuilder : public ComputeGraphBuilder {
  423. public:
  424. PartialGraphBuilder() = default;
  425. PartialGraphBuilder(const PartialGraphBuilder &) = delete;
  426. PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete;
  427. PartialGraphBuilder(const PartialGraphBuilder &&) = delete;
  428. PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete;
  429. ~PartialGraphBuilder() = default;
  430. ///
  431. /// @brief Add node to graph
  432. /// @param [in] op_desc
  433. /// @return PartialGraphBuilder
  434. ///
  435. PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override;
  436. ///
  437. /// @brief Add data-link among nodes in graph
  438. /// @param [in] src_name
  439. /// @param [in] out_anchor_ind
  440. /// @param [in] dst_name
  441. /// @param [in] in_anchor_ind
  442. /// @return PartialGraphBuilder
  443. ///
  444. PartialGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name,
  445. uint32_t in_anchor_ind) override;
  446. ///
  447. /// @brief Add ctrl-link among nodes in graph
  448. /// @param [in] src_name
  449. /// @param [in] dst_name
  450. /// @return PartialGraphBuilder
  451. ///
  452. PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override;
  453. ///
  454. /// @brief Set owner graph
  455. /// @param [in] graph
  456. /// @return PartialGraphBuilder
  457. ///
  458. PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph);
  459. ///
  460. /// @brief Add exist node
  461. /// @param [in] node
  462. /// @return PartialGraphBuilder
  463. ///
  464. PartialGraphBuilder &AddExistNode(const NodePtr &node);
  465. ///
  466. /// @brief Build multi nodes with links
  467. /// @param [out] error_code
  468. /// @param [out] error_msg
  469. /// @return ComputeGraphPtr
  470. ///
  471. ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override;
  472. private:
  473. ///
  474. /// @brief Build exist nodes
  475. /// @param [out] error_code
  476. /// @param [out] error_msg
  477. /// @return void
  478. ///
  479. void BuildExistNodes(graphStatus &error_code, std::string &error_msg);
  480. std::vector<NodePtr> exist_nodes_;
  481. };
  482. } // namespace ge
  483. #endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示