Browse Source

fix(interp): thread safety for drop and swapout

GitOrigin-RevId: 7684f160bf
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
d04b4bc006
2 changed files with 10 additions and 4 deletions
  1. +8
    -4
      imperative/src/impl/interpreter_impl.cpp
  2. +2
    -0
      imperative/src/impl/interpreter_impl.h

+ 8
- 4
imperative/src/impl/interpreter_impl.cpp View File

@@ -233,18 +233,17 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
mgb_assert(!m_waitee); mgb_assert(!m_waitee);
// donnot use info->value_fetched, it's unsafe // donnot use info->value_fetched, it's unsafe
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
TensorPtr tensor_ptr = info->ptr; TensorPtr tensor_ptr = info->ptr;
auto value_fetched = [&]() { auto value_fetched = [&]() {
return tensor_ptr && tensor_ptr->value_fetched(); return tensor_ptr && tensor_ptr->value_fetched();
}; };
if (!value_fetched()) { if (!value_fetched()) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
m_waitee = info; m_waitee = info;
regenerate(info); regenerate(info);
m_buffer.enqueue(GetValue{info}); m_buffer.enqueue(GetValue{info});
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
// get tensor ptr in lock to ensure safety
tensor_ptr = info->ptr; tensor_ptr = info->ptr;
return value_fetched(); return value_fetched();
}); });
@@ -359,6 +358,11 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
} }
} }


void ChannelImpl::release_tensor(TensorInfo* dest) {
MGB_LOCK_GUARD(m_mutex);
dest->ptr.reset();
}

void ChannelImpl::regenerate(TensorInfo* dest) { void ChannelImpl::regenerate(TensorInfo* dest) {
if (dest->evict_type == DROP) { if (dest->evict_type == DROP) {
recompute(dest->producer); recompute(dest->producer);
@@ -481,9 +485,9 @@ void ChannelImpl::process_one_task(Command& cmd) {
produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value)); produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value));
} else if constexpr (std::is_same_v<T, SwapOut>) { } else if constexpr (std::is_same_v<T, SwapOut>) {
cmd.dest->h_value = cmd.dest->ptr->get_value(); cmd.dest->h_value = cmd.dest->ptr->get_value();
cmd.dest->ptr.reset();
release_tensor(cmd.dest);
} else if constexpr (std::is_same_v<T, Drop>) { } else if constexpr (std::is_same_v<T, Drop>) {
cmd.dest->ptr.reset();
release_tensor(cmd.dest);
} else if constexpr (std::is_same_v<T, Move>) { } else if constexpr (std::is_same_v<T, Move>) {
produce_tensor(cmd.dest, cmd.src->ptr); produce_tensor(cmd.dest, cmd.src->ptr);
free(cmd.src); free(cmd.src);


+ 2
- 0
imperative/src/impl/interpreter_impl.h View File

@@ -249,6 +249,8 @@ private:


void produce_tensor(TensorInfo* dest, TensorPtr ptr); void produce_tensor(TensorInfo* dest, TensorPtr ptr);


void release_tensor(TensorInfo* dest);

void regenerate(TensorInfo* dest); void regenerate(TensorInfo* dest);
void recompute(TensorInfo::ComputePath* path); void recompute(TensorInfo::ComputePath* path);




Loading…
Cancel
Save