diff --git a/ge/graph/passes/merge_pass.cc b/ge/graph/passes/merge_pass.cc index d2340037..c6bae2a2 100644 --- a/ge/graph/passes/merge_pass.cc +++ b/ge/graph/passes/merge_pass.cc @@ -34,6 +34,11 @@ using domi::SUCCESS; namespace ge { const int kValueIndexOutputIndex = 1; +bool IsEmptyTensor(const GeShape &shpae) { + const auto &dims = shape.GetDims(); + return std::any_of(dims.begin(), dims.end(), [](int64_t dim) { return dim == 0; }); +} + Status MergePass::Run(NodePtr &node) { GELOGD("MergePass running"); if (node == nullptr) { @@ -53,6 +58,11 @@ Status MergePass::Run(NodePtr &node) { return PARAM_INVALID; } + if (OptimizeEmptyTensorInput(node) != SUCCESS) { + GELOGE(FAILED, "[%s] remove empty_tensor inputs failed.", node->GetName().c_str()); + return FAILED; + } + auto in_data_nodes = node->GetInDataNodes(); switch (in_data_nodes.size()) { case 0: { @@ -202,4 +212,21 @@ bool MergePass::IsMergeInputNeedOptimized(NodePtr &node) const { } return true; } + +Status MergePass::OptimizeEmptyTensorInput(const NodePtr &node) const { + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_data_anchor == nullptr) { + continue; + } + const auto &op_desc = peer_data_anchor->GetOwnerNode()->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + if (IsEmptyTensor(op_desc->GetOutputDesc(peer_data_anchor->GetIdx()).GetShape())) { + return GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) == GRAPH_SUCCESS ? SUCCESS : FAILED; + } + } + return SUCCESS; +} } // namespace ge diff --git a/ge/graph/passes/merge_pass.h b/ge/graph/passes/merge_pass.h index 2cdb5022..c297a86e 100755 --- a/ge/graph/passes/merge_pass.h +++ b/ge/graph/passes/merge_pass.h @@ -29,6 +29,7 @@ class MergePass : public BaseNodePass { Status ChangeIndexToConstant(NodePtr &node, int &value_index); Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc); bool IsMergeInputNeedOptimized(NodePtr &node) const; + static Status OptimizeEmptyTensorInput(const NodePtr &node) const; }; } // namespace ge #endif // GE_GRAPH_PASSES_MERGE_PASS_H_