@@ -0,0 +1,3 @@ | |||||
from .core._config import * | |||||
__import__("mprop").init() |
@@ -0,0 +1,205 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import os | |||||
from contextlib import contextmanager | |||||
__compute_mode = "default" | |||||
__conv_format = "default" | |||||
_benchmark_kernel = False | |||||
_deterministic_kernel = False | |||||
_async_level = os.getenv("MEGENGINE_INTERP_ASYNC_LEVEL", 2) | |||||
__all__ = [ | |||||
"benchmark_kernel", | |||||
"deterministic_kernel", | |||||
"async_level", | |||||
"_compute_mode", | |||||
"_conv_format", | |||||
"_override", | |||||
] | |||||
@property | |||||
def benchmark_kernel(mod): | |||||
r"""Whether or not run possible algorithms on real device to find the best one. The default option is false, | |||||
which means use heuristic to choose the fastest algorithm. | |||||
Examples: | |||||
.. code-block:: | |||||
import megengine as mge | |||||
mge.config.benchmark_kernel = True | |||||
""" | |||||
return _benchmark_kernel | |||||
@benchmark_kernel.setter | |||||
def benchmark_kernel(mod, option: bool): | |||||
global _benchmark_kernel | |||||
_benchmark_kernel = option | |||||
@property | |||||
def deterministic_kernel(mod): | |||||
r"""Whether or not the fastest algorithm choosed is reproducible. The default option is false, | |||||
which means the algorithm is not reproducible. | |||||
Examples: | |||||
.. code-block:: | |||||
import megengine as mge | |||||
mge.config.deterministic_kernel = True | |||||
""" | |||||
return _deterministic_kernel | |||||
@deterministic_kernel.setter | |||||
def deterministic_kernel(mod, option: bool): | |||||
global _deterministic_kernel | |||||
_deterministic_kernel = option | |||||
@property | |||||
def async_level(mod) -> int: | |||||
r"""Get or set config whether raise error exactly when invoking op. The default level is 2, | |||||
which means both device and user side errors are async. | |||||
Examples: | |||||
.. code-block:: | |||||
import megengine as mge | |||||
mge.config.async_level = 2 | |||||
""" | |||||
return _async_level | |||||
@async_level.setter | |||||
def async_level(mod, level: int): | |||||
global _async_level | |||||
_async_level = level | |||||
@property | |||||
def _compute_mode(mod): | |||||
r"""Get or set the precision of intermediate results. The default option is "default", | |||||
which means that no special requirements will be placed on. When set to 'float32', it | |||||
would be used for accumulator and intermediate result, but only effective when input and | |||||
output are of float16 dtype. | |||||
Examples: | |||||
.. code-block:: | |||||
import megengine as mge | |||||
mge.config._compute_mode = "default" | |||||
""" | |||||
return __compute_mode | |||||
@_compute_mode.setter | |||||
def _compute_mode(mod, _compute_mode: str): | |||||
global __compute_mode | |||||
__compute_mode = _compute_mode | |||||
@property | |||||
def _conv_format(mod): | |||||
r"""Get or set convolution data/filter/output layout format. The default option is "default", | |||||
which means that no special format will be placed on. There are all layout definitions | |||||
``NCHW`` layout: ``{N, C, H, W}`` | |||||
``NHWC`` layout: ``{N, H, W, C}`` | |||||
``NHWCD4`` layout: ``{N, H, (C + 3) / 4, W, 4}`` | |||||
``NHWCD4I`` layout: with ``align_axis = 2`` | |||||
``NCHW4`` layout: ``{N, C/4, H, W, 4}`` | |||||
``NCHW88`` layout: ``{N, C/8, H, W, 8}`` | |||||
``CHWN4`` layout: ``{C/4, H, W, N, 4}`` | |||||
``NCHW64`` layout: ``{N, C/64, H, W, 64}`` | |||||
Examples: | |||||
.. code-block:: | |||||
import megengine as mge | |||||
mge.config._conv_format = "NHWC" | |||||
""" | |||||
return __conv_format | |||||
@_conv_format.setter | |||||
def _conv_format(mod, format: str): | |||||
global __conv_format | |||||
__conv_format = format | |||||
def _reset_execution_config( | |||||
benchmark_kernel=None, | |||||
deterministic_kernel=None, | |||||
async_level=None, | |||||
compute_mode=None, | |||||
conv_format=None, | |||||
): | |||||
global _benchmark_kernel, _deterministic_kernel, _async_level, __compute_mode, __conv_format | |||||
orig_flags = ( | |||||
_benchmark_kernel, | |||||
_deterministic_kernel, | |||||
_async_level, | |||||
__compute_mode, | |||||
__conv_format, | |||||
) | |||||
if benchmark_kernel is not None: | |||||
_benchmark_kernel = benchmark_kernel | |||||
if deterministic_kernel is not None: | |||||
_deterministic_kernel = deterministic_kernel | |||||
if async_level is not None: | |||||
_async_level = async_level | |||||
if compute_mode is not None: | |||||
__compute_mode = compute_mode | |||||
if conv_format is not None: | |||||
__conv_format = conv_format | |||||
return orig_flags | |||||
@contextmanager | |||||
def _override( | |||||
benchmark_kernel=None, | |||||
deterministic_kernel=None, | |||||
async_level=None, | |||||
compute_mode=None, | |||||
conv_format=None, | |||||
): | |||||
r"""A context manager that users can opt in by attaching the decorator to set | |||||
the config of the global variable. | |||||
Examples: | |||||
.. code-block:: | |||||
import megengine as mge | |||||
@mge.config._override( | |||||
benchmark_kernel = True, | |||||
deterministic_kernel = Fasle, | |||||
async_level=2, | |||||
compute_mode="float32", | |||||
conv_format="NHWC", | |||||
) | |||||
def train(): | |||||
""" | |||||
orig_flags = _reset_execution_config( | |||||
benchmark_kernel, deterministic_kernel, async_level, compute_mode, conv_format, | |||||
) | |||||
try: | |||||
yield | |||||
finally: | |||||
# recover the previous values | |||||
_reset_execution_config(*orig_flags) | |||||
def _get_actual_op_param(function_param, config_param): | |||||
return function_param if config_param == "default" else config_param |
@@ -8,26 +8,38 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import os | import os | ||||
from ..core import _config | |||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from ..utils.deprecation import deprecated | from ..utils.deprecation import deprecated | ||||
Strategy = builtin.ops.Convolution.Strategy | Strategy = builtin.ops.Convolution.Strategy | ||||
_execution_strategy = os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC") | |||||
if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None: | if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None: | ||||
get_logger().warning( | get_logger().warning( | ||||
"Environment variable `MEGENGINE_CONV_EXECUTION_STRATEGY` is deprecated, please use `MEGENGINE_EXECUTION_STRATEGY`" | "Environment variable `MEGENGINE_CONV_EXECUTION_STRATEGY` is deprecated, please use `MEGENGINE_EXECUTION_STRATEGY`" | ||||
) | ) | ||||
_valid_string_option = { | |||||
"REPRODUCIBLE": Strategy.REPRODUCIBLE, | |||||
"HEURISTIC": Strategy.HEURISTIC, | |||||
"PROFILE": Strategy.PROFILE, | |||||
} | |||||
def get_execution_strategy() -> Strategy: | def get_execution_strategy() -> Strategy: | ||||
r"""Returns the execution strategy of :class:`~module..Conv2d` and :func:`~.matmul` | r"""Returns the execution strategy of :class:`~module..Conv2d` and :func:`~.matmul` | ||||
See :func:`~.set_execution_strategy` for possible return values | See :func:`~.set_execution_strategy` for possible return values | ||||
""" | """ | ||||
return _execution_strategy | |||||
strategy = Strategy(0) | |||||
if _config._benchmark_kernel: | |||||
strategy |= Strategy.PROFILE | |||||
else: | |||||
strategy |= Strategy.HEURISTIC | |||||
if _config._deterministic_kernel: | |||||
strategy |= Strategy.REPRODUCIBLE | |||||
return strategy | |||||
def set_execution_strategy(option): | def set_execution_strategy(option): | ||||
@@ -50,7 +62,6 @@ def set_execution_strategy(option): | |||||
* 'HEURISTIC' uses heuristic to choose the fastest algorithm. | * 'HEURISTIC' uses heuristic to choose the fastest algorithm. | ||||
* 'PROFILE' runs possible algorithms on real device to find the best one. | * 'PROFILE' runs possible algorithms on real device to find the best one. | ||||
* 'PROFILE_HEURISTIC' uses profiling result and heuristic to choose the fastest algorithm. | |||||
* 'PROFILE_REPRODUCIBLE' uses the fastest of profiling result that is also reproducible. | * 'PROFILE_REPRODUCIBLE' uses the fastest of profiling result that is also reproducible. | ||||
* 'HEURISTIC_REPRODUCIBLE' uses heuristic to choose the fastest algorithm that is also reproducible. | * 'HEURISTIC_REPRODUCIBLE' uses heuristic to choose the fastest algorithm that is also reproducible. | ||||
@@ -58,29 +69,33 @@ def set_execution_strategy(option): | |||||
It can also be set through the environment variable 'MEGENGINE_EXECUTION_STRATEGY'. | It can also be set through the environment variable 'MEGENGINE_EXECUTION_STRATEGY'. | ||||
""" | """ | ||||
valid_string_option = { | |||||
"REPRODUCIBLE": Strategy.REPRODUCIBLE, | |||||
"HEURISTIC": Strategy.HEURISTIC, | |||||
"PROFILE": Strategy.PROFILE, | |||||
} | |||||
global _execution_strategy # pylint: disable=global-statement | |||||
if isinstance(option, Strategy): | if isinstance(option, Strategy): | ||||
_execution_strategy = option | |||||
_config._benchmark_kernel = ( | |||||
True if option & _valid_string_option["PROFILE"] != Strategy(0) else False | |||||
) | |||||
_config._deterministic_kernel = ( | |||||
True | |||||
if option & _valid_string_option["REPRODUCIBLE"] != Strategy(0) | |||||
else False | |||||
) | |||||
return | return | ||||
assert isinstance(option, str) | assert isinstance(option, str) | ||||
strategy_tmp = Strategy(0) | |||||
_config._benchmark_kernel = False | |||||
_config._deterministic_kernel = False | |||||
for opt in option.split("_"): | for opt in option.split("_"): | ||||
if not opt in valid_string_option: | |||||
if not opt in _valid_string_option: | |||||
raise ValueError( | raise ValueError( | ||||
"Valid option can only be one of {}, or combine them with '_'.".format( | "Valid option can only be one of {}, or combine them with '_'.".format( | ||||
valid_string_option.keys() | |||||
_valid_string_option.keys() | |||||
) | ) | ||||
) | ) | ||||
strategy_tmp = strategy_tmp | valid_string_option[opt] | |||||
_execution_strategy = strategy_tmp | |||||
_config._benchmark_kernel |= _valid_string_option[opt] == Strategy.PROFILE | |||||
_config._deterministic_kernel |= ( | |||||
_valid_string_option[opt] == Strategy.REPRODUCIBLE | |||||
) | |||||
@deprecated(version="1.3", reason="use get_execution_strategy() instead") | @deprecated(version="1.3", reason="use get_execution_strategy() instead") | ||||
@@ -91,3 +106,6 @@ def get_conv_execution_strategy() -> str: | |||||
@deprecated(version="1.3", reason="use set_execution_strategy() instead") | @deprecated(version="1.3", reason="use set_execution_strategy() instead") | ||||
def set_conv_execution_strategy(option: str): | def set_conv_execution_strategy(option: str): | ||||
return set_execution_strategy(option) | return set_execution_strategy(option) | ||||
set_execution_strategy(os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC")) |