|
|
@@ -24,34 +24,31 @@ namespace ge { |
|
|
|
Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { |
|
|
|
GELOGD("AttachStreamLabelPass Enter."); |
|
|
|
|
|
|
|
FindNodes(graph); |
|
|
|
for (const auto &node : need_label_nodes_) { |
|
|
|
GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); |
|
|
|
std::vector<NodePtr> need_label_nodes; |
|
|
|
std::vector<NodePtr> enter_nodes; |
|
|
|
std::map<NodePtr, NodePtr> branch_head_nodes; |
|
|
|
FindNodes(graph, need_label_nodes, enter_nodes, branch_head_nodes); |
|
|
|
for (const auto &node : need_label_nodes) { |
|
|
|
GE_CHK_STATUS_RET(UpdateCondBranch(node, branch_head_nodes), "Update cond branch failed, start node:%s.", node->GetName().c_str()); |
|
|
|
} |
|
|
|
GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed."); |
|
|
|
GE_CHK_STATUS_RET(UpdateEnterNode(enter_nodes), "UpdateEnterNode failed."); |
|
|
|
|
|
|
|
GELOGD("AttachStreamLabelPass Leave."); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
/// |
|
|
|
/// @brief Clear Status, used for subgraph pass |
|
|
|
/// @return |
|
|
|
/// |
|
|
|
Status AttachStreamLabelPass::ClearStatus() { |
|
|
|
stream_switch_nodes_.clear(); |
|
|
|
need_label_nodes_.clear(); |
|
|
|
enter_nodes_.clear(); |
|
|
|
branch_head_nodes_.clear(); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
/// |
|
|
|
/// @brief Find StreamSwitch / StreamMerge / Enter node |
|
|
|
/// @param [in] graph |
|
|
|
/// @param [out] need_label_nodes |
|
|
|
/// @param [out] enter_nodes |
|
|
|
/// @param [out] branch_head_nodes |
|
|
|
/// @return void |
|
|
|
/// |
|
|
|
void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { |
|
|
|
void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph, std::vector<NodePtr> &need_label_nodes, |
|
|
|
std::vector<NodePtr> &enter_nodes, |
|
|
|
std::map<NodePtr, NodePtr> &branch_head_nodes) { |
|
|
|
std::vector<NodePtr> stream_switch_nodes; |
|
|
|
for (const NodePtr &node : graph->GetDirectNode()) { |
|
|
|
const auto &op_desc = node->GetOpDesc(); |
|
|
|
if (op_desc == nullptr) { |
|
|
@@ -59,29 +56,31 @@ void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { |
|
|
|
} |
|
|
|
const std::string &type = op_desc->GetType(); |
|
|
|
if ((type == STREAMSWITCH) && op_desc->HasAttr(ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG)) { |
|
|
|
stream_switch_nodes_.emplace_back(node); |
|
|
|
stream_switch_nodes.emplace_back(node); |
|
|
|
} else if ((type == STREAMMERGE) && !op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) { |
|
|
|
need_label_nodes_.emplace_back(node); |
|
|
|
need_label_nodes.emplace_back(node); |
|
|
|
} else if ((type == ENTER) || (type == REFENTER)) { |
|
|
|
enter_nodes_.emplace_back(node); |
|
|
|
enter_nodes.emplace_back(node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &node : stream_switch_nodes_) { |
|
|
|
for (const auto &node : stream_switch_nodes) { |
|
|
|
for (const auto &out_ctrl_node : node->GetOutControlNodes()) { |
|
|
|
GELOGD("branch_head_node %s of stream_switch %s.", out_ctrl_node->GetName().c_str(), node->GetName().c_str()); |
|
|
|
branch_head_nodes_[out_ctrl_node] = node; |
|
|
|
branch_head_nodes[out_ctrl_node] = node; |
|
|
|
} |
|
|
|
need_label_nodes_.emplace_back(node); |
|
|
|
need_label_nodes.emplace_back(node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// |
|
|
|
/// @brief update cond branch |
|
|
|
/// @param [in] node |
|
|
|
/// @param [in] branch_head_nodes |
|
|
|
/// @return Status |
|
|
|
/// |
|
|
|
Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { |
|
|
|
Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node, |
|
|
|
const std::map<NodePtr, NodePtr> &branch_head_nodes) { |
|
|
|
std::string stream_label; |
|
|
|
if (AttachFlag(node, stream_label) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Attach flag for node %s failed.", node->GetName().c_str()); |
|
|
@@ -103,8 +102,9 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { |
|
|
|
const std::string &type = cur_node->GetType(); |
|
|
|
for (const auto &out_node : cur_node->GetOutAllNodes()) { |
|
|
|
const std::string &out_type = out_node->GetType(); |
|
|
|
const auto &iter = branch_head_nodes.find(node); |
|
|
|
bool stop_flag = (end_type_set.count(out_type) > 0) || |
|
|
|
((branch_head_nodes_.count(out_node) > 0) && (branch_head_nodes_[out_node] != node)) || |
|
|
|
((iter != branch_head_nodes.end()) && (iter->second != node)) || |
|
|
|
(((type == ENTER) || (type == REFENTER)) && (out_type != STREAMACTIVE)); |
|
|
|
if (!stop_flag) { |
|
|
|
nodes.push(out_node); |
|
|
@@ -178,11 +178,12 @@ Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &strea |
|
|
|
|
|
|
|
/// |
|
|
|
/// @brief Update stream_label start with enter nodes |
|
|
|
/// @param [in] enter_nodes |
|
|
|
/// @return Status |
|
|
|
/// |
|
|
|
Status AttachStreamLabelPass::UpdateEnterNode() { |
|
|
|
Status AttachStreamLabelPass::UpdateEnterNode(const std::vector<NodePtr> &enter_nodes) { |
|
|
|
std::unordered_map<NodePtr, std::vector<NodePtr>> enter_active_map; |
|
|
|
for (const auto &enter_node : enter_nodes_) { |
|
|
|
for (const auto &enter_node : enter_nodes) { |
|
|
|
for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) { |
|
|
|
if (out_ctrl_node->GetType() != STREAMACTIVE) { |
|
|
|
continue; |
|
|
@@ -214,11 +215,11 @@ Status AttachStreamLabelPass::UpdateEnterNode() { |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
std::stack<NodePtr> enter_nodes; |
|
|
|
std::stack<NodePtr> nodes; |
|
|
|
for (const auto &enter_node : pair.second) { |
|
|
|
enter_nodes.emplace(enter_node); |
|
|
|
nodes.emplace(enter_node); |
|
|
|
} |
|
|
|
if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) { |
|
|
|
if (UpdateLoopBranch(nodes, active_label_list[0]) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Update stream_label for loop_branch failed."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|