@@ -16,6 +16,7 @@ set(GE_SRC_LIST | |||||
"task/label_goto_task.cc" | "task/label_goto_task.cc" | ||||
"task/label_set_task.cc" | "task/label_set_task.cc" | ||||
"task/label_switch_task.cc" | "task/label_switch_task.cc" | ||||
"task/label_manager.cc" | |||||
) | ) | ||||
add_library(ge_runtime SHARED ${GE_SRC_LIST}) | add_library(ge_runtime SHARED ${GE_SRC_LIST}) | ||||
@@ -16,99 +16,83 @@ | |||||
#include "ge_runtime/task/label_goto_task.h" | #include "ge_runtime/task/label_goto_task.h" | ||||
#include "ge_runtime/task/task_factory.h" | #include "ge_runtime/task/task_factory.h" | ||||
#include "framework/common/util.h" | |||||
namespace ge { | namespace ge { | ||||
namespace model_runner { | namespace model_runner { | ||||
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | ||||
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), task_info_(task_info) { | |||||
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||||
task_info_(task_info), | |||||
stream_(nullptr), | |||||
index_value_(nullptr) { | |||||
if (task_info_ == nullptr) { | if (task_info_ == nullptr) { | ||||
GELOGW("task_info_ is null!"); | GELOGW("task_info_ is null!"); | ||||
return; | return; | ||||
} | } | ||||
auto stream_list = model_context.stream_list(); | auto stream_list = model_context.stream_list(); | ||||
auto label_list = model_context.label_list(); | auto label_list = model_context.label_list(); | ||||
rt_model_handle_ = model_context.rt_model_handle(); | |||||
uint32_t stream_id = task_info->stream_id(); | uint32_t stream_id = task_info->stream_id(); | ||||
uint32_t label_id = task_info->label_id(); | |||||
label_id_ = task_info->label_id(); | |||||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | ||||
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||||
if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||||
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id_); | |||||
if (stream_id >= stream_list.size() || label_id_ >= label_list.size()) { | |||||
GELOGW("Stream/Label id invalid."); | GELOGW("Stream/Label id invalid."); | ||||
return; | return; | ||||
} | } | ||||
stream_ = stream_list[stream_id]; | stream_ = stream_list[stream_id]; | ||||
label_ = label_list[label_id]; | |||||
label_manager_ = LabelManager::GetInstance(); | |||||
if (label_manager_ == nullptr) { | |||||
GELOGW("Get label manager instance failed."); | |||||
return; | |||||
} | |||||
label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, {label_id_}, label_list); | |||||
} | } | ||||
LabelGotoTask::~LabelGotoTask() { | LabelGotoTask::~LabelGotoTask() { | ||||
GE_FREE_RT_LOG(label_info_); | |||||
GE_FREE_RT_LOG(index_value_); | |||||
if (index_value_ != nullptr) { | |||||
rtError_t rt_ret = rtFree(index_value_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "rtFree index_value_ failed! ret: 0x%X.", rt_ret); | |||||
} | |||||
index_value_ = nullptr; | |||||
} | |||||
} | } | ||||
bool LabelGotoTask::Distribute() { | bool LabelGotoTask::Distribute() { | ||||
GELOGI("LabelGotoTask Distribute start."); | GELOGI("LabelGotoTask Distribute start."); | ||||
if (!CheckParamValid()) { | |||||
return false; | |||||
} | |||||
const std::vector<void *> label_list = { label_ }; | |||||
rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||||
return false; | |||||
} | |||||
uint64_t branch_index = 0; | |||||
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||||
return false; | |||||
} | |||||
uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); | |||||
rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||||
return false; | |||||
} | |||||
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||||
return false; | |||||
} | |||||
rt_ret = rtLabelSwitchByIndex(index_value_, label_list.size(), label_info_, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||||
return false; | |||||
} | |||||
GELOGI("DistributeTask end."); | |||||
return true; | |||||
} | |||||
bool LabelGotoTask::CheckParamValid() { | |||||
if (stream_ == nullptr) { | if (stream_ == nullptr) { | ||||
GELOGE(PARAM_INVALID, "stream is null!"); | GELOGE(PARAM_INVALID, "stream is null!"); | ||||
return false; | return false; | ||||
} | } | ||||
if (label_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "label is null!"); | |||||
if (label_info_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "label info is null!"); | |||||
return false; | return false; | ||||
} | } | ||||
if (label_info_ != nullptr) { | |||||
GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||||
return false; | |||||
if (index_value_ == nullptr) { | |||||
rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
uint64_t index = 0; | |||||
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &index, sizeof(index), RT_MEMCPY_HOST_TO_DEVICE); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
} | } | ||||
if (index_value_ != nullptr) { | |||||
GELOGE(PARAM_INVALID, "index_value_ has dirty data."); | |||||
void *label_info = label_info_->GetLabelInfo(); | |||||
rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, 1, label_info, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | return false; | ||||
} | } | ||||
GELOGI("DistributeTask end."); | |||||
return true; | return true; | ||||
} | } | ||||
@@ -18,7 +18,11 @@ | |||||
#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | #define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | ||||
#include <memory> | #include <memory> | ||||
#include <vector> | |||||
#include <map> | |||||
#include <mutex> | |||||
#include "ge_runtime/task/task.h" | #include "ge_runtime/task/task.h" | ||||
#include "ge_runtime/task/label_manager.h" | |||||
namespace ge { | namespace ge { | ||||
namespace model_runner { | namespace model_runner { | ||||
@@ -31,13 +35,13 @@ class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||||
bool Distribute() override; | bool Distribute() override; | ||||
private: | private: | ||||
bool CheckParamValid(); | |||||
std::shared_ptr<LabelGotoTaskInfo> task_info_; | std::shared_ptr<LabelGotoTaskInfo> task_info_; | ||||
void *stream_{nullptr}; | |||||
void *label_{nullptr}; | |||||
void *label_info_{nullptr}; | |||||
void *index_value_{nullptr}; | |||||
void *stream_; | |||||
std::shared_ptr<LabelGuard> label_info_; | |||||
void *index_value_; | |||||
uint32_t label_id_; | |||||
rtModel_t rt_model_handle_; | |||||
std::shared_ptr<LabelManager> label_manager_; | |||||
}; | }; | ||||
} // namespace model_runner | } // namespace model_runner | ||||
} // namespace ge | } // namespace ge | ||||
@@ -0,0 +1,119 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_runtime/task/label_manager.h" | |||||
#include <algorithm> | |||||
#include <string> | |||||
#include "runtime/mem.h" | |||||
#include "runtime/rt_model.h" | |||||
#include "common/ge_inner_error_codes.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
std::weak_ptr<LabelManager> LabelManager::instance_; | |||||
std::mutex LabelManager::instance_mutex_; | |||||
template <class T> | |||||
static std::string GetVectorString(const std::vector<T> &vec) { | |||||
std::string ret; | |||||
for (size_t i = 0; i < vec.size(); ++i) { | |||||
if (i != 0) { | |||||
ret.push_back(','); | |||||
} | |||||
ret += std::to_string(vec[i]); | |||||
} | |||||
return ret; | |||||
} | |||||
LabelGuard::~LabelGuard() { | |||||
void *label_info = GetLabelInfo(); | |||||
if (label_info != nullptr) { | |||||
rtError_t rt_ret = rtFree(label_info); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "rtFree label_info failed! ret: 0x%X.", rt_ret); | |||||
} | |||||
} | |||||
} | |||||
std::shared_ptr<LabelManager> LabelManager::GetInstance() { | |||||
std::lock_guard<std::mutex> lock(instance_mutex_); | |||||
auto instance = instance_.lock(); | |||||
if (instance != nullptr) { | |||||
return instance; | |||||
} | |||||
instance = std::make_shared<LabelManager>(); | |||||
instance_ = instance; | |||||
return instance; | |||||
} | |||||
std::shared_ptr<LabelGuard> LabelManager::GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids, | |||||
const std::vector<void *> &all_label) { | |||||
std::lock_guard<std::mutex> lock(model_info_mapping_mutex_); | |||||
rtError_t rt_ret; | |||||
auto model_iter = model_info_mapping_.find(model); | |||||
if (model_iter == model_info_mapping_.end()) { | |||||
model_info_mapping_.emplace(model, std::map<std::string, std::weak_ptr<LabelGuard>>()); | |||||
model_iter = model_info_mapping_.find(model); | |||||
} | |||||
std::string label_id_str = GetVectorString(label_ids); | |||||
auto &label_map = model_iter->second; | |||||
auto label_iter = label_map.find(label_id_str); | |||||
if (label_iter != label_map.end()) { | |||||
auto label_guard = label_iter->second.lock(); | |||||
if (label_guard != nullptr) { | |||||
GELOGI("model %p find same label id %s.", model, label_id_str.c_str()); | |||||
return label_guard; | |||||
} | |||||
} | |||||
GELOGI("Alloc label id %s for model %p.", label_id_str.c_str(), model); | |||||
void *label_info; | |||||
std::vector<void *> label_list; | |||||
bool status = true; | |||||
std::transform(label_ids.begin(), label_ids.end(), std::back_inserter(label_list), | |||||
[&all_label, &status](uint32_t idx) -> void * { | |||||
if (idx >= all_label.size()) { | |||||
GELOGE(PARAM_INVALID, "Invalid label id %u, all label list size %zu.", idx, all_label.size()); | |||||
status = false; | |||||
return nullptr; | |||||
} | |||||
return all_label[idx]; | |||||
}); | |||||
if (!status) { | |||||
GELOGE(PARAM_INVALID, "Get label info failed."); | |||||
return nullptr; | |||||
} | |||||
uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); | |||||
rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return nullptr; | |||||
} | |||||
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info, label_info_size); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return nullptr; | |||||
} | |||||
auto label_guard = std::make_shared<LabelGuard>(label_info); | |||||
label_map.emplace(label_id_str, label_guard); | |||||
return label_guard; | |||||
} | |||||
} // namespace model_runner | |||||
} // namespace ge |
@@ -0,0 +1,54 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_ | |||||
#define GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_ | |||||
#include <vector> | |||||
#include <memory> | |||||
#include <mutex> | |||||
#include <map> | |||||
#include <runtime/base.h> | |||||
namespace ge { | |||||
namespace model_runner { | |||||
class LabelGuard { | |||||
public: | |||||
explicit LabelGuard(void *label_info) : label_info_(reinterpret_cast<uintptr_t>(label_info)) {} | |||||
~LabelGuard(); | |||||
void *GetLabelInfo() { return reinterpret_cast<void *>(label_info_); } | |||||
private: | |||||
uintptr_t label_info_; | |||||
}; | |||||
class LabelManager { | |||||
public: | |||||
static std::shared_ptr<LabelManager> GetInstance(); | |||||
std::shared_ptr<LabelGuard> GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids, | |||||
const std::vector<void *> &all_label); | |||||
private: | |||||
std::mutex model_info_mapping_mutex_; | |||||
std::map<rtModel_t, std::map<std::string, std::weak_ptr<LabelGuard>>> model_info_mapping_; | |||||
static std::weak_ptr<LabelManager> instance_; | |||||
static std::mutex instance_mutex_; | |||||
}; | |||||
} // namespace model_runner | |||||
} // namespace ge | |||||
#endif // GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_ |
@@ -24,14 +24,14 @@ LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||||
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | : TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | ||||
task_info_(task_info), | task_info_(task_info), | ||||
stream_(nullptr), | stream_(nullptr), | ||||
all_label_resource_(), | |||||
label_info_(nullptr) { | label_info_(nullptr) { | ||||
if (task_info_ == nullptr) { | if (task_info_ == nullptr) { | ||||
GELOGW("task_info_ is null!"); | GELOGW("task_info_ is null!"); | ||||
return; | return; | ||||
} | } | ||||
all_label_resource_ = model_context.label_list(); | |||||
rt_model_handle_ = model_context.rt_model_handle(); | |||||
auto all_label_resource = model_context.label_list(); | |||||
auto stream_list = model_context.stream_list(); | auto stream_list = model_context.stream_list(); | ||||
uint32_t stream_id = task_info->stream_id(); | uint32_t stream_id = task_info->stream_id(); | ||||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | ||||
@@ -40,52 +40,24 @@ LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||||
return; | return; | ||||
} | } | ||||
stream_ = stream_list[stream_id]; | stream_ = stream_list[stream_id]; | ||||
} | |||||
LabelSwitchTask::~LabelSwitchTask() { | |||||
if (label_info_ != nullptr) { | |||||
rtError_t rt_ret = rtFree(label_info_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||||
} | |||||
label_info_ = nullptr; | |||||
label_manager_ = LabelManager::GetInstance(); | |||||
if (label_manager_ == nullptr) { | |||||
GELOGW("Get label manager instance failed."); | |||||
return; | |||||
} | } | ||||
label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, task_info_->label_list(), all_label_resource); | |||||
} | } | ||||
LabelSwitchTask::~LabelSwitchTask() {} | |||||
bool LabelSwitchTask::Distribute() { | bool LabelSwitchTask::Distribute() { | ||||
GELOGI("LabelSwitchTask Distribute start."); | GELOGI("LabelSwitchTask Distribute start."); | ||||
if (!CheckParamValid()) { | if (!CheckParamValid()) { | ||||
return false; | return false; | ||||
} | } | ||||
const std::vector<uint32_t> &label_index_list = task_info_->label_list(); | |||||
std::vector<void *> label_list(task_info_->label_size(), nullptr); | |||||
for (size_t i = 0; i < task_info_->label_size(); ++i) { | |||||
uint32_t label_index = label_index_list[i]; | |||||
if (label_index >= all_label_resource_.size()) { | |||||
GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, | |||||
all_label_resource_.size()); | |||||
return false; | |||||
} | |||||
label_list[i] = all_label_resource_[label_index]; | |||||
GELOGI("Case %zu: label id %zu.", i, label_index); | |||||
} | |||||
uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | |||||
rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); | |||||
void *label_info = label_info_->GetLabelInfo(); | |||||
rtError_t rt_ret = rtLabelSwitchByIndex(task_info_->cond(), task_info_->label_size(), label_info, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
return false; | return false; | ||||
@@ -117,8 +89,8 @@ bool LabelSwitchTask::CheckParamValid() { | |||||
return false; | return false; | ||||
} | } | ||||
if (label_info_ != nullptr) { | |||||
GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||||
if (label_info_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "CopyLabelList failed, label info is null."); | |||||
return false; | return false; | ||||
} | } | ||||
@@ -126,6 +98,5 @@ bool LabelSwitchTask::CheckParamValid() { | |||||
} | } | ||||
REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | ||||
} // namespace model_runner | } // namespace model_runner | ||||
} // namespace ge | } // namespace ge |
@@ -19,6 +19,7 @@ | |||||
#include <memory> | #include <memory> | ||||
#include "ge_runtime/task/task.h" | #include "ge_runtime/task/task.h" | ||||
#include "ge_runtime/task/label_manager.h" | |||||
namespace ge { | namespace ge { | ||||
namespace model_runner { | namespace model_runner { | ||||
@@ -35,8 +36,9 @@ class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||||
std::shared_ptr<LabelSwitchTaskInfo> task_info_; | std::shared_ptr<LabelSwitchTaskInfo> task_info_; | ||||
void *stream_; | void *stream_; | ||||
std::vector<void *> all_label_resource_; | |||||
void *label_info_; | |||||
rtModel_t rt_model_handle_; | |||||
std::shared_ptr<LabelGuard> label_info_; | |||||
std::shared_ptr<LabelManager> label_manager_; | |||||
}; | }; | ||||
} // namespace model_runner | } // namespace model_runner | ||||
} // namespace ge | } // namespace ge | ||||