Browse Source

Fix dynamic shape partition

pull/1782/head
zhangxiaokun 4 years ago
parent
commit
747ae4bbe1
5 changed files with 14 additions and 10 deletions
  1. +5
    -0
      ge/graph/partition/dynamic_shape_partition.cc
  2. +5
    -5
      ge/hybrid/executor/node_state.cc
  3. +1
    -1
      ge/hybrid/executor/node_state.h
  4. +0
    -1
      ge/hybrid/model/node_item.h
  5. +3
    -3
      ge/hybrid/node_executor/aicore/aicore_op_task.cc

+ 5
- 0
ge/graph/partition/dynamic_shape_partition.cc View File

@@ -381,6 +381,10 @@ void DynamicShapePartitioner::MergeClustersControlFlow() {
bool is_unknown_cluster = cluster->IsUnknownShape();
for (++rit; rit != control_cluster.rend(); ++rit) {
const auto &cluster_from = *rit;
if (all_merged_clusters.count(cluster_from) > 0) {
continue;
}

auto merged_clusters = cluster->MergeAllPathFrom(cluster_from);
GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(),
ToString(merged_clusters).c_str());
@@ -393,6 +397,7 @@ void DynamicShapePartitioner::MergeClustersControlFlow() {
}

if (!is_unknown_cluster && cluster->IsUnknownShape()) {
GELOGD("Add to ordered cluster: %s", cluster->DebugString().c_str());
ordered_cluster_.push_back(cluster);
}
}


+ 5
- 5
ge/hybrid/executor/node_state.cc View File

@@ -320,18 +320,18 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() {
void NodeState::SaveRootTensor(int input_idx, const TensorValue &tensor) {
if (node_item_->root_data_.count(input_idx) > 0) {
GELOGD("[%s] Save Const input tensor: %d", GetName().c_str(), input_idx);
root_tensor_value_[input_idx] = tensor;
root_tensor_values_[input_idx] = tensor;
}

if (node_item_->enter_data_.count(input_idx) > 0) {
GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx);
root_tensor_value_[input_idx] = tensor;
root_tensor_values_[input_idx] = tensor;
}
}

void NodeState::UpdateRootTensor(int input_idx) {
const auto it = root_tensor_value_.find(input_idx);
if (it == root_tensor_value_.end()) {
const auto it = root_tensor_values_.find(input_idx);
if (it == root_tensor_values_.end()) {
GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx);
return;
}
@@ -343,7 +343,7 @@ void NodeState::UpdateRootTensor(int input_idx) {
}

*tensor = it->second;
GELOGW("[%s] Update input tensor: %d", GetName().c_str(), input_idx);
GELOGD("[%s] Update input tensor: %d", GetName().c_str(), input_idx);
}

void NodeState::ResetContext(uint64_t iteration) {


+ 1
- 1
ge/hybrid/executor/node_state.h View File

@@ -202,7 +202,7 @@ struct NodeState {

std::future<Status> schedule_future_;
std::shared_ptr<FrameState> frame_state_;
std::map<int, TensorValue> root_tensor_value_;
std::map<int, TensorValue> root_tensor_values_;
uint64_t active_count_ = 0;
uint64_t iteration_count_ = 0;
uint32_t ctrl_scheduled_ = 0;


+ 0
- 1
ge/hybrid/model/node_item.h View File

@@ -156,7 +156,6 @@ struct NodeItem {
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to
std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from
std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to
std::set<int> enter_inside_; // Enter feed loop inside Node, Not cross Merge.

std::shared_ptr<NodeTask> kernel_task;
std::unique_ptr<FusedSubgraph> fused_subgraph;


+ 3
- 3
ge/hybrid/node_executor/aicore/aicore_op_task.cc View File

@@ -306,7 +306,7 @@ Status AiCoreOpTask::InitWithKernelDefWithHandle(const OpDesc &op_desc, const do
}

Status AiCoreOpTask::InitWithTaskDef(const OpDesc &op_desc, const domi::TaskDef &task_def) {
auto rt_ret = ValidateTaskDef(task_def);
if (rt_ret != SUCCESS) {
REPORT_CALL_ERROR("E19999", "op:%s(op_type:%s) failed to validate task def:%s",
@@ -315,7 +315,7 @@ Status AiCoreOpTask::InitWithTaskDef(const OpDesc &op_desc, const domi::TaskDef
op_desc.GetName().c_str(), op_desc.GetType().c_str(), task_def.DebugString().c_str());
return rt_ret;
}
if (task_def.type() != RT_MODEL_TASK_ALL_KERNEL) {
GE_CHK_STATUS_RET(InitWithKernelDef(op_desc, task_def));
} else {
@@ -474,7 +474,7 @@ Status AiCoreOpTask::UpdateArgs(TaskContext &task_context) {

if (task_context.IsTraceEnabled()) {
for (int i = 0; i < index; ++i) {
GELOGD("[%s] Arg[%d] = %p", stub_name_.c_str(), i, arg_base_[i]);
GELOGD("[%s] Arg[%d] = %lu", stub_name_.c_str(), i, arg_base_[i]);
}
}



Loading…
Cancel
Save