You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transop_without_reshape_fusion_pass.h 6.6 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_PASSES_TRANSOP_WITHOUT_RESHAPE_FUSION_PASS_H_
  17. #define GE_GRAPH_PASSES_TRANSOP_WITHOUT_RESHAPE_FUSION_PASS_H_
  18. #include <vector>
  19. #include <utility>
  20. #include "inc/graph_pass.h"
  21. namespace ge {
  22. ///
  23. /// Transform operators depth fusion
  24. ///
  25. class TransOpWithoutReshapeFusionPass : public GraphPass {
  26. public:
  27. TransOpWithoutReshapeFusionPass() {}
  28. virtual ~TransOpWithoutReshapeFusionPass() {}
  29. graphStatus Run(ge::ComputeGraphPtr graph) override;
  30. private:
  31. void SetRemainNode(const vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_anchor);
  32. bool FormatContinuousCheck(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &in_anchor);
  33. void RemoveNousedNodes(const ComputeGraphPtr &graph);
  34. void GetBeginOutDescAndEndInDesc(const int index, GeTensorDesc &out_desc, GeTensorDesc &in_desc);
  35. void GetFormatTransferDesc(const GeTensorDesc &out_desc,
  36. const GeTensorDesc &in_desc,
  37. GeTensorDesc &format_transfer_input,
  38. GeTensorDesc &format_transfer_output);
  39. void GetCastOpDesc(const GeTensorDesc &out_desc,
  40. const GeTensorDesc &in_desc,
  41. GeTensorDesc &cast_input,
  42. GeTensorDesc &cast_output);
  43. graphStatus FormatFusion(const int index,
  44. OpDescPtr &format_transfer_op,
  45. int32_t &fusion_op_count,
  46. bool &fusion_continue);
  47. graphStatus DataTypeFusion(const int index, OpDescPtr &cast_op, int32_t &fusion_op_count);
  48. void GetOutDataPeerInControlAnchors(const size_t index,
  49. vector<vector<InControlAnchorPtr>> &out_data_peer_in_control_anchors);
  50. void GetInControlPeerOutControlAnchors(
  51. const size_t index,
  52. vector<vector<OutControlAnchorPtr>> &in_control_peer_out_control_anchors);
  53. void GetOutControlPeerAnchors(
  54. const size_t index,
  55. vector<vector<InControlAnchorPtr>> &out_control_peer_in_control_anchors,
  56. vector<vector<InDataAnchorPtr>> &out_control_peer_in_data_anchors);
  57. graphStatus TransOpFuse(const ComputeGraphPtr &graph);
  58. bool OpAccuracyAbilityCheck(const OpDescPtr &op_desc);
  59. graphStatus GetSubGraphsBetweenNormalNode(
  60. const OutDataAnchorPtr &out_anchor,
  61. vector<vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>
  62. >& sub_graphs_out,
  63. vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_list
  64. );
  65. graphStatus GetSubGraphNodesInfo();
  66. void GetControlAnchors();
  67. graphStatus InsertNewTransOp(const ComputeGraphPtr &graph, const OpDescPtr &cast_op,
  68. const OpDescPtr &format_transfer_op, const int index,
  69. const bool insert_cast_first);
  70. void EraseInvalidAnchorsPair();
  71. graphStatus RelinkNodesWhenDescNotChanged(const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
  72. const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair,
  73. const int index);
  74. OpDescPtr GetFormatTransferOp(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc);
  75. OpDescPtr GetCastOp(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc);
  76. graphStatus TransOpFuseHandle(const ge::ComputeGraphPtr &graph, const int index);
  77. graphStatus AddTransNode(const ComputeGraphPtr &graph, const OpDescPtr &transop, NodePtr &trans_node);
  78. bool DescEqualCheck(ConstGeTensorDescPtr &desc_src, ConstGeTensorDescPtr &desc_dst) const;
  79. bool ShapeEqualCheck(const GeShape &src, const GeShape &dst) const;
  80. bool InsertCastFirstCheck(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc) const;
  81. graphStatus RelinkControlEdge(const int index, const OutDataAnchorPtr &out_anchor,
  82. const vector<NodePtr> &new_trans_nodes);
  83. graphStatus GetTransNode(const ComputeGraphPtr &graph,
  84. const OpDescPtr &cast_op,
  85. const OpDescPtr &format_transfer_op,
  86. const bool insert_cast_first,
  87. std::vector<NodePtr> &new_trans_nodes);
  88. void UpdateOutputName(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &old_peer_in_anchor,
  89. const NodePtr &in_owner_node);
  90. void UpdateInputName(const OutDataAnchorPtr &old_peer_out_anchor, const InDataAnchorPtr &in_anchor,
  91. const NodePtr &out_owner_node);
  92. graphStatus RelinkControlEdgesWhenDescNotChanged(const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
  93. const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair,
  94. const int index);
  95. graphStatus RelinkSubGraphControlEdges(const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
  96. const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair,
  97. const int index);
  98. ///
  99. /// judge whether an operator is a transform op or not
  100. /// @param node
  101. /// @return True or False
  102. ///
  103. static bool IsTransOp(const NodePtr &node);
  104. static bool FusionFormatSupport(Format format);
  105. vector<vector<pair<OutDataAnchorPtr, InDataAnchorPtr>>>
  106. sub_graph_anchors_;
  107. vector<vector<NodePtr>> sub_graph_nodes_;
  108. vector<int> transop_num_count_;
  109. vector<bool> sub_graph_has_reshape_node_;
  110. vector<vector<OutControlAnchorPtr>> in_control_peer_out_control_anchors_;
  111. vector<vector<InControlAnchorPtr>> out_control_peer_in_control_anchors_;
  112. vector<vector<InDataAnchorPtr>> out_control_peer_in_data_anchors_;
  113. vector<vector<InControlAnchorPtr>> out_data_peer_in_control_anchors_;
  114. vector<bool> sub_graph_has_control_edge_;
  115. vector<bool> sub_graph_has_out_data_peer_in_control_edge_;
  116. };
  117. } // namespace ge
  118. #endif // GE_GRAPH_PASSES_TRANSOP_WITHOUT_RESHAPE_FUSION_PASS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示