@@ -100,13 +100,16 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vector<TensorValue | |||||
GE_CHECK_LE(i + 1, input_desc.size()); | GE_CHECK_LE(i + 1, input_desc.size()); | ||||
const auto &tensor_desc = input_desc[i]; | const auto &tensor_desc = input_desc[i]; | ||||
GE_CHECK_NOTNULL(tensor_desc); | GE_CHECK_NOTNULL(tensor_desc); | ||||
auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); | |||||
GE_CHECK_NOTNULL(node_state); | |||||
node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc); | |||||
auto op_desc = input_node->GetOpDesc(); | auto op_desc = input_node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
auto output_desc = op_desc->MutableOutputDesc(kDataInputIndex); | auto output_desc = op_desc->MutableOutputDesc(kDataInputIndex); | ||||
GE_CHECK_NOTNULL(output_desc); | GE_CHECK_NOTNULL(output_desc); | ||||
output_desc->SetShape(tensor_desc->GetShape()); | output_desc->SetShape(tensor_desc->GetShape()); | ||||
output_desc->SetOriginShape(tensor_desc->GetOriginShape()); | output_desc->SetOriginShape(tensor_desc->GetOriginShape()); | ||||
output_desc->SetDataType(tensor_desc->GetDataType()); | |||||
output_desc->SetDataType(tensor_desc->GetDataType()); | |||||
} | } | ||||
} | } | ||||
@@ -87,7 +87,7 @@ TEST_F(UtestHybridModelAsyncExecutor, BuildDeviceTensor) { | |||||
ASSERT_EQ(size, 100); | ASSERT_EQ(size, 100); | ||||
} | } | ||||
TEST_F(UtestHybridModelAsyncExecutor, Test_execute_internal) { | |||||
TEST_F(UtestHybridModelAsyncExecutor, Test_execute) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ||||
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | ||||
ge_root_model->SetModelName("test_name"); | ge_root_model->SetModelName("test_name"); | ||||
@@ -101,6 +101,6 @@ TEST_F(UtestHybridModelAsyncExecutor, Test_execute_internal) { | |||||
std::pair<rtEvent_t, std::pair<rtCallback_t, void *>> eof_entry; | std::pair<rtEvent_t, std::pair<rtCallback_t, void *>> eof_entry; | ||||
eof_entry.first = nullptr; | eof_entry.first = nullptr; | ||||
context.callback_manager->callback_queue_.Push(eof_entry); | context.callback_manager->callback_queue_.Push(eof_entry); | ||||
ASSERT_EQ(executor.ExecuteGraphInternal(args), SUCCESS); | |||||
ASSERT_EQ(executor.Execute(args), SUCCESS); | |||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -224,7 +224,6 @@ TEST_F(UtestSingleOpModel, test_build_dynamic_op) { | |||||
model.model_helper_.model_->SetGraph(graph); | model.model_helper_.model_->SetGraph(graph); | ||||
auto op_desc = transdata->GetOpDesc(); | auto op_desc = transdata->GetOpDesc(); | ||||
op_desc->impl_->input_name_idx_["Data"] = 0; | |||||
const vector<string> depend_names = { "Data" }; | const vector<string> depend_names = { "Data" }; | ||||
op_desc->SetOpInferDepends(depend_names); | op_desc->SetOpInferDepends(depend_names); | ||||
(void)AttrUtils::SetBool(op_desc, kAttrSupportDynamicShape, true); | (void)AttrUtils::SetBool(op_desc, kAttrSupportDynamicShape, true); | ||||
@@ -247,6 +246,9 @@ TEST_F(UtestSingleOpModel, test_build_dynamic_op) { | |||||
DynamicSingleOp dynamic_single_op(0, &stream_mu_, nullptr); | DynamicSingleOp dynamic_single_op(0, &stream_mu_, nullptr); | ||||
StreamResource res((uintptr_t)1); | StreamResource res((uintptr_t)1); | ||||
model.BuildDynamicOp(res, dynamic_single_op); | model.BuildDynamicOp(res, dynamic_single_op); | ||||
op_desc->impl_->input_name_idx_["Data"] = 0; | |||||
model.BuildDynamicOp(res, dynamic_single_op); | |||||
} | } | ||||
TEST_F(UtestSingleOpModel, test_host_mem) { | TEST_F(UtestSingleOpModel, test_host_mem) { | ||||