diff --git a/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc index 76e2ca5f..9b8ef739 100755 --- a/ge/single_op/task/op_task.cc +++ b/ge/single_op/task/op_task.cc @@ -346,49 +346,95 @@ Status TbeOpTask::AllocateWorkspaces(const vector &workspace_sizes) { return SUCCESS; } -Status TbeOpTask::LaunchKernel(const vector &input_desc, - const vector &input_buffers, - vector &output_desc, - vector &output_buffers, - rtStream_t stream) { - GELOGD("[%s] Start to launch kernel", node_->GetName().c_str()); - GE_CHK_STATUS_RET_NOLOG(UpdateNodeByShape(input_desc, output_desc)); - GE_CHK_STATUS_RET_NOLOG(UpdateRunInfo()); - GE_CHK_STATUS_RET(AllocateWorkspaces(run_info_workspaces_), "[Allocate][Workspaces] failed."); - std::vector args; - for (auto &buffer : input_buffers) { - args.emplace_back(buffer.data); +Status TbeOpTask::UpdateTilingArgs(rtStream_t stream) { + size_t args_size = input_num_ + output_num_ + workspaces_.size(); + if (tiling_buffer_ != nullptr) { + args_size++; } - for (auto &buffer : output_buffers) { - args.emplace_back(buffer.data); + size_t temp_size = args_size * sizeof(void *); + if (arg_size_ < temp_size) { + GELOGD("Need to reset size of args_ from %zu to %zu.", arg_size_, temp_size); + std::unique_ptr args(new (std::nothrow) uint8_t[temp_size]()); + GE_CHECK_NOTNULL(args); + if (memcpy_s(args.get(), temp_size, args_.get(), arg_size_) != EOK) { + GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Update][KernelArgs] failed for [%s].", node_->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "update kernel args failed for %s.", node_->GetName().c_str()); + return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; + } + + args_ = std::move(args); + arg_size_ = temp_size; } - for (auto &buffer : workspaces_) { - args.emplace_back(buffer); + + uintptr_t *arg_base = reinterpret_cast(args_.get()); + size_t arg_index = input_num_ + output_num_; + for (size_t i = 0; i < workspaces_.size(); ++i) { + arg_base[arg_index++] = reinterpret_cast(workspaces_[i]); } if (tiling_buffer_ != nullptr) { GELOGD("[%s] Start to copy tiling info. size = %zu", node_->GetName().c_str(), tiling_data_.size()); GE_CHK_RT_RET(rtMemcpyAsync(tiling_buffer_, max_tiling_size_, tiling_data_.data(), tiling_data_.size(), RT_MEMCPY_HOST_TO_DEVICE_EX, stream)); + arg_base[arg_index] = reinterpret_cast(tiling_buffer_); + } + + return SUCCESS; +} + +Status TbeOpTask::SetArgIndex() { + const vector v_is_input_const = op_desc_->GetIsInputConst(); + size_t input_index = 0; + for (size_t i = 0; i < op_desc_->GetAllInputsSize(); ++i) { + const GeTensorDescPtr tensor_desc = op_desc_->MutableInputDesc(static_cast(i)); + if (tensor_desc == nullptr) { + GELOGD("SingleOp: %s, Index: %zu, has no input", op_desc_->GetName().c_str(), i); + continue; + } + if (i < v_is_input_const.size() && v_is_input_const[i]) { + GELOGD("SingleOp: %s, Index: %zu, input is const", op_desc_->GetName().c_str(), i); + input_index++; + continue; + } + arg_index_.emplace_back(input_index); + input_index++; + } + return SUCCESS; +} - args.emplace_back(tiling_buffer_); +Status TbeOpTask::UpdateIoAddr(const vector &inputs, const vector &outputs) { + if (arg_index_.size() != inputs.size()) { + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Size] Args size is %zu, but get input size is %zu.", + arg_index_.size(), inputs.size()); + REPORT_INNER_ERROR("E19999", "[Check][Size] Args size is %zu, but get input size is %zu.", + arg_index_.size(), inputs.size()); + return ACL_ERROR_GE_PARAM_INVALID; } - GELOGD("Dst size is %zu, src size is %zu.", arg_size_, args.size() * sizeof(void *)); - // node with workspace: build can not get size of workspace, need to update arg_size_ when execute - if (arg_size_ < (args.size() * sizeof(void *))) { - size_t temp_size = args.size() * sizeof(void *); - GELOGD("Need to reset size of args_ from %zu to %zu.", arg_size_, temp_size); - args_.reset(new(std::nothrow) uint8_t[temp_size]()); - GE_CHECK_NOTNULL(args_); - arg_size_ = temp_size; + uintptr_t *arg_base = reinterpret_cast(args_.get()); + for (size_t i = 0; i < arg_index_.size(); ++i) { + arg_base[arg_index_[i]] = reinterpret_cast(inputs[i].data); } - if (memcpy_s(args_.get(), arg_size_, args.data(), args.size() * sizeof(void *)) != EOK) { - GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Update][KernelArgs] failed for [%s].", node_->GetName().c_str()); - REPORT_INNER_ERROR("E19999", "update kernel args failed for %s.", node_->GetName().c_str()); - return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; + + for (size_t i = 0; i < op_desc_->GetOutputsSize(); ++i) { + arg_base[input_num_ + i] = reinterpret_cast(outputs[i].data); } + return SUCCESS; +} + +Status TbeOpTask::LaunchKernel(const vector &input_desc, + const vector &input_buffers, + vector &output_desc, + vector &output_buffers, + rtStream_t stream) { + GELOGD("[%s] Start to launch kernel", node_->GetName().c_str()); + GE_CHK_STATUS_RET(UpdateIoAddr(input_buffers, output_buffers), "[Update][IoAddr] failed."); + GE_CHK_STATUS_RET_NOLOG(UpdateNodeByShape(input_desc, output_desc)); + GE_CHK_STATUS_RET_NOLOG(UpdateRunInfo()); + GE_CHK_STATUS_RET(AllocateWorkspaces(run_info_workspaces_), "[Allocate][Workspaces] failed."); + GE_CHK_STATUS_RET(UpdateTilingArgs(stream), "[Update][TilingArgs] failed."); + GELOGD("[%s] Start to invoke rtKernelLaunch", node_->GetName().c_str()); GE_CHK_STATUS_RET(DoLaunchKernel(stream), "Failed to do launch kernel."); diff --git a/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h index eb660aa5..0aa2b617 100644 --- a/ge/single_op/task/op_task.h +++ b/ge/single_op/task/op_task.h @@ -87,6 +87,7 @@ class TbeOpTask : public OpTask { const OpDescPtr &op_desc, const domi::KernelDefWithHandle& kernel_def_with_handle); Status UpdateRunInfo() override; + Status SetArgIndex(); const void *GetArgs() const; size_t GetArgSize() const; @@ -102,7 +103,9 @@ class TbeOpTask : public OpTask { Status UpdateNodeByShape(const vector &input_desc, const vector &output_desc); Status AllocateWorkspaces(const std::vector &workspace_sizes); + Status UpdateTilingArgs(rtStream_t stream); Status DoLaunchKernel(rtStream_t stream); + Status UpdateIoAddr(const vector &inputs, const vector &outputs); const void *stub_func_ = nullptr; std::unique_ptr args_; @@ -122,6 +125,9 @@ class TbeOpTask : public OpTask { void* handle_ = nullptr; std::string original_kernel_key_; std::string node_info_; + std::vector arg_index_; // data index in args + size_t input_num_; // include const input + size_t output_num_; }; class AiCpuBaseTask : public OpTask { diff --git a/ge/single_op/task/tbe_task_builder.cc b/ge/single_op/task/tbe_task_builder.cc index c7ff13d1..db8ecfe2 100644 --- a/ge/single_op/task/tbe_task_builder.cc +++ b/ge/single_op/task/tbe_task_builder.cc @@ -387,6 +387,9 @@ Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam ¶ } task.SetStubFunc(stub_name_, stub_func); } + GE_CHK_STATUS_RET(task.SetArgIndex(), "[Set][ArgTable] failed."); + task.input_num_ = op_desc_->GetInputsSize(); + task.output_num_ = op_desc_->GetOutputsSize(); return SUCCESS; } diff --git a/tests/ut/ge/single_op/single_op_task_unittest.cc b/tests/ut/ge/single_op/single_op_task_unittest.cc index a17c9012..b0c98205 100644 --- a/tests/ut/ge/single_op/single_op_task_unittest.cc +++ b/tests/ut/ge/single_op/single_op_task_unittest.cc @@ -91,10 +91,11 @@ TEST_F(UtestSingleOpTask, test_build_kernel_task) { TbeOpTask task_tmp; TbeOpTask *task = &task_tmp; ASSERT_EQ(model.BuildKernelTask(task_def, &task), SUCCESS); + ge::DataBuffer data_buffer; vector input_desc; - vector input_buffers; + vector input_buffers = { data_buffer }; vector output_desc; - vector output_buffers; + vector output_buffers = { data_buffer }; task->node_ = node; OpTilingFunc op_tiling_func = [](const TeOpParas &, const OpCompileInfo &, OpRunInfo &) -> bool {return true;}; OpTilingRegistryInterf("Add", op_tiling_func); @@ -106,12 +107,49 @@ TEST_F(UtestSingleOpTask, test_build_kernel_task) { task->max_tiling_size_ = 64; task->tiling_data_ = "tiling_data"; task->arg_size_ = 64; - uint8_t task_args{0}; - task->args_.reset(&task_args); + task->args_.reset(new (std::nothrow) uint8_t[sizeof(void *) * 3]); ASSERT_EQ(task->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_), SUCCESS); - char handle_tmp = '0'; - char *handle = &handle_tmp; + char *handle = "00"; task->SetHandle(handle); ASSERT_EQ(task->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_), SUCCESS); -} \ No newline at end of file +} + +TEST_F(UtestSingleOpTask, test_update_ioaddr) { + auto graph = make_shared("graph"); + auto op_desc = make_shared("Add", "Add"); + + GeTensorDesc desc; + op_desc->AddInputDesc(desc); + op_desc->AddInputDesc(desc); + op_desc->AddOutputDesc(desc); + vector is_input_const = { true, false }; + op_desc->SetIsInputConst(is_input_const); + auto node = graph->AddNode(op_desc); + + TbeOpTask task; + task.op_desc_ = op_desc; + task.node_ = node; + ASSERT_EQ(task.SetArgIndex(), SUCCESS); + task.arg_size_ = sizeof(void *) * 4; + task.args_.reset(new (std::nothrow) uint8_t[task.arg_size_]); + task.arg_index_ = {0}; + task.input_num_ = 2; + task.output_num_ = 1; + + vector args; + vector inputs; + vector outputs; + ASSERT_EQ(task.UpdateIoAddr(inputs, outputs), ACL_ERROR_GE_PARAM_INVALID); + + ge::DataBuffer data_buffer; + inputs = { data_buffer }; + outputs = { data_buffer }; + ASSERT_EQ(task.UpdateIoAddr(inputs, outputs), SUCCESS); + + task.tiling_buffer_ = (void *)0x0001; + task.workspaces_ = { (void *)0x0002 }; + ASSERT_EQ(task.UpdateTilingArgs(nullptr), SUCCESS); + task.tiling_buffer_ = nullptr; +} + diff --git a/tests/ut/ge/single_op/single_op_unittest.cc b/tests/ut/ge/single_op/single_op_unittest.cc index 831f3f16..181805ff 100644 --- a/tests/ut/ge/single_op/single_op_unittest.cc +++ b/tests/ut/ge/single_op/single_op_unittest.cc @@ -104,7 +104,7 @@ TEST_F(UtestSingleOp, test_dynamic_singleop_execute_async1) { EXPECT_EQ(desc_ptr->AddInputDesc("x", GeTensorDesc(GeShape({2}), FORMAT_NCHW)), GRAPH_SUCCESS); dynamic_single_op.op_task_->op_desc_ = desc_ptr; // UpdateRunInfo failed - EXPECT_EQ(dynamic_single_op.ExecuteAsync(input_desc, input_buffers, output_desc, output_buffers), PARAM_INVALID); + EXPECT_EQ(dynamic_single_op.ExecuteAsync(input_desc, input_buffers, output_desc, output_buffers), ACL_ERROR_GE_PARAM_INVALID); }