From: @HW_KK Reviewed-by: @ji_chen,@selfws,@wqtshg Signed-off-by: @lbisdaddypull/1525/MERGE
@@ -121,6 +121,7 @@ const uint32_t kInitGraphCount = 1; | |||||
const uint32_t kNotAdded = 0; | const uint32_t kNotAdded = 0; | ||||
const uint32_t kStartAdd = 1; | const uint32_t kStartAdd = 1; | ||||
const uint32_t kDoneAdded = 2; | const uint32_t kDoneAdded = 2; | ||||
const uint32_t kNeverLoaded = 0; | |||||
bool IsTailingOptimization() { | bool IsTailingOptimization() { | ||||
string is_tailing_optimization_option; | string is_tailing_optimization_option; | ||||
@@ -2584,6 +2585,15 @@ void GraphManager::ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph | |||||
GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success.", graph_id, model_id); | GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success.", graph_id, model_id); | ||||
} | } | ||||
graph_node->SetLoadFlag(false); | graph_node->SetLoadFlag(false); | ||||
// Allow model to be loaded agagin without adding graph again | |||||
graph_node->SetLoadCount(graph_node->GetLoadRecord()); | |||||
graph_node->SetLoadRecord(kNeverLoaded); | |||||
GeRootModelPtr ge_root_model = graph_node->GetGeRootModel(); | |||||
if (ge_root_model == nullptr) { | |||||
GELOGW("ge_root_model is null, graph_id:%u", graph_id); | |||||
return; | |||||
} | |||||
ge_root_model->ClearAllModelId(); | |||||
rt_ret = rtDeviceReset(GetContext().DeviceId()); | rt_ret = rtDeviceReset(GetContext().DeviceId()); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, when GraphManager %s", | REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, when GraphManager %s", | ||||
@@ -178,9 +178,12 @@ class GraphNode { | |||||
void SetSemSize(uint32_t size) { sem_.SetMaxSize(size); } | void SetSemSize(uint32_t size) { sem_.SetMaxSize(size); } | ||||
uint32_t GetLoadCount() const { return load_count_; } | uint32_t GetLoadCount() const { return load_count_; } | ||||
void SetLoadCount(uint32_t count) { load_count_ = count; } | |||||
uint32_t GetLoadRecord() const { return load_record_; } | |||||
void SetLoadRecord(uint32_t record) { load_record_ = record; } | |||||
void IncreaseLoadRecord() { ++load_record_; } | |||||
void IncreaseLoadCount(); | void IncreaseLoadCount(); | ||||
void DecreaseLoadCount() { --load_count_; } | void DecreaseLoadCount() { --load_count_; } | ||||
void IncreaseLoadRecord() { ++load_record_; } | |||||
// run graph asynchronous listener | // run graph asynchronous listener | ||||
std::shared_ptr<RunAsyncListener> graph_run_async_listener_; | std::shared_ptr<RunAsyncListener> graph_run_async_listener_; | ||||
@@ -40,10 +40,13 @@ class GeRootModel { | |||||
} | } | ||||
uint32_t GetModelId() const { return model_id_; } | uint32_t GetModelId() const { return model_id_; } | ||||
void SetModelName(const std::string &model_name) { model_name_ = model_name; } | void SetModelName(const std::string &model_name) { model_name_ = model_name; } | ||||
const std::string &GetModelName() const { return model_name_; } | const std::string &GetModelName() const { return model_name_; } | ||||
std::vector<uint32_t> GetAllModelId() const { return model_ids_; } | std::vector<uint32_t> GetAllModelId() const { return model_ids_; } | ||||
void ClearAllModelId() { model_ids_.clear(); } | |||||
Status CheckIsUnknownShape(bool &is_dynamic_shape); | Status CheckIsUnknownShape(bool &is_dynamic_shape); | ||||
void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } | void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } | ||||