diff --git a/ge/graph/passes/attach_stream_label_pass.cc b/ge/graph/passes/attach_stream_label_pass.cc index 75599c45..d8c81e92 100644 --- a/ge/graph/passes/attach_stream_label_pass.cc +++ b/ge/graph/passes/attach_stream_label_pass.cc @@ -24,34 +24,31 @@ namespace ge { Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { GELOGD("AttachStreamLabelPass Enter."); - FindNodes(graph); - for (const auto &node : need_label_nodes_) { - GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); + std::vector need_label_nodes; + std::vector enter_nodes; + std::map branch_head_nodes; + FindNodes(graph, need_label_nodes, enter_nodes, branch_head_nodes); + for (const auto &node : need_label_nodes) { + GE_CHK_STATUS_RET(UpdateCondBranch(node, branch_head_nodes), "Update cond branch failed, start node:%s.", node->GetName().c_str()); } - GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed."); + GE_CHK_STATUS_RET(UpdateEnterNode(enter_nodes), "UpdateEnterNode failed."); GELOGD("AttachStreamLabelPass Leave."); return SUCCESS; } /// -/// @brief Clear Status, used for subgraph pass -/// @return -/// -Status AttachStreamLabelPass::ClearStatus() { - stream_switch_nodes_.clear(); - need_label_nodes_.clear(); - enter_nodes_.clear(); - branch_head_nodes_.clear(); - return SUCCESS; -} - -/// /// @brief Find StreamSwitch / StreamMerge / Enter node /// @param [in] graph +/// @param [out] need_label_nodes +/// @param [out] enter_nodes +/// @param [out] branch_head_nodes /// @return void /// -void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { +void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph, std::vector &need_label_nodes, + std::vector &enter_nodes, + std::map &branch_head_nodes) { + std::vector stream_switch_nodes; for (const NodePtr &node : graph->GetDirectNode()) { const auto &op_desc = node->GetOpDesc(); if (op_desc == nullptr) { @@ -59,29 +56,31 @@ void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { } const std::string &type = op_desc->GetType(); if ((type == STREAMSWITCH) && op_desc->HasAttr(ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG)) { - stream_switch_nodes_.emplace_back(node); + stream_switch_nodes.emplace_back(node); } else if ((type == STREAMMERGE) && !op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) { - need_label_nodes_.emplace_back(node); + need_label_nodes.emplace_back(node); } else if ((type == ENTER) || (type == REFENTER)) { - enter_nodes_.emplace_back(node); + enter_nodes.emplace_back(node); } } - for (const auto &node : stream_switch_nodes_) { + for (const auto &node : stream_switch_nodes) { for (const auto &out_ctrl_node : node->GetOutControlNodes()) { GELOGD("branch_head_node %s of stream_switch %s.", out_ctrl_node->GetName().c_str(), node->GetName().c_str()); - branch_head_nodes_[out_ctrl_node] = node; + branch_head_nodes[out_ctrl_node] = node; } - need_label_nodes_.emplace_back(node); + need_label_nodes.emplace_back(node); } } /// /// @brief update cond branch /// @param [in] node +/// @param [in] branch_head_nodes /// @return Status /// -Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { +Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node, + const std::map &branch_head_nodes) { std::string stream_label; if (AttachFlag(node, stream_label) != SUCCESS) { GELOGE(FAILED, "Attach flag for node %s failed.", node->GetName().c_str()); @@ -103,8 +102,9 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { const std::string &type = cur_node->GetType(); for (const auto &out_node : cur_node->GetOutAllNodes()) { const std::string &out_type = out_node->GetType(); + const auto &iter = branch_head_nodes.find(node); bool stop_flag = (end_type_set.count(out_type) > 0) || - ((branch_head_nodes_.count(out_node) > 0) && (branch_head_nodes_[out_node] != node)) || + ((iter != branch_head_nodes.end()) && (iter->second != node)) || (((type == ENTER) || (type == REFENTER)) && (out_type != STREAMACTIVE)); if (!stop_flag) { nodes.push(out_node); @@ -178,11 +178,12 @@ Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &strea /// /// @brief Update stream_label start with enter nodes +/// @param [in] enter_nodes /// @return Status /// -Status AttachStreamLabelPass::UpdateEnterNode() { +Status AttachStreamLabelPass::UpdateEnterNode(const std::vector &enter_nodes) { std::unordered_map> enter_active_map; - for (const auto &enter_node : enter_nodes_) { + for (const auto &enter_node : enter_nodes) { for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) { if (out_ctrl_node->GetType() != STREAMACTIVE) { continue; @@ -214,11 +215,11 @@ Status AttachStreamLabelPass::UpdateEnterNode() { return INTERNAL_ERROR; } - std::stack enter_nodes; + std::stack nodes; for (const auto &enter_node : pair.second) { - enter_nodes.emplace(enter_node); + nodes.emplace(enter_node); } - if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) { + if (UpdateLoopBranch(nodes, active_label_list[0]) != SUCCESS) { GELOGE(FAILED, "Update stream_label for loop_branch failed."); return FAILED; } diff --git a/ge/graph/passes/attach_stream_label_pass.h b/ge/graph/passes/attach_stream_label_pass.h index ad71d58f..a1600a58 100755 --- a/ge/graph/passes/attach_stream_label_pass.h +++ b/ge/graph/passes/attach_stream_label_pass.h @@ -25,26 +25,25 @@ class AttachStreamLabelPass : public GraphPass { public: Status Run(ComputeGraphPtr graph); - /// - /// @brief Clear Status, used for subgraph pass - /// @return - /// - Status ClearStatus() override; - private: /// /// @brief Find StreamSwitch / StreamMerge / Enter node /// @param [in] graph + /// @param [out] need_label_nodes + /// @param [out] enter_nodes + /// @param [out] branch_head_nodes /// @return void /// - void FindNodes(const ComputeGraphPtr &graph); + void FindNodes(const ComputeGraphPtr &graph, std::vector &need_label_nodes, + std::vector &enter_nodes, std::map &branch_head_nodes); /// /// @brief update cond branch /// @param [in] node + /// @param [in] branch_head_nodes /// @return Status /// - Status UpdateCondBranch(const NodePtr &node); + Status UpdateCondBranch(const NodePtr &node, const std::map &branch_head_nodes); /// /// @brief attach flag @@ -64,9 +63,10 @@ class AttachStreamLabelPass : public GraphPass { /// /// @brief Update stream_label start with enter nodes + /// @param [in] enter_nodes /// @return Status /// - Status UpdateEnterNode(); + Status UpdateEnterNode(const std::vector &enter_nodes); /// /// @brief Set stream_label for enter_nodes @@ -75,11 +75,6 @@ class AttachStreamLabelPass : public GraphPass { /// @return Status /// static Status SetEnterLabel(const std::vector &enter_nodes, const NodePtr &active_node); - - std::vector stream_switch_nodes_; - std::vector need_label_nodes_; - std::vector enter_nodes_; - std::unordered_map branch_head_nodes_; }; } // namespace ge #endif // GE_GRAPH_PASSES_ATTACH_STREAM_LABEL_PASS_H_