Browse Source

support unknown while subgraph

tags/v1.2.0
lichun 4 years ago
parent
commit
12cef9e9b9
4 changed files with 83 additions and 19 deletions
  1. +1
    -0
      ge/hybrid/model/hybrid_model.h
  2. +30
    -16
      ge/hybrid/model/hybrid_model_builder.cc
  3. +2
    -3
      ge/hybrid/model/hybrid_model_builder.h
  4. +50
    -0
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 1
- 0
ge/hybrid/model/hybrid_model.h View File

@@ -135,6 +135,7 @@ class HybridModel {
std::string model_name_; std::string model_name_;
GeRootModelPtr ge_root_model_; GeRootModelPtr ge_root_model_;
std::map<uint32_t, NodeItem *> input_nodes_; std::map<uint32_t, NodeItem *> input_nodes_;
ComputeGraphPtr root_graph_;
std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148 std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148
std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148 std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148
std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_; std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_;


+ 30
- 16
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -136,12 +136,12 @@ Status HybridModelBuilder::Build() {
GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[%s] Failed to RecoverGraphUnknownFlag", GetGraphName()); GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[%s] Failed to RecoverGraphUnknownFlag", GetGraphName());
GE_CHK_STATUS_RET(IndexSpecialNodes(), "[%s] Failed to index nodes", GetGraphName()); GE_CHK_STATUS_RET(IndexSpecialNodes(), "[%s] Failed to index nodes", GetGraphName());
GE_CHK_STATUS_RET(IndexTaskDefs(), "[%s] Failed to index task defs", GetGraphName()); GE_CHK_STATUS_RET(IndexTaskDefs(), "[%s] Failed to index task defs", GetGraphName());
GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName());
GE_CHK_STATUS_RET(LoadGraph(), "[%s] Failed to load graph", GetGraphName()); GE_CHK_STATUS_RET(LoadGraph(), "[%s] Failed to load graph", GetGraphName());
GE_CHK_STATUS_RET(AssignUninitializedConstantOps(), "[%s] Failed to assign uninitialized constants", GetGraphName()); GE_CHK_STATUS_RET(AssignUninitializedConstantOps(), "[%s] Failed to assign uninitialized constants", GetGraphName());
GE_CHK_STATUS_RET(TransAllVarData(), "[%s] Failed to trans all var data", GetGraphName()); GE_CHK_STATUS_RET(TransAllVarData(), "[%s] Failed to trans all var data", GetGraphName());
GE_CHK_STATUS_RET(CopyVarData(), "[%s] Failed to copy var data", GetGraphName()); GE_CHK_STATUS_RET(CopyVarData(), "[%s] Failed to copy var data", GetGraphName());
GE_CHK_STATUS_RET(InitModelMem(), "[%s] Failed to init memory", GetGraphName()); GE_CHK_STATUS_RET(InitModelMem(), "[%s] Failed to init memory", GetGraphName());
GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName());
GE_CHK_STATUS_RET(InitConstantOps(), "[%s] Failed to init constant op", GetGraphName()); GE_CHK_STATUS_RET(InitConstantOps(), "[%s] Failed to init constant op", GetGraphName());
GE_CHK_STATUS_RET(InitVariableTensors(), "[%s] Failed to init variables", GetGraphName()); GE_CHK_STATUS_RET(InitVariableTensors(), "[%s] Failed to init variables", GetGraphName());
GE_CHK_STATUS_RET(LoadTasks(), "[%s] Failed to load tasks", GetGraphName()); GE_CHK_STATUS_RET(LoadTasks(), "[%s] Failed to load tasks", GetGraphName());
@@ -599,9 +599,10 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) {
return SUCCESS; return SUCCESS;
} }


Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) {
Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) {
merged_graph = MakeShared<ComputeGraph>("MergedGraph"); merged_graph = MakeShared<ComputeGraph>("MergedGraph");
for (const auto &node : root_graph.GetDirectNode()) {
merged_graph->SetGraphUnknownFlag(root_graph->GetGraphUnknownFlag());
for (const auto &node : root_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
auto op_desc = node->GetOpDesc(); auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
@@ -631,7 +632,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap
} }
} }
} }
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph),
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph),
"[%s] Failed to merge subgraph.", "[%s] Failed to merge subgraph.",
subgraph->GetName().c_str()); subgraph->GetName().c_str());
} }
@@ -647,18 +648,19 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap
return a_level < b_level; return a_level < b_level;
}); });


