You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_utils.h 32 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_GRAPH_UTILS_GRAPH_UTILS_H_
  17. #define INC_GRAPH_UTILS_GRAPH_UTILS_H_
  18. #include <fstream>
  19. #include <iostream>
  20. #include <list>
  21. #include <map>
  22. #include <string>
  23. #include <unordered_map>
  24. #include <vector>
  25. #include "graph/anchor.h"
  26. #include "graph/compute_graph.h"
  27. #include "graph/graph.h"
  28. #include "graph/model.h"
  29. #include "graph/node.h"
  30. #include "graph/utils/anchor_utils.h"
  31. #define GE_DUMP(compute_graph, name) \
  32. do { \
  33. GraphUtils::DumpGEGraph(compute_graph, name); \
  34. GraphUtils::DumpGEGraphToOnnx(*compute_graph, name); \
  35. uint64_t i = 0; \
  36. for (const auto &sub_graph_func : compute_graph->GetAllSubgraphs()) { \
  37. auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); \
  38. GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); \
  39. GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); \
  40. } \
  41. } while (0)
  42. #define REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
  43. do { \
  44. DataType ret; \
  45. attr.GetValue<DataType>(ret); \
  46. } while (0)
  47. #define PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \
  48. do { \
  49. if (value_type == VT_ENUM) { \
  50. REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
  51. stream << ret; \
  52. } \
  53. } while (0)
  54. #define PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \
  55. do { \
  56. if (value_type == VT_ENUM) { \
  57. REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
  58. stream << "["; \
  59. for (int i = 0; i < ret.size(); i++) { \
  60. stream << ret[i]; \
  61. if (i + 1 != ret.size()) stream << ", "; \
  62. } \
  63. stream << "]"; \
  64. } \
  65. } while (0)
  66. #define PRINT_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \
  67. else PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream)
  68. #define PRINT_LIST_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \
  69. else PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream)
  70. #define PRINT_SHAPE(i_o, n, idx, stream) \
  71. do { \
  72. auto op = n->GetOpDesc(); \
  73. GeTensorDesc td = i_o == "input" ? op->GetInputDesc(idx) : op->GetOutputDesc(idx); \
  74. auto shape = td.GetShape().GetDims(); \
  75. stream << "["; \
  76. for (int i = 0; i < shape.size(); i++) { \
  77. stream << shape[i]; \
  78. if (i + 1 < shape.size()) stream << ", "; \
  79. } \
  80. stream << "]"; \
  81. } while (0)
  82. #define PRINT_ATTR_FUNC(stream) \
  83. [&](GeAttrValue attr) { \
  84. auto type = attr.GetValueType(); \
  85. PRINT_ATTR_VALUE_IF(type, GeAttrValue::ValueType::VT_STRING, GeAttrValue::STR, attr, stream) \
  86. PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_FLOAT, GeAttrValue::FLOAT, attr, stream) \
  87. PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_BOOL, GeAttrValue::BOOL, attr, stream) \
  88. PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_INT, GeAttrValue::INT, attr, stream) \
  89. PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_STRING, GeAttrValue::LIST_STR, attr, stream) \
  90. PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_FLOAT, GeAttrValue::LIST_FLOAT, attr, stream) \
  91. PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_BOOL, GeAttrValue::LIST_BOOL, attr, stream) \
  92. PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_INT, GeAttrValue::LIST_INT, attr, stream) \
  93. else if (type == GeAttrValue::ValueType::VT_TENSOR_DESC) stream << "TENSOR_DESC"; \
  94. else if (type == GeAttrValue::ValueType::VT_TENSOR) stream << "TENSOR"; \
  95. else if (type == GeAttrValue::ValueType::VT_BYTES) stream << "BYTES"; \
  96. else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR_DESC) stream << "LIST_TENSOR_DESC"; \
  97. else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR) stream << "LIST_TENSOR"; \
  98. else if (type == GeAttrValue::ValueType::VT_LIST_BYTES) stream << "LIST_BYTES"; \
  99. };
  100. namespace ge {
  101. enum IOType { kIn, kOut };
  102. struct NodeIndexIO {
  103. NodeIndexIO(ge::NodePtr node, uint32_t index, IOType io_type)
  104. : node_(std::move(node)), index_(index), io_type_(io_type) {
  105. if (node_ != nullptr) {
  106. value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_);
  107. }
  108. }
  109. NodeIndexIO(ge::NodePtr node, int index, IOType io_type)
  110. : node_(std::move(node)), index_(static_cast<uint32_t>(index)), io_type_(io_type) {
  111. if (node_ != nullptr) {
  112. value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_);
  113. }
  114. }
  115. ~NodeIndexIO() {}
  116. NodePtr node_ = nullptr;
  117. uint32_t index_ = 0;
  118. IOType io_type_ = kOut;
  119. std::string value_;
  120. const std::string &ToString() const { return value_; }
  121. };
  122. class GraphUtils {
  123. public:
  124. static ComputeGraphPtr GetComputeGraph(const Graph &graph);
  125. static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph);
  126. static GraphPtr CreateGraphPtrFromComputeGraph(const ComputeGraphPtr compute_graph);
  127. static graphStatus RecoverGraphOperators(const Graph &graph);
  128. static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector<Operator> &inputs);
  129. static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
  130. static graphStatus AddEdge(const OutDataAnchorPtr &src, const Format &src_format, const InDataAnchorPtr &dst,
  131. const Format &dst_format);
  132. static graphStatus AddEdge(const AnchorPtr &src, const AnchorPtr &dst);
  133. static graphStatus AddEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst);
  134. static graphStatus AddEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst);
  135. // check whether src is link to dst and then remove
  136. static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
  137. static graphStatus RemoveEdge(const AnchorPtr &src, const AnchorPtr &dst);
  138. static graphStatus RemoveEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst);
  139. static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst);
  140. static graphStatus ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst,
  141. const InDataAnchorPtr &new_dst);
  142. static graphStatus ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst,
  143. const InControlAnchorPtr &new_dst);
  144. static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst,
  145. const NodePtr &new_node);
  146. static graphStatus RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node);
  147. static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node);
  148. static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor,
  149. const std::vector<OpDescPtr> &vec_op_desc);
  150. ///
  151. /// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst
  152. /// @param [in] src
  153. /// @param [in] dsts
  154. /// @param [in] insert_node
  155. /// @param [in] input_index
  156. /// @param [in] output_index
  157. /// @return graphStatus
  158. ///
  159. static graphStatus InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
  160. const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0);
  161. static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node);
  162. static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node);
  163. static void RecordOriginalNames(std::vector<ge::NodePtr> original_nodes, const ge::NodePtr &node);
  164. static void RecordOriginalNames(std::vector<std::string> names_tmp, const ge::NodePtr &node);
  165. static bool MatchDumpStr(const std::string &suffix);
  166. static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false,
  167. const std::string &user_graph_name = "");
  168. static void DumpGEGrph(const ge::ComputeGraphPtr &graph, const std::string &path, const std::string &suffix);
  169. static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph);
  170. static bool LoadGEGraph(const char *file, ge::ComputeGraphPtr &compute_graph);
  171. static void BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos);
  172. static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix);
  173. static void DumpGrphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &path, const std::string &suffix);
  174. static bool LoadGEGraphFromOnnx(const char *file, ge::ComputeGraph &compute_graph);
  175. static bool ReadProtoFromTextFile(const char *file, google::protobuf::Message *message);
  176. static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *real_path);
  177. static graphStatus AppendInputNode(const ComputeGraphPtr &graph, const NodePtr &node);
  178. ///
  179. /// Isolating `node`, relinking data links from the in-anchor peer nodes to
  180. /// the out-anchor peer nodes according to `io_map`, relinking control links
  181. /// to ensure that input nodes of `node` are before out nodes
  182. ///
  183. /// Link the `io_map[i]` input anchor peer node to `i` output anchor peer
  184. /// nodes, then unlink all links connecting with `node`. If `io_map[i]` < 0,
  185. /// unlink all links from `i` output anchor without any relinking.
  186. ///
  187. /// @param node
  188. /// @param io_map
  189. /// @return
  190. ///
  191. static graphStatus IsolateNode(const NodePtr &node, const std::initializer_list<int> &io_map);
  192. static graphStatus IsolateNode(const NodePtr &node, const std::vector<int> &io_map);
  193. ///
  194. /// Isolate `node` which must be one input one output, equivalent to
  195. /// `IsolateNode(node, {0})`
  196. /// @param node
  197. /// @return
  198. ///
  199. static graphStatus IsolateNodeOneIO(const NodePtr &node);
  200. ///
  201. /// The data anchors replacing behavior is the same with
  202. /// `ReplaceNodeDataAnchors`. In addition, replace all `old_node` control
  203. /// anchors with `new_node`'s.
  204. /// @param new_node
  205. /// @param old_node
  206. /// @param inputs_map
  207. /// @param outputs_map
  208. /// @return
  209. ///
  210. static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node,
  211. std::initializer_list<int> inputs_map, std::initializer_list<int> outputs_map);
  212. static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node,
  213. const std::vector<int> &inputs_map, const std::vector<int> &outputs_map);
  214. ///
  215. /// Replace `old_node` data anchors with `new_node`'s according to `inputs_map` and `outputs_map`.
  216. /// Replace the `i` in/out data anchor on `old_node` with
  217. /// `inputs_map[i]`/`outputs_map[i]` data anchor on `new_node`.
  218. /// If `inputs_map[i]`/`outputs_map[i]` < 0 or the index not contained in
  219. /// `inputs_map[i]`/`outputs_map[i]`, the `i` data anchor will remain
  220. /// on `old_node`.
  221. /// @param new_node
  222. /// @param old_node
  223. /// @param inputs_map
  224. /// @param outputs_map
  225. /// @return
  226. ///
  227. static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node,
  228. std::initializer_list<int> inputs_map,
  229. std::initializer_list<int> outputs_map);
  230. static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node,
  231. const std::vector<int> &inputs_map, const std::vector<int> &outputs_map);
  232. ///
  233. /// Copy all in-control edges from `src_node` to `dst_node`
  234. /// @param src_node
  235. /// @param dst_node
  236. /// @return
  237. ///
  238. static graphStatus CopyInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node);
  239. static graphStatus MoveInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node);
  240. ///
  241. /// Copy all out-control edges from `src_node` to `dst_node`
  242. /// @param src_node
  243. /// @param dst_node
  244. /// @return success: GRAPH_SUCESS
  245. ///
  246. static graphStatus CopyOutCtrlEdges(const NodePtr &src_node, NodePtr &dst_node);
  247. ///
  248. /// Move all out-control edges from `src_node` to `dst_node`
  249. /// @param src_node
  250. /// @param dst_node
  251. /// @return success: GRAPH_SUCESS
  252. ///
  253. static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node);
  254. ///
  255. /// Copy all in-data edges from `src_node` to `dst_node`
  256. /// @param src_node
  257. /// @param dst_node
  258. /// @return
  259. ///
  260. static graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node);
  261. static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph);
  262. ///
  263. /// Make a copy of ComputeGraph.
  264. /// @param graph: original graph.
  265. /// @param prefix: node name prefix of new graph.
  266. /// @return ComputeGraphPtr
  267. ///
  268. static ComputeGraphPtr CloneGraph(const ComputeGraphPtr &graph, const string &prefix,
  269. std::vector<NodePtr> &input_nodes, std::vector<NodePtr> &output_nodes);
  270. ///
  271. /// Copy tensor attribute to new node.
  272. /// @param [in] dst_desc: cloned node.
  273. /// @param [in] src_node: original node.
  274. /// @return success: GRAPH_SUCESS
  275. ///
  276. static graphStatus CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node);
  277. static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec);
  278. ///
  279. /// Get reference-mapping of all data_anchors in graph
  280. /// @param [in] graph
  281. /// @param [out] symbol_to_anchors
  282. /// @param [out] anchor_to_symbol
  283. /// @return success: GRAPH_SUCESS
  284. ///
  285. static graphStatus GetRefMapping(const ComputeGraphPtr &graph,
  286. std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
  287. std::map<std::string, std::string> &anchor_to_symbol);
  288. ///
  289. /// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs
  290. /// of the graph have UNKNOWN_SHAPE operators or not.
  291. /// Note: This function will only look 'down' from the graph, not 'up'. For example, the following
  292. /// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE
  293. /// ROOT graph: A -----> B -----> C
  294. /// K subgraph U
  295. /// |
  296. /// V
  297. /// SUB graph: D --> E --> F
  298. /// K K K
  299. /// @param [in] graph
  300. /// @return bool
  301. ///
  302. static bool IsUnknownShapeGraph(const ComputeGraphPtr &graph);
  303. static NodePtr FindNodeFromAllNodes(ComputeGraphPtr &graph, const std::string &name);
  304. private:
  305. ///
  306. /// Get reference-mapping for in_data_anchors of node
  307. /// @param [in] node
  308. /// @param [out] symbol_to_anchors
  309. /// @param [out] anchor_to_symbol
  310. /// @return success: GRAPH_SUCESS
  311. ///
  312. static graphStatus HandleInAnchorMapping(const NodePtr &node,
  313. std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
  314. std::map<std::string, std::string> &anchor_to_symbol);
  315. ///
  316. /// Get reference-mapping for out_data_anchors of node
  317. /// @param [in] node
  318. /// @param [out] symbol_to_anchors
  319. /// @param [out] anchor_to_symbol
  320. /// @return success: GRAPH_SUCESS
  321. ///
  322. static graphStatus HandleOutAnchorMapping(const NodePtr &node,
  323. std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
  324. std::map<std::string, std::string> &anchor_to_symbol);
  325. ///
  326. /// Handle input of subgraph
  327. /// @param [in] node
  328. /// @param [out] symbol_to_anchors
  329. /// @param [out] anchor_to_symbol
  330. /// @return success: GRAPH_SUCESS
  331. ///
  332. static graphStatus HandleSubgraphInput(const NodePtr &node,
  333. std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
  334. std::map<std::string, std::string> &anchor_to_symbol);
  335. ///
  336. /// Handle input of Merge op
  337. /// @param [in] node
  338. /// @param [out] symbol_to_anchors
  339. /// @param [out] anchor_to_symbol
  340. /// @return success: GRAPH_SUCESS
  341. ///
  342. static graphStatus HandleMergeInput(const NodePtr &node,
  343. std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
  344. std::map<std::string, std::string> &anchor_to_symbol);
  345. ///
  346. /// Handle output of subgraph
  347. /// @param [in] node
  348. /// @param [out] symbol_to_anchors
  349. /// @param [out] anchor_to_symbol
  350. /// @return success: GRAPH_SUCESS
  351. ///
  352. static graphStatus HandleSubgraphOutput(const NodePtr &node,
  353. std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
  354. std::map<std::string, std::string> &anchor_to_symbol);
  355. ///
  356. /// Relink all edges for cloned ComputeGraph.
  357. /// @param [in] node: original node.
  358. /// @param [in] prefix: node name prefix of new node.
  359. /// @param [in] all_nodes: all nodes in new graph.
  360. /// @return success: GRAPH_SUCESS
  361. ///
  362. static graphStatus RelinkGraphEdges(const NodePtr &node, const string &prefix,
  363. const std::unordered_map<string, NodePtr> &all_nodes);
  364. ///
  365. /// Union ref-mapping
  366. /// @param [in] exist_node_info1
  367. /// @param [in] exist_node_info2
  368. /// @param [out] symbol_to_anchors
  369. /// @param [out] anchor_to_symbol
  370. /// @param [out] symbol
  371. /// @return success: GRAPH_SUCESS
  372. ///
  373. static graphStatus UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2,
  374. std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
  375. std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol);
  376. ///
  377. /// Update symbol mapping with a new reference pair
  378. /// @param [in] cur_node_info
  379. /// @param [in] exist_node_info
  380. /// @param [out] symbol_to_anchors
  381. /// @param [out] anchor_to_symbol
  382. /// @return success: GRAPH_SUCESS
  383. ///
  384. static graphStatus UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info,
  385. std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
  386. std::map<std::string, std::string> &anchor_to_symbol);
  387. ///
  388. /// Check if out_data_anchor is reference of input
  389. /// @param [in] out_data_anchor
  390. /// @param [out] reuse_in_index
  391. /// @return bool
  392. ///
  393. static bool IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index);
  394. };
  395. class ComputeGraphBuilder {
  396. public:
  397. ComputeGraphBuilder() : owner_graph_(nullptr) {}
  398. ComputeGraphBuilder(const ComputeGraphBuilder &) = delete;
  399. ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete;
  400. ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete;
  401. ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete;
  402. ~ComputeGraphBuilder() = default;
  403. ///
  404. /// @brief Add node to graph
  405. /// @param [in] op_desc
  406. /// @return ComputeGraphBuilder
  407. ///
  408. virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc);
  409. ///
  410. /// @brief Add data-link among nodes in graph
  411. /// @param [in] src_name
  412. /// @param [in] out_anchor_ind
  413. /// @param [in] dst_name
  414. /// @param [in] in_anchor_ind
  415. /// @return ComputeGraphBuilder
  416. ///
  417. virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind,
  418. const std::string &dst_name, uint32_t in_anchor_ind);
  419. ///
  420. /// @brief Add ctrl-link among nodes in graph
  421. /// @param [in] src_name
  422. /// @param [in] dst_name
  423. /// @return ComputeGraphBuilder
  424. ///
  425. virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name);
  426. ///
  427. /// @brief Build graph
  428. /// @param [out] error_code
  429. /// @param [out] error_msg
  430. /// @return ComputeGraphPtr
  431. ///
  432. virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0;
  433. /// @brief Get node with name
  434. /// @param [in] name
  435. /// @return NodePtr
  436. ///
  437. NodePtr GetNode(const std::string &name);
  438. /// @brief Get all nodes
  439. /// @return std::vector<NodePtr>
  440. ///
  441. std::vector<NodePtr> GetAllNodes();
  442. protected:
  443. ///
  444. /// @brief Build nodes
  445. /// @param [out] error_code
  446. /// @param [out] error_msg
  447. /// @return void
  448. ///
  449. void BuildNodes(graphStatus &error_code, std::string &error_msg);
  450. ///
  451. /// @brief Build data-links
  452. /// @param [out] error_code
  453. /// @param [out] error_msg
  454. /// @return void
  455. ///
  456. void BuildDataLinks(graphStatus &error_code, std::string &error_msg);
  457. ///
  458. /// @brief Build ctrl-links
  459. /// @param [out] error_code
  460. /// @param [out] error_msg
  461. /// @return void
  462. ///
  463. void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg);
  464. ComputeGraphPtr owner_graph_;
  465. // node_name -> node
  466. std::map<std::string, NodePtr> node_names_;
  467. std::vector<OpDescPtr> nodes_;
  468. // <src_node_name, out_anchor_ind> -> <dst_node_name, in_anchor_ind>
  469. std::vector<std::pair<std::pair<std::string, uint32_t>, std::pair<std::string, uint32_t>>> data_links_;
  470. // src_node_name -> dst_node_name
  471. std::vector<std::pair<std::string, std::string>> ctrl_links_;
  472. };
  473. class CompleteGraphBuilder : public ComputeGraphBuilder {
  474. public:
  475. explicit CompleteGraphBuilder(std::string name, bool retval_flag = true)
  476. : name_(std::move(name)), parent_node_(nullptr), retval_flag_(retval_flag) {}
  477. CompleteGraphBuilder(const CompleteGraphBuilder &) = delete;
  478. CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete;
  479. CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete;
  480. CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete;
  481. ~CompleteGraphBuilder() = default;
  482. ///
  483. /// @brief Add node to graph
  484. /// @param [in] op_desc
  485. /// @return CompleteGraphBuilder
  486. ///
  487. CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override;
  488. ///
  489. /// @brief Add data-link among nodes in graph
  490. /// @param [in] src_name
  491. /// @param [in] out_anchor_ind
  492. /// @param [in] dst_name
  493. /// @param [in] in_anchor_ind
  494. /// @return CompleteGraphBuilder
  495. ///
  496. CompleteGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name,
  497. uint32_t in_anchor_ind) override;
  498. ///
  499. /// @brief Add ctrl-link among nodes in graph
  500. /// @param [in] src_name
  501. /// @param [in] dst_name
  502. /// @return CompleteGraphBuilder
  503. ///
  504. CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override;
  505. ///
  506. /// @brief Set index_th input anchor for graph
  507. /// @param [in] index
  508. /// @param [in] node_names
  509. /// @param [in] anchor_inds
  510. /// @return CompleteGraphBuilder
  511. ///
  512. CompleteGraphBuilder &SetInput(uint32_t index, const std::vector<std::string> &node_names,
  513. const std::vector<uint32_t> &anchor_inds);
  514. ///
  515. /// @brief Set index_th input of graph as useless
  516. /// @param [in] index
  517. /// @return CompleteGraphBuilder
  518. ///
  519. CompleteGraphBuilder &SetUselessInput(uint32_t index);
  520. ///
  521. /// @brief Add output anchor for graph
  522. /// @param [in] owner_node_name
  523. /// @param [in] anchor_ind
  524. /// @return CompleteGraphBuilder
  525. ///
  526. CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind);
  527. ///
  528. /// @brief Add target for graph
  529. /// @param [in] target_name
  530. /// @return CompleteGraphBuilder
  531. ///
  532. CompleteGraphBuilder &AddTarget(const std::string &target_name);
  533. ///
  534. /// @brief Set parent-node of graph
  535. /// @param [in] parent_node
  536. /// @return CompleteGraphBuilder
  537. ///
  538. CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node);
  539. ///
  540. /// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node
  541. /// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node
  542. /// @return CompleteGraphBuilder
  543. ///
  544. CompleteGraphBuilder &SetInputMapping(const std::map<uint32_t, uint32_t> &input_mapping);
  545. ///
  546. /// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind
  547. /// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node
  548. /// @return CompleteGraphBuilder
  549. ///
  550. CompleteGraphBuilder &SetOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping);
  551. ///
  552. /// @brief Build graph
  553. /// @param [out] error_code
  554. /// @param [out] error_msg
  555. /// @return ComputeGraphPtr
  556. ///
  557. ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override;
  558. private:
  559. ///
  560. /// @brief Add data nodes
  561. /// @param [out] error_code
  562. /// @param [out] error_msg
  563. /// @return void
  564. ///
  565. void AddDataNodes(graphStatus &error_code, std::string &error_msg);
  566. ///
  567. /// @brief Add data node
  568. /// @param [in] index
  569. /// @param [out] error_code
  570. /// @param [out] error_msg
  571. /// @return void
  572. ///
  573. NodePtr AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg);
  574. ///
  575. /// @brief Add RetVal nodes
  576. /// @param [out] error_code
  577. /// @param [out] error_msg
  578. /// @return void
  579. ///
  580. void AddRetValNodes(graphStatus &error_code, std::string &error_msg);
  581. ///
  582. /// @brief Build target-nodes for graph
  583. /// @param [out] error_code
  584. /// @param [out] error_msg
  585. /// @return void
  586. ///
  587. void BuildGraphTargets(graphStatus &error_code, std::string &error_msg);
  588. ///
  589. /// @brief Add NetOutput node
  590. /// @param [out] error_code
  591. /// @param [out] error_msg
  592. /// @return void
  593. ///
  594. void AddNetOutputNode(graphStatus &error_code, std::string &error_msg);
  595. ///
  596. /// @brief Build NetOutput nodes with data & ctrl edges
  597. /// @param [in] net_output_desc
  598. /// @param [in] peer_out_anchors
  599. /// @param [out] error_code
  600. /// @param [out] error_msg
  601. /// @return void
  602. ///
  603. void BuildNetOutputNodeWithLink(const OpDescPtr &net_output_desc,
  604. const std::vector<OutDataAnchorPtr> &peer_out_anchors, graphStatus &error_code,
  605. std::string &error_msg);
  606. ///
  607. /// @brief process after build
  608. /// @param [out] error_code
  609. /// @param [out] error_msg
  610. /// @return void
  611. ///
  612. void PostProcess(graphStatus &error_code, std::string &error_msg);
  613. std::string name_;
  614. NodePtr parent_node_;
  615. bool retval_flag_;
  616. std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_;
  617. std::vector<std::pair<std::string, uint32_t>> graph_outputs_;
  618. std::vector<std::string> graph_targets_;
  619. // index_of_graph_input -> in_anchor_index_of_parent_node
  620. std::map<uint32_t, uint32_t> input_mapping_;
  621. // index_of_graph_output -> out_anchor_index_of_parent_node
  622. std::map<uint32_t, uint32_t> output_mapping_;
  623. };
  624. class PartialGraphBuilder : public ComputeGraphBuilder {
  625. public:
  626. PartialGraphBuilder() = default;
  627. PartialGraphBuilder(const PartialGraphBuilder &) = delete;
  628. PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete;
  629. PartialGraphBuilder(const PartialGraphBuilder &&) = delete;
  630. PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete;
  631. ~PartialGraphBuilder() = default;
  632. ///
  633. /// @brief Add node to graph
  634. /// @param [in] op_desc
  635. /// @return PartialGraphBuilder
  636. ///
  637. PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override;
  638. ///
  639. /// @brief Add data-link among nodes in graph
  640. /// @param [in] src_name
  641. /// @param [in] out_anchor_ind
  642. /// @param [in] dst_name
  643. /// @param [in] in_anchor_ind
  644. /// @return PartialGraphBuilder
  645. ///
  646. PartialGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name,
  647. uint32_t in_anchor_ind) override;
  648. ///
  649. /// @brief Add ctrl-link among nodes in graph
  650. /// @param [in] src_name
  651. /// @param [in] dst_name
  652. /// @return PartialGraphBuilder
  653. ///
  654. PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override;
  655. ///
  656. /// @brief Set owner graph
  657. /// @param [in] graph
  658. /// @return PartialGraphBuilder
  659. ///
  660. PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph);
  661. ///
  662. /// @brief Add exist node
  663. /// @param [in] node
  664. /// @return PartialGraphBuilder
  665. ///
  666. PartialGraphBuilder &AddExistNode(const NodePtr &node);
  667. ///
  668. /// @brief Build multi nodes with links
  669. /// @param [out] error_code
  670. /// @param [out] error_msg
  671. /// @return ComputeGraphPtr
  672. ///
  673. ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override;
  674. private:
  675. ///
  676. /// @brief Build exist nodes
  677. /// @param [out] error_code
  678. /// @param [out] error_msg
  679. /// @return void
  680. ///
  681. void BuildExistNodes(graphStatus &error_code, std::string &error_msg);
  682. std::vector<NodePtr> exist_nodes_;
  683. };
  684. } // namespace ge
  685. #endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示