|
|
@@ -36,6 +36,7 @@ struct DuringPassNodeSets { |
|
|
|
std::unordered_set<NodePtr> nodes_re_pass; |
|
|
|
std::unordered_set<NodePtr> nodes_re_pass_immediately; |
|
|
|
std::unordered_set<NodePtr> nodes_last; |
|
|
|
std::unordered_set<NodePtr> nodes_suspend; |
|
|
|
}; |
|
|
|
|
|
|
|
void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes, |
|
|
@@ -55,8 +56,25 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &i |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool HasIterationSuspendInputs(const NodePtr &node, const std::unordered_set<NodePtr> &nodes_suspend) { |
|
|
|
if (nodes_suspend.empty()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
const auto &nodes = node->GetInAllNodes(); |
|
|
|
if (std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { return nodes_suspend.count(n) > 0; })) { |
|
|
|
GELOGI("The node %s has suspended input node, the iteration of it will be suspend.", node->GetName().c_str()); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, |
|
|
|
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_last) { |
|
|
|
DuringPassNodeSets &during_pass_node_set) { |
|
|
|
auto &nodes_seen = during_pass_node_set.nodes_seen; |
|
|
|
const auto &nodes_last = during_pass_node_set.nodes_last; |
|
|
|
const auto &nodes_suspend = during_pass_node_set.nodes_suspend; |
|
|
|
for (auto &node : nodes) { |
|
|
|
if (node == nullptr) { |
|
|
|
continue; |
|
|
@@ -64,16 +82,58 @@ void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &n |
|
|
|
if (nodes_last.count(node) != 0) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (nodes_suspend.count(node) > 0) { |
|
|
|
GELOGD("The node %s has suspend by pass, skip it.", node->GetName().c_str()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); |
|
|
|
if (all_in_nodes_seen && nodes_seen.insert(node.get()).second) { |
|
|
|
bool has_suspend_inputs = HasIterationSuspendInputs(node, nodes_suspend); |
|
|
|
if (all_in_nodes_seen && !has_suspend_inputs && nodes_seen.insert(node.get()).second) { |
|
|
|
nodes_to_pass.push_back(node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PushToSuspendNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, |
|
|
|
const std::unordered_set<NodePtr> &nodes_suspend, |
|
|
|
const std::unordered_set<NodePtr> &nodes_resume) { |
|
|
|
for (const auto &node : nodes_suspend) { |
|
|
|
GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str()); |
|
|
|
during_pass_node_set.nodes_suspend.emplace(node); |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &node : nodes_resume) { |
|
|
|
GELOGD("The iteration suspend of node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str()); |
|
|
|
const auto it = during_pass_node_set.nodes_suspend.find(node); |
|
|
|
if (it != during_pass_node_set.nodes_suspend.end()) { |
|
|
|
during_pass_node_set.nodes_suspend.erase(node); |
|
|
|
} else { |
|
|
|
GELOGW("The iteration resumed node %s not suspend by pass %s", node->GetName().c_str(), pass_name.c_str()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void AddNodesRepassToQueue(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) { |
|
|
|
for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { |
|
|
|
GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); |
|
|
|
nodes.push_front(node); |
|
|
|
} |
|
|
|
during_pass_node_set.nodes_re_pass_immediately.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
void AddNodesLastToQueue(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) { |
|
|
|
for (auto &node : during_pass_node_set.nodes_last) { |
|
|
|
bool all_in_nodes_seen = node->IsAllInNodesSeen(during_pass_node_set.nodes_seen); |
|
|
|
if (all_in_nodes_seen && during_pass_node_set.nodes_seen.insert(node.get()).second) { |
|
|
|
nodes.push_back(node); |
|
|
|
} |
|
|
|
} |
|
|
|
during_pass_node_set.nodes_last.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, |
|
|
|
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_to_re_pass, |
|
|
|
std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass, |
|
|
|
std::unordered_set<NodePtr> &nodes_re_pass) { |
|
|
|
for (const auto &node_to_re_pass : nodes_to_re_pass) { |
|
|
|
if (node_to_re_pass == nullptr) { |
|
|
@@ -113,15 +173,18 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
auto nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); |
|
|
|
const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); |
|
|
|
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, |
|
|
|
during_pass_node_set.nodes_re_pass); |
|
|
|
|
|
|
|
auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); |
|
|
|
const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); |
|
|
|
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, |
|
|
|
during_pass_node_set.nodes_re_pass_immediately); |
|
|
|
|
|
|
|
auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); |
|
|
|
PushToSuspendNodes(during_pass_node_set, name_to_pass.first, |
|
|
|
name_to_pass.second->GetNodesSuspend(), name_to_pass.second->GetNodesResume()); |
|
|
|
|
|
|
|
const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); |
|
|
|
during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); |
|
|
|
if (nodes_deleted_by_pass.count(node) > 0) { |
|
|
|
GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), |
|
|
@@ -221,8 +284,13 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { |
|
|
|
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (during_pass_node_set.nodes_suspend.count(node) > 0) { |
|
|
|
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", |
|
|
|
node->GetName().c_str()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); |
|
|
|
AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set); |
|
|
|
|
|
|
|
auto ret = RunPasses(node, names_to_passes, during_pass_node_set); |
|
|
|
if (ret != SUCCESS) { |
|
|
@@ -253,22 +321,18 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { |
|
|
|
// should be called each time at the begin of the iteration |
|
|
|
ClearOption(names_to_passes); |
|
|
|
} |
|
|
|
for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { |
|
|
|
GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); |
|
|
|
nodes.push_front(node); |
|
|
|
} |
|
|
|
during_pass_node_set.nodes_re_pass_immediately.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &node : during_pass_node_set.nodes_last) { |
|
|
|
bool all_in_nodes_seen = node->IsAllInNodesSeen(during_pass_node_set.nodes_seen); |
|
|
|
if (all_in_nodes_seen && during_pass_node_set.nodes_seen.insert(node.get()).second) { |
|
|
|
nodes.push_back(node); |
|
|
|
} |
|
|
|
AddNodesRepassToQueue(during_pass_node_set, nodes); |
|
|
|
} |
|
|
|
during_pass_node_set.nodes_last.clear(); |
|
|
|
|
|
|
|
AddNodesLastToQueue(during_pass_node_set, nodes); |
|
|
|
} while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); |
|
|
|
|
|
|
|
if (!during_pass_node_set.nodes_suspend.empty()) { |
|
|
|
GELOGE(FAILED, "After iteration some nodes still suspend, size: %zu", during_pass_node_set.nodes_suspend.size()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (re_pass_times == kMaxRePassTimes) { |
|
|
|
GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); |
|
|
|
} |
|
|
|