From 60476ce26e63352caf5806073162b346bba54bba Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Mon, 30 Nov 2020 11:03:33 +0800 Subject: [PATCH] modified: ge/graph/passes/replace_with_empty_const_pass.cc modified: ge/graph/passes/replace_with_empty_const_pass.h --- ge/graph/passes/replace_with_empty_const_pass.cc | 88 ++++-------------------- ge/graph/passes/replace_with_empty_const_pass.h | 8 +-- 2 files changed, 17 insertions(+), 79 deletions(-) diff --git a/ge/graph/passes/replace_with_empty_const_pass.cc b/ge/graph/passes/replace_with_empty_const_pass.cc index 171c76d0..2334daac 100644 --- a/ge/graph/passes/replace_with_empty_const_pass.cc +++ b/ge/graph/passes/replace_with_empty_const_pass.cc @@ -57,81 +57,29 @@ Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { if (is_all_output_empty) { GELOGI("Node %s has empty tensor output. It will be replaced by empty const.", node->GetName().c_str()); // Replace op which all output is empty with empty const - Status ret = ReplaceWithEmptyConst(node); + vector outputs; + Status ret = GetOutputsOfCurrNode(node, outputs); if (ret != SUCCESS) { // If replace failed, it should not break whole process, so still return success - GELOGW("Failed to repalce node %s with empty const.", node->GetName().c_str()); + GELOGW("Failed to get outputs of node %s.", node->GetName().c_str()); + } + else { + ret = Folding(node, outputs); + if (ret != SUCCESS) { + // If replace failed, it should not break whole process, so still return success + GELOGW("Failed to repalce node %s with empty const.", node->GetName().c_str()); + } } } GELOGD("ReplaceWithEmptyConstPass end."); return SUCCESS; } - -Status ReplaceWithEmptyConstPass::ReplaceWithEmptyConst(NodePtr &node_to_replace) { - std::map> shape_out_idx_map; - auto op_desc = node_to_replace->GetOpDesc(); - // Collect out_idx follow different out shape +Status GetOutputsOfCurrNode(const NodePtr &node_to_repalce, vector &outputs) { for (const auto &out_anchor : node_to_replace->GetAllOutDataAnchors()) { auto out_desc = op_desc->GetOutputDesc(out_anchor->GetIdx()); - shape_out_idx_map[GetDimStr(out_desc.GetShape())].emplace_back(out_anchor->GetIdx()); - } - - for (const auto &shape_2_out_idx : shape_out_idx_map) { - // Create empty const - // The out_desc in one group should be same shape, so here only get first out_desc. its valid index. - auto out_desc = op_desc->GetOutputDesc(shape_2_out_idx.second[0]); - NodePtr const_node; - auto graph = node_to_replace->GetOwnerComputeGraph(); - Status ret = InsertEmptyConst(out_desc, const_node, graph); - if (ret != SUCCESS) { - GELOGE(FAILED, "Failed insert const node."); - return FAILED; - } - - // Repalce data anchors - for (const auto &anchor_idx: shape_2_out_idx.second) { - if (GraphUtils::ReplaceNodeDataAnchors(const_node, node_to_replace, {}, {anchor_idx}) != GRAPH_SUCCESS) { - GELOGE(FAILED, "[%s] ReplaceNodeAnchors failed.", node_to_replace->GetName().c_str()); - return FAILED; - } - } - - // Copy in control edge - if (GraphUtils::CopyInCtrlEdges(node_to_replace, const_node) != GRAPH_SUCCESS) { - GELOGE(FAILED, "CopyInCtrlEdges from %s to %s failed.", node_to_replace->GetName().c_str(), - const_node->GetName().c_str()); - return FAILED; - } - // Copy out control edge - if (GraphUtils::CopyOutCtrlEdges(node_to_replace, const_node) != GRAPH_SUCCESS) { - GELOGE(FAILED, "CopyOutCtrlEdges from %s to %s failed.", node_to_replace->GetName().c_str(), - const_node->GetName().c_str()); - return FAILED; - } - AddRePassNodesWithInOut(const_node); - GELOGI("Node %s has been replaced by empty const %s.", node_to_replace->GetName().c_str(), - const_node->GetName().c_str()); - } - IsolateAndDeleteNode(node_to_replace, {}); - return SUCCESS; -} -Status ReplaceWithEmptyConstPass::InsertEmptyConst(const GeTensorDesc &out_desc, NodePtr &const_node, - ComputeGraphPtr &graph) { - GeTensorPtr empty_tensor = MakeShared(out_desc); - if (empty_tensor == nullptr) { - GELOGE(OUT_OF_MEMORY, "Failed create empty tensor."); - return OUT_OF_MEMORY; - } - auto const_desc = OpDescUtils::CreateConstOp(empty_tensor); - if (const_desc == nullptr) { - GELOGE(OUT_OF_MEMORY, "Failed to get const desc from tensor"); - return OUT_OF_MEMORY; - } - - const_node = graph->AddNode(const_desc); - if (const_node == nullptr) { - GELOGE(FAILED, "Failed insert const node."); - return FAILED; + GeTensorPtr empty_tensor = MakeShared(out_desc); + GE_CHECK_NOTNULL(empty_tensor); + outputs.emplace_back(empty_tensor); } return SUCCESS; } @@ -144,12 +92,4 @@ bool ReplaceWithEmptyConstPass::IsEmptyTenor(const GeShape &shape) const { } return false; } - -string ReplaceWithEmptyConstPass::GetDimStr(const GeShape &shape) { - std::stringstream dim_str; - for (auto dim : shape.GetDims()) { - dim_str << dim << '-'; - } - return dim_str.str(); -} } // namespace ge diff --git a/ge/graph/passes/replace_with_empty_const_pass.h b/ge/graph/passes/replace_with_empty_const_pass.h index 5083c699..3b2f0a2d 100644 --- a/ge/graph/passes/replace_with_empty_const_pass.h +++ b/ge/graph/passes/replace_with_empty_const_pass.h @@ -17,18 +17,16 @@ #ifndef GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_ #define GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_ -#include "graph/passes/base_pass.h" +#include "graph/passes/folding_pass.h" namespace ge { -class ReplaceWithEmptyConstPass : public BaseNodePass { +class ReplaceWithEmptyConstPass : public FoldingPass { public: Status Run(NodePtr &node) override; private: - Status ReplaceWithEmptyConst(NodePtr &node_to_replace); - Status InsertEmptyConst(const GeTensorDesc &out_desc, NodePtr &const_node, ComputeGraphPtr &graph); + Status GetOutputsOfCurrNode(const NodePtr &node_to_repalce, vector &outputs); bool IsEmptyTenor(const GeShape &shape) const; - std::string GetDimStr(const GeShape &shape); }; } // namespace ge #endif // GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_