GitOrigin-RevId: 1b41e1042c
release-1.10
@@ -2,6 +2,7 @@ import mprop | |||
from ..core.tensor.amp import * | |||
from .autocast import autocast | |||
from .convert_format import convert_module_format, convert_tensor_format | |||
from .grad_scaler import GradScaler | |||
mprop.init() |
@@ -1,5 +1,6 @@ | |||
import functools | |||
from ..core import _config | |||
from ..core.tensor import amp | |||
@@ -50,24 +51,37 @@ class autocast: | |||
self._origin_high = None | |||
self._origin_low = None | |||
self._origin_configs = None | |||
def __enter__(self): | |||
self._origin_enabled = amp._enabled | |||
self._origin_high = amp._get_amp_high_prec_dtype() | |||
self._origin_low = amp._get_amp_low_prec_dtype() | |||
amp._enabled = self.enabled | |||
amp._set_amp_dtype_autocast(self.enabled) | |||
if not self.enabled: | |||
return | |||
self._origin_high = amp._get_amp_high_prec_dtype() | |||
self._origin_low = amp._get_amp_low_prec_dtype() | |||
amp._set_amp_high_prec_dtype(self.high_prec_dtype) | |||
amp._set_amp_low_prec_dtype(self.low_prec_dtype) | |||
self._origin_configs = _config._reset_execution_config(compute_mode="float32") | |||
def __exit__(self, *args): | |||
amp._enabled = self._origin_enabled | |||
amp._set_amp_dtype_autocast(self._origin_enabled) | |||
if not self.enabled: | |||
return | |||
amp._set_amp_high_prec_dtype(self._origin_high) | |||
amp._set_amp_low_prec_dtype(self._origin_low) | |||
_config._reset_execution_config(*self._origin_configs) | |||
def __call__(self, func): | |||
@functools.wraps(func) | |||
def wrapper(*args, **kwargs): | |||
if not self.enabled: | |||
return func(*args, **kwargs) | |||
with self: | |||
return func(*args, **kwargs) | |||
@@ -0,0 +1,45 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from copy import deepcopy | |||
from .. import functional as F | |||
from ..module import Module | |||
from ..tensor import Tensor | |||
def _is_nchw_format(param: Tensor): | |||
# TODO: use better condition | |||
return (len(param.shape) == 4 or len(param.shape) == 5) and param.format != "nhwc" | |||
def convert_tensor_format(x: Tensor, inplace: bool = True): | |||
"""Convert NCHW Tensor to NHWC Tensor.""" | |||
if x.ndim == 4: | |||
pattern = (0, 2, 3, 1) | |||
elif x.ndim == 5: | |||
pattern = (0, 1, 3, 4, 2) | |||
else: | |||
raise ValueError("Unsupport tensor ndim {}".format(x.ndim)) | |||
# TODO: use initialization from tensor after fixing format setting | |||
if inplace: | |||
x[...] = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||
else: | |||
x = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||
return x | |||
def convert_module_format(module: Module, inplace: bool = True): | |||
"""Convert NCHW Module to NHWC Module.""" | |||
if not inplace: | |||
module = deepcopy(module) | |||
for name, param in module.named_tensors(): | |||
if _is_nchw_format(param): | |||
# hostvalue should still be valid, so no d2h cost. | |||
convert_tensor_format(param, inplace=True) | |||
return module |
@@ -1,7 +1,13 @@ | |||
import weakref | |||
from typing import Callable, Iterable, List, Union | |||
from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option | |||
from ..core._imperative_rt.core2 import ( | |||
get_auto_format_convert, | |||
pop_scope, | |||
push_scope, | |||
set_auto_format_convert, | |||
set_option, | |||
) | |||
from ..core.autodiff.grad import Grad | |||
from ..core.tensor.dtype import is_differentible_dtype | |||
from ..logger import get_logger | |||
@@ -253,6 +259,8 @@ class GradManager: | |||
""" | |||
push_scope("backward") | |||
set_option("record_computing_path", 0) | |||
_origin_auto_format = get_auto_format_convert() | |||
set_auto_format_convert(False) | |||
from ..functional import ones_like | |||
global backwarding_grad_manager | |||
@@ -296,6 +304,7 @@ class GradManager: | |||
self.release() | |||
backwarding_grad_manager = cache | |||
set_option("record_computing_path", 1) | |||
set_auto_format_convert(_origin_auto_format) | |||
pop_scope("backward") | |||
def record(self): | |||
@@ -10,8 +10,10 @@ from ._imperative_rt.core2 import ( | |||
set_option, | |||
) | |||
# use "default" to distinguish it from None in _reset_execution_config | |||
__compute_mode = "default" | |||
__conv_format = "default" | |||
__bn_format = "default" | |||
_benchmark_kernel = False | |||
_deterministic_kernel = False | |||
@@ -22,6 +24,8 @@ __all__ = [ | |||
"disable_memory_forwarding", | |||
"_compute_mode", | |||
"_conv_format", | |||
"_bn_format", | |||
"_auto_format_convert", | |||
"_override", | |||
] | |||
@@ -32,6 +36,7 @@ def benchmark_kernel(mod): | |||
which means use heuristic to choose the fastest algorithm. | |||
Examples: | |||
.. code-block:: | |||
import megengine as mge | |||
@@ -55,6 +60,7 @@ def deterministic_kernel(mod): | |||
which means the algorithm is not reproducible. | |||
Examples: | |||
.. code-block:: | |||
import megengine as mge | |||
@@ -75,6 +81,7 @@ def async_level(mod) -> int: | |||
which means both device and user side errors are async. | |||
Examples: | |||
.. code-block:: | |||
import megengine as mge | |||
@@ -110,16 +117,17 @@ def disable_memory_forwarding(mod, disable: bool): | |||
@property | |||
def _compute_mode(mod): | |||
r"""Get or set the precision of intermediate results. The default option is "default", | |||
which means that no special requirements will be placed on. When set to 'float32', it | |||
would be used for accumulator and intermediate result, but only effective when input and | |||
r"""Get or set the precision of intermediate results for conv, matmul. The default | |||
option is None and will fallback to "default". When set to "float32", it will | |||
trigger mixed precision computation on TensorCore, but only effective when input and | |||
output are of float16 dtype. | |||
Examples: | |||
.. code-block:: | |||
import megengine as mge | |||
mge.config._compute_mode = "default" | |||
mge.config._compute_mode = "float32" | |||
""" | |||
return __compute_mode | |||
@@ -132,7 +140,7 @@ def _compute_mode(mod, _compute_mode: str): | |||
@property | |||
def _conv_format(mod): | |||
r"""Get or set convolution data/filter/output layout format. The default option is "default", | |||
r"""Get or set convolution data/filter/output layout format. The default option is None, | |||
which means that no special format will be placed on. There are all layout definitions | |||
``NCHW`` layout: ``{N, C, H, W}`` | |||
@@ -145,6 +153,7 @@ def _conv_format(mod): | |||
``NCHW64`` layout: ``{N, C/64, H, W, 64}`` | |||
Examples: | |||
.. code-block:: | |||
import megengine as mge | |||
@@ -160,11 +169,34 @@ def _conv_format(mod, format: str): | |||
@property | |||
def _bn_format(mod): | |||
r"""Get or set batchnorm param layout format. The default option is None and will | |||
fallback to "dim_1c11" which corresponds to NCHW format. When set to "dim_111c", | |||
param format of batchnorm will be changed to NHWC. | |||
Examples: | |||
.. code-block:: | |||
import megengine as mge | |||
mge.config._bn_format = "dim_111c" | |||
""" | |||
return __bn_format | |||
@_bn_format.setter | |||
def _bn_format(mod, format: str): | |||
global __bn_format | |||
__bn_format = format | |||
@property | |||
def _auto_format_convert(mod): | |||
r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order. | |||
The default value is False, which means no convert. | |||
Examples: | |||
.. code-block:: | |||
import megengine as mge | |||
@@ -184,15 +216,17 @@ def _reset_execution_config( | |||
async_level=None, | |||
compute_mode=None, | |||
conv_format=None, | |||
bn_format=None, | |||
auto_format_convert=None, | |||
): | |||
global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format | |||
global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format, __bn_format | |||
orig_flags = ( | |||
_benchmark_kernel, | |||
_deterministic_kernel, | |||
get_option("async_level"), | |||
__compute_mode, | |||
__conv_format, | |||
__bn_format, | |||
get_auto_format_convert(), | |||
) | |||
if benchmark_kernel is not None: | |||
@@ -205,6 +239,8 @@ def _reset_execution_config( | |||
__compute_mode = compute_mode | |||
if conv_format is not None: | |||
__conv_format = conv_format | |||
if bn_format is not None: | |||
__bn_format = bn_format | |||
if auto_format_convert is not None: | |||
set_auto_format_convert(auto_format_convert) | |||
@@ -218,12 +254,14 @@ def _override( | |||
async_level=None, | |||
compute_mode=None, | |||
conv_format=None, | |||
bn_format=None, | |||
auto_format_convert=None, | |||
): | |||
r"""A context manager that users can opt in by attaching the decorator to set | |||
the config of the global variable. | |||
Examples: | |||
.. code-block:: | |||
import megengine as mge | |||
@@ -234,6 +272,7 @@ def _override( | |||
async_level=2, | |||
compute_mode="float32", | |||
conv_format="NHWC", | |||
bn_format="dim_111c", | |||
auto_format_convert=True, | |||
) | |||
def train(): | |||
@@ -244,6 +283,7 @@ def _override( | |||
async_level, | |||
compute_mode, | |||
conv_format, | |||
bn_format, | |||
auto_format_convert, | |||
) | |||
try: | |||
@@ -254,4 +294,4 @@ def _override( | |||
def _get_actual_op_param(function_param, config_param): | |||
return function_param if config_param == "default" else config_param | |||
return function_param if config_param is "default" else config_param |
@@ -10,13 +10,19 @@ from .._imperative_rt.core2 import ( | |||
_enabled = False | |||
_set_amp_dtype_autocast(_enabled) | |||
__all__ = [ | |||
"enabled", | |||
"high_prec_dtype", | |||
"low_prec_dtype", | |||
] | |||
@property | |||
def enabled(mod): | |||
r"""Get or set amp autocast mode enabled or not. | |||
Examples: | |||
.. code-block:: | |||
import megengine as mge | |||
@@ -36,9 +42,9 @@ def enabled(mod, enabled: bool): | |||
def high_prec_dtype(mod): | |||
r"""Get or set amp autocast mode's higher precision dtype. It will change the | |||
target dtype in tensor casting for better precision. Default: float32. | |||
Examples: | |||
.. code-block:: | |||
import megengine as mge | |||
@@ -56,9 +62,9 @@ def high_prec_dtype(mod, dtype: str): | |||
def low_prec_dtype(mod): | |||
r"""Get or set amp autocast mode's lower precision dtype. It will change the | |||
target dtype in tensor casting for better speed and memory. Default: float16. | |||
Examples: | |||
.. code-block:: | |||
import megengine as mge | |||
@@ -63,6 +63,7 @@ def _matmul( | |||
assert dim1 > 0 and dim2 > 0 | |||
maxdim = dim1 if dim1 > dim2 else dim2 | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
if dim1 == 1 and dim2 == 1: # dispatch to Dot | |||
(result,) = apply(builtin.Dot(), inp1, inp2) | |||
return result | |||
@@ -441,7 +441,6 @@ def deformable_conv2d( | |||
or conv_mode.name == "CROSS_CORRELATION" | |||
) | |||
if amp._enabled: | |||
compute_mode = "float32" | |||
inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias) | |||
else: | |||
offset = offset.astype("float32") | |||
@@ -1182,7 +1181,6 @@ def batch_norm( | |||
momentum: float = 0.9, | |||
eps: float = 1e-5, | |||
inplace: bool = True, | |||
compute_mode="default", | |||
param_dim="dim_1c11" | |||
): | |||
r"""Applies batch normalization to the input. | |||
@@ -19,7 +19,6 @@ class _BatchNorm(Module): | |||
affine=True, | |||
track_running_stats=True, | |||
freeze=False, | |||
compute_mode="default", | |||
param_dim="dim_1c11", | |||
**kwargs | |||
): | |||
@@ -31,7 +30,6 @@ class _BatchNorm(Module): | |||
self.track_running_stats = track_running_stats | |||
self._track_running_stats_saved = track_running_stats | |||
self.freeze = freeze | |||
self.compute_mode = compute_mode | |||
self.param_dim = param_dim | |||
if self.freeze: | |||
assert ( | |||
@@ -106,7 +104,6 @@ class _BatchNorm(Module): | |||
or ((self.running_mean is None) and (self.running_var is None)), | |||
momentum=exponential_average_factor, | |||
eps=self.eps, | |||
compute_mode=self.compute_mode, | |||
param_dim=self.param_dim, | |||
) | |||
@@ -8,7 +8,13 @@ from typing import Union | |||
import numpy as np | |||
from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option | |||
from ..core._imperative_rt.core2 import ( | |||
get_auto_format_convert, | |||
pop_scope, | |||
push_scope, | |||
set_auto_format_convert, | |||
set_option, | |||
) | |||
from ..core.tensor.utils import set_convert_inputs | |||
from ..tensor import Parameter, Tensor | |||
from ..utils.deprecation import deprecated | |||
@@ -90,7 +96,7 @@ class Optimizer(metaclass=ABCMeta): | |||
"optimizer can only optimize Parameters, but one of the params is " | |||
+ str(type(param)) | |||
) | |||
param._reset(Tensor(param.numpy(), no_cache=True)) | |||
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||
for name, default in self._defaults.items(): | |||
if default is required and name not in param_group: | |||
@@ -139,6 +145,8 @@ class Optimizer(metaclass=ABCMeta): | |||
# set the globle state `_enable_convert_inputs` to `False` to disable | |||
# the `convert_inputs` for param updates | |||
set_option("record_computing_path", 0) | |||
_origin_auto_format = get_auto_format_convert() | |||
set_auto_format_convert(False) | |||
if self._disable_type_convert: | |||
backup = set_convert_inputs(False) | |||
for group in self.param_groups: | |||
@@ -155,6 +163,7 @@ class Optimizer(metaclass=ABCMeta): | |||
# restore the globle state `_enable_convert_inputs` | |||
set_convert_inputs(backup) | |||
set_option("record_computing_path", 1) | |||
set_auto_format_convert(_origin_auto_format) | |||
return self | |||
@deprecated(version="1.0", reason="use clear_grad instead") | |||
@@ -0,0 +1,44 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
import pytest | |||
import megengine.functional as F | |||
import megengine.module as M | |||
from megengine import Parameter, Tensor, amp, tensor | |||
class MyModule(M.Module): | |||
class InnerModule(M.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.bn = M.BatchNorm2d(4) | |||
def forward(self, x): | |||
return self.bn(x) | |||
def __init__(self): | |||
super().__init__() | |||
self.i = self.InnerModule() | |||
self.conv = M.Conv2d(4, 4, 4, groups=2) | |||
self.bn = M.BatchNorm2d(4) | |||
self.param = Parameter(np.ones((1, 3, 1, 1), dtype=np.float32)) | |||
self.buff = Tensor(np.ones((1, 3, 1, 1), dtype=np.float32)) | |||
def forward(self, x): | |||
x = self.i(x) | |||
x = self.bn(x) | |||
return x | |||
@pytest.mark.parametrize("is_inplace", [False, True]) | |||
def test_convert_module(is_inplace): | |||
m = MyModule() | |||
m = amp.convert_module_format(m, is_inplace) | |||
for name, param in m.named_tensors(): | |||
assert param.format == "nhwc" |
@@ -8,14 +8,27 @@ from megengine.autodiff import GradManager | |||
def test_basic(): | |||
a = tensor(np.arange(0, 24).reshape((1, 2, 3, 4)), dtype="float32", format="nhwc") | |||
data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||
# init from numpy | |||
a = tensor(data, format="nhwc") | |||
assert a.format == "nhwc" | |||
# init from tensor | |||
b = tensor(a) | |||
assert b.format == "nhwc" | |||
# TODO: fix Tensor init bug for another Tensor | |||
# TODO: init from tensor with new format | |||
# c = tensor(a, format="nchw") | |||
# assert c.format == "nchw" | |||
# TODO: reset from numpy | |||
# b[...] = data | |||
# assert b.format == "nhwc" | |||
# reset from tensor | |||
b[...] = tensor(data, format="nchw") | |||
assert b.format == "nchw" | |||
def _compare_nchw_nhwc(data, func): | |||
x1 = tensor(data, format="nchw") | |||
@@ -23,7 +36,7 @@ def _compare_nchw_nhwc(data, func): | |||
out1 = func(x1) | |||
with mge.config._override(auto_format_convert=True): | |||
out2 = func(x2) | |||
np.testing.assert_equal(out1, out2) | |||
np.testing.assert_almost_equal(out1, out2, decimal=5) | |||
def test_dimshuffle(): | |||
@@ -296,8 +309,10 @@ def test_backward(): | |||
with gm: | |||
with mge.config._override(auto_format_convert=True, conv_format="NHWC"): | |||
x = F.conv2d(x, w, b) | |||
# TODO: fix manually convert to NHWC, usually used in detection head | |||
# x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2) | |||
gm.backward(x) | |||
# TODO: backward grad has no format yet | |||
# backward grad has no format | |||
np.testing.assert_equal( | |||
w.grad.numpy(), | |||
np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), | |||
@@ -921,12 +921,7 @@ def test_batchnorm2d_autocast(): | |||
amp.enabled = False | |||
expected = F.batch_norm( | |||
inp.astype("float16"), | |||
weight=weight, | |||
bias=bias, | |||
training=True, | |||
inplace=False, | |||
compute_mode="float32", | |||
inp.astype("float16"), weight=weight, bias=bias, training=True, inplace=False, | |||
) | |||
assert out.dtype == np.float16 | |||
assert expected.dtype == np.float16 | |||