Browse Source

Fix mark branch force unknown

tags/v1.3.0
zhangxiaokun 4 years ago
parent
commit
e9edaca33f
6 changed files with 399 additions and 1 deletions
  1. +2
    -0
      ge/CMakeLists.txt
  2. +4
    -1
      ge/graph/manager/graph_manager.cc
  3. +125
    -0
      ge/graph/passes/mark_branch_force_unknown_pass.cc
  4. +36
    -0
      ge/graph/passes/mark_branch_force_unknown_pass.h
  5. +2
    -0
      tests/ut/ge/CMakeLists.txt
  6. +230
    -0
      tests/ut/ge/graph/passes/mark_branch_force_unknown_pass_unittest.cc

+ 2
- 0
ge/CMakeLists.txt View File

@@ -307,6 +307,7 @@ set(TRAIN_SRC_LIST
"graph/passes/merge_to_stream_merge_pass.cc"
"graph/passes/merge_input_memcpy_pass.cc"
"graph/passes/switch_to_stream_switch_pass.cc"
"graph/passes/mark_branch_force_unknown_pass.cc"
"graph/passes/attach_stream_label_pass.cc"
"graph/passes/switch_dead_branch_elimination.cc"
"graph/passes/replace_transshape_pass.cc"
@@ -584,6 +585,7 @@ set(INFER_SRC_LIST
"graph/passes/merge_to_stream_merge_pass.cc"
"graph/passes/merge_input_memcpy_pass.cc"
"graph/passes/switch_to_stream_switch_pass.cc"
"graph/passes/mark_branch_force_unknown_pass.cc"
"graph/passes/attach_stream_label_pass.cc"
"graph/passes/multi_batch_pass.cc"
"graph/passes/multi_batch_clone_pass.cc"


+ 4
- 1
ge/graph/manager/graph_manager.cc View File

@@ -65,6 +65,7 @@
#include "graph/passes/merge_pass.h"
#include "graph/passes/merge_input_memcpy_pass.h"
#include "graph/passes/merge_to_stream_merge_pass.h"
#include "graph/passes/mark_branch_force_unknown_pass.h"
#include "graph/passes/multi_batch_pass.h"
#include "graph/passes/next_iteration_pass.h"
#include "graph/passes/permute_pass.h"
@@ -2535,7 +2536,9 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
// the prune pass should between SwitchPass and SwitchToStreamSwitchPass
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass));
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass));
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass))
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass));
auto mark_force_unknown_pass = new (std::nothrow) MarkBranchForceUnknownPass;
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MarkBranchForceUnknownPass", mark_force_unknown_pass));
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass))
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass))
GE_CHK_STATUS_RET(


+ 125
- 0
ge/graph/passes/mark_branch_force_unknown_pass.cc View File

@@ -0,0 +1,125 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "mark_branch_force_unknown_pass.h"

#include <queue>

#include "graph/common/omg_util.h"

namespace ge {
namespace {
const std::set<std::string> kMergeOpTypes{ MERGE, REFMERGE };

const std::set<std::string> kSwitchOpTypes{ SWITCH, REFSWITCH };

const std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION };

inline bool IsMergeInLoop(const NodePtr &node) {
std::string node_type;
(void)GetOriginalType(node, node_type);
return kLoopMergeInputs.count(node_type) > 0;
}
}

Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) {
GELOGD("MarkBranchForceUnknownPass Enter");
for (const auto &node : graph->GetDirectNode()) {
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed.");
if ((node_type != MERGE) && (node_type != REFMERGE)) {
continue;
}

const auto op_desc = node->GetOpDesc();
if (!op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE) && !IsUnknownShapeTensor(op_desc->GetOutputDesc(0))) {
GELOGI("Merge[%s] has known shape, no need check switch", node->GetName().c_str());
continue;
}

const auto &all_in_nodes = node->GetInDataNodes();
if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsMergeInLoop)) {
continue; // LoopCond marked in NextIterationPass.
}

MarkUnknownForSwitch(node);
}

GELOGD("MarkBranchForceUnknownPass Leave");
return SUCCESS;
}

