Browse Source

!432 rm empty_tensor input for merge node

From: @chen_yemeng
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
13ff4ac8c0
2 changed files with 37 additions and 0 deletions
  1. +36
    -0
      ge/graph/passes/merge_pass.cc
  2. +1
    -0
      ge/graph/passes/merge_pass.h

+ 36
- 0
ge/graph/passes/merge_pass.cc View File

@@ -34,6 +34,11 @@ using domi::SUCCESS;
namespace ge { namespace ge {
const int kValueIndexOutputIndex = 1; const int kValueIndexOutputIndex = 1;


bool IsEmptyTensor(const GeShape &shape) {
const auto &dims = shape.GetDims();
return std::any_of(dims.begin(), dims.end(), [](int64_t dim) { return dim == 0; });
}

Status MergePass::Run(NodePtr &node) { Status MergePass::Run(NodePtr &node) {
GELOGD("MergePass running"); GELOGD("MergePass running");
if (node == nullptr) { if (node == nullptr) {
@@ -53,6 +58,11 @@ Status MergePass::Run(NodePtr &node) {
return PARAM_INVALID; return PARAM_INVALID;
} }


if (OptimizeEmptyTensorInput(node) != SUCCESS) {
GELOGE(FAILED, "[%s] remove empty_tensor inputs failed.", node->GetName().c_str());
return FAILED;
}

auto in_data_nodes = node->GetInDataNodes(); auto in_data_nodes = node->GetInDataNodes();
switch (in_data_nodes.size()) { switch (in_data_nodes.size()) {
case 0: { case 0: {
@@ -202,4 +212,30 @@ bool MergePass::IsMergeInputNeedOptimized(NodePtr &node) const {
} }
return true; return true;
} }

Status MergePass::OptimizeEmptyTensorInput(const NodePtr &node) {
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_data_anchor == nullptr) {
continue;
}
if ((peer_data_anchor->GetOwnerNode() == nullptr) ||
(peer_data_anchor->GetOwnerNode()->GetOpDesc() == nullptr)) {
continue;
}
const auto &op_desc = peer_data_anchor->GetOwnerNode()->GetOpDesc();
if (IsEmptyTensor(op_desc->GetOutputDesc(peer_data_anchor->GetIdx()).GetShape())) {
if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.",
op_desc->GetName().c_str(), peer_data_anchor->GetIdx(),
node->GetName().c_str(), in_data_anchor->GetIdx());
return FAILED;
}
GELOGD("Remove data edge %s:%d->%s:%d",
op_desc->GetName().c_str(), peer_data_anchor->GetIdx(),
node->GetName().c_str(), in_data_anchor->GetIdx());
}
}
return SUCCESS;
}
} // namespace ge } // namespace ge

+ 1
- 0
ge/graph/passes/merge_pass.h View File

@@ -29,6 +29,7 @@ class MergePass : public BaseNodePass {
Status ChangeIndexToConstant(NodePtr &node, int &value_index); Status ChangeIndexToConstant(NodePtr &node, int &value_index);
Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc); Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc);
bool IsMergeInputNeedOptimized(NodePtr &node) const; bool IsMergeInputNeedOptimized(NodePtr &node) const;
static Status OptimizeEmptyTensorInput(const NodePtr &node);
}; };
} // namespace ge } // namespace ge
#endif // GE_GRAPH_PASSES_MERGE_PASS_H_ #endif // GE_GRAPH_PASSES_MERGE_PASS_H_

Loading…
Cancel
Save