Browse Source

Pre Merge pull request !2054 from zhupuxu/use_new_func

pull/2054/MERGE
zhupuxu Gitee 3 years ago
parent
commit
c2e44e7323
5 changed files with 75 additions and 3 deletions
  1. +1
    -1
      ge/ge_runtime/task/profiler_task.cc
  2. +1
    -1
      ge/graph/load/model_manager/task_info/profiler_trace_task_info.cc
  3. +1
    -1
      ge/hybrid/node_executor/rts/rts_node_executor.cc
  4. +1
    -0
      tests/ut/ge/CMakeLists.txt
  5. +71
    -0
      tests/ut/ge/graph/load/profiler_trace_task_info_unittest.cc

+ 1
- 1
ge/ge_runtime/task/profiler_task.cc View File

@@ -40,7 +40,7 @@ ProfilerTask::~ProfilerTask() {}
bool ProfilerTask::Distribute() {
GELOGI("ProfilerTask Distribute start.");
GELOGI("logid = %lu, notify = %d, flat = %u.", task_info_->log_id(), task_info_->notify(), task_info_->flat());
rtError_t rt_ret = rtProfilerTrace(task_info_->log_id(), task_info_->notify(), task_info_->flat(), stream_);
rtError_t rt_ret = rtProfilerTraceEx(task_info_->log_id(), task_info_->notify(), task_info_->flat(), stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;


+ 1
- 1
ge/graph/load/model_manager/task_info/profiler_trace_task_info.cc View File

@@ -45,7 +45,7 @@ Status ProfilerTraceTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *
Status ProfilerTraceTaskInfo::Distribute() {
GELOGI("ProfilerTraceTaskInfo Distribute Start. logid = %lu. notify = %d.", log_id_, notify_);

rtError_t rt_ret = rtProfilerTrace(log_id_, notify_, flat_, stream_);
rtError_t rt_ret = rtProfilerTraceEx(log_id_, notify_, flat_, stream_);
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtProfilerTrace failed, ret:0x%X, logid:%lu. notify:%d",
rt_ret, log_id_, notify_);


+ 1
- 1
ge/hybrid/node_executor/rts/rts_node_executor.cc View File

@@ -116,7 +116,7 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function<
uint32_t flat = log_time_stamp_def.flat();

GELOGD("ProfilingTraceTask execute async start. logid = %lu, notify = %d.", log_id, notify);
rtError_t rt_ret = rtProfilerTrace(log_id, notify, flat, context.GetStream());
rtError_t rt_ret = rtProfilerTraceEx(log_id, notify, flat, context.GetStream());
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);


+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -533,6 +533,7 @@ set(DISTINCT_GRAPH_LOAD_TEST_FILES
"graph/ge_executor_unittest.cc"
"graph/load/model_helper_unittest.cc"
"graph/load/model_utils_unittest.cc"
"graph/load/profiler_trace_task_info_unittest.cc"
)

set(PASS_TEST_FILES


+ 71
- 0
tests/ut/ge/graph/load/profiler_trace_task_info_unittest.cc View File

@@ -0,0 +1,71 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>

#define private public
#define protected public

#include "graph/load/model_manager/davinci_model.h"

#include "graph/load/model_manager/task_info/profiler_trace_task_info"
#include "cce/aicpu_engine_struct.h"
#include "tests/depends/runtime/src/runtime_stub.h"

namespace ge {
extern OpDescPtr CreateOpDesc(string name, string type);

class UtestProfTraceTaskInfo : public testing::Test {
protected:
void SetUp() {
RTS_STUB_SETUP();
}

void TearDown() {
RTS_STUB_TEARDOWN();
}
};

// test KernelTaskInfo Init.
TEST_F(UtestKernelTaskInfo, success_kernel_taskInfo_not_te) {
DavinciModel model(0, nullptr);
domi::ModelTaskDef model_task_def;
domi::TaskDef *task = model_task_def.add_task();
task->set_type(RT_MODEL_TASK_KERNEL);
TaskInfoPtr task_info = TaskInfoFactory::Instance().Create(static_cast<rtModelTaskType_t>(task->type()));

task->stream_id_ = 0;
rtStream_t stream = nullptr;
rtStreamCreate(&stream, 0);
model.stream_list_ = { stream };

domi::KernelDef *kernel_def = task->mutable_kernel();
domi::KernelContext *ctx = kernel_def->mutable_context();
model.op_list_[0] = CreateOpDesc("relu", RELU);
ctx->set_op_index(0);

EXPECT_EQ(task_info->Init(*task, &model), FAILED);

kernel_def->set_block_dim(10);
kernel_def->set_args("args111111", 10);
kernel_def->set_args_size(10);

ctx->set_kernel_type(0);
EXPECT_EQ(task_info->Init(*task, &model), INTERNAL_ERROR);

task_info->Release();
}
} // namespace ge

Loading…
Cancel
Save