You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_utils.h 27 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_GRAPH_UTILS_GRAPH_UTILS_H_
  17. #define INC_GRAPH_UTILS_GRAPH_UTILS_H_
  18. #include <fstream>
  19. #include <iostream>
  20. #include <map>
  21. #include <string>
  22. #include <vector>
  23. #include "graph/anchor.h"
  24. #include "graph/node.h"
  25. #include "graph/compute_graph.h"
  26. #include "graph/utils/anchor_utils.h"
  27. #include "graph/graph.h"
  28. #include "graph/model.h"
  29. #define REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
  30. do { \
  31. DataType ret; \
  32. attr.GetValue<DataType>(ret); \
  33. } while (0)
  34. #define PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \
  35. do { \
  36. if (value_type == VT_ENUM) { \
  37. REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
  38. stream << ret; \
  39. } \
  40. } while (0)
  41. #define PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \
  42. do { \
  43. if (value_type == VT_ENUM) { \
  44. REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
  45. stream << "["; \
  46. for (int i = 0; i < ret.size(); i++) { \
  47. stream << ret[i]; \
  48. if (i + 1 != ret.size()) stream << ", "; \
  49. } \
  50. stream << "]"; \
  51. } \
  52. } while (0)
  53. #define PRINT_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \
  54. else PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream)
  55. #define PRINT_LIST_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \
  56. else PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream)
  57. #define PRINT_SHAPE(i_o, n, idx, stream) \
  58. do { \
  59. auto op = n->GetOpDesc(); \
  60. GeTensorDesc td = i_o == "input" ? op->GetInputDesc(idx) : op->GetOutputDesc(idx); \
  61. auto shape = td.GetShape().GetDims(); \
  62. stream << "["; \
  63. for (int i = 0; i < shape.size(); i++) { \
  64. stream << shape[i]; \
  65. if (i + 1 < shape.size()) stream << ", "; \
  66. } \
  67. stream << "]"; \
  68. } while (0)
  69. #define PRINT_ATTR_FUNC(stream) \
  70. [&](GeAttrValue attr) { \
  71. auto type = attr.GetValueType(); \
  72. PRINT_ATTR_VALUE_IF(type, GeAttrValue::ValueType::VT_STRING, GeAttrValue::STR, attr, stream) \
  73. PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_FLOAT, GeAttrValue::FLOAT, attr, stream) \
  74. PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_BOOL, GeAttrValue::BOOL, attr, stream) \
  75. PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_INT, GeAttrValue::INT, attr, stream) \
  76. PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_STRING, GeAttrValue::LIST_STR, attr, stream) \
  77. PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_FLOAT, GeAttrValue::LIST_FLOAT, attr, stream) \
  78. PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_BOOL, GeAttrValue::LIST_BOOL, attr, stream) \
  79. PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_INT, GeAttrValue::LIST_INT, attr, stream) \
  80. else if (type == GeAttrValue::ValueType::VT_TENSOR_DESC) stream << "TENSOR_DESC"; \
  81. else if (type == GeAttrValue::ValueType::VT_TENSOR) stream << "TENSOR"; \
  82. else if (type == GeAttrValue::ValueType::VT_BYTES) stream << "BYTES"; \
  83. else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR_DESC) stream << "LIST_TENSOR_DESC"; \
  84. else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR) stream << "LIST_TENSOR"; \
  85. else if (type == GeAttrValue::ValueType::VT_LIST_BYTES) stream << "LIST_BYTES"; \
  86. };
  87. namespace ge {
  88. enum IOType { kIn, kOut };
  89. struct NodeIndexIO {
  90. NodeIndexIO(ge::NodePtr node, uint32_t index, IOType io_type)
  91. : node(std::move(node)), index(index), io_type(io_type) {}
  92. NodeIndexIO(ge::NodePtr node, int index, IOType io_type)
  93. : node(std::move(node)), index(static_cast<uint32_t>(index)), io_type(io_type) {}
  94. ~NodeIndexIO() {}
  95. NodePtr node = nullptr;
  96. uint32_t index = 0;
  97. IOType io_type = kOut;
  98. std::string ToString() const {
  99. if ((node == nullptr) || (node->GetOwnerComputeGraph() == nullptr)) {
  100. return "";
  101. }
  102. return node->GetName() + (io_type == kOut ? "_out_" : "_in_") + std::to_string(index);
  103. }
  104. };
  105. class GraphUtils {
  106. public:
  107. static ComputeGraphPtr GetComputeGraph(const Graph &graph);
  108. static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph);
  109. static graphStatus RecoverGraphOperators(const Graph &graph);
  110. static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector<Operator> &inputs);
  111. static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
  112. static graphStatus AddEdge(const OutDataAnchorPtr &src, const Format &src_format, const InDataAnchorPtr &dst,
  113. const Format &dst_format);
  114. static graphStatus AddEdge(const AnchorPtr &src, const AnchorPtr &dst);
  115. static graphStatus AddEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst);
  116. static graphStatus AddEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst);
  117. // check whether src is link to dst and then remove
  118. static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
  119. static graphStatus RemoveEdge(const AnchorPtr &src, const AnchorPtr &dst);
  120. static graphStatus RemoveEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst);
  121. static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst);
  122. static graphStatus ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst,
  123. const InDataAnchorPtr &new_dst);
  124. static graphStatus ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst,
  125. const InControlAnchorPtr &new_dst);
  126. static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst,
  127. const NodePtr &new_node);
  128. static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node);
  129. static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor,
  130. const std::vector<OpDescPtr> &vec_op_desc);
  131. ///
  132. /// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst
  133. /// @param [in] src
  134. /// @param [in] dsts
  135. /// @param [in] insert_node
  136. /// @param [in] input_index
  137. /// @param [in] output_index
  138. /// @return graphStatus
  139. ///
  140. static graphStatus InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
  141. const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0);
  142. static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node);
  143. static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node);
  144. static void RecordOriginalNames(std::vector<ge::NodePtr> original_nodes, const ge::NodePtr &node);
  145. static void RecordOriginalNames(std::vector<std::string> names_tmp, const ge::NodePtr &node);
  146. static bool MatchDumpStr(const std::string &suffix);
  147. static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false);
  148. static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph);
  149. static void BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos);
  150. static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix);
  151. static bool LoadGEGraphFromOnnx(const char *file, ge::ComputeGraph &compute_graph);
  152. static bool ReadProtoFromTextFile(const char *file, google::protobuf::Message *message);
  153. static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *real_path);
  154. static graphStatus AppendInputNode(const ComputeGraphPtr &graph, const NodePtr &node);
  155. ///
  156. /// Isolating `node`, relinking data links from the in-anchor peer nodes to
  157. /// the out-anchor peer nodes according to `io_map`, relinking control links
  158. /// to ensure that input nodes of `node` are before out nodes
  159. ///
  160. /// Link the `io_map[i]` input anchor peer node to `i` output anchor peer
  161. /// nodes, then unlink all links connecting with `node`. If `io_map[i]` < 0,
  162. /// unlink all links from `i` output anchor without any relinking.
  163. ///
  164. /// @param node
  165. /// @param io_map
  166. /// @return
  167. ///
  168. static graphStatus IsolateNode(const NodePtr &node, const std::initializer_list<int> &io_map);
  169. static graphStatus IsolateNode(const NodePtr &node, const std::vector<int> &io_map);
  170. ///
  171. /// Isolate `node` which must be one input one output, equivalent to
  172. /// `IsolateNode(node, {0})`
  173. /// @param node
  174. /// @return
  175. ///
  176. static graphStatus IsolateNodeOneIO(const NodePtr &node);
  177. ///
  178. /// The data anchors replacing behavior is the same with
  179. /// `ReplaceNodeDataAnchors`. In addition, replace all `old_node` control
  180. /// anchors with `new_node`'s.
  181. /// @param new_node
  182. /// @param old_node
  183. /// @param inputs_map
  184. /// @param outputs_map
  185. /// @return
  186. ///
  187. static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node,
  188. std::initializer_list<int> inputs_map, std::initializer_list<int> outputs_map);
  189. static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node,
  190. const std::vector<int> &inputs_map, const std::vector<int> &outputs_map);
  191. ///
  192. /// Replace `old_node` data anchors with `new_node`'s according to `inputs_map` and `outputs_map`.
  193. /// Replace the `i` in/out data anchor on `old_node` with
  194. /// `inputs_map[i]`/`outputs_map[i]` data anchor on `new_node`.
  195. /// If `inputs_map[i]`/`outputs_map[i]` < 0 or the index not contained in
  196. /// `inputs_map[i]`/`outputs_map[i]`, the `i` data anchor will remain
  197. /// on `old_node`.
  198. /// @param new_node
  199. /// @param old_node
  200. /// @param inputs_map
  201. /// @param outputs_map
  202. /// @return
  203. ///
  204. static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node,
  205. std::initializer_list<int> inputs_map,
  206. std::initializer_list<int> outputs_map);
  207. static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node,
  208. const std::vector<int> &inputs_map, const std::vector<int> &outputs_map);
  209. ///
  210. /// Copy all in-control edges from `src_node` to `dst_node`
  211. /// @param src_node
  212. /// @param dst_node
  213. /// @return
  214. ///
  215. static graphStatus CopyInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node);
  216. static graphStatus MoveInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node);
  217. ///
  218. /// Copy all out-control edges from `src_node` to `dst_node`
  219. /// @param src_node
  220. /// @param dst_node
  221. /// @return success: GRAPH_SUCESS
  222. ///
  223. static graphStatus CopyOutCtrlEdges(const NodePtr &src_node, NodePtr &dst_node);
  224. ///
  225. /// Move all out-control edges from `src_node` to `dst_node`
  226. /// @param src_node
  227. /// @param dst_node
  228. /// @return success: GRAPH_SUCESS
  229. ///
  230. static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node);
  231. static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph);
  232. static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec);
  233. ///
  234. /// Get reference-mapping of all data_anchors in graph
  235. /// @param [in] graph
  236. /// @param [out] symbol_to_anchors
  237. /// @param [out] anchor_to_symbol
  238. /// @return success: GRAPH_SUCESS
  239. ///
  240. static graphStatus GetRefMapping(const ComputeGraphPtr &graph,
  241. std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors,
  242. std::map<std::string, std::string> &anchor_to_symbol);
  243. private:
  244. ///
  245. /// Get reference-mapping for in_data_anchors of node
  246. /// @param [in] node
  247. /// @param [out] symbol_to_anchors
  248. /// @param [out] anchor_to_symbol
  249. /// @return success: GRAPH_SUCESS
  250. ///
  251. static graphStatus HandleInAnchorMapping(const NodePtr &node,
  252. std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors,
  253. std::map<std::string, std::string> &anchor_to_symbol);
  254. ///
  255. /// Get reference-mapping for out_data_anchors of node
  256. /// @param [in] node
  257. /// @param [out] symbol_to_anchors
  258. /// @param [out] anchor_to_symbol
  259. /// @return success: GRAPH_SUCESS
  260. ///
  261. static graphStatus HandleOutAnchorMapping(const NodePtr &node,
  262. std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors,
  263. std::map<std::string, std::string> &anchor_to_symbol);
  264. ///
  265. /// Handle input of subgraph
  266. /// @param [in] node
  267. /// @param [out] symbol_to_anchors
  268. /// @param [out] anchor_to_symbol
  269. /// @return success: GRAPH_SUCESS
  270. ///
  271. static graphStatus HandleSubgraphInput(const NodePtr &node,
  272. std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors,
  273. std::map<std::string, std::string> &anchor_to_symbol);
  274. ///
  275. /// Handle input of Merge op
  276. /// @param [in] node
  277. /// @param [out] symbol_to_anchors
  278. /// @param [out] anchor_to_symbol
  279. /// @return success: GRAPH_SUCESS
  280. ///
  281. static graphStatus HandleMergeInput(const NodePtr &node,
  282. std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors,
  283. std::map<std::string, std::string> &anchor_to_symbol);
  284. ///
  285. /// Handle output of subgraph
  286. /// @param [in] node
  287. /// @param [out] symbol_to_anchors
  288. /// @param [out] anchor_to_symbol
  289. /// @return success: GRAPH_SUCESS
  290. ///
  291. static graphStatus HandleSubgraphOutput(const NodePtr &node,
  292. std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors,
  293. std::map<std::string, std::string> &anchor_to_symbol);
  294. ///
  295. /// Union ref-mapping
  296. /// @param [in] exist_node_info1
  297. /// @param [in] exist_node_info2
  298. /// @param [out] symbol_to_anchors
  299. /// @param [out] anchor_to_symbol
  300. /// @param [out] symbol
  301. /// @return success: GRAPH_SUCESS
  302. ///
  303. static graphStatus UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2,
  304. std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors,
  305. std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol);
  306. ///
  307. /// Update symbol mapping with a new reference pair
  308. /// @param [in] cur_node_info
  309. /// @param [in] exist_node_info
  310. /// @param [out] symbol_to_anchors
  311. /// @param [out] anchor_to_symbol
  312. /// @return success: GRAPH_SUCESS
  313. ///
  314. static graphStatus UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info,
  315. std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors,
  316. std::map<std::string, std::string> &anchor_to_symbol);
  317. ///
  318. /// Check if out_data_anchor is reference of input
  319. /// @param [in] out_data_anchor
  320. /// @param [out] reuse_in_index
  321. /// @return bool
  322. ///
  323. static bool IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index);
  324. };
  325. class ComputeGraphBuilder {
  326. public:
  327. ComputeGraphBuilder() : owner_graph_(nullptr) {}
  328. ComputeGraphBuilder(const ComputeGraphBuilder &) = delete;
  329. ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete;
  330. ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete;
  331. ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete;
  332. ~ComputeGraphBuilder() = default;
  333. ///
  334. /// @brief Add node to graph
  335. /// @param [in] op_desc
  336. /// @return ComputeGraphBuilder
  337. ///
  338. virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc);
  339. ///
  340. /// @brief Add data-link among nodes in graph
  341. /// @param [in] src_name
  342. /// @param [in] out_anchor_ind
  343. /// @param [in] dst_name
  344. /// @param [in] in_anchor_ind
  345. /// @return ComputeGraphBuilder
  346. ///
  347. virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind,
  348. const std::string &dst_name, uint32_t in_anchor_ind);
  349. ///
  350. /// @brief Add ctrl-link among nodes in graph
  351. /// @param [in] src_name
  352. /// @param [in] dst_name
  353. /// @return ComputeGraphBuilder
  354. ///
  355. virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name);
  356. ///
  357. /// @brief Build graph
  358. /// @param [out] error_code
  359. /// @param [out] error_msg
  360. /// @return ComputeGraphPtr
  361. ///
  362. virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0;
  363. /// @brief Get node with name
  364. /// @param [in] name
  365. /// @return NodePtr
  366. ///
  367. NodePtr GetNode(const std::string &name);
  368. protected:
  369. ///
  370. /// @brief Build nodes
  371. /// @param [out] error_code
  372. /// @param [out] error_msg
  373. /// @return void
  374. ///
  375. void BuildNodes(graphStatus &error_code, std::string &error_msg);
  376. ///
  377. /// @brief Build data-links
  378. /// @param [out] error_code
  379. /// @param [out] error_msg
  380. /// @return void
  381. ///
  382. void BuildDataLinks(graphStatus &error_code, std::string &error_msg);
  383. ///
  384. /// @brief Build ctrl-links
  385. /// @param [out] error_code
  386. /// @param [out] error_msg
  387. /// @return void
  388. ///
  389. void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg);
  390. ComputeGraphPtr owner_graph_;
  391. // node_name -> node
  392. std::map<std::string, NodePtr> node_names_;
  393. std::vector<OpDescPtr> nodes_;
  394. // <src_node_name, out_anchor_ind> -> <dst_node_name, in_anchor_ind>
  395. std::vector<std::pair<std::pair<std::string, uint32_t>, std::pair<std::string, uint32_t>>> data_links_;
  396. // src_node_name -> dst_node_name
  397. std::vector<std::pair<std::string, std::string>> ctrl_links_;
  398. };
  399. class CompleteGraphBuilder : public ComputeGraphBuilder {
  400. public:
  401. explicit CompleteGraphBuilder(std::string name) : name_(std::move(name)), parent_node_(nullptr) {}
  402. CompleteGraphBuilder(const CompleteGraphBuilder &) = delete;
  403. CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete;
  404. CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete;
  405. CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete;
  406. ~CompleteGraphBuilder() = default;
  407. ///
  408. /// @brief Add node to graph
  409. /// @param [in] op_desc
  410. /// @return CompleteGraphBuilder
  411. ///
  412. CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override;
  413. ///
  414. /// @brief Add data-link among nodes in graph
  415. /// @param [in] src_name
  416. /// @param [in] out_anchor_ind
  417. /// @param [in] dst_name
  418. /// @param [in] in_anchor_ind
  419. /// @return CompleteGraphBuilder
  420. ///
  421. CompleteGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name,
  422. uint32_t in_anchor_ind) override;
  423. ///
  424. /// @brief Add ctrl-link among nodes in graph
  425. /// @param [in] src_name
  426. /// @param [in] dst_name
  427. /// @return CompleteGraphBuilder
  428. ///
  429. CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override;
  430. ///
  431. /// @brief Set index_th input anchor for graph
  432. /// @param [in] index
  433. /// @param [in] node_names
  434. /// @param [in] anchor_inds
  435. /// @return CompleteGraphBuilder
  436. ///
  437. CompleteGraphBuilder &SetInput(uint32_t index, const std::vector<std::string> &node_names,
  438. const std::vector<uint32_t> &anchor_inds);
  439. ///
  440. /// @brief Set index_th input of graph as useless
  441. /// @param [in] index
  442. /// @return CompleteGraphBuilder
  443. ///
  444. CompleteGraphBuilder &SetUselessInput(uint32_t index);
  445. ///
  446. /// @brief Add output anchor for graph
  447. /// @param [in] owner_node_name
  448. /// @param [in] anchor_ind
  449. /// @return CompleteGraphBuilder
  450. ///
  451. CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind);
  452. ///
  453. /// @brief Set parent-node of graph
  454. /// @param [in] parent_node
  455. /// @return CompleteGraphBuilder
  456. ///
  457. CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node);
  458. ///
  459. /// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node
  460. /// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node
  461. /// @return CompleteGraphBuilder
  462. ///
  463. CompleteGraphBuilder &SetInputMapping(const std::map<uint32_t, uint32_t> &input_mapping);
  464. ///
  465. /// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind
  466. /// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node
  467. /// @return CompleteGraphBuilder
  468. ///
  469. CompleteGraphBuilder &SetOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping);
  470. ///
  471. /// @brief Build graph
  472. /// @param [out] error_code
  473. /// @param [out] error_msg
  474. /// @return ComputeGraphPtr
  475. ///
  476. ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override;
  477. private:
  478. ///
  479. /// @brief Add data nodes
  480. /// @param [out] error_code
  481. /// @param [out] error_msg
  482. /// @return void
  483. ///
  484. void AddDataNodes(graphStatus &error_code, std::string &error_msg);
  485. ///
  486. /// @brief Add data node
  487. /// @param [in] index
  488. /// @param [out] error_code
  489. /// @param [out] error_msg
  490. /// @return void
  491. ///
  492. NodePtr AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg);
  493. ///
  494. /// @brief Add RetVal nodes
  495. /// @param [out] error_code
  496. /// @param [out] error_msg
  497. /// @return void
  498. ///
  499. void AddRetValNodes(graphStatus &error_code, std::string &error_msg);
  500. std::string name_;
  501. NodePtr parent_node_;
  502. std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_;
  503. std::vector<std::pair<std::string, uint32_t>> graph_outputs_;
  504. // index_of_graph_input -> in_anchor_index_of_parent_node
  505. std::map<uint32_t, uint32_t> input_mapping_;
  506. // index_of_graph_output -> out_anchor_index_of_parent_node
  507. std::map<uint32_t, uint32_t> output_mapping_;
  508. };
  509. class PartialGraphBuilder : public ComputeGraphBuilder {
  510. public:
  511. PartialGraphBuilder() = default;
  512. PartialGraphBuilder(const PartialGraphBuilder &) = delete;
  513. PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete;
  514. PartialGraphBuilder(const PartialGraphBuilder &&) = delete;
  515. PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete;
  516. ~PartialGraphBuilder() = default;
  517. ///
  518. /// @brief Add node to graph
  519. /// @param [in] op_desc
  520. /// @return PartialGraphBuilder
  521. ///
  522. PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override;
  523. ///
  524. /// @brief Add data-link among nodes in graph
  525. /// @param [in] src_name
  526. /// @param [in] out_anchor_ind
  527. /// @param [in] dst_name
  528. /// @param [in] in_anchor_ind
  529. /// @return PartialGraphBuilder
  530. ///
  531. PartialGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name,
  532. uint32_t in_anchor_ind) override;
  533. ///
  534. /// @brief Add ctrl-link among nodes in graph
  535. /// @param [in] src_name
  536. /// @param [in] dst_name
  537. /// @return PartialGraphBuilder
  538. ///
  539. PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override;
  540. ///
  541. /// @brief Set owner graph
  542. /// @param [in] graph
  543. /// @return PartialGraphBuilder
  544. ///
  545. PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph);
  546. ///
  547. /// @brief Add exist node
  548. /// @param [in] node
  549. /// @return PartialGraphBuilder
  550. ///
  551. PartialGraphBuilder &AddExistNode(const NodePtr &node);
  552. ///
  553. /// @brief Build multi nodes with links
  554. /// @param [out] error_code
  555. /// @param [out] error_msg
  556. /// @return ComputeGraphPtr
  557. ///
  558. ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override;
  559. private:
  560. ///
  561. /// @brief Build exist nodes
  562. /// @param [out] error_code
  563. /// @param [out] error_msg
  564. /// @return void
  565. ///
  566. void BuildExistNodes(graphStatus &error_code, std::string &error_msg);
  567. std::vector<NodePtr> exist_nodes_;
  568. };
  569. } // namespace ge
  570. #endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示