You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_copy_graph.h 7.3 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_PREPROCESS_MULTI_BATCH_COPY_GRAPH_H_
  17. #define GE_GRAPH_PREPROCESS_MULTI_BATCH_COPY_GRAPH_H_
  18. #include <map>
  19. #include <queue>
  20. #include <vector>
  21. #include <set>
  22. #include "external/ge/ge_api_error_codes.h"
  23. #include "graph/compute_graph.h"
  24. namespace ge {
  25. namespace multibatch {
  26. Status ProcessMultiBatch(ComputeGraphPtr &graph);
  27. Status GetDynamicOutputShape(ComputeGraphPtr &graph);
  28. enum NodeStatus {
  29. kNodeInBatchBranch,
  30. kNodeOutBatchBranch,
  31. kNodeStartNode,
  32. kNodeNotSupportNode,
  33. };
  34. enum DynamicType {
  35. kDynamicBatch,
  36. kDynamicImageSize,
  37. kDynamicDims,
  38. kDynamicUnknown,
  39. };
  40. class MultiBatchGraphCopyer {
  41. public:
  42. explicit MultiBatchGraphCopyer(ComputeGraphPtr &graph) : graph_(graph) {}
  43. ~MultiBatchGraphCopyer() = default;
  44. void AddShape(const std::vector<int64_t> &shape) { shapes_.emplace_back(shape); }
  45. void SetUserDesignateShape(const vector<pair<string, vector<int64_t>>> &designate_shape) {
  46. user_designate_shape_ = designate_shape;
  47. for (auto &item : designate_shape) {
  48. data_name_order_.push_back(item.first);
  49. }
  50. }
  51. void SetDynamicType(const DynamicType dynamic_type) {
  52. dynamic_type_ = dynamic_type;
  53. }
  54. Status CopyGraph();
  55. private:
  56. Status Init();
  57. Status CheckArguments();
  58. Status RelinkConstCtrlEdge();
  59. Status ExtractUnchangedStructureOutofCycle();
  60. Status GetEnterNodesGroupByFrame(std::map<std::string, std::vector<NodePtr>> &frame_enter);
  61. Status GetNodeNeedExtract(const std::map<std::string, std::vector<NodePtr>> &frame_enter,
  62. std::queue<NodePtr> &nodes_to_extract);
  63. bool AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node);
  64. Status MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc);
  65. Status InsertEnterAfterNode(NodePtr &node, const OpDescPtr &enter_desc, std::set<NodePtr> &out_nodes);
  66. Status MoveCtrlEdgeToOutNodes(NodePtr &node, std::set<NodePtr> &out_nodes);
  67. Status DeleteEnterWithoutDataOut();
  68. // label status for origin_all_nodes_
  69. Status LabelStatus();
  70. Status LabelInBatchBranchStatus();
  71. void LabelStatusForData(const NodePtr &data);
  72. void LabelStatusForGetNextSink(const NodePtr &data);
  73. void InitStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters);
  74. void ResetEnterStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters, const NodePtr &node);
  75. // add nodes functions
  76. Status CreateNewNodes();
  77. NodePtr InsertShapeDataNode();
  78. NodePtr InsertGetDynamicDimsNode();
  79. Status InsertSwitchNAndUpdateMaxShape(const NodePtr &node);
  80. Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index,
  81. std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn);
  82. Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index);
  83. Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index);
  84. Status InsertMergeForEdgeNode(const NodePtr &node);
  85. Status LinkGetDynamicDimsToNetOutput(const NodePtr &node);
  86. /// Insert a merge node for src node `node` on output index `index`. The merge node will be used to merge all nodes
  87. /// in batch-branch to one output to the node out of the batch-branch.
  88. /// Cond 1: If the `index` is -1, then the src node link a data edge(at output 0) to the merge node,
  89. /// Cond 2: In condition 1, if the src node does not have any data output, we create a const node after it,
  90. /// the result like this:
  91. /// src_node ---------> const_for_src_node --------> merge
  92. /// control data
  93. /// Cond 3: If the src node is a data-like node, the SwitchN after it will be link to the merge node.
  94. /// @param node
  95. /// @param index
  96. /// @return
  97. NodePtr InsertMergeNode(const NodePtr &node, int index);
  98. Status CopyNodeInBatchBranch(const NodePtr &node);
  99. // link edges functions
  100. Status LinkEdges();
  101. Status AddAttrForGetDynamicDims(const NodePtr &node);
  102. Status AddLinkForGetDynamicDims(const NodePtr &node);
  103. Status LinkDataToSwitchN(const NodePtr &data, const NodePtr &switchn, const int &out_index);
  104. Status LinkToMerge(const NodePtr &node);
  105. Status LinkToNodeInBranch(const NodePtr &node);
  106. Status LinkToNodeOutBranch(const NodePtr &node);
  107. Status LinkDataToMerge(const NodePtr &data, const NodePtr &merge, const NodePtr &switchn);
  108. Status LinkNodeToMerge(const NodePtr &node, int out_index, const NodePtr &merge);
  109. NodePtr FindSwitchnNodeForDataEdge(const OutDataAnchorPtr &data_out_anchor, const NodePtr &origin_node);
  110. Status CopyInDataEdges(const NodePtr &origin_node, int batch_num, const NodePtr &copyed_node);
  111. Status CopyInControlEdges(const NodePtr &node, int batch_num, const NodePtr &copyed_node);
  112. Status CheckAndParseDynamicData();
  113. bool IsInBatchBranch(const NodePtr &node);
  114. NodeStatus GetNodeStatus(const NodePtr &node) { return origin_nodes_status_[node.get()]; };
  115. Status CheckCopyResult(const std::vector<NodePtr> &start_nodes);
  116. // arguments
  117. ComputeGraphPtr graph_;
  118. std::vector<std::vector<int64_t>> shapes_;
  119. // the shape data node created
  120. NodePtr shape_data_;
  121. // all nodes in the origin graph
  122. std::vector<NodePtr> origin_all_nodes_;
  123. // all data nodes in the origin graph
  124. std::vector<NodePtr> origin_data_nodes_;
  125. // the nodes in-batch-branch, and the nodes copyed by shapes
  126. std::map<Node *, std::vector<NodePtr>> nodes_to_batch_nodes_;
  127. // the data nodes, and the SwitchN nodes inserted after it
  128. std::map<Node *, NodePtr> data_nodes_to_switchn_;
  129. // the getnext_sink nodes, and the SwitchN nodes inserted after it
  130. std::vector<std::vector<std::pair<Node *, NodePtr>>> getnext_nodes_to_switchn_;
  131. std::vector<std::vector<std::pair<int, int>>> outidx_inidx_mappings_;
  132. std::vector<std::pair<int, int>> outidx_inidx_mapping_;
  133. // the nodes on the in/out-batch-branch edge, and the merge nodes inserted after it
  134. std::map<Node *, std::vector<NodePtr>> nodes_to_merge_nodes_;
  135. // all nodes and their status
  136. std::map<Node *, NodeStatus> origin_nodes_status_;
  137. // user designate shape, decord the order of each input data
  138. std::vector<std::pair<std::string, std::vector<int64_t>>> user_designate_shape_;
  139. std::vector<std::string> data_name_order_;
  140. // each data's own dynamic info
  141. map<string, vector<vector<int64_t>>> data_to_dynamic_info_;
  142. // dynamic type : dynamic batch,, dynamic image size, dynamic dims.
  143. DynamicType dynamic_type_ = DynamicType::kDynamicUnknown;
  144. std::vector<std::pair<size_t, size_t>> getnext_sink_dynamic_out_mapping_;
  145. bool getnext_sink_dynamic_dims_ = false;
  146. };
  147. } // namespace multibatch
  148. } // namespace ge
  149. #endif // GE_GRAPH_PREPROCESS_MULTI_BATCH_COPY_GRAPH_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示