From 179e10f36b34c74648194c8dceb0ddfb82616633 Mon Sep 17 00:00:00 2001 From: zhupuxu Date: Thu, 18 Mar 2021 14:12:26 +0800 Subject: [PATCH] label switch Signed-off-by: zhupuxu --- ge/ge_runtime/task/label_goto_task.cc | 148 ++++++++++++++++++++++---------- ge/ge_runtime/task/label_goto_task.h | 40 +++++++-- ge/ge_runtime/task/label_switch_task.cc | 67 +++++++++------ ge/ge_runtime/task/label_switch_task.h | 1 + 4 files changed, 178 insertions(+), 78 deletions(-) diff --git a/ge/ge_runtime/task/label_goto_task.cc b/ge/ge_runtime/task/label_goto_task.cc index ad93a98f..b464ccab 100644 --- a/ge/ge_runtime/task/label_goto_task.cc +++ b/ge/ge_runtime/task/label_goto_task.cc @@ -16,71 +16,88 @@ #include "ge_runtime/task/label_goto_task.h" #include "ge_runtime/task/task_factory.h" -#include "framework/common/util.h" namespace ge { namespace model_runner { +std::weak_ptr LabelGotoTask::LabelManager::instance_; +std::mutex LabelGotoTask::LabelManager::instance_mutex_; + LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr &task_info) - : TaskRepeater(model_context, task_info), task_info_(task_info) { + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + label_(nullptr), + index_value_(nullptr) { if (task_info_ == nullptr) { GELOGW("task_info_ is null!"); return; } auto stream_list = model_context.stream_list(); auto label_list = model_context.label_list(); + rt_model_handle_ = model_context.rt_model_handle(); uint32_t stream_id = task_info->stream_id(); - uint32_t label_id = task_info->label_id(); + label_id_ = task_info->label_id(); GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); - GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); - if (stream_id >= stream_list.size() || label_id >= label_list.size()) { + GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id_); + if (stream_id >= stream_list.size() || label_id_ >= label_list.size()) { GELOGW("Stream/Label id invalid."); return; } stream_ = stream_list[stream_id]; - label_ = label_list[label_id]; + label_ = label_list[label_id_]; + label_manager_ = LabelManager::GetInstance(); + if (label_manager_ == nullptr) { + GELOGW("Get label manager instance failed."); + return; + } + label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, label_id_, label_); } LabelGotoTask::~LabelGotoTask() { - GE_FREE_RT_LOG(label_info_); - GE_FREE_RT_LOG(index_value_); + if (index_value_ != nullptr) { + rtError_t rt_ret = rtFree(index_value_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtFree index_value_ failed! ret: 0x%X.", rt_ret); + } + index_value_ = nullptr; + } } bool LabelGotoTask::Distribute() { GELOGI("LabelGotoTask Distribute start."); - if (!CheckParamValid()) { - return false; - } - - const std::vector label_list = { label_ }; - rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); + if (stream_ == nullptr) { + GELOGE(PARAM_INVALID, "stream is null!"); return false; } - - uint64_t branch_index = 0; - rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); + if (label_ == nullptr) { + GELOGE(PARAM_INVALID, "label is null!"); return false; } - uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); - rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); + if (label_info_ == nullptr) { + GELOGE(PARAM_INVALID, "label info is null!"); return false; } - rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); - return false; + if (index_value_ == nullptr) { + rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + uint64_t index = 0; + rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &index, sizeof(index), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } } - rt_ret = rtLabelSwitchByIndex(index_value_, label_list.size(), label_info_, stream_); + void *label_info = label_info_->GetLabelInfo(); + rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, 1, label_info, stream_); if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return false; } @@ -88,30 +105,69 @@ bool LabelGotoTask::Distribute() { return true; } -bool LabelGotoTask::CheckParamValid() { - if (stream_ == nullptr) { - GELOGE(PARAM_INVALID, "stream is null!"); - return false; +LabelGotoTask::LabelGuard::~LabelGuard() { + void *label_info = GetLabelInfo(); + if (label_info != nullptr) { + rtError_t rt_ret = rtFree(label_info); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtFree label_info failed! ret: 0x%X.", rt_ret); + } } +} - if (label_ == nullptr) { - GELOGE(PARAM_INVALID, "label is null!"); - return false; +std::shared_ptr LabelGotoTask::LabelManager::GetInstance() { + std::lock_guard lock(instance_mutex_); + auto instance = instance_.lock(); + if (instance != nullptr) { + return instance; } - if (label_info_ != nullptr) { - GELOGE(PARAM_INVALID, "label_info_ has dirty data."); - return false; + instance = std::make_shared(); + instance_ = instance; + return instance; +} + +std::shared_ptr LabelGotoTask::LabelManager::GetLabelInfo(rtModel_t model, uint32_t label_id, + void *label) { + std::lock_guard lock(model_info_mapping_mutex_); + rtError_t rt_ret; + auto model_iter = model_info_mapping_.find(model); + if (model_iter == model_info_mapping_.end()) { + model_info_mapping_.emplace(model, std::map>()); + model_iter = model_info_mapping_.find(model); } - if (index_value_ != nullptr) { - GELOGE(PARAM_INVALID, "index_value_ has dirty data."); - return false; + std::map> &label_map = model_iter->second; + auto label_iter = label_map.find(label_id); + if (label_iter != label_map.end()) { + auto label_guard = label_iter->second.lock(); + if (label_guard != nullptr) { + GELOGI("model %p find same label id.", model, label_id); + return label_guard; + } } - return true; -} + GELOGI("Alloc label id %u for model %p.", label_id, model); + void *label_info; + std::vector label_list = {label}; + uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); + rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return nullptr; + } + rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info, label_info_size); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return nullptr; + } + + auto label_guard = std::make_shared(label_info); + label_map.emplace(label_id, label_guard); + return label_guard; +} REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); + } // namespace model_runner } // namespace ge diff --git a/ge/ge_runtime/task/label_goto_task.h b/ge/ge_runtime/task/label_goto_task.h index addbb700..9a69780d 100644 --- a/ge/ge_runtime/task/label_goto_task.h +++ b/ge/ge_runtime/task/label_goto_task.h @@ -18,6 +18,9 @@ #define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ #include +#include +#include +#include #include "ge_runtime/task/task.h" namespace ge { @@ -31,13 +34,40 @@ class LabelGotoTask : public TaskRepeater { bool Distribute() override; private: - bool CheckParamValid(); + class LabelGuard; + class LabelManager; std::shared_ptr task_info_; - void *stream_{nullptr}; - void *label_{nullptr}; - void *label_info_{nullptr}; - void *index_value_{nullptr}; + void *stream_; + void *label_; + std::shared_ptr label_info_; + void *index_value_; + uint32_t label_id_; + rtModel_t rt_model_handle_; + std::shared_ptr label_manager_; +}; + +class LabelGotoTask::LabelGuard { + public: + explicit LabelGuard(void *label_info) : label_info_(reinterpret_cast(label_info)) {} + ~LabelGuard(); + void *GetLabelInfo() { return reinterpret_cast(label_info_); } + + private: + uintptr_t label_info_; +}; + +class LabelGotoTask::LabelManager { + public: + static std::shared_ptr GetInstance(); + std::shared_ptr GetLabelInfo(rtModel_t model, uint32_t label_id, void *label); + + private: + std::mutex model_info_mapping_mutex_; + std::map>> model_info_mapping_; + + static std::weak_ptr instance_; + static std::mutex instance_mutex_; }; } // namespace model_runner } // namespace ge diff --git a/ge/ge_runtime/task/label_switch_task.cc b/ge/ge_runtime/task/label_switch_task.cc index a3c2d41a..6e11a6ce 100644 --- a/ge/ge_runtime/task/label_switch_task.cc +++ b/ge/ge_runtime/task/label_switch_task.cc @@ -15,6 +15,7 @@ */ #include "ge_runtime/task/label_switch_task.h" +#include #include "ge_runtime/task/task_factory.h" namespace ge { @@ -40,6 +41,7 @@ LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, return; } stream_ = stream_list[stream_id]; + CopyLabelList(); } LabelSwitchTask::~LabelSwitchTask() { @@ -58,34 +60,12 @@ bool LabelSwitchTask::Distribute() { return false; } - const std::vector &label_index_list = task_info_->label_list(); - std::vector label_list(task_info_->label_size(), nullptr); - - for (size_t i = 0; i < task_info_->label_size(); ++i) { - uint32_t label_index = label_index_list[i]; - if (label_index >= all_label_resource_.size()) { - GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, - all_label_resource_.size()); - return false; - } - label_list[i] = all_label_resource_[label_index]; - GELOGI("Case %zu: label id %zu.", i, label_index); - } - - uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); - rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return false; - } - - rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + if (label_info_ == nullptr) { + GELOGE(PARAM_INVALID, "CopyLabelList failed, label info is null."); return false; } - rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); + rtError_t rt_ret = rtLabelSwitchByIndex(task_info_->cond(), task_info_->label_size(), label_info_, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return false; @@ -117,12 +97,45 @@ bool LabelSwitchTask::CheckParamValid() { return false; } + return true; +} + +void LabelSwitchTask::CopyLabelList() { + if (!CheckParamValid()) { + return; + } + if (label_info_ != nullptr) { GELOGE(PARAM_INVALID, "label_info_ has dirty data."); - return false; + return; } - return true; + const std::vector &label_index_list = task_info_->label_list(); + std::vector label_list(task_info_->label_size(), nullptr); + + for (size_t i = 0; i < task_info_->label_size(); ++i) { + uint32_t label_index = label_index_list[i]; + if (label_index >= all_label_resource_.size()) { + GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, + all_label_resource_.size()); + return; + } + label_list[i] = all_label_resource_[label_index]; + GELOGI("Case %zu: label id %zu.", i, label_index); + } + + uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); + rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return; + } + + rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return; + } } REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); diff --git a/ge/ge_runtime/task/label_switch_task.h b/ge/ge_runtime/task/label_switch_task.h index 463faa31..afd8e474 100644 --- a/ge/ge_runtime/task/label_switch_task.h +++ b/ge/ge_runtime/task/label_switch_task.h @@ -32,6 +32,7 @@ class LabelSwitchTask : public TaskRepeater { private: bool CheckParamValid(); + void CopyLabelList(); std::shared_ptr task_info_; void *stream_;