Browse Source

回退 'Pull Request !2037 : Bug: Fix pytorch infershape origin shape wrong'

revert-merge-2037-master
王涛 Gitee 3 years ago
parent
commit
1ac4bcb033
1 changed files with 10 additions and 4 deletions
  1. +10
    -4
      ge/graph/passes/infershape_pass.cc

+ 10
- 4
ge/graph/passes/infershape_pass.cc View File

@@ -228,13 +228,19 @@ bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDe
}
graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
changed = false;
if (SameTensorDesc(src, dst)) {
changed = !SameTensorDesc(src, dst);
// refresh src itself
src->SetOriginShape(src->GetShape());
src->SetOriginDataType(src->GetDataType());
TensorUtils::SetRealDimCnt(*src, static_cast<uint32_t>(src->GetOriginShape().GetDims().size()));
vector<pair<int64_t, int64_t>> src_shape_range;
src->GetShapeRange(src_shape_range);
src->SetOriginShapeRange(src_shape_range);
if (!changed) {
GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update.");
return SUCCESS;
}
changed = true;
UpdateShapeAndDType(src, dst);
GELOGD(
"UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s."


Loading…
Cancel
Save