@@ -599,6 +599,7 @@ Status ModelHelper::GenerateGeRootModel(OmFileLoadHelper &om_load_helper) { | |||||
is_first_model = false; | is_first_model = false; | ||||
root_model_->SetRootGraph(GraphUtils::GetComputeGraph(cur_model->GetGraph())); | root_model_->SetRootGraph(GraphUtils::GetComputeGraph(cur_model->GetGraph())); | ||||
root_model_->SetModelId(cur_model->GetModelId()); | root_model_->SetModelId(cur_model->GetModelId()); | ||||
root_model_->SetModelName(cur_model->GetName()); | |||||
model_ = cur_model; | model_ = cur_model; | ||||
continue; | continue; | ||||
} | } | ||||
@@ -130,7 +130,7 @@ HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model) | |||||
Status HybridModelBuilder::Build() { | Status HybridModelBuilder::Build() { | ||||
GE_CHK_STATUS_RET(ValidateParams(), "Failed to validate GeRootModel"); | GE_CHK_STATUS_RET(ValidateParams(), "Failed to validate GeRootModel"); | ||||
hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); | |||||
hybrid_model_.model_name_ = ge_root_model_->GetName(); | |||||
GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | ||||
GE_CHK_STATUS_RET(InitRuntimeParams(), "[%s] Failed to InitRuntimeParams", GetGraphName()); | GE_CHK_STATUS_RET(InitRuntimeParams(), "[%s] Failed to InitRuntimeParams", GetGraphName()); | ||||
GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[%s] Failed to RecoverGraphUnknownFlag", GetGraphName()); | GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[%s] Failed to RecoverGraphUnknownFlag", GetGraphName()); | ||||
@@ -154,7 +154,7 @@ Status HybridModelBuilder::Build() { | |||||
Status HybridModelBuilder::BuildForSingleOp() { | Status HybridModelBuilder::BuildForSingleOp() { | ||||
GE_CHK_STATUS_RET(ValidateParams(), "Failed to validate GeRootModel"); | GE_CHK_STATUS_RET(ValidateParams(), "Failed to validate GeRootModel"); | ||||
hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); | |||||
hybrid_model_.model_name_ = ge_root_model_->GetName(); | |||||
GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | ||||
auto ret = ge_root_model_->GetSubgraphInstanceNameToModel(); | auto ret = ge_root_model_->GetSubgraphInstanceNameToModel(); | ||||
const GeModelPtr ge_model = ret[ge_root_model_->GetRootGraph()->GetName()]; | const GeModelPtr ge_model = ret[ge_root_model_->GetRootGraph()->GetName()]; | ||||
@@ -272,7 +272,7 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt | |||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
// not care result, if no this attr, stand for the op does not need force infershape | // not care result, if no this attr, stand for the op does not need force infershape | ||||
(void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); | |||||
(void) AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); | |||||
GELOGD("node [%s] is need do infershape , flag is %d", | GELOGD("node [%s] is need do infershape , flag is %d", | ||||
op_desc->GetName().c_str(), | op_desc->GetName().c_str(), | ||||
node_item.is_need_force_infershape); | node_item.is_need_force_infershape); | ||||
@@ -35,12 +35,15 @@ class GeRootModel { | |||||
const ComputeGraphPtr &GetRootGraph() const { return root_graph_; }; | const ComputeGraphPtr &GetRootGraph() const { return root_graph_; }; | ||||
void SetModelId(uint32_t model_id) { model_id_ = model_id; } | void SetModelId(uint32_t model_id) { model_id_ = model_id; } | ||||
uint32_t GetModelId() const { return model_id_; } | uint32_t GetModelId() const { return model_id_; } | ||||
void SetModelName(const std::string &model_name) { model_name_ = model_name; } | |||||
const std::string &GetModelName() const { return model_name_; } | |||||
Status CheckIsUnknownShape(bool &is_dynamic_shape); | Status CheckIsUnknownShape(bool &is_dynamic_shape); | ||||
void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } | void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } | ||||
private: | private: | ||||
ComputeGraphPtr root_graph_ = nullptr; | ComputeGraphPtr root_graph_ = nullptr; | ||||
std::map<std::string, GeModelPtr> subgraph_instance_name_to_model_; | std::map<std::string, GeModelPtr> subgraph_instance_name_to_model_; | ||||
uint32_t model_id_ = 0; | uint32_t model_id_ = 0; | ||||
std::string model_name_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | ||||
@@ -150,9 +150,11 @@ TEST_F(UtestGeHybrid, index_taskdefs_failed) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ||||
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | ||||
ge_root_model->SetModelName("test_name"); | |||||
HybridModel hybrid_model(ge_root_model); | HybridModel hybrid_model(ge_root_model); | ||||
HybridModelBuilder hybrid_model_builder(hybrid_model); | HybridModelBuilder hybrid_model_builder(hybrid_model); | ||||
ASSERT_EQ(hybrid_model_builder.Build(), INTERNAL_ERROR); | |||||
ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); | ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); | ||||
} | } | ||||