From cebc89dd3b88fd70246284e3ca9a9043417e51a5 Mon Sep 17 00:00:00 2001 From: wuweikang Date: Wed, 14 Apr 2021 22:39:11 +0800 Subject: [PATCH] bugfix for release memory --- ge/graph/manager/graph_manager.cc | 10 ++++++++++ ge/graph/manager/graph_manager_utils.h | 5 ++++- ge/model/ge_root_model.h | 3 +++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index a64cb1ec..d90c03d9 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -121,6 +121,7 @@ const uint32_t kInitGraphCount = 1; const uint32_t kNotAdded = 0; const uint32_t kStartAdd = 1; const uint32_t kDoneAdded = 2; +const uint32_t kNeverLoaded = 0; bool IsTailingOptimization() { string is_tailing_optimization_option; @@ -2584,6 +2585,15 @@ void GraphManager::ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success.", graph_id, model_id); } graph_node->SetLoadFlag(false); + // Allow model to be loaded agagin without adding graph again + graph_node->SetLoadCount(graph_node->GetLoadRecord()); + graph_node->SetLoadRecord(kNeverLoaded); + GeRootModelPtr ge_root_model = graph_node->GetGeRootModel(); + if (ge_root_model == nullptr) { + GELOGW("ge_root_model is null, graph_id:%u", graph_id); + return; + } + ge_root_model->ClearAllModelId(); rt_ret = rtDeviceReset(GetContext().DeviceId()); if (rt_ret != RT_ERROR_NONE) { REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, when GraphManager %s", diff --git a/ge/graph/manager/graph_manager_utils.h b/ge/graph/manager/graph_manager_utils.h index ffbc20cf..bebba93e 100644 --- a/ge/graph/manager/graph_manager_utils.h +++ b/ge/graph/manager/graph_manager_utils.h @@ -178,9 +178,12 @@ class GraphNode { void SetSemSize(uint32_t size) { sem_.SetMaxSize(size); } uint32_t GetLoadCount() const { return load_count_; } + void SetLoadCount(uint32_t count) { load_count_ = count; } + uint32_t GetLoadRecord() const { return load_record_; } + void SetLoadRecord(uint32_t record) { load_record_ = record; } + void IncreaseLoadRecord() { ++load_record_; } void IncreaseLoadCount(); void DecreaseLoadCount() { --load_count_; } - void IncreaseLoadRecord() { ++load_record_; } // run graph asynchronous listener std::shared_ptr graph_run_async_listener_; diff --git a/ge/model/ge_root_model.h b/ge/model/ge_root_model.h index d0e0af54..5eaa9b29 100755 --- a/ge/model/ge_root_model.h +++ b/ge/model/ge_root_model.h @@ -40,10 +40,13 @@ class GeRootModel { } uint32_t GetModelId() const { return model_id_; } void SetModelName(const std::string &model_name) { model_name_ = model_name; } + const std::string &GetModelName() const { return model_name_; } std::vector GetAllModelId() const { return model_ids_; } + void ClearAllModelId() { model_ids_.clear(); } + Status CheckIsUnknownShape(bool &is_dynamic_shape); void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; }