diff --git a/imperative/python/megengine/dtr/__init__.py b/imperative/python/megengine/dtr/__init__.py new file mode 100644 index 00000000..2dd81406 --- /dev/null +++ b/imperative/python/megengine/dtr/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import mprop + +from .dtr import * + +mprop.init() diff --git a/imperative/python/megengine/dtr.py b/imperative/python/megengine/dtr/dtr.py similarity index 70% rename from imperative/python/megengine/dtr.py rename to imperative/python/megengine/dtr/dtr.py index ef763e97..10435fff 100644 --- a/imperative/python/megengine/dtr.py +++ b/imperative/python/megengine/dtr/dtr.py @@ -9,13 +9,12 @@ import re from typing import Union -from mprop import mproperty - -from .core._imperative_rt.core2 import set_option as _set_option -from .core._imperative_rt.utils import _set_defrag +from ..core._imperative_rt.core2 import set_option as _set_option +from ..core._imperative_rt.utils import _set_defrag _eviction_threshold = 0 _evictee_minimum_size = 1024 ** 2 +_enable_sqrt_sampling = False def _str2bytes(text: str) -> int: @@ -29,7 +28,7 @@ def _str2bytes(text: str) -> int: return int(float(result[0][0]) * 1024 ** order.index(result[0][1].lower())) -@mproperty +@property def eviction_threshold(mod): r""" Get or set the eviction threshold in bytes. It can also be set to a string, @@ -50,21 +49,22 @@ def eviction_threshold(mod): mge.dtr.eviction_threshold = "2GB" """ - return mod._eviction_threshold + return _eviction_threshold @eviction_threshold.setter def eviction_threshold(mod, value: Union[int, str]): + global _eviction_threshold if isinstance(value, str): - mod._eviction_threshold = mod._str2bytes(value) + _eviction_threshold = _str2bytes(value) elif isinstance(value, int): - mod._eviction_threshold = value + _eviction_threshold = value else: raise TypeError("`value` should be a str or an int") - _set_option("dtr_eviction_threshold", mod._eviction_threshold) + _set_option("dtr_eviction_threshold", _eviction_threshold) -@mproperty +@property def evictee_minimum_size(mod): r""" Get or set the memory threshold of tensors in bytes. It can also be set to a @@ -85,18 +85,45 @@ def evictee_minimum_size(mod): mge.dtr.evictee_minimum_size = "2MB" """ - return mod._evictee_minimum_size + return _evictee_minimum_size @evictee_minimum_size.setter def evictee_minimum_size(mod, value: Union[int, str]): + global _evictee_minimum_size if isinstance(value, str): - mod._evictee_minimum_size = mod._str2bytes(value) + _evictee_minimum_size = _str2bytes(value) elif isinstance(value, int): - mod._evictee_minimum_size = value + _evictee_minimum_size = value else: raise TypeError("`value` should be a str or an int") - _set_option("dtr_evictee_minimum_size", mod._evictee_minimum_size) + _set_option("dtr_evictee_minimum_size", _evictee_minimum_size) + + +@property +def enable_sqrt_sampling(mod): + r""" + Get or set whether sqrt sampling is allowed. Sqrt sampling means that given + the size of the candidate set is N, only enumerate sqrt(N) tensors. When + the number of tensors is very high, enabling this optimization will speed + up the training. + + Examples: + + .. code-block:: + + import megengine as mge + mge.dtr.enable_sqrt_sampling = True + + """ + return _enable_sqrt_sampling + + +@enable_sqrt_sampling.setter +def enable_sqrt_sampling(mod, value: bool): + global _enable_sqrt_sampling + _enable_sqrt_sampling = value + _set_option("enable_dtr_sqrt_sampling", _enable_sqrt_sampling) def enable(): diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 313c5a68..89379c8b 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -761,7 +761,7 @@ bool ChannelImpl::auto_evict(size_t force_num=0) { while ((state.options.dtr_eviction_threshold > 0 && current_memory > state.options.dtr_eviction_threshold) || force_num > 0) { RECORD_EVENT(AutoEvictEvent); sample_on_device(m_dtr.comp_node, false); - auto best = m_dtr.find_best_tensor(); + auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling && !force_num); if (!best) { break; } @@ -988,8 +988,15 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) { TensorInfo::ComputePath::make(cmd.id, cmd.op, cmd.inputs, cmd.outputs); size_t detach_cnt = 0; + if (!strcmp(get_name(*cmd.op), "BatchNorm") && cmd.outputs.size() == 5) { + cmd.outputs[0]->detach_producer(); // detach running_mean + cmd.outputs[1]->detach_producer(); // detach running_var + for (auto input : cmd.inputs) { + input->ref_cnt -= 2; + } + } for (auto output : cmd.outputs) { - if (!output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { + if (output->producer && !output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { output->detach_producer(); detach_cnt ++; } @@ -1339,9 +1346,15 @@ double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) { return cost; } -TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor() { +TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(bool enable_dtr_sqrt_sampling=false) { double min_msps = -1; TensorInfo* best = nullptr; + size_t sz = 1; + if (enable_dtr_sqrt_sampling) { + while (sz * sz <= candidates.size()) sz ++; + } else { + sz = candidates.size(); + } for (auto i : candidates) { if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) { double neighbor_cost = estimate_neighbor_cost(i); @@ -1354,6 +1367,7 @@ TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor() { best = i; } } + if (--sz == 0) break; } return best; } diff --git a/imperative/src/impl/interpreter/interpreter_impl.h b/imperative/src/impl/interpreter/interpreter_impl.h index 9d869251..a79988ae 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.h +++ b/imperative/src/impl/interpreter/interpreter_impl.h @@ -323,7 +323,7 @@ private: * \return the pointer of the best tensor; nullptr is returned if no * available tensor is found */ - TensorInfo* find_best_tensor(); + TensorInfo* find_best_tensor(bool); /*! * \brief estimate the cost of recomputing tensor ptr diff --git a/imperative/src/impl/interpreter/option_manager.h b/imperative/src/impl/interpreter/option_manager.h index c140b1f4..591ad727 100644 --- a/imperative/src/impl/interpreter/option_manager.h +++ b/imperative/src/impl/interpreter/option_manager.h @@ -41,6 +41,7 @@ public: DEF_OPTION(enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1, "enable host compute, thus computation may be done in host event if it's device is gpu."); DEF_OPTION(enable_dtr_auto_drop, "MEGENGINE_DTR_AUTO_DROP", 0, ""); + DEF_OPTION(enable_dtr_sqrt_sampling, "MEGENGINE_DTR_SQRT_SAMPLING", 0, ""); DEF_OPTION(dtr_eviction_threshold, "MEGENGINE_DTR_EVICTION_THRESHOLD", 0, "auto drop will start whenever gpu memory usage exceeds this value."); DEF_OPTION(dtr_evictee_minimum_size, "MEGENGINE_DTR_EVICTEE_MINIMUM_SIZE", 1048576,