Merge pull request !36 from zhoufeng/control-sinktags/v0.6.0-beta
@@ -177,37 +177,44 @@ class AicpuTaskInfo : public TaskInfo { | |||
std::vector<void *> output_data_addrs_; | |||
}; | |||
class LabelTaskInfo : public TaskInfo { | |||
class LabelSetTaskInfo : public TaskInfo { | |||
public: | |||
LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
: TaskInfo(stream_id, TaskInfoType::LABEL_SET), label_id_(label_id) {} | |||
~LabelSetTaskInfo() override {} | |||
uint32_t label_id() const { return label_id_; } | |||
protected: | |||
LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id) | |||
: TaskInfo(stream_id, type), label_id_(label_id) {} | |||
virtual ~LabelTaskInfo() override {} | |||
private: | |||
uint32_t label_id_; | |||
}; | |||
class LabelSetTaskInfo : public LabelTaskInfo { | |||
class LabelGotoTaskInfo : public TaskInfo { | |||
public: | |||
LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {} | |||
~LabelSetTaskInfo() override {} | |||
LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
: TaskInfo(stream_id, TaskInfoType::LABEL_GOTO), label_id_(label_id) {} | |||
~LabelGotoTaskInfo() override {} | |||
uint32_t label_id() const { return label_id_; } | |||
private: | |||
uint32_t label_id_; | |||
}; | |||
class LabelSwitchTaskInfo : public LabelTaskInfo { | |||
class LabelSwitchTaskInfo : public TaskInfo { | |||
public: | |||
LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {} | |||
LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_size, const std::vector<uint32_t> &label_list, void *cond) | |||
: TaskInfo(stream_id, TaskInfoType::LABEL_SWITCH), | |||
label_size_(label_size), | |||
label_list_(label_list), | |||
cond_(cond) {} | |||
~LabelSwitchTaskInfo() override {} | |||
}; | |||
uint32_t label_size() { return label_size_; }; | |||
const std::vector<uint32_t> &label_list() { return label_list_; }; | |||
void *cond() { return cond_; }; | |||
class LabelGotoTaskInfo : public LabelTaskInfo { | |||
public: | |||
LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {} | |||
~LabelGotoTaskInfo() override {} | |||
private: | |||
uint32_t label_size_; | |||
std::vector<uint32_t> label_list_; | |||
void *cond_; | |||
}; | |||
class EventTaskInfo : public TaskInfo { | |||
@@ -116,23 +116,34 @@ bool RuntimeModel::InitEvent(uint32_t event_num) { | |||
return true; | |||
} | |||
bool RuntimeModel::InitLabel(uint32_t batch_num) { | |||
GELOGI("batch number:%u.", batch_num); | |||
for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) { | |||
rtLabel_t rt_lLabel = nullptr; | |||
rtError_t rt_ret = rtLabelCreate(&rt_lLabel); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret); | |||
return false; | |||
bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) { | |||
GELOGI("batch number:%u.", davinci_model->GetBatchNum()); | |||
label_list_.resize(davinci_model->GetBatchNum()); | |||
for (auto &task_info : davinci_model->GetTaskInfoList()) { | |||
if (task_info == nullptr) { | |||
GELOGE(PARAM_INVALID, "task_info is null."); | |||
continue; | |||
} | |||
if (task_info->type() != TaskInfoType::LABEL_SET) { | |||
continue; | |||
} | |||
auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info); | |||
if (rt_lLabel == nullptr) { | |||
GELOGE(RT_FAILED, "rtLabel is nullptr!"); | |||
if (label_set_task_info->stream_id() >= stream_list_.size()) { | |||
GELOGE(PARAM_INVALID, "Invalid stream id."); | |||
return false; | |||
} | |||
label_list_.emplace_back(rt_lLabel); | |||
rtLabel_t rt_label = nullptr; | |||
rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
label_list_[label_set_task_info->label_id()] = rt_label; | |||
} | |||
return true; | |||
} | |||
@@ -164,7 +175,7 @@ bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) { | |||
return false; | |||
} | |||
if (!InitLabel(davinci_model->GetBatchNum())) { | |||
if (!InitLabel(davinci_model)) { | |||
return false; | |||
} | |||
@@ -48,7 +48,7 @@ class RuntimeModel { | |||
bool LoadTask(); | |||
bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | |||
bool InitEvent(uint32_t event_num); | |||
bool InitLabel(uint32_t batch_num); | |||
bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model); | |||
bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
@@ -0,0 +1,70 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "ge_runtime/task/label_goto_task.h" | |||
#include "ge_runtime/task/task_factory.h" | |||
namespace ge { | |||
namespace model_runner { | |||
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||
task_info_(task_info), | |||
stream_(nullptr), | |||
label_(nullptr) { | |||
if (task_info_ == nullptr) { | |||
GELOGW("task_info_ is null!"); | |||
return; | |||
} | |||
auto stream_list = model_context.stream_list(); | |||
auto label_list = model_context.label_list(); | |||
uint32_t stream_id = task_info->stream_id(); | |||
uint32_t label_id = task_info->label_id(); | |||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||
if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||
GELOGW("Stream/Label id invalid."); | |||
return; | |||
} | |||
stream_ = stream_list[stream_id]; | |||
label_ = label_list[label_id]; | |||
} | |||
LabelGotoTask::~LabelGotoTask() {} | |||
bool LabelGotoTask::Distribute() { | |||
GELOGI("LabelGotoTask Distribute start."); | |||
if (stream_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "stream is null!"); | |||
return false; | |||
} | |||
if (label_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "label is null!"); | |||
return false; | |||
} | |||
rtError_t rt_ret = rtLabelGotoEx(label_, stream_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
GELOGI("DistributeTask end."); | |||
return true; | |||
} | |||
REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||
} // namespace model_runner | |||
} // namespace ge |
@@ -0,0 +1,41 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
#include <memory> | |||
#include "ge_runtime/task/task.h" | |||
namespace ge { | |||
namespace model_runner { | |||
class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||
public: | |||
LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info); | |||
~LabelGotoTask() override; | |||
bool Distribute() override; | |||
private: | |||
std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||
void *stream_; | |||
void *label_; | |||
}; | |||
} // namespace model_runner | |||
} // namespace ge | |||
#endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ |
@@ -0,0 +1,70 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "ge_runtime/task/label_set_task.h" | |||
#include "ge_runtime/task/task_factory.h" | |||
namespace ge { | |||
namespace model_runner { | |||
LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info) | |||
: TaskRepeater<LabelSetTaskInfo>(model_context, task_info), | |||
task_info_(task_info), | |||
stream_(nullptr), | |||
label_(nullptr) { | |||
if (task_info_ == nullptr) { | |||
GELOGW("task_info_ is null!"); | |||
return; | |||
} | |||
auto stream_list = model_context.stream_list(); | |||
auto label_list = model_context.label_list(); | |||
uint32_t stream_id = task_info->stream_id(); | |||
uint32_t label_id = task_info->label_id(); | |||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||
if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||
GELOGW("Stream/Label id invalid."); | |||
return; | |||
} | |||
stream_ = stream_list[stream_id]; | |||
label_ = label_list[label_id]; | |||
} | |||
LabelSetTask::~LabelSetTask() {} | |||
bool LabelSetTask::Distribute() { | |||
GELOGI("LabelSetTask Distribute start."); | |||
if (stream_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "stream is null!"); | |||
return false; | |||
} | |||
if (label_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "label is null!"); | |||
return false; | |||
} | |||
rtError_t rt_ret = rtLabelSet(label_, stream_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
GELOGI("DistributeTask end."); | |||
return true; | |||
} | |||
REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); | |||
} // namespace model_runner | |||
} // namespace ge |
@@ -0,0 +1,41 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
#define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
#include <memory> | |||
#include "ge_runtime/task/task.h" | |||
namespace ge { | |||
namespace model_runner { | |||
class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> { | |||
public: | |||
LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info); | |||
~LabelSetTask() override; | |||
bool Distribute() override; | |||
private: | |||
std::shared_ptr<LabelSetTaskInfo> task_info_; | |||
void *stream_; | |||
void *label_; | |||
}; | |||
} // namespace model_runner | |||
} // namespace ge | |||
#endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ |
@@ -0,0 +1,131 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "ge_runtime/task/label_switch_task.h" | |||
#include "ge_runtime/task/task_factory.h" | |||
namespace ge { | |||
namespace model_runner { | |||
LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||
const std::shared_ptr<LabelSwitchTaskInfo> &task_info) | |||
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | |||
task_info_(task_info), | |||
stream_(nullptr), | |||
all_label_resource_(), | |||
label_info_(nullptr) { | |||
if (task_info_ == nullptr) { | |||
GELOGW("task_info_ is null!"); | |||
return; | |||
} | |||
all_label_resource_ = model_context.label_list(); | |||
auto stream_list = model_context.stream_list(); | |||
uint32_t stream_id = task_info->stream_id(); | |||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
if (stream_id >= stream_list.size()) { | |||
GELOGW("Stream id invalid."); | |||
return; | |||
} | |||
stream_ = stream_list[stream_id]; | |||
} | |||
LabelSwitchTask::~LabelSwitchTask() { | |||
if (label_info_ != nullptr) { | |||
rtError_t rt_ret = rtFree(label_info_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||
} | |||
label_info_ = nullptr; | |||
} | |||
} | |||
bool LabelSwitchTask::Distribute() { | |||
GELOGI("LabelSwitchTask Distribute start."); | |||
if (!CheckParamValid()) { | |||
return false; | |||
} | |||
const std::vector<uint32_t> &label_index_list = task_info_->label_list(); | |||
std::vector<void *> label_list(task_info_->label_size(), nullptr); | |||
for (size_t i = 0; i < task_info_->label_size(); ++i) { | |||
uint32_t label_index = label_index_list[i]; | |||
if (label_index >= all_label_resource_.size()) { | |||
GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, | |||
all_label_resource_.size()); | |||
return false; | |||
} | |||
label_list[i] = all_label_resource_[label_index]; | |||
GELOGI("Case %zu: label id %zu.", i, label_index); | |||
} | |||
uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | |||
rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
GELOGI("DistributeTask end."); | |||
return true; | |||
} | |||
bool LabelSwitchTask::CheckParamValid() { | |||
if (stream_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "stream is null!"); | |||
return false; | |||
} | |||
if (task_info_->label_list().empty()) { | |||
GELOGE(PARAM_INVALID, "label_list is empty."); | |||
return false; | |||
} | |||
if (task_info_->label_size() != task_info_->label_list().size()) { | |||
GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), | |||
task_info_->label_size()); | |||
return false; | |||
} | |||
if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { | |||
GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); | |||
return false; | |||
} | |||
if (label_info_ != nullptr) { | |||
GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||
return false; | |||
} | |||
return true; | |||
} | |||
REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | |||
} // namespace model_runner | |||
} // namespace ge |
@@ -0,0 +1,44 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
#define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
#include <memory> | |||
#include "ge_runtime/task/task.h" | |||
namespace ge { | |||
namespace model_runner { | |||
class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||
public: | |||
LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info); | |||
~LabelSwitchTask() override; | |||
bool Distribute() override; | |||
private: | |||
bool CheckParamValid(); | |||
std::shared_ptr<LabelSwitchTaskInfo> task_info_; | |||
void *stream_; | |||
std::vector<void *> all_label_resource_; | |||
void *label_info_; | |||
}; | |||
} // namespace model_runner | |||
} // namespace ge | |||
#endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ |