diff --git a/ge/graph/passes/merge_pass.cc b/ge/graph/passes/merge_pass.cc index 0b367614..d2340037 100644 --- a/ge/graph/passes/merge_pass.cc +++ b/ge/graph/passes/merge_pass.cc @@ -34,11 +34,6 @@ using domi::SUCCESS; namespace ge { const int kValueIndexOutputIndex = 1; -bool IsEmptyTensor(const GeShape &shape) { - const auto &dims = shape.GetDims(); - return std::any_of(dims.begin(), dims.end(), [](int64_t dim) { return dim == 0; }); -} - Status MergePass::Run(NodePtr &node) { GELOGD("MergePass running"); if (node == nullptr) { @@ -58,11 +53,6 @@ Status MergePass::Run(NodePtr &node) { return PARAM_INVALID; } - if (OptimizeEmptyTensorInput(node) != SUCCESS) { - GELOGE(FAILED, "[%s] remove empty_tensor inputs failed.", node->GetName().c_str()); - return FAILED; - } - auto in_data_nodes = node->GetInDataNodes(); switch (in_data_nodes.size()) { case 0: { @@ -212,30 +202,4 @@ bool MergePass::IsMergeInputNeedOptimized(NodePtr &node) const { } return true; } - -Status MergePass::OptimizeEmptyTensorInput(const NodePtr &node) { - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor(); - if (peer_data_anchor == nullptr) { - continue; - } - if ((peer_data_anchor->GetOwnerNode() == nullptr) || - (peer_data_anchor->GetOwnerNode()->GetOpDesc() == nullptr)) { - continue; - } - const auto &op_desc = peer_data_anchor->GetOwnerNode()->GetOpDesc(); - if (IsEmptyTensor(op_desc->GetOutputDesc(peer_data_anchor->GetIdx()).GetShape())) { - if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", - op_desc->GetName().c_str(), peer_data_anchor->GetIdx(), - node->GetName().c_str(), in_data_anchor->GetIdx()); - return FAILED; - } - GELOGD("Remove data edge %s:%d->%s:%d", - op_desc->GetName().c_str(), peer_data_anchor->GetIdx(), - node->GetName().c_str(), in_data_anchor->GetIdx()); - } - } - return SUCCESS; -} } // namespace ge diff --git a/ge/graph/passes/merge_pass.h b/ge/graph/passes/merge_pass.h index 464f2172..2cdb5022 100755 --- a/ge/graph/passes/merge_pass.h +++ b/ge/graph/passes/merge_pass.h @@ -29,7 +29,6 @@ class MergePass : public BaseNodePass { Status ChangeIndexToConstant(NodePtr &node, int &value_index); Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc); bool IsMergeInputNeedOptimized(NodePtr &node) const; - static Status OptimizeEmptyTensorInput(const NodePtr &node); }; } // namespace ge #endif // GE_GRAPH_PASSES_MERGE_PASS_H_ diff --git a/ge/graph/passes/replace_with_empty_const_pass.cc b/ge/graph/passes/replace_with_empty_const_pass.cc index 171c76d0..f3887867 100644 --- a/ge/graph/passes/replace_with_empty_const_pass.cc +++ b/ge/graph/passes/replace_with_empty_const_pass.cc @@ -57,81 +57,30 @@ Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { if (is_all_output_empty) { GELOGI("Node %s has empty tensor output. It will be replaced by empty const.", node->GetName().c_str()); // Replace op which all output is empty with empty const - Status ret = ReplaceWithEmptyConst(node); + vector outputs; + Status ret = GetOutputsOfCurrNode(node, outputs); if (ret != SUCCESS) { // If replace failed, it should not break whole process, so still return success - GELOGW("Failed to repalce node %s with empty const.", node->GetName().c_str()); + GELOGW("Failed to get outputs of node %s.", node->GetName().c_str()); } - } - GELOGD("ReplaceWithEmptyConstPass end."); - return SUCCESS; -} - -Status ReplaceWithEmptyConstPass::ReplaceWithEmptyConst(NodePtr &node_to_replace) { - std::map> shape_out_idx_map; - auto op_desc = node_to_replace->GetOpDesc(); - // Collect out_idx follow different out shape - for (const auto &out_anchor : node_to_replace->GetAllOutDataAnchors()) { - auto out_desc = op_desc->GetOutputDesc(out_anchor->GetIdx()); - shape_out_idx_map[GetDimStr(out_desc.GetShape())].emplace_back(out_anchor->GetIdx()); - } - - for (const auto &shape_2_out_idx : shape_out_idx_map) { - // Create empty const - // The out_desc in one group should be same shape, so here only get first out_desc. its valid index. - auto out_desc = op_desc->GetOutputDesc(shape_2_out_idx.second[0]); - NodePtr const_node; - auto graph = node_to_replace->GetOwnerComputeGraph(); - Status ret = InsertEmptyConst(out_desc, const_node, graph); - if (ret != SUCCESS) { - GELOGE(FAILED, "Failed insert const node."); - return FAILED; - } - - // Repalce data anchors - for (const auto &anchor_idx: shape_2_out_idx.second) { - if (GraphUtils::ReplaceNodeDataAnchors(const_node, node_to_replace, {}, {anchor_idx}) != GRAPH_SUCCESS) { - GELOGE(FAILED, "[%s] ReplaceNodeAnchors failed.", node_to_replace->GetName().c_str()); - return FAILED; + else { + ret = Folding(node, outputs); + if (ret != SUCCESS) { + // If replace failed, it should not break whole process, so still return success + GELOGW("Failed to repalce node %s with empty const.", node->GetName().c_str()); } } - - // Copy in control edge - if (GraphUtils::CopyInCtrlEdges(node_to_replace, const_node) != GRAPH_SUCCESS) { - GELOGE(FAILED, "CopyInCtrlEdges from %s to %s failed.", node_to_replace->GetName().c_str(), - const_node->GetName().c_str()); - return FAILED; - } - // Copy out control edge - if (GraphUtils::CopyOutCtrlEdges(node_to_replace, const_node) != GRAPH_SUCCESS) { - GELOGE(FAILED, "CopyOutCtrlEdges from %s to %s failed.", node_to_replace->GetName().c_str(), - const_node->GetName().c_str()); - return FAILED; - } - AddRePassNodesWithInOut(const_node); - GELOGI("Node %s has been replaced by empty const %s.", node_to_replace->GetName().c_str(), - const_node->GetName().c_str()); } - IsolateAndDeleteNode(node_to_replace, {}); + GELOGD("ReplaceWithEmptyConstPass end."); return SUCCESS; } -Status ReplaceWithEmptyConstPass::InsertEmptyConst(const GeTensorDesc &out_desc, NodePtr &const_node, - ComputeGraphPtr &graph) { - GeTensorPtr empty_tensor = MakeShared(out_desc); - if (empty_tensor == nullptr) { - GELOGE(OUT_OF_MEMORY, "Failed create empty tensor."); - return OUT_OF_MEMORY; - } - auto const_desc = OpDescUtils::CreateConstOp(empty_tensor); - if (const_desc == nullptr) { - GELOGE(OUT_OF_MEMORY, "Failed to get const desc from tensor"); - return OUT_OF_MEMORY; - } - - const_node = graph->AddNode(const_desc); - if (const_node == nullptr) { - GELOGE(FAILED, "Failed insert const node."); - return FAILED; +Status ReplaceWithEmptyConstPass::GetOutputsOfCurrNode(const NodePtr &node_to_replace, vector &outputs) { + for (const auto &out_anchor : node_to_replace->GetAllOutDataAnchors()) { + GE_CHECK_NOTNULL(node_to_replace->GetOpDesc()); + auto out_desc = node_to_replace->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()); + GeTensorPtr empty_tensor = MakeShared(out_desc); + GE_CHECK_NOTNULL(empty_tensor); + outputs.emplace_back(empty_tensor); } return SUCCESS; } @@ -144,12 +93,4 @@ bool ReplaceWithEmptyConstPass::IsEmptyTenor(const GeShape &shape) const { } return false; } - -string ReplaceWithEmptyConstPass::GetDimStr(const GeShape &shape) { - std::stringstream dim_str; - for (auto dim : shape.GetDims()) { - dim_str << dim << '-'; - } - return dim_str.str(); -} } // namespace ge diff --git a/ge/graph/passes/replace_with_empty_const_pass.h b/ge/graph/passes/replace_with_empty_const_pass.h index 5083c699..fde75358 100644 --- a/ge/graph/passes/replace_with_empty_const_pass.h +++ b/ge/graph/passes/replace_with_empty_const_pass.h @@ -17,18 +17,16 @@ #ifndef GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_ #define GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_ -#include "graph/passes/base_pass.h" +#include "graph/passes/folding_pass.h" namespace ge { -class ReplaceWithEmptyConstPass : public BaseNodePass { +class ReplaceWithEmptyConstPass : public FoldingPass { public: Status Run(NodePtr &node) override; private: - Status ReplaceWithEmptyConst(NodePtr &node_to_replace); - Status InsertEmptyConst(const GeTensorDesc &out_desc, NodePtr &const_node, ComputeGraphPtr &graph); + Status GetOutputsOfCurrNode(const NodePtr &node_to_replace, vector &outputs); bool IsEmptyTenor(const GeShape &shape) const; - std::string GetDimStr(const GeShape &shape); }; } // namespace ge #endif // GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_