@@ -84,9 +84,6 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) { | |||
SubgraphContext subgraph_context(nullptr, &execution_context); | |||
NodeState node_state(*node_item, &subgraph_context); | |||
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(task_context.release()); | |||
node_state.SetTaskContext(shared_task_context); | |||
ExecutionEngine execution_engine; | |||
ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | |||
@@ -119,14 +116,11 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | |||
SubgraphContext subgraph_context(nullptr, &execution_context); | |||
NodeState node_state(*node_item, &subgraph_context); | |||
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||
uint32_t task_id = 0; | |||
uint32_t stream_id = 1; | |||
std::string task_type = "rts"; | |||
uint32_t block_dim = 0; | |||
task_context->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(task_context.release()); | |||
node_state.SetTaskContext(shared_task_context); | |||
node_state.GetTaskContext()->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim); | |||
ExecutionEngine execution_engine; | |||
ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | |||
@@ -160,10 +160,8 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) { | |||
GraphExecutionContext execution_context; | |||
SubgraphContext subgraph_context(nullptr, &execution_context); | |||
NodeState node_state(*node_item, &subgraph_context); | |||
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||
ASSERT_TRUE(task_context != nullptr); | |||
ASSERT_EQ(aicore_task->InitTilingInfo(*op_desc), SUCCESS); | |||
ASSERT_EQ(aicore_task->UpdateTilingInfo(*task_context), SUCCESS); | |||
ASSERT_EQ(aicore_task->UpdateTilingInfo(*node_state.GetTaskContext()), SUCCESS); | |||
} | |||
TEST_F(UtestGeHybrid, index_taskdefs_failed) { | |||
@@ -481,7 +479,7 @@ TEST_F(UtestGeHybrid, TestTaskContext) { | |||
subgraph_context.all_outputs_.resize(1); | |||
NodeState node_state(*node_item, &subgraph_context); | |||
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||
auto task_context = node_state.GetTaskContext(); | |||
ASSERT_TRUE(task_context != nullptr); | |||
auto desc = task_context->MutableInputDesc(2); | |||
ASSERT_TRUE(desc == nullptr); | |||
@@ -526,7 +524,7 @@ TEST_F(UtestGeHybrid, hybrid_model_executor_update_args) { | |||
subgraph_context.all_outputs_.resize(1); | |||
NodeState node_state(*node_item, &subgraph_context); | |||
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||
auto task_context = node_state.GetTaskContext(); | |||
int32_t buffer[1]; | |||
aicore_task->tiling_buffer_ = TensorBuffer::Create(buffer, sizeof(buffer)); | |||
@@ -97,11 +97,6 @@ TEST_F(UtestGeLocalNodeExecutor, test_no_op_task) { | |||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
ASSERT_NE(node_state, nullptr); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
ASSERT_NE(unique_task_context, nullptr); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state->SetTaskContext(shared_task_context); | |||
NodeTaskPtr task = nullptr; | |||
GeLocalNodeExecutor node_executor; | |||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||
@@ -94,7 +94,7 @@ TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) { | |||
tensor.SetData(data); | |||
ctx->SetTensor(1, 0, tensor.Clone()); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
auto unique_task_context = node_state->GetTaskContext(); | |||
vector<HcomRemoteAccessAddrInfo> addr_infos; | |||
shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>(); | |||
task->remote_index_ = {1, 0}; | |||
@@ -140,11 +140,6 @@ TEST_F(UtestHcclNodeExecutor, gatheralltoallv_execute) { | |||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
ASSERT_NE(node_state, nullptr); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
ASSERT_NE(unique_task_context, nullptr); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state->SetTaskContext(shared_task_context); | |||
for (int i=0; i<4; ++i) { | |||
uint64_t value_0 = 512; | |||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
@@ -206,11 +201,6 @@ TEST_F(UtestHcclNodeExecutor, alltoallv_execute) { | |||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
ASSERT_NE(node_state, nullptr); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
ASSERT_NE(unique_task_context, nullptr); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state->SetTaskContext(shared_task_context); | |||
for (int i=0; i<5; ++i) { | |||
uint64_t value_0 = 512; | |||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
@@ -96,11 +96,6 @@ TEST_F(UtestRtsNodeTask, test_stream_switch_task) { | |||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
ASSERT_NE(node_state, nullptr); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
ASSERT_NE(unique_task_context, nullptr); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state->SetTaskContext(shared_task_context); | |||
uint64_t value_0 = 110; | |||
uint64_t value_1 = 120; | |||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
@@ -153,11 +148,6 @@ TEST_F(UtestRtsNodeTask, test_stream_active_task) { | |||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
ASSERT_NE(node_state, nullptr); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
ASSERT_NE(unique_task_context, nullptr); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state->SetTaskContext(shared_task_context); | |||
NodeTaskPtr task = nullptr; | |||
RtsNodeExecutor node_executor; | |||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||
@@ -203,11 +193,6 @@ TEST_F(UtestRtsNodeTask, test_stream_merge_task) { | |||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
ASSERT_NE(node_state, nullptr); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
ASSERT_NE(unique_task_context, nullptr); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state->SetTaskContext(shared_task_context); | |||
uint64_t value_0 = 110; | |||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||
@@ -271,11 +256,6 @@ TEST_F(UtestRtsNodeTask, test_memcpy_async_task) { | |||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
ASSERT_NE(node_state, nullptr); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
ASSERT_NE(unique_task_context, nullptr); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state->SetTaskContext(shared_task_context); | |||
uint64_t value_0 = 110; | |||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||
@@ -328,11 +308,6 @@ TEST_F(UtestRtsNodeTask, test_pass_through_task) { | |||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
ASSERT_NE(node_state, nullptr); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
ASSERT_NE(unique_task_context, nullptr); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state->SetTaskContext(shared_task_context); | |||
uint64_t value_0 = 110; | |||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||
@@ -384,11 +359,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_set) { | |||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
ASSERT_NE(node_state, nullptr); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
ASSERT_NE(unique_task_context, nullptr); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state->SetTaskContext(shared_task_context); | |||
NodeTaskPtr task = nullptr; | |||
RtsNodeExecutor node_executor; | |||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||
@@ -428,11 +398,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_goto) { | |||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
ASSERT_NE(node_state, nullptr); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
ASSERT_NE(unique_task_context, nullptr); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state->SetTaskContext(shared_task_context); | |||
NodeTaskPtr task = nullptr; | |||
RtsNodeExecutor node_executor; | |||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||
@@ -472,11 +437,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_switch) { | |||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
ASSERT_NE(node_state, nullptr); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
ASSERT_NE(unique_task_context, nullptr); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state->SetTaskContext(shared_task_context); | |||
NodeTaskPtr task = nullptr; | |||
RtsNodeExecutor node_executor; | |||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||