Browse Source

Fix hccl_node_executor_unittest

tags/v1.3.0
zhangxiaokun 4 years ago
parent
commit
8852766766
1 changed files with 3 additions and 14 deletions
  1. +3
    -14
      tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc

+ 3
- 14
tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc View File

@@ -94,18 +94,17 @@ TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) {
tensor.SetData(data);
ctx->SetTensor(1, 0, tensor.Clone());
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
vector<HcomRemoteAccessAddrInfo> addr_infos;
shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>();
task->remote_index_ = {1, 0};
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID);
ASSERT_EQ(task->ExtractTensor(*node_state->GetTaskContext(), addr_infos), PARAM_INVALID);
Shape s2({1});
TensorDesc tensor_desc2(s2);
Tensor tensor2(tensor_desc2);
ctx->SetTensor(1, 0, tensor2.Clone());
task->ExtractTensor(*unique_task_context, addr_infos);
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID);
task->ExtractTensor(*node_state->GetTaskContext(), addr_infos);
ASSERT_EQ(task->ExtractTensor(*node_state->GetTaskContext(), addr_infos), PARAM_INVALID);
RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id));
}
@@ -140,11 +139,6 @@ TEST_F(UtestHcclNodeExecutor, gatheralltoallv_execute) {
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);
for (int i=0; i<4; ++i) {
uint64_t value_0 = 512;
TensorValue in_tensor0(&value_0, sizeof(value_0));
@@ -206,11 +200,6 @@ TEST_F(UtestHcclNodeExecutor, alltoallv_execute) {
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);
for (int i=0; i<5; ++i) {
uint64_t value_0 = 512;
TensorValue in_tensor0(&value_0, sizeof(value_0));


Loading…
Cancel
Save