@@ -28,7 +28,7 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace model_runner { | namespace model_runner { | ||||
class RuntimeModel; | class RuntimeModel; | ||||
using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>; | |||||
class ModelRunner { | class ModelRunner { | ||||
public: | public: | ||||
static ModelRunner &Instance(); | static ModelRunner &Instance(); | ||||
@@ -36,18 +36,8 @@ class ModelRunner { | |||||
bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, | bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, | ||||
std::shared_ptr<DavinciModel> davinci_model, std::shared_ptr<ModelListener> listener); | std::shared_ptr<DavinciModel> davinci_model, std::shared_ptr<ModelListener> listener); | ||||
bool DistributeTask(uint32_t model_id); | |||||
bool LoadModelComplete(uint32_t model_id); | |||||
const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const; | const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const; | ||||
const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const; | |||||
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap(uint32_t model_id) const; | |||||
void *GetModelHandle(uint32_t model_id) const; | |||||
bool UnloadModel(uint32_t model_id); | bool UnloadModel(uint32_t model_id); | ||||
bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); | bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); | ||||
@@ -21,7 +21,6 @@ | |||||
#include <functional> | #include <functional> | ||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <utility> | |||||
#include <vector> | #include <vector> | ||||
#include "cce/taskdown_api.h" | #include "cce/taskdown_api.h" | ||||
@@ -53,27 +52,21 @@ class TaskInfo { | |||||
virtual ~TaskInfo() {} | virtual ~TaskInfo() {} | ||||
uint32_t stream_id() const { return stream_id_; } | uint32_t stream_id() const { return stream_id_; } | ||||
TaskInfoType type() const { return type_; } | TaskInfoType type() const { return type_; } | ||||
std::string op_name() const { return op_name_; } | |||||
bool dump_flag() const { return dump_flag_; } | |||||
protected: | protected: | ||||
TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag) | |||||
: op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {} | |||||
TaskInfo(uint32_t stream_id, TaskInfoType type) : stream_id_(stream_id), type_(type) {} | |||||
private: | private: | ||||
std::string op_name_; | |||||
uint32_t stream_id_; | uint32_t stream_id_; | ||||
TaskInfoType type_; | TaskInfoType type_; | ||||
bool dump_flag_; | |||||
}; | }; | ||||
class CceTaskInfo : public TaskInfo { | class CceTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, | |||||
uint32_t block_dim, const std::vector<uint8_t> &args, uint32_t args_size, | |||||
const std::vector<uint8_t> &sm_desc, const std::vector<uint8_t> &flow_table, | |||||
const std::vector<uint8_t> &args_offset, bool is_flowtable) | |||||
: TaskInfo(op_name, stream_id, TaskInfoType::CCE, false), | |||||
CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim, | |||||
const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, | |||||
const std::vector<uint8_t> &flow_table, const std::vector<uint8_t> &args_offset, bool is_flowtable) | |||||
: TaskInfo(stream_id, TaskInfoType::CCE), | |||||
ctx_(ctx), | ctx_(ctx), | ||||
stub_func_(stub_func), | stub_func_(stub_func), | ||||
block_dim_(block_dim), | block_dim_(block_dim), | ||||
@@ -109,11 +102,11 @@ class CceTaskInfo : public TaskInfo { | |||||
class TbeTaskInfo : public TaskInfo { | class TbeTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, | |||||
const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, | |||||
uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs, | |||||
const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag) | |||||
: TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag), | |||||
TbeTaskInfo(uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, const std::vector<uint8_t> &args, | |||||
uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, uint32_t binary_size, | |||||
const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs, | |||||
const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs) | |||||
: TaskInfo(stream_id, TaskInfoType::TBE), | |||||
stub_func_(stub_func), | stub_func_(stub_func), | ||||
block_dim_(block_dim), | block_dim_(block_dim), | ||||
args_(args), | args_(args), | ||||
@@ -160,10 +153,9 @@ class TbeTaskInfo : public TaskInfo { | |||||
class AicpuTaskInfo : public TaskInfo { | class AicpuTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, | |||||
const std::string &node_def, const std::vector<void *> &input_data_addrs, | |||||
const std::vector<void *> &output_data_addrs, bool dump_flag) | |||||
: TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), | |||||
AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def, | |||||
const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs) | |||||
: TaskInfo(stream_id, TaskInfoType::AICPU), | |||||
so_name_(so_name), | so_name_(so_name), | ||||
kernel_name_(kernel_name), | kernel_name_(kernel_name), | ||||
node_def_(node_def), | node_def_(node_def), | ||||
@@ -185,45 +177,37 @@ class AicpuTaskInfo : public TaskInfo { | |||||
std::vector<void *> output_data_addrs_; | std::vector<void *> output_data_addrs_; | ||||
}; | }; | ||||
class LabelSetTaskInfo : public TaskInfo { | |||||
class LabelTaskInfo : public TaskInfo { | |||||
public: | public: | ||||
LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) | |||||
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {} | |||||
~LabelSetTaskInfo() override {} | |||||
uint32_t label_id() const { return label_id_; } | uint32_t label_id() const { return label_id_; } | ||||
private: | |||||
protected: | |||||
LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id) | |||||
: TaskInfo(stream_id, type), label_id_(label_id) {} | |||||
virtual ~LabelTaskInfo() override {} | |||||
uint32_t label_id_; | uint32_t label_id_; | ||||
}; | }; | ||||
class LabelGotoTaskInfo : public TaskInfo { | |||||
class LabelSetTaskInfo : public LabelTaskInfo { | |||||
public: | public: | ||||
LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) | |||||
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {} | |||||
~LabelGotoTaskInfo() override {} | |||||
uint32_t label_id() const { return label_id_; } | |||||
private: | |||||
uint32_t label_id_; | |||||
LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {} | |||||
~LabelSetTaskInfo() override {} | |||||
}; | }; | ||||
class LabelSwitchTaskInfo : public TaskInfo { | |||||
class LabelSwitchTaskInfo : public LabelTaskInfo { | |||||
public: | public: | ||||
LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size, | |||||
const std::vector<uint32_t> &label_list, void *cond) | |||||
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false), | |||||
label_size_(label_size), | |||||
label_list_(label_list), | |||||
cond_(cond) {} | |||||
LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {} | |||||
~LabelSwitchTaskInfo() override {} | ~LabelSwitchTaskInfo() override {} | ||||
uint32_t label_size() { return label_size_; }; | |||||
const std::vector<uint32_t> &label_list() { return label_list_; }; | |||||
void *cond() { return cond_; }; | |||||
}; | |||||
private: | |||||
uint32_t label_size_; | |||||
std::vector<uint32_t> label_list_; | |||||
void *cond_; | |||||
class LabelGotoTaskInfo : public LabelTaskInfo { | |||||
public: | |||||
LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {} | |||||
~LabelGotoTaskInfo() override {} | |||||
}; | }; | ||||
class EventTaskInfo : public TaskInfo { | class EventTaskInfo : public TaskInfo { | ||||
@@ -231,8 +215,8 @@ class EventTaskInfo : public TaskInfo { | |||||
uint32_t event_id() const { return event_id_; } | uint32_t event_id() const { return event_id_; } | ||||
protected: | protected: | ||||
EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id) | |||||
: TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {} | |||||
EventTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t event_id) | |||||
: TaskInfo(stream_id, type), event_id_(event_id) {} | |||||
virtual ~EventTaskInfo() override {} | virtual ~EventTaskInfo() override {} | ||||
uint32_t event_id_; | uint32_t event_id_; | ||||
@@ -240,41 +224,39 @@ class EventTaskInfo : public TaskInfo { | |||||
class EventRecordTaskInfo : public EventTaskInfo { | class EventRecordTaskInfo : public EventTaskInfo { | ||||
public: | public: | ||||
EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) | |||||
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {} | |||||
EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id) | |||||
: EventTaskInfo(stream_id, TaskInfoType::EVENT_RECORD, event_id) {} | |||||
~EventRecordTaskInfo() override {} | ~EventRecordTaskInfo() override {} | ||||
}; | }; | ||||
class EventWaitTaskInfo : public EventTaskInfo { | class EventWaitTaskInfo : public EventTaskInfo { | ||||
public: | public: | ||||
EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) | |||||
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {} | |||||
EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id) | |||||
: EventTaskInfo(stream_id, TaskInfoType::EVENT_WAIT, event_id) {} | |||||
~EventWaitTaskInfo() override {} | ~EventWaitTaskInfo() override {} | ||||
}; | }; | ||||
class FusionStartTaskInfo : public TaskInfo { | class FusionStartTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id) | |||||
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {} | |||||
explicit FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_START) {} | |||||
~FusionStartTaskInfo() override {} | ~FusionStartTaskInfo() override {} | ||||
}; | }; | ||||
class FusionEndTaskInfo : public TaskInfo { | class FusionEndTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id) | |||||
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {} | |||||
explicit FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_END) {} | |||||
~FusionEndTaskInfo() override {} | ~FusionEndTaskInfo() override {} | ||||
}; | }; | ||||
class HcclTaskInfo : public TaskInfo { | class HcclTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr, | |||||
void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, | |||||
HcclTaskInfo(uint32_t stream_id, const std::string hccl_type, void *input_data_addr, void *output_data_addr, | |||||
void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, | |||||
const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, | const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, | ||||
int64_t op_type, int64_t data_type, const std::string &group, | |||||
std::function<bool(void *, void *)> hcom_bind_model, std::function<bool(void *)> hcom_unbind_model, | |||||
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task, bool dump_flag) | |||||
: TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag), | |||||
int64_t op_type, int64_t data_type, std::function<bool(void *, void *)> hcom_bind_model, | |||||
std::function<bool(void *)> hcom_unbind_model, | |||||
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task) | |||||
: TaskInfo(stream_id, TaskInfoType::HCCL), | |||||
hccl_type_(hccl_type), | hccl_type_(hccl_type), | ||||
input_data_addr_(input_data_addr), | input_data_addr_(input_data_addr), | ||||
output_data_addr_(output_data_addr), | output_data_addr_(output_data_addr), | ||||
@@ -287,7 +269,6 @@ class HcclTaskInfo : public TaskInfo { | |||||
root_id_(root_id), | root_id_(root_id), | ||||
op_type_(op_type), | op_type_(op_type), | ||||
data_type_(data_type), | data_type_(data_type), | ||||
group_(group), | |||||
hcom_bind_model_(hcom_bind_model), | hcom_bind_model_(hcom_bind_model), | ||||
hcom_unbind_model_(hcom_unbind_model), | hcom_unbind_model_(hcom_unbind_model), | ||||
hcom_distribute_task_(hcom_distribute_task) {} | hcom_distribute_task_(hcom_distribute_task) {} | ||||
@@ -305,7 +286,6 @@ class HcclTaskInfo : public TaskInfo { | |||||
int64_t root_id() const { return root_id_; } | int64_t root_id() const { return root_id_; } | ||||
int64_t op_type() const { return op_type_; } | int64_t op_type() const { return op_type_; } | ||||
int64_t data_type() const { return data_type_; } | int64_t data_type() const { return data_type_; } | ||||
const std::string &group() const { return group_; } | |||||
std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; } | std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; } | ||||
std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; } | std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; } | ||||
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const { | std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const { | ||||
@@ -325,7 +305,6 @@ class HcclTaskInfo : public TaskInfo { | |||||
int64_t root_id_; | int64_t root_id_; | ||||
int64_t op_type_; | int64_t op_type_; | ||||
int64_t data_type_; | int64_t data_type_; | ||||
std::string group_; | |||||
std::function<bool(void *, void *)> hcom_bind_model_; | std::function<bool(void *, void *)> hcom_bind_model_; | ||||
std::function<bool(void *)> hcom_unbind_model_; | std::function<bool(void *)> hcom_unbind_model_; | ||||
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task_; | std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task_; | ||||
@@ -333,11 +312,8 @@ class HcclTaskInfo : public TaskInfo { | |||||
class ProfilerTraceTaskInfo : public TaskInfo { | class ProfilerTraceTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) | |||||
: TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false), | |||||
log_id_(log_id), | |||||
notify_(notify), | |||||
flat_(flat) {} | |||||
ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) | |||||
: TaskInfo(stream_id, TaskInfoType::PROFILER_TRACE), log_id_(log_id), notify_(notify), flat_(flat) {} | |||||
~ProfilerTraceTaskInfo() override {} | ~ProfilerTraceTaskInfo() override {} | ||||
uint64_t log_id() const { return log_id_; } | uint64_t log_id() const { return log_id_; } | ||||
@@ -352,9 +328,8 @@ class ProfilerTraceTaskInfo : public TaskInfo { | |||||
class MemcpyAsyncTaskInfo : public TaskInfo { | class MemcpyAsyncTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src, | |||||
uint64_t count, uint32_t kind, bool dump_flag) | |||||
: TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag), | |||||
MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind) | |||||
: TaskInfo(stream_id, TaskInfoType::MEMCPY_ASYNC), | |||||
dst_(dst), | dst_(dst), | ||||
dst_max_(dst_max), | dst_max_(dst_max), | ||||
src_(src), | src_(src), | ||||
@@ -378,9 +353,9 @@ class MemcpyAsyncTaskInfo : public TaskInfo { | |||||
class StreamSwitchTaskInfo : public TaskInfo { | class StreamSwitchTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr, | |||||
void *value_addr, int64_t cond, int64_t data_type) | |||||
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false), | |||||
StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond, | |||||
int64_t data_type) | |||||
: TaskInfo(stream_id, TaskInfoType::STREAM_SWITCH), | |||||
true_stream_id_(true_stream_id), | true_stream_id_(true_stream_id), | ||||
input_addr_(input_addr), | input_addr_(input_addr), | ||||
value_addr_(value_addr), | value_addr_(value_addr), | ||||
@@ -404,8 +379,8 @@ class StreamSwitchTaskInfo : public TaskInfo { | |||||
class StreamActiveTaskInfo : public TaskInfo { | class StreamActiveTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id) | |||||
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {} | |||||
StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id) | |||||
: TaskInfo(stream_id, TaskInfoType::STREAM_ACTIVE), active_stream_id_(active_stream_id) {} | |||||
~StreamActiveTaskInfo() override {} | ~StreamActiveTaskInfo() override {} | ||||
uint32_t active_stream_id() const { return active_stream_id_; } | uint32_t active_stream_id() const { return active_stream_id_; } | ||||
@@ -49,24 +49,6 @@ bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint | |||||
return true; | return true; | ||||
} | } | ||||
bool ModelRunner::DistributeTask(uint32_t model_id) { | |||||
auto model_iter = runtime_models_.find(model_id); | |||||
if (model_iter == runtime_models_.end()) { | |||||
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||||
return false; | |||||
} | |||||
return model_iter->second->DistributeTask(); | |||||
} | |||||
bool ModelRunner::LoadModelComplete(uint32_t model_id) { | |||||
auto model_iter = runtime_models_.find(model_id); | |||||
if (model_iter == runtime_models_.end()) { | |||||
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||||
return false; | |||||
} | |||||
return model_iter->second->LoadComplete(); | |||||
} | |||||
const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const { | const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const { | ||||
auto model_iter = runtime_models_.find(model_id); | auto model_iter = runtime_models_.find(model_id); | ||||
if (model_iter == runtime_models_.end()) { | if (model_iter == runtime_models_.end()) { | ||||
@@ -78,38 +60,6 @@ const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const | |||||
return model_iter->second->GetTaskIdList(); | return model_iter->second->GetTaskIdList(); | ||||
} | } | ||||
const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const { | |||||
auto model_iter = runtime_models_.find(model_id); | |||||
if (model_iter == runtime_models_.end()) { | |||||
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||||
static const std::vector<uint32_t> empty_ret; | |||||
return empty_ret; | |||||
} | |||||
return model_iter->second->GetStreamIdList(); | |||||
} | |||||
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const { | |||||
auto model_iter = runtime_models_.find(model_id); | |||||
if (model_iter == runtime_models_.end()) { | |||||
GELOGW("Model id %u not found.", model_id); | |||||
static const std::map<std::string, std::shared_ptr<RuntimeInfo>> empty_ret; | |||||
return empty_ret; | |||||
} | |||||
return model_iter->second->GetRuntimeInfoMap(); | |||||
} | |||||
void *ModelRunner::GetModelHandle(uint32_t model_id) const { | |||||
auto model_iter = runtime_models_.find(model_id); | |||||
if (model_iter == runtime_models_.end()) { | |||||
GELOGW("Model id %u not found.", model_id); | |||||
return nullptr; | |||||
} | |||||
return model_iter->second->GetModelHandle(); | |||||
} | |||||
bool ModelRunner::UnloadModel(uint32_t model_id) { | bool ModelRunner::UnloadModel(uint32_t model_id) { | ||||
auto iter = runtime_models_.find(model_id); | auto iter = runtime_models_.find(model_id); | ||||
if (iter != runtime_models_.end()) { | if (iter != runtime_models_.end()) { | ||||
@@ -76,7 +76,7 @@ bool Output::CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_inde | |||||
DataBuffer data_buf = rslt->blobs[data_begin + data_count]; | DataBuffer data_buf = rslt->blobs[data_begin + data_count]; | ||||
bool ret = SetDataBuf(data_buf, data_begin, data_count, i, support_mem_share); | bool ret = SetDataBuf(data_buf, data_begin, data_count, i, support_mem_share); | ||||
if (!ret) { | if (!ret) { | ||||
GELOGE(FAILED, "Copy data to host error. index: %lu, addr: %p", i, v_input_data_addr_[i]); | |||||
GELOGE(FAILED, "Copy data to host failed. index: %lu, addr: %p", i, v_input_data_addr_[i]); | |||||
return ret; | return ret; | ||||
} | } | ||||
data_index = data_begin + data_count; | data_index = data_begin + data_count; | ||||
@@ -28,6 +28,7 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace model_runner { | namespace model_runner { | ||||
RuntimeModel::~RuntimeModel() { | RuntimeModel::~RuntimeModel() { | ||||
GELOGI("RuntimeModel destructor start"); | GELOGI("RuntimeModel destructor start"); | ||||
@@ -115,34 +116,23 @@ bool RuntimeModel::InitEvent(uint32_t event_num) { | |||||
return true; | return true; | ||||
} | } | ||||
bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
GELOGI("batch number:%u.", davinci_model->GetBatchNum()); | |||||
label_list_.resize(davinci_model->GetBatchNum()); | |||||
for (auto &task_info : davinci_model->GetTaskInfoList()) { | |||||
if (task_info == nullptr) { | |||||
GELOGE(PARAM_INVALID, "task_info is null."); | |||||
continue; | |||||
} | |||||
if (task_info->type() != TaskInfoType::LABEL_SET) { | |||||
continue; | |||||
} | |||||
auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info); | |||||
if (label_set_task_info->stream_id() >= stream_list_.size()) { | |||||
GELOGE(PARAM_INVALID, "Invalid stream id."); | |||||
bool RuntimeModel::InitLabel(uint32_t batch_num) { | |||||
GELOGI("batch number:%u.", batch_num); | |||||
for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) { | |||||
rtLabel_t rt_lLabel = nullptr; | |||||
rtError_t rt_ret = rtLabelCreate(&rt_lLabel); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret); | |||||
return false; | return false; | ||||
} | } | ||||
rtLabel_t rt_label = nullptr; | |||||
rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); | |||||
if (rt_lLabel == nullptr) { | |||||
GELOGE(RT_FAILED, "rtLabel is nullptr!"); | |||||
return false; | return false; | ||||
} | } | ||||
label_list_[label_set_task_info->label_id()] = rt_label; | |||||
} | |||||
label_list_.emplace_back(rt_lLabel); | |||||
} | |||||
return true; | return true; | ||||
} | } | ||||
@@ -174,7 +164,7 @@ bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
return false; | return false; | ||||
} | } | ||||
if (!InitLabel(davinci_model)) { | |||||
if (!InitLabel(davinci_model->GetBatchNum())) { | |||||
return false; | return false; | ||||
} | } | ||||
@@ -219,41 +209,20 @@ bool RuntimeModel::LoadTask() { | |||||
return false; | return false; | ||||
} | } | ||||
task_id_list_.push_back(task_id); | task_id_list_.push_back(task_id); | ||||
stream_id_list_.push_back(stream_id); | |||||
if (task->Args() != nullptr) { | |||||
std::shared_ptr<RuntimeInfo> runtime_tuple = nullptr; | |||||
GE_MAKE_SHARED(runtime_tuple = std::make_shared<RuntimeInfo>(task_id, stream_id, task->Args()), return false); | |||||
auto emplace_ret = runtime_info_map_.emplace(task->task_name(), runtime_tuple); | |||||
if (!emplace_ret.second) { | |||||
GELOGW("Task name exist:%s", task->task_name().c_str()); | |||||
} | |||||
} | |||||
} | } | ||||
if (task_list_.empty()) { | if (task_list_.empty()) { | ||||
GELOGE(FAILED, "Task list is empty"); | GELOGE(FAILED, "Task list is empty"); | ||||
return false; | return false; | ||||
} | } | ||||
GELOGI("Distribute task succ."); | |||||
GELOGI("LoadTask succ."); | |||||
return true; | |||||
} | |||||
bool RuntimeModel::LoadComplete() { | |||||
uint32_t task_id = 0; | |||||
uint32_t stream_id = 0; | |||||
auto rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rtModelGetTaskId failed, ret:0x%X", rt_ret); | |||||
return RT_FAILED; | |||||
} | |||||
task_id_list_.push_back(task_id); | |||||
stream_id_list_.push_back(stream_id); | |||||
rt_ret = rtModelLoadComplete(rt_model_handle_); | |||||
auto rt_ret = rtModelLoadComplete(rt_model_handle_); | |||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt api rtModelLoadComplete failed, ret: 0x%X.", rt_ret); | GELOGE(RT_FAILED, "Call rt api rtModelLoadComplete failed, ret: 0x%X.", rt_ret); | ||||
return false; | return false; | ||||
} | } | ||||
GELOGI("LoadTask succ."); | |||||
return true; | return true; | ||||
} | } | ||||
@@ -283,16 +252,14 @@ bool RuntimeModel::Load(uint32_t device_id, uint64_t session_id, std::shared_ptr | |||||
} | } | ||||
GenerateTask(device_id, session_id, davinci_model); | GenerateTask(device_id, session_id, davinci_model); | ||||
return status; | |||||
} | |||||
bool RuntimeModel::DistributeTask() { | |||||
bool status = LoadTask(); | |||||
status = LoadTask(); | |||||
if (!status) { | if (!status) { | ||||
GELOGE(FAILED, "DistributeTask failed"); | GELOGE(FAILED, "DistributeTask failed"); | ||||
return false; | |||||
return status; | |||||
} | } | ||||
return true; | |||||
return status; | |||||
} | } | ||||
bool RuntimeModel::Run() { | bool RuntimeModel::Run() { | ||||
@@ -303,14 +270,10 @@ bool RuntimeModel::Run() { | |||||
return false; | return false; | ||||
} | } | ||||
GELOGI("Run rtModelExecute success, ret = 0x%X", ret); | |||||
GELOGI("Run rtModelExecute success"); | |||||
ret = rtStreamSynchronize(rt_model_stream_); | ret = rtStreamSynchronize(rt_model_stream_); | ||||
if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
if (ret == RT_ERROR_END_OF_SEQUENCE) { | |||||
GELOGI("Model stream RT_ERROR_END_OF_SEQUENCE signal received, ret = 0x%X", ret); | |||||
return true; | |||||
} | |||||
GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); | GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); | ||||
return false; | return false; | ||||
} | } | ||||
@@ -470,7 +433,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model | |||||
} | } | ||||
if (constant->output_tensors[0].size < constant->weight_data.size()) { | if (constant->output_tensors[0].size < constant->weight_data.size()) { | ||||
GELOGE(PARAM_INVALID, "Output size:%u less than weight data size:%zu", constant->output_tensors[0].size, | |||||
GELOGE(PARAM_INVALID, "Output size:%u is less than weight data size:%zu", constant->output_tensors[0].size, | |||||
constant->weight_data.size()); | constant->weight_data.size()); | ||||
return false; | return false; | ||||
} | } | ||||
@@ -485,8 +448,11 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model | |||||
/// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | ||||
/// and that of unknown shape is zero too. | /// and that of unknown shape is zero too. | ||||
/// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | ||||
int64_t elem_num = | |||||
(constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); | |||||
int64_t elem_num = constant->weight_tensors[0].GetShapeSize(); | |||||
if (elem_num == 0 && constant->weight_tensors[0].size == 0) { | |||||
elem_num = 1; | |||||
} | |||||
if (constant->weight_data.size() < sizeof(uint64_t)) { | if (constant->weight_data.size() < sizeof(uint64_t)) { | ||||
GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | ||||
return false; | return false; | ||||
@@ -529,6 +495,5 @@ void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutp | |||||
const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; } | const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; } | ||||
const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; } | |||||
} // namespace model_runner | } // namespace model_runner | ||||
} // namespace ge | } // namespace ge |
@@ -27,7 +27,7 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace model_runner { | namespace model_runner { | ||||
using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>; | |||||
class Task; | class Task; | ||||
class RuntimeModel { | class RuntimeModel { | ||||
public: | public: | ||||
@@ -35,12 +35,7 @@ class RuntimeModel { | |||||
~RuntimeModel(); | ~RuntimeModel(); | ||||
bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model); | bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model); | ||||
bool DistributeTask(); | |||||
bool LoadComplete(); | |||||
const std::vector<uint32_t> &GetTaskIdList() const; | const std::vector<uint32_t> &GetTaskIdList() const; | ||||
const std::vector<uint32_t> &GetStreamIdList() const; | |||||
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; } | |||||
rtModel_t GetModelHandle() const { return rt_model_handle_; } | |||||
bool Run(); | bool Run(); | ||||
bool CopyInputData(const InputData &input_data); | bool CopyInputData(const InputData &input_data); | ||||
bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | ||||
@@ -53,7 +48,7 @@ class RuntimeModel { | |||||
bool LoadTask(); | bool LoadTask(); | ||||
bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | ||||
bool InitEvent(uint32_t event_num); | bool InitEvent(uint32_t event_num); | ||||
bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model); | |||||
bool InitLabel(uint32_t batch_num); | |||||
bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | ||||
bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | ||||
bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | ||||
@@ -82,8 +77,6 @@ class RuntimeModel { | |||||
std::vector<std::shared_ptr<OpInfo>> constant_info_list_{}; | std::vector<std::shared_ptr<OpInfo>> constant_info_list_{}; | ||||
std::vector<uint32_t> task_id_list_{}; | std::vector<uint32_t> task_id_list_{}; | ||||
std::vector<uint32_t> stream_id_list_{}; | |||||
std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_; | |||||
}; | }; | ||||
} // namespace model_runner | } // namespace model_runner | ||||
@@ -85,15 +85,11 @@ bool AicpuTask::Distribute() { | |||||
return false; | return false; | ||||
} | } | ||||
input_output_addr_ = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset); | |||||
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; | |||||
GELOGI( | |||||
"Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s, dump_flag = %d.", | |||||
args_size, io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data(), dump_flag); | |||||
rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(task_info_->so_name().data()), | |||||
reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_, | |||||
args_size, nullptr, stream_, dump_flag); | |||||
GELOGI("Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s.", args_size, | |||||
io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data()); | |||||
rt_ret = rtCpuKernelLaunch(reinterpret_cast<const void *>(task_info_->so_name().data()), | |||||
reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_, args_size, | |||||
nullptr, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
return false; | return false; | ||||
@@ -18,7 +18,6 @@ | |||||
#define GE_GE_RUNTIME_TASK_AICPU_TASK_H_ | #define GE_GE_RUNTIME_TASK_AICPU_TASK_H_ | ||||
#include <memory> | #include <memory> | ||||
#include <string> | |||||
#include "ge_runtime/task/task.h" | #include "ge_runtime/task/task.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -31,17 +30,12 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> { | |||||
bool Distribute() override; | bool Distribute() override; | ||||
void *Args() override { return input_output_addr_; } | |||||
std::string task_name() const override { return task_info_->op_name(); } | |||||
private: | private: | ||||
static void ReleaseRtMem(void **ptr) noexcept; | static void ReleaseRtMem(void **ptr) noexcept; | ||||
std::shared_ptr<AicpuTaskInfo> task_info_; | std::shared_ptr<AicpuTaskInfo> task_info_; | ||||
void *stream_; | void *stream_; | ||||
void *args_; | void *args_; | ||||
void *input_output_addr_; | |||||
}; | }; | ||||
} // namespace model_runner | } // namespace model_runner | ||||
} // namespace ge | } // namespace ge | ||||
@@ -115,6 +115,7 @@ bool HcclTask::Distribute() { | |||||
rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM); | rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
(void)rtStreamDestroy(stream); | |||||
return false; | return false; | ||||
} | } | ||||
@@ -1,70 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_runtime/task/label_goto_task.h" | |||||
#include "ge_runtime/task/task_factory.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||||
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||||
task_info_(task_info), | |||||
stream_(nullptr), | |||||
label_(nullptr) { | |||||
if (task_info_ == nullptr) { | |||||
GELOGW("task_info_ is null!"); | |||||
return; | |||||
} | |||||
auto stream_list = model_context.stream_list(); | |||||
auto label_list = model_context.label_list(); | |||||
uint32_t stream_id = task_info->stream_id(); | |||||
uint32_t label_id = task_info->label_id(); | |||||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||||
if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||||
GELOGW("Stream/Label id invalid."); | |||||
return; | |||||
} | |||||
stream_ = stream_list[stream_id]; | |||||
label_ = label_list[label_id]; | |||||
} | |||||
LabelGotoTask::~LabelGotoTask() {} | |||||
bool LabelGotoTask::Distribute() { | |||||
GELOGI("LabelGotoTask Distribute start."); | |||||
if (stream_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "stream is null!"); | |||||
return false; | |||||
} | |||||
if (label_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "label is null!"); | |||||
return false; | |||||
} | |||||
rtError_t rt_ret = rtLabelGotoEx(label_, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
GELOGI("DistributeTask end."); | |||||
return true; | |||||
} | |||||
REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||||
} // namespace model_runner | |||||
} // namespace ge |
@@ -1,41 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||||
#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||||
#include <memory> | |||||
#include "ge_runtime/task/task.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||||
public: | |||||
LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info); | |||||
~LabelGotoTask() override; | |||||
bool Distribute() override; | |||||
private: | |||||
std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||||
void *stream_; | |||||
void *label_; | |||||
}; | |||||
} // namespace model_runner | |||||
} // namespace ge | |||||
#endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ |
@@ -1,70 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_runtime/task/label_set_task.h" | |||||
#include "ge_runtime/task/task_factory.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info) | |||||
: TaskRepeater<LabelSetTaskInfo>(model_context, task_info), | |||||
task_info_(task_info), | |||||
stream_(nullptr), | |||||
label_(nullptr) { | |||||
if (task_info_ == nullptr) { | |||||
GELOGW("task_info_ is null!"); | |||||
return; | |||||
} | |||||
auto stream_list = model_context.stream_list(); | |||||
auto label_list = model_context.label_list(); | |||||
uint32_t stream_id = task_info->stream_id(); | |||||
uint32_t label_id = task_info->label_id(); | |||||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||||
if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||||
GELOGW("Stream/Label id invalid."); | |||||
return; | |||||
} | |||||
stream_ = stream_list[stream_id]; | |||||
label_ = label_list[label_id]; | |||||
} | |||||
LabelSetTask::~LabelSetTask() {} | |||||
bool LabelSetTask::Distribute() { | |||||
GELOGI("LabelSetTask Distribute start."); | |||||
if (stream_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "stream is null!"); | |||||
return false; | |||||
} | |||||
if (label_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "label is null!"); | |||||
return false; | |||||
} | |||||
rtError_t rt_ret = rtLabelSet(label_, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
GELOGI("DistributeTask end."); | |||||
return true; | |||||
} | |||||
REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); | |||||
} // namespace model_runner | |||||
} // namespace ge |
@@ -1,41 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||||
#define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||||
#include <memory> | |||||
#include "ge_runtime/task/task.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> { | |||||
public: | |||||
LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info); | |||||
~LabelSetTask() override; | |||||
bool Distribute() override; | |||||
private: | |||||
std::shared_ptr<LabelSetTaskInfo> task_info_; | |||||
void *stream_; | |||||
void *label_; | |||||
}; | |||||
} // namespace model_runner | |||||
} // namespace ge | |||||
#endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ |
@@ -1,131 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_runtime/task/label_switch_task.h" | |||||
#include "ge_runtime/task/task_factory.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||||
const std::shared_ptr<LabelSwitchTaskInfo> &task_info) | |||||
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | |||||
task_info_(task_info), | |||||
stream_(nullptr), | |||||
all_label_resource_(), | |||||
label_info_(nullptr) { | |||||
if (task_info_ == nullptr) { | |||||
GELOGW("task_info_ is null!"); | |||||
return; | |||||
} | |||||
all_label_resource_ = model_context.label_list(); | |||||
auto stream_list = model_context.stream_list(); | |||||
uint32_t stream_id = task_info->stream_id(); | |||||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
if (stream_id >= stream_list.size()) { | |||||
GELOGW("Stream id invalid."); | |||||
return; | |||||
} | |||||
stream_ = stream_list[stream_id]; | |||||
} | |||||
LabelSwitchTask::~LabelSwitchTask() { | |||||
if (label_info_ != nullptr) { | |||||
rtError_t rt_ret = rtFree(label_info_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||||
} | |||||
label_info_ = nullptr; | |||||
} | |||||
} | |||||
bool LabelSwitchTask::Distribute() { | |||||
GELOGI("LabelSwitchTask Distribute start."); | |||||
if (!CheckParamValid()) { | |||||
return false; | |||||
} | |||||
const std::vector<uint32_t> &label_index_list = task_info_->label_list(); | |||||
std::vector<void *> label_list(task_info_->label_size(), nullptr); | |||||
for (size_t i = 0; i < task_info_->label_size(); ++i) { | |||||
uint32_t label_index = label_index_list[i]; | |||||
if (label_index >= all_label_resource_.size()) { | |||||
GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, | |||||
all_label_resource_.size()); | |||||
return false; | |||||
} | |||||
label_list[i] = all_label_resource_[label_index]; | |||||
GELOGI("Case %zu: label id %zu.", i, label_index); | |||||
} | |||||
uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | |||||
rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
GELOGI("DistributeTask end."); | |||||
return true; | |||||
} | |||||
bool LabelSwitchTask::CheckParamValid() { | |||||
if (stream_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "stream is null!"); | |||||
return false; | |||||
} | |||||
if (task_info_->label_list().empty()) { | |||||
GELOGE(PARAM_INVALID, "label_list is empty."); | |||||
return false; | |||||
} | |||||
if (task_info_->label_size() != task_info_->label_list().size()) { | |||||
GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), | |||||
task_info_->label_size()); | |||||
return false; | |||||
} | |||||
if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { | |||||
GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); | |||||
return false; | |||||
} | |||||
if (label_info_ != nullptr) { | |||||
GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | |||||
} // namespace model_runner | |||||
} // namespace ge |
@@ -1,44 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||||
#define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||||
#include <memory> | |||||
#include "ge_runtime/task/task.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||||
public: | |||||
LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info); | |||||
~LabelSwitchTask() override; | |||||
bool Distribute() override; | |||||
private: | |||||
bool CheckParamValid(); | |||||
std::shared_ptr<LabelSwitchTaskInfo> task_info_; | |||||
void *stream_; | |||||
std::vector<void *> all_label_resource_; | |||||
void *label_info_; | |||||
}; | |||||
} // namespace model_runner | |||||
} // namespace ge | |||||
#endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ |
@@ -51,7 +51,7 @@ bool StreamSwitchTask::Distribute() { | |||||
} | } | ||||
if (static_cast<uint64_t>(task_info_->true_stream_id()) >= stream_list_.size()) { | if (static_cast<uint64_t>(task_info_->true_stream_id()) >= stream_list_.size()) { | ||||
GELOGE(PARAM_INVALID, "true_stream_id %ld must less than stream_list_ size %zu!", task_info_->true_stream_id(), | |||||
GELOGE(PARAM_INVALID, "true_stream_id %ld must be less than stream_list_ size %zu!", task_info_->true_stream_id(), | |||||
stream_list_.size()); | stream_list_.size()); | ||||
return false; | return false; | ||||
} | } | ||||
@@ -18,9 +18,7 @@ | |||||
#define GE_GE_RUNTIME_TASK_TASK_H_ | #define GE_GE_RUNTIME_TASK_TASK_H_ | ||||
#include <memory> | #include <memory> | ||||
#include <utility> | |||||
#include <vector> | #include <vector> | ||||
#include <string> | |||||
#include "runtime/rt_model.h" | #include "runtime/rt_model.h" | ||||
#include "ge_runtime/model_context.h" | #include "ge_runtime/model_context.h" | ||||
#include "ge_runtime/task_info.h" | #include "ge_runtime/task_info.h" | ||||
@@ -34,10 +32,6 @@ class Task { | |||||
virtual ~Task() {} | virtual ~Task() {} | ||||
virtual bool Distribute() = 0; | virtual bool Distribute() = 0; | ||||
virtual void *Args() { return nullptr; } | |||||
virtual std::string task_name() const { return ""; } | |||||
}; | }; | ||||
template <class T> | template <class T> | ||||
@@ -95,14 +95,15 @@ bool TbeTask::Distribute() { | |||||
return false; | return false; | ||||
} | } | ||||
GELOGI("InitTbeTask end."); | |||||
GELOGI("DistributeTbeTask start."); | GELOGI("DistributeTbeTask start."); | ||||
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; | |||||
rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag); | |||||
rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt api rtKernelLaunch failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api rtKernelLaunch failed, ret: 0x%X", rt_ret); | ||||
return false; | return false; | ||||
} | } | ||||
GELOGI("[DataDump] task name:%s, dump_flag:%d", task_info_->op_name().c_str(), dump_flag); | |||||
GELOGI("DistributeTbeTask end."); | |||||
return true; | return true; | ||||
} | } | ||||
@@ -30,10 +30,6 @@ class TbeTask : public TaskRepeater<TbeTaskInfo> { | |||||
bool Distribute() override; | bool Distribute() override; | ||||
void *Args() override { return args_; } | |||||
std::string task_name() const override { return task_info_->op_name(); } | |||||
private: | private: | ||||
std::shared_ptr<TbeTaskInfo> task_info_; | std::shared_ptr<TbeTaskInfo> task_info_; | ||||
void *stream_; | void *stream_; | ||||