|
|
@@ -888,22 +888,19 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { |
|
|
|
m_waitee_id = Profiler::next_id(); |
|
|
|
RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); |
|
|
|
bool require_host = prop == TensorProp::HostValue; |
|
|
|
bool value_fetching = false; |
|
|
|
auto host_available = [&]{ |
|
|
|
return info->ptr && info->ptr->value_fetched(); |
|
|
|
}; |
|
|
|
if (require_host && !host_available()) { |
|
|
|
// avoid dead lock |
|
|
|
lock.unlock(); |
|
|
|
m_buffer.enqueue(GetValue{info}); |
|
|
|
m_buffer.flush(); |
|
|
|
lock.lock(); |
|
|
|
} |
|
|
|
m_cv.wait(lock, [&]() { |
|
|
|
check_worker_exc_unsafe(); |
|
|
|
if (require_host) { |
|
|
|
if (info->ptr && info->ptr->value_fetched()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
if (!value_fetching) { |
|
|
|
m_buffer.enqueue(GetValue{info}); |
|
|
|
m_buffer.flush(); |
|
|
|
value_fetching = true; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} else { |
|
|
|
return static_cast<bool>(info->ptr); |
|
|
|
} |
|
|
|
return require_host ? host_available() : static_cast<bool>(info->ptr); |
|
|
|
}); |
|
|
|
RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop, m_waitee == nullptr); |
|
|
|
m_waitee = nullptr; |
|
|
|