diff --git a/ge/graph/passes/transop_without_reshape_fusion_pass.cc b/ge/graph/passes/transop_without_reshape_fusion_pass.cc index c1eaf0f9..d2b3f1b1 100644 --- a/ge/graph/passes/transop_without_reshape_fusion_pass.cc +++ b/ge/graph/passes/transop_without_reshape_fusion_pass.cc @@ -130,6 +130,14 @@ graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphNodesInfo() { sub_graph_has_reshape_node[i] = true; break; } + if (in_node->GetType() == TRANSPOSE || in_node->GetType() == TRANSPOSED) { + auto input_format = in_node->GetOpDesc()->GetInputDescPtr(0)->GetFormat(); + auto output_format = in_node->GetOpDesc()->GetOutputDescPtr(0)->GetFormat(); + if (input_format == output_format) { + sub_graph_has_reshape_node[i] = true; + break; + } + } auto out_anchor = iter->first; GE_CHECK_NOTNULL(out_anchor); diff --git a/ge/graph/passes/transpose_transdata_pass.cc b/ge/graph/passes/transpose_transdata_pass.cc index 19bff563..7348f143 100644 --- a/ge/graph/passes/transpose_transdata_pass.cc +++ b/ge/graph/passes/transpose_transdata_pass.cc @@ -46,6 +46,15 @@ Status TransposeTransDataPass::Run(NodePtr &node) { if (op_desc->GetType() != TRANSPOSED) { return SUCCESS; } + auto input_format = op_desc->GetInputDescPtr(0)->GetFormat(); + auto output_format = op_desc->GetOutputDescPtr(0)->GetFormat(); + if (input_format == output_format) { + GELOGW("Node %s input format is %s, output format is %s, should not happend. Ignore pass.", + op_desc->GetName().c_str(), + TypeUtils::FormatToSerialString(input_format).c_str(), + TypeUtils::FormatToSerialString(output_format).c_str()); + return SUCCESS; + } if (CheckOneInAndOneOutDataAnchor(node) != SUCCESS) { return FAILED; }