Browse Source

rm member var for clean code

tags/v1.3.0
chenyemeng 4 years ago
parent
commit
6fd0788c1b
2 changed files with 40 additions and 44 deletions
  1. +31
    -30
      ge/graph/passes/attach_stream_label_pass.cc
  2. +9
    -14
      ge/graph/passes/attach_stream_label_pass.h

+ 31
- 30
ge/graph/passes/attach_stream_label_pass.cc View File

@@ -24,34 +24,31 @@ namespace ge {
Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) {
GELOGD("AttachStreamLabelPass Enter.");

FindNodes(graph);
for (const auto &node : need_label_nodes_) {
GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str());
std::vector<NodePtr> need_label_nodes;
std::vector<NodePtr> enter_nodes;
std::map<NodePtr, NodePtr> branch_head_nodes;
FindNodes(graph, need_label_nodes, enter_nodes, branch_head_nodes);
for (const auto &node : need_label_nodes) {
GE_CHK_STATUS_RET(UpdateCondBranch(node, branch_head_nodes), "Update cond branch failed, start node:%s.", node->GetName().c_str());
}
GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed.");
GE_CHK_STATUS_RET(UpdateEnterNode(enter_nodes), "UpdateEnterNode failed.");

GELOGD("AttachStreamLabelPass Leave.");
return SUCCESS;
}

///
/// @brief Clear Status, used for subgraph pass
/// @return
///
Status AttachStreamLabelPass::ClearStatus() {
stream_switch_nodes_.clear();
need_label_nodes_.clear();
enter_nodes_.clear();
branch_head_nodes_.clear();
return SUCCESS;
}

///
/// @brief Find StreamSwitch / StreamMerge / Enter node
/// @param [in] graph
/// @param [out] need_label_nodes
/// @param [out] enter_nodes
/// @param [out] branch_head_nodes
/// @return void
///
void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) {
void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph, std::vector<NodePtr> &need_label_nodes,
std::vector<NodePtr> &enter_nodes,
std::map<NodePtr, NodePtr> &branch_head_nodes) {
std::vector<NodePtr> stream_switch_nodes;
for (const NodePtr &node : graph->GetDirectNode()) {
const auto &op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
@@ -59,29 +56,31 @@ void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) {
}
const std::string &type = op_desc->GetType();
if ((type == STREAMSWITCH) && op_desc->HasAttr(ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG)) {
stream_switch_nodes_.emplace_back(node);
stream_switch_nodes.emplace_back(node);
} else if ((type == STREAMMERGE) && !op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) {
need_label_nodes_.emplace_back(node);
need_label_nodes.emplace_back(node);
} else if ((type == ENTER) || (type == REFENTER)) {
enter_nodes_.emplace_back(node);
enter_nodes.emplace_back(node);
}
}

for (const auto &node : stream_switch_nodes_) {
for (const auto &node : stream_switch_nodes) {
for (const auto &out_ctrl_node : node->GetOutControlNodes()) {
GELOGD("branch_head_node %s of stream_switch %s.", out_ctrl_node->GetName().c_str(), node->GetName().c_str());
branch_head_nodes_[out_ctrl_node] = node;
branch_head_nodes[out_ctrl_node] = node;
}
need_label_nodes_.emplace_back(node);
need_label_nodes.emplace_back(node);
}
}

///
/// @brief update cond branch
/// @param [in] node
/// @param [in] branch_head_nodes
/// @return Status
///
Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) {
Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node,
const std::map<NodePtr, NodePtr> &branch_head_nodes) {
std::string stream_label;
if (AttachFlag(node, stream_label) != SUCCESS) {
GELOGE(FAILED, "Attach flag for node %s failed.", node->GetName().c_str());
@@ -103,8 +102,9 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) {
const std::string &type = cur_node->GetType();
for (const auto &out_node : cur_node->GetOutAllNodes()) {
const std::string &out_type = out_node->GetType();
const auto &iter = branch_head_nodes.find(node);
bool stop_flag = (end_type_set.count(out_type) > 0) ||
((branch_head_nodes_.count(out_node) > 0) && (branch_head_nodes_[out_node] != node)) ||
((iter != branch_head_nodes.end()) && (iter->second != node)) ||
(((type == ENTER) || (type == REFENTER)) && (out_type != STREAMACTIVE));
if (!stop_flag) {
nodes.push(out_node);
@@ -178,11 +178,12 @@ Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &strea

///
/// @brief Update stream_label start with enter nodes
/// @param [in] enter_nodes
/// @return Status
///
Status AttachStreamLabelPass::UpdateEnterNode() {
Status AttachStreamLabelPass::UpdateEnterNode(const std::vector<NodePtr> &enter_nodes) {
std::unordered_map<NodePtr, std::vector<NodePtr>> enter_active_map;
for (const auto &enter_node : enter_nodes_) {
for (const auto &enter_node : enter_nodes) {
for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) {
if (out_ctrl_node->GetType() != STREAMACTIVE) {
continue;
@@ -214,11 +215,11 @@ Status AttachStreamLabelPass::UpdateEnterNode() {
return INTERNAL_ERROR;
}

std::stack<NodePtr> enter_nodes;
std::stack<NodePtr> nodes;
for (const auto &enter_node : pair.second) {
enter_nodes.emplace(enter_node);
nodes.emplace(enter_node);
}
if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) {
if (UpdateLoopBranch(nodes, active_label_list[0]) != SUCCESS) {
GELOGE(FAILED, "Update stream_label for loop_branch failed.");
return FAILED;
}


+ 9
- 14
ge/graph/passes/attach_stream_label_pass.h View File

@@ -25,26 +25,25 @@ class AttachStreamLabelPass : public GraphPass {
public:
Status Run(ComputeGraphPtr graph);

///
/// @brief Clear Status, used for subgraph pass
/// @return
///
Status ClearStatus() override;

private:
///
/// @brief Find StreamSwitch / StreamMerge / Enter node
/// @param [in] graph
/// @param [out] need_label_nodes
/// @param [out] enter_nodes
/// @param [out] branch_head_nodes
/// @return void
///
void FindNodes(const ComputeGraphPtr &graph);
void FindNodes(const ComputeGraphPtr &graph, std::vector<NodePtr> &need_label_nodes,
std::vector<NodePtr> &enter_nodes, std::map<NodePtr, NodePtr> &branch_head_nodes);

///
/// @brief update cond branch
/// @param [in] node
/// @param [in] branch_head_nodes
/// @return Status
///
Status UpdateCondBranch(const NodePtr &node);
Status UpdateCondBranch(const NodePtr &node, const std::map<NodePtr, NodePtr> &branch_head_nodes);

///
/// @brief attach flag
@@ -64,9 +63,10 @@ class AttachStreamLabelPass : public GraphPass {

///
/// @brief Update stream_label start with enter nodes
/// @param [in] enter_nodes
/// @return Status
///
Status UpdateEnterNode();
Status UpdateEnterNode(const std::vector<NodePtr> &enter_nodes);

///
/// @brief Set stream_label for enter_nodes
@@ -75,11 +75,6 @@ class AttachStreamLabelPass : public GraphPass {
/// @return Status
///
static Status SetEnterLabel(const std::vector<NodePtr> &enter_nodes, const NodePtr &active_node);

std::vector<NodePtr> stream_switch_nodes_;
std::vector<NodePtr> need_label_nodes_;
std::vector<NodePtr> enter_nodes_;
std::unordered_map<NodePtr, NodePtr> branch_head_nodes_;
};
} // namespace ge
#endif // GE_GRAPH_PASSES_ATTACH_STREAM_LABEL_PASS_H_

Loading…
Cancel
Save