You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

compute_graph.h 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_GRAPH_COMPUTE_GRAPH_H_
  17. #define INC_GRAPH_COMPUTE_GRAPH_H_
  18. #include <map>
  19. #include <memory>
  20. #include <string>
  21. #include <utility>
  22. #include <vector>
  23. #include <deque>
  24. #include "detail/attributes_holder.h"
  25. #include "graph/anchor.h"
  26. #include "graph/node.h"
  27. #include "graph/op_desc.h"
  28. #include "graph/range_vistor.h"
  29. namespace ge {
  30. class Node;
  31. using NodePtr = std::shared_ptr<Node>;
  32. class Edge;
  33. using EdgePtr = std::shared_ptr<Edge>;
  34. class InDataAnchor;
  35. using InDataAnchorPtr = std::shared_ptr<InDataAnchor>;
  36. class OutDataAnchor;
  37. using OutDataAnchorPtr = std::shared_ptr<OutDataAnchor>;
  38. class ControlAnchor;
  39. using ControlAnchorPtr = std::shared_ptr<ControlAnchor>;
  40. class InControlAnchor;
  41. using InControlAnchorPtr = std::shared_ptr<InControlAnchor>;
  42. class OutControlAnchor;
  43. using OutControlAnchorPtr = std::shared_ptr<OutControlAnchor>;
  44. class GeAttrValue;
  45. using AttrValuePtr = std::shared_ptr<GeAttrValue>;
  46. using ConstComputeGraph = const ComputeGraph;
  47. class OperatorImpl;
  48. using OperatorImplPtr = std::shared_ptr<OperatorImpl>;
  49. class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public AttrHolder {
  50. friend class GraphUtils;
  51. public:
  52. template <class T>
  53. using Vistor = RangeVistor<T, std::shared_ptr<ConstComputeGraph>>;
  54. explicit ComputeGraph(const std::string &name);
  55. ~ComputeGraph() override;
  56. std::string GetName() const;
  57. void SetName(const std::string &name);
  58. using AttrHolder::DelAttr;
  59. using AttrHolder::GetAttr;
  60. using AttrHolder::HasAttr;
  61. using AttrHolder::SetAttr;
  62. size_t GetAllNodesSize() const;
  63. Vistor<NodePtr> GetAllNodes() const;
  64. // is_unknown_shape: false, same with GetAllNodes func
  65. // is_unknown_shape: true, same with GetDirectNodes func
  66. Vistor<NodePtr> GetNodes(bool is_unknown_shape) const;
  67. size_t GetDirectNodesSize() const;
  68. Vistor<NodePtr> GetDirectNode() const;
  69. Vistor<NodePtr> GetInputNodes() const;
  70. Vistor<NodePtr> GetOutputNodes() const;
  71. NodePtr FindNode(const std::string &name) const;
  72. NodePtr FindFirstNodeMatchType(const std::string &name) const;
  73. // AddNode with NodePtr
  74. NodePtr AddNode(NodePtr node);
  75. NodePtr AddNode(OpDescPtr op);
  76. NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize
  77. NodePtr AddNodeFront(NodePtr node);
  78. NodePtr AddNodeFront(const OpDescPtr &op);
  79. NodePtr AddInputNode(NodePtr node);
  80. NodePtr AddOutputNode(NodePtr node);
  81. NodePtr AddOutputNodeByIndex(NodePtr node, int32_t index);
  82. // insert node with specific pre_node
  83. NodePtr AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node);
  84. NodePtr AddNodeAfter(NodePtr node, const NodePtr &pre_node);
  85. graphStatus RemoveNode(const NodePtr &node);
  86. graphStatus RemoveInputNode(const NodePtr &node);
  87. graphStatus RemoveOutputNode(const NodePtr &node);
  88. graphStatus RemoveConstInput(const NodePtr &node);
  89. /// Add a subgraph to this graph. The subgraph must has a parent graph and parent node,
  90. /// which means the member functions `SetParentGraph` and `SetParentNode` of the subgraph
  91. /// must be called before add it to the root graph. and subgraph->GetParentNode()->GetOwnerGraph()
  92. /// must equal to subgraph->GetOwnerGraph().
  93. /// The subgraphs can only be added to a *root graph*. A root graph is a graph without any parent graph.
  94. /// The subgraph's name SHOULD(not must) be the same as the parameter `name`
  95. graphStatus AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph);
  96. graphStatus AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph);
  97. void RemoveSubgraph(const std::string &name);
  98. void RemoveSubgraph(const std::shared_ptr<ComputeGraph> &subgraph);
  99. std::shared_ptr<ComputeGraph> GetSubgraph(const std::string &name) const;
  100. std::vector<std::shared_ptr<ComputeGraph>> GetAllSubgraphs() const;
  101. // obsolete
  102. std::shared_ptr<ComputeGraph> AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph);
  103. // obsolete
  104. graphStatus RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph);
  105. ///
  106. /// @brief Update input-mapping
  107. /// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input
  108. /// @return graphStatus
  109. ///
  110. graphStatus UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping);
  111. ///
  112. /// @brief Update output-mapping
  113. /// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output
  114. /// @return graphStatus
  115. ///
  116. graphStatus UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping);
  117. graphStatus TopologicalSorting();
  118. bool IsValid() const;
  119. void InValid() { is_valid_flag_ = false; }
  120. void Dump() const;
  121. void Swap(ComputeGraph &graph);
  122. graphStatus IsolateNode(const NodePtr &node);
  123. graphStatus Verify();
  124. graphStatus InferShape();
  125. graphStatus InferOriginFormat();
  126. graphStatus InferShapeInNeed();
  127. graphStatus InsertEventNodes();
  128. bool operator==(const ComputeGraph &r_compute_graph) const;
  129. const std::map<std::vector<std::string>, std::vector<std::string>> &GetShareParamLayer() const {
  130. return params_share_map_;
  131. }
  132. void SetShareParamLayer(const std::map<std::vector<std::string>, std::vector<std::string>> params_share_map) {
  133. params_share_map_ = params_share_map;
  134. }
  135. void SetInputsOrder(const std::vector<std::string> &inputs_order) { inputs_order_ = inputs_order; }
  136. void SetGraphOutNodes(std::map<std::string, std::vector<int32_t>> out_nodes_map) { out_nodes_map_ = out_nodes_map; }
  137. void AppendGraphOutNodes(std::map<std::string, std::vector<int32_t>> out_nodes_map) {
  138. for (auto &item : out_nodes_map) {
  139. (void)out_nodes_map_.emplace(item.first, item.second);
  140. }
  141. }
  142. shared_ptr<ComputeGraph> GetParentGraph();
  143. void SetParentGraph(const shared_ptr<ComputeGraph> &parent);
  144. shared_ptr<Node> GetParentNode();
  145. void SetParentNode(const shared_ptr<Node> &parent);
  146. const std::map<std::string, std::vector<int32_t>> &GetGraphOutNodes() const { return out_nodes_map_; }
  147. void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; }
  148. ComputeGraphPtr GetOrigGraph(void) { return origGraph_; }
  149. void SetOutputSize(uint32_t size) { output_size_ = size; }
  150. uint32_t GetOutputSize() const { return output_size_; }
  151. void SetInputSize(uint32_t size) { input_size_ = size; }
  152. uint32_t GetInputSize() const { return input_size_; }
  153. // false: known shape true: unknow shape
  154. bool GetGraphUnknownFlag() const { return is_unknown_shape_graph_; }
  155. void SetGraphUnknownFlag(bool flag) { is_unknown_shape_graph_ = flag; }
  156. ///
  157. /// Set is need train iteration.
  158. /// If set true, it means this graph need to be run iteration some
  159. /// times(according variant "npu_runconfig/iterations_per_loop").
  160. /// @param need_iteration is need iteration
  161. ///
  162. void SetNeedIteration(bool need_iteration) { need_iteration_ = need_iteration; }
  163. void SetUserDefOutput(const std::string &output_name);
  164. const std::string GetOutput();
  165. ///
  166. /// Get is need train iteration.
  167. /// @return is need iteration
  168. ///
  169. bool GetNeedIteration() const { return need_iteration_; }
  170. void SetGraphOpName(const std::map<uint32_t, std::string> &op_name_map) { op_name_map_ = op_name_map; }
  171. const std::map<uint32_t, std::string> &GetGraphOpName() const { return op_name_map_; }
  172. const std::map<OperatorImplPtr, NodePtr> &GetAllNodesInfo() const;
  173. void SetAllNodesInfo(const std::map<OperatorImplPtr, NodePtr> &nodes) { all_nodes_infos_ = nodes; }
  174. void SetGraphOutNodesInfo(std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info) {
  175. output_nodes_info_ = out_nodes_info;
  176. }
  177. void AppendGraphOutNodesInfo(std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info) {
  178. output_nodes_info_.insert(output_nodes_info_.end(), out_nodes_info.begin(), out_nodes_info.end());
  179. }
  180. const std::vector<std::pair<NodePtr, int32_t>> &GetGraphOutNodesInfo() const { return output_nodes_info_; }
  181. void SetGraphTargetNodesInfo(const std::vector<NodePtr> &target_nodes_info) {
  182. target_nodes_info_ = target_nodes_info;
  183. }
  184. const std::vector<NodePtr> &GetGraphTargetNodesInfo() const { return target_nodes_info_; }
  185. void SetSessionID(uint64_t session_id) { session_id_ = session_id; }
  186. uint64_t GetSessionID() const { return session_id_; }
  187. void SetGraphID(uint32_t graph_id) { graph_id_ = graph_id; }
  188. uint32_t GetGraphID() const { return graph_id_; }
  189. void SaveDataFormat(ge::Format data_format) { data_format_ = data_format; }
  190. ge::Format GetDataFormat() const { return data_format_; }
  191. bool IsSummaryGraph() const { return is_summary_graph_; }
  192. void SetSummaryFlag(bool is_summary_graph) { is_summary_graph_ = is_summary_graph; }
  193. // Graph Before BFE
  194. ComputeGraphPtr origGraph_;
  195. protected:
  196. ProtoAttrMapHelper MutableAttrMap() override;
  197. ConstProtoAttrMapHelper GetAttrMap() const override;
  198. private:
  199. graphStatus DFSTopologicalSorting(std::vector<NodePtr> &node_vec, std::map<NodePtr, uint32_t> &map_in_edge_num,
  200. std::vector<NodePtr> &stack, bool reverse);
  201. graphStatus BFSTopologicalSorting(std::vector<NodePtr> &node_vec, std::map<NodePtr, uint32_t> &map_in_edge_num,
  202. std::deque<NodePtr> &stack);
  203. graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num,
  204. std::map<string, NodePtr> &breadth_node_map);
  205. /// nodes like : (a) <--- (c) ---> (b)
  206. /// node a and b have only one parent node c, and a is connected to c firstly
  207. /// topo order of DFS is `c, b, a` with `dfs_reverse=false` as default
  208. /// in same case, user could get `c, a, b` with `dfs_reverse=true`
  209. graphStatus TopologicalSortingGraph(bool dfs_reverse = false);
  210. graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum);
  211. Vistor<NodePtr> AllGraphNodes(std::vector<std::shared_ptr<ComputeGraph>> &subgraphs) const;
  212. size_t GetInEdgeSize(const NodePtr &node);
  213. size_t GetOutEdgeSize(const NodePtr &node);
  214. graphStatus RemoveExtraOutEdge(const NodePtr &node);
  215. bool GraphMembersAreEqual(const ComputeGraph &r_graph) const;
  216. bool GraphAttrsAreEqual(const ComputeGraph &r_graph) const;
  217. bool VectorInputNodePtrIsEqual(const std::vector<NodePtr> &r_node_ptr_vector,
  218. const std::vector<NodePtr> &l_node_ptr_vector) const;
  219. void SetNodesOwner();
  220. friend class ModelSerializeImp;
  221. friend class GraphDebugImp;
  222. friend class OnnxUtils;
  223. friend class TuningUtils;
  224. std::string name_;
  225. uint32_t graph_id_ = 0;
  226. ProtoAttrMapHelper attrs_;
  227. std::vector<NodePtr> nodes_;
  228. std::map<OperatorImplPtr, NodePtr> all_nodes_infos_;
  229. std::vector<NodePtr> target_nodes_info_;
  230. std::vector<NodePtr> input_nodes_;
  231. std::vector<std::string> inputs_order_;
  232. uint32_t input_size_ = 1;
  233. std::map<std::string, std::vector<int32_t>> out_nodes_map_;
  234. uint32_t output_size_ = 1;
  235. std::vector<std::pair<NodePtr, int32_t>> output_nodes_info_;
  236. std::vector<std::shared_ptr<ComputeGraph>> sub_graph_;
  237. std::map<std::string, std::shared_ptr<ComputeGraph>> names_to_subgraph_;
  238. std::weak_ptr<ComputeGraph> parent_graph_;
  239. std::weak_ptr<Node> parent_node_;
  240. // the members followed should not in the ComputeGraph class
  241. bool is_valid_flag_;
  242. bool is_summary_graph_ = false;
  243. // Indicates whether it is need iteration
  244. bool need_iteration_ = false;
  245. std::map<std::vector<std::string>, std::vector<std::string>> params_share_map_;
  246. // TaskIdx -> op_name Map
  247. std::map<uint32_t, std::string> op_name_map_;
  248. uint64_t session_id_ = 0;
  249. ge::Format data_format_ = ge::FORMAT_ND;
  250. // unknown graph indicator, default is false, mean known shape
  251. bool is_unknown_shape_graph_ = false;
  252. };
  253. } // namespace ge
  254. #endif // INC_GRAPH_COMPUTE_GRAPH_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示