From 242afc4e6799a8910328805e8774c83f84e3ef9c Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Sat, 13 Mar 2021 17:30:39 +0800 Subject: [PATCH] modified: ge/graph/passes/base_pass.cc modified: ge/graph/passes/base_pass.h modified: ge/graph/passes/infershape_pass.cc --- ge/graph/passes/base_pass.cc | 43 ++++++++++++++++++++++++++++---------- ge/graph/passes/base_pass.h | 11 ++++++++++ ge/graph/passes/infershape_pass.cc | 16 ++++++++++++++ 3 files changed, 59 insertions(+), 11 deletions(-) diff --git a/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc index 3b854c18..64342509 100755 --- a/ge/graph/passes/base_pass.cc +++ b/ge/graph/passes/base_pass.cc @@ -31,7 +31,7 @@ constexpr size_t kMaxOneInNodes = 1000; // Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later constexpr int kMaxRecursiveDepth = 20; -void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::queue &input_edge_nodes, +void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque &input_edge_nodes, std::unordered_set &nodes_seen, std::unordered_set &nodes_last) { nodes_last.clear(); for (auto &node : graph->GetDirectNode()) { @@ -40,7 +40,7 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::queue &i } size_t in_nums = node->GetInNodes().size(); if (in_nums == 0) { - input_edge_nodes.push(node); + input_edge_nodes.push_back(node); nodes_seen.insert(node.get()); } else if (in_nums > kMaxOneInNodes) { nodes_last.insert(node); @@ -48,7 +48,7 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::queue &i } } -void AddNextIterNodes(const Node::Vistor &nodes, std::queue &nodes_to_pass, +void AddNextIterNodes(const Node::Vistor &nodes, std::deque &nodes_to_pass, std::unordered_set &nodes_seen, std::unordered_set &nodes_last) { for (auto &node : nodes) { if (node == nullptr) { @@ -60,13 +60,14 @@ void AddNextIterNodes(const Node::Vistor &nodes, std::queue &n bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); if (all_in_nodes_seen && nodes_seen.insert(node.get()).second) { - nodes_to_pass.push(node); + nodes_to_pass.push_back(node); } } } Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unordered_set &nodes_re_pass, - std::unordered_set &nodes_deleted, std::unordered_set &nodes_seen) { + std::unordered_set &nodes_re_pass_immediately, std::unordered_set &nodes_deleted, + std::unordered_set &nodes_seen) { if (node == nullptr) { GELOGE(FAILED, "parameter is null."); return FAILED; @@ -104,6 +105,21 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder } } + auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); + for (const auto &node_to_re_pass : nodes_to_re_pass_immediately) { + if (node_to_re_pass == nullptr) { + GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(), + node->GetName().c_str(), node->GetType().c_str()); + continue; + } + if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { + GELOGD("The node %s will be re-pass immediately.", node_to_re_pass->GetName().c_str()); + nodes_re_pass_immediately.insert(node_to_re_pass); + } else { + GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str()); + } + } + auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); if (nodes_deleted_by_pass.count(node) > 0) { @@ -181,10 +197,11 @@ Status GEPass::Run(const NamesToPass &names_to_passes) { Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); - std::queue nodes; + std::deque nodes; std::unordered_set nodes_seen; std::unordered_set nodes_deleted; std::unordered_set nodes_re_pass; + std::unordered_set nodes_re_pass_immediately; std::unordered_set nodes_last; GetAllNodesNoInputEdge(graph_, nodes, nodes_seen, nodes_last); GELOGD("Start points count %zu", nodes.size()); @@ -192,14 +209,14 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { do { for (auto &node : nodes_re_pass) { - nodes.push(node); + nodes.push_back(node); nodes_seen.insert(node.get()); } nodes_re_pass.clear(); while (!nodes.empty()) { NodePtr node = nodes.front(); - nodes.pop(); + nodes.pop_front(); (void)nodes_re_pass.erase(node); GE_IF_BOOL_EXEC(node == nullptr, GELOGW("node is null"); continue); @@ -210,7 +227,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { AddNextIterNodes(node->GetOutNodes(), nodes, nodes_seen, nodes_last); - auto ret = RunPasses(node, names_to_passes, nodes_re_pass, nodes_deleted, nodes_seen); + auto ret = RunPasses(node, names_to_passes, nodes_re_pass, nodes_re_pass_immediately, nodes_deleted, nodes_seen); if (ret != SUCCESS) { GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u", node->GetName().c_str(), node->GetType().c_str(), ret); @@ -227,7 +244,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { if (has_sub_graph) { GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); - ret = RunPasses(node, names_to_passes, nodes_re_pass, nodes_deleted, nodes_seen); + ret = RunPasses(node, names_to_passes, nodes_re_pass, nodes_re_pass_immediately, nodes_deleted, nodes_seen); if (ret != SUCCESS) { GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u", node->GetName().c_str(), node->GetType().c_str(), ret); @@ -239,12 +256,16 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { // should be called each time at the begin of the iteration ClearOption(names_to_passes); } + for(auto &node : nodes_re_pass_immediately){ + nodes.push_front(node); + } + nodes_re_pass_immediately.clear(); } for (auto &node : nodes_last) { bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); if (all_in_nodes_seen && nodes_seen.insert(node.get()).second) { - nodes.push(node); + nodes.push_back(node); } } nodes_last.clear(); diff --git a/ge/graph/passes/base_pass.h b/ge/graph/passes/base_pass.h index bb41691d..89a364a9 100644 --- a/ge/graph/passes/base_pass.h +++ b/ge/graph/passes/base_pass.h @@ -53,6 +53,8 @@ class BaseNodePass { std::unordered_set GetNodesNeedRePass() { return nodes_need_re_pass_; } + std::unordered_set GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } + std::unordered_set GetNodesDeleted() { return nodes_deleted_; } void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } @@ -80,6 +82,14 @@ class BaseNodePass { void AddRePassNode(NodePtr &node) { nodes_need_re_pass_.insert(node); } /// + /// Add a node to be optimized immediately again. If you add a new node to the graph, or + /// change a node connections, and you want to make sure the node will be + /// optimized by other passes, call this function. + /// @param node + /// + void AddImmediateRePassNode(NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } + + /// /// Add a node and it's input/output data nodes to be optimized again. /// @param node /// @@ -109,6 +119,7 @@ class BaseNodePass { private: std::unordered_set nodes_need_re_pass_; + std::unordered_set nodes_need_re_pass_immediately_; std::unordered_set nodes_deleted_; std::map options_; }; diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index 7b8f7b50..fd943c2d 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -25,6 +25,7 @@ namespace ge { Status InferShapePass::Run(NodePtr &node) { + // kOptimizeAfterSubGraph exist means after subgraph auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph)); if (ret != GRAPH_SUCCESS) { // select INFERSHAPE failed info @@ -41,6 +42,21 @@ Status InferShapePass::Run(NodePtr &node) { GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infershape failed. node: %s", node->GetName().c_str()); return GE_GRAPH_INFERSHAPE_FAILED; } + if(node->GetType() == WHILE){ + bool need_repass = false; + AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass); + if(!OptionExists(kOptimizeAfterSubGraph)){ + return SUCCESS; + } + if(need_repass){ + AddImmediateRePassNode(node); + GELOGD("Node %s need repass immediately.", node->GetName().c_str()); + } + else{ + // clear attr on while + node->GetOpDesc()->DelAttr("need_infer_again_"); + } + } return SUCCESS; } } // namespace ge