for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) {
for (auto &remained_subgraph : root_graph->GetAllSubgraphs()) {
GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str());
GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph), GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph),
"Failed to add subgraph [%s]", "Failed to add subgraph [%s]",
remained_subgraph->GetName().c_str()); remained_subgraph->GetName().c_str());
remained_subgraph->SetParentGraph(merged_graph);
} }


return SUCCESS; return SUCCESS;
} }


Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph,
ComputeGraph &parent_graph,
Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph,
ComputeGraphPtr &parent_graph,
ComputeGraph &sub_graph) { ComputeGraph &sub_graph) {
auto parent_node = sub_graph.GetParentNode(); auto parent_node = sub_graph.GetParentNode();
GE_CHECK_NOTNULL(parent_node); GE_CHECK_NOTNULL(parent_node);
@@ -687,15 +689,23 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph,
} }
} }


parent_graph.AddNode(sub_node);
if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) {
for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) {
auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i);
GE_CHECK_NOTNULL(sub_sub_graph);
sub_sub_graph->SetParentGraph(parent_graph);
}
}
parent_graph->AddNode(sub_node);
GELOGD("[%s::%s] added to parent graph: [%s].", GELOGD("[%s::%s] added to parent graph: [%s].",
sub_graph.GetName().c_str(), sub_graph.GetName().c_str(),
sub_node->GetName().c_str(), sub_node->GetName().c_str(),
parent_graph.GetName().c_str());
parent_graph->GetName().c_str());
sub_node->SetOwnerComputeGraph(parent_graph);
} }


GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str());
root_graph.RemoveSubgraph(sub_graph.GetName());
root_graph->RemoveSubgraph(sub_graph.GetName());
return SUCCESS; return SUCCESS;
} }


@@ -747,14 +757,14 @@ Status HybridModelBuilder::LoadGraph() {
GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu",
root_graph->GetDirectNodesSize(), root_graph->GetDirectNodesSize(),
root_graph->GetAllNodesSize()); root_graph->GetAllNodesSize());
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs.");
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "Failed to unfold subgraphs.");
root_graph = std::move(merged_graph); root_graph = std::move(merged_graph);
GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu",
root_graph->GetDirectNodesSize(), root_graph->GetDirectNodesSize(),
root_graph->GetAllNodesSize()); root_graph->GetAllNodesSize());
} }


