From 0b4313e8a5d0fc5c021f2d627e541b82f91a7750 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Sun, 25 Apr 2021 17:05:49 +0800 Subject: [PATCH] add ut001 --- ge/graph/passes/infershape_pass.cc | 2 + ge/hybrid/model/node_item.h | 3 +- tests/ut/ge/CMakeLists.txt | 38 ++-- tests/ut/ge/graph/ge_executor_unittest.cc | 2 +- .../ge/graph/load/kernel_ex_task_info_unittest.cc | 11 +- tests/ut/ge/graph/load/model_manager_unittest.cc | 2 +- .../hybrid/executor/subgraph_executor_unittest.cc | 249 +++++++++++++++++++++ .../executor/worker/execution_engine_unittest.cc | 18 +- 8 files changed, 297 insertions(+), 28 deletions(-) create mode 100644 tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index 7181d824..acd240a5 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -145,8 +145,10 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { if (node->GetType() == MERGE || node->GetType() == REFMERGE) { if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { + node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); return RePassNode({SWITCH, REFSWITCH}); // Re-Pass Switch } + return SUCCESS; } if (node->GetType() == SWITCH || node->GetType() == REFSWITCH) { diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index 5c967920..606e58fe 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -41,8 +41,9 @@ bool IsControlFlowV2Op(const std::string &op_type); class OptionalMutexGuard { public: - OptionalMutexGuard(std::mutex *mutex, const string &name); + OptionalMutexGuard(std::mutex *mutex, const std::string &name); ~OptionalMutexGuard(); + private: std::mutex *mu_{nullptr}; std::string name_; diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 51b12514..2e28f1f2 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -816,6 +816,7 @@ set(PROFILING_MNG_TEST_FILES set(HYBRID_TEST_FILES "hybrid/ge_hybrid_unittest.cc" "hybrid/known_node_executor_unittest.cc" + "hybrid/executor/subgraph_executor_unittest.cc" "hybrid/executor/worker/execution_engine_unittest.cc" "hybrid/model/hybrid_model_builder_unittest.cc" "hybrid/node_executor/rts/rts_node_task_unittest.cc" @@ -834,6 +835,8 @@ list(APPEND COMMON_SHARED_LIBRARIES mmpa_stub hccl_stub error_manager_stub + ascend_protobuf + json ) # build graph @@ -879,7 +882,7 @@ target_link_libraries(ge_ut_common PRIVATE ) # build common format -add_library(ge_ut_common_format STATIC ${COMMON_SRC_FILES} ${COMMON_FORMAT_SRC_FILES} ${PROTO_HDRS}) +add_library(ge_ut_common_format STATIC ${COMMON_FORMAT_SRC_FILES} ${PROTO_HDRS}) target_compile_definitions(ge_ut_common_format PRIVATE google=ascend_private @@ -1056,7 +1059,6 @@ target_link_libraries(ge_single_op PRIVATE # libge_mutiparts_utest add_executable(ut_libge_multiparts_utest ${COMMON_TEST_FILES} - ${COMMON_FORMAT_SRC_FILES} ${MULTI_PARTS_TEST_FILES} ) @@ -1071,14 +1073,14 @@ target_compile_definitions(ut_libge_multiparts_utest PRIVATE target_link_libraries(ut_libge_multiparts_utest $ - ge_build_common ge_load_common ge_execute_common ge_optimize_common ge_partition_common ge_prepare_common ge_single_op ge_ut_common - gtest gtest_main gmock gmock_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov + ge_build_common ge_load_common ge_execute_common ge_optimize_common ge_partition_common ge_prepare_common + ge_single_op ge_ut_common_format ge_ut_common + gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lgcov ) # libge_others_utest add_executable(ut_libge_others_utest ${COMMON_TEST_FILES} - ${COMMON_FORMAT_SRC_FILES} ${PASS_TEST_FILES} ${EXECUTE_TEST_FILES} ${OTHERS_TEST_FILES} @@ -1091,16 +1093,15 @@ target_compile_options(ut_libge_others_utest PRIVATE target_link_libraries(ut_libge_others_utest $ - ge_load_common ge_execute_common ge_ut_common - gtest gtest_main gmock gmock_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov + ge_load_common ge_execute_common ge_ut_common ge_ut_common_format + gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lgcov ) # libge_kernel_utest add_executable(ut_libge_kernel_utest - ${COMMON_TEST_FILES} - ${COMMON_FORMAT_SRC_FILES} - ${KERNEL_TEST_FILES} - ${KERNEL_SRC_FILES} + ${COMMON_TEST_FILES} + ${KERNEL_TEST_FILES} + ${KERNEL_SRC_FILES} ) target_compile_options(ut_libge_kernel_utest PRIVATE @@ -1110,8 +1111,8 @@ target_compile_options(ut_libge_kernel_utest PRIVATE target_link_libraries(ut_libge_kernel_utest $ - ge_load_common ge_ut_common - gtest gtest_main gmock gmock_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov + ge_load_common ge_ut_common ge_ut_common_format + gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lgcov ) # libge_distinct_load_utest @@ -1137,10 +1138,11 @@ target_compile_definitions(ut_libge_distinct_load_utest PRIVATE ) target_link_libraries(ut_libge_distinct_load_utest - ${COMMON_SHARED_LIBRARIES} $ - ge_execute_common ge_ut_common_format ge_load_common - ge_single_op ge_prepare_common - ge_optimize_common ge_build_common ge_partition_common ge_ut_common - gtest gtest_main gmock gmock_main ascend_protobuf json c_sec -lrt -ldl -lpthread -lgcov + -Wl,--whole-archive + ge_single_op + -Wl,--no-whole-archive + ge_execute_common ge_load_common + ge_prepare_common ge_optimize_common ge_build_common ge_partition_common ge_ut_common ge_ut_common_format + gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lpthread -lgcov ) diff --git a/tests/ut/ge/graph/ge_executor_unittest.cc b/tests/ut/ge/graph/ge_executor_unittest.cc index e26aa86e..13293eea 100644 --- a/tests/ut/ge/graph/ge_executor_unittest.cc +++ b/tests/ut/ge/graph/ge_executor_unittest.cc @@ -115,7 +115,7 @@ TEST_F(UtestGeExecutor, load_data_from_file) { string test_smap = "/tmp/" + std::to_string(getpid()) + "_maps"; string self_smap = "/proc/" + std::to_string(getpid()) + "/maps"; - string copy_smap = "cp " + self_smap + " " + test_smap; + string copy_smap = "cp -f " + self_smap + " " + test_smap; EXPECT_EQ(system(copy_smap.c_str()), 0); ModelData model_data; diff --git a/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc b/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc index 44d4d042..63202a28 100644 --- a/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc +++ b/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc @@ -91,8 +91,8 @@ TEST_F(UtestKernelExTaskInfo, success_kernel_ex_task_release) { // test kernel_ex_task_Release TEST_F(UtestKernelExTaskInfo, success_kernel_ex_task_info_copy) { DavinciModel model(0, nullptr); - model.runtime_param_.mem_base = (uint8_t *)0x12345; - model.runtime_param_.mem_size = 100332000; + model.runtime_param_.mem_size = 10240; + model.runtime_param_.mem_base = new uint8_t[model.runtime_param_.mem_size]; rtStream_t stream = nullptr; rtStreamCreate(&stream, 0); @@ -108,19 +108,20 @@ TEST_F(UtestKernelExTaskInfo, success_kernel_ex_task_info_copy) { EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace empty. - model.op_list_[0]->SetWorkspace({100331008}); // offset + model.op_list_[0]->SetWorkspace({1008}); // offset model.op_list_[0]->SetWorkspaceBytes({0}); // length EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace addr is null. - model.op_list_[0]->SetWorkspace({100331008}); // offset + model.op_list_[0]->SetWorkspace({1208}); // offset model.op_list_[0]->SetWorkspaceBytes({10}); // length EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace addr is small. - model.op_list_[0]->SetWorkspace({100331008}); // offset + model.op_list_[0]->SetWorkspace({1308}); // offset model.op_list_[0]->SetWorkspaceBytes({150}); // length EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), SUCCESS); task_def.clear_kernel_ex(); + delete [] model.runtime_param_.mem_base; model.runtime_param_.mem_base = nullptr; } diff --git a/tests/ut/ge/graph/load/model_manager_unittest.cc b/tests/ut/ge/graph/load/model_manager_unittest.cc index 342f6362..83d694d4 100644 --- a/tests/ut/ge/graph/load/model_manager_unittest.cc +++ b/tests/ut/ge/graph/load/model_manager_unittest.cc @@ -418,6 +418,6 @@ TEST_F(UtestModelManagerModelManager, test_data_input_tensor) { vector inputs; inputs.emplace_back(input_tensor); auto ret = mm.DataInputTensor(model_id,inputs); - EXPECT_EQ(UNSUPPORTED, ret); + EXPECT_EQ(PARAM_INVALID, ret); // HybridDavinciModel::impl_ is null. } } // namespace ge diff --git a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc new file mode 100644 index 00000000..5e9aa0e8 --- /dev/null +++ b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc @@ -0,0 +1,249 @@ +/** + * Copyright 2019-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#define private public +#define protected public +#include "hybrid/executor/subgraph_executor.h" +#include "hybrid/node_executor/node_executor.h" +#include "hybrid/node_executor/rts/rts_node_executor.h" +#include "hybrid/node_executor/ge_local/ge_local_node_executor.h" +#include "hybrid/model/hybrid_model_builder.h" +#include "graph/utils/graph_utils.h" + +using namespace std; +using namespace testing; + +namespace ge { +using namespace hybrid; + +class UtestSubgraphExecutor : public testing::Test { + protected: + void SetUp() { + NodeExecutorManager::GetInstance().engine_mapping_.clear(); + auto &engine_mapping = NodeExecutorManager::GetInstance().engine_mapping_; + engine_mapping.emplace("DNN_VM_RTS_OP_STORE", NodeExecutorManager::ExecutorType::RTS); + engine_mapping.emplace("DNN_VM_GE_LOCAL_OP_STORE", NodeExecutorManager::ExecutorType::GE_LOCAL); + + NodeExecutorManager::GetInstance().executors_.clear(); + auto &task_executor = NodeExecutorManager::GetInstance().executors_; + task_executor.emplace(NodeExecutorManager::ExecutorType::RTS, std::unique_ptr(new RtsNodeExecutor())); + task_executor.emplace(NodeExecutorManager::ExecutorType::GE_LOCAL, std::unique_ptr(new GeLocalNodeExecutor())); + } + void TearDown() { + NodeExecutorManager::GetInstance().engine_mapping_.clear(); + NodeExecutorManager::GetInstance().executors_.clear(); + } +}; + +static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { + OpDescPtr op_desc = std::make_shared(name, type); + op_desc->SetStreamId(0); + static int32_t index = 0; + op_desc->SetId(index++); + + GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); + TensorUtils::SetSize(tensor, 64); + vector input_offset; + for (int i = 0; i < in_num; i++) { + op_desc->AddInputDesc(tensor); + input_offset.emplace_back(index * 64 + i * 64); + } + op_desc->SetInputOffset(input_offset); + + vector output_offset; + for (int i = 0; i < out_num; i++) { + op_desc->AddOutputDesc(tensor); + output_offset.emplace_back(index * 64 + in_num * 64 + i * 64); + } + op_desc->SetOutputOffset(output_offset); + + op_desc->SetWorkspace({}); + op_desc->SetWorkspaceBytes({}); + op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); + + return graph.AddNode(op_desc); +} + +static void CreateSimpleCondGraph(ComputeGraph &graph) { +/******************************************************************************* + * | + * Merge + * / \. + * / \. + * / \. + * Add Sub + * | \ / | + * | \ _ / | + * | / \ | + * | / \ | + * Switch Switch + * | \ / | + * | \ / | + * | \ / | + * | \ / | + * | Less | + * | / \ | + * | / \ | + * Data Data + ******************************************************************************/ + const auto data0 = CreateNode(graph, "data", DATA, 1, 1); + const auto data1 = CreateNode(graph, "data1", DATA, 1, 1); + data0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); + data1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); + + const auto const0 = CreateNode(graph, "const", CONSTANT, 0, 1); + const auto const1 = CreateNode(graph, "const1", CONSTANT, 0, 1); + const0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); + const1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); + { + uint64_t const_value = 0; + const auto op_desc = const0->GetOpDesc(); + auto weight = make_shared(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); + AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); + } + { + uint64_t const_value = 1; + const auto op_desc = const1->GetOpDesc(); + auto weight = make_shared(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); + AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); + } + + const auto less1 = CreateNode(graph, "less", ENTER, 2, 1); + + const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); + const auto switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); + const auto switch_f = CreateNode(graph, "switch_f", STREAMSWITCH, 2, 0); + AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1); + AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1); + + const auto add1 = CreateNode(graph, "add", ENTER, 2, 1); + const auto sub1 = CreateNode(graph, "sub", ENTER, 2, 1); + + const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); + const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); + const auto active3 = CreateNode(graph, "active3", STREAMACTIVE, 0, 0); + + const auto output1 = CreateNode(graph, "net_output", NETOUTPUT, 1, 1); + output1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); + + GraphUtils::AddEdge(data0->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_t->GetInDataAnchor(0)); + GraphUtils::AddEdge(const0->GetOutDataAnchor(0), switch_t->GetInDataAnchor(1)); + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_f->GetInDataAnchor(0)); + GraphUtils::AddEdge(const1->GetOutDataAnchor(0), switch_f->GetInDataAnchor(1)); + + GraphUtils::AddEdge(less1->GetOutControlAnchor(), active1->GetInControlAnchor()); + GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); + GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_f->GetInControlAnchor()); + + GraphUtils::AddEdge(data0->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); + GraphUtils::AddEdge(add1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_t->GetOutControlAnchor(), add1->GetInControlAnchor()); + GraphUtils::AddEdge(add1->GetOutControlAnchor(), active2->GetInControlAnchor()); + GraphUtils::AddEdge(active2->GetOutControlAnchor(), merge1->GetInControlAnchor()); + + GraphUtils::AddEdge(data0->GetOutDataAnchor(0), sub1->GetInDataAnchor(0)); + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), sub1->GetInDataAnchor(1)); + GraphUtils::AddEdge(sub1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); + GraphUtils::AddEdge(switch_f->GetOutControlAnchor(), sub1->GetInControlAnchor()); + GraphUtils::AddEdge(sub1->GetOutControlAnchor(), active3->GetInControlAnchor()); + GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); + + GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); +} + +TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) { + ComputeGraphPtr graph = std::make_shared("test"); + const auto data0 = CreateNode(*graph, "data", DATA, 1, 1); + const auto output0 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); + GraphUtils::AddEdge(data0->GetOutDataAnchor(0), output0->GetInDataAnchor(0)); + data0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); + output0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); + + GeRootModelPtr ge_root_model = make_shared(graph); + ge_root_model->SetModelName("test_name"); + GeModelPtr ge_sub_model = make_shared(); + ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); + + HybridModel hybrid_model(ge_root_model); + HybridModelBuilder hybrid_model_builder(hybrid_model); + ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); + + GraphExecutionContext graph_context; + graph_context.callback_manager = std::unique_ptr(new CallbackManager()); + graph_context.model = &hybrid_model; + + uint64_t value_0 = 110; + TensorValue in_tensor0(&value_0, sizeof(value_0)); + const std::vector inputs{ in_tensor0 }; + + uint64_t value_1 = 123; + TensorValue out_tensor0(&value_1, sizeof(value_1)); + const std::vector outputs{ out_tensor0 }; + + auto input_desc = output0->GetOpDesc()->GetInputDescPtr(0); + const std::vector input_descs{ input_desc }; + + SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context); + ASSERT_EQ(executor.ExecuteAsync(inputs, input_descs, outputs), SUCCESS); +} + +TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) { + ComputeGraphPtr graph = std::make_shared("test"); + CreateSimpleCondGraph(*graph); + + GeRootModelPtr ge_root_model = make_shared(graph); + ge_root_model->SetModelName("test_name"); + GeModelPtr ge_sub_model = make_shared(); + ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); + Buffer weights_buffer(1024, 0x76); + ge_sub_model->SetWeight(weights_buffer); + ge_sub_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); + + HybridModel hybrid_model(ge_root_model); + HybridModelBuilder hybrid_model_builder(hybrid_model); + ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); + + GraphExecutionContext graph_context; + graph_context.model = &hybrid_model; + graph_context.allocator = NpuMemoryAllocator::GetAllocator(0); + graph_context.callback_manager = std::unique_ptr(new CallbackManager()); + ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS); + + uint64_t value_0 = 110; + TensorValue in_tensor0(&value_0, sizeof(value_0)); + uint64_t value_1 = 110; + TensorValue in_tensor1(&value_1, sizeof(value_1)); + const std::vector inputs{ in_tensor0, in_tensor1 }; + uint64_t value_2 = 123; + TensorValue out_tensor0(&value_2, sizeof(value_2)); + const std::vector outputs{ out_tensor0 }; + + GeTensorDescPtr tensor_desc = make_shared(GeShape(), FORMAT_ND, DT_INT64); + TensorUtils::SetSize(*tensor_desc, 64); + const std::vector input_desc{ tensor_desc, tensor_desc }; + + SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context); + ASSERT_EQ(executor.ExecuteAsync(inputs, input_desc, outputs), SUCCESS); + ASSERT_EQ(graph_context.callback_manager->Destroy(), SUCCESS); +} +} // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc b/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc index 5fa0d22c..92315448 100644 --- a/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc +++ b/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc @@ -26,6 +26,7 @@ #include "hybrid/executor/hybrid_execution_context.h" #include "hybrid/executor/hybrid_model_executor.h" #include "hybrid/executor/worker/execution_engine.h" +#include "hybrid/executor/subgraph_executor.h" #undef private #undef protected @@ -75,6 +76,10 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) { node_item->output_start = 0; GraphExecutionContext execution_context; + GeRootModelPtr ge_root_model = make_shared(graph); + HybridModel hybrid_model(ge_root_model); + hybrid_model.root_graph_item_ = std::unique_ptr(new(std::nothrow)GraphItem()); + execution_context.model = &hybrid_model; execution_context.profiling_level = 1; SubgraphContext subgraph_context(nullptr, &execution_context); @@ -85,7 +90,11 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) { ExecutionEngine execution_engine; ASSERT_TRUE(node_state.GetTaskContext() != nullptr); - EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context), INTERNAL_ERROR); + + std::function callback; + SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); + executor.InitCallback(&node_state, callback); + EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); } TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { @@ -105,6 +114,7 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { GraphExecutionContext execution_context; GeRootModelPtr ge_root_model = make_shared(graph); HybridModel hybrid_model(ge_root_model); + hybrid_model.root_graph_item_ = std::unique_ptr(new(std::nothrow)GraphItem()); execution_context.model = &hybrid_model; SubgraphContext subgraph_context(nullptr, &execution_context); @@ -115,5 +125,9 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { ExecutionEngine execution_engine; ASSERT_TRUE(node_state.GetTaskContext() != nullptr); - EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context), INTERNAL_ERROR); + + std::function callback; + SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); + executor.InitCallback(&node_state, callback); + EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); }