|
@@ -233,18 +233,17 @@ HostTensorND ChannelImpl::get_value(Handle handle) { |
|
|
mgb_assert(!m_waitee); |
|
|
mgb_assert(!m_waitee); |
|
|
// donnot use info->value_fetched, it's unsafe |
|
|
// donnot use info->value_fetched, it's unsafe |
|
|
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); |
|
|
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); |
|
|
|
|
|
std::unique_lock<decltype(m_mutex)> lock(m_mutex); |
|
|
TensorPtr tensor_ptr = info->ptr; |
|
|
TensorPtr tensor_ptr = info->ptr; |
|
|
auto value_fetched = [&]() { |
|
|
auto value_fetched = [&]() { |
|
|
return tensor_ptr && tensor_ptr->value_fetched(); |
|
|
return tensor_ptr && tensor_ptr->value_fetched(); |
|
|
}; |
|
|
}; |
|
|
if (!value_fetched()) { |
|
|
if (!value_fetched()) { |
|
|
std::unique_lock<decltype(m_mutex)> lock(m_mutex); |
|
|
|
|
|
m_waitee = info; |
|
|
m_waitee = info; |
|
|
regenerate(info); |
|
|
regenerate(info); |
|
|
m_buffer.enqueue(GetValue{info}); |
|
|
m_buffer.enqueue(GetValue{info}); |
|
|
m_cv.wait(lock, [&]() { |
|
|
m_cv.wait(lock, [&]() { |
|
|
check_worker_exc_unsafe(); |
|
|
check_worker_exc_unsafe(); |
|
|
// get tensor ptr in lock to ensure safety |
|
|
|
|
|
tensor_ptr = info->ptr; |
|
|
tensor_ptr = info->ptr; |
|
|
return value_fetched(); |
|
|
return value_fetched(); |
|
|
}); |
|
|
}); |
|
@@ -359,6 +358,11 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void ChannelImpl::release_tensor(TensorInfo* dest) { |
|
|
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
|
|
dest->ptr.reset(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
void ChannelImpl::regenerate(TensorInfo* dest) { |
|
|
void ChannelImpl::regenerate(TensorInfo* dest) { |
|
|
if (dest->evict_type == DROP) { |
|
|
if (dest->evict_type == DROP) { |
|
|
recompute(dest->producer); |
|
|
recompute(dest->producer); |
|
@@ -481,9 +485,9 @@ void ChannelImpl::process_one_task(Command& cmd) { |
|
|
produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value)); |
|
|
produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value)); |
|
|
} else if constexpr (std::is_same_v<T, SwapOut>) { |
|
|
} else if constexpr (std::is_same_v<T, SwapOut>) { |
|
|
cmd.dest->h_value = cmd.dest->ptr->get_value(); |
|
|
cmd.dest->h_value = cmd.dest->ptr->get_value(); |
|
|
cmd.dest->ptr.reset(); |
|
|
|
|
|
|
|
|
release_tensor(cmd.dest); |
|
|
} else if constexpr (std::is_same_v<T, Drop>) { |
|
|
} else if constexpr (std::is_same_v<T, Drop>) { |
|
|
cmd.dest->ptr.reset(); |
|
|
|
|
|
|
|
|
release_tensor(cmd.dest); |
|
|
} else if constexpr (std::is_same_v<T, Move>) { |
|
|
} else if constexpr (std::is_same_v<T, Move>) { |
|
|
produce_tensor(cmd.dest, cmd.src->ptr); |
|
|
produce_tensor(cmd.dest, cmd.src->ptr); |
|
|
free(cmd.src); |
|
|
free(cmd.src); |
|
|