diff --git a/imperative/src/impl/interpreter_impl.cpp b/imperative/src/impl/interpreter_impl.cpp index 8248adf3..8ab5e0aa 100644 --- a/imperative/src/impl/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter_impl.cpp @@ -35,7 +35,6 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { info->desc.comp_node = value.comp_node(); info->desc.value = value.proxy_to_default_cpu(); info->h_value = value; - m_valid_handle.insert(info); m_buffer.enqueue(Put{info, value, no_cache}); if (m_async_level == 0) { sync(); @@ -49,20 +48,25 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { info->desc.layout = data.layout(); info->desc.comp_node = data.comp_node(); info->ptr = Tensor::make(data); - m_valid_handle.insert(info); return info; } void ChannelImpl::del(Handle handle) { - mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); - m_buffer.enqueue(Del{reinterpret_cast(handle)}); + mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle); + auto* info = reinterpret_cast(handle); + detach_users(info); + info->detach_producer(); + m_valid_handle.erase(handle); + m_buffer.enqueue(Del{info}); } void ChannelImpl::swap_in(Handle handle) { if (m_enable_evict & SWAP) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); - m_buffer.enqueue(SwapIn{reinterpret_cast(handle)}); + auto* info = reinterpret_cast(handle); + m_buffer.enqueue(SwapIn{info}); + info->evict_type = NONE; } } @@ -70,7 +74,9 @@ void ChannelImpl::swap_out(Handle handle) { if (m_enable_evict & SWAP) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); - m_buffer.enqueue(SwapOut{reinterpret_cast(handle)}); + auto* info = reinterpret_cast(handle); + m_buffer.enqueue(SwapOut{info}); + info->evict_type = SWAP; } } @@ -78,7 +84,13 @@ void ChannelImpl::drop(Handle handle) { if (m_enable_evict & DROP) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); - m_buffer.enqueue(Drop{reinterpret_cast(handle)}); + auto* info = reinterpret_cast(handle); + if (!info->producer) { + mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", info); + return; + } + info->evict_type = DROP; + m_buffer.enqueue(Drop{info}); } } @@ -134,18 +146,8 @@ void ChannelImpl::dispatch_default_cpu( output_infos.push_back(info); outputs->push_back(info); } - if (m_enable_evict & DROP) { - for (auto out : output_infos) { - out->path.op = op; - for (auto out_ : output_infos) { - out->path.outputs.push_back(m_st.at(out_)); - } - for (auto inp : input_infos) { - out->path.inputs.push_back(m_st.at(inp)); - inp->path.dep_outputs.push_back(m_st.at(out)); - } - } + TensorInfo::ComputePath::make(op, input_infos, output_infos); } } @@ -168,21 +170,11 @@ void ChannelImpl::dispatch_kernel( info->h_value = HostTensorND::make_proxy(desc.value) .proxy_to_comp_node(desc.comp_node); } - m_valid_handle.insert(info); cmd.outputs.push_back(info); outputs->push_back(info); } if (m_enable_evict & DROP) { - for (auto out : cmd.outputs) { - out->path.op = cmd.op; - for (auto out_ : cmd.outputs) { - out->path.outputs.push_back(m_st.at(out_)); - } - for (auto inp : cmd.inputs) { - out->path.inputs.push_back(m_st.at(inp)); - inp->path.dep_outputs.push_back(m_st.at(out)); - } - } + TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs); } m_buffer.enqueue(std::move(cmd)); if (!validated && m_async_level == 1) { @@ -215,6 +207,7 @@ SmallVector ChannelImpl::apply_op( mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); input_infos.push_back(info); input_descs.push_back(info->desc); + regenerate(info); } } @@ -233,23 +226,31 @@ SmallVector ChannelImpl::apply_op( } HostTensorND ChannelImpl::get_value(Handle handle) { + // TODO: maybe get_value should be done on host. i.e. delete GetValue mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); - std::unique_lock lock(m_mutex); mgb_assert(!m_waitee); - if (!info->value_fetched) { - mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); + // donnot use info->value_fetched, it's unsafe + mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); + TensorPtr tensor_ptr = info->ptr; + auto value_fetched = [&]() { + return tensor_ptr && tensor_ptr->value_fetched(); + }; + if (!value_fetched()) { + std::unique_lock lock(m_mutex); m_waitee = info; + regenerate(info); m_buffer.enqueue(GetValue{info}); m_cv.wait(lock, [&]() { check_worker_exc_unsafe(); - return info->value_fetched; + // get tensor ptr in lock to ensure safety + tensor_ptr = info->ptr; + return value_fetched(); }); m_waitee = nullptr; } - mgb_assert(info->ptr->value_fetched()); - return info->ptr->get_value(); + return tensor_ptr->get_value(); } TensorShape ChannelImpl::get_shape(Handle handle) { @@ -298,6 +299,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { std::unique_lock lock(m_mutex); mgb_assert(!m_waitee); m_waitee = info; + regenerate(info); m_buffer.enqueue(Flush{info}); m_cv.wait(lock, [&]() { check_worker_exc_unsafe(); @@ -332,17 +334,12 @@ int ChannelImpl::get_async_level() { TensorInfo* ChannelImpl::alloc() { MGB_LOCK_GUARD(m_mutex); auto info = m_pool.alloc(); - m_st.insert(info); + m_valid_handle.insert(info); return info; } void ChannelImpl::free(TensorInfo* ptr) { MGB_LOCK_GUARD(m_mutex); - if (ptr->path.dep_outputs.size() > 0) { - remove_dep(ptr); - } - m_st.erase(ptr); - mgb_assert(ptr->allow_delete, "delete before ref_cnt = 0"); m_pool.free(ptr); } @@ -350,77 +347,64 @@ ChannelImpl::~ChannelImpl() { close(); } -void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice = true) { - auto lock = notice ? std::unique_lock(m_mutex) - : std::unique_lock(); +void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { + MGB_LOCK_GUARD(m_mutex); dest->value_fetched = ptr->value_fetched(); // update tensor desc for static infer dest->desc.layout = ptr->layout(); dest->desc.comp_node = ptr->comp_node(); dest->ptr = std::move(ptr); - if (notice && m_waitee == dest) { + if (m_waitee == dest) { m_cv.notify_all(); } } -void ChannelImpl::do_swap_out(TensorInfo* dest) { +void ChannelImpl::regenerate(TensorInfo* dest) { if (dest->evict_type == DROP) { - mgb_log_warn("the evict type of tensor %p was set to DROP, this SWAP operation will be ignored", dest); - return; - } - if (!dest->ptr) { - return; + recompute(dest->producer); + } else if (dest->evict_type == SWAP) { + swap_in(dest); } - dest->evict_type = SWAP; - dest->value_fetched = false; - // TODO: swap in parallel - dest->h_value = dest->ptr->get_value(); - dest->ptr.reset(); + mgb_assert(dest->evict_type == NONE); } -void ChannelImpl::do_swap_in(TensorInfo* dest) { - if (dest->ptr) { - return; - } - if (dest->h_value.empty()) { - mgb_log_error("backup of the tensor %p not found", dest); - return; +void ChannelImpl::recompute(TensorInfo::ComputePath* path) { + SmallVector workspaces(path->outputs.size(), nullptr); + for (auto&& input: path->inputs) { + regenerate(input); } - produce_tensor(dest, Tensor::make(dest->h_value), false); - dest->evict_type = NONE; -} - -void ChannelImpl::remove_dep(TensorInfo* dest) { - for (auto i : dest->path.dep_outputs) { - auto out_ptr = i.lock(); - if (out_ptr) { - regenerate(out_ptr.get(), true); + for (auto&& output: path->outputs) { + if(output == nullptr) { + continue; } + output->evict_type = NONE; } + m_buffer.enqueue(ApplyOp{path->op, path->inputs, path->outputs}); } -void ChannelImpl::do_drop(TensorInfo* dest) { - if (dest->evict_type == SWAP) { - mgb_log_warn("the evict type of tensor %p was set to SWAP, this DROP operation will be ignored", dest); - return; - } - if (!dest->path.op) { - mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", dest); - return; - } - if (dest->recompute_times >= m_max_recompute_time) { - mgb_log_warn("the recomputation time for tensor %p exceeds the limit, this drop operation will be ignored", dest); - return; - } - if (!dest->ptr) { - return; +void ChannelImpl::detach_users(TensorInfo* dest) { + SmallVector users = dest->users; + for (auto* user: users) { + for (auto* output: user->outputs) { + if (output == nullptr) { + continue; + } + regenerate(output); + output->detach_producer(); + } } - dest->evict_type = DROP; - dest->value_fetched = false; - dest->ptr.reset(); + dest->users.clear(); } void ChannelImpl::set_swap_flag(bool flag) { + if ((!flag) && (m_enable_evict & SWAP)) { + for (auto handle: m_valid_handle) { + auto* info = reinterpret_cast(handle); + if (info->evict_type == SWAP) { + swap_in(info); + } + } + } if (flag) { m_enable_evict |= SWAP; } else { @@ -429,6 +413,14 @@ void ChannelImpl::set_swap_flag(bool flag) { } void ChannelImpl::set_drop_flag(bool flag) { + if ((!flag) && (m_enable_evict & DROP)) { + for (auto handle: m_valid_handle) { + auto* info = reinterpret_cast(handle); + if (info->evict_type == DROP) { + recompute(info->producer); + } + } + } if (flag) { m_enable_evict |= DROP; } else { @@ -440,46 +432,6 @@ void ChannelImpl::set_buffer_length(int length) { m_buffer.set_capacity(length); } -void ChannelImpl::regenerate(TensorInfo* info, bool must_drop = false) { - if (!info->ptr && info->evict_type != NONE) { - if (info->evict_type == SWAP) { - do_swap_in(info); - } else { - mgb_assert(info->evict_type == DROP); - mgb_assert(info->path.op, "recomputation path not found"); - auto path = info->path; - SmallVector inputs; - inputs.reserve(path.inputs.size()); - for (auto i : path.inputs) { - mgb_assert(i, "invalid history input"); - if (!i->ptr) { - regenerate(i.get(), must_drop); - } - inputs.push_back(i->ptr); - } - auto outputs = OpDef::apply_on_physical_tensor(*path.op, inputs); - for (size_t i = 0; i < outputs.size(); i ++) { - auto out_ptr = path.outputs[i].lock(); - if (out_ptr) { - out_ptr->recompute_times ++; - if (!out_ptr->ptr && out_ptr->evict_type == DROP) { - produce_tensor(out_ptr.get(), std::move(outputs[i]), false); - } - } - } - } - } - if (must_drop) { - if (info->path.op) { - info->path.op.reset(); - info->path.inputs.clear(); - if (info->evict_type == DROP) { - info->evict_type = NONE; - } - } - } -} - void ChannelImpl::process_one_task(Command& cmd) { //TODO: remove std::visit for support osx 10.12 std::visit([this](auto& cmd) { @@ -493,11 +445,6 @@ void ChannelImpl::process_one_task(Command& cmd) { tensor_inputs.reserve(cmd.inputs.size()); // refcnt == 1, owners: [TensorInfo::ptr] for (auto i : cmd.inputs) { - if (m_enable_evict && i->evict_type != NONE) { - if (!i->ptr) { - regenerate(i); - } - } mgb_assert(i->ptr, "Invalid input tensor ptr!"); // refcnt ++, owners: [i->ptr, tensor_inputs] tensor_inputs.push_back(i->ptr); @@ -515,16 +462,14 @@ void ChannelImpl::process_one_task(Command& cmd) { *cmd.op, std::move(tensor_inputs)); mgb_assert(tensor_outputs.size() == cmd.outputs.size()); for (size_t i = 0; i < tensor_outputs.size(); ++i) { + if (cmd.outputs[i] == nullptr) { + continue; + } produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i])); } } else if constexpr (std::is_same_v) { free(cmd.dest); } else if constexpr (std::is_same_v) { - if (m_enable_evict && cmd.dest->evict_type != NONE) { - if (!cmd.dest->ptr) { - regenerate(cmd.dest); - } - } mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!"); cmd.dest->ptr->fetch_value(); MGB_LOCK_GUARD(m_mutex); @@ -533,11 +478,12 @@ void ChannelImpl::process_one_task(Command& cmd) { m_cv.notify_all(); } } else if constexpr (std::is_same_v) { - do_swap_in(cmd.dest); + produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value)); } else if constexpr (std::is_same_v) { - do_swap_out(cmd.dest); + cmd.dest->h_value = cmd.dest->ptr->get_value(); + cmd.dest->ptr.reset(); } else if constexpr (std::is_same_v) { - do_drop(cmd.dest); + cmd.dest->ptr.reset(); } else if constexpr (std::is_same_v) { produce_tensor(cmd.dest, cmd.src->ptr); free(cmd.src); diff --git a/imperative/src/impl/interpreter_impl.h b/imperative/src/impl/interpreter_impl.h index 328fe3eb..c9c070c8 100644 --- a/imperative/src/impl/interpreter_impl.h +++ b/imperative/src/impl/interpreter_impl.h @@ -38,22 +38,77 @@ using TensorInfoPtr = std::shared_ptr; struct TensorInfo { TensorPtr ptr; LogicalTensorDesc desc; + + // FIXME: broken by drop bool value_fetched = false; bool invalid = false; - bool allow_delete = false; EvictType evict_type = NONE; HostTensorND h_value; - size_t locked = 0; + + // reserved for auto drop + size_t pinned = 0; size_t recompute_times = 0; struct ComputePath { std::shared_ptr op; - SmallVector inputs; - SmallVector> outputs; - SmallVector> dep_outputs; - } path; + SmallVector inputs; + SmallVector unique_inputs; + SmallVector outputs; + + size_t ref_cnt() { + return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr); + } + + static ComputePath* make(std::shared_ptr op, SmallVector inputs, SmallVector outputs) { + auto* path = new TensorInfo::ComputePath(); + path->op = op; + path->inputs = inputs; + path->outputs = outputs; + // dedup + SmallVector unique_inputs = inputs; + std::sort(unique_inputs.begin(), unique_inputs.end()); + unique_inputs.erase(std::unique(unique_inputs.begin(), unique_inputs.end()), unique_inputs.end()); + path->unique_inputs = unique_inputs; + // attach users + for (auto input: unique_inputs) { + input->users.push_back(path); + } + // attach producer + for (auto output: outputs) { + output->producer = path; + } + return path; + } + }* producer = nullptr; + + void pin() { + ++pinned; + } + + void unpin() { + --pinned; + } + + void detach_producer() { + if (!producer) { + return; + } + auto output = std::find(producer->outputs.begin(), producer->outputs.end(), this); + mgb_assert(output != producer->outputs.end()); + *output = nullptr; + if (producer->ref_cnt() == 0) { + for (auto* input: producer->unique_inputs) { + input->users.erase(std::find(input->users.begin(), input->users.end(), producer)); + } + delete producer; + } + producer = nullptr; + } + + SmallVector users; + }; struct Put { @@ -186,17 +241,16 @@ struct ChannelImpl : Interpreter::Channel { private: TensorInfo* alloc(); void free(TensorInfo*); - void remove_dep(TensorInfo*); + void detach_users(TensorInfo*); void process_one_task(Command&); void check_worker_exc_unsafe(); - void produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice); - void do_swap_out(TensorInfo* dest); - void do_swap_in(TensorInfo* dest); - void do_drop(TensorInfo* dest); - void regenerate(TensorInfo* dest, bool must_drop); + void produce_tensor(TensorInfo* dest, TensorPtr ptr); + + void regenerate(TensorInfo* dest); + void recompute(TensorInfo::ComputePath* path); void dispatch_default_cpu( std::shared_ptr op, @@ -235,24 +289,6 @@ private: ChannelImpl* m_owner; } m_worker; - struct SharedTensorInfoMap { - void insert(TensorInfo* info) { - MGB_LOCK_GUARD(mtx); - tmap.emplace(info, TensorInfoPtr{info, [](TensorInfo* ptr){ ptr->allow_delete = true;}}); - } - void erase(TensorInfo* info) { - MGB_LOCK_GUARD(mtx); - tmap.erase(info); - } - TensorInfoPtr at(TensorInfo* info) { - MGB_LOCK_GUARD(mtx); - return tmap.at(info); - } - private: - std::mutex mtx; - std::unordered_map tmap; - }m_st; - /** * Buf a command window for following fuse * example: