From 46156bf04feaaa7a3e711cb8846b0c25d0f49e86 Mon Sep 17 00:00:00 2001 From: wjm Date: Mon, 12 Apr 2021 16:31:46 +0800 Subject: [PATCH] fix dump --- ge/common/helper/model_helper.cc | 1 + ge/hybrid/model/hybrid_model_builder.cc | 6 +++--- ge/model/ge_root_model.h | 3 +++ tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 2 ++ 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/ge/common/helper/model_helper.cc b/ge/common/helper/model_helper.cc index 74238bc1..e95c3429 100644 --- a/ge/common/helper/model_helper.cc +++ b/ge/common/helper/model_helper.cc @@ -599,6 +599,7 @@ Status ModelHelper::GenerateGeRootModel(OmFileLoadHelper &om_load_helper) { is_first_model = false; root_model_->SetRootGraph(GraphUtils::GetComputeGraph(cur_model->GetGraph())); root_model_->SetModelId(cur_model->GetModelId()); + root_model_->SetModelName(cur_model->GetName()); model_ = cur_model; continue; } diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index bd4df10d..d413e167 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -130,7 +130,7 @@ HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model) Status HybridModelBuilder::Build() { GE_CHK_STATUS_RET(ValidateParams(), "Failed to validate GeRootModel"); - hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); + hybrid_model_.model_name_ = ge_root_model_->GetName(); GELOGI("[%s] Start to build hybrid model.", GetGraphName()); GE_CHK_STATUS_RET(InitRuntimeParams(), "[%s] Failed to InitRuntimeParams", GetGraphName()); GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[%s] Failed to RecoverGraphUnknownFlag", GetGraphName()); @@ -154,7 +154,7 @@ Status HybridModelBuilder::Build() { Status HybridModelBuilder::BuildForSingleOp() { GE_CHK_STATUS_RET(ValidateParams(), "Failed to validate GeRootModel"); - hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); + hybrid_model_.model_name_ = ge_root_model_->GetName(); GELOGI("[%s] Start to build hybrid model.", GetGraphName()); auto ret = ge_root_model_->GetSubgraphInstanceNameToModel(); const GeModelPtr ge_model = ret[ge_root_model_->GetRootGraph()->GetName()]; @@ -272,7 +272,7 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); // not care result, if no this attr, stand for the op does not need force infershape - (void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); + (void) AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); GELOGD("node [%s] is need do infershape , flag is %d", op_desc->GetName().c_str(), node_item.is_need_force_infershape); diff --git a/ge/model/ge_root_model.h b/ge/model/ge_root_model.h index aa5a4d47..32b6ec4e 100755 --- a/ge/model/ge_root_model.h +++ b/ge/model/ge_root_model.h @@ -35,12 +35,15 @@ class GeRootModel { const ComputeGraphPtr &GetRootGraph() const { return root_graph_; }; void SetModelId(uint32_t model_id) { model_id_ = model_id; } uint32_t GetModelId() const { return model_id_; } + void SetModelName(const std::string &model_name) { model_name_ = model_name; } + const std::string &GetModelName() const { return model_name_; } Status CheckIsUnknownShape(bool &is_dynamic_shape); void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } private: ComputeGraphPtr root_graph_ = nullptr; std::map subgraph_instance_name_to_model_; uint32_t model_id_ = 0; + std::string model_name_; }; } // namespace ge using GeRootModelPtr = std::shared_ptr; diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index f5a802a2..033dff18 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -150,9 +150,11 @@ TEST_F(UtestGeHybrid, index_taskdefs_failed) { ComputeGraphPtr graph = std::make_shared("test"); GeRootModelPtr ge_root_model = make_shared(graph); + ge_root_model->SetModelName("test_name"); HybridModel hybrid_model(ge_root_model); HybridModelBuilder hybrid_model_builder(hybrid_model); + ASSERT_EQ(hybrid_model_builder.Build(), INTERNAL_ERROR); ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); }