///
/// @brief Mark force unknown shape for Switch node
/// @param [in] merge node
/// @return
///
void MarkBranchForceUnknownPass::MarkUnknownForSwitch(const NodePtr &node) {
// Switch --> {Switch --> Merge} --> Merge
std::vector<NodePtr> switch_group;
std::unordered_set<NodePtr> nodes_seen;

std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}});
while (!search_queue.empty()) {
const auto dst_node = search_queue.front().first;
const auto dst_span = search_queue.front().second;
search_queue.pop();

// Switch --> Identity --> Constant
for (const auto &in_node : dst_node->GetInControlNodes()) {
if (nodes_seen.count(in_node) > 0) {
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str());
continue;
}
nodes_seen.insert(in_node);

if (in_node->GetType() == IDENTITY) {
GELOGD("Travel node: %s, In control: %s, span is: %u", dst_node->GetName().c_str(),
in_node->GetName().c_str(), dst_span);
search_queue.push({in_node, dst_span});
}
}

for (const auto &in_node : dst_node->GetInDataNodes()) {
if (nodes_seen.count(in_node) > 0) {
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str());
continue;
}
nodes_seen.insert(in_node);

std::string node_type;
(void)GetOriginalType(in_node, node_type);
GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(),
in_node->GetName().c_str(), dst_span);
if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node.
if (dst_span > 0) {
search_queue.push({in_node, dst_span - 1});
} else {
switch_group.emplace_back(in_node);
}
} else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node.
search_queue.push({in_node, dst_span + 1});
} else {
search_queue.push({in_node, dst_span});
}
}
}

for (const auto &n : switch_group) {
MarkForceUnknownShape(n, true);
}
}
} // namespace ge

+ 36
- 0
ge/graph/passes/mark_branch_force_unknown_pass.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_
#define GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_

#include "inc/graph_pass.h"

namespace ge {
class MarkBranchForceUnknownPass : public GraphPass {
public:
Status Run(ComputeGraphPtr graph);

private:
///
/// @brief Mark force unknown shape for Switch node
/// @param [in] merge node
/// @return
///
void MarkUnknownForSwitch(const NodePtr &node);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_

+ 2
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -239,6 +239,7 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/merge_to_stream_merge_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/merge_input_memcpy_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/switch_to_stream_switch_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/mark_branch_force_unknown_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/attach_stream_label_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/multi_batch_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/multi_batch_clone_pass.cc"
@@ -703,6 +704,7 @@ set(PASS_TEST_FILES
"graph/passes/net_output_pass_unittest.cc"
"graph/passes/no_use_reshape_remove_pass_unittest.cc"
"graph/passes/infershape_pass_unittest.cc"
"graph/passes/mark_branch_force_unknown_pass_unittest.cc"
"graph/passes/multi_batch_clone_pass_unittest.cc"
"graph/passes/replace_with_empty_const_pass_unittest.cc"
"graph/passes/link_gen_mask_nodes_pass_unittest.cc"


+ 230
- 0
tests/ut/ge/graph/passes/mark_branch_force_unknown_pass_unittest.cc View File

@@ -0,0 +1,230 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>

#define protected public
#define private public
#include "graph/passes/mark_branch_force_unknown_pass.h"

#include "graph/utils/tensor_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/operator_factory.h"
#include "graph/operator_reg.h"
#include "graph_builder_utils.h"

using namespace std;
using namespace testing;
namespace ge {
class UtestMarkBranchForceUnknownPass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};

static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
op_desc->SetStreamId(0);
static int32_t index = 0;
op_desc->SetId(index++);

GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT);
TensorUtils::SetSize(tensor, 512);
vector<int64_t> input_offset;
for (int i = 0; i < in_num; i++) {
op_desc->AddInputDesc(tensor);
input_offset.emplace_back(1024);
}
op_desc->SetInputOffset(input_offset);

vector<int64_t> output_offset;
for (int i = 0; i < out_num; i++) {
op_desc->AddOutputDesc(tensor);
output_offset.emplace_back(1024);
}
op_desc->SetOutputOffset(output_offset);

op_desc->SetWorkspace({});
op_desc->SetWorkspaceBytes({});
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE");

const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; };
op_desc->AddInferFunc(stub_func);
op_desc->AddInferFormatFunc(stub_func);
op_desc->AddVerifierFunc(stub_func);

return graph.AddNode(op_desc);
}

