Browse Source

add ut001

tags/v1.3.0
zhangxiaokun 4 years ago
parent
commit
0b4313e8a5
8 changed files with 297 additions and 28 deletions
  1. +2
    -0
      ge/graph/passes/infershape_pass.cc
  2. +2
    -1
      ge/hybrid/model/node_item.h
  3. +20
    -18
      tests/ut/ge/CMakeLists.txt
  4. +1
    -1
      tests/ut/ge/graph/ge_executor_unittest.cc
  5. +6
    -5
      tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc
  6. +1
    -1
      tests/ut/ge/graph/load/model_manager_unittest.cc
  7. +249
    -0
      tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc
  8. +16
    -2
      tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc

+ 2
- 0
ge/graph/passes/infershape_pass.cc View File

@@ -145,8 +145,10 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) {

if (node->GetType() == MERGE || node->GetType() == REFMERGE) {
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
return RePassNode({SWITCH, REFSWITCH}); // Re-Pass Switch
}
return SUCCESS;
}

if (node->GetType() == SWITCH || node->GetType() == REFSWITCH) {


+ 2
- 1
ge/hybrid/model/node_item.h View File

@@ -41,8 +41,9 @@ bool IsControlFlowV2Op(const std::string &op_type);

class OptionalMutexGuard {
public:
OptionalMutexGuard(std::mutex *mutex, const string &name);
OptionalMutexGuard(std::mutex *mutex, const std::string &name);
~OptionalMutexGuard();

private:
std::mutex *mu_{nullptr};
std::string name_;


+ 20
- 18
tests/ut/ge/CMakeLists.txt View File

@@ -816,6 +816,7 @@ set(PROFILING_MNG_TEST_FILES
set(HYBRID_TEST_FILES
"hybrid/ge_hybrid_unittest.cc"
"hybrid/known_node_executor_unittest.cc"
"hybrid/executor/subgraph_executor_unittest.cc"
"hybrid/executor/worker/execution_engine_unittest.cc"
"hybrid/model/hybrid_model_builder_unittest.cc"
"hybrid/node_executor/rts/rts_node_task_unittest.cc"
@@ -834,6 +835,8 @@ list(APPEND COMMON_SHARED_LIBRARIES
mmpa_stub
hccl_stub
error_manager_stub
ascend_protobuf
json
)

# build graph
@@ -879,7 +882,7 @@ target_link_libraries(ge_ut_common PRIVATE
)

# build common format
add_library(ge_ut_common_format STATIC ${COMMON_SRC_FILES} ${COMMON_FORMAT_SRC_FILES} ${PROTO_HDRS})
add_library(ge_ut_common_format STATIC ${COMMON_FORMAT_SRC_FILES} ${PROTO_HDRS})

target_compile_definitions(ge_ut_common_format PRIVATE
google=ascend_private
@@ -1056,7 +1059,6 @@ target_link_libraries(ge_single_op PRIVATE
# libge_mutiparts_utest
add_executable(ut_libge_multiparts_utest
${COMMON_TEST_FILES}
${COMMON_FORMAT_SRC_FILES}
${MULTI_PARTS_TEST_FILES}
)

@@ -1071,14 +1073,14 @@ target_compile_definitions(ut_libge_multiparts_utest PRIVATE

target_link_libraries(ut_libge_multiparts_utest
$<BUILD_INTERFACE:intf_pub>
ge_build_common ge_load_common ge_execute_common ge_optimize_common ge_partition_common ge_prepare_common ge_single_op ge_ut_common
gtest gtest_main gmock gmock_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov
ge_build_common ge_load_common ge_execute_common ge_optimize_common ge_partition_common ge_prepare_common
ge_single_op ge_ut_common_format ge_ut_common
gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lgcov
)

# libge_others_utest
add_executable(ut_libge_others_utest
${COMMON_TEST_FILES}
${COMMON_FORMAT_SRC_FILES}
${PASS_TEST_FILES}
${EXECUTE_TEST_FILES}
${OTHERS_TEST_FILES}
@@ -1091,16 +1093,15 @@ target_compile_options(ut_libge_others_utest PRIVATE

target_link_libraries(ut_libge_others_utest
$<BUILD_INTERFACE:intf_pub>
ge_load_common ge_execute_common ge_ut_common
gtest gtest_main gmock gmock_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov
ge_load_common ge_execute_common ge_ut_common ge_ut_common_format
gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lgcov
)

# libge_kernel_utest
add_executable(ut_libge_kernel_utest
${COMMON_TEST_FILES}
${COMMON_FORMAT_SRC_FILES}
${KERNEL_TEST_FILES}
${KERNEL_SRC_FILES}
${COMMON_TEST_FILES}
${KERNEL_TEST_FILES}
${KERNEL_SRC_FILES}
)

target_compile_options(ut_libge_kernel_utest PRIVATE
@@ -1110,8 +1111,8 @@ target_compile_options(ut_libge_kernel_utest PRIVATE

target_link_libraries(ut_libge_kernel_utest
$<BUILD_INTERFACE:intf_pub>
ge_load_common ge_ut_common
gtest gtest_main gmock gmock_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov
ge_load_common ge_ut_common ge_ut_common_format
gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lgcov
)

# libge_distinct_load_utest
@@ -1137,10 +1138,11 @@ target_compile_definitions(ut_libge_distinct_load_utest PRIVATE
)

target_link_libraries(ut_libge_distinct_load_utest
${COMMON_SHARED_LIBRARIES}
$<BUILD_INTERFACE:intf_pub>
ge_execute_common ge_ut_common_format ge_load_common
ge_single_op ge_prepare_common
ge_optimize_common ge_build_common ge_partition_common ge_ut_common
gtest gtest_main gmock gmock_main ascend_protobuf json c_sec -lrt -ldl -lpthread -lgcov
-Wl,--whole-archive
ge_single_op
-Wl,--no-whole-archive
ge_execute_common ge_load_common
ge_prepare_common ge_optimize_common ge_build_common ge_partition_common ge_ut_common ge_ut_common_format
gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lpthread -lgcov
)

+ 1
- 1
tests/ut/ge/graph/ge_executor_unittest.cc View File

@@ -115,7 +115,7 @@ TEST_F(UtestGeExecutor, load_data_from_file) {

string test_smap = "/tmp/" + std::to_string(getpid()) + "_maps";
string self_smap = "/proc/" + std::to_string(getpid()) + "/maps";
string copy_smap = "cp " + self_smap + " " + test_smap;
string copy_smap = "cp -f " + self_smap + " " + test_smap;
EXPECT_EQ(system(copy_smap.c_str()), 0);

ModelData model_data;


+ 6
- 5
tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc View File

@@ -91,8 +91,8 @@ TEST_F(UtestKernelExTaskInfo, success_kernel_ex_task_release) {
// test kernel_ex_task_Release
TEST_F(UtestKernelExTaskInfo, success_kernel_ex_task_info_copy) {
DavinciModel model(0, nullptr);
model.runtime_param_.mem_base = (uint8_t *)0x12345;
model.runtime_param_.mem_size = 100332000;
model.runtime_param_.mem_size = 10240;
model.runtime_param_.mem_base = new uint8_t[model.runtime_param_.mem_size];

rtStream_t stream = nullptr;
rtStreamCreate(&stream, 0);
@@ -108,19 +108,20 @@ TEST_F(UtestKernelExTaskInfo, success_kernel_ex_task_info_copy) {

EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace empty.

model.op_list_[0]->SetWorkspace({100331008}); // offset
model.op_list_[0]->SetWorkspace({1008}); // offset
model.op_list_[0]->SetWorkspaceBytes({0}); // length
EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace addr is null.

model.op_list_[0]->SetWorkspace({100331008}); // offset
model.op_list_[0]->SetWorkspace({1208}); // offset
model.op_list_[0]->SetWorkspaceBytes({10}); // length
EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace addr is small.

model.op_list_[0]->SetWorkspace({100331008}); // offset
model.op_list_[0]->SetWorkspace({1308}); // offset
model.op_list_[0]->SetWorkspaceBytes({150}); // length
EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), SUCCESS);

task_def.clear_kernel_ex();
delete [] model.runtime_param_.mem_base;
model.runtime_param_.mem_base = nullptr;
}



+ 1
- 1
tests/ut/ge/graph/load/model_manager_unittest.cc View File

@@ -418,6 +418,6 @@ TEST_F(UtestModelManagerModelManager, test_data_input_tensor) {
vector<InputTensorInfo> inputs;
inputs.emplace_back(input_tensor);
auto ret = mm.DataInputTensor(model_id,inputs);
EXPECT_EQ(UNSUPPORTED, ret);
EXPECT_EQ(PARAM_INVALID, ret); // HybridDavinciModel::impl_ is null.
}
} // namespace ge

+ 249
- 0
tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc View File

@@ -0,0 +1,249 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <vector>

#define private public
#define protected public
#include "hybrid/executor/subgraph_executor.h"
#include "hybrid/node_executor/node_executor.h"
#include "hybrid/node_executor/rts/rts_node_executor.h"
#include "hybrid/node_executor/ge_local/ge_local_node_executor.h"
#include "hybrid/model/hybrid_model_builder.h"
#include "graph/utils/graph_utils.h"

using namespace std;
using namespace testing;

namespace ge {
using namespace hybrid;

class UtestSubgraphExecutor : public testing::Test {
protected:
void SetUp() {
NodeExecutorManager::GetInstance().engine_mapping_.clear();
auto &engine_mapping = NodeExecutorManager::GetInstance().engine_mapping_;
engine_mapping.emplace("DNN_VM_RTS_OP_STORE", NodeExecutorManager::ExecutorType::RTS);
engine_mapping.emplace("DNN_VM_GE_LOCAL_OP_STORE", NodeExecutorManager::ExecutorType::GE_LOCAL);

NodeExecutorManager::GetInstance().executors_.clear();
auto &task_executor = NodeExecutorManager::GetInstance().executors_;
task_executor.emplace(NodeExecutorManager::ExecutorType::RTS, std::unique_ptr<NodeExecutor>(new RtsNodeExecutor()));
task_executor.emplace(NodeExecutorManager::ExecutorType::GE_LOCAL, std::unique_ptr<NodeExecutor>(new GeLocalNodeExecutor()));
}
void TearDown() {
NodeExecutorManager::GetInstance().engine_mapping_.clear();
NodeExecutorManager::GetInstance().executors_.clear();
}
};

static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
op_desc->SetStreamId(0);
static int32_t index = 0;
op_desc->SetId(index++);

GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64);
TensorUtils::SetSize(tensor, 64);
vector<int64_t> input_offset;
for (int i = 0; i < in_num; i++) {
op_desc->AddInputDesc(tensor);
input_offset.emplace_back(index * 64 + i * 64);
}
op_desc->SetInputOffset(input_offset);

vector<int64_t> output_offset;
for (int i = 0; i < out_num; i++) {
op_desc->AddOutputDesc(tensor);
output_offset.emplace_back(index * 64 + in_num * 64 + i * 64);
}
op_desc->SetOutputOffset(output_offset);

op_desc->SetWorkspace({});
op_desc->SetWorkspaceBytes({});
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE");

return graph.AddNode(op_desc);
}

static void CreateSimpleCondGraph(ComputeGraph &graph) {
/*******************************************************************************
* |
* Merge
* / \.
* / \.
* / \.
* Add Sub
* | \ / |
* | \ _ / |
* | / \ |
* | / \ |
* Switch Switch
* | \ / |
* | \ / |
* | \ / |
* | \ / |
* | Less |
* | / \ |
* | / \ |
* Data Data
******************************************************************************/
const auto data0 = CreateNode(graph, "data", DATA, 1, 1);
const auto data1 = CreateNode(graph, "data1", DATA, 1, 1);
data0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE");
data1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE");

const auto const0 = CreateNode(graph, "const", CONSTANT, 0, 1);
const auto const1 = CreateNode(graph, "const1", CONSTANT, 0, 1);
const0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE");
const1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE");
{
uint64_t const_value = 0;
const auto op_desc = const0->GetOpDesc();
auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t));
AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight);
}
{
uint64_t const_value = 1;
const auto op_desc = const1->GetOpDesc();
auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t));
AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight);
}

const auto less1 = CreateNode(graph, "less", ENTER, 2, 1);

const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0);
const auto switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0);
const auto switch_f = CreateNode(graph, "switch_f", STREAMSWITCH, 2, 0);
AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1);
AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1);

const auto add1 = CreateNode(graph, "add", ENTER, 2, 1);
const auto sub1 = CreateNode(graph, "sub", ENTER, 2, 1);

const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2);
const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0);
const auto active3 = CreateNode(graph, "active3", STREAMACTIVE, 0, 0);

