From ad079009e13069389b59ec41e138805b43bdd6d3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 10 Sep 2021 18:27:16 +0800 Subject: [PATCH] fix(interpreter): avoid deadlock in GetValue GitOrigin-RevId: fd438731674095a28b79beb45a39af78082b0829 --- .../src/impl/interpreter/interpreter_impl.cpp | 25 ++++++++++------------ 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 85256d67..1e4ac9db 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -888,22 +888,19 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { m_waitee_id = Profiler::next_id(); RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); bool require_host = prop == TensorProp::HostValue; - bool value_fetching = false; + auto host_available = [&]{ + return info->ptr && info->ptr->value_fetched(); + }; + if (require_host && !host_available()) { + // avoid dead lock + lock.unlock(); + m_buffer.enqueue(GetValue{info}); + m_buffer.flush(); + lock.lock(); + } m_cv.wait(lock, [&]() { check_worker_exc_unsafe(); - if (require_host) { - if (info->ptr && info->ptr->value_fetched()) { - return true; - } - if (!value_fetching) { - m_buffer.enqueue(GetValue{info}); - m_buffer.flush(); - value_fetching = true; - } - return false; - } else { - return static_cast(info->ptr); - } + return require_host ? host_available() : static_cast(info->ptr); }); RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop, m_waitee == nullptr); m_waitee = nullptr;