diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 85256d67..1e4ac9db 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -888,22 +888,19 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { m_waitee_id = Profiler::next_id(); RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); bool require_host = prop == TensorProp::HostValue; - bool value_fetching = false; + auto host_available = [&]{ + return info->ptr && info->ptr->value_fetched(); + }; + if (require_host && !host_available()) { + // avoid dead lock + lock.unlock(); + m_buffer.enqueue(GetValue{info}); + m_buffer.flush(); + lock.lock(); + } m_cv.wait(lock, [&]() { check_worker_exc_unsafe(); - if (require_host) { - if (info->ptr && info->ptr->value_fetched()) { - return true; - } - if (!value_fetching) { - m_buffer.enqueue(GetValue{info}); - m_buffer.flush(); - value_fetching = true; - } - return false; - } else { - return static_cast(info->ptr); - } + return require_host ? host_available() : static_cast(info->ptr); }); RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop, m_waitee == nullptr); m_waitee = nullptr;