|
|
@@ -320,18 +320,18 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { |
|
|
|
void NodeState::SaveRootTensor(int input_idx, const TensorValue &tensor) { |
|
|
|
if (node_item_->root_data_.count(input_idx) > 0) { |
|
|
|
GELOGD("[%s] Save Const input tensor: %d", GetName().c_str(), input_idx); |
|
|
|
root_tensor_value_[input_idx] = tensor; |
|
|
|
root_tensor_values_[input_idx] = tensor; |
|
|
|
} |
|
|
|
|
|
|
|
if (node_item_->enter_data_.count(input_idx) > 0) { |
|
|
|
GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); |
|
|
|
root_tensor_value_[input_idx] = tensor; |
|
|
|
root_tensor_values_[input_idx] = tensor; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void NodeState::UpdateRootTensor(int input_idx) { |
|
|
|
const auto it = root_tensor_value_.find(input_idx); |
|
|
|
if (it == root_tensor_value_.end()) { |
|
|
|
const auto it = root_tensor_values_.find(input_idx); |
|
|
|
if (it == root_tensor_values_.end()) { |
|
|
|
GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx); |
|
|
|
return; |
|
|
|
} |
|
|
@@ -343,7 +343,7 @@ void NodeState::UpdateRootTensor(int input_idx) { |
|
|
|
} |
|
|
|
|
|
|
|
*tensor = it->second; |
|
|
|
GELOGW("[%s] Update input tensor: %d", GetName().c_str(), input_idx); |
|
|
|
GELOGD("[%s] Update input tensor: %d", GetName().c_str(), input_idx); |
|
|
|
} |
|
|
|
|
|
|
|
void NodeState::ResetContext(uint64_t iteration) { |
|
|
|