diff --git a/tests/ut/ge/hybrid/node_executor/aicore/aicore_op_task_unittest.cc b/tests/ut/ge/hybrid/node_executor/aicore/aicore_op_task_unittest.cc index ff48946d..d6b39996 100644 --- a/tests/ut/ge/hybrid/node_executor/aicore/aicore_op_task_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/aicore/aicore_op_task_unittest.cc @@ -22,6 +22,7 @@ #define private public #define protected public #include "framework/common/taskdown_common.h" +#include "hybrid/executor/subgraph_context.h" #include "hybrid/node_executor/aicore/aicore_op_task.h" #include "init/gelib.h" #undef private @@ -80,12 +81,88 @@ TEST_F(UtestAiCoreOpTask, Init_failed) { DEPEND_SHAPE_RANGE); domi::TaskDef task_def; task_def.set_type(RT_MODEL_TASK_KERNEL); - std::vector args(10, 0); + std::vector args(100, 0); task_def.mutable_kernel()->set_args(args.data(), args.size()); - task_def.mutable_kernel()->set_args_size(10); + task_def.mutable_kernel()->set_args_size(100); task_def.mutable_kernel()->mutable_context()->set_kernel_type( ccKernelType::TE); + uint16_t args_offset = 20; + char *a = reinterpret_cast(&args_offset); + task_def.mutable_kernel()->mutable_context()->set_args_offset( + a, 2 * sizeof(uint16_t)); EXPECT_EQ(task1->Init(*op_desc, task_def), ge::PARAM_INVALID); dlog_setlevel(0, 3, 0); } + +TEST_F(UtestAiCoreOpTask, Init_success) { + dlog_setlevel(0, 0, 0); + std::unique_ptr task1(new AiCoreOpTask()); + OpDescPtr op_desc = CreateOpDesc("Add", "Add", 2, 1); + ge::AttrUtils::SetInt(*op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, + DEPEND_SHAPE_RANGE); + domi::TaskDef task_def; + task_def.set_type(RT_MODEL_TASK_KERNEL); + std::vector args(100, 0); + task_def.mutable_kernel()->set_args(args.data(), args.size()); + task_def.mutable_kernel()->set_args_size(100); + task_def.mutable_kernel()->mutable_context()->set_kernel_type( + ccKernelType::TE); + uint16_t args_offset = 20; + char *a = reinterpret_cast(&args_offset); + task_def.mutable_kernel()->mutable_context()->set_args_offset( + a, 2 * sizeof(uint16_t)); + EXPECT_EQ(task1->Init(*op_desc, task_def), ge::SUCCESS); + dlog_setlevel(0, 3, 0); +} + +TEST_F(UtestAiCoreOpTask, UpdateOutputsShape_success) { + dlog_setlevel(0, 0, 0); + std::unique_ptr task1(new AiCoreOpTask()); + OpDescPtr op_desc = CreateOpDesc("Add", "Add", 2, 1); + ge::AttrUtils::SetInt(*op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, + DEPEND_SHAPE_RANGE); + domi::TaskDef task_def; + task_def.set_type(RT_MODEL_TASK_KERNEL); + std::vector args(100, 0); + task_def.mutable_kernel()->set_args(args.data(), args.size()); + task_def.mutable_kernel()->set_args_size(100); + task_def.mutable_kernel()->mutable_context()->set_kernel_type( + ccKernelType::TE); + uint16_t args_offset = 20; + char *a = reinterpret_cast(&args_offset); + task_def.mutable_kernel()->mutable_context()->set_args_offset( + a, 2 * sizeof(uint16_t)); + EXPECT_EQ(task1->Init(*op_desc, task_def), ge::SUCCESS); + + ComputeGraphPtr graph = std::make_shared("test"); + NodePtr node = graph->AddNode(op_desc); + std::unique_ptr new_node; + ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); + NodeItem *node_item = new_node.get(); + node_item->input_start = 0; + node_item->output_start = 0; + node_item->is_dynamic = true; + node_item->shape_inference_type = DEPEND_SHAPE_RANGE; + + GraphItem graph_item; + graph_item.node_items_.emplace_back(node_item); + graph_item.total_inputs_ = 2; + graph_item.total_outputs_ = 1; + + GraphExecutionContext graph_context; + SubgraphContext subgraph_context(&graph_item, &graph_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); + graph_context.callback_manager = + std::unique_ptr(new CallbackManager()); + + auto node_state = subgraph_context.GetOrCreateNodeState(node_item); + ASSERT_NE(node_state, nullptr); + auto outputs_shape = + reinterpret_cast(task1->shape_buffer_->GetData()); + outputs_shape[0][0] = 2; + outputs_shape[0][1] = 1; + outputs_shape[0][2] = 2; + ASSERT_EQ(task1->UpdateOutputsShape(*node_state->GetTaskContext()), SUCCESS); + dlog_setlevel(0, 3, 0); +} } // namespace ge