const auto output1 = CreateNode(graph, "net_output", NETOUTPUT, 1, 1);
output1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE");

GraphUtils::AddEdge(data0->GetOutDataAnchor(0), less1->GetInDataAnchor(0));
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), less1->GetInDataAnchor(1));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_t->GetInDataAnchor(0));
GraphUtils::AddEdge(const0->GetOutDataAnchor(0), switch_t->GetInDataAnchor(1));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_f->GetInDataAnchor(0));
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), switch_f->GetInDataAnchor(1));

GraphUtils::AddEdge(less1->GetOutControlAnchor(), active1->GetInControlAnchor());
GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor());
GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_f->GetInControlAnchor());

GraphUtils::AddEdge(data0->GetOutDataAnchor(0), add1->GetInDataAnchor(0));
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), add1->GetInDataAnchor(1));
GraphUtils::AddEdge(add1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_t->GetOutControlAnchor(), add1->GetInControlAnchor());
GraphUtils::AddEdge(add1->GetOutControlAnchor(), active2->GetInControlAnchor());
GraphUtils::AddEdge(active2->GetOutControlAnchor(), merge1->GetInControlAnchor());

GraphUtils::AddEdge(data0->GetOutDataAnchor(0), sub1->GetInDataAnchor(0));
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), sub1->GetInDataAnchor(1));
GraphUtils::AddEdge(sub1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1));
GraphUtils::AddEdge(switch_f->GetOutControlAnchor(), sub1->GetInControlAnchor());
GraphUtils::AddEdge(sub1->GetOutControlAnchor(), active3->GetInControlAnchor());
GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor());

GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0));
}

TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
const auto data0 = CreateNode(*graph, "data", DATA, 1, 1);
const auto output0 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1);
GraphUtils::AddEdge(data0->GetOutDataAnchor(0), output0->GetInDataAnchor(0));
data0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE");
output0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE");

GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
GeModelPtr ge_sub_model = make_shared<GeModel>();
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);

HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS);

GraphExecutionContext graph_context;
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());
graph_context.model = &hybrid_model;

uint64_t value_0 = 110;
TensorValue in_tensor0(&value_0, sizeof(value_0));
const std::vector<TensorValue> inputs{ in_tensor0 };

uint64_t value_1 = 123;
TensorValue out_tensor0(&value_1, sizeof(value_1));
const std::vector<TensorValue> outputs{ out_tensor0 };

auto input_desc = output0->GetOpDesc()->GetInputDescPtr(0);
const std::vector<ConstGeTensorDescPtr> input_descs{ input_desc };

SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context);
ASSERT_EQ(executor.ExecuteAsync(inputs, input_descs, outputs), SUCCESS);
}

TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
CreateSimpleCondGraph(*graph);

GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
GeModelPtr ge_sub_model = make_shared<GeModel>();
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);
Buffer weights_buffer(1024, 0x76);
ge_sub_model->SetWeight(weights_buffer);
ge_sub_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph));

HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS);

GraphExecutionContext graph_context;
graph_context.model = &hybrid_model;
graph_context.allocator = NpuMemoryAllocator::GetAllocator(0);
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());
ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS);

uint64_t value_0 = 110;
TensorValue in_tensor0(&value_0, sizeof(value_0));
uint64_t value_1 = 110;
TensorValue in_tensor1(&value_1, sizeof(value_1));
const std::vector<TensorValue> inputs{ in_tensor0, in_tensor1 };
uint64_t value_2 = 123;
TensorValue out_tensor0(&value_2, sizeof(value_2));
const std::vector<TensorValue> outputs{ out_tensor0 };

GeTensorDescPtr tensor_desc = make_shared<GeTensorDesc>(GeShape(), FORMAT_ND, DT_INT64);
TensorUtils::SetSize(*tensor_desc, 64);
const std::vector<ConstGeTensorDescPtr> input_desc{ tensor_desc, tensor_desc };

SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context);
ASSERT_EQ(executor.ExecuteAsync(inputs, input_desc, outputs), SUCCESS);
ASSERT_EQ(graph_context.callback_manager->Destroy(), SUCCESS);
}
} // namespace ge

+ 16
- 2
tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc View File

@@ -26,6 +26,7 @@
#include "hybrid/executor/hybrid_execution_context.h"
#include "hybrid/executor/hybrid_model_executor.h"
#include "hybrid/executor/worker/execution_engine.h"
#include "hybrid/executor/subgraph_executor.h"
#undef private
#undef protected

@@ -75,6 +76,10 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) {
node_item->output_start = 0;

GraphExecutionContext execution_context;
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
HybridModel hybrid_model(ge_root_model);
hybrid_model.root_graph_item_ = std::unique_ptr<GraphItem>(new(std::nothrow)GraphItem());
execution_context.model = &hybrid_model;
execution_context.profiling_level = 1;
SubgraphContext subgraph_context(nullptr, &execution_context);

@@ -85,7 +90,11 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) {

ExecutionEngine execution_engine;
ASSERT_TRUE(node_state.GetTaskContext() != nullptr);
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context), INTERNAL_ERROR);

std::function<void()> callback;
SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context);
executor.InitCallback(&node_state, callback);
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR);
}

TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) {
@@ -105,6 +114,7 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) {
GraphExecutionContext execution_context;
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
HybridModel hybrid_model(ge_root_model);
hybrid_model.root_graph_item_ = std::unique_ptr<GraphItem>(new(std::nothrow)GraphItem());
execution_context.model = &hybrid_model;
SubgraphContext subgraph_context(nullptr, &execution_context);

@@ -115,5 +125,9 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) {

ExecutionEngine execution_engine;
ASSERT_TRUE(node_state.GetTaskContext() != nullptr);
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context), INTERNAL_ERROR);

std::function<void()> callback;
SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context);
executor.InitCallback(&node_state, callback);
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR);
}

Loading…
Cancel
Save