diff --git a/ge/graph/load/model_manager/davinci_model.cc b/ge/graph/load/model_manager/davinci_model.cc index 495ec28e..14e824bd 100755 --- a/ge/graph/load/model_manager/davinci_model.cc +++ b/ge/graph/load/model_manager/davinci_model.cc @@ -3530,6 +3530,31 @@ Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &outp return SUCCESS; } +void DavinciModel::BuildZeroCopyTasksLookupTable() { + std::lock_guard lk(lookup_table_build_lock_); + if (lookup_table_built_) { + return; + } + + const auto default_label_hash = std::hash{}(kDefaultBatchLable); + for (auto &task : zero_copy_tasks_) { + auto label_hash = std::hash{}(task.GetBatchLabel()); + auto addr2offsets = task.GetTaskArgsOffset(); + + label_hash2tasks_[label_hash].insert(&task); + if (label_hash == default_label_hash) { + for (auto &addr2offset : addr2offsets) { + addr2default_label_tasks_[addr2offset.first].insert(&task); + } + } else { + for (auto &addr2offset : addr2offsets) { + addr2specific_label_tasks_[addr2offset.first].insert(&task); + } + } + } + lookup_table_built_ = true; +} + /// /// @ingroup ge /// @brief Copy Data addr to model for direct use. @@ -3551,6 +3576,8 @@ Status DavinciModel::UpdateIoTaskArgs(const std::map & return ACL_ERROR_GE_PARAM_INVALID; } + BuildZeroCopyTasksLookupTable(); + for (const auto &data : data_info) { if (data.first >= blobs.size()) { // check data index. REPORT_INNER_ERROR("E19999", "is_input:%d, data index:%u from model >= blobs.size:%zu from user, mode_id:%u" diff --git a/ge/graph/load/model_manager/davinci_model.h b/ge/graph/load/model_manager/davinci_model.h index 76b0beef..dc9fba85 100755 --- a/ge/graph/load/model_manager/davinci_model.h +++ b/ge/graph/load/model_manager/davinci_model.h @@ -917,6 +917,7 @@ class DavinciModel { Status GetGearAndRealOutSizeInfo(const ComputeGraphPtr &graph, const NodePtr &node); Status GetRealOutputSizeOfCase(const ComputeGraphPtr &graph, size_t input_index, const NodePtr &case_node); Status GetGearAndRealOutShapeInfo(const ComputeGraphPtr &graph, const NodePtr &node); + void BuildZeroCopyTasksLookupTable(); bool is_weight_mem_has_inited_; bool is_feature_map_mem_has_inited_; @@ -1112,6 +1113,13 @@ class DavinciModel { // op name to attrs mapping std::map>> op_name_to_attrs_; + // fields for build fast search hash table for zero copy tasks + std::mutex lookup_table_build_lock_; + bool lookup_table_built_{false}; + std::unordered_map> label_hash2tasks_; + std::unordered_map> addr2specific_label_tasks_; + std::unordered_map> addr2default_label_tasks_; + std::map stream_2_event_; }; } // namespace ge diff --git a/ge/graph/load/model_manager/zero_copy_task.cc b/ge/graph/load/model_manager/zero_copy_task.cc index 85be6d7b..61a9713f 100755 --- a/ge/graph/load/model_manager/zero_copy_task.cc +++ b/ge/graph/load/model_manager/zero_copy_task.cc @@ -54,6 +54,10 @@ Status ZeroCopyTask::SetTaskArgsOffset(uintptr_t addr, size_t offset) { return SUCCESS; } +const std::map>& ZeroCopyTask::GetTaskArgsOffset() const { + return task_addr_offset_; +} + /** * @ingroup ge * @brief Save orignal data of task args. diff --git a/ge/graph/load/model_manager/zero_copy_task.h b/ge/graph/load/model_manager/zero_copy_task.h index efabc814..6420c91c 100644 --- a/ge/graph/load/model_manager/zero_copy_task.h +++ b/ge/graph/load/model_manager/zero_copy_task.h @@ -46,6 +46,8 @@ class ZeroCopyTask { */ ge::Status SetTaskArgsOffset(uintptr_t addr, size_t offset); + const std::map>& GetTaskArgsOffset() const; + /** * @ingroup ge * @brief Is need zero copy.