From 747ae4bbe11814505d8c0ce840c9fdfa4a26d82a Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Thu, 10 Jun 2021 20:46:28 +0800 Subject: [PATCH] Fix dynamic shape partition --- ge/graph/partition/dynamic_shape_partition.cc | 5 +++++ ge/hybrid/executor/node_state.cc | 10 +++++----- ge/hybrid/executor/node_state.h | 2 +- ge/hybrid/model/node_item.h | 1 - ge/hybrid/node_executor/aicore/aicore_op_task.cc | 6 +++--- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/ge/graph/partition/dynamic_shape_partition.cc b/ge/graph/partition/dynamic_shape_partition.cc index a01fa62f..1db47498 100755 --- a/ge/graph/partition/dynamic_shape_partition.cc +++ b/ge/graph/partition/dynamic_shape_partition.cc @@ -381,6 +381,10 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { bool is_unknown_cluster = cluster->IsUnknownShape(); for (++rit; rit != control_cluster.rend(); ++rit) { const auto &cluster_from = *rit; + if (all_merged_clusters.count(cluster_from) > 0) { + continue; + } + auto merged_clusters = cluster->MergeAllPathFrom(cluster_from); GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(), ToString(merged_clusters).c_str()); @@ -393,6 +397,7 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { } if (!is_unknown_cluster && cluster->IsUnknownShape()) { + GELOGD("Add to ordered cluster: %s", cluster->DebugString().c_str()); ordered_cluster_.push_back(cluster); } } diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index ddded35e..c2760f04 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -320,18 +320,18 @@ std::shared_ptr NodeState::GetTaskContext() { void NodeState::SaveRootTensor(int input_idx, const TensorValue &tensor) { if (node_item_->root_data_.count(input_idx) > 0) { GELOGD("[%s] Save Const input tensor: %d", GetName().c_str(), input_idx); - root_tensor_value_[input_idx] = tensor; + root_tensor_values_[input_idx] = tensor; } if (node_item_->enter_data_.count(input_idx) > 0) { GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); - root_tensor_value_[input_idx] = tensor; + root_tensor_values_[input_idx] = tensor; } } void NodeState::UpdateRootTensor(int input_idx) { - const auto it = root_tensor_value_.find(input_idx); - if (it == root_tensor_value_.end()) { + const auto it = root_tensor_values_.find(input_idx); + if (it == root_tensor_values_.end()) { GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx); return; } @@ -343,7 +343,7 @@ void NodeState::UpdateRootTensor(int input_idx) { } *tensor = it->second; - GELOGW("[%s] Update input tensor: %d", GetName().c_str(), input_idx); + GELOGD("[%s] Update input tensor: %d", GetName().c_str(), input_idx); } void NodeState::ResetContext(uint64_t iteration) { diff --git a/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h index 98dfbf0b..c6468bbd 100644 --- a/ge/hybrid/executor/node_state.h +++ b/ge/hybrid/executor/node_state.h @@ -202,7 +202,7 @@ struct NodeState { std::future schedule_future_; std::shared_ptr frame_state_; - std::map root_tensor_value_; + std::map root_tensor_values_; uint64_t active_count_ = 0; uint64_t iteration_count_ = 0; uint32_t ctrl_scheduled_ = 0; diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index ae0b1b47..ec66f094 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -156,7 +156,6 @@ struct NodeItem { std::set ctrl_send_; // Send ctrl notify to std::set ctrl_recv_; // Recv ctrl notify from std::vector> switch_groups_; // Send ctrl notify to - std::set enter_inside_; // Enter feed loop inside Node, Not cross Merge. std::shared_ptr kernel_task; std::unique_ptr fused_subgraph; diff --git a/ge/hybrid/node_executor/aicore/aicore_op_task.cc b/ge/hybrid/node_executor/aicore/aicore_op_task.cc index 5ed57621..8cd24bd1 100644 --- a/ge/hybrid/node_executor/aicore/aicore_op_task.cc +++ b/ge/hybrid/node_executor/aicore/aicore_op_task.cc @@ -306,7 +306,7 @@ Status AiCoreOpTask::InitWithKernelDefWithHandle(const OpDesc &op_desc, const do } Status AiCoreOpTask::InitWithTaskDef(const OpDesc &op_desc, const domi::TaskDef &task_def) { - + auto rt_ret = ValidateTaskDef(task_def); if (rt_ret != SUCCESS) { REPORT_CALL_ERROR("E19999", "op:%s(op_type:%s) failed to validate task def:%s", @@ -315,7 +315,7 @@ Status AiCoreOpTask::InitWithTaskDef(const OpDesc &op_desc, const domi::TaskDef op_desc.GetName().c_str(), op_desc.GetType().c_str(), task_def.DebugString().c_str()); return rt_ret; } - + if (task_def.type() != RT_MODEL_TASK_ALL_KERNEL) { GE_CHK_STATUS_RET(InitWithKernelDef(op_desc, task_def)); } else { @@ -474,7 +474,7 @@ Status AiCoreOpTask::UpdateArgs(TaskContext &task_context) { if (task_context.IsTraceEnabled()) { for (int i = 0; i < index; ++i) { - GELOGD("[%s] Arg[%d] = %p", stub_name_.c_str(), i, arg_base_[i]); + GELOGD("[%s] Arg[%d] = %lu", stub_name_.c_str(), i, arg_base_[i]); } }