Browse Source

!1307 fix bug of dynamic shape load error

From: @wan_xuelei
Reviewed-by: @ji_chen,@xchu42
Signed-off-by: @lbisdaddy
tags/v1.2.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
3050d3984a
2 changed files with 13 additions and 3 deletions
  1. +12
    -3
      ge/graph/load/model_manager/model_manager.cc
  2. +1
    -0
      ge/graph/load/model_manager/model_manager.h

+ 12
- 3
ge/graph/load/model_manager/model_manager.cc View File

@@ -286,6 +286,17 @@ ge::Status ModelManager::DoLoadHybridModelOnline(uint32_t model_id, const string
return SUCCESS;
}

bool ModelManager::IsNeedHybridLoad(ge::GeRootModel &ge_root_model) {
auto root_graph = ge_root_model.GetRootGraph();
if (root_graph == nullptr) {
GELOGE(FAILED, "no model on root model");
return false;
}
bool is_shape_unknown = root_graph->GetGraphUnknownFlag();
bool is_dsp_partitioned_graph = false;
(void)AttrUtils::GetBool(root_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dsp_partitioned_graph);
return is_shape_unknown || is_dsp_partitioned_graph || GetContext().GetHostExecFlag();
}
///
/// @ingroup domi_ome
/// @brief load model online
@@ -299,9 +310,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge
}
auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel();
string model_name = "";
bool is_shape_unknown = ge_root_model->GetRootGraph()->GetGraphUnknownFlag();
// if multi subgraph is known, do hybrid load process
if (is_shape_unknown || GetContext().GetHostExecFlag() || (name_to_model.size() > 1)) {
if (IsNeedHybridLoad(*ge_root_model)) {
return DoLoadHybridModelOnline(model_id, model_name, ge_root_model, listener);
}



+ 1
- 0
ge/graph/load/model_manager/model_manager.h View File

@@ -294,6 +294,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager {
std::vector<InputOutputDims> &output_dims);

bool IsDynamicShape(uint32_t model_id);
bool IsNeedHybridLoad(ge::GeRootModel &ge_root_model);
ge::Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info);

ge::Status EnableExceptionDump(const std::map<string, string> &options);


Loading…
Cancel
Save