From 0c2d07eb7250e5cad532a906691f48b4dd48b552 Mon Sep 17 00:00:00 2001 From: zhaozhixuan Date: Tue, 15 Jun 2021 21:44:09 +0800 Subject: [PATCH] Fix ut. --- ge/hybrid/executor/subgraph_executor.cc | 9 ++++++--- ge/hybrid/executor/worker/shape_inference_engine.cc | 2 +- tests/ut/ge/single_op/single_op_model_unittest.cc | 1 + 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ge/hybrid/executor/subgraph_executor.cc b/ge/hybrid/executor/subgraph_executor.cc index 612e7565..4f0566b4 100644 --- a/ge/hybrid/executor/subgraph_executor.cc +++ b/ge/hybrid/executor/subgraph_executor.cc @@ -100,9 +100,12 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vectorGetOrCreateNodeState(input_node); - GE_CHECK_NOTNULL(node_state); - node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc); + auto op_desc = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto output_desc = op_desc->MutableOutputDesc(kDataInputIndex); + output_desc.SetShape(tensor_desc->GetShape()); + output_desc.SetOriginShape(tensor_desc->GetOriginShape()); + output_desc.SetDataType(tensor_desc->GetDataType()); } } diff --git a/ge/hybrid/executor/worker/shape_inference_engine.cc b/ge/hybrid/executor/worker/shape_inference_engine.cc index a2efbb25..4dc5b79c 100755 --- a/ge/hybrid/executor/worker/shape_inference_engine.cc +++ b/ge/hybrid/executor/worker/shape_inference_engine.cc @@ -69,7 +69,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { // Do shape inference GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); - { + if (node_state.GetType() != DATA_TYPE && node_state.GetType() != AIPP_DATA_TYPE) { RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), "[Invoke][InferShapeAndType] for %s failed.", node_item.NodeName().c_str()); diff --git a/tests/ut/ge/single_op/single_op_model_unittest.cc b/tests/ut/ge/single_op/single_op_model_unittest.cc index fb772c33..cb0b497d 100644 --- a/tests/ut/ge/single_op/single_op_model_unittest.cc +++ b/tests/ut/ge/single_op/single_op_model_unittest.cc @@ -30,6 +30,7 @@ #include "single_op/single_op.h" #include "single_op/stream_resource.h" #include "graph/passes/graph_builder_utils.h" +#include "graph/op_desc_impl.h" #undef private #undef protected