@@ -177,37 +177,44 @@ class AicpuTaskInfo : public TaskInfo { | |||||
std::vector<void *> output_data_addrs_; | std::vector<void *> output_data_addrs_; | ||||
}; | }; | ||||
class LabelTaskInfo : public TaskInfo { | |||||
class LabelSetTaskInfo : public TaskInfo { | |||||
public: | public: | ||||
LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
: TaskInfo(stream_id, TaskInfoType::LABEL_SET), label_id_(label_id) {} | |||||
~LabelSetTaskInfo() override {} | |||||
uint32_t label_id() const { return label_id_; } | uint32_t label_id() const { return label_id_; } | ||||
protected: | |||||
LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id) | |||||
: TaskInfo(stream_id, type), label_id_(label_id) {} | |||||
virtual ~LabelTaskInfo() override {} | |||||
private: | |||||
uint32_t label_id_; | uint32_t label_id_; | ||||
}; | }; | ||||
class LabelSetTaskInfo : public LabelTaskInfo { | |||||
class LabelGotoTaskInfo : public TaskInfo { | |||||
public: | public: | ||||
LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {} | |||||
~LabelSetTaskInfo() override {} | |||||
LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
: TaskInfo(stream_id, TaskInfoType::LABEL_GOTO), label_id_(label_id) {} | |||||
~LabelGotoTaskInfo() override {} | |||||
uint32_t label_id() const { return label_id_; } | |||||
private: | |||||
uint32_t label_id_; | |||||
}; | }; | ||||
class LabelSwitchTaskInfo : public LabelTaskInfo { | |||||
class LabelSwitchTaskInfo : public TaskInfo { | |||||
public: | public: | ||||
LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {} | |||||
LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_size, const std::vector<uint32_t> &label_list, void *cond) | |||||
: TaskInfo(stream_id, TaskInfoType::LABEL_SWITCH), | |||||
label_size_(label_size), | |||||
label_list_(label_list), | |||||
cond_(cond) {} | |||||
~LabelSwitchTaskInfo() override {} | ~LabelSwitchTaskInfo() override {} | ||||
}; | |||||
uint32_t label_size() { return label_size_; }; | |||||
const std::vector<uint32_t> &label_list() { return label_list_; }; | |||||
void *cond() { return cond_; }; | |||||
class LabelGotoTaskInfo : public LabelTaskInfo { | |||||
public: | |||||
LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {} | |||||
~LabelGotoTaskInfo() override {} | |||||
private: | |||||
uint32_t label_size_; | |||||
std::vector<uint32_t> label_list_; | |||||
void *cond_; | |||||
}; | }; | ||||
class EventTaskInfo : public TaskInfo { | class EventTaskInfo : public TaskInfo { | ||||
@@ -116,23 +116,34 @@ bool RuntimeModel::InitEvent(uint32_t event_num) { | |||||
return true; | return true; | ||||
} | } | ||||
bool RuntimeModel::InitLabel(uint32_t batch_num) { | |||||
GELOGI("batch number:%u.", batch_num); | |||||
for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) { | |||||
rtLabel_t rt_lLabel = nullptr; | |||||
rtError_t rt_ret = rtLabelCreate(&rt_lLabel); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret); | |||||
return false; | |||||
bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
GELOGI("batch number:%u.", davinci_model->GetBatchNum()); | |||||
label_list_.resize(davinci_model->GetBatchNum()); | |||||
for (auto &task_info : davinci_model->GetTaskInfoList()) { | |||||
if (task_info == nullptr) { | |||||
GELOGE(PARAM_INVALID, "task_info is null."); | |||||
continue; | |||||
} | |||||
if (task_info->type() != TaskInfoType::LABEL_SET) { | |||||
continue; | |||||
} | } | ||||
auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info); | |||||
if (rt_lLabel == nullptr) { | |||||
GELOGE(RT_FAILED, "rtLabel is nullptr!"); | |||||
if (label_set_task_info->stream_id() >= stream_list_.size()) { | |||||
GELOGE(PARAM_INVALID, "Invalid stream id."); | |||||
return false; | return false; | ||||
} | } | ||||
label_list_.emplace_back(rt_lLabel); | |||||
rtLabel_t rt_label = nullptr; | |||||
rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
label_list_[label_set_task_info->label_id()] = rt_label; | |||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
@@ -164,7 +175,7 @@ bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
return false; | return false; | ||||
} | } | ||||
if (!InitLabel(davinci_model->GetBatchNum())) { | |||||
if (!InitLabel(davinci_model)) { | |||||
return false; | return false; | ||||
} | } | ||||
@@ -48,7 +48,7 @@ class RuntimeModel { | |||||
bool LoadTask(); | bool LoadTask(); | ||||
bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | ||||
bool InitEvent(uint32_t event_num); | bool InitEvent(uint32_t event_num); | ||||
bool InitLabel(uint32_t batch_num); | |||||
bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model); | |||||
bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | ||||
bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | ||||
bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | ||||
@@ -0,0 +1,70 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_runtime/task/label_goto_task.h" | |||||
#include "ge_runtime/task/task_factory.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||||
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||||
task_info_(task_info), | |||||
stream_(nullptr), | |||||
label_(nullptr) { | |||||
if (task_info_ == nullptr) { | |||||
GELOGW("task_info_ is null!"); | |||||
return; | |||||
} | |||||
auto stream_list = model_context.stream_list(); | |||||
auto label_list = model_context.label_list(); | |||||
uint32_t stream_id = task_info->stream_id(); | |||||
uint32_t label_id = task_info->label_id(); | |||||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||||
if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||||
GELOGW("Stream/Label id invalid."); | |||||
return; | |||||
} | |||||
stream_ = stream_list[stream_id]; | |||||
label_ = label_list[label_id]; | |||||
} | |||||
LabelGotoTask::~LabelGotoTask() {} | |||||
bool LabelGotoTask::Distribute() { | |||||
GELOGI("LabelGotoTask Distribute start."); | |||||
if (stream_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "stream is null!"); | |||||
return false; | |||||
} | |||||
if (label_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "label is null!"); | |||||
return false; | |||||
} | |||||
rtError_t rt_ret = rtLabelGotoEx(label_, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
GELOGI("DistributeTask end."); | |||||
return true; | |||||
} | |||||
REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||||
} // namespace model_runner | |||||
} // namespace ge |
@@ -0,0 +1,41 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||||
#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||||
#include <memory> | |||||
#include "ge_runtime/task/task.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||||
public: | |||||
LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info); | |||||
~LabelGotoTask() override; | |||||
bool Distribute() override; | |||||
private: | |||||
std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||||
void *stream_; | |||||
void *label_; | |||||
}; | |||||
} // namespace model_runner | |||||
} // namespace ge | |||||
#endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ |
@@ -0,0 +1,70 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_runtime/task/label_set_task.h" | |||||
#include "ge_runtime/task/task_factory.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info) | |||||
: TaskRepeater<LabelSetTaskInfo>(model_context, task_info), | |||||
task_info_(task_info), | |||||
stream_(nullptr), | |||||
label_(nullptr) { | |||||
if (task_info_ == nullptr) { | |||||
GELOGW("task_info_ is null!"); | |||||
return; | |||||
} | |||||
auto stream_list = model_context.stream_list(); | |||||
auto label_list = model_context.label_list(); | |||||
uint32_t stream_id = task_info->stream_id(); | |||||
uint32_t label_id = task_info->label_id(); | |||||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||||
if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||||
GELOGW("Stream/Label id invalid."); | |||||
return; | |||||
} | |||||
stream_ = stream_list[stream_id]; | |||||
label_ = label_list[label_id]; | |||||
} | |||||
LabelSetTask::~LabelSetTask() {} | |||||
bool LabelSetTask::Distribute() { | |||||
GELOGI("LabelSetTask Distribute start."); | |||||
if (stream_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "stream is null!"); | |||||
return false; | |||||
} | |||||
if (label_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "label is null!"); | |||||
return false; | |||||
} | |||||
rtError_t rt_ret = rtLabelSet(label_, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
GELOGI("DistributeTask end."); | |||||
return true; | |||||
} | |||||
REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); | |||||
} // namespace model_runner | |||||
} // namespace ge |
@@ -0,0 +1,41 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||||
#define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||||
#include <memory> | |||||
#include "ge_runtime/task/task.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> { | |||||
public: | |||||
LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info); | |||||
~LabelSetTask() override; | |||||
bool Distribute() override; | |||||
private: | |||||
std::shared_ptr<LabelSetTaskInfo> task_info_; | |||||
void *stream_; | |||||
void *label_; | |||||
}; | |||||
} // namespace model_runner | |||||
} // namespace ge | |||||
#endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ |
@@ -0,0 +1,131 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_runtime/task/label_switch_task.h" | |||||
#include "ge_runtime/task/task_factory.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||||
const std::shared_ptr<LabelSwitchTaskInfo> &task_info) | |||||
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | |||||
task_info_(task_info), | |||||
stream_(nullptr), | |||||
all_label_resource_(), | |||||
label_info_(nullptr) { | |||||
if (task_info_ == nullptr) { | |||||
GELOGW("task_info_ is null!"); | |||||
return; | |||||
} | |||||
all_label_resource_ = model_context.label_list(); | |||||
auto stream_list = model_context.stream_list(); | |||||
uint32_t stream_id = task_info->stream_id(); | |||||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
if (stream_id >= stream_list.size()) { | |||||
GELOGW("Stream id invalid."); | |||||
return; | |||||
} | |||||
stream_ = stream_list[stream_id]; | |||||
} | |||||
LabelSwitchTask::~LabelSwitchTask() { | |||||
if (label_info_ != nullptr) { | |||||
rtError_t rt_ret = rtFree(label_info_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||||
} | |||||
label_info_ = nullptr; | |||||
} | |||||
} | |||||
bool LabelSwitchTask::Distribute() { | |||||
GELOGI("LabelSwitchTask Distribute start."); | |||||
if (!CheckParamValid()) { | |||||
return false; | |||||
} | |||||
const std::vector<uint32_t> &label_index_list = task_info_->label_list(); | |||||
std::vector<void *> label_list(task_info_->label_size(), nullptr); | |||||
for (size_t i = 0; i < task_info_->label_size(); ++i) { | |||||
uint32_t label_index = label_index_list[i]; | |||||
if (label_index >= all_label_resource_.size()) { | |||||
GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, | |||||
all_label_resource_.size()); | |||||
return false; | |||||
} | |||||
label_list[i] = all_label_resource_[label_index]; | |||||
GELOGI("Case %zu: label id %zu.", i, label_index); | |||||
} | |||||
uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | |||||
rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
GELOGI("DistributeTask end."); | |||||
return true; | |||||
} | |||||
bool LabelSwitchTask::CheckParamValid() { | |||||
if (stream_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "stream is null!"); | |||||
return false; | |||||
} | |||||
if (task_info_->label_list().empty()) { | |||||
GELOGE(PARAM_INVALID, "label_list is empty."); | |||||
return false; | |||||
} | |||||
if (task_info_->label_size() != task_info_->label_list().size()) { | |||||
GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), | |||||
task_info_->label_size()); | |||||
return false; | |||||
} | |||||
if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { | |||||
GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); | |||||
return false; | |||||
} | |||||
if (label_info_ != nullptr) { | |||||
GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | |||||
} // namespace model_runner | |||||
} // namespace ge |
@@ -0,0 +1,44 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||||
#define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||||
#include <memory> | |||||
#include "ge_runtime/task/task.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||||
public: | |||||
LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info); | |||||
~LabelSwitchTask() override; | |||||
bool Distribute() override; | |||||
private: | |||||
bool CheckParamValid(); | |||||
std::shared_ptr<LabelSwitchTaskInfo> task_info_; | |||||
void *stream_; | |||||
std::vector<void *> all_label_resource_; | |||||
void *label_info_; | |||||
}; | |||||
} // namespace model_runner | |||||
} // namespace ge | |||||
#endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ |