Browse Source

fix(interp): thread safety for drop and swapout

GitOrigin-RevId: 7684f160bf
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
d04b4bc006
2 changed files with 10 additions and 4 deletions
  1. +8
    -4
      imperative/src/impl/interpreter_impl.cpp
  2. +2
    -0
      imperative/src/impl/interpreter_impl.h

+ 8
- 4
imperative/src/impl/interpreter_impl.cpp View File

@@ -233,18 +233,17 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
mgb_assert(!m_waitee);
// donnot use info->value_fetched, it's unsafe
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
TensorPtr tensor_ptr = info->ptr;
auto value_fetched = [&]() {
return tensor_ptr && tensor_ptr->value_fetched();
};
if (!value_fetched()) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
m_waitee = info;
regenerate(info);
m_buffer.enqueue(GetValue{info});
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
// get tensor ptr in lock to ensure safety
tensor_ptr = info->ptr;
return value_fetched();
});
@@ -359,6 +358,11 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
}
}

void ChannelImpl::release_tensor(TensorInfo* dest) {
MGB_LOCK_GUARD(m_mutex);
dest->ptr.reset();
}

void ChannelImpl::regenerate(TensorInfo* dest) {
if (dest->evict_type == DROP) {
recompute(dest->producer);
@@ -481,9 +485,9 @@ void ChannelImpl::process_one_task(Command& cmd) {
produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value));
} else if constexpr (std::is_same_v<T, SwapOut>) {
cmd.dest->h_value = cmd.dest->ptr->get_value();
cmd.dest->ptr.reset();
release_tensor(cmd.dest);
} else if constexpr (std::is_same_v<T, Drop>) {
cmd.dest->ptr.reset();
release_tensor(cmd.dest);
} else if constexpr (std::is_same_v<T, Move>) {
produce_tensor(cmd.dest, cmd.src->ptr);
free(cmd.src);


+ 2
- 0
imperative/src/impl/interpreter_impl.h View File

@@ -249,6 +249,8 @@ private:

void produce_tensor(TensorInfo* dest, TensorPtr ptr);

void release_tensor(TensorInfo* dest);

void regenerate(TensorInfo* dest);
void recompute(TensorInfo::ComputePath* path);



Loading…
Cancel
Save