diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index a27222ef..9c4d795c 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -113,6 +113,43 @@ TEST_F(UtestGeHybrid, aicore_op_task_init_success) { ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS); } +TEST_F(UtestGeHybrid, aicore_op_task_init_success2) { + ComputeGraphPtr graph = std::make_shared("test"); + GeRootModelPtr ge_root_model = make_shared(graph); + HybridModel hybrid_model(ge_root_model); + + // build aicore task + auto aicore_task = std::unique_ptr(new(std::nothrow)hybrid::AiCoreOpTask()); + domi::TaskDef task_def; + task_def.set_type(RT_MODEL_TASK_KERNEL); + domi::KernelDef *kernel = task_def.mutable_kernel(); + kernel->set_original_kernel_key(""); + kernel->set_node_info(""); + kernel->set_block_dim(32); + kernel->set_args_size(64); + string args(64, '1'); + kernel->set_args(args.data(), 64); + domi::KernelContext *context = kernel->mutable_context(); + context->set_op_index(1); + context->set_kernel_type(2); // ccKernelType::TE + uint16_t args_offset[9] = {0}; + context->set_args_offset(args_offset, 9 * sizeof(uint16_t)); + OpDescPtr op_desc = CreateOpDesc("Add", "Add"); + std::vector kernelBin; + TBEKernelPtr tbe_kernel = std::make_shared("name/Add", std::move(kernelBin)); + op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); + std::string kernel_name("kernel/Add"); + AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name); + ASSERT_EQ(aicore_task->InitWithTaskDef(hybrid_model, *op_desc.get(), task_def), SUCCESS); + rtStream_t stream = nullptr; + rtStreamCreate(&stream, 0); + ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS); + char *handle = ""; + aicore_task->handle_ = handle; + aicore_task->tiling_key_ = 1; + ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS); +} + TEST_F(UtestGeHybrid, task_update_tiling_info) { auto aicore_task = std::unique_ptr(new(std::nothrow)hybrid::AiCoreOpTask()); auto graph = make_shared("graph");