@@ -0,0 +1,13 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import mprop | |||||
from .dtr import * | |||||
mprop.init() |
@@ -9,13 +9,12 @@ | |||||
import re | import re | ||||
from typing import Union | from typing import Union | ||||
from mprop import mproperty | |||||
from .core._imperative_rt.core2 import set_option as _set_option | |||||
from .core._imperative_rt.utils import _set_defrag | |||||
from ..core._imperative_rt.core2 import set_option as _set_option | |||||
from ..core._imperative_rt.utils import _set_defrag | |||||
_eviction_threshold = 0 | _eviction_threshold = 0 | ||||
_evictee_minimum_size = 1024 ** 2 | _evictee_minimum_size = 1024 ** 2 | ||||
_enable_sqrt_sampling = False | |||||
def _str2bytes(text: str) -> int: | def _str2bytes(text: str) -> int: | ||||
@@ -29,7 +28,7 @@ def _str2bytes(text: str) -> int: | |||||
return int(float(result[0][0]) * 1024 ** order.index(result[0][1].lower())) | return int(float(result[0][0]) * 1024 ** order.index(result[0][1].lower())) | ||||
@mproperty | |||||
@property | |||||
def eviction_threshold(mod): | def eviction_threshold(mod): | ||||
r""" | r""" | ||||
Get or set the eviction threshold in bytes. It can also be set to a string, | Get or set the eviction threshold in bytes. It can also be set to a string, | ||||
@@ -50,21 +49,22 @@ def eviction_threshold(mod): | |||||
mge.dtr.eviction_threshold = "2GB" | mge.dtr.eviction_threshold = "2GB" | ||||
""" | """ | ||||
return mod._eviction_threshold | |||||
return _eviction_threshold | |||||
@eviction_threshold.setter | @eviction_threshold.setter | ||||
def eviction_threshold(mod, value: Union[int, str]): | def eviction_threshold(mod, value: Union[int, str]): | ||||
global _eviction_threshold | |||||
if isinstance(value, str): | if isinstance(value, str): | ||||
mod._eviction_threshold = mod._str2bytes(value) | |||||
_eviction_threshold = _str2bytes(value) | |||||
elif isinstance(value, int): | elif isinstance(value, int): | ||||
mod._eviction_threshold = value | |||||
_eviction_threshold = value | |||||
else: | else: | ||||
raise TypeError("`value` should be a str or an int") | raise TypeError("`value` should be a str or an int") | ||||
_set_option("dtr_eviction_threshold", mod._eviction_threshold) | |||||
_set_option("dtr_eviction_threshold", _eviction_threshold) | |||||
@mproperty | |||||
@property | |||||
def evictee_minimum_size(mod): | def evictee_minimum_size(mod): | ||||
r""" | r""" | ||||
Get or set the memory threshold of tensors in bytes. It can also be set to a | Get or set the memory threshold of tensors in bytes. It can also be set to a | ||||
@@ -85,18 +85,45 @@ def evictee_minimum_size(mod): | |||||
mge.dtr.evictee_minimum_size = "2MB" | mge.dtr.evictee_minimum_size = "2MB" | ||||
""" | """ | ||||
return mod._evictee_minimum_size | |||||
return _evictee_minimum_size | |||||
@evictee_minimum_size.setter | @evictee_minimum_size.setter | ||||
def evictee_minimum_size(mod, value: Union[int, str]): | def evictee_minimum_size(mod, value: Union[int, str]): | ||||
global _evictee_minimum_size | |||||
if isinstance(value, str): | if isinstance(value, str): | ||||
mod._evictee_minimum_size = mod._str2bytes(value) | |||||
_evictee_minimum_size = _str2bytes(value) | |||||
elif isinstance(value, int): | elif isinstance(value, int): | ||||
mod._evictee_minimum_size = value | |||||
_evictee_minimum_size = value | |||||
else: | else: | ||||
raise TypeError("`value` should be a str or an int") | raise TypeError("`value` should be a str or an int") | ||||
_set_option("dtr_evictee_minimum_size", mod._evictee_minimum_size) | |||||
_set_option("dtr_evictee_minimum_size", _evictee_minimum_size) | |||||
@property | |||||
def enable_sqrt_sampling(mod): | |||||
r""" | |||||
Get or set whether sqrt sampling is allowed. Sqrt sampling means that given | |||||
the size of the candidate set is N, only enumerate sqrt(N) tensors. When | |||||
the number of tensors is very high, enabling this optimization will speed | |||||
up the training. | |||||
Examples: | |||||
.. code-block:: | |||||
import megengine as mge | |||||
mge.dtr.enable_sqrt_sampling = True | |||||
""" | |||||
return _enable_sqrt_sampling | |||||
@enable_sqrt_sampling.setter | |||||
def enable_sqrt_sampling(mod, value: bool): | |||||
global _enable_sqrt_sampling | |||||
_enable_sqrt_sampling = value | |||||
_set_option("enable_dtr_sqrt_sampling", _enable_sqrt_sampling) | |||||
def enable(): | def enable(): |
@@ -761,7 +761,7 @@ bool ChannelImpl::auto_evict(size_t force_num=0) { | |||||
while ((state.options.dtr_eviction_threshold > 0 && current_memory > state.options.dtr_eviction_threshold) || force_num > 0) { | while ((state.options.dtr_eviction_threshold > 0 && current_memory > state.options.dtr_eviction_threshold) || force_num > 0) { | ||||
RECORD_EVENT(AutoEvictEvent); | RECORD_EVENT(AutoEvictEvent); | ||||
sample_on_device(m_dtr.comp_node, false); | sample_on_device(m_dtr.comp_node, false); | ||||
auto best = m_dtr.find_best_tensor(); | |||||
auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling && !force_num); | |||||
if (!best) { | if (!best) { | ||||
break; | break; | ||||
} | } | ||||
@@ -988,8 +988,15 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { | |||||
if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) { | if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) { | ||||
TensorInfo::ComputePath::make(cmd.id, cmd.op, cmd.inputs, cmd.outputs); | TensorInfo::ComputePath::make(cmd.id, cmd.op, cmd.inputs, cmd.outputs); | ||||
size_t detach_cnt = 0; | size_t detach_cnt = 0; | ||||
if (!strcmp(get_name(*cmd.op), "BatchNorm") && cmd.outputs.size() == 5) { | |||||
cmd.outputs[0]->detach_producer(); // detach running_mean | |||||
cmd.outputs[1]->detach_producer(); // detach running_var | |||||
for (auto input : cmd.inputs) { | |||||
input->ref_cnt -= 2; | |||||
} | |||||
} | |||||
for (auto output : cmd.outputs) { | for (auto output : cmd.outputs) { | ||||
if (!output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { | |||||
if (output->producer && !output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { | |||||
output->detach_producer(); | output->detach_producer(); | ||||
detach_cnt ++; | detach_cnt ++; | ||||
} | } | ||||
@@ -1339,9 +1346,15 @@ double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) { | |||||
return cost; | return cost; | ||||
} | } | ||||
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor() { | |||||
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(bool enable_dtr_sqrt_sampling=false) { | |||||
double min_msps = -1; | double min_msps = -1; | ||||
TensorInfo* best = nullptr; | TensorInfo* best = nullptr; | ||||
size_t sz = 1; | |||||
if (enable_dtr_sqrt_sampling) { | |||||
while (sz * sz <= candidates.size()) sz ++; | |||||
} else { | |||||
sz = candidates.size(); | |||||
} | |||||
for (auto i : candidates) { | for (auto i : candidates) { | ||||
if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) { | if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) { | ||||
double neighbor_cost = estimate_neighbor_cost(i); | double neighbor_cost = estimate_neighbor_cost(i); | ||||
@@ -1354,6 +1367,7 @@ TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor() { | |||||
best = i; | best = i; | ||||
} | } | ||||
} | } | ||||
if (--sz == 0) break; | |||||
} | } | ||||
return best; | return best; | ||||
} | } | ||||
@@ -323,7 +323,7 @@ private: | |||||
* \return the pointer of the best tensor; nullptr is returned if no | * \return the pointer of the best tensor; nullptr is returned if no | ||||
* available tensor is found | * available tensor is found | ||||
*/ | */ | ||||
TensorInfo* find_best_tensor(); | |||||
TensorInfo* find_best_tensor(bool); | |||||
/*! | /*! | ||||
* \brief estimate the cost of recomputing tensor ptr | * \brief estimate the cost of recomputing tensor ptr | ||||
@@ -41,6 +41,7 @@ public: | |||||
DEF_OPTION(enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1, | DEF_OPTION(enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1, | ||||
"enable host compute, thus computation may be done in host event if it's device is gpu."); | "enable host compute, thus computation may be done in host event if it's device is gpu."); | ||||
DEF_OPTION(enable_dtr_auto_drop, "MEGENGINE_DTR_AUTO_DROP", 0, ""); | DEF_OPTION(enable_dtr_auto_drop, "MEGENGINE_DTR_AUTO_DROP", 0, ""); | ||||
DEF_OPTION(enable_dtr_sqrt_sampling, "MEGENGINE_DTR_SQRT_SAMPLING", 0, ""); | |||||
DEF_OPTION(dtr_eviction_threshold, "MEGENGINE_DTR_EVICTION_THRESHOLD", 0, | DEF_OPTION(dtr_eviction_threshold, "MEGENGINE_DTR_EVICTION_THRESHOLD", 0, | ||||
"auto drop will start whenever gpu memory usage exceeds this value."); | "auto drop will start whenever gpu memory usage exceeds this value."); | ||||
DEF_OPTION(dtr_evictee_minimum_size, "MEGENGINE_DTR_EVICTEE_MINIMUM_SIZE", 1048576, | DEF_OPTION(dtr_evictee_minimum_size, "MEGENGINE_DTR_EVICTEE_MINIMUM_SIZE", 1048576, | ||||