root_graph_ = root_graph;
hybrid_model_.root_graph_ = root_graph;
// Reset node id by topological order across all subgraphs // Reset node id by topological order across all subgraphs
int64_t index = 0; int64_t index = 0;
for (const auto &node : root_graph->GetAllNodes()) { for (const auto &node : root_graph->GetAllNodes()) {
@@ -1030,9 +1040,13 @@ Status HybridModelBuilder::InitWeights() {
GELOGI("Init weight mem successfully, weight base %p, weight size = %zu", GELOGI("Init weight mem successfully, weight base %p, weight size = %zu",
weight_base, weight_base,
sub_weight_buffer->GetSize()); sub_weight_buffer->GetSize());
auto root_graph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph());
hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(),std::move(sub_weight_buffer));
for (auto &node : root_graph->GetDirectNode()) {
auto subgraph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph());
if (subgraph != ge_root_model_->GetRootGraph()) {
subgraph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first);
}
GE_CHECK_NOTNULL(subgraph);
hybrid_model_.weight_buffer_map_.emplace(subgraph->GetName(), std::move(sub_weight_buffer));
for (auto &node : subgraph->GetDirectNode()) {
if (node->GetType() != CONSTANT) { if (node->GetType() != CONSTANT) {
continue; continue;
} }
@@ -2044,7 +2058,7 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) {
GELOGD("[%s] Start to get parallel group from subgraph: %s", GELOGD("[%s] Start to get parallel group from subgraph: %s",
node_item->NodeName().c_str(), node_item->NodeName().c_str(),
subgraph_name.c_str()); subgraph_name.c_str());
auto subgraph = root_graph_->GetSubgraph(subgraph_name);
auto subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_name);
GE_CHECK_NOTNULL(subgraph); GE_CHECK_NOTNULL(subgraph);
for (const auto &sub_node : subgraph->GetAllNodes()) { for (const auto &sub_node : subgraph->GetAllNodes()) {
std::string parallel_group; std::string parallel_group;


+ 2
- 3
ge/hybrid/model/hybrid_model_builder.h View File

@@ -47,8 +47,8 @@ class HybridModelBuilder {
static Status HandleDtString(const GeTensor &tensor, void *var_addr); static Status HandleDtString(const GeTensor &tensor, void *var_addr);
static Status MergeInputNodes(ComputeGraph &compute_graph); static Status MergeInputNodes(ComputeGraph &compute_graph);
static Status MergeNetOutputNode(ComputeGraph &compute_graph); static Status MergeNetOutputNode(ComputeGraph &compute_graph);
static Status UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph);
static Status UnfoldSubgraph(ComputeGraph &root_graph, ComputeGraph &parent_graph, ComputeGraph &sub_graph);
static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph);
static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph);
static Status BuildInputMapping(GraphItem &graph_item, static Status BuildInputMapping(GraphItem &graph_item,
std::vector<NodeItem *> &data_nodes, std::vector<NodeItem *> &data_nodes,
bool is_root_graph); bool is_root_graph);
@@ -100,7 +100,6 @@ class HybridModelBuilder {
NodeItem *MutableNodeItem(const NodePtr &node); NodeItem *MutableNodeItem(const NodePtr &node);


GeRootModelPtr ge_root_model_; GeRootModelPtr ge_root_model_;
ComputeGraphPtr root_graph_;
std::map<std::string, GeModelPtr> subgraph_models_; std::map<std::string, GeModelPtr> subgraph_models_;
std::map<std::string, NodePtr> constant_op_nodes_; std::map<std::string, NodePtr> constant_op_nodes_;
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_; std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_;


+ 50
- 0
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -256,3 +256,53 @@ TEST_F(UtestGeHybrid, init_weight_success) {
HybridModelExecutor executor(model_ptr, device_id, stream); HybridModelExecutor executor(model_ptr, device_id, stream);
executor.Init(); executor.Init();
} }

TEST_F(UtestGeHybrid, unfold_subgraphs_success) {
ComputeGraphPtr merged_graph = nullptr;

ComputeGraphPtr sub_sub_graph1 = std::make_shared<ComputeGraph>("while_cond");
OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA);
NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc);

ComputeGraphPtr sub_sub_graph2 = std::make_shared<ComputeGraph>("while_body");
/*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT);
NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/
OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA);
NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc);
sub_sub_graph2->SetGraphUnknownFlag(true);
/*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD);
NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node);
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node);
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/

ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE);
NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc);
sub_graph->SetGraphUnknownFlag(true);
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond");
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body");
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond");
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body");

ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("root_graph");
auto partitioned_call_op_desc = MakeShared<OpDesc>("partitioned_call", PARTITIONEDCALL);
auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc);
partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph");
partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph");

root_graph->AddSubGraph(sub_sub_graph1);
root_graph->AddSubGraph(sub_sub_graph2);
sub_sub_graph1->SetParentGraph(root_graph);
sub_sub_graph2->SetParentGraph(root_graph);
sub_sub_graph1->SetParentNode(sub_graph_while_node);
sub_sub_graph2->SetParentNode(sub_graph_while_node);

root_graph->AddSubGraph(sub_graph);
sub_graph->SetParentNode(partitioned_call_node);
sub_graph->SetParentGraph(root_graph);

GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(root_graph);
HybridModel hybrid_model(root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS);
}

Loading…
Cancel
Save