|
|
@@ -35,7 +35,6 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { |
|
|
|
info->desc.comp_node = value.comp_node(); |
|
|
|
info->desc.value = value.proxy_to_default_cpu(); |
|
|
|
info->h_value = value; |
|
|
|
m_valid_handle.insert(info); |
|
|
|
m_buffer.enqueue(Put{info, value, no_cache}); |
|
|
|
if (m_async_level == 0) { |
|
|
|
sync(); |
|
|
@@ -49,20 +48,25 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { |
|
|
|
info->desc.layout = data.layout(); |
|
|
|
info->desc.comp_node = data.comp_node(); |
|
|
|
info->ptr = Tensor::make(data); |
|
|
|
m_valid_handle.insert(info); |
|
|
|
return info; |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::del(Handle handle) { |
|
|
|
mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); |
|
|
|
m_buffer.enqueue(Del{reinterpret_cast<TensorInfo*>(handle)}); |
|
|
|
mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle); |
|
|
|
auto* info = reinterpret_cast<TensorInfo*>(handle); |
|
|
|
detach_users(info); |
|
|
|
info->detach_producer(); |
|
|
|
m_valid_handle.erase(handle); |
|
|
|
m_buffer.enqueue(Del{info}); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::swap_in(Handle handle) { |
|
|
|
if (m_enable_evict & SWAP) { |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
m_buffer.enqueue(SwapIn{reinterpret_cast<TensorInfo*>(handle)}); |
|
|
|
auto* info = reinterpret_cast<TensorInfo*>(handle); |
|
|
|
m_buffer.enqueue(SwapIn{info}); |
|
|
|
info->evict_type = NONE; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -70,7 +74,9 @@ void ChannelImpl::swap_out(Handle handle) { |
|
|
|
if (m_enable_evict & SWAP) { |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
m_buffer.enqueue(SwapOut{reinterpret_cast<TensorInfo*>(handle)}); |
|
|
|
auto* info = reinterpret_cast<TensorInfo*>(handle); |
|
|
|
m_buffer.enqueue(SwapOut{info}); |
|
|
|
info->evict_type = SWAP; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -78,7 +84,13 @@ void ChannelImpl::drop(Handle handle) { |
|
|
|
if (m_enable_evict & DROP) { |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
m_buffer.enqueue(Drop{reinterpret_cast<TensorInfo*>(handle)}); |
|
|
|
auto* info = reinterpret_cast<TensorInfo*>(handle); |
|
|
|
if (!info->producer) { |
|
|
|
mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", info); |
|
|
|
return; |
|
|
|
} |
|
|
|
info->evict_type = DROP; |
|
|
|
m_buffer.enqueue(Drop{info}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -134,18 +146,8 @@ void ChannelImpl::dispatch_default_cpu( |
|
|
|
output_infos.push_back(info); |
|
|
|
outputs->push_back(info); |
|
|
|
} |
|
|
|
|
|
|
|
if (m_enable_evict & DROP) { |
|
|
|
for (auto out : output_infos) { |
|
|
|
out->path.op = op; |
|
|
|
for (auto out_ : output_infos) { |
|
|
|
out->path.outputs.push_back(m_st.at(out_)); |
|
|
|
} |
|
|
|
for (auto inp : input_infos) { |
|
|
|
out->path.inputs.push_back(m_st.at(inp)); |
|
|
|
inp->path.dep_outputs.push_back(m_st.at(out)); |
|
|
|
} |
|
|
|
} |
|
|
|
TensorInfo::ComputePath::make(op, input_infos, output_infos); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -168,21 +170,11 @@ void ChannelImpl::dispatch_kernel( |
|
|
|
info->h_value = HostTensorND::make_proxy(desc.value) |
|
|
|
.proxy_to_comp_node(desc.comp_node); |
|
|
|
} |
|
|
|
m_valid_handle.insert(info); |
|
|
|
cmd.outputs.push_back(info); |
|
|
|
outputs->push_back(info); |
|
|
|
} |
|
|
|
if (m_enable_evict & DROP) { |
|
|
|
for (auto out : cmd.outputs) { |
|
|
|
out->path.op = cmd.op; |
|
|
|
for (auto out_ : cmd.outputs) { |
|
|
|
out->path.outputs.push_back(m_st.at(out_)); |
|
|
|
} |
|
|
|
for (auto inp : cmd.inputs) { |
|
|
|
out->path.inputs.push_back(m_st.at(inp)); |
|
|
|
inp->path.dep_outputs.push_back(m_st.at(out)); |
|
|
|
} |
|
|
|
} |
|
|
|
TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs); |
|
|
|
} |
|
|
|
m_buffer.enqueue(std::move(cmd)); |
|
|
|
if (!validated && m_async_level == 1) { |
|
|
@@ -215,6 +207,7 @@ SmallVector<Handle> ChannelImpl::apply_op( |
|
|
|
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); |
|
|
|
input_infos.push_back(info); |
|
|
|
input_descs.push_back(info->desc); |
|
|
|
regenerate(info); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -233,23 +226,31 @@ SmallVector<Handle> ChannelImpl::apply_op( |
|
|
|
} |
|
|
|
|
|
|
|
HostTensorND ChannelImpl::get_value(Handle handle) { |
|
|
|
// TODO: maybe get_value should be done on host. i.e. delete GetValue |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
|
std::unique_lock<decltype(m_mutex)> lock(m_mutex); |
|
|
|
mgb_assert(!m_waitee); |
|
|
|
if (!info->value_fetched) { |
|
|
|
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); |
|
|
|
// donnot use info->value_fetched, it's unsafe |
|
|
|
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); |
|
|
|
TensorPtr tensor_ptr = info->ptr; |
|
|
|
auto value_fetched = [&]() { |
|
|
|
return tensor_ptr && tensor_ptr->value_fetched(); |
|
|
|
}; |
|
|
|
if (!value_fetched()) { |
|
|
|
std::unique_lock<decltype(m_mutex)> lock(m_mutex); |
|
|
|
m_waitee = info; |
|
|
|
regenerate(info); |
|
|
|
m_buffer.enqueue(GetValue{info}); |
|
|
|
m_cv.wait(lock, [&]() { |
|
|
|
check_worker_exc_unsafe(); |
|
|
|
return info->value_fetched; |
|
|
|
// get tensor ptr in lock to ensure safety |
|
|
|
tensor_ptr = info->ptr; |
|
|
|
return value_fetched(); |
|
|
|
}); |
|
|
|
m_waitee = nullptr; |
|
|
|
} |
|
|
|
mgb_assert(info->ptr->value_fetched()); |
|
|
|
return info->ptr->get_value(); |
|
|
|
return tensor_ptr->get_value(); |
|
|
|
} |
|
|
|
|
|
|
|
TensorShape ChannelImpl::get_shape(Handle handle) { |
|
|
@@ -298,6 +299,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { |
|
|
|
std::unique_lock<decltype(m_mutex)> lock(m_mutex); |
|
|
|
mgb_assert(!m_waitee); |
|
|
|
m_waitee = info; |
|
|
|
regenerate(info); |
|
|
|
m_buffer.enqueue(Flush{info}); |
|
|
|
m_cv.wait(lock, [&]() { |
|
|
|
check_worker_exc_unsafe(); |
|
|
@@ -332,17 +334,12 @@ int ChannelImpl::get_async_level() { |
|
|
|
TensorInfo* ChannelImpl::alloc() { |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
auto info = m_pool.alloc(); |
|
|
|
m_st.insert(info); |
|
|
|
m_valid_handle.insert(info); |
|
|
|
return info; |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::free(TensorInfo* ptr) { |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
if (ptr->path.dep_outputs.size() > 0) { |
|
|
|
remove_dep(ptr); |
|
|
|
} |
|
|
|
m_st.erase(ptr); |
|
|
|
mgb_assert(ptr->allow_delete, "delete before ref_cnt = 0"); |
|
|
|
m_pool.free(ptr); |
|
|
|
} |
|
|
|
|
|
|
@@ -350,77 +347,64 @@ ChannelImpl::~ChannelImpl() { |
|
|
|
close(); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice = true) { |
|
|
|
auto lock = notice ? std::unique_lock<std::mutex>(m_mutex) |
|
|
|
: std::unique_lock<std::mutex>(); |
|
|
|
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
dest->value_fetched = ptr->value_fetched(); |
|
|
|
// update tensor desc for static infer |
|
|
|
dest->desc.layout = ptr->layout(); |
|
|
|
dest->desc.comp_node = ptr->comp_node(); |
|
|
|
dest->ptr = std::move(ptr); |
|
|
|
if (notice && m_waitee == dest) { |
|
|
|
if (m_waitee == dest) { |
|
|
|
m_cv.notify_all(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::do_swap_out(TensorInfo* dest) { |
|
|
|
void ChannelImpl::regenerate(TensorInfo* dest) { |
|
|
|
if (dest->evict_type == DROP) { |
|
|
|
mgb_log_warn("the evict type of tensor %p was set to DROP, this SWAP operation will be ignored", dest); |
|
|
|
return; |
|
|
|
} |
|
|
|
if (!dest->ptr) { |
|
|
|
return; |
|
|
|
recompute(dest->producer); |
|
|
|
} else if (dest->evict_type == SWAP) { |
|
|
|
swap_in(dest); |
|
|
|
} |
|
|
|
dest->evict_type = SWAP; |
|
|
|
dest->value_fetched = false; |
|
|
|
// TODO: swap in parallel |
|
|
|
dest->h_value = dest->ptr->get_value(); |
|
|
|
dest->ptr.reset(); |
|
|
|
mgb_assert(dest->evict_type == NONE); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::do_swap_in(TensorInfo* dest) { |
|
|
|
if (dest->ptr) { |
|
|
|
return; |
|
|
|
} |
|
|
|
if (dest->h_value.empty()) { |
|
|
|
mgb_log_error("backup of the tensor %p not found", dest); |
|
|
|
return; |
|
|
|
void ChannelImpl::recompute(TensorInfo::ComputePath* path) { |
|
|
|
SmallVector<TensorInfo*> workspaces(path->outputs.size(), nullptr); |
|
|
|
for (auto&& input: path->inputs) { |
|
|
|
regenerate(input); |
|
|
|
} |
|
|
|
produce_tensor(dest, Tensor::make(dest->h_value), false); |
|
|
|
dest->evict_type = NONE; |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::remove_dep(TensorInfo* dest) { |
|
|
|
for (auto i : dest->path.dep_outputs) { |
|
|
|
auto out_ptr = i.lock(); |
|
|
|
if (out_ptr) { |
|
|
|
regenerate(out_ptr.get(), true); |
|
|
|
for (auto&& output: path->outputs) { |
|
|
|
if(output == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
output->evict_type = NONE; |
|
|
|
} |
|
|
|
m_buffer.enqueue(ApplyOp{path->op, path->inputs, path->outputs}); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::do_drop(TensorInfo* dest) { |
|
|
|
if (dest->evict_type == SWAP) { |
|
|
|
mgb_log_warn("the evict type of tensor %p was set to SWAP, this DROP operation will be ignored", dest); |
|
|
|
return; |
|
|
|
} |
|
|
|
if (!dest->path.op) { |
|
|
|
mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", dest); |
|
|
|
return; |
|
|
|
} |
|
|
|
if (dest->recompute_times >= m_max_recompute_time) { |
|
|
|
mgb_log_warn("the recomputation time for tensor %p exceeds the limit, this drop operation will be ignored", dest); |
|
|
|
return; |
|
|
|
} |
|
|
|
if (!dest->ptr) { |
|
|
|
return; |
|
|
|
void ChannelImpl::detach_users(TensorInfo* dest) { |
|
|
|
SmallVector<TensorInfo::ComputePath*> users = dest->users; |
|
|
|
for (auto* user: users) { |
|
|
|
for (auto* output: user->outputs) { |
|
|
|
if (output == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
regenerate(output); |
|
|
|
output->detach_producer(); |
|
|
|
} |
|
|
|
} |
|
|
|
dest->evict_type = DROP; |
|
|
|
dest->value_fetched = false; |
|
|
|
dest->ptr.reset(); |
|
|
|
dest->users.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::set_swap_flag(bool flag) { |
|
|
|
if ((!flag) && (m_enable_evict & SWAP)) { |
|
|
|
for (auto handle: m_valid_handle) { |
|
|
|
auto* info = reinterpret_cast<TensorInfo*>(handle); |
|
|
|
if (info->evict_type == SWAP) { |
|
|
|
swap_in(info); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (flag) { |
|
|
|
m_enable_evict |= SWAP; |
|
|
|
} else { |
|
|
@@ -429,6 +413,14 @@ void ChannelImpl::set_swap_flag(bool flag) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::set_drop_flag(bool flag) { |
|
|
|
if ((!flag) && (m_enable_evict & DROP)) { |
|
|
|
for (auto handle: m_valid_handle) { |
|
|
|
auto* info = reinterpret_cast<TensorInfo*>(handle); |
|
|
|
if (info->evict_type == DROP) { |
|
|
|
recompute(info->producer); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (flag) { |
|
|
|
m_enable_evict |= DROP; |
|
|
|
} else { |
|
|
@@ -440,46 +432,6 @@ void ChannelImpl::set_buffer_length(int length) { |
|
|
|
m_buffer.set_capacity(length); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::regenerate(TensorInfo* info, bool must_drop = false) { |
|
|
|
if (!info->ptr && info->evict_type != NONE) { |
|
|
|
if (info->evict_type == SWAP) { |
|
|
|
do_swap_in(info); |
|
|
|
} else { |
|
|
|
mgb_assert(info->evict_type == DROP); |
|
|
|
mgb_assert(info->path.op, "recomputation path not found"); |
|
|
|
auto path = info->path; |
|
|
|
SmallVector<TensorPtr> inputs; |
|
|
|
inputs.reserve(path.inputs.size()); |
|
|
|
for (auto i : path.inputs) { |
|
|
|
mgb_assert(i, "invalid history input"); |
|
|
|
if (!i->ptr) { |
|
|
|
regenerate(i.get(), must_drop); |
|
|
|
} |
|
|
|
inputs.push_back(i->ptr); |
|
|
|
} |
|
|
|
auto outputs = OpDef::apply_on_physical_tensor(*path.op, inputs); |
|
|
|
for (size_t i = 0; i < outputs.size(); i ++) { |
|
|
|
auto out_ptr = path.outputs[i].lock(); |
|
|
|
if (out_ptr) { |
|
|
|
out_ptr->recompute_times ++; |
|
|
|
if (!out_ptr->ptr && out_ptr->evict_type == DROP) { |
|
|
|
produce_tensor(out_ptr.get(), std::move(outputs[i]), false); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (must_drop) { |
|
|
|
if (info->path.op) { |
|
|
|
info->path.op.reset(); |
|
|
|
info->path.inputs.clear(); |
|
|
|
if (info->evict_type == DROP) { |
|
|
|
info->evict_type = NONE; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::process_one_task(Command& cmd) { |
|
|
|
//TODO: remove std::visit for support osx 10.12 |
|
|
|
std::visit([this](auto& cmd) { |
|
|
@@ -493,11 +445,6 @@ void ChannelImpl::process_one_task(Command& cmd) { |
|
|
|
tensor_inputs.reserve(cmd.inputs.size()); |
|
|
|
// refcnt == 1, owners: [TensorInfo::ptr] |
|
|
|
for (auto i : cmd.inputs) { |
|
|
|
if (m_enable_evict && i->evict_type != NONE) { |
|
|
|
if (!i->ptr) { |
|
|
|
regenerate(i); |
|
|
|
} |
|
|
|
} |
|
|
|
mgb_assert(i->ptr, "Invalid input tensor ptr!"); |
|
|
|
// refcnt ++, owners: [i->ptr, tensor_inputs] |
|
|
|
tensor_inputs.push_back(i->ptr); |
|
|
@@ -515,16 +462,14 @@ void ChannelImpl::process_one_task(Command& cmd) { |
|
|
|
*cmd.op, std::move(tensor_inputs)); |
|
|
|
mgb_assert(tensor_outputs.size() == cmd.outputs.size()); |
|
|
|
for (size_t i = 0; i < tensor_outputs.size(); ++i) { |
|
|
|
if (cmd.outputs[i] == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i])); |
|
|
|
} |
|
|
|
} else if constexpr (std::is_same_v<T, Del>) { |
|
|
|
free(cmd.dest); |
|
|
|
} else if constexpr (std::is_same_v<T, GetValue>) { |
|
|
|
if (m_enable_evict && cmd.dest->evict_type != NONE) { |
|
|
|
if (!cmd.dest->ptr) { |
|
|
|
regenerate(cmd.dest); |
|
|
|
} |
|
|
|
} |
|
|
|
mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!"); |
|
|
|
cmd.dest->ptr->fetch_value(); |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
@@ -533,11 +478,12 @@ void ChannelImpl::process_one_task(Command& cmd) { |
|
|
|
m_cv.notify_all(); |
|
|
|
} |
|
|
|
} else if constexpr (std::is_same_v<T, SwapIn>) { |
|
|
|
do_swap_in(cmd.dest); |
|
|
|
produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value)); |
|
|
|
} else if constexpr (std::is_same_v<T, SwapOut>) { |
|
|
|
do_swap_out(cmd.dest); |
|
|
|
cmd.dest->h_value = cmd.dest->ptr->get_value(); |
|
|
|
cmd.dest->ptr.reset(); |
|
|
|
} else if constexpr (std::is_same_v<T, Drop>) { |
|
|
|
do_drop(cmd.dest); |
|
|
|
cmd.dest->ptr.reset(); |
|
|
|
} else if constexpr (std::is_same_v<T, Move>) { |
|
|
|
produce_tensor(cmd.dest, cmd.src->ptr); |
|
|
|
free(cmd.src); |
|
|
|