diff --git a/ge/ge_runtime/task/profiler_task.cc b/ge/ge_runtime/task/profiler_task.cc index dbbc8b8b..a9b80f01 100644 --- a/ge/ge_runtime/task/profiler_task.cc +++ b/ge/ge_runtime/task/profiler_task.cc @@ -40,7 +40,7 @@ ProfilerTask::~ProfilerTask() {} bool ProfilerTask::Distribute() { GELOGI("ProfilerTask Distribute start."); GELOGI("logid = %lu, notify = %d, flat = %u.", task_info_->log_id(), task_info_->notify(), task_info_->flat()); - rtError_t rt_ret = rtProfilerTrace(task_info_->log_id(), task_info_->notify(), task_info_->flat(), stream_); + rtError_t rt_ret = rtProfilerTraceEx(task_info_->log_id(), task_info_->notify(), task_info_->flat(), stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return false; diff --git a/ge/graph/load/model_manager/task_info/profiler_trace_task_info.cc b/ge/graph/load/model_manager/task_info/profiler_trace_task_info.cc index ce696978..218777b4 100755 --- a/ge/graph/load/model_manager/task_info/profiler_trace_task_info.cc +++ b/ge/graph/load/model_manager/task_info/profiler_trace_task_info.cc @@ -45,7 +45,7 @@ Status ProfilerTraceTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel * Status ProfilerTraceTaskInfo::Distribute() { GELOGI("ProfilerTraceTaskInfo Distribute Start. logid = %lu. notify = %d.", log_id_, notify_); - rtError_t rt_ret = rtProfilerTrace(log_id_, notify_, flat_, stream_); + rtError_t rt_ret = rtProfilerTraceEx(log_id_, notify_, flat_, stream_); if (rt_ret != RT_ERROR_NONE) { REPORT_CALL_ERROR("E19999", "Call rtProfilerTrace failed, ret:0x%X, logid:%lu. notify:%d", rt_ret, log_id_, notify_); diff --git a/ge/hybrid/node_executor/rts/rts_node_executor.cc b/ge/hybrid/node_executor/rts/rts_node_executor.cc index e3058ee3..64405820 100644 --- a/ge/hybrid/node_executor/rts/rts_node_executor.cc +++ b/ge/hybrid/node_executor/rts/rts_node_executor.cc @@ -116,7 +116,7 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function< uint32_t flat = log_time_stamp_def.flat(); GELOGD("ProfilingTraceTask execute async start. logid = %lu, notify = %d.", log_id, notify); - rtError_t rt_ret = rtProfilerTrace(log_id, notify, flat, context.GetStream()); + rtError_t rt_ret = rtProfilerTraceEx(log_id, notify, flat, context.GetStream()); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret); diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index a7afee3f..5c0cbb0d 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -533,6 +533,7 @@ set(DISTINCT_GRAPH_LOAD_TEST_FILES "graph/ge_executor_unittest.cc" "graph/load/model_helper_unittest.cc" "graph/load/model_utils_unittest.cc" + "graph/load/profiler_trace_task_info_unittest.cc" ) set(PASS_TEST_FILES diff --git a/tests/ut/ge/graph/load/profiler_trace_task_info_unittest.cc b/tests/ut/ge/graph/load/profiler_trace_task_info_unittest.cc new file mode 100644 index 00000000..7e8fe1e2 --- /dev/null +++ b/tests/ut/ge/graph/load/profiler_trace_task_info_unittest.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define private public +#define protected public + +#include "graph/load/model_manager/davinci_model.h" + +#include "graph/load/model_manager/task_info/profiler_trace_task_info" +#include "cce/aicpu_engine_struct.h" +#include "tests/depends/runtime/src/runtime_stub.h" + +namespace ge { +extern OpDescPtr CreateOpDesc(string name, string type); + +class UtestProfTraceTaskInfo : public testing::Test { + protected: + void SetUp() { + RTS_STUB_SETUP(); + } + + void TearDown() { + RTS_STUB_TEARDOWN(); + } +}; + +// test KernelTaskInfo Init. +TEST_F(UtestKernelTaskInfo, success_kernel_taskInfo_not_te) { + DavinciModel model(0, nullptr); + domi::ModelTaskDef model_task_def; + domi::TaskDef *task = model_task_def.add_task(); + task->set_type(RT_MODEL_TASK_KERNEL); + TaskInfoPtr task_info = TaskInfoFactory::Instance().Create(static_cast(task->type())); + +task->stream_id_ = 0; +rtStream_t stream = nullptr; +rtStreamCreate(&stream, 0); +model.stream_list_ = { stream }; + +domi::KernelDef *kernel_def = task->mutable_kernel(); +domi::KernelContext *ctx = kernel_def->mutable_context(); +model.op_list_[0] = CreateOpDesc("relu", RELU); +ctx->set_op_index(0); + +EXPECT_EQ(task_info->Init(*task, &model), FAILED); + +kernel_def->set_block_dim(10); +kernel_def->set_args("args111111", 10); +kernel_def->set_args_size(10); + +ctx->set_kernel_type(0); +EXPECT_EQ(task_info->Init(*task, &model), INTERNAL_ERROR); + +task_info->Release(); +} +} // namespace ge \ No newline at end of file