diff --git a/ge/graph/load/model_manager/task_info/kernel_task_info.cc b/ge/graph/load/model_manager/task_info/kernel_task_info.cc index d69d0a8b..919a56cd 100755 --- a/ge/graph/load/model_manager/task_info/kernel_task_info.cc +++ b/ge/graph/load/model_manager/task_info/kernel_task_info.cc @@ -1066,10 +1066,6 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k } InitDumpArgs(sizeof(aicpu::AicpuParamHead)); - if (kernel_type_ == ccKernelType::CUST_AI_CPU) { - dump_flag_ |= RT_KERNEL_CUSTOM_AICPU; - } - davinci_model_->SetZeroCopyAddr(op_desc, io_addrs, args_addr.get(), args_, args_size_, sizeof(aicpu::AicpuParamHead)); return SUCCESS; @@ -1095,6 +1091,9 @@ void KernelTaskInfo::InitDumpArgs(uint32_t offset) { GELOGD("Op debug is open in kernel task info"); dump_args_ = static_cast(args_) + offset; } + if (kernel_type_ == ccKernelType::CUST_AI_CPU) { + dump_flag_ |= RT_KERNEL_CUSTOM_AICPU; + } } Status KernelTaskInfo::InitAicpuTaskExtInfo(const std::string &ext_info) { diff --git a/tests/ut/ge/graph/load/kernel_task_info_unittest.cc b/tests/ut/ge/graph/load/kernel_task_info_unittest.cc index 2cfb2a76..0c8da4b5 100644 --- a/tests/ut/ge/graph/load/kernel_task_info_unittest.cc +++ b/tests/ut/ge/graph/load/kernel_task_info_unittest.cc @@ -1184,6 +1184,22 @@ TEST_F(UtestKernelTaskInfo, kernel_task_info_calculate_args_aicpu) { EXPECT_EQ(kernel_task_info.CalculateArgs(task_def, &model), SUCCESS); } +TEST_F(UtestKernelTaskInfo, kernel_task_info_calculate_args_custom_aicpu) { + DavinciModel model(0, nullptr); + domi::TaskDef task_def; + + domi::KernelDef *kernel_def = task_def.mutable_kernel(); + domi::KernelContext *ctx = kernel_def->mutable_context(); + ctx->set_kernel_type(7); + + KernelTaskInfo kernel_task_info; + kernel_task_info.davinci_model_ = &model; + kernel_task_info.kernel_type_ = ccKernelType::CUST_AI_CPU; + kernel_task_info.op_desc_ = std::make_shared("concat", "TensorArrayWrite"); + kernel_task_info.InitDumpArgs(0); + EXPECT_EQ(kernel_task_info.CalculateArgs(task_def, &model), SUCCESS); +} + TEST_F(UtestKernelTaskInfo, kernel_task_info_update_args_te) { DavinciModel model(0, nullptr);