From: @lichun30 Reviewed-by: @xchu42,@ji_chen Signed-off-by: @lbisdaddytags/v1.2.0
@@ -135,6 +135,7 @@ class HybridModel { | |||
std::string model_name_; | |||
GeRootModelPtr ge_root_model_; | |||
std::map<uint32_t, NodeItem *> input_nodes_; | |||
ComputeGraphPtr root_graph_; | |||
std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148 | |||
std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148 | |||
std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_; | |||
@@ -136,12 +136,12 @@ Status HybridModelBuilder::Build() { | |||
GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[%s] Failed to RecoverGraphUnknownFlag", GetGraphName()); | |||
GE_CHK_STATUS_RET(IndexSpecialNodes(), "[%s] Failed to index nodes", GetGraphName()); | |||
GE_CHK_STATUS_RET(IndexTaskDefs(), "[%s] Failed to index task defs", GetGraphName()); | |||
GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName()); | |||
GE_CHK_STATUS_RET(LoadGraph(), "[%s] Failed to load graph", GetGraphName()); | |||
GE_CHK_STATUS_RET(AssignUninitializedConstantOps(), "[%s] Failed to assign uninitialized constants", GetGraphName()); | |||
GE_CHK_STATUS_RET(TransAllVarData(), "[%s] Failed to trans all var data", GetGraphName()); | |||
GE_CHK_STATUS_RET(CopyVarData(), "[%s] Failed to copy var data", GetGraphName()); | |||
GE_CHK_STATUS_RET(InitModelMem(), "[%s] Failed to init memory", GetGraphName()); | |||
GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName()); | |||
GE_CHK_STATUS_RET(InitConstantOps(), "[%s] Failed to init constant op", GetGraphName()); | |||
GE_CHK_STATUS_RET(InitVariableTensors(), "[%s] Failed to init variables", GetGraphName()); | |||
GE_CHK_STATUS_RET(LoadTasks(), "[%s] Failed to load tasks", GetGraphName()); | |||
@@ -599,9 +599,10 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) { | |||
Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) { | |||
merged_graph = MakeShared<ComputeGraph>("MergedGraph"); | |||
for (const auto &node : root_graph.GetDirectNode()) { | |||
merged_graph->SetGraphUnknownFlag(root_graph->GetGraphUnknownFlag()); | |||
for (const auto &node : root_graph->GetDirectNode()) { | |||
GE_CHECK_NOTNULL(node); | |||
auto op_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
@@ -631,7 +632,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap | |||
} | |||
} | |||
} | |||
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph), | |||
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph), | |||
"[%s] Failed to merge subgraph.", | |||
subgraph->GetName().c_str()); | |||
} | |||
@@ -647,18 +648,19 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap | |||
return a_level < b_level; | |||
}); | |||
for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) { | |||
for (auto &remained_subgraph : root_graph->GetAllSubgraphs()) { | |||
GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); | |||
GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph), | |||
"Failed to add subgraph [%s]", | |||
remained_subgraph->GetName().c_str()); | |||
remained_subgraph->SetParentGraph(merged_graph); | |||
} | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, | |||
ComputeGraph &parent_graph, | |||
Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph, | |||
ComputeGraphPtr &parent_graph, | |||
ComputeGraph &sub_graph) { | |||
auto parent_node = sub_graph.GetParentNode(); | |||
GE_CHECK_NOTNULL(parent_node); | |||
@@ -687,15 +689,23 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, | |||
} | |||
} | |||
parent_graph.AddNode(sub_node); | |||
if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { | |||
for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { | |||
auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i); | |||
GE_CHECK_NOTNULL(sub_sub_graph); | |||
sub_sub_graph->SetParentGraph(parent_graph); | |||
} | |||
} | |||
parent_graph->AddNode(sub_node); | |||
GELOGD("[%s::%s] added to parent graph: [%s].", | |||
sub_graph.GetName().c_str(), | |||
sub_node->GetName().c_str(), | |||
parent_graph.GetName().c_str()); | |||
parent_graph->GetName().c_str()); | |||
sub_node->SetOwnerComputeGraph(parent_graph); | |||
} | |||
GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); | |||
root_graph.RemoveSubgraph(sub_graph.GetName()); | |||
root_graph->RemoveSubgraph(sub_graph.GetName()); | |||
return SUCCESS; | |||
} | |||
@@ -747,14 +757,14 @@ Status HybridModelBuilder::LoadGraph() { | |||
GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", | |||
root_graph->GetDirectNodesSize(), | |||
root_graph->GetAllNodesSize()); | |||
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs."); | |||
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "Failed to unfold subgraphs."); | |||
root_graph = std::move(merged_graph); | |||
GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", | |||
root_graph->GetDirectNodesSize(), | |||
root_graph->GetAllNodesSize()); | |||
} | |||
root_graph_ = root_graph; | |||
hybrid_model_.root_graph_ = root_graph; | |||
// Reset node id by topological order across all subgraphs | |||
int64_t index = 0; | |||
for (const auto &node : root_graph->GetAllNodes()) { | |||
@@ -1030,9 +1040,13 @@ Status HybridModelBuilder::InitWeights() { | |||
GELOGI("Init weight mem successfully, weight base %p, weight size = %zu", | |||
weight_base, | |||
sub_weight_buffer->GetSize()); | |||
auto root_graph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); | |||
hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(),std::move(sub_weight_buffer)); | |||
for (auto &node : root_graph->GetDirectNode()) { | |||
auto subgraph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); | |||
if (subgraph != ge_root_model_->GetRootGraph()) { | |||
subgraph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first); | |||
} | |||
GE_CHECK_NOTNULL(subgraph); | |||
hybrid_model_.weight_buffer_map_.emplace(subgraph->GetName(), std::move(sub_weight_buffer)); | |||
for (auto &node : subgraph->GetDirectNode()) { | |||
if (node->GetType() != CONSTANT) { | |||
continue; | |||
} | |||
@@ -2044,7 +2058,7 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { | |||
GELOGD("[%s] Start to get parallel group from subgraph: %s", | |||
node_item->NodeName().c_str(), | |||
subgraph_name.c_str()); | |||
auto subgraph = root_graph_->GetSubgraph(subgraph_name); | |||
auto subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_name); | |||
GE_CHECK_NOTNULL(subgraph); | |||
for (const auto &sub_node : subgraph->GetAllNodes()) { | |||
std::string parallel_group; | |||
@@ -47,8 +47,8 @@ class HybridModelBuilder { | |||
static Status HandleDtString(const GeTensor &tensor, void *var_addr); | |||
static Status MergeInputNodes(ComputeGraph &compute_graph); | |||
static Status MergeNetOutputNode(ComputeGraph &compute_graph); | |||
static Status UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph); | |||
static Status UnfoldSubgraph(ComputeGraph &root_graph, ComputeGraph &parent_graph, ComputeGraph &sub_graph); | |||
static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph); | |||
static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph); | |||
static Status BuildInputMapping(GraphItem &graph_item, | |||
std::vector<NodeItem *> &data_nodes, | |||
bool is_root_graph); | |||
@@ -100,7 +100,6 @@ class HybridModelBuilder { | |||
NodeItem *MutableNodeItem(const NodePtr &node); | |||
GeRootModelPtr ge_root_model_; | |||
ComputeGraphPtr root_graph_; | |||
std::map<std::string, GeModelPtr> subgraph_models_; | |||
std::map<std::string, NodePtr> constant_op_nodes_; | |||
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_; | |||
@@ -256,3 +256,77 @@ TEST_F(UtestGeHybrid, init_weight_success) { | |||
HybridModelExecutor executor(model_ptr, device_id, stream); | |||
executor.Init(); | |||
} | |||
TEST_F(UtestGeHybrid, unfold_subgraphs_success) { | |||
ComputeGraphPtr merged_graph = nullptr; | |||
ComputeGraphPtr sub_sub_graph1 = std::make_shared<ComputeGraph>("while_cond"); | |||
OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); | |||
NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); | |||
ComputeGraphPtr sub_sub_graph2 = std::make_shared<ComputeGraph>("while_body"); | |||
/*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT); | |||
NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/ | |||
OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); | |||
NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); | |||
sub_sub_graph2->SetGraphUnknownFlag(true); | |||
/*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD); | |||
NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node); | |||
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node); | |||
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/ | |||
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); | |||
OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); | |||
NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); | |||
sub_graph->SetGraphUnknownFlag(true); | |||
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond"); | |||
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body"); | |||
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond"); | |||
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body"); | |||
ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("root_graph"); | |||
auto partitioned_call_op_desc = MakeShared<OpDesc>("partitioned_call", PARTITIONEDCALL); | |||
auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); | |||
partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph"); | |||
partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph"); | |||
root_graph->AddSubGraph(sub_sub_graph1); | |||
root_graph->AddSubGraph(sub_sub_graph2); | |||
sub_sub_graph1->SetParentGraph(root_graph); | |||
sub_sub_graph2->SetParentGraph(root_graph); | |||
sub_sub_graph1->SetParentNode(sub_graph_while_node); | |||
sub_sub_graph2->SetParentNode(sub_graph_while_node); | |||
root_graph->AddSubGraph(sub_graph); | |||
sub_graph->SetParentNode(partitioned_call_node); | |||
sub_graph->SetParentGraph(root_graph); | |||
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(root_graph); | |||
HybridModel hybrid_model(root_model); | |||
HybridModelBuilder hybrid_model_builder(hybrid_model); | |||
// subgraph num before unfold: 1 | |||
EXPECT_EQ(root_graph->GetAllSubgraphs().size(), 3); | |||
// num of nodes in root_graph before unfold: 1, name: partitioned_call | |||
EXPECT_EQ(root_graph->GetDirectNodesSize(), 1); | |||
EXPECT_EQ(root_graph->GetDirectNode().at(0)->GetName(), "partitioned_call"); | |||
// two sub_sub_graphs: while cond & while body, their parent graph is "subgraph" before unfold | |||
EXPECT_EQ(sub_sub_graph1->GetParentGraph()->GetName(), "root_graph"); | |||
EXPECT_EQ(sub_sub_graph1->GetParentGraph()->GetName(), "root_graph"); | |||
// node "cond_data" & "body_data" has owner compute graph "subgraph" before unfold | |||
EXPECT_EQ(sub_graph_while_node->GetOwnerComputeGraph()->GetName(), "sub_graph"); | |||
// unfold success | |||
EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS); | |||
// subgraph num after unfold: 0 | |||
EXPECT_EQ(merged_graph->GetAllSubgraphs().size(), 2); | |||
// num of nodes in MergedGraph after unfold: 1, name: while | |||
EXPECT_EQ(merged_graph->GetDirectNodesSize(), 1); | |||
EXPECT_EQ(merged_graph->GetDirectNode().at(0)->GetName(), "while"); | |||
// two sub_sub_graphs: while cond & while body, their parent graph is "MergedGraph" after unfold | |||
EXPECT_EQ(sub_sub_graph1->GetParentGraph()->GetName(), "MergedGraph" ); | |||
EXPECT_EQ(sub_sub_graph1->GetParentGraph()->GetName(), "MergedGraph"); | |||
// node "cond_data" & "body_data" has owner compute graph "MergedGraph" before unfold | |||
EXPECT_EQ(sub_graph_while_node->GetOwnerComputeGraph()->GetName(), "MergedGraph"); | |||
} |