From 2cf49ced1c080e8792302e539cb5603dbe708ee6 Mon Sep 17 00:00:00 2001 From: lianghao Date: Sat, 27 Mar 2021 14:41:23 +0800 Subject: [PATCH] online_inference c77 --- ge/graph/passes/attach_stream_label_pass.cc | 1 - ge/graph/passes/pass_utils.cc | 8 +- ge/graph/passes/pass_utils.h | 2 + ge/graph/passes/subexpression_migration_pass.cc | 2 +- ge/graph/passes/switch_dead_branch_elimination.cc | 10 +- ge/graph/passes/switch_to_stream_switch_pass.cc | 2 + metadef | 2 +- tests/ut/ge/CMakeLists.txt | 1 + .../switch_dead_branch_elimination_unittest.cc | 163 +++++++++++++++++++++ 9 files changed, 185 insertions(+), 6 deletions(-) create mode 100644 tests/ut/ge/graph/passes/switch_dead_branch_elimination_unittest.cc diff --git a/ge/graph/passes/attach_stream_label_pass.cc b/ge/graph/passes/attach_stream_label_pass.cc index cd3509c7..4927e3aa 100644 --- a/ge/graph/passes/attach_stream_label_pass.cc +++ b/ge/graph/passes/attach_stream_label_pass.cc @@ -137,7 +137,6 @@ Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &strea return INTERNAL_ERROR; } stream_label = node->GetInDataNodes().at(0)->GetName(); - GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); bool value = false; OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); diff --git a/ge/graph/passes/pass_utils.cc b/ge/graph/passes/pass_utils.cc index 3adfbde3..b827e88a 100644 --- a/ge/graph/passes/pass_utils.cc +++ b/ge/graph/passes/pass_utils.cc @@ -35,9 +35,9 @@ #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" +#include "utils/node_utils.h" namespace ge { - Status PassUtils::ConstructTensorDescWithData(const GeTensorDesc &out_desc, std::vector &data, std::vector &v_output, const bool scalar_output) { Status ret = SUCCESS; @@ -246,6 +246,12 @@ NodePtr PassUtils::GetInDataNode(const ConstNodePtr &node, int index) { return src_node; } +NodePtr PassUtils::GetInNodeCrossSubgraphByIndex(const ConstNodePtr &node, int index) { + auto src_node = GetInDataNode(node, index); + + return NodeUtils::GetInNodeCrossSubgraph(src_node); +} + bool PassUtils::IsNeedTrainIteFlowCtrl(const ComputeGraphPtr &compute_graph) { if (compute_graph == nullptr) { return false; diff --git a/ge/graph/passes/pass_utils.h b/ge/graph/passes/pass_utils.h index fbfb3b47..bd506d09 100755 --- a/ge/graph/passes/pass_utils.h +++ b/ge/graph/passes/pass_utils.h @@ -30,6 +30,8 @@ class PassUtils { static NodePtr GetInDataNode(const ConstNodePtr &node, int index); + static NodePtr GetInNodeCrossSubgraphByIndex(const ConstNodePtr &node, int index); + static bool IsConstant(const ConstNodePtr &node); static Status SetOutNodeWeight(const OutDataAnchorPtr &out_data_anchor, const NodePtr &src_node); diff --git a/ge/graph/passes/subexpression_migration_pass.cc b/ge/graph/passes/subexpression_migration_pass.cc index dc4d2185..05b7baa1 100755 --- a/ge/graph/passes/subexpression_migration_pass.cc +++ b/ge/graph/passes/subexpression_migration_pass.cc @@ -279,7 +279,7 @@ Status SubexpressionMigrationPass::GraphNodeMigration(const ComputeGraphPtr &gra const auto &in_anchor = in_anchors.at(i); const auto &base_node = in_anchor->GetOwnerNode(); GELOGD("Get Data direct node: %s", base_node->GetName().c_str()); - if (!base_node->GetHostNode()) { + if (!base_node->GetHostNode() || base_node->GetType() == SWITCH) { continue; } diff --git a/ge/graph/passes/switch_dead_branch_elimination.cc b/ge/graph/passes/switch_dead_branch_elimination.cc index 70105aea..20598f17 100644 --- a/ge/graph/passes/switch_dead_branch_elimination.cc +++ b/ge/graph/passes/switch_dead_branch_elimination.cc @@ -94,6 +94,12 @@ Status SwitchDeadBranchElimination::DeleteSwitchNode(NodePtr &node, NodePtr &pre GELOGE(FAILED, "parameter is null."); return FAILED; } + + // If two nodes aren't in same graph, get node's direct in_node instead of pred_node. + if (node->GetOwnerComputeGraph() != pred_node->GetOwnerComputeGraph()) { + pred_node = PassUtils::GetInDataNode(node, kPredInputIndex); + } + // link pred's in control nodes to switch if (GraphUtils::CopyInCtrlEdges(pred_node, node) != GRAPH_SUCCESS) { return FAILED; @@ -131,7 +137,7 @@ Status SwitchDeadBranchElimination::Run(NodePtr &node) { return SUCCESS; } - auto pred_node = PassUtils::GetInDataNode(node, kPredInputIndex); + auto pred_node = PassUtils::GetInNodeCrossSubgraphByIndex(node, kPredInputIndex); if (pred_node == nullptr) { GELOGD("[%s] Pred input is null.", node->GetName().c_str()); return SUCCESS; @@ -143,7 +149,7 @@ Status SwitchDeadBranchElimination::Run(NodePtr &node) { return SUCCESS; } - auto input_node = PassUtils::GetInDataNode(node, kDataInputIndex); + auto input_node = PassUtils::GetInNodeCrossSubgraphByIndex(node, kDataInputIndex); if (input_node == nullptr) { GELOGD("[%s] Data input is null.", node->GetName().c_str()); return SUCCESS; diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc index 392968e7..d7fa8844 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -448,6 +448,8 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) // select first stream_switch NodePtr stream_switch = switch_list.front(); + // set stream_label + GE_CHK_STATUS_RET(SetStreamLabel(stream_switch, cast_node->GetName()), "Set stream label failed."); OpDescPtr switch_desc = stream_switch->GetOpDesc(); GE_CHECK_NOTNULL(switch_desc); switch_desc->SetName(CheckDuplicateName(cond_group + "/" + STREAMSWITCH + (true_branch_flag ? "_t" : "_f"))); diff --git a/metadef b/metadef index f0dd9337..0c4602a4 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit f0dd933702e5224399e757e1b6174e49eb4e71fa +Subproject commit 0c4602a4615a9368b06633a5087e2114518f29ca diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 80636a20..12c33db7 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -667,6 +667,7 @@ set(PASS_TEST_FILES "graph/passes/merge_pass_unittest.cc" #"graph/passes/switch_pass_unittest.cc" "graph/passes/switch_logic_remove_pass_unittest.cc" + "graph/passes/switch_dead_branch_elimination_unittest.cc" "graph/passes/assert_pass_unittest.cc" "graph/passes/dropout_pass_unittest.cc" "graph/passes/unused_const_pass_unittest.cc" diff --git a/tests/ut/ge/graph/passes/switch_dead_branch_elimination_unittest.cc b/tests/ut/ge/graph/passes/switch_dead_branch_elimination_unittest.cc new file mode 100644 index 00000000..c3f21251 --- /dev/null +++ b/tests/ut/ge/graph/passes/switch_dead_branch_elimination_unittest.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "common/ge_inner_error_codes.h" +#include "graph/passes/switch_dead_branch_elimination.h" +#include "graph_builder_utils.h" + +namespace ge { +class UtestSwitchDeadBranchElimination : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +namespace { +/* + * data1 const1 + * \ / + * case1 + * | + * relu1 + * | + * netoutput + */ +ut::GraphBuilder ParentGraphBuilder() { + ut::GraphBuilder builder = ut::GraphBuilder("g1"); + auto data1 = builder.AddNode("data1", "Data", 0, 1); + auto const1 = builder.AddNode("const1", "Const", 0, 1); + auto case1 = builder.AddNode("case1", CASE, 2, 1); + auto relu1 = builder.AddNode("relu1", "Relu", 1, 1); + auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); + + int32_t weight[1] = {1}; + GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32); + GeTensorPtr tensor = std::make_shared(weight_desc, (uint8_t *)weight, sizeof(weight)); + OpDescUtils::SetWeights(const1, {tensor}); + + builder.AddDataEdge(data1, 0, case1, 0); + builder.AddDataEdge(const1, 0, case1, 1); + builder.AddDataEdge(case1, 0, relu1, 0); + builder.AddDataEdge(relu1, 0, netoutput, 0); + return builder; +} + +/* + * data1 data2 + * \ / + * switch + * / \ + * relu1 relu2 + * \ / + * merge + * | + * netoutput + */ +ut::GraphBuilder SwitchSubgraphBuilder(string graph_name, uint32_t num) { + ut::GraphBuilder builder = ut::GraphBuilder(graph_name); + + string data1_name = "data1_" + std::to_string(num); + auto data1 = builder.AddNode(data1_name, "Data", 0, 1); + auto data1_desc = data1->GetOpDesc(); + EXPECT_NE(data1_desc, nullptr); + AttrUtils::SetInt(data1_desc, "_parent_node_index", 0); + + string data2_name = "data2_" + std::to_string(num); + auto data2 = builder.AddNode(data2_name, "Data", 0, 1); + auto data2_desc = data2->GetOpDesc(); + EXPECT_NE(data2_desc, nullptr); + AttrUtils::SetInt(data2_desc, "_parent_node_index", 1); + + string switch_name = "switch_" + std::to_string(num); + auto switch1 = builder.AddNode(switch_name, "Switch", 2, 2); + + string relu1_name = "relu1_" + std::to_string(num); + auto relu1 = builder.AddNode(relu1_name, "Relu", 1, 1); + + string relu2_name = "relu2_" + std::to_string(num); + auto relu2 = builder.AddNode(relu2_name, "Relu", 1, 1); + + string merge_name = "merge_" + std::to_string(num); + auto merge = builder.AddNode(merge_name, "Merge", 2, 1); + + string output_name = "output_" + std::to_string(num); + auto netoutput = builder.AddNode(output_name, NETOUTPUT, 1, 0); + + builder.AddDataEdge(data1, 0, switch1, 0); + builder.AddDataEdge(data2, 0, switch1, 1); + builder.AddDataEdge(switch1, 0, relu1, 0); + builder.AddDataEdge(switch1, 1, relu2, 0); + builder.AddDataEdge(relu1, 0, merge, 0); + builder.AddDataEdge(relu2, 0, merge, 1); + builder.AddDataEdge(merge, 0, netoutput, 0); + + return builder; +} + +void AddCaseSubgraph(ComputeGraphPtr &parent_graph, uint32_t branch_num) { + auto case_node = parent_graph->FindNode("case1"); + EXPECT_NE(case_node, nullptr); + + for (uint32_t i = 0; i < branch_num; ++i) { + string name = "Branch_Graph_" + std::to_string(i); + + auto builder_subgraph = SwitchSubgraphBuilder(name, i); + auto switch_subgraph = builder_subgraph.GetGraph(); + + case_node->GetOpDesc()->AddSubgraphName(switch_subgraph->GetName()); + case_node->GetOpDesc()->SetSubgraphInstanceName(i, switch_subgraph->GetName()); + + switch_subgraph->SetParentNode(case_node); + switch_subgraph->SetParentGraph(parent_graph); + EXPECT_EQ(parent_graph->AddSubgraph(switch_subgraph->GetName(), switch_subgraph), GRAPH_SUCCESS); + } +} +} // namespace + + +TEST_F(UtestSwitchDeadBranchElimination, switch_dead_branch_elimination_across_case_success) { + auto builder = ParentGraphBuilder(); + auto parent_graph = builder.GetGraph(); + + AddCaseSubgraph(parent_graph, 2); + auto subgraphs = parent_graph->GetAllSubgraphs(); + EXPECT_EQ(subgraphs.size(), 2); + + SwitchDeadBranchElimination switch_pass; + for (auto &subgraph : subgraphs) { + auto switch_node = subgraph->FindFirstNodeMatchType("Switch"); + if (switch_node != nullptr) { + EXPECT_EQ(switch_pass.Run(switch_node), SUCCESS); + } + } + + auto all_nodes = parent_graph->GetAllNodes(); + EXPECT_EQ(all_nodes.size(), 17); + + for (auto &subgraph : subgraphs) { + EXPECT_EQ(subgraph->GetDirectNode().size(), 6); + EXPECT_EQ(subgraph->FindFirstNodeMatchType("Switch"), nullptr); + auto merge_node = subgraph->FindFirstNodeMatchType("Merge"); + EXPECT_NE(merge_node, nullptr); + auto merge_innode = merge_node->GetInDataNodes(); + EXPECT_EQ(merge_innode.size(), 1); + } +} +} // namespace ge