diff --git a/ge/single_op/single_op_manager.cc b/ge/single_op/single_op_manager.cc index 3cdb7f7d..fddbeec2 100644 --- a/ge/single_op/single_op_manager.cc +++ b/ge/single_op/single_op_manager.cc @@ -46,14 +46,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::GetOpFr return ACL_ERROR_GE_MEMORY_ALLOCATION; } - SingleOp *op = res->GetOperator(model_data.model_data); + SingleOp *op = res->GetOperator(model_id); if (op != nullptr) { GELOGD("Got operator from stream cache"); *single_op = op; return SUCCESS; } - return res->BuildOperator(model_name, model_data, single_op); + return res->BuildOperator(model_data, single_op, model_id); } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::ReleaseResource(void *stream) { @@ -116,14 +116,14 @@ Status SingleOpManager::GetDynamicOpFromModel(const string &model_name, return ACL_ERROR_GE_MEMORY_ALLOCATION; } - DynamicSingleOp *op = res->GetDynamicOperator(model_data.model_data); + DynamicSingleOp *op = res->GetDynamicOperator(model_id); if (op != nullptr) { GELOGD("Got operator from stream cache"); *single_op = op; return SUCCESS; } - return res->BuildDynamicOperator(model_name, model_data, single_op); + return res->BuildDynamicOperator(model_data, single_op, model_id); } void SingleOpManager::RegisterTilingFunc() { diff --git a/ge/single_op/stream_resource.cc b/ge/single_op/stream_resource.cc index a3acf6b7..21d127ec 100755 --- a/ge/single_op/stream_resource.cc +++ b/ge/single_op/stream_resource.cc @@ -41,7 +41,7 @@ StreamResource::~StreamResource() { } } -SingleOp *StreamResource::GetOperator(const void *key) { +SingleOp *StreamResource::GetOperator(const uint64_t key) { std::lock_guard lk(mu_); auto it = op_map_.find(key); if (it == op_map_.end()) { @@ -51,7 +51,7 @@ SingleOp *StreamResource::GetOperator(const void *key) { return it->second.get(); } -DynamicSingleOp *StreamResource::GetDynamicOperator(const void *key) { +DynamicSingleOp *StreamResource::GetDynamicOperator(const uint64_t key) { std::lock_guard lk(mu_); auto it = dynamic_op_map_.find(key); if (it == dynamic_op_map_.end()) { @@ -138,11 +138,12 @@ uint8_t *StreamResource::MallocWeight(const std::string &purpose, size_t size) { return buffer; } -Status StreamResource::BuildDynamicOperator(const string &model_name, - const ModelData &model_data, - DynamicSingleOp **single_op) { +Status StreamResource::BuildDynamicOperator(const ModelData &model_data, + DynamicSingleOp **single_op, + const uint64_t model_id) { + const string &model_name = std::to_string(model_id); std::lock_guard lk(mu_); - auto it = dynamic_op_map_.find(model_data.model_data); + auto it = dynamic_op_map_.find(model_id); if (it != dynamic_op_map_.end()) { *single_op = it->second.get(); return SUCCESS; @@ -162,13 +163,14 @@ Status StreamResource::BuildDynamicOperator(const string &model_name, GE_CHK_STATUS_RET(model.BuildDynamicOp(*this, *new_op), "Build op failed. op = %s, ret = %u", model_name.c_str(), ret); *single_op = new_op.get(); - dynamic_op_map_[model_data.model_data] = std::move(new_op); + dynamic_op_map_[model_id] = std::move(new_op); return SUCCESS; } -Status StreamResource::BuildOperator(const string &model_name, const ModelData &model_data, SingleOp **single_op) { +Status StreamResource::BuildOperator(const ModelData &model_data, SingleOp **single_op, const uint64_t model_id) { + const string &model_name = std::to_string(model_id); std::lock_guard lk(mu_); - auto it = op_map_.find(model_data.model_data); + auto it = op_map_.find(model_id); if (it != op_map_.end()) { *single_op = it->second.get(); return SUCCESS; @@ -191,7 +193,7 @@ Status StreamResource::BuildOperator(const string &model_name, const ModelData & GE_CHK_STATUS_RET(model.BuildOp(*this, *new_op), "Build op failed. op = %s, ret = %u", model_name.c_str(), ret); *single_op = new_op.get(); - op_map_[model_data.model_data] = std::move(new_op); + op_map_[model_id] = std::move(new_op); return SUCCESS; } diff --git a/ge/single_op/stream_resource.h b/ge/single_op/stream_resource.h index d2c1ca36..73a6231b 100755 --- a/ge/single_op/stream_resource.h +++ b/ge/single_op/stream_resource.h @@ -40,11 +40,11 @@ class StreamResource { rtStream_t GetStream() const; void SetStream(rtStream_t stream); - SingleOp *GetOperator(const void *key); - DynamicSingleOp *GetDynamicOperator(const void *key); + SingleOp *GetOperator(const uint64_t key); + DynamicSingleOp *GetDynamicOperator(const uint64_t key); - Status BuildOperator(const std::string &model_name, const ModelData &model_data, SingleOp **single_op); - Status BuildDynamicOperator(const std::string &model_name, const ModelData &model_data, DynamicSingleOp **single_op); + Status BuildOperator(const ModelData &model_data, SingleOp **single_op, const uint64_t model_id); + Status BuildDynamicOperator(const ModelData &model_data, DynamicSingleOp **single_op, const uint64_t model_id); uint8_t *MallocMemory(const std::string &purpose, size_t size, bool holding_lock = true); uint8_t *MallocWeight(const std::string &purpose, size_t size); @@ -60,8 +60,8 @@ class StreamResource { size_t max_memory_size_ = 0; std::vector memory_list_; std::vector weight_list_; - std::unordered_map> op_map_; - std::unordered_map> dynamic_op_map_; + std::unordered_map> op_map_; + std::unordered_map> dynamic_op_map_; rtStream_t stream_ = nullptr; std::mutex mu_; std::mutex stream_mu_; diff --git a/tests/ut/ge/single_op/stream_resource_unittest.cc b/tests/ut/ge/single_op/stream_resource_unittest.cc index b7306815..8a5124ef 100644 --- a/tests/ut/ge/single_op/stream_resource_unittest.cc +++ b/tests/ut/ge/single_op/stream_resource_unittest.cc @@ -58,6 +58,18 @@ TEST_F(UtestStreamResource, test_malloc_memory) { ASSERT_NE(res.MallocMemory(purpose, 100), nullptr); } +TEST_F(UtestStreamResource, test_malloc_memory) { + StreamResource res((uintptr_t)1); + ModelData model_data; + SingleOp *single_op = nullptr; + DynamicSingleOp *dynamic_single_op = nullptr; + res.op_map_[0] = &single_op; + res.dynamic_op_map_[1] = &dynamic_single_op; + + ASSERT_EQ(res.BuildOperator(model_data, &single_op, 0), SUCCESS); + ASSERT_EQ(res.BuildDynamicOperator(model_data, &dynamic_single_op, 1), SUCCESS); +} + /* TEST_F(UtestStreamResource, test_do_malloc_memory) { size_t max_allocated = 0;