Browse Source

Pre Merge pull request !2075 from zhupuxu/task_info

pull/2075/MERGE
zhupuxu Gitee 3 years ago
parent
commit
773844e998
4 changed files with 102 additions and 9 deletions
  1. +47
    -3
      ge/graph/load/model_manager/task_info/profiler_trace_task_info.cc
  2. +6
    -1
      ge/graph/load/model_manager/task_info/profiler_trace_task_info.h
  3. +48
    -5
      ge/hybrid/node_executor/rts/rts_node_executor.cc
  4. +1
    -0
      ge/hybrid/node_executor/rts/rts_node_executor.h

+ 47
- 3
ge/graph/load/model_manager/task_info/profiler_trace_task_info.cc View File

@@ -19,6 +19,49 @@
#include "framework/common/debug/ge_log.h"
#include "graph/load/model_manager/davinci_model.h"

namespace {
const uint64_t kProfilingFpStartLogid = 1;
const uint64_t kProfilingBpEndLogid = 2;
const uint64_t kProfilingArStartLogid = 3;
const uint64_t kProfilingArEndLogid = 4;
const uint64_t kProfilingArMax = 10002;
const uint64_t kProfilingIterEndLogid = 65535;

const uint16_t kProfilingFpStartTagid = 2;
const uint16_t kProfilingBpEndTagid = 3;
const uint16_t kProfilingIterEndTagid = 4;
const uint16_t kProfilingArStartTagid = 10000;
const uint16_t kProfilingArEndTagid = 10001;

const map<uint64_t, uint16_t> kLogToTagMap = {
{kProfilingFpStartLogid, kProfilingFpStartTagid},
{kProfilingBpEndLogid, kProfilingBpEndTagid},
{kProfilingArStartLogid, kProfilingArStartTagid},
{kProfilingArEndLogid, kProfilingArEndTagid}
};

void GetTagFromLogid(const uint64_t &log_id, uint16_t &tag_id) {
GELOGD("log id is %u", log_id);
if (log_id == 0 || log_id > kProfilingArMax) {
GELOGW("log id:%u is out of range", log_id);
return;
}
if (log_id == kProfilingIterEndLogid) {
tag_id = kProfilingIterEndTagid;
return;
}
auto iter = kLogToTagMap.find(log_id);
if (iter != kLogToTagMap.end()) {
tag_id = iter->second;
return;
}
else {
tag_id = log_id + kProfilingArEndTagid - kProfilingArEndLogid;
return;
}
}
} // namespace

