Browse Source

perf(interpreter): don't check host value if unnecessary

GitOrigin-RevId: 5306c71328
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
462364eb06
1 changed files with 4 additions and 3 deletions
  1. +4
    -3
      imperative/src/impl/interpreter/interpreter_impl.cpp

+ 4
- 3
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -935,13 +935,14 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
bool require_host = prop == TensorProp::HostValue; bool require_host = prop == TensorProp::HostValue;
auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); }; auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
bool wait_host = !host_available();
if (require_host && wait_host) {
bool wait_host = false;
if (require_host && !host_available()) {
// avoid dead lock // avoid dead lock
lock.unlock(); lock.unlock();
m_buffer.enqueue(GetValue{info}); m_buffer.enqueue(GetValue{info});
m_buffer.flush(); m_buffer.flush();
lock.lock(); lock.lock();
wait_host = true;
} }
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
@@ -949,7 +950,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
}); });
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
m_waitee = nullptr; m_waitee = nullptr;
if (require_host && wait_host) {
if (wait_host) {
auto err = info->ptr->comp_node().check_async_error(); auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what()); mgb_assert(!err, "%s", err->what());
} }


Loading…
Cancel
Save