diff --git a/ge/graph/load/model_manager/model_manager.cc b/ge/graph/load/model_manager/model_manager.cc index 97ad0054..27cbd526 100755 --- a/ge/graph/load/model_manager/model_manager.cc +++ b/ge/graph/load/model_manager/model_manager.cc @@ -286,6 +286,17 @@ ge::Status ModelManager::DoLoadHybridModelOnline(uint32_t model_id, const string return SUCCESS; } +bool ModelManager::IsNeedHybridLoad(ge::GeRootModel &ge_root_model) { + auto root_graph = ge_root_model.GetRootGraph(); + if (root_graph == nullptr) { + GELOGE(FAILED, "no model on root model"); + return false; + } + bool is_shape_unknown = root_graph->GetGraphUnknownFlag(); + bool is_dsp_partitioned_graph = false; + (void)AttrUtils::GetBool(root_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dsp_partitioned_graph); + return is_shape_unknown || is_dsp_partitioned_graph || GetContext().GetHostExecFlag(); +} /// /// @ingroup domi_ome /// @brief load model online @@ -299,9 +310,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrGetSubgraphInstanceNameToModel(); string model_name = ""; - bool is_shape_unknown = ge_root_model->GetRootGraph()->GetGraphUnknownFlag(); - // if multi subgraph is known, do hybrid load process - if (is_shape_unknown || GetContext().GetHostExecFlag() || (name_to_model.size() > 1)) { + if (IsNeedHybridLoad(*ge_root_model)) { return DoLoadHybridModelOnline(model_id, model_name, ge_root_model, listener); } diff --git a/ge/graph/load/model_manager/model_manager.h b/ge/graph/load/model_manager/model_manager.h index f2d55db7..489320f4 100755 --- a/ge/graph/load/model_manager/model_manager.h +++ b/ge/graph/load/model_manager/model_manager.h @@ -294,6 +294,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { std::vector &output_dims); bool IsDynamicShape(uint32_t model_id); + bool IsNeedHybridLoad(ge::GeRootModel &ge_root_model); ge::Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); ge::Status EnableExceptionDump(const std::map &options);