From ba2fcefa041e62220687ab4de70934808b816763 Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Wed, 24 Mar 2021 18:49:54 +0800 Subject: [PATCH] refactor label manager Signed-off-by: zhoufeng --- ge/ge_runtime/CMakeLists.txt | 1 + ge/ge_runtime/task/label_goto_task.cc | 74 +------------------- ge/ge_runtime/task/label_goto_task.h | 28 +------- ge/ge_runtime/task/label_manager.cc | 119 ++++++++++++++++++++++++++++++++ ge/ge_runtime/task/label_manager.h | 54 +++++++++++++++ ge/ge_runtime/task/label_switch_task.cc | 72 ++++--------------- ge/ge_runtime/task/label_switch_task.h | 7 +- 7 files changed, 195 insertions(+), 160 deletions(-) create mode 100644 ge/ge_runtime/task/label_manager.cc create mode 100644 ge/ge_runtime/task/label_manager.h diff --git a/ge/ge_runtime/CMakeLists.txt b/ge/ge_runtime/CMakeLists.txt index b00dd5b3..40113285 100644 --- a/ge/ge_runtime/CMakeLists.txt +++ b/ge/ge_runtime/CMakeLists.txt @@ -16,6 +16,7 @@ set(GE_SRC_LIST "task/label_goto_task.cc" "task/label_set_task.cc" "task/label_switch_task.cc" + "task/label_manager.cc" ) add_library(ge_runtime SHARED ${GE_SRC_LIST}) diff --git a/ge/ge_runtime/task/label_goto_task.cc b/ge/ge_runtime/task/label_goto_task.cc index b464ccab..c04bd5cf 100644 --- a/ge/ge_runtime/task/label_goto_task.cc +++ b/ge/ge_runtime/task/label_goto_task.cc @@ -19,14 +19,10 @@ namespace ge { namespace model_runner { -std::weak_ptr LabelGotoTask::LabelManager::instance_; -std::mutex LabelGotoTask::LabelManager::instance_mutex_; - LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr &task_info) : TaskRepeater(model_context, task_info), task_info_(task_info), stream_(nullptr), - label_(nullptr), index_value_(nullptr) { if (task_info_ == nullptr) { GELOGW("task_info_ is null!"); @@ -44,13 +40,12 @@ LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::share return; } stream_ = stream_list[stream_id]; - label_ = label_list[label_id_]; label_manager_ = LabelManager::GetInstance(); if (label_manager_ == nullptr) { GELOGW("Get label manager instance failed."); return; } - label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, label_id_, label_); + label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, {label_id_}, label_list); } LabelGotoTask::~LabelGotoTask() { @@ -69,10 +64,6 @@ bool LabelGotoTask::Distribute() { GELOGE(PARAM_INVALID, "stream is null!"); return false; } - if (label_ == nullptr) { - GELOGE(PARAM_INVALID, "label is null!"); - return false; - } if (label_info_ == nullptr) { GELOGE(PARAM_INVALID, "label info is null!"); @@ -105,69 +96,6 @@ bool LabelGotoTask::Distribute() { return true; } -LabelGotoTask::LabelGuard::~LabelGuard() { - void *label_info = GetLabelInfo(); - if (label_info != nullptr) { - rtError_t rt_ret = rtFree(label_info); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "rtFree label_info failed! ret: 0x%X.", rt_ret); - } - } -} - -std::shared_ptr LabelGotoTask::LabelManager::GetInstance() { - std::lock_guard lock(instance_mutex_); - auto instance = instance_.lock(); - if (instance != nullptr) { - return instance; - } - - instance = std::make_shared(); - instance_ = instance; - return instance; -} - -std::shared_ptr LabelGotoTask::LabelManager::GetLabelInfo(rtModel_t model, uint32_t label_id, - void *label) { - std::lock_guard lock(model_info_mapping_mutex_); - rtError_t rt_ret; - auto model_iter = model_info_mapping_.find(model); - if (model_iter == model_info_mapping_.end()) { - model_info_mapping_.emplace(model, std::map>()); - model_iter = model_info_mapping_.find(model); - } - - std::map> &label_map = model_iter->second; - auto label_iter = label_map.find(label_id); - if (label_iter != label_map.end()) { - auto label_guard = label_iter->second.lock(); - if (label_guard != nullptr) { - GELOGI("model %p find same label id.", model, label_id); - return label_guard; - } - } - - GELOGI("Alloc label id %u for model %p.", label_id, model); - void *label_info; - std::vector label_list = {label}; - uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); - rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return nullptr; - } - - rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info, label_info_size); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return nullptr; - } - - auto label_guard = std::make_shared(label_info); - label_map.emplace(label_id, label_guard); - return label_guard; -} REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); - } // namespace model_runner } // namespace ge diff --git a/ge/ge_runtime/task/label_goto_task.h b/ge/ge_runtime/task/label_goto_task.h index 9a69780d..e579c683 100644 --- a/ge/ge_runtime/task/label_goto_task.h +++ b/ge/ge_runtime/task/label_goto_task.h @@ -22,6 +22,7 @@ #include #include #include "ge_runtime/task/task.h" +#include "ge_runtime/task/label_manager.h" namespace ge { namespace model_runner { @@ -34,41 +35,14 @@ class LabelGotoTask : public TaskRepeater { bool Distribute() override; private: - class LabelGuard; - class LabelManager; - std::shared_ptr task_info_; void *stream_; - void *label_; std::shared_ptr label_info_; void *index_value_; uint32_t label_id_; rtModel_t rt_model_handle_; std::shared_ptr label_manager_; }; - -class LabelGotoTask::LabelGuard { - public: - explicit LabelGuard(void *label_info) : label_info_(reinterpret_cast(label_info)) {} - ~LabelGuard(); - void *GetLabelInfo() { return reinterpret_cast(label_info_); } - - private: - uintptr_t label_info_; -}; - -class LabelGotoTask::LabelManager { - public: - static std::shared_ptr GetInstance(); - std::shared_ptr GetLabelInfo(rtModel_t model, uint32_t label_id, void *label); - - private: - std::mutex model_info_mapping_mutex_; - std::map>> model_info_mapping_; - - static std::weak_ptr instance_; - static std::mutex instance_mutex_; -}; } // namespace model_runner } // namespace ge diff --git a/ge/ge_runtime/task/label_manager.cc b/ge/ge_runtime/task/label_manager.cc new file mode 100644 index 00000000..a2b0c3aa --- /dev/null +++ b/ge/ge_runtime/task/label_manager.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ge_runtime/task/label_manager.h" +#include +#include +#include "runtime/mem.h" +#include "runtime/rt_model.h" +#include "common/ge_inner_error_codes.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +namespace model_runner { +std::weak_ptr LabelManager::instance_; +std::mutex LabelManager::instance_mutex_; + +template +static std::string GetVectorString(const std::vector &vec) { + std::string ret; + for (size_t i = 0; i < vec.size(); ++i) { + if (i != 0) { + ret.push_back(','); + } + ret += std::to_string(vec[i]); + } + return ret; +} + +LabelGuard::~LabelGuard() { + void *label_info = GetLabelInfo(); + if (label_info != nullptr) { + rtError_t rt_ret = rtFree(label_info); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtFree label_info failed! ret: 0x%X.", rt_ret); + } + } +} + +std::shared_ptr LabelManager::GetInstance() { + std::lock_guard lock(instance_mutex_); + auto instance = instance_.lock(); + if (instance != nullptr) { + return instance; + } + + instance = std::make_shared(); + instance_ = instance; + return instance; +} + +std::shared_ptr LabelManager::GetLabelInfo(rtModel_t model, const std::vector &label_ids, + const std::vector &all_label) { + std::lock_guard lock(model_info_mapping_mutex_); + rtError_t rt_ret; + auto model_iter = model_info_mapping_.find(model); + if (model_iter == model_info_mapping_.end()) { + model_info_mapping_.emplace(model, std::map>()); + model_iter = model_info_mapping_.find(model); + } + + std::string label_id_str = GetVectorString(label_ids); + auto &label_map = model_iter->second; + auto label_iter = label_map.find(label_id_str); + if (label_iter != label_map.end()) { + auto label_guard = label_iter->second.lock(); + if (label_guard != nullptr) { + GELOGI("model %p find same label id %s.", model, label_id_str.c_str()); + return label_guard; + } + } + + GELOGI("Alloc label id %s for model %p.", label_id_str.c_str(), model); + void *label_info; + std::vector label_list; + bool status = true; + std::transform(label_ids.begin(), label_ids.end(), std::back_inserter(label_list), + [&all_label, &status](uint32_t idx) -> void * { + if (idx >= all_label.size()) { + GELOGE(PARAM_INVALID, "Invalid label id %u, all label list size %zu.", idx, all_label.size()); + status = false; + return nullptr; + } + return all_label[idx]; + }); + if (!status) { + GELOGE(PARAM_INVALID, "Get label info failed."); + return nullptr; + } + uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); + rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return nullptr; + } + + rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info, label_info_size); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return nullptr; + } + + auto label_guard = std::make_shared(label_info); + label_map.emplace(label_id_str, label_guard); + return label_guard; +} +} // namespace model_runner +} // namespace ge diff --git a/ge/ge_runtime/task/label_manager.h b/ge/ge_runtime/task/label_manager.h new file mode 100644 index 00000000..f2c42c29 --- /dev/null +++ b/ge/ge_runtime/task/label_manager.h @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_ +#define GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_ + +#include +#include +#include +#include +#include + +namespace ge { +namespace model_runner { +class LabelGuard { + public: + explicit LabelGuard(void *label_info) : label_info_(reinterpret_cast(label_info)) {} + ~LabelGuard(); + void *GetLabelInfo() { return reinterpret_cast(label_info_); } + + private: + uintptr_t label_info_; +}; + +class LabelManager { + public: + static std::shared_ptr GetInstance(); + std::shared_ptr GetLabelInfo(rtModel_t model, const std::vector &label_ids, + const std::vector &all_label); + + private: + std::mutex model_info_mapping_mutex_; + std::map>> model_info_mapping_; + + static std::weak_ptr instance_; + static std::mutex instance_mutex_; +}; + + +} // namespace model_runner +} // namespace ge +#endif // GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_ \ No newline at end of file diff --git a/ge/ge_runtime/task/label_switch_task.cc b/ge/ge_runtime/task/label_switch_task.cc index 6e11a6ce..1f913d74 100644 --- a/ge/ge_runtime/task/label_switch_task.cc +++ b/ge/ge_runtime/task/label_switch_task.cc @@ -15,7 +15,6 @@ */ #include "ge_runtime/task/label_switch_task.h" -#include #include "ge_runtime/task/task_factory.h" namespace ge { @@ -25,14 +24,14 @@ LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, : TaskRepeater(model_context, task_info), task_info_(task_info), stream_(nullptr), - all_label_resource_(), label_info_(nullptr) { if (task_info_ == nullptr) { GELOGW("task_info_ is null!"); return; } - all_label_resource_ = model_context.label_list(); + rt_model_handle_ = model_context.rt_model_handle(); + auto all_label_resource = model_context.label_list(); auto stream_list = model_context.stream_list(); uint32_t stream_id = task_info->stream_id(); GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); @@ -41,31 +40,24 @@ LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, return; } stream_ = stream_list[stream_id]; - CopyLabelList(); -} - -LabelSwitchTask::~LabelSwitchTask() { - if (label_info_ != nullptr) { - rtError_t rt_ret = rtFree(label_info_); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); - } - label_info_ = nullptr; + label_manager_ = LabelManager::GetInstance(); + if (label_manager_ == nullptr) { + GELOGW("Get label manager instance failed."); + return; } + label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, task_info_->label_list(), all_label_resource); } +LabelSwitchTask::~LabelSwitchTask() {} + bool LabelSwitchTask::Distribute() { GELOGI("LabelSwitchTask Distribute start."); if (!CheckParamValid()) { return false; } - if (label_info_ == nullptr) { - GELOGE(PARAM_INVALID, "CopyLabelList failed, label info is null."); - return false; - } - - rtError_t rt_ret = rtLabelSwitchByIndex(task_info_->cond(), task_info_->label_size(), label_info_, stream_); + void *label_info = label_info_->GetLabelInfo(); + rtError_t rt_ret = rtLabelSwitchByIndex(task_info_->cond(), task_info_->label_size(), label_info, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return false; @@ -97,48 +89,14 @@ bool LabelSwitchTask::CheckParamValid() { return false; } - return true; -} - -void LabelSwitchTask::CopyLabelList() { - if (!CheckParamValid()) { - return; - } - - if (label_info_ != nullptr) { - GELOGE(PARAM_INVALID, "label_info_ has dirty data."); - return; - } - - const std::vector &label_index_list = task_info_->label_list(); - std::vector label_list(task_info_->label_size(), nullptr); - - for (size_t i = 0; i < task_info_->label_size(); ++i) { - uint32_t label_index = label_index_list[i]; - if (label_index >= all_label_resource_.size()) { - GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, - all_label_resource_.size()); - return; - } - label_list[i] = all_label_resource_[label_index]; - GELOGI("Case %zu: label id %zu.", i, label_index); - } - - uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); - rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return; + if (label_info_ == nullptr) { + GELOGE(PARAM_INVALID, "CopyLabelList failed, label info is null."); + return false; } - rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return; - } + return true; } REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); - } // namespace model_runner } // namespace ge diff --git a/ge/ge_runtime/task/label_switch_task.h b/ge/ge_runtime/task/label_switch_task.h index afd8e474..cfa6877c 100644 --- a/ge/ge_runtime/task/label_switch_task.h +++ b/ge/ge_runtime/task/label_switch_task.h @@ -19,6 +19,7 @@ #include #include "ge_runtime/task/task.h" +#include "ge_runtime/task/label_manager.h" namespace ge { namespace model_runner { @@ -32,12 +33,12 @@ class LabelSwitchTask : public TaskRepeater { private: bool CheckParamValid(); - void CopyLabelList(); std::shared_ptr task_info_; void *stream_; - std::vector all_label_resource_; - void *label_info_; + rtModel_t rt_model_handle_; + std::shared_ptr label_info_; + std::shared_ptr label_manager_; }; } // namespace model_runner } // namespace ge