@@ -0,0 +1,13 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import mprop | |||
from .dtr import * | |||
mprop.init() |
@@ -9,13 +9,12 @@ | |||
import re | |||
from typing import Union | |||
from mprop import mproperty | |||
from .core._imperative_rt.core2 import set_option as _set_option | |||
from .core._imperative_rt.utils import _set_defrag | |||
from ..core._imperative_rt.core2 import set_option as _set_option | |||
from ..core._imperative_rt.utils import _set_defrag | |||
_eviction_threshold = 0 | |||
_evictee_minimum_size = 1024 ** 2 | |||
_enable_sqrt_sampling = False | |||
def _str2bytes(text: str) -> int: | |||
@@ -29,7 +28,7 @@ def _str2bytes(text: str) -> int: | |||
return int(float(result[0][0]) * 1024 ** order.index(result[0][1].lower())) | |||
@mproperty | |||
@property | |||
def eviction_threshold(mod): | |||
r""" | |||
Get or set the eviction threshold in bytes. It can also be set to a string, | |||
@@ -50,21 +49,22 @@ def eviction_threshold(mod): | |||
mge.dtr.eviction_threshold = "2GB" | |||
""" | |||
return mod._eviction_threshold | |||
return _eviction_threshold | |||
@eviction_threshold.setter | |||
def eviction_threshold(mod, value: Union[int, str]): | |||
global _eviction_threshold | |||
if isinstance(value, str): | |||
mod._eviction_threshold = mod._str2bytes(value) | |||
_eviction_threshold = _str2bytes(value) | |||
elif isinstance(value, int): | |||
mod._eviction_threshold = value | |||
_eviction_threshold = value | |||
else: | |||
raise TypeError("`value` should be a str or an int") | |||
_set_option("dtr_eviction_threshold", mod._eviction_threshold) | |||
_set_option("dtr_eviction_threshold", _eviction_threshold) | |||
@mproperty | |||
@property | |||
def evictee_minimum_size(mod): | |||
r""" | |||
Get or set the memory threshold of tensors in bytes. It can also be set to a | |||
@@ -85,18 +85,45 @@ def evictee_minimum_size(mod): | |||
mge.dtr.evictee_minimum_size = "2MB" | |||
""" | |||
return mod._evictee_minimum_size | |||
return _evictee_minimum_size | |||
@evictee_minimum_size.setter | |||
def evictee_minimum_size(mod, value: Union[int, str]): | |||
global _evictee_minimum_size | |||
if isinstance(value, str): | |||
mod._evictee_minimum_size = mod._str2bytes(value) | |||
_evictee_minimum_size = _str2bytes(value) | |||
elif isinstance(value, int): | |||
mod._evictee_minimum_size = value | |||
_evictee_minimum_size = value | |||
else: | |||
raise TypeError("`value` should be a str or an int") | |||
_set_option("dtr_evictee_minimum_size", mod._evictee_minimum_size) | |||
_set_option("dtr_evictee_minimum_size", _evictee_minimum_size) | |||
@property | |||
def enable_sqrt_sampling(mod): | |||
r""" | |||
Get or set whether sqrt sampling is allowed. Sqrt sampling means that given | |||
the size of the candidate set is N, only enumerate sqrt(N) tensors. When | |||
the number of tensors is very high, enabling this optimization will speed | |||
up the training. | |||
Examples: | |||
.. code-block:: | |||
import megengine as mge | |||
mge.dtr.enable_sqrt_sampling = True | |||
""" | |||
return _enable_sqrt_sampling | |||
@enable_sqrt_sampling.setter | |||
def enable_sqrt_sampling(mod, value: bool): | |||
global _enable_sqrt_sampling | |||
_enable_sqrt_sampling = value | |||
_set_option("enable_dtr_sqrt_sampling", _enable_sqrt_sampling) | |||
def enable(): |
@@ -761,7 +761,7 @@ bool ChannelImpl::auto_evict(size_t force_num=0) { | |||
while ((state.options.dtr_eviction_threshold > 0 && current_memory > state.options.dtr_eviction_threshold) || force_num > 0) { | |||
RECORD_EVENT(AutoEvictEvent); | |||
sample_on_device(m_dtr.comp_node, false); | |||
auto best = m_dtr.find_best_tensor(); | |||
auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling && !force_num); | |||
if (!best) { | |||
break; | |||
} | |||
@@ -988,8 +988,15 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { | |||
if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) { | |||
TensorInfo::ComputePath::make(cmd.id, cmd.op, cmd.inputs, cmd.outputs); | |||
size_t detach_cnt = 0; | |||
if (!strcmp(get_name(*cmd.op), "BatchNorm") && cmd.outputs.size() == 5) { | |||
cmd.outputs[0]->detach_producer(); // detach running_mean | |||
cmd.outputs[1]->detach_producer(); // detach running_var | |||
for (auto input : cmd.inputs) { | |||
input->ref_cnt -= 2; | |||
} | |||
} | |||
for (auto output : cmd.outputs) { | |||
if (!output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { | |||
if (output->producer && !output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { | |||
output->detach_producer(); | |||
detach_cnt ++; | |||
} | |||
@@ -1339,9 +1346,15 @@ double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) { | |||
return cost; | |||
} | |||
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor() { | |||
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(bool enable_dtr_sqrt_sampling=false) { | |||
double min_msps = -1; | |||
TensorInfo* best = nullptr; | |||
size_t sz = 1; | |||
if (enable_dtr_sqrt_sampling) { | |||
while (sz * sz <= candidates.size()) sz ++; | |||
} else { | |||
sz = candidates.size(); | |||
} | |||
for (auto i : candidates) { | |||
if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) { | |||
double neighbor_cost = estimate_neighbor_cost(i); | |||
@@ -1354,6 +1367,7 @@ TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor() { | |||
best = i; | |||
} | |||
} | |||
if (--sz == 0) break; | |||
} | |||
return best; | |||
} | |||
@@ -323,7 +323,7 @@ private: | |||
* \return the pointer of the best tensor; nullptr is returned if no | |||
* available tensor is found | |||
*/ | |||
TensorInfo* find_best_tensor(); | |||
TensorInfo* find_best_tensor(bool); | |||
/*! | |||
* \brief estimate the cost of recomputing tensor ptr | |||
@@ -41,6 +41,7 @@ public: | |||
DEF_OPTION(enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1, | |||
"enable host compute, thus computation may be done in host event if it's device is gpu."); | |||
DEF_OPTION(enable_dtr_auto_drop, "MEGENGINE_DTR_AUTO_DROP", 0, ""); | |||
DEF_OPTION(enable_dtr_sqrt_sampling, "MEGENGINE_DTR_SQRT_SAMPLING", 0, ""); | |||
DEF_OPTION(dtr_eviction_threshold, "MEGENGINE_DTR_EVICTION_THRESHOLD", 0, | |||
"auto drop will start whenever gpu memory usage exceeds this value."); | |||
DEF_OPTION(dtr_evictee_minimum_size, "MEGENGINE_DTR_EVICTEE_MINIMUM_SIZE", 1048576, | |||