From d04b4bc006022f8c78f95ed43cab8125da162f61 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 21 Jan 2021 13:45:56 +0800 Subject: [PATCH] fix(interp): thread safety for drop and swapout GitOrigin-RevId: 7684f160bf1ca239c92c977c7238cac2b51ab4a2 --- imperative/src/impl/interpreter_impl.cpp | 12 ++++++++---- imperative/src/impl/interpreter_impl.h | 2 ++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/imperative/src/impl/interpreter_impl.cpp b/imperative/src/impl/interpreter_impl.cpp index 8ab5e0aa..7be91f18 100644 --- a/imperative/src/impl/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter_impl.cpp @@ -233,18 +233,17 @@ HostTensorND ChannelImpl::get_value(Handle handle) { mgb_assert(!m_waitee); // donnot use info->value_fetched, it's unsafe mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); + std::unique_lock lock(m_mutex); TensorPtr tensor_ptr = info->ptr; auto value_fetched = [&]() { return tensor_ptr && tensor_ptr->value_fetched(); }; if (!value_fetched()) { - std::unique_lock lock(m_mutex); m_waitee = info; regenerate(info); m_buffer.enqueue(GetValue{info}); m_cv.wait(lock, [&]() { check_worker_exc_unsafe(); - // get tensor ptr in lock to ensure safety tensor_ptr = info->ptr; return value_fetched(); }); @@ -359,6 +358,11 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { } } +void ChannelImpl::release_tensor(TensorInfo* dest) { + MGB_LOCK_GUARD(m_mutex); + dest->ptr.reset(); +} + void ChannelImpl::regenerate(TensorInfo* dest) { if (dest->evict_type == DROP) { recompute(dest->producer); @@ -481,9 +485,9 @@ void ChannelImpl::process_one_task(Command& cmd) { produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value)); } else if constexpr (std::is_same_v) { cmd.dest->h_value = cmd.dest->ptr->get_value(); - cmd.dest->ptr.reset(); + release_tensor(cmd.dest); } else if constexpr (std::is_same_v) { - cmd.dest->ptr.reset(); + release_tensor(cmd.dest); } else if constexpr (std::is_same_v) { produce_tensor(cmd.dest, cmd.src->ptr); free(cmd.src); diff --git a/imperative/src/impl/interpreter_impl.h b/imperative/src/impl/interpreter_impl.h index c9c070c8..979201dd 100644 --- a/imperative/src/impl/interpreter_impl.h +++ b/imperative/src/impl/interpreter_impl.h @@ -249,6 +249,8 @@ private: void produce_tensor(TensorInfo* dest, TensorPtr ptr); + void release_tensor(TensorInfo* dest); + void regenerate(TensorInfo* dest); void recompute(TensorInfo::ComputePath* path);