You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_copy_graph.h 6.4 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_PREPROCESS_MULTI_BATCH_COPY_GRAPH_H_
  17. #define GE_GRAPH_PREPROCESS_MULTI_BATCH_COPY_GRAPH_H_
  18. #include <map>
  19. #include <queue>
  20. #include <vector>
  21. #include "external/ge/ge_api_error_codes.h"
  22. #include "graph/compute_graph.h"
  23. namespace ge {
  24. namespace multibatch {
  25. Status ProcessMultiBatch(ComputeGraphPtr &graph);
  26. Status GetDynamicOutputShape(ComputeGraphPtr &graph);
  27. enum NodeStatus {
  28. kNodeInBatchBranch,
  29. kNodeOutBatchBranch,
  30. kNodeStartNode,
  31. kNodeNotSupportNode,
  32. };
  33. enum DynamicType {
  34. kDynamicBatch,
  35. kDynamicImageSize,
  36. kDynamicDims,
  37. kDynamicUnknown,
  38. };
  39. class MultiBatchGraphCopyer {
  40. public:
  41. explicit MultiBatchGraphCopyer(ComputeGraphPtr &graph) : graph_(graph) {}
  42. ~MultiBatchGraphCopyer() = default;
  43. void AddShape(const std::vector<int64_t> &shape) { shapes_.emplace_back(shape); }
  44. void SetUserDesignateShape(const vector<pair<string, vector<int64_t>>> &designate_shape) {
  45. user_designate_shape_ = designate_shape;
  46. for (auto &item : designate_shape) {
  47. data_name_order_.push_back(item.first);
  48. }
  49. }
  50. void SetDynamicType(const DynamicType dynamic_type) {
  51. dynamic_type_ = dynamic_type;
  52. }
  53. Status CopyGraph();
  54. private:
  55. Status Init();
  56. Status CheckArguments();
  57. // label status for origin_all_nodes_
  58. Status LabelStatus();
  59. Status LabelInBatchBranchStatus();
  60. void LabelStatusForData(const NodePtr &data);
  61. void LabelStatusForGetNextSink(const NodePtr &data);
  62. // add nodes functions
  63. Status CreateNewNodes();
  64. NodePtr InsertShapeDataNode();
  65. NodePtr InsertGetDynamicDimsNode();
  66. Status InsertSwitchNAndUpdateMaxShape(const NodePtr &node);
  67. Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index,
  68. std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn);
  69. Status InsertIdentityAfterSwitchN();
  70. Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index);
  71. Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index);
  72. Status InsertMergeForEdgeNode(const NodePtr &node);
  73. Status LinkGetDynamicDimsToNetOutput(const NodePtr &node);
  74. /// Insert a merge node for src node `node` on output index `index`. The merge node will be used to merge all nodes
  75. /// in batch-branch to one output to the node out of the batch-branch.
  76. /// Cond 1: If the `index` is -1, then the src node link a data edge(at output 0) to the merge node,
  77. /// Cond 2: In condition 1, if the src node does not have any data output, we create a const node after it,
  78. /// the result like this:
  79. /// src_node ---------> const_for_src_node --------> merge
  80. /// control data
  81. /// Cond 3: If the src node is a data-like node, the SwitchN after it will be link to the merge node.
  82. /// @param node
  83. /// @param index
  84. /// @return
  85. NodePtr InsertMergeNode(const NodePtr &node, int index);
  86. Status CopyNodeInBatchBranch(const NodePtr &node);
  87. // link edges functions
  88. Status LinkEdges();
  89. Status AddAttrForGetDynamicDims(const NodePtr &node);
  90. Status AddLinkForGetDynamicDims(const NodePtr &node);
  91. Status LinkDataToSwitchN(const NodePtr &data, const NodePtr &switchn, const int &out_index);
  92. Status LinkToMerge(const NodePtr &node);
  93. Status LinkToNodeInBranch(const NodePtr &node);
  94. Status LinkToNodeOutBranch(const NodePtr &node);
  95. Status LinkDataToMerge(const NodePtr &data, const NodePtr &merge, const NodePtr &switchn);
  96. Status LinkNodeToMerge(const NodePtr &node, int out_index, const NodePtr &merge);
  97. NodePtr FindSwitchnNodeForDataEdge(const OutDataAnchorPtr &data_out_anchor, const NodePtr &origin_node);
  98. Status CopyInDataEdges(const NodePtr &origin_node, int batch_num, const NodePtr &copyed_node);
  99. Status CopyInControlEdges(const NodePtr &node, int batch_num, const NodePtr &copyed_node);
  100. Status CheckAndParseDynamicData();
  101. bool IsInBatchBranch(const NodePtr &node);
  102. NodeStatus GetNodeStatus(const NodePtr &node) { return origin_nodes_status_[node.get()]; };
  103. Status CheckCopyResult(const std::vector<NodePtr> &start_nodes);
  104. // arguments
  105. ComputeGraphPtr graph_;
  106. std::vector<std::vector<int64_t>> shapes_;
  107. // the shape data node created
  108. NodePtr shape_data_;
  109. // all nodes in the origin graph
  110. std::vector<NodePtr> origin_all_nodes_;
  111. // all data nodes in the origin graph
  112. std::vector<NodePtr> origin_data_nodes_;
  113. // the nodes in-batch-branch, and the nodes copyed by shapes
  114. std::map<Node *, std::vector<NodePtr>> nodes_to_batch_nodes_;
  115. // the data nodes, and the SwitchN nodes inserted after it
  116. std::map<Node *, NodePtr> data_nodes_to_switchn_;
  117. // the getnext_sink nodes, and the SwitchN nodes inserted after it
  118. std::vector<std::vector<std::pair<Node *, NodePtr>>> getnext_nodes_to_switchn_;
  119. std::vector<std::vector<std::pair<int, int>>> outidx_inidx_mappings_;
  120. std::vector<std::pair<int, int>> outidx_inidx_mapping_;
  121. // the nodes on the in/out-batch-branch edge, and the merge nodes inserted after it
  122. std::map<Node *, std::vector<NodePtr>> nodes_to_merge_nodes_;
  123. // all nodes and their status
  124. std::map<Node *, NodeStatus> origin_nodes_status_;
  125. // user designate shape, decord the order of each input data
  126. std::vector<std::pair<std::string, std::vector<int64_t>>> user_designate_shape_;
  127. std::vector<std::string> data_name_order_;
  128. // each data's own dynamic info
  129. map<string, vector<vector<int64_t>>> data_to_dynamic_info_;
  130. // dynamic type : dynamic batch,, dynamic image size, dynamic dims.
  131. DynamicType dynamic_type_ = DynamicType::kDynamicUnknown;
  132. std::vector<std::pair<size_t, size_t>> getnext_sink_dynamic_out_mapping_;
  133. bool getnext_sink_dynamic_dims_ = false;
  134. };
  135. } // namespace multibatch
  136. } // namespace ge
  137. #endif // GE_GRAPH_PREPROCESS_MULTI_BATCH_COPY_GRAPH_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示