@@ -387,6 +387,9 @@ void DynamicShapePartitioner::MergeClustersUnknownShape() { | |||||
if (!in_cluster->IsUnknownShape()) { | if (!in_cluster->IsUnknownShape()) { | ||||
continue; | continue; | ||||
} | } | ||||
if (!cluster->IsAdjoinNodes(in_cluster)) { | |||||
continue; | |||||
} | |||||
auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | ||||
GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), | GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), | ||||
ToString(merged_clusters).c_str()); | ToString(merged_clusters).c_str()); | ||||
@@ -80,6 +80,10 @@ class DynamicShapePartitioner { | |||||
Status BuildPartitionSubgraph(); | Status BuildPartitionSubgraph(); | ||||
// Clear resource and break circular dependency | // Clear resource and break circular dependency | ||||
void Clear(); | void Clear(); | ||||
bool IsAdjoinNodes(const std::shared_ptr<Cluster> &other) const { | |||||
const auto &out_clusters = other->out_clusters_; | |||||
return std::find(out_clusters.begin(), out_clusters.end(), shared_from_this()) != out_clusters.end(); | |||||
} | |||||
private: | private: | ||||
static thread_local size_t unique_id_; | static thread_local size_t unique_id_; | ||||
@@ -537,6 +537,7 @@ Status SubgraphExecutor::LaunchTasks() { | |||||
Status SubgraphExecutor::ScheduleTasks(int group) { | Status SubgraphExecutor::ScheduleTasks(int group) { | ||||
GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); | GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); | ||||
subgraph_context_->SetGroup(group); | |||||
auto prepare_future = std::async(std::launch::async, [&]() -> Status { | auto prepare_future = std::async(std::launch::async, [&]() -> Status { | ||||
GetContext().SetSessionId(context_->session_id); | GetContext().SetSessionId(context_->session_id); | ||||
GetContext().SetContextId(context_->context_id); | GetContext().SetContextId(context_->context_id); | ||||
@@ -401,6 +401,11 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||||
if (is_root_node_) { | if (is_root_node_) { | ||||
node_item->root_data_.emplace(this); | node_item->root_data_.emplace(this); | ||||
} | } | ||||
// If Enter feed Not Merge, take as root Node. | |||||
if ((kEnterOpTypes.count(node_type) > 0) && (node_item->node_type != STREAMMERGE)) { | |||||
node_item->root_data_.emplace(this); | |||||
node_item->enter_inside_.emplace(anchor_index); | |||||
} | |||||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
} | } | ||||
@@ -142,6 +142,7 @@ struct NodeItem { | |||||
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | ||||
std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | ||||
std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | ||||
std::set<int> enter_inside_; // Enter feed loop inside Node, Not cross Merge. | |||||
std::shared_ptr<NodeTask> kernel_task; | std::shared_ptr<NodeTask> kernel_task; | ||||
std::unique_ptr<FusedSubgraph> fused_subgraph; | std::unique_ptr<FusedSubgraph> fused_subgraph; | ||||
@@ -489,6 +489,11 @@ void TaskContext::ReleaseInputsAndOutputs() { | |||||
} | } | ||||
void TaskContext::ReleaseInput(int index) { | void TaskContext::ReleaseInput(int index) { | ||||
if (node_item_->enter_inside_.count(index) > 0) { | |||||
GELOGD("[%s] Tensor of input[%d] is enter, keep it", GetNodeName(), index); | |||||
return; | |||||
} | |||||
auto input_tensor = MutableInput(index); | auto input_tensor = MutableInput(index); | ||||
if (input_tensor != nullptr) { | if (input_tensor != nullptr) { | ||||
input_tensor->Destroy(); | input_tensor->Destroy(); | ||||