diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc index 74babadc..a4095c1b 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc @@ -16,8 +16,6 @@ #include "mark_force_unknown_for_cond_pass.h" -#include - #include "graph/utils/node_utils.h" #include "graph/common/omg_util.h" @@ -26,17 +24,7 @@ namespace { inline bool IsMergeInLoop(const NodePtr &node) { const static std::set kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; - std::string node_type; - (void)GetOriginalType(node, node_type); - return kLoopMergeInputs.count(node_type) > 0; -} - -inline bool IsSwitchInLoop(const NodePtr &node) { - const static std::set kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND }; - - std::string node_type; - (void)GetOriginalType(node, node_type); - return kLoopSwitchInputs.count(node_type) > 0; + return kLoopMergeInputs.count(NodeUtils::GetNodeType(node)) > 0; } } @@ -44,10 +32,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { GELOGD("MarkForceUnknownForCondPass Enter"); std::map> switch_groups; for (const auto &node : graph->GetDirectNode()) { - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(node, node_type), - "[Get][OriginalType] of node in graph:%s failed.", graph->GetName().c_str()); - if (kMergeOpTypes.count(node_type) == 0) { + if (kMergeOpTypes.count(NodeUtils::GetNodeType(node)) == 0) { continue; } @@ -65,6 +50,51 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { } /// +/// @brief Deal with Switch node for LoopCond +/// @param [in] Switch node +/// @param [in] dest span +/// @param [out] Search queue +/// @return true: Switch In while loop / false: Not in while Loop. +/// +bool MarkForceUnknownForCondPass::DealWithLoopSwitch(const NodePtr &node, uint32_t dst_span, + std::queue> search_queue) { + /// LoopCond --->\. + /// \. + /// Enter-----------+ \. + /// +--> Merge --> Switch --> Exit + /// NextIteration---+ + const auto is_loop_op = [](const NodePtr &n) { + return NodeUtils::GetNodeType(n) == LOOPCOND; + }; + const auto is_exit_op = [](const NodePtr &n) { + return kExitOpTypes.count(NodeUtils::GetNodeType(n)) > 0; + }; + + const auto src_nodes = node->GetInAllNodes(); + const auto dst_nodes = node->GetOutAllNodes(); + if (std::none_of(src_nodes.begin(), src_nodes.end(), is_loop_op) && + std::none_of(dst_nodes.begin(), dst_nodes.end(), is_exit_op)) { + return false; + } + + for (const auto &m : src_nodes) { + if (kMergeOpTypes.count(NodeUtils::GetNodeType(m)) > 0) { + for (const auto &n : m->GetInAllNodes()) { + if (kNextIterationOpTypes.count(NodeUtils::GetNodeType(n)) > 0) { + continue; + } + + search_queue.push({n, dst_span}); + GELOGD("Travel in Loop: %s <-- %s <-- %s, span is: %u", node->GetName().c_str(), m->GetName().c_str(), + n->GetName().c_str(), dst_span); + } + } + } + + return true; +} + +/// /// @brief Mark force unknown shape for Switch node /// @param [in] merge node /// @param [out] switch group @@ -72,6 +102,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { /// void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector &switch_group) { // Switch --> {Switch --> Merge} --> Merge + GELOGD("Search Switch node for Merge: %s", node->GetName().c_str()); std::unordered_set nodes_seen; std::queue> search_queue({{node, 0}}); while (!search_queue.empty()) { @@ -79,43 +110,25 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: const auto dst_span = search_queue.front().second; search_queue.pop(); - // Switch --> Identity --> Constant - for (const auto &in_node : dst_node->GetInControlNodes()) { - if (nodes_seen.count(in_node) > 0) { - GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); - continue; - } - nodes_seen.insert(in_node); - - if (in_node->GetType() == IDENTITY) { - GELOGD("Travel node: %s, In control: %s, span is: %u", dst_node->GetName().c_str(), - in_node->GetName().c_str(), dst_span); - search_queue.push({in_node, dst_span}); - } - } - - for (const auto &in_node : dst_node->GetInDataNodes()) { + for (const auto &in_node : dst_node->GetInAllNodes()) { if (nodes_seen.count(in_node) > 0) { GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); continue; } nodes_seen.insert(in_node); - std::string node_type; - (void)GetOriginalType(in_node, node_type); + const std::string node_type = NodeUtils::GetNodeType(in_node); GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), in_node->GetName().c_str(), dst_span); if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. + if (DealWithLoopSwitch(in_node, dst_span, search_queue)) { + continue; + } + if (dst_span > 0) { search_queue.push({in_node, dst_span - 1}); } else { - const auto &all_in_nodes = in_node->GetInDataNodes(); - if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) { - GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(), - in_node->GetName().c_str()); - } else { - switch_group.emplace_back(in_node); - } + switch_group.emplace_back(in_node); } } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. search_queue.push({in_node, dst_span + 1}); diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.h b/ge/graph/passes/mark_force_unknown_for_cond_pass.h index 528a8fdc..d2be9a9e 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.h +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.h @@ -19,6 +19,8 @@ #include "inc/graph_pass.h" +#include + namespace ge { class MarkForceUnknownForCondPass : public GraphPass { public: @@ -26,6 +28,15 @@ class MarkForceUnknownForCondPass : public GraphPass { private: /// + /// @brief Deal with Switch node for LoopCond + /// @param [in] Switch node + /// @param [in] dest span + /// @param [out] Search queue + /// @return true: Switch In while loop / false: Not in while Loop. + /// + bool DealWithLoopSwitch(const NodePtr &node, uint32_t dst_span, std::queue> search_queue); + + /// /// @brief Mark force unknown shape for Switch node /// @param [in] merge node /// @param [out] switch group diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc index e4ab0111..1a47c14b 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -395,8 +395,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); int64_t group_index = -1; - (void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); - SetControlFlowGroup(stream_switch, group_index); + if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { + SetControlFlowGroup(stream_switch, group_index); + } return stream_switch; } diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index 468c84e6..4b0d0c44 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -326,17 +326,37 @@ std::shared_ptr NodeState::GetTaskContext() { } void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { - if (node_item_->root_data_.count(input_idx) > 0) { + const auto is_persist_tensor = [](const std::map> &items, int idx) { + const auto is_exist = [&idx](const std::pair> &items) { + return items.second.count(idx) > 0; + }; + return std::any_of(items.begin(), items.end(), is_exist); + }; + + if (is_persist_tensor(node_item_->root_data_, input_idx)) { GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); root_tensor_values_[input_idx] = tensor; - } - - if (node_item_->enter_data_.count(input_idx) > 0) { + } else if (is_persist_tensor(node_item_->enter_data_, input_idx)) { GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); root_tensor_values_[input_idx] = tensor; } } +void NodeState::UpdatePersistTensor() { + const auto update_tensor = [&](const std::map> &items) { + for (const auto &item : items) { + for (const auto idx : item.second) { + UpdatePersistTensor(idx); + } + } + }; + + update_tensor(node_item_->root_data_); + if (iteration_count_ > 0) { + update_tensor(node_item_->enter_data_); + } +} + void NodeState::UpdatePersistTensor(int input_idx) { const auto it = root_tensor_values_.find(input_idx); if (it == root_tensor_values_.end()) { @@ -363,16 +383,9 @@ void NodeState::ResetContext(uint64_t iteration) { data_scheduled_ = static_cast(node_item_->root_data_.size()); ctrl_scheduled_ = static_cast(node_item_->root_ctrl_.size()); - for (auto item : node_item_->root_data_) { - UpdatePersistTensor(item.first); - } - if (iteration > 0) { data_scheduled_ += static_cast(node_item_->enter_data_.size()); ctrl_scheduled_ += static_cast(node_item_->enter_ctrl_.size()); - for (auto item : node_item_->enter_data_) { - UpdatePersistTensor(item.first); - } } iteration_count_ = iteration; diff --git a/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h index b80b60b0..1ec8517e 100644 --- a/ge/hybrid/executor/node_state.h +++ b/ge/hybrid/executor/node_state.h @@ -132,6 +132,7 @@ struct NodeState { void RunNextIteration(); void SavePersistTensor(int input_idx, const TensorValue &tensor); + void UpdatePersistTensor(); Status NodeScheduled(const std::function &ready) const; diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index 250562ce..8e87c6e2 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -395,11 +395,13 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { data_send_.emplace(node_item); node_item->data_recv_[this] = anchor_index; if (is_root_node_) { - node_item->root_data_[anchor_index] = this; + auto &data_anchors = node_item->root_data_[this]; + data_anchors.emplace(anchor_index); } // If Enter feed Not Merge, take as root Node. if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { - node_item->enter_data_[anchor_index] = this; + auto &data_anchors = node_item->enter_data_[this]; + data_anchors.emplace(anchor_index); } GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); } diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index 12775b00..f6dcdcf6 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -148,9 +148,9 @@ struct NodeItem { int64_t frame_index_ = -1; int64_t parent_frame_ = -1; std::set root_ctrl_; // Recv ctrl from root node - std::map root_data_; // Recv data from root node + std::map> root_data_; // Recv data from root node std::set enter_ctrl_; // Recv ctrl from Enter node - std::map enter_data_; // Recv data from Enter node + std::map> enter_data_; // Recv data from Enter node std::set data_send_; // Send data notify to std::map data_recv_; // Recv data notify from std::set ctrl_send_; // Send ctrl notify to diff --git a/ge/hybrid/node_executor/node_executor.cc b/ge/hybrid/node_executor/node_executor.cc index 9e9354d9..eeb5ba20 100755 --- a/ge/hybrid/node_executor/node_executor.cc +++ b/ge/hybrid/node_executor/node_executor.cc @@ -39,6 +39,7 @@ const char *const kEngineNameHostCpu = "DNN_VM_HOST_CPU_OP_STORE"; Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces()); + GE_CHK_STATUS_RET_NOLOG(context.UpdatePersistTensor()); GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context)); return SUCCESS; } diff --git a/ge/hybrid/node_executor/task_context.cc b/ge/hybrid/node_executor/task_context.cc index fe580c1e..7ff83ce0 100644 --- a/ge/hybrid/node_executor/task_context.cc +++ b/ge/hybrid/node_executor/task_context.cc @@ -468,6 +468,12 @@ Status TaskContext::PropagateOutputs() { return SUCCESS; } +Status TaskContext::UpdatePersistTensor() { + GE_CHECK_NOTNULL(node_state_); + node_state_->UpdatePersistTensor(); + return SUCCESS; +} + const void *TaskContext::GetVarBaseAddr() { return execution_context_->model->GetVarMemBase(); } diff --git a/ge/hybrid/node_executor/task_context.h b/ge/hybrid/node_executor/task_context.h index c96e194e..cff5d680 100644 --- a/ge/hybrid/node_executor/task_context.h +++ b/ge/hybrid/node_executor/task_context.h @@ -78,6 +78,7 @@ class TaskContext { Status AllocateOutputs(AllocationAttr *attr = nullptr); Status AllocateWorkspaces(); Status AllocateWorkspace(size_t size, void **buffer, void *ori_addr = nullptr); + Status UpdatePersistTensor(); bool IsTraceEnabled() const;