diff --git a/ge/graph/partition/dynamic_shape_partition.cc b/ge/graph/partition/dynamic_shape_partition.cc index 8fee1eb5..5a441b86 100755 --- a/ge/graph/partition/dynamic_shape_partition.cc +++ b/ge/graph/partition/dynamic_shape_partition.cc @@ -387,6 +387,9 @@ void DynamicShapePartitioner::MergeClustersUnknownShape() { if (!in_cluster->IsUnknownShape()) { continue; } + if (!cluster->IsAdjoinNodes(in_cluster)) { + continue; + } auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), ToString(merged_clusters).c_str()); diff --git a/ge/graph/partition/dynamic_shape_partition.h b/ge/graph/partition/dynamic_shape_partition.h index f1d711eb..a17c4e4b 100644 --- a/ge/graph/partition/dynamic_shape_partition.h +++ b/ge/graph/partition/dynamic_shape_partition.h @@ -80,6 +80,10 @@ class DynamicShapePartitioner { Status BuildPartitionSubgraph(); // Clear resource and break circular dependency void Clear(); + bool IsAdjoinNodes(const std::shared_ptr &other) const { + const auto &out_clusters = other->out_clusters_; + return std::find(out_clusters.begin(), out_clusters.end(), shared_from_this()) != out_clusters.end(); + } private: static thread_local size_t unique_id_; diff --git a/ge/hybrid/executor/subgraph_executor.cc b/ge/hybrid/executor/subgraph_executor.cc index d7c2c4b7..f40be058 100644 --- a/ge/hybrid/executor/subgraph_executor.cc +++ b/ge/hybrid/executor/subgraph_executor.cc @@ -537,6 +537,7 @@ Status SubgraphExecutor::LaunchTasks() { Status SubgraphExecutor::ScheduleTasks(int group) { GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); + subgraph_context_->SetGroup(group); auto prepare_future = std::async(std::launch::async, [&]() -> Status { GetContext().SetSessionId(context_->session_id); GetContext().SetContextId(context_->context_id); diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index 07c8038b..952df9a0 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -401,6 +401,11 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { if (is_root_node_) { node_item->root_data_.emplace(this); } + // If Enter feed Not Merge, take as root Node. + if ((kEnterOpTypes.count(node_type) > 0) && (node_item->node_type != STREAMMERGE)) { + node_item->root_data_.emplace(this); + node_item->enter_inside_.emplace(anchor_index); + } GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); } diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index af796753..67f92868 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -142,6 +142,7 @@ struct NodeItem { std::set ctrl_send_; // Send ctrl notify to std::set ctrl_recv_; // Recv ctrl notify from std::vector> switch_groups_; // Send ctrl notify to + std::set enter_inside_; // Enter feed loop inside Node, Not cross Merge. std::shared_ptr kernel_task; std::unique_ptr fused_subgraph; diff --git a/ge/hybrid/node_executor/task_context.cc b/ge/hybrid/node_executor/task_context.cc index 59250d8c..14eb1222 100644 --- a/ge/hybrid/node_executor/task_context.cc +++ b/ge/hybrid/node_executor/task_context.cc @@ -489,6 +489,11 @@ void TaskContext::ReleaseInputsAndOutputs() { } void TaskContext::ReleaseInput(int index) { + if (node_item_->enter_inside_.count(index) > 0) { + GELOGD("[%s] Tensor of input[%d] is enter, keep it", GetNodeName(), index); + return; + } + auto input_tensor = MutableInput(index); if (input_tensor != nullptr) { input_tensor->Destroy();