namespace ge {
Status ProfilerTraceTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) {
GELOGI("ProfilerTraceTaskInfo Init Start.");
@@ -27,7 +70,7 @@ Status ProfilerTraceTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *
GELOGE(PARAM_INVALID, "[Check][Param] davinci_model is null!");
return PARAM_INVALID;
}
model_id_ = davinci_model->GetModelId();
Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList());
if (ret != SUCCESS) {
return ret;
@@ -44,8 +87,9 @@ Status ProfilerTraceTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *

Status ProfilerTraceTaskInfo::Distribute() {
GELOGI("ProfilerTraceTaskInfo Distribute Start. logid = %lu. notify = %d.", log_id_, notify_);

rtError_t rt_ret = rtProfilerTrace(log_id_, notify_, flat_, stream_);
uint16_t tag_id = 0;
GetTagFromLogid(log_id_, tag_id);
rtError_t rt_ret = rtProfilerTraceEx(log_id_, model_id_, tag_id, stream_);
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtProfilerTrace failed, ret:0x%X, logid:%lu. notify:%d",
rt_ret, log_id_, notify_);


+ 6
- 1
ge/graph/load/model_manager/task_info/profiler_trace_task_info.h View File

@@ -21,7 +21,7 @@
namespace ge {
class ProfilerTraceTaskInfo : public TaskInfo {
public:
ProfilerTraceTaskInfo() : log_id_(0), notify_(false), flat_(0) {}
ProfilerTraceTaskInfo() : log_id_(0), notify_(false), flat_(0), model_id_(0) {}

~ProfilerTraceTaskInfo() override {}

@@ -29,10 +29,15 @@ class ProfilerTraceTaskInfo : public TaskInfo {

Status Distribute() override;

uint32_t GetModelId() { return model_id_; }



private:
uint64_t log_id_;
bool notify_;
uint32_t flat_;
uint32_t model_id_;
};
} // namespace ge
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_PROFILER_TRACE_TASK_INFO_H_

+ 48
- 5
ge/hybrid/node_executor/rts/rts_node_executor.cc View File

@@ -21,6 +21,49 @@
#include "graph/utils/tensor_utils.h"
#include "hybrid/model/hybrid_model.h"

namespace {
const uint64_t kProfilingFpStartLogid = 1;
const uint64_t kProfilingBpEndLogid = 2;
const uint64_t kProfilingArStartLogid = 3;
const uint64_t kProfilingArEndLogid = 4;
const uint64_t kProfilingArMax = 10002;
const uint64_t kProfilingIterEndLogid = 65535;

const uint16_t kProfilingFpStartTagid = 2;
const uint16_t kProfilingBpEndTagid = 3;
const uint16_t kProfilingIterEndTagid = 4;
const uint16_t kProfilingArStartTagid = 10000;
const uint16_t kProfilingArEndTagid = 10001;

const map<uint64_t, uint16_t> kLogToTagMap = {
{kProfilingFpStartLogid, kProfilingFpStartTagid},
{kProfilingBpEndLogid, kProfilingBpEndTagid},
{kProfilingArStartLogid, kProfilingArStartTagid},
{kProfilingArEndLogid, kProfilingArEndTagid}
};

void GetTagFromLogid(const uint64_t &log_id, uint16_t &tag_id) {
GELOGD("log id is %u", log_id);
if (log_id == 0 || log_id > kProfilingArMax) {
GELOGW("log id:%u is out of range", log_id);
return;
}
if (log_id == kProfilingIterEndLogid) {
tag_id = kProfilingIterEndTagid;
return;
}
auto iter = kLogToTagMap.find(log_id);
if (iter != kLogToTagMap.end()) {
tag_id = iter->second;
return;
}
else {
tag_id = log_id + kProfilingArEndTagid - kProfilingArEndLogid;
return;
}
}
} // namespace

namespace ge {
namespace hybrid {
REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::RTS, RtsNodeExecutor);
@@ -102,21 +145,21 @@ Status ProfilingTraceNodeTask::Init(const HybridModel &model, const NodePtr &nod
GELOGE(INTERNAL_ERROR, "Profiling node has no task to execute.");
return INTERNAL_ERROR;
}
model_id_ = model.GetModelId();
task_defs_ = *task_defs;
GELOGD("[%s] Done initialization successfully.", node->GetName().c_str());
return SUCCESS;
}

Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) {
uint16_t tag_id = 0;
for (const auto &task_def : task_defs_) {
auto log_time_stamp_def = task_def.log_timestamp();
uint64_t log_id = log_time_stamp_def.logid();
bool notify = log_time_stamp_def.notify();
uint32_t flat = log_time_stamp_def.flat();

GELOGD("ProfilingTraceTask execute async start. logid = %lu, notify = %d.", log_id, notify);
rtError_t rt_ret = rtProfilerTrace(log_id, notify, flat, context.GetStream());
GELOGD("ProfilingTraceTask execute async start. logid = %lu.", log_id);
GetTagFromLogid(log_id, tag_id);
rtError_t rt_ret = rtProfilerTraceEx(log_id, model_id_, tag_id, context.GetStream());
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);


+ 1
- 0
ge/hybrid/node_executor/rts/rts_node_executor.h View File

@@ -47,6 +47,7 @@ class ProfilingTraceNodeTask : public RtsNodeTask {
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;

private:
uint32_t model_id_ = 0;
std::vector<domi::TaskDef> task_defs_;
};



Loading…
Cancel
Save