Browse Source

Add GetStreamIdList in ge_runtime

tags/v0.6.0-beta
caifubi 5 years ago
parent
commit
9ad993ae2c
4 changed files with 17 additions and 0 deletions
  1. +2
    -0
      inc/framework/ge_runtime/model_runner.h
  2. +11
    -0
      src/ge/ge_runtime/model_runner.cc
  3. +2
    -0
      src/ge/ge_runtime/runtime_model.cc
  4. +2
    -0
      src/ge/ge_runtime/runtime_model.h

+ 2
- 0
inc/framework/ge_runtime/model_runner.h View File

@@ -38,6 +38,8 @@ class ModelRunner {


const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const; const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const;


const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const;

bool UnloadModel(uint32_t model_id); bool UnloadModel(uint32_t model_id);


bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data);


+ 11
- 0
src/ge/ge_runtime/model_runner.cc View File

@@ -60,6 +60,17 @@ const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const
return model_iter->second->GetTaskIdList(); return model_iter->second->GetTaskIdList();
} }


const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
static const std::vector<uint32_t> empty_ret;
return empty_ret;
}

return model_iter->second->GetStreamIdList();
}

bool ModelRunner::UnloadModel(uint32_t model_id) { bool ModelRunner::UnloadModel(uint32_t model_id) {
auto iter = runtime_models_.find(model_id); auto iter = runtime_models_.find(model_id);
if (iter != runtime_models_.end()) { if (iter != runtime_models_.end()) {


+ 2
- 0
src/ge/ge_runtime/runtime_model.cc View File

@@ -207,6 +207,7 @@ bool RuntimeModel::LoadTask() {
return false; return false;
} }
task_id_list_.push_back(task_id); task_id_list_.push_back(task_id);
stream_id_list_.push_back(stream_id);
} }
GELOGI("Distribute task succ."); GELOGI("Distribute task succ.");


@@ -486,5 +487,6 @@ void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutp


const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; } const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; }


const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; }
} // namespace model_runner } // namespace model_runner
} // namespace ge } // namespace ge

+ 2
- 0
src/ge/ge_runtime/runtime_model.h View File

@@ -36,6 +36,7 @@ class RuntimeModel {


bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model); bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model);
const std::vector<uint32_t> &GetTaskIdList() const; const std::vector<uint32_t> &GetTaskIdList() const;
const std::vector<uint32_t> &GetStreamIdList() const;
bool Run(); bool Run();
bool CopyInputData(const InputData &input_data); bool CopyInputData(const InputData &input_data);
bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc,
@@ -77,6 +78,7 @@ class RuntimeModel {
std::vector<std::shared_ptr<OpInfo>> constant_info_list_{}; std::vector<std::shared_ptr<OpInfo>> constant_info_list_{};


std::vector<uint32_t> task_id_list_{}; std::vector<uint32_t> task_id_list_{};
std::vector<uint32_t> stream_id_list_{};
}; };


} // namespace model_runner } // namespace model_runner


Loading…
Cancel
Save