Browse Source

UpdatePersistTensor from ExecutionEngine

tags/v1.5.1
zhangxiaokun 4 years ago
parent
commit
d4828ea130
5 changed files with 6 additions and 12 deletions
  1. +4
    -0
      ge/hybrid/executor/node_state.cc
  2. +1
    -0
      ge/hybrid/executor/worker/execution_engine.cc
  3. +0
    -1
      ge/hybrid/node_executor/node_executor.cc
  4. +1
    -10
      ge/hybrid/node_executor/task_context.cc
  5. +0
    -1
      ge/hybrid/node_executor/task_context.h

+ 4
- 0
ge/hybrid/executor/node_state.cc View File

@@ -333,6 +333,10 @@ void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) {
return std::any_of(items.begin(), items.end(), is_exist);
};

if (root_tensor_values_.count(input_idx) > 0) {
return;
}

if (is_persist_tensor(node_item_->root_data_, input_idx)) {
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx);
root_tensor_values_[input_idx] = tensor;


+ 1
- 0
ge/hybrid/executor/worker/execution_engine.cc View File

@@ -375,6 +375,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state,
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start");
GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.",
node_state.GetName().c_str());
node_state.UpdatePersistTensor();
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End");
GELOGD("[%s] Done task preparation successfully.", node_state.GetName().c_str());



+ 0
- 1
ge/hybrid/node_executor/node_executor.cc View File

@@ -39,7 +39,6 @@ const char *const kEngineNameHostCpu = "DNN_VM_HOST_CPU_OP_STORE";
Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const {
GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs());
GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces());
GE_CHK_STATUS_RET_NOLOG(context.UpdatePersistTensor());
GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context));
return SUCCESS;
}


+ 1
- 10
ge/hybrid/node_executor/task_context.cc View File

@@ -460,22 +460,12 @@ Status TaskContext::PropagateOutputs() {
subgraph_context_->all_inputs_[input_offset].SetName(
node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx));
}

auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item);
GE_CHECK_NOTNULL(dst_node_state);
dst_node_state->SavePersistTensor(dst_input_idx, *tensor);
}
}
(void)guard;
return SUCCESS;
}

Status TaskContext::UpdatePersistTensor() {
GE_CHECK_NOTNULL(node_state_);
node_state_->UpdatePersistTensor();
return SUCCESS;
}

const void *TaskContext::GetVarBaseAddr() {
return execution_context_->model->GetVarMemBase();
}
@@ -501,6 +491,7 @@ void TaskContext::ReleaseInputsAndOutputs() {
void TaskContext::ReleaseInput(int index) {
auto input_tensor = MutableInput(index);
if (input_tensor != nullptr) {
node_state_->SavePersistTensor(index, *input_tensor);
input_tensor->Destroy();
GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index);
}


+ 0
- 1
ge/hybrid/node_executor/task_context.h View File

@@ -78,7 +78,6 @@ class TaskContext {
Status AllocateOutputs(AllocationAttr *attr = nullptr);
Status AllocateWorkspaces();
Status AllocateWorkspace(size_t size, void **buffer, void *ori_addr = nullptr);
Status UpdatePersistTensor();

bool IsTraceEnabled() const;



Loading…
Cancel
Save