|
|
@@ -1002,8 +1002,11 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { |
|
|
|
m_waitee_id = Profiler::next_id(); |
|
|
|
MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); |
|
|
|
bool require_host = prop == TensorProp::HostValue; |
|
|
|
bool require_dev = prop == TensorProp::DevValue; |
|
|
|
auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); }; |
|
|
|
auto dev_available = [&] { return info->ptr; }; |
|
|
|
bool wait_host = false; |
|
|
|
bool wait_regen = false; |
|
|
|
if (require_host && !host_available()) { |
|
|
|
// avoid dead lock |
|
|
|
lock.unlock(); |
|
|
@@ -1020,16 +1023,52 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { |
|
|
|
lock.lock(); |
|
|
|
wait_host = true; |
|
|
|
} |
|
|
|
m_cv.wait(lock, [&]() { |
|
|
|
check_worker_exc_unsafe(); |
|
|
|
return require_host ? host_available() : static_cast<bool>(info->ptr); |
|
|
|
}); |
|
|
|
if (require_dev && !dev_available()) { |
|
|
|
lock.unlock(); |
|
|
|
if (Profiler::is_profiling()) { |
|
|
|
m_worker.add_task( |
|
|
|
{Profiler::next_id(), StartRegen{info}, |
|
|
|
get_channel_state().stack_manager.dump()}); |
|
|
|
} else { |
|
|
|
m_worker.add_task({ |
|
|
|
Profiler::next_id(), |
|
|
|
StartRegen{info}, |
|
|
|
}); |
|
|
|
} |
|
|
|
lock.lock(); |
|
|
|
wait_regen = true; |
|
|
|
} |
|
|
|
if (require_dev) { |
|
|
|
m_cv.wait(lock, [&]() { |
|
|
|
check_worker_exc_unsafe(); |
|
|
|
return dev_available(); |
|
|
|
}); |
|
|
|
} else { |
|
|
|
m_cv.wait(lock, [&]() { |
|
|
|
check_worker_exc_unsafe(); |
|
|
|
return require_host ? host_available() : static_cast<bool>(info->ptr); |
|
|
|
}); |
|
|
|
} |
|
|
|
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); |
|
|
|
m_waitee = nullptr; |
|
|
|
if (wait_host) { |
|
|
|
auto err = info->ptr->comp_node().check_async_error(); |
|
|
|
mgb_assert(!err, "%s", err->what()); |
|
|
|
} |
|
|
|
if (wait_regen) { |
|
|
|
lock.unlock(); |
|
|
|
if (Profiler::is_profiling()) { |
|
|
|
m_worker.add_task( |
|
|
|
{Profiler::next_id(), StopRegen{info}, |
|
|
|
get_channel_state().stack_manager.dump()}); |
|
|
|
} else { |
|
|
|
m_worker.add_task({ |
|
|
|
Profiler::next_id(), |
|
|
|
StopRegen{info}, |
|
|
|
}); |
|
|
|
} |
|
|
|
lock.lock(); |
|
|
|
} |
|
|
|
return info->ptr; |
|
|
|
} |
|
|
|
|
|
|
@@ -1254,6 +1293,17 @@ void ChannelImpl::process_one_task(Command& icmd) { |
|
|
|
MGB_RECORD_EVENT(ScopeEvent, cmd.scope_name); |
|
|
|
} else if constexpr (std::is_same_v<T, PopScope>) { |
|
|
|
MGB_RECORD_EVENT(ScopeFinishEvent, cmd.scope_name); |
|
|
|
} else if constexpr (std::is_same_v<T, StartRegen>) { |
|
|
|
if (cmd.dest->invalid) |
|
|
|
return; |
|
|
|
cmd.dest->pin(); |
|
|
|
if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) { |
|
|
|
regenerate(cmd.dest); |
|
|
|
} |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
notify_tensor_unsafe(cmd.dest); |
|
|
|
} else if constexpr (std::is_same_v<T, StopRegen>) { |
|
|
|
cmd.dest->unpin(); |
|
|
|
} else { |
|
|
|
static_assert(!std::is_same_v<T, T>); |
|
|
|
} |
|
|
|