@@ -100,9 +100,13 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vector<TensorValue | |||||
GE_CHECK_LE(i + 1, input_desc.size()); | GE_CHECK_LE(i + 1, input_desc.size()); | ||||
const auto &tensor_desc = input_desc[i]; | const auto &tensor_desc = input_desc[i]; | ||||
GE_CHECK_NOTNULL(tensor_desc); | GE_CHECK_NOTNULL(tensor_desc); | ||||
auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); | |||||
GE_CHECK_NOTNULL(node_state); | |||||
node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc); | |||||
auto op_desc = input_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
auto output_desc = op_desc->MutableOutputDesc(kDataInputIndex); | |||||
GE_CHECK_NOTNULL(output_desc); | |||||
output_desc->SetShape(tensor_desc->GetShape()); | |||||
output_desc->SetOriginShape(tensor_desc->GetOriginShape()); | |||||
output_desc->SetDataType(tensor_desc->GetDataType()); | |||||
} | } | ||||
} | } | ||||
@@ -68,8 +68,9 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { | |||||
} | } | ||||
// Do shape inference | // Do shape inference | ||||
// Skipping infer shape of input node. | |||||
GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); | GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); | ||||
{ | |||||
if (node_state.GetType() != DATA_TYPE && node_state.GetType() != AIPP_DATA_TYPE) { | |||||
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); | RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); | ||||
GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), | GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), | ||||
"[Invoke][InferShapeAndType] for %s failed.", node_item.NodeName().c_str()); | "[Invoke][InferShapeAndType] for %s failed.", node_item.NodeName().c_str()); | ||||