Browse Source

rm empty_tensor inputs for merge

tags/v1.1.0
chenyemeng 4 years ago
parent
commit
eb058ac969
2 changed files with 28 additions and 0 deletions
  1. +27
    -0
      ge/graph/passes/merge_pass.cc
  2. +1
    -0
      ge/graph/passes/merge_pass.h

+ 27
- 0
ge/graph/passes/merge_pass.cc View File

@@ -34,6 +34,11 @@ using domi::SUCCESS;
namespace ge { namespace ge {
const int kValueIndexOutputIndex = 1; const int kValueIndexOutputIndex = 1;


bool IsEmptyTensor(const GeShape &shpae) {
const auto &dims = shape.GetDims();
return std::any_of(dims.begin(), dims.end(), [](int64_t dim) { return dim == 0; });
}

Status MergePass::Run(NodePtr &node) { Status MergePass::Run(NodePtr &node) {
GELOGD("MergePass running"); GELOGD("MergePass running");
if (node == nullptr) { if (node == nullptr) {
@@ -53,6 +58,11 @@ Status MergePass::Run(NodePtr &node) {
return PARAM_INVALID; return PARAM_INVALID;
} }


if (OptimizeEmptyTensorInput(node) != SUCCESS) {
GELOGE(FAILED, "[%s] remove empty_tensor inputs failed.", node->GetName().c_str());
return FAILED;
}

auto in_data_nodes = node->GetInDataNodes(); auto in_data_nodes = node->GetInDataNodes();
switch (in_data_nodes.size()) { switch (in_data_nodes.size()) {
case 0: { case 0: {
@@ -202,4 +212,21 @@ bool MergePass::IsMergeInputNeedOptimized(NodePtr &node) const {
} }
return true; return true;
} }

Status MergePass::OptimizeEmptyTensorInput(const NodePtr &node) const {
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_data_anchor == nullptr) {
continue;
}
const auto &op_desc = peer_data_anchor->GetOwnerNode()->GetOpDesc();
if (op_desc == nullptr) {
continue;
}
if (IsEmptyTensor(op_desc->GetOutputDesc(peer_data_anchor->GetIdx()).GetShape())) {
return GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) == GRAPH_SUCCESS ? SUCCESS : FAILED;
}
}
return SUCCESS;
}
} // namespace ge } // namespace ge

+ 1
- 0
ge/graph/passes/merge_pass.h View File

@@ -29,6 +29,7 @@ class MergePass : public BaseNodePass {
Status ChangeIndexToConstant(NodePtr &node, int &value_index); Status ChangeIndexToConstant(NodePtr &node, int &value_index);
Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc); Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc);
bool IsMergeInputNeedOptimized(NodePtr &node) const; bool IsMergeInputNeedOptimized(NodePtr &node) const;
static Status OptimizeEmptyTensorInput(const NodePtr &node) const;
}; };
} // namespace ge } // namespace ge
#endif // GE_GRAPH_PASSES_MERGE_PASS_H_ #endif // GE_GRAPH_PASSES_MERGE_PASS_H_

Loading…
Cancel
Save