From 1ac4bcb03380939a3aae8f2f556aaf32a2741890 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=B6=9B?= Date: Thu, 22 Jul 2021 15:12:57 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!2037=20?= =?UTF-8?q?:=20Bug:=20Fix=20pytorch=20infershape=20origin=20shape=20wrong'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ge/graph/passes/infershape_pass.cc | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index 0555929d..05b1b5fc 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -228,13 +228,19 @@ bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDe } graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { - changed = false; - if (SameTensorDesc(src, dst)) { + changed = !SameTensorDesc(src, dst); + // refresh src itself + src->SetOriginShape(src->GetShape()); + src->SetOriginDataType(src->GetDataType()); + TensorUtils::SetRealDimCnt(*src, static_cast(src->GetOriginShape().GetDims().size())); + vector> src_shape_range; + src->GetShapeRange(src_shape_range); + src->SetOriginShapeRange(src_shape_range); + + if (!changed) { GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update."); return SUCCESS; } - - changed = true; UpdateShapeAndDType(src, dst); GELOGD( "UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s."