diff --git a/ge/hybrid/node_executor/aicore/aicore_node_executor.cc b/ge/hybrid/node_executor/aicore/aicore_node_executor.cc index 6ed6866c..3dca8661 100755 --- a/ge/hybrid/node_executor/aicore/aicore_node_executor.cc +++ b/ge/hybrid/node_executor/aicore/aicore_node_executor.cc @@ -29,7 +29,7 @@ bool IsNoOp(const NodeItem &node_item) { const auto &tensor_desc = node_item.MutableOutputDesc(i); GE_CHECK_NOTNULL(tensor_desc); const auto &shape = tensor_desc->MutableShape(); - if (shape.IsScalar() || shape.GetShapeSize() > 0) { + if (shape.IsScalar() || shape.GetShapeSize() > 0 || (node_item.shape_inference_type == DEPEND_SHAPE_RANGE)) { return false; } } @@ -219,12 +219,28 @@ Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); } - if (done_callback != nullptr) { + auto callback = done_callback; + if (!tasks_.empty()) { + // only last task need update outputs shape + auto task = tasks_.back().get(); + if (task->GetUnknownShapeOpType() == DEPEND_SHAPE_RANGE) { + callback = [=, &context]() { + Status callback_ret = SUCCESS; + GELOGD("Node[%s] need update outputs shape.", context.GetNodeName()); + callback_ret = task->UpdateOutputsShape(context); + if (done_callback != nullptr) { + context.SetStatus(callback_ret); + done_callback(); + } + }; + } + } + + if (callback != nullptr) { RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeRegisterCallback] Start"); - GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback)); + GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(callback)); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeRegisterCallback] End"); } - GELOGD("[%s] ExecuteAsync End.", context.GetNodeName()); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeTaskExecuteAsync] End"); return SUCCESS; diff --git a/ge/hybrid/node_executor/aicore/aicore_op_task.cc b/ge/hybrid/node_executor/aicore/aicore_op_task.cc index fe9bba9a..3ed8aa64 100644 --- a/ge/hybrid/node_executor/aicore/aicore_op_task.cc +++ b/ge/hybrid/node_executor/aicore/aicore_op_task.cc @@ -15,13 +15,15 @@ */ #include "hybrid/node_executor/aicore/aicore_op_task.h" -#include "framework/common/taskdown_common.h" + +#include "common/formats/formats.h" +#include "external/graph/types.h" #include "framework/common/debug/log.h" +#include "framework/common/taskdown_common.h" #include "graph/ge_context.h" +#include "graph/load/model_manager/tbe_handle_store.h" #include "hybrid/executor/hybrid_execution_context.h" #include "hybrid/node_executor/aicore/aicore_task_builder.h" -#include "graph/load/model_manager/tbe_handle_store.h" -#include "external/graph/types.h" #include "single_op/task/build_task_utils.h" #include "single_op/task/tbe_task_builder.h" @@ -35,6 +37,9 @@ constexpr char const *kAttrOpParamSize = "op_para_size"; constexpr char const *kAttrAtomicOpParamSize = "atomic_op_para_size"; const string kAtomicOpType = "DynamicAtomicAddrClean"; std::atomic log_id(0); +const uint32_t kMaxDimNum = 8; +// size,dim1,...,dim8: 9*4=36 +const size_t kShapeBufferSize = sizeof(uint32_t) * (1 + kMaxDimNum); } // namespace TbeHandleHolder::TbeHandleHolder(void *bin_handle) @@ -52,6 +57,30 @@ bool TbeHandleRegistry::AddHandle(std::unique_ptr &&holder) { } Status AiCoreOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) { + GE_CHK_STATUS_RET_NOLOG(DoInit(op_desc, task_def)); + int32_t unknown_shape_op_type_val = static_cast(DEPEND_IN_SHAPE); + (void)AttrUtils::GetInt(op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_op_type_val); + unknown_shape_op_type_ = static_cast(unknown_shape_op_type_val); + GELOGD("Op [%s] unknown shape type is %d", op_desc.GetName().c_str(), unknown_shape_op_type_); + if (unknown_shape_op_type_ == DEPEND_SHAPE_RANGE) { + size_t size = kShapeBufferSize * op_desc.GetOutputsSize(); + if (size == 0) { + GELOGE(PARAM_INVALID, "Op [%s] unknown shape type is %d, but outputs size is 0.", op_desc.GetName().c_str(), + unknown_shape_op_type_); + return PARAM_INVALID; + } + auto allocator = NpuMemoryAllocator::GetAllocator(); + GE_CHECK_NOTNULL(allocator); + shape_buffer_ = TensorBuffer::Create(allocator, size); + GE_CHECK_NOTNULL(shape_buffer_); + GELOGD("Op [%s] allocate memory for outputs shape success, size=%zu", op_desc.GetName().c_str(), size); + host_shape_buffer_.reset(new (std::nothrow) uint8_t[shape_buffer_->GetSize()]); + GE_CHECK_NOTNULL(host_shape_buffer_); + } + return SUCCESS; +} + +Status AiCoreOpTask::DoInit(const OpDesc &op_desc, const domi::TaskDef &task_def) { op_type_ = op_desc.GetType(); log_name_ = op_desc.GetName() + "_tvmbin"; log_id_ = log_id++; @@ -81,6 +110,74 @@ Status AiCoreOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) return SUCCESS; } +Status AiCoreOpTask::UpdateOutputsShape(TaskContext &context) const { + GELOGD("Node[%s] start update outputs shape.", context.GetNodeName()); + GE_CHECK_NOTNULL(shape_buffer_); + GE_CHECK_NOTNULL(host_shape_buffer_); + GE_CHK_RT_RET(rtMemcpy(host_shape_buffer_.get(), shape_buffer_->GetSize(), shape_buffer_->GetData(), + shape_buffer_->GetSize(), RT_MEMCPY_DEVICE_TO_HOST)); + int num_outputs = context.NumOutputs(); + auto outputs_shape = reinterpret_cast(host_shape_buffer_.get()); + for (int i = 0; i < num_outputs; ++i) { + if (outputs_shape[i][0] != 0) { + uint32_t dim_num = outputs_shape[i][0]; + GE_CHECK_LE(dim_num, kMaxDimNum); + vector dims; + for (uint32_t j = 1; j <= dim_num; ++j) { + dims.emplace_back(static_cast(outputs_shape[i][j])); + } + auto shape_new = GeShape(dims); + GELOGD("Node[%s] output[%d] shape:%s.", context.GetNodeName(), i, ToString(dims).c_str()); + GE_CHK_STATUS_RET_NOLOG(UpdateShapeToOutputDesc(context, shape_new, i)); + } + } + return SUCCESS; +} + +Status AiCoreOpTask::UpdateShapeToOutputDesc(TaskContext &context, const GeShape &shape, const int output_index) const { + auto output_desc = context.MutableOutputDesc(output_index); + GE_CHECK_NOTNULL(output_desc); + auto shape_old = output_desc->GetShape(); + auto origin_shape_old = output_desc->GetOriginShape(); + auto origin_format = output_desc->GetOriginFormat(); + auto format = output_desc->GetFormat(); + auto node_state = context.GetNodeState(); + GE_CHECK_NOTNULL(node_state); + if (origin_format == format) { + GELOGD( + "Node[%s] try to update output[%d] shape from [%s] to [%s], origin_shape " + "from [%s] to [%s].", + context.GetNodeName(), output_index, shape_old.ToString().c_str(), shape.ToString().c_str(), + origin_shape_old.ToString().c_str(), shape.ToString().c_str()); + GE_CHK_STATUS_RET(node_state->UpdateOutputShapes(output_index, shape, shape), + "Node[%s] try to update output[%d] shape from [%s] to [%s], origin_shape " + "from [%s] to [%s] failed.", + context.GetNodeName(), output_index, shape_old.ToString().c_str(), shape.ToString().c_str(), + origin_shape_old.ToString().c_str(), shape.ToString().c_str()); + return SUCCESS; + } + // if format is not same need convert shape + std::vector origin_dims_new; + auto trans_ret = + formats::TransShape(format, shape.GetDims(), output_desc->GetDataType(), origin_format, origin_dims_new); + GE_CHK_STATUS_RET(trans_ret, + "[Trans][Shape] failed for node[%s] output[%d], origin_format[%d] " + "is not same as format[%d], shape=[%s].", + context.GetNodeName(), output_index, origin_format, format, shape.ToString().c_str()); + auto origin_shape_new = GeShape(origin_dims_new); + GE_CHK_STATUS_RET(node_state->UpdateOutputShapes(output_index, shape, origin_shape_new), + "Node[%s] try to update output[%d] shape from [%s] to [%s], origin_shape " + "from [%s] to [%s] failed.", + context.GetNodeName(), output_index, shape_old.ToString().c_str(), shape.ToString().c_str(), + origin_shape_old.ToString().c_str(), origin_shape_new.ToString().c_str()); + GELOGD( + "Node[%s] update output[%d] shape from [%s] to [%s], origin_shape " + "from [%s] to [%s].", + context.GetNodeName(), output_index, shape_old.ToString().c_str(), shape.ToString().c_str(), + origin_shape_old.ToString().c_str(), origin_shape_new.ToString().c_str()); + return SUCCESS; +} + Status AiCoreOpTask::RegisterTbeHandle(const OpDesc &op_desc) { rtError_t rt_ret = rtQueryFunctionRegistered(stub_name_.c_str()); if (rt_ret != RT_ERROR_NONE) { @@ -429,6 +526,11 @@ Status AiCoreOpTask::UpdateArgs(TaskContext &task_context) { if (tiling_buffer_ != nullptr) { ++expected_arg_count; } + + if (shape_buffer_ != nullptr) { + ++expected_arg_count; + } + if (expected_arg_count > max_arg_count_) { GELOGD("Need to reset size of args_ from %u to %zu.", max_arg_count_, expected_arg_count); auto length = expected_arg_count * sizeof(uintptr_t) + offset_; @@ -465,6 +567,12 @@ Status AiCoreOpTask::UpdateArgs(TaskContext &task_context) { arg_base_[index++] = reinterpret_cast(output->GetData()); } + if (shape_buffer_ != nullptr) { + GE_CHK_RT_RET(rtMemset(shape_buffer_->GetData(), shape_buffer_->GetSize(), 0, shape_buffer_->GetSize())); + arg_base_[index++] = reinterpret_cast(shape_buffer_->GetData()); + GELOGD("Node:%s add shape buffer addr to args.", task_context.GetNodeName()); + } + int workspace_num = static_cast(task_context.NumWorkspaces()); for (int i = 0; i < workspace_num; ++i) { const auto workspace = task_context.MutableWorkspace(i); @@ -567,7 +675,7 @@ std::string AiCoreOpTask::GetKeyForKernelName(const OpDesc &op_desc) const { } Status AtomicAddrCleanOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) { - GE_CHK_STATUS_RET_NOLOG(AiCoreOpTask::Init(op_desc, task_def)); + GE_CHK_STATUS_RET_NOLOG(AiCoreOpTask::DoInit(op_desc, task_def)); return InitAtomicAddrCleanIndices(op_desc); } diff --git a/ge/hybrid/node_executor/aicore/aicore_op_task.h b/ge/hybrid/node_executor/aicore/aicore_op_task.h index 21a947f2..e2484fda 100755 --- a/ge/hybrid/node_executor/aicore/aicore_op_task.h +++ b/ge/hybrid/node_executor/aicore/aicore_op_task.h @@ -82,6 +82,12 @@ class AiCoreOpTask { virtual const std::string& GetOpType() const; + const UnknowShapeOpType GetUnknownShapeOpType() const { + return unknown_shape_op_type_; + } + + Status UpdateOutputsShape(TaskContext &context) const; + protected: Status UpdateTilingInfo(TaskContext &context); virtual std::string GetKeyForOpParamSize() const; @@ -90,6 +96,7 @@ class AiCoreOpTask { virtual std::string GetKeyForTvmMetaData() const; virtual std::string GetKeyForKernelName(const OpDesc &op_desc) const; virtual Status CalcTilingInfo(const NodePtr &node, optiling::utils::OpRunInfo &tiling_info); + Status DoInit(const OpDesc &op_desc, const domi::TaskDef &task_def); std::unique_ptr tiling_buffer_ = nullptr; std::string tiling_data_; @@ -104,6 +111,7 @@ class AiCoreOpTask { Status RegisterKernelHandle(const OpDesc &op_desc); Status InitWithKernelDef(const OpDesc &op_desc, const domi::TaskDef &task_def); Status InitWithKernelDefWithHandle(const OpDesc &node, const domi::TaskDef &task_def); + Status UpdateShapeToOutputDesc(TaskContext &context, const GeShape &shape, const int output_index) const; std::string stub_name_; void *stub_func_ = nullptr; @@ -122,6 +130,9 @@ class AiCoreOpTask { std::string log_name_; uint32_t offset_ = 0; std::string op_type_; + UnknowShapeOpType unknown_shape_op_type_ = DEPEND_IN_SHAPE; + std::unique_ptr shape_buffer_ = nullptr; + std::unique_ptr host_shape_buffer_ = nullptr; }; class AtomicAddrCleanOpTask : public AiCoreOpTask { diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index a7afee3f..f87770f3 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -735,6 +735,8 @@ set(HYBRID_TEST_FILES "hybrid/executor/hybrid_model_async_executor_unittest.cc" "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" + "hybrid/node_executor/aicore/aicore_op_task_unittest.cc" + "hybrid/node_executor/aicore/aicore_node_executor_unittest.cc" ) set(OTHERS_TEST_FILES diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 782a06d6..6f401bc7 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -821,6 +821,8 @@ TEST_F(UtestGeHybrid, TestTaskExecuteAsync) { node_item->output_start = 0; GraphExecutionContext execution_context; + execution_context.callback_manager = + std::unique_ptr(new (std::nothrow) CallbackManager()); GraphItem graph_item; SubgraphContext subgraph_context(&graph_item, &execution_context); ASSERT_EQ(subgraph_context.Init(), SUCCESS); diff --git a/tests/ut/ge/hybrid/node_executor/aicore/aicore_node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/aicore/aicore_node_executor_unittest.cc new file mode 100644 index 00000000..174f322c --- /dev/null +++ b/tests/ut/ge/hybrid/node_executor/aicore/aicore_node_executor_unittest.cc @@ -0,0 +1,144 @@ +/** + * Copyright 2021-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +#define private public +#define protected public +#include "framework/common/taskdown_common.h" +#include "hybrid/executor/rt_callback_manager.h" +#include "hybrid/executor/subgraph_context.h" +#include "hybrid/node_executor/aicore/aicore_node_executor.h" +#include "init/gelib.h" +#undef private +#undef protected + +using namespace std; +using namespace testing; + +namespace ge { +using namespace hybrid; + +class UtestAiCoreNodeExecutor : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +static ge::OpDescPtr CreateOpDesc(string name = "", string type = "", + int in_num = 0, int out_num = 0) { + auto op_desc = std::make_shared(name, type); + op_desc->SetStreamId(0); + static int32_t index = 0; + op_desc->SetId(index++); + + GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); + TensorUtils::SetSize(tensor, 64); + vector input_offset; + for (int i = 0; i < in_num; ++i) { + op_desc->AddInputDesc(tensor); + input_offset.emplace_back(index * 64 + i * 64); + } + op_desc->SetInputOffset(input_offset); + + vector output_offset; + for (int i = 0; i < out_num; ++i) { + op_desc->AddOutputDesc(tensor); + output_offset.emplace_back(index * 64 + in_num * 64 + i * 64); + } + op_desc->SetOutputOffset(output_offset); + + op_desc->SetWorkspace({}); + op_desc->SetWorkspaceBytes({}); + + ge::AttrUtils::SetStr(op_desc, ge::TVM_ATTR_NAME_MAGIC, + "RT_DEV_BINARY_MAGIC_ELF_AIVEC"); + bool support_dynamic = true; + ge::AttrUtils::GetBool(op_desc, "support_dynamicshape", support_dynamic); + return op_desc; +} + +TEST_F(UtestAiCoreNodeExecutor, callback_success) { + dlog_setlevel(0, 0, 0); + std::unique_ptr task1(new AiCoreOpTask()); + OpDescPtr op_desc = CreateOpDesc("Add", "Add", 2, 1); + ge::AttrUtils::SetInt(*op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, + DEPEND_SHAPE_RANGE); + domi::TaskDef task_def; + task_def.set_type(RT_MODEL_TASK_KERNEL); + std::vector args(100, 0); + task_def.mutable_kernel()->set_args(args.data(), args.size()); + task_def.mutable_kernel()->set_args_size(100); + task_def.mutable_kernel()->mutable_context()->set_kernel_type( + ccKernelType::TE); + uint16_t args_offset = 20; + char *a = reinterpret_cast(&args_offset); + task_def.mutable_kernel()->mutable_context()->set_args_offset( + a, 2 * sizeof(uint16_t)); + EXPECT_EQ(task1->Init(*op_desc, task_def), ge::SUCCESS); + + ComputeGraphPtr graph = std::make_shared("test"); + NodePtr node = graph->AddNode(op_desc); + std::unique_ptr new_node; + ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); + NodeItem *node_item = new_node.get(); + node_item->input_start = 0; + node_item->output_start = 0; + node_item->is_dynamic = true; + node_item->shape_inference_type = DEPEND_SHAPE_RANGE; + + GraphItem graph_item; + graph_item.node_items_.emplace_back(node_item); + graph_item.total_inputs_ = 2; + graph_item.total_outputs_ = 1; + + GeRootModelPtr ge_root_model = std::make_shared(graph); + ge_root_model->SetModelName("test_name"); + HybridModel hybrid_model(ge_root_model); + + GraphExecutionContext graph_context; + graph_context.model = &hybrid_model; + SubgraphContext subgraph_context(&graph_item, &graph_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); + graph_context.callback_manager = + std::unique_ptr(new CallbackManager()); + + auto node_state = subgraph_context.GetOrCreateNodeState(node_item); + ASSERT_NE(node_state, nullptr); + auto outputs_shape = + reinterpret_cast(task1->shape_buffer_->GetData()); + outputs_shape[0][0] = 2; + outputs_shape[0][1] = 1; + outputs_shape[0][2] = 2; + std::vector> tasks; + tasks.emplace_back(std::move(task1)); + std::unique_ptr aicore_node_task; + aicore_node_task.reset(new (std::nothrow) AiCoreNodeTask(std::move(tasks))); + ASSERT_EQ( + aicore_node_task->ExecuteAsync(*node_state->GetTaskContext(), nullptr), + SUCCESS); + std::pair> entry; + node_state->GetTaskContext() + ->execution_context_->callback_manager->callback_queue_.Pop(entry); + auto cb_func = entry.second.first; + auto cb_args = entry.second.second; + cb_func(cb_args); + dlog_setlevel(0, 3, 0); +} +} // namespace ge diff --git a/tests/ut/ge/hybrid/node_executor/aicore/aicore_op_task_unittest.cc b/tests/ut/ge/hybrid/node_executor/aicore/aicore_op_task_unittest.cc new file mode 100644 index 00000000..d6b39996 --- /dev/null +++ b/tests/ut/ge/hybrid/node_executor/aicore/aicore_op_task_unittest.cc @@ -0,0 +1,168 @@ +/** + * Copyright 2021-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +#define private public +#define protected public +#include "framework/common/taskdown_common.h" +#include "hybrid/executor/subgraph_context.h" +#include "hybrid/node_executor/aicore/aicore_op_task.h" +#include "init/gelib.h" +#undef private +#undef protected + +using namespace std; +using namespace testing; + +namespace ge { +using namespace hybrid; + +class UtestAiCoreOpTask : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +static ge::OpDescPtr CreateOpDesc(string name = "", string type = "", + int in_num = 0, int out_num = 0) { + auto op_desc = std::make_shared(name, type); + op_desc->SetStreamId(0); + static int32_t index = 0; + op_desc->SetId(index++); + + GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); + TensorUtils::SetSize(tensor, 64); + vector input_offset; + for (int i = 0; i < in_num; ++i) { + op_desc->AddInputDesc(tensor); + input_offset.emplace_back(index * 64 + i * 64); + } + op_desc->SetInputOffset(input_offset); + + vector output_offset; + for (int i = 0; i < out_num; ++i) { + op_desc->AddOutputDesc(tensor); + output_offset.emplace_back(index * 64 + in_num * 64 + i * 64); + } + op_desc->SetOutputOffset(output_offset); + + op_desc->SetWorkspace({}); + op_desc->SetWorkspaceBytes({}); + + ge::AttrUtils::SetStr(op_desc, ge::TVM_ATTR_NAME_MAGIC, + "RT_DEV_BINARY_MAGIC_ELF_AIVEC"); + bool support_dynamic = true; + ge::AttrUtils::GetBool(op_desc, "support_dynamicshape", support_dynamic); + return op_desc; +} + +TEST_F(UtestAiCoreOpTask, Init_failed) { + dlog_setlevel(0, 0, 0); + std::unique_ptr task1(new AiCoreOpTask()); + OpDescPtr op_desc = CreateOpDesc("Add", "Add"); + ge::AttrUtils::SetInt(*op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, + DEPEND_SHAPE_RANGE); + domi::TaskDef task_def; + task_def.set_type(RT_MODEL_TASK_KERNEL); + std::vector args(100, 0); + task_def.mutable_kernel()->set_args(args.data(), args.size()); + task_def.mutable_kernel()->set_args_size(100); + task_def.mutable_kernel()->mutable_context()->set_kernel_type( + ccKernelType::TE); + uint16_t args_offset = 20; + char *a = reinterpret_cast(&args_offset); + task_def.mutable_kernel()->mutable_context()->set_args_offset( + a, 2 * sizeof(uint16_t)); + EXPECT_EQ(task1->Init(*op_desc, task_def), ge::PARAM_INVALID); + dlog_setlevel(0, 3, 0); +} + +TEST_F(UtestAiCoreOpTask, Init_success) { + dlog_setlevel(0, 0, 0); + std::unique_ptr task1(new AiCoreOpTask()); + OpDescPtr op_desc = CreateOpDesc("Add", "Add", 2, 1); + ge::AttrUtils::SetInt(*op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, + DEPEND_SHAPE_RANGE); + domi::TaskDef task_def; + task_def.set_type(RT_MODEL_TASK_KERNEL); + std::vector args(100, 0); + task_def.mutable_kernel()->set_args(args.data(), args.size()); + task_def.mutable_kernel()->set_args_size(100); + task_def.mutable_kernel()->mutable_context()->set_kernel_type( + ccKernelType::TE); + uint16_t args_offset = 20; + char *a = reinterpret_cast(&args_offset); + task_def.mutable_kernel()->mutable_context()->set_args_offset( + a, 2 * sizeof(uint16_t)); + EXPECT_EQ(task1->Init(*op_desc, task_def), ge::SUCCESS); + dlog_setlevel(0, 3, 0); +} + +TEST_F(UtestAiCoreOpTask, UpdateOutputsShape_success) { + dlog_setlevel(0, 0, 0); + std::unique_ptr task1(new AiCoreOpTask()); + OpDescPtr op_desc = CreateOpDesc("Add", "Add", 2, 1); + ge::AttrUtils::SetInt(*op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, + DEPEND_SHAPE_RANGE); + domi::TaskDef task_def; + task_def.set_type(RT_MODEL_TASK_KERNEL); + std::vector args(100, 0); + task_def.mutable_kernel()->set_args(args.data(), args.size()); + task_def.mutable_kernel()->set_args_size(100); + task_def.mutable_kernel()->mutable_context()->set_kernel_type( + ccKernelType::TE); + uint16_t args_offset = 20; + char *a = reinterpret_cast(&args_offset); + task_def.mutable_kernel()->mutable_context()->set_args_offset( + a, 2 * sizeof(uint16_t)); + EXPECT_EQ(task1->Init(*op_desc, task_def), ge::SUCCESS); + + ComputeGraphPtr graph = std::make_shared("test"); + NodePtr node = graph->AddNode(op_desc); + std::unique_ptr new_node; + ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); + NodeItem *node_item = new_node.get(); + node_item->input_start = 0; + node_item->output_start = 0; + node_item->is_dynamic = true; + node_item->shape_inference_type = DEPEND_SHAPE_RANGE; + + GraphItem graph_item; + graph_item.node_items_.emplace_back(node_item); + graph_item.total_inputs_ = 2; + graph_item.total_outputs_ = 1; + + GraphExecutionContext graph_context; + SubgraphContext subgraph_context(&graph_item, &graph_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); + graph_context.callback_manager = + std::unique_ptr(new CallbackManager()); + + auto node_state = subgraph_context.GetOrCreateNodeState(node_item); + ASSERT_NE(node_state, nullptr); + auto outputs_shape = + reinterpret_cast(task1->shape_buffer_->GetData()); + outputs_shape[0][0] = 2; + outputs_shape[0][1] = 1; + outputs_shape[0][2] = 2; + ASSERT_EQ(task1->UpdateOutputsShape(*node_state->GetTaskContext()), SUCCESS); + dlog_setlevel(0, 3, 0); +} +} // namespace ge