Browse Source

fix tbe handle not unregisted

pull/1640/head
lichun 4 years ago
parent
commit
8f03c05bcc
1 changed files with 37 additions and 0 deletions
  1. +37
    -0
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 37
- 0
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -113,6 +113,43 @@ TEST_F(UtestGeHybrid, aicore_op_task_init_success) {
ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS);
}

TEST_F(UtestGeHybrid, aicore_op_task_init_success2) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
HybridModel hybrid_model(ge_root_model);

// build aicore task
auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask());
domi::TaskDef task_def;
task_def.set_type(RT_MODEL_TASK_KERNEL);
domi::KernelDef *kernel = task_def.mutable_kernel();
kernel->set_original_kernel_key("");
kernel->set_node_info("");
kernel->set_block_dim(32);
kernel->set_args_size(64);
string args(64, '1');
kernel->set_args(args.data(), 64);
domi::KernelContext *context = kernel->mutable_context();
context->set_op_index(1);
context->set_kernel_type(2); // ccKernelType::TE
uint16_t args_offset[9] = {0};
context->set_args_offset(args_offset, 9 * sizeof(uint16_t));
OpDescPtr op_desc = CreateOpDesc("Add", "Add");
std::vector<char> kernelBin;
TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin));
op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel);
std::string kernel_name("kernel/Add");
AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name);
ASSERT_EQ(aicore_task->InitWithTaskDef(hybrid_model, *op_desc.get(), task_def), SUCCESS);
rtStream_t stream = nullptr;
rtStreamCreate(&stream, 0);
ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS);
char *handle = "";
aicore_task->handle_ = handle;
aicore_task->tiling_key_ = 1;
ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS);
}

TEST_F(UtestGeHybrid, task_update_tiling_info) {
auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask());
auto graph = make_shared<ComputeGraph>("graph");


Loading…
Cancel
Save