From bcf69d8f87d0eb69b8c52ed18fc862c0d075306a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 4 Nov 2021 19:26:02 +0800 Subject: [PATCH] refactor(imperative): correctly apply sqrt sampling for dtr GitOrigin-RevId: dabd36551765af1d2646789ae9ed57d8eac4a936 --- .../src/impl/interpreter/interpreter_impl.cpp | 43 +++++++++++++++++----- imperative/src/impl/interpreter/interpreter_impl.h | 2 +- imperative/src/impl/interpreter/tensor_info.h | 3 ++ 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index d71d6842..3d84340c 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -646,6 +646,10 @@ void ChannelImpl::release_tensor(TensorInfo* dest) { MGB_RECORD_EVENT(TensorReleaseEvent, dest->id); MGB_LOCK_GUARD(m_mutex); dest->ptr.reset(); + auto& state = get_worker_state(); + if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { + m_dtr.erase_candidate(dest); + } } void ChannelImpl::regenerate(TensorInfo* dest) { @@ -891,8 +895,7 @@ bool ChannelImpl::auto_evict(size_t force_num) { force_num > 0) { MGB_RECORD_EVENT(AutoEvictEvent); sample_on_device(m_dtr.comp_node, false); - auto best = m_dtr.find_best_tensor( - state.options.enable_dtr_sqrt_sampling && !force_num); + auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling); if (!best) { MGB_RECORD_EVENT(AutoEvictFinishEvent); break; @@ -1300,7 +1303,6 @@ void ChannelImpl::CommandBuffer::enqueue(CommandData cmd) { if (std::get_if(&cmd) && fuse_del(std::get(cmd))) { return; } - // mgb_log_debug("%s Enqueued", to_string(cmd).c_str()); m_commands.push_back( {Profiler::next_id(), std::move(cmd), state.stack_manager.dump()}); auto flush_pos = flush_pos_for(m_commands.back()); @@ -1365,7 +1367,6 @@ bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) { if (apply_iter == end || find_last_usage(dest, {apply_iter + 1, end}) != end) { return false; } - // mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str()); std::get(apply_iter->data).dels.push_back(dest); return true; } @@ -1538,16 +1539,26 @@ double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) { TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor( bool enable_dtr_sqrt_sampling = false) { + if (candidates.empty()) + return nullptr; + double min_msps = -1; TensorInfo* best = nullptr; size_t sz = 1; if (enable_dtr_sqrt_sampling) { while (sz * sz <= candidates.size()) sz++; + sz--; } else { sz = candidates.size(); } - for (auto i : candidates) { + + size_t ti = rand() % sz; + for (size_t vi = 0; vi < sz; vi++) { + if (!enable_dtr_sqrt_sampling) { + ti = vi; + } + auto i = candidates[ti]; if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) { double neighbor_cost = estimate_neighbor_cost(i); size_t begin_ptr = @@ -1562,8 +1573,11 @@ TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor( best = i; } } - if (--sz == 0) - break; + if (enable_dtr_sqrt_sampling) { + ti += rand() % sz; + if (ti > candidates.size()) + break; + } } return best; } @@ -1590,14 +1604,25 @@ std::shared_ptr ChannelImpl::DynamicSublinear::find_father( } void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) { - candidates.insert(ptr); + // tensor to be inserted must be brand new + mgb_assert( + ptr->cand_index == UINT_MAX, "got wrong candidate index : %lu", + ptr->cand_index); + ptr->cand_index = candidates.size(); + candidates.push_back(ptr); if (!comp_node.valid()) { comp_node = ptr->ptr->comp_node(); } } void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) { - candidates.erase(ptr); + // some tensors may be erased already, so just skip them + if (ptr->cand_index != UINT_MAX) { + std::swap(candidates[ptr->cand_index], candidates.back()); + candidates[ptr->cand_index]->cand_index = ptr->cand_index; + candidates.pop_back(); + ptr->cand_index = UINT_MAX; + } } void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) { diff --git a/imperative/src/impl/interpreter/interpreter_impl.h b/imperative/src/impl/interpreter/interpreter_impl.h index 25d360fd..34019db1 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.h +++ b/imperative/src/impl/interpreter/interpreter_impl.h @@ -335,7 +335,7 @@ private: CompNode comp_node; //! store all tensors that may be evicted - std::unordered_set candidates; + SmallVector candidates; bool is_bad_op(std::string op_name) { return std::find(op_blacklist.begin(), op_blacklist.end(), op_name) != diff --git a/imperative/src/impl/interpreter/tensor_info.h b/imperative/src/impl/interpreter/tensor_info.h index 152e7fd3..0de89696 100644 --- a/imperative/src/impl/interpreter/tensor_info.h +++ b/imperative/src/impl/interpreter/tensor_info.h @@ -170,6 +170,9 @@ struct TensorInfo { bool size_exceeds_thd(size_t thd) { return memory > thd; } SmallVector users; + + // UINT_MAX as a magic default value + size_t cand_index = UINT_MAX; }; } // namespace interpreter::intl