diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index 0555929d..05b1b5fc 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -228,13 +228,19 @@ bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDe } graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { - changed = false; - if (SameTensorDesc(src, dst)) { + changed = !SameTensorDesc(src, dst); + // refresh src itself + src->SetOriginShape(src->GetShape()); + src->SetOriginDataType(src->GetDataType()); + TensorUtils::SetRealDimCnt(*src, static_cast(src->GetOriginShape().GetDims().size())); + vector> src_shape_range; + src->GetShapeRange(src_shape_range); + src->SetOriginShapeRange(src_shape_range); + + if (!changed) { GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update."); return SUCCESS; } - - changed = true; UpdateShapeAndDType(src, dst); GELOGD( "UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s."