From d4828ea130d310773161d5f1b8ccc313f283cd1a Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Sat, 26 Jun 2021 15:00:53 +0800 Subject: [PATCH] UpdatePersistTensor from ExecutionEngine --- ge/hybrid/executor/node_state.cc | 4 ++++ ge/hybrid/executor/worker/execution_engine.cc | 1 + ge/hybrid/node_executor/node_executor.cc | 1 - ge/hybrid/node_executor/task_context.cc | 11 +---------- ge/hybrid/node_executor/task_context.h | 1 - 5 files changed, 6 insertions(+), 12 deletions(-) diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index 4b0d0c44..7ab7b536 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -333,6 +333,10 @@ void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { return std::any_of(items.begin(), items.end(), is_exist); }; + if (root_tensor_values_.count(input_idx) > 0) { + return; + } + if (is_persist_tensor(node_item_->root_data_, input_idx)) { GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); root_tensor_values_[input_idx] = tensor; diff --git a/ge/hybrid/executor/worker/execution_engine.cc b/ge/hybrid/executor/worker/execution_engine.cc index 8eecbc80..d4c73f58 100755 --- a/ge/hybrid/executor/worker/execution_engine.cc +++ b/ge/hybrid/executor/worker/execution_engine.cc @@ -375,6 +375,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", node_state.GetName().c_str()); + node_state.UpdatePersistTensor(); RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); GELOGD("[%s] Done task preparation successfully.", node_state.GetName().c_str()); diff --git a/ge/hybrid/node_executor/node_executor.cc b/ge/hybrid/node_executor/node_executor.cc index eeb5ba20..9e9354d9 100755 --- a/ge/hybrid/node_executor/node_executor.cc +++ b/ge/hybrid/node_executor/node_executor.cc @@ -39,7 +39,6 @@ const char *const kEngineNameHostCpu = "DNN_VM_HOST_CPU_OP_STORE"; Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces()); - GE_CHK_STATUS_RET_NOLOG(context.UpdatePersistTensor()); GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context)); return SUCCESS; } diff --git a/ge/hybrid/node_executor/task_context.cc b/ge/hybrid/node_executor/task_context.cc index 3c288981..4ecc1558 100644 --- a/ge/hybrid/node_executor/task_context.cc +++ b/ge/hybrid/node_executor/task_context.cc @@ -460,22 +460,12 @@ Status TaskContext::PropagateOutputs() { subgraph_context_->all_inputs_[input_offset].SetName( node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); } - - auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); - GE_CHECK_NOTNULL(dst_node_state); - dst_node_state->SavePersistTensor(dst_input_idx, *tensor); } } (void)guard; return SUCCESS; } -Status TaskContext::UpdatePersistTensor() { - GE_CHECK_NOTNULL(node_state_); - node_state_->UpdatePersistTensor(); - return SUCCESS; -} - const void *TaskContext::GetVarBaseAddr() { return execution_context_->model->GetVarMemBase(); } @@ -501,6 +491,7 @@ void TaskContext::ReleaseInputsAndOutputs() { void TaskContext::ReleaseInput(int index) { auto input_tensor = MutableInput(index); if (input_tensor != nullptr) { + node_state_->SavePersistTensor(index, *input_tensor); input_tensor->Destroy(); GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); } diff --git a/ge/hybrid/node_executor/task_context.h b/ge/hybrid/node_executor/task_context.h index cff5d680..c96e194e 100644 --- a/ge/hybrid/node_executor/task_context.h +++ b/ge/hybrid/node_executor/task_context.h @@ -78,7 +78,6 @@ class TaskContext { Status AllocateOutputs(AllocationAttr *attr = nullptr); Status AllocateWorkspaces(); Status AllocateWorkspace(size_t size, void **buffer, void *ori_addr = nullptr); - Status UpdatePersistTensor(); bool IsTraceEnabled() const;