static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge) {
/*******************************************************************************
* Exit Identify
* \ / \.
* \ / \.
* Switch Add
* / | |
* / | |
* / | |
* LoopCond | |
* \ | |
* \ | |
* \ | |
* Less | |
* \ | NextIteration
* \ | |
* \ | |
* Merge <---------|
* |
* |
* Enter
******************************************************************************/
auto data1 = CreateNode(*graph, "data", DATA, 1, 1);
auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1);
auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2);
auto less1 = CreateNode(*graph, "less", LESS, 2, 1);
auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1);
auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2);
auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1);
auto add1 = CreateNode(*graph, "add", ADD, 2, 1);
auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1);
auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1);
auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1);

GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0));
GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0));

GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1));

GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0));
GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0));

GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1));
GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0));

GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1));
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0));

merge = merge1;
}

static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) {
/*******************************************************************************
* NetOutput
* |
* |
* Merge
* / \.
* / \.
* / \.
* Add Sub
* | \ | \.
* | \ | \.
* | \ | Const
* | \ | \.
* | \ | Identify
* | \ | |
* Switch Switch Switch Switch
* | | | | |
* | | | | |
* x y Cond z
******************************************************************************/
auto data1 = CreateNode(*graph, "data_x", DATA, 1, 1);
auto data2 = CreateNode(*graph, "data_y", DATA, 1, 1);
auto data3 = CreateNode(*graph, "data_z", DATA, 1, 1);

auto less1 = CreateNode(*graph, "less", LESS, 2, 1);

auto switch1 = CreateNode(*graph, "switch_x", SWITCH, 2, 2);
auto switch2 = CreateNode(*graph, "switch_y", SWITCH, 2, 2);
auto switch3 = CreateNode(*graph, "switch_z", SWITCH, 2, 2);
auto switch4 = CreateNode(*graph, "switch_i", SWITCH, 2, 2);

auto add1 = CreateNode(*graph, "add", ADD, 2, 1);
auto sub1 = CreateNode(*graph, "add", SUB, 2, 1);
auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1);
auto const1 = CreateNode(*graph, "const", CONSTANT, 0, 1);

auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2);
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1);

GraphUtils::AddEdge(data1->GetOutDataAnchor(0), less1->GetInDataAnchor(0));
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), less1->GetInDataAnchor(1));

GraphUtils::AddEdge(data1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0));
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), switch2->GetInDataAnchor(0));
GraphUtils::AddEdge(data3->GetOutDataAnchor(0), switch3->GetInDataAnchor(0));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch4->GetInDataAnchor(0));

GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch2->GetInDataAnchor(1));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch3->GetInDataAnchor(1));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch4->GetInDataAnchor(1));

GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), add1->GetInDataAnchor(0));
GraphUtils::AddEdge(switch2->GetOutDataAnchor(0), add1->GetInDataAnchor(1));
GraphUtils::AddEdge(switch3->GetOutDataAnchor(0), sub1->GetInDataAnchor(0));
GraphUtils::AddEdge(switch4->GetOutDataAnchor(0), ident1->GetInDataAnchor(1));
GraphUtils::AddEdge(ident1->GetOutControlAnchor(), const1->GetInControlAnchor());
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), sub1->GetInDataAnchor(1));

GraphUtils::AddEdge(add1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0));
GraphUtils::AddEdge(sub1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0));

merge = merge1;
}

TEST_F(UtestMarkBranchForceUnknownPass, skip_while_loop_merge) {
auto graph = std::make_shared<ComputeGraph>("test_graph");
NodePtr merge;
CreateLoopGraph(graph, merge);

AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true);

MarkBranchForceUnknownPass mark_force_unknown_pass;
EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond
}

TEST_F(UtestMarkBranchForceUnknownPass, skip_known_shape_merge) {
auto graph = std::make_shared<ComputeGraph>("test_graph");
NodePtr merge;
CreateCondGraph(graph, merge);

MarkBranchForceUnknownPass mark_force_unknown_pass;
EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip known shape merge
}


TEST_F(UtestMarkBranchForceUnknownPass, mark_unknown_shape_merge) {
auto graph = std::make_shared<ComputeGraph>("test_graph");
NodePtr merge;
CreateCondGraph(graph, merge);

auto tensor_desc = merge->GetOpDesc()->GetOutputDesc(0);
tensor_desc.SetShape(GeShape({-1})); // Set for unknown.
merge->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);

MarkBranchForceUnknownPass mark_force_unknown_pass;
EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS);
}
} // namespace ge

Loading…
Cancel
Save