@@ -84,9 +84,6 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) { | |||||
SubgraphContext subgraph_context(nullptr, &execution_context); | SubgraphContext subgraph_context(nullptr, &execution_context); | ||||
NodeState node_state(*node_item, &subgraph_context); | NodeState node_state(*node_item, &subgraph_context); | ||||
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(task_context.release()); | |||||
node_state.SetTaskContext(shared_task_context); | |||||
ExecutionEngine execution_engine; | ExecutionEngine execution_engine; | ||||
ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | ||||
@@ -119,14 +116,11 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | |||||
SubgraphContext subgraph_context(nullptr, &execution_context); | SubgraphContext subgraph_context(nullptr, &execution_context); | ||||
NodeState node_state(*node_item, &subgraph_context); | NodeState node_state(*node_item, &subgraph_context); | ||||
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||||
uint32_t task_id = 0; | uint32_t task_id = 0; | ||||
uint32_t stream_id = 1; | uint32_t stream_id = 1; | ||||
std::string task_type = "rts"; | std::string task_type = "rts"; | ||||
uint32_t block_dim = 0; | uint32_t block_dim = 0; | ||||
task_context->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(task_context.release()); | |||||
node_state.SetTaskContext(shared_task_context); | |||||
node_state.GetTaskContext()->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim); | |||||
ExecutionEngine execution_engine; | ExecutionEngine execution_engine; | ||||
ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | ||||
@@ -160,10 +160,8 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) { | |||||
GraphExecutionContext execution_context; | GraphExecutionContext execution_context; | ||||
SubgraphContext subgraph_context(nullptr, &execution_context); | SubgraphContext subgraph_context(nullptr, &execution_context); | ||||
NodeState node_state(*node_item, &subgraph_context); | NodeState node_state(*node_item, &subgraph_context); | ||||
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||||
ASSERT_TRUE(task_context != nullptr); | |||||
ASSERT_EQ(aicore_task->InitTilingInfo(*op_desc), SUCCESS); | ASSERT_EQ(aicore_task->InitTilingInfo(*op_desc), SUCCESS); | ||||
ASSERT_EQ(aicore_task->UpdateTilingInfo(*task_context), SUCCESS); | |||||
ASSERT_EQ(aicore_task->UpdateTilingInfo(*node_state.GetTaskContext()), SUCCESS); | |||||
} | } | ||||
TEST_F(UtestGeHybrid, index_taskdefs_failed) { | TEST_F(UtestGeHybrid, index_taskdefs_failed) { | ||||
@@ -481,7 +479,7 @@ TEST_F(UtestGeHybrid, TestTaskContext) { | |||||
subgraph_context.all_outputs_.resize(1); | subgraph_context.all_outputs_.resize(1); | ||||
NodeState node_state(*node_item, &subgraph_context); | NodeState node_state(*node_item, &subgraph_context); | ||||
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||||
auto task_context = node_state.GetTaskContext(); | |||||
ASSERT_TRUE(task_context != nullptr); | ASSERT_TRUE(task_context != nullptr); | ||||
auto desc = task_context->MutableInputDesc(2); | auto desc = task_context->MutableInputDesc(2); | ||||
ASSERT_TRUE(desc == nullptr); | ASSERT_TRUE(desc == nullptr); | ||||
@@ -526,7 +524,7 @@ TEST_F(UtestGeHybrid, hybrid_model_executor_update_args) { | |||||
subgraph_context.all_outputs_.resize(1); | subgraph_context.all_outputs_.resize(1); | ||||
NodeState node_state(*node_item, &subgraph_context); | NodeState node_state(*node_item, &subgraph_context); | ||||
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||||
auto task_context = node_state.GetTaskContext(); | |||||
int32_t buffer[1]; | int32_t buffer[1]; | ||||
aicore_task->tiling_buffer_ = TensorBuffer::Create(buffer, sizeof(buffer)); | aicore_task->tiling_buffer_ = TensorBuffer::Create(buffer, sizeof(buffer)); | ||||
@@ -97,11 +97,6 @@ TEST_F(UtestGeLocalNodeExecutor, test_no_op_task) { | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | ||||
ASSERT_NE(node_state, nullptr); | ASSERT_NE(node_state, nullptr); | ||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
ASSERT_NE(unique_task_context, nullptr); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state->SetTaskContext(shared_task_context); | |||||
NodeTaskPtr task = nullptr; | NodeTaskPtr task = nullptr; | ||||
GeLocalNodeExecutor node_executor; | GeLocalNodeExecutor node_executor; | ||||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | ||||
@@ -94,7 +94,7 @@ TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) { | |||||
tensor.SetData(data); | tensor.SetData(data); | ||||
ctx->SetTensor(1, 0, tensor.Clone()); | ctx->SetTensor(1, 0, tensor.Clone()); | ||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
auto unique_task_context = node_state->GetTaskContext(); | |||||
vector<HcomRemoteAccessAddrInfo> addr_infos; | vector<HcomRemoteAccessAddrInfo> addr_infos; | ||||
shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>(); | shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>(); | ||||
task->remote_index_ = {1, 0}; | task->remote_index_ = {1, 0}; | ||||
@@ -140,11 +140,6 @@ TEST_F(UtestHcclNodeExecutor, gatheralltoallv_execute) { | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | ||||
ASSERT_NE(node_state, nullptr); | ASSERT_NE(node_state, nullptr); | ||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
ASSERT_NE(unique_task_context, nullptr); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state->SetTaskContext(shared_task_context); | |||||
for (int i=0; i<4; ++i) { | for (int i=0; i<4; ++i) { | ||||
uint64_t value_0 = 512; | uint64_t value_0 = 512; | ||||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | TensorValue in_tensor0(&value_0, sizeof(value_0)); | ||||
@@ -206,11 +201,6 @@ TEST_F(UtestHcclNodeExecutor, alltoallv_execute) { | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | ||||
ASSERT_NE(node_state, nullptr); | ASSERT_NE(node_state, nullptr); | ||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
ASSERT_NE(unique_task_context, nullptr); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state->SetTaskContext(shared_task_context); | |||||
for (int i=0; i<5; ++i) { | for (int i=0; i<5; ++i) { | ||||
uint64_t value_0 = 512; | uint64_t value_0 = 512; | ||||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | TensorValue in_tensor0(&value_0, sizeof(value_0)); | ||||
@@ -96,11 +96,6 @@ TEST_F(UtestRtsNodeTask, test_stream_switch_task) { | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | ||||
ASSERT_NE(node_state, nullptr); | ASSERT_NE(node_state, nullptr); | ||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
ASSERT_NE(unique_task_context, nullptr); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state->SetTaskContext(shared_task_context); | |||||
uint64_t value_0 = 110; | uint64_t value_0 = 110; | ||||
uint64_t value_1 = 120; | uint64_t value_1 = 120; | ||||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | TensorValue in_tensor0(&value_0, sizeof(value_0)); | ||||
@@ -153,11 +148,6 @@ TEST_F(UtestRtsNodeTask, test_stream_active_task) { | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | ||||
ASSERT_NE(node_state, nullptr); | ASSERT_NE(node_state, nullptr); | ||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
ASSERT_NE(unique_task_context, nullptr); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state->SetTaskContext(shared_task_context); | |||||
NodeTaskPtr task = nullptr; | NodeTaskPtr task = nullptr; | ||||
RtsNodeExecutor node_executor; | RtsNodeExecutor node_executor; | ||||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | ||||
@@ -203,11 +193,6 @@ TEST_F(UtestRtsNodeTask, test_stream_merge_task) { | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | ||||
ASSERT_NE(node_state, nullptr); | ASSERT_NE(node_state, nullptr); | ||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
ASSERT_NE(unique_task_context, nullptr); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state->SetTaskContext(shared_task_context); | |||||
uint64_t value_0 = 110; | uint64_t value_0 = 110; | ||||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | TensorValue in_tensor0(&value_0, sizeof(value_0)); | ||||
subgraph_context.SetInput(*node_item, 0, in_tensor0); | subgraph_context.SetInput(*node_item, 0, in_tensor0); | ||||
@@ -271,11 +256,6 @@ TEST_F(UtestRtsNodeTask, test_memcpy_async_task) { | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | ||||
ASSERT_NE(node_state, nullptr); | ASSERT_NE(node_state, nullptr); | ||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
ASSERT_NE(unique_task_context, nullptr); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state->SetTaskContext(shared_task_context); | |||||
uint64_t value_0 = 110; | uint64_t value_0 = 110; | ||||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | TensorValue in_tensor0(&value_0, sizeof(value_0)); | ||||
subgraph_context.SetInput(*node_item, 0, in_tensor0); | subgraph_context.SetInput(*node_item, 0, in_tensor0); | ||||
@@ -328,11 +308,6 @@ TEST_F(UtestRtsNodeTask, test_pass_through_task) { | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | ||||
ASSERT_NE(node_state, nullptr); | ASSERT_NE(node_state, nullptr); | ||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
ASSERT_NE(unique_task_context, nullptr); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state->SetTaskContext(shared_task_context); | |||||
uint64_t value_0 = 110; | uint64_t value_0 = 110; | ||||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | TensorValue in_tensor0(&value_0, sizeof(value_0)); | ||||
subgraph_context.SetInput(*node_item, 0, in_tensor0); | subgraph_context.SetInput(*node_item, 0, in_tensor0); | ||||
@@ -384,11 +359,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_set) { | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | ||||
ASSERT_NE(node_state, nullptr); | ASSERT_NE(node_state, nullptr); | ||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
ASSERT_NE(unique_task_context, nullptr); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state->SetTaskContext(shared_task_context); | |||||
NodeTaskPtr task = nullptr; | NodeTaskPtr task = nullptr; | ||||
RtsNodeExecutor node_executor; | RtsNodeExecutor node_executor; | ||||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | ||||
@@ -428,11 +398,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_goto) { | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | ||||
ASSERT_NE(node_state, nullptr); | ASSERT_NE(node_state, nullptr); | ||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
ASSERT_NE(unique_task_context, nullptr); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state->SetTaskContext(shared_task_context); | |||||
NodeTaskPtr task = nullptr; | NodeTaskPtr task = nullptr; | ||||
RtsNodeExecutor node_executor; | RtsNodeExecutor node_executor; | ||||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | ||||
@@ -472,11 +437,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_switch) { | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | ||||
ASSERT_NE(node_state, nullptr); | ASSERT_NE(node_state, nullptr); | ||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
ASSERT_NE(unique_task_context, nullptr); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state->SetTaskContext(shared_task_context); | |||||
NodeTaskPtr task = nullptr; | NodeTaskPtr task = nullptr; | ||||
RtsNodeExecutor node_executor; | RtsNodeExecutor node_executor; | ||||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | ||||