@@ -38,6 +38,8 @@ class ModelRunner { | |||||
const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const; | const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const; | ||||
const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const; | |||||
bool UnloadModel(uint32_t model_id); | bool UnloadModel(uint32_t model_id); | ||||
bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); | bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); | ||||
@@ -60,6 +60,17 @@ const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const | |||||
return model_iter->second->GetTaskIdList(); | return model_iter->second->GetTaskIdList(); | ||||
} | } | ||||
const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const { | |||||
auto model_iter = runtime_models_.find(model_id); | |||||
if (model_iter == runtime_models_.end()) { | |||||
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||||
static const std::vector<uint32_t> empty_ret; | |||||
return empty_ret; | |||||
} | |||||
return model_iter->second->GetStreamIdList(); | |||||
} | |||||
bool ModelRunner::UnloadModel(uint32_t model_id) { | bool ModelRunner::UnloadModel(uint32_t model_id) { | ||||
auto iter = runtime_models_.find(model_id); | auto iter = runtime_models_.find(model_id); | ||||
if (iter != runtime_models_.end()) { | if (iter != runtime_models_.end()) { | ||||
@@ -207,6 +207,7 @@ bool RuntimeModel::LoadTask() { | |||||
return false; | return false; | ||||
} | } | ||||
task_id_list_.push_back(task_id); | task_id_list_.push_back(task_id); | ||||
stream_id_list_.push_back(stream_id); | |||||
} | } | ||||
GELOGI("Distribute task succ."); | GELOGI("Distribute task succ."); | ||||
@@ -486,5 +487,6 @@ void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutp | |||||
const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; } | const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; } | ||||
const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; } | |||||
} // namespace model_runner | } // namespace model_runner | ||||
} // namespace ge | } // namespace ge |
@@ -36,6 +36,7 @@ class RuntimeModel { | |||||
bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model); | bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model); | ||||
const std::vector<uint32_t> &GetTaskIdList() const; | const std::vector<uint32_t> &GetTaskIdList() const; | ||||
const std::vector<uint32_t> &GetStreamIdList() const; | |||||
bool Run(); | bool Run(); | ||||
bool CopyInputData(const InputData &input_data); | bool CopyInputData(const InputData &input_data); | ||||
bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | ||||
@@ -77,6 +78,7 @@ class RuntimeModel { | |||||
std::vector<std::shared_ptr<OpInfo>> constant_info_list_{}; | std::vector<std::shared_ptr<OpInfo>> constant_info_list_{}; | ||||
std::vector<uint32_t> task_id_list_{}; | std::vector<uint32_t> task_id_list_{}; | ||||
std::vector<uint32_t> stream_id_list_{}; | |||||
}; | }; | ||||
} // namespace model_runner | } // namespace model_runner | ||||