From 5d897e344ea703995aee5ad04b33d7f0d787b30c Mon Sep 17 00:00:00 2001 From: lichun Date: Mon, 17 May 2021 14:07:38 +0800 Subject: [PATCH] fix add tbe kernel failed --- ge/hybrid/node_executor/aicore/aicore_op_task.cc | 5 ++-- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 33 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/ge/hybrid/node_executor/aicore/aicore_op_task.cc b/ge/hybrid/node_executor/aicore/aicore_op_task.cc index 61942d51..40118af3 100644 --- a/ge/hybrid/node_executor/aicore/aicore_op_task.cc +++ b/ge/hybrid/node_executor/aicore/aicore_op_task.cc @@ -23,6 +23,7 @@ #include "graph/load/model_manager/tbe_handle_store.h" #include "graph/types.h" #include "single_op/task/build_task_utils.h" +#include "single_op/task/tbe_task_builder.h" using optiling::OpRunInfo; @@ -131,8 +132,8 @@ Status AiCoreOpTask::RegisterTbeHandle(const OpDesc &op_desc) { GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc_ptr, GetKeyForKernelName(op_desc), kernel_name), GELOGI("Get original type of kernel_name")); GELOGI("TBE: binfile_key=%s, kernel_name=%s", stub_name_.c_str(), kernel_name.c_str()); - GE_CHK_RT_RET(rtFunctionRegister(bin_handle, stub_name_.c_str(), - stub_name_.c_str(), kernel_name.c_str(), 0)); + auto stub_func = KernelBinRegistry::GetInstance().GetUnique(stub_name_); + GE_CHK_RT_RET(rtFunctionRegister(bin_handle, stub_func, stub_name_.c_str(), kernel_name.c_str(), 0)); } return SUCCESS; } diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 7cd8a30a..7a2a5dfe 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -109,6 +109,39 @@ TEST_F(UtestGeHybrid, aicore_op_task_init_success) { ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS); } +TEST_F(UtestGeHybrid, aicore_op_task_init_success2) { + // build aicore task + auto aicore_task = std::unique_ptr(new(std::nothrow)hybrid::AiCoreOpTask()); + aicore_task->is_single_op_ = true; + domi::TaskDef task_def; + task_def.set_type(RT_MODEL_TASK_KERNEL); + domi::KernelDef *kernel = task_def.mutable_kernel(); + kernel->set_block_dim(32); + kernel->set_args_size(64); + string args(64, '1'); + kernel->set_args(args.data(), 64); + domi::KernelContext *context = kernel->mutable_context(); + context->set_op_index(1); + context->set_kernel_type(2); // ccKernelType::TE + uint16_t args_offset[9] = {0}; + context->set_args_offset(args_offset, 9 * sizeof(uint16_t)); + + OpDescPtr op_desc = CreateOpDesc("Add", "Add"); + std::vector kernelBin; + TBEKernelPtr tbe_kernel = std::make_shared("name/Add", std::move(kernelBin)); + op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); + std::string kernel_name("kernel/Add"); + AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name); + ASSERT_EQ(aicore_task->InitWithTaskDef(*op_desc.get(), task_def), SUCCESS); + rtStream_t stream = nullptr; + rtStreamCreate(&stream, 0); + ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS); + char *handle = ""; + aicore_task->handle_ = handle; + aicore_task->tiling_key_ = 1; + ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS); +} + TEST_F(UtestGeHybrid, task_update_tiling_info) { auto aicore_task = std::unique_ptr(new(std::nothrow)hybrid::AiCoreOpTask()); auto graph = make_shared("graph");