|
|
@@ -136,12 +136,12 @@ Status HybridModelBuilder::Build() { |
|
|
|
GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[%s] Failed to RecoverGraphUnknownFlag", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(IndexSpecialNodes(), "[%s] Failed to index nodes", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(IndexTaskDefs(), "[%s] Failed to index task defs", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(LoadGraph(), "[%s] Failed to load graph", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(AssignUninitializedConstantOps(), "[%s] Failed to assign uninitialized constants", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(TransAllVarData(), "[%s] Failed to trans all var data", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(CopyVarData(), "[%s] Failed to copy var data", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(InitModelMem(), "[%s] Failed to init memory", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(InitConstantOps(), "[%s] Failed to init constant op", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(InitVariableTensors(), "[%s] Failed to init variables", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(LoadTasks(), "[%s] Failed to load tasks", GetGraphName()); |
|
|
@@ -599,9 +599,10 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) { |
|
|
|
Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) { |
|
|
|
merged_graph = MakeShared<ComputeGraph>("MergedGraph"); |
|
|
|
for (const auto &node : root_graph.GetDirectNode()) { |
|
|
|
merged_graph->SetGraphUnknownFlag(root_graph->GetGraphUnknownFlag()); |
|
|
|
for (const auto &node : root_graph->GetDirectNode()) { |
|
|
|
GE_CHECK_NOTNULL(node); |
|
|
|
auto op_desc = node->GetOpDesc(); |
|
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
@@ -631,7 +632,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph), |
|
|
|
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph), |
|
|
|
"[%s] Failed to merge subgraph.", |
|
|
|
subgraph->GetName().c_str()); |
|
|
|
} |
|
|
@@ -647,18 +648,19 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap |
|
|
|
return a_level < b_level; |
|
|
|
}); |
|
|
|
|
|
|
|
for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) { |
|
|
|
for (auto &remained_subgraph : root_graph->GetAllSubgraphs()) { |
|
|
|
GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); |
|
|
|
GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph), |
|
|
|
"Failed to add subgraph [%s]", |
|
|
|
remained_subgraph->GetName().c_str()); |
|
|
|
remained_subgraph->SetParentGraph(merged_graph); |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, |
|
|
|
ComputeGraph &parent_graph, |
|
|
|
Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph, |
|
|
|
ComputeGraphPtr &parent_graph, |
|
|
|
ComputeGraph &sub_graph) { |
|
|
|
auto parent_node = sub_graph.GetParentNode(); |
|
|
|
GE_CHECK_NOTNULL(parent_node); |
|
|
@@ -687,15 +689,23 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
parent_graph.AddNode(sub_node); |
|
|
|
if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { |
|
|
|
for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { |
|
|
|
auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i); |
|
|
|
GE_CHECK_NOTNULL(sub_sub_graph); |
|
|
|
sub_sub_graph->SetParentGraph(parent_graph); |
|
|
|
} |
|
|
|
} |
|
|
|
parent_graph->AddNode(sub_node); |
|
|
|
GELOGD("[%s::%s] added to parent graph: [%s].", |
|
|
|
sub_graph.GetName().c_str(), |
|
|
|
sub_node->GetName().c_str(), |
|
|
|
parent_graph.GetName().c_str()); |
|
|
|
parent_graph->GetName().c_str()); |
|
|
|
sub_node->SetOwnerComputeGraph(parent_graph); |
|
|
|
} |
|
|
|
|
|
|
|
GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); |
|
|
|
root_graph.RemoveSubgraph(sub_graph.GetName()); |
|
|
|
root_graph->RemoveSubgraph(sub_graph.GetName()); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
@@ -747,14 +757,14 @@ Status HybridModelBuilder::LoadGraph() { |
|
|
|
GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", |
|
|
|
root_graph->GetDirectNodesSize(), |
|
|
|
root_graph->GetAllNodesSize()); |
|
|
|
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs."); |
|
|
|
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "Failed to unfold subgraphs."); |
|
|
|
root_graph = std::move(merged_graph); |
|
|
|
GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", |
|
|
|
root_graph->GetDirectNodesSize(), |
|
|
|
root_graph->GetAllNodesSize()); |
|
|
|
} |
|
|
|
|
|
|
|
root_graph_ = root_graph; |
|
|
|
hybrid_model_.root_graph_ = root_graph; |
|
|
|
// Reset node id by topological order across all subgraphs |
|
|
|
int64_t index = 0; |
|
|
|
for (const auto &node : root_graph->GetAllNodes()) { |
|
|
@@ -1030,9 +1040,13 @@ Status HybridModelBuilder::InitWeights() { |
|
|
|
GELOGI("Init weight mem successfully, weight base %p, weight size = %zu", |
|
|
|
weight_base, |
|
|
|
sub_weight_buffer->GetSize()); |
|
|
|
auto root_graph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); |
|
|
|
hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(),std::move(sub_weight_buffer)); |
|
|
|
for (auto &node : root_graph->GetDirectNode()) { |
|
|
|
auto subgraph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); |
|
|
|
if (subgraph != ge_root_model_->GetRootGraph()) { |
|
|
|
subgraph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first); |
|
|
|
} |
|
|
|
GE_CHECK_NOTNULL(subgraph); |
|
|
|
hybrid_model_.weight_buffer_map_.emplace(subgraph->GetName(), std::move(sub_weight_buffer)); |
|
|
|
for (auto &node : subgraph->GetDirectNode()) { |
|
|
|
if (node->GetType() != CONSTANT) { |
|
|
|
continue; |
|
|
|
} |
|
|
@@ -2044,7 +2058,7 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { |
|
|
|
GELOGD("[%s] Start to get parallel group from subgraph: %s", |
|
|
|
node_item->NodeName().c_str(), |
|
|
|
subgraph_name.c_str()); |
|
|
|
auto subgraph = root_graph_->GetSubgraph(subgraph_name); |
|
|
|
auto subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_name); |
|
|
|
GE_CHECK_NOTNULL(subgraph); |
|
|
|
for (const auto &sub_node : subgraph->GetAllNodes()) { |
|
|
|
std::string parallel_group; |
|
|
|