|
@@ -935,13 +935,14 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { |
|
|
MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); |
|
|
MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); |
|
|
bool require_host = prop == TensorProp::HostValue; |
|
|
bool require_host = prop == TensorProp::HostValue; |
|
|
auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); }; |
|
|
auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); }; |
|
|
bool wait_host = !host_available(); |
|
|
|
|
|
if (require_host && wait_host) { |
|
|
|
|
|
|
|
|
bool wait_host = false; |
|
|
|
|
|
if (require_host && !host_available()) { |
|
|
// avoid dead lock |
|
|
// avoid dead lock |
|
|
lock.unlock(); |
|
|
lock.unlock(); |
|
|
m_buffer.enqueue(GetValue{info}); |
|
|
m_buffer.enqueue(GetValue{info}); |
|
|
m_buffer.flush(); |
|
|
m_buffer.flush(); |
|
|
lock.lock(); |
|
|
lock.lock(); |
|
|
|
|
|
wait_host = true; |
|
|
} |
|
|
} |
|
|
m_cv.wait(lock, [&]() { |
|
|
m_cv.wait(lock, [&]() { |
|
|
check_worker_exc_unsafe(); |
|
|
check_worker_exc_unsafe(); |
|
@@ -949,7 +950,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { |
|
|
}); |
|
|
}); |
|
|
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); |
|
|
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); |
|
|
m_waitee = nullptr; |
|
|
m_waitee = nullptr; |
|
|
if (require_host && wait_host) { |
|
|
|
|
|
|
|
|
if (wait_host) { |
|
|
auto err = info->ptr->comp_node().check_async_error(); |
|
|
auto err = info->ptr->comp_node().check_async_error(); |
|
|
mgb_assert(!err, "%s", err->what()); |
|
|
mgb_assert(!err, "%s", err->what()); |
|
|
} |
|
|
} |
|
|