Browse Source

Bugfix: fix transpose fusion with input&output format check

tags/v1.1.0
zhaoxinxin 4 years ago
parent
commit
d42dabcd43
2 changed files with 17 additions and 0 deletions
  1. +8
    -0
      ge/graph/passes/transop_without_reshape_fusion_pass.cc
  2. +9
    -0
      ge/graph/passes/transpose_transdata_pass.cc

+ 8
- 0
ge/graph/passes/transop_without_reshape_fusion_pass.cc View File

@@ -131,6 +131,14 @@ graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphNodesInfo() {
sub_graph_has_reshape_node[i] = true;
break;
}
if (in_node->GetType() == TRANSOPSE || in_node->GetType() == TRANSOPSED) {
auto input_format = in_node->GetOpDesc()->GetInputDescPtr(0)->GetFormat();
auto output_format = in_node->GetOpDesc()->GetOutputDescPtr(0)->GetFormat();
if (input_format == output_format) {
sub_graph_has_reshape_node[i] = true;
break;
}
}

auto out_anchor = iter->first;
GE_CHECK_NOTNULL(out_anchor);


+ 9
- 0
ge/graph/passes/transpose_transdata_pass.cc View File

@@ -46,6 +46,15 @@ Status TransposeTransDataPass::Run(NodePtr &node) {
if (op_desc->GetType() != TRANSPOSED) {
return SUCCESS;
}
auto input_format = op_desc->GetInputDescPtr(0)->GetFormat();
auto output_format = op_desc->GetOutputDescPtr(0)->GetFormat();
if (intput_format == output_format) {
GELOGW("Node %s input format is %s, output format is %s, should not happend. Ignore pass.",
op_desc->GetName().c_str(),
TypeUtils::FormatToSerialString(input_format).c_str(),
TypeUtils::FormatToSerialString(output_format).c_str());
return SUCCESS;
}
if (CheckOneInAndOneOutDataAnchor(node) != SUCCESS) {
return FAILED;
}


Loading…
Cancel
Save