GitOrigin-RevId: 1b41e1042c
release-1.10
@@ -2,6 +2,7 @@ import mprop | |||||
from ..core.tensor.amp import * | from ..core.tensor.amp import * | ||||
from .autocast import autocast | from .autocast import autocast | ||||
from .convert_format import convert_module_format, convert_tensor_format | |||||
from .grad_scaler import GradScaler | from .grad_scaler import GradScaler | ||||
mprop.init() | mprop.init() |
@@ -1,5 +1,6 @@ | |||||
import functools | import functools | ||||
from ..core import _config | |||||
from ..core.tensor import amp | from ..core.tensor import amp | ||||
@@ -50,24 +51,37 @@ class autocast: | |||||
self._origin_high = None | self._origin_high = None | ||||
self._origin_low = None | self._origin_low = None | ||||
self._origin_configs = None | |||||
def __enter__(self): | def __enter__(self): | ||||
self._origin_enabled = amp._enabled | self._origin_enabled = amp._enabled | ||||
self._origin_high = amp._get_amp_high_prec_dtype() | |||||
self._origin_low = amp._get_amp_low_prec_dtype() | |||||
amp._enabled = self.enabled | amp._enabled = self.enabled | ||||
amp._set_amp_dtype_autocast(self.enabled) | amp._set_amp_dtype_autocast(self.enabled) | ||||
if not self.enabled: | |||||
return | |||||
self._origin_high = amp._get_amp_high_prec_dtype() | |||||
self._origin_low = amp._get_amp_low_prec_dtype() | |||||
amp._set_amp_high_prec_dtype(self.high_prec_dtype) | amp._set_amp_high_prec_dtype(self.high_prec_dtype) | ||||
amp._set_amp_low_prec_dtype(self.low_prec_dtype) | amp._set_amp_low_prec_dtype(self.low_prec_dtype) | ||||
self._origin_configs = _config._reset_execution_config(compute_mode="float32") | |||||
def __exit__(self, *args): | def __exit__(self, *args): | ||||
amp._enabled = self._origin_enabled | amp._enabled = self._origin_enabled | ||||
amp._set_amp_dtype_autocast(self._origin_enabled) | amp._set_amp_dtype_autocast(self._origin_enabled) | ||||
if not self.enabled: | |||||
return | |||||
amp._set_amp_high_prec_dtype(self._origin_high) | amp._set_amp_high_prec_dtype(self._origin_high) | ||||
amp._set_amp_low_prec_dtype(self._origin_low) | amp._set_amp_low_prec_dtype(self._origin_low) | ||||
_config._reset_execution_config(*self._origin_configs) | |||||
def __call__(self, func): | def __call__(self, func): | ||||
@functools.wraps(func) | @functools.wraps(func) | ||||
def wrapper(*args, **kwargs): | def wrapper(*args, **kwargs): | ||||
if not self.enabled: | |||||
return func(*args, **kwargs) | |||||
with self: | with self: | ||||
return func(*args, **kwargs) | return func(*args, **kwargs) | ||||
@@ -0,0 +1,45 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from copy import deepcopy | |||||
from .. import functional as F | |||||
from ..module import Module | |||||
from ..tensor import Tensor | |||||
def _is_nchw_format(param: Tensor): | |||||
# TODO: use better condition | |||||
return (len(param.shape) == 4 or len(param.shape) == 5) and param.format != "nhwc" | |||||
def convert_tensor_format(x: Tensor, inplace: bool = True): | |||||
"""Convert NCHW Tensor to NHWC Tensor.""" | |||||
if x.ndim == 4: | |||||
pattern = (0, 2, 3, 1) | |||||
elif x.ndim == 5: | |||||
pattern = (0, 1, 3, 4, 2) | |||||
else: | |||||
raise ValueError("Unsupport tensor ndim {}".format(x.ndim)) | |||||
# TODO: use initialization from tensor after fixing format setting | |||||
if inplace: | |||||
x[...] = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||||
else: | |||||
x = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||||
return x | |||||
def convert_module_format(module: Module, inplace: bool = True): | |||||
"""Convert NCHW Module to NHWC Module.""" | |||||
if not inplace: | |||||
module = deepcopy(module) | |||||
for name, param in module.named_tensors(): | |||||
if _is_nchw_format(param): | |||||
# hostvalue should still be valid, so no d2h cost. | |||||
convert_tensor_format(param, inplace=True) | |||||
return module |
@@ -1,7 +1,13 @@ | |||||
import weakref | import weakref | ||||
from typing import Callable, Iterable, List, Union | from typing import Callable, Iterable, List, Union | ||||
from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option | |||||
from ..core._imperative_rt.core2 import ( | |||||
get_auto_format_convert, | |||||
pop_scope, | |||||
push_scope, | |||||
set_auto_format_convert, | |||||
set_option, | |||||
) | |||||
from ..core.autodiff.grad import Grad | from ..core.autodiff.grad import Grad | ||||
from ..core.tensor.dtype import is_differentible_dtype | from ..core.tensor.dtype import is_differentible_dtype | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
@@ -253,6 +259,8 @@ class GradManager: | |||||
""" | """ | ||||
push_scope("backward") | push_scope("backward") | ||||
set_option("record_computing_path", 0) | set_option("record_computing_path", 0) | ||||
_origin_auto_format = get_auto_format_convert() | |||||
set_auto_format_convert(False) | |||||
from ..functional import ones_like | from ..functional import ones_like | ||||
global backwarding_grad_manager | global backwarding_grad_manager | ||||
@@ -296,6 +304,7 @@ class GradManager: | |||||
self.release() | self.release() | ||||
backwarding_grad_manager = cache | backwarding_grad_manager = cache | ||||
set_option("record_computing_path", 1) | set_option("record_computing_path", 1) | ||||
set_auto_format_convert(_origin_auto_format) | |||||
pop_scope("backward") | pop_scope("backward") | ||||
def record(self): | def record(self): | ||||
@@ -10,8 +10,10 @@ from ._imperative_rt.core2 import ( | |||||
set_option, | set_option, | ||||
) | ) | ||||
# use "default" to distinguish it from None in _reset_execution_config | |||||
__compute_mode = "default" | __compute_mode = "default" | ||||
__conv_format = "default" | __conv_format = "default" | ||||
__bn_format = "default" | |||||
_benchmark_kernel = False | _benchmark_kernel = False | ||||
_deterministic_kernel = False | _deterministic_kernel = False | ||||
@@ -22,6 +24,8 @@ __all__ = [ | |||||
"disable_memory_forwarding", | "disable_memory_forwarding", | ||||
"_compute_mode", | "_compute_mode", | ||||
"_conv_format", | "_conv_format", | ||||
"_bn_format", | |||||
"_auto_format_convert", | |||||
"_override", | "_override", | ||||
] | ] | ||||
@@ -32,6 +36,7 @@ def benchmark_kernel(mod): | |||||
which means use heuristic to choose the fastest algorithm. | which means use heuristic to choose the fastest algorithm. | ||||
Examples: | Examples: | ||||
.. code-block:: | .. code-block:: | ||||
import megengine as mge | import megengine as mge | ||||
@@ -55,6 +60,7 @@ def deterministic_kernel(mod): | |||||
which means the algorithm is not reproducible. | which means the algorithm is not reproducible. | ||||
Examples: | Examples: | ||||
.. code-block:: | .. code-block:: | ||||
import megengine as mge | import megengine as mge | ||||
@@ -75,6 +81,7 @@ def async_level(mod) -> int: | |||||
which means both device and user side errors are async. | which means both device and user side errors are async. | ||||
Examples: | Examples: | ||||
.. code-block:: | .. code-block:: | ||||
import megengine as mge | import megengine as mge | ||||
@@ -110,16 +117,17 @@ def disable_memory_forwarding(mod, disable: bool): | |||||
@property | @property | ||||
def _compute_mode(mod): | def _compute_mode(mod): | ||||
r"""Get or set the precision of intermediate results. The default option is "default", | |||||
which means that no special requirements will be placed on. When set to 'float32', it | |||||
would be used for accumulator and intermediate result, but only effective when input and | |||||
r"""Get or set the precision of intermediate results for conv, matmul. The default | |||||
option is None and will fallback to "default". When set to "float32", it will | |||||
trigger mixed precision computation on TensorCore, but only effective when input and | |||||
output are of float16 dtype. | output are of float16 dtype. | ||||
Examples: | Examples: | ||||
.. code-block:: | .. code-block:: | ||||
import megengine as mge | import megengine as mge | ||||
mge.config._compute_mode = "default" | |||||
mge.config._compute_mode = "float32" | |||||
""" | """ | ||||
return __compute_mode | return __compute_mode | ||||
@@ -132,7 +140,7 @@ def _compute_mode(mod, _compute_mode: str): | |||||
@property | @property | ||||
def _conv_format(mod): | def _conv_format(mod): | ||||
r"""Get or set convolution data/filter/output layout format. The default option is "default", | |||||
r"""Get or set convolution data/filter/output layout format. The default option is None, | |||||
which means that no special format will be placed on. There are all layout definitions | which means that no special format will be placed on. There are all layout definitions | ||||
``NCHW`` layout: ``{N, C, H, W}`` | ``NCHW`` layout: ``{N, C, H, W}`` | ||||
@@ -145,6 +153,7 @@ def _conv_format(mod): | |||||
``NCHW64`` layout: ``{N, C/64, H, W, 64}`` | ``NCHW64`` layout: ``{N, C/64, H, W, 64}`` | ||||
Examples: | Examples: | ||||
.. code-block:: | .. code-block:: | ||||
import megengine as mge | import megengine as mge | ||||
@@ -160,11 +169,34 @@ def _conv_format(mod, format: str): | |||||
@property | @property | ||||
def _bn_format(mod): | |||||
r"""Get or set batchnorm param layout format. The default option is None and will | |||||
fallback to "dim_1c11" which corresponds to NCHW format. When set to "dim_111c", | |||||
param format of batchnorm will be changed to NHWC. | |||||
Examples: | |||||
.. code-block:: | |||||
import megengine as mge | |||||
mge.config._bn_format = "dim_111c" | |||||
""" | |||||
return __bn_format | |||||
@_bn_format.setter | |||||
def _bn_format(mod, format: str): | |||||
global __bn_format | |||||
__bn_format = format | |||||
@property | |||||
def _auto_format_convert(mod): | def _auto_format_convert(mod): | ||||
r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order. | r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order. | ||||
The default value is False, which means no convert. | The default value is False, which means no convert. | ||||
Examples: | Examples: | ||||
.. code-block:: | .. code-block:: | ||||
import megengine as mge | import megengine as mge | ||||
@@ -184,15 +216,17 @@ def _reset_execution_config( | |||||
async_level=None, | async_level=None, | ||||
compute_mode=None, | compute_mode=None, | ||||
conv_format=None, | conv_format=None, | ||||
bn_format=None, | |||||
auto_format_convert=None, | auto_format_convert=None, | ||||
): | ): | ||||
global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format | |||||
global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format, __bn_format | |||||
orig_flags = ( | orig_flags = ( | ||||
_benchmark_kernel, | _benchmark_kernel, | ||||
_deterministic_kernel, | _deterministic_kernel, | ||||
get_option("async_level"), | get_option("async_level"), | ||||
__compute_mode, | __compute_mode, | ||||
__conv_format, | __conv_format, | ||||
__bn_format, | |||||
get_auto_format_convert(), | get_auto_format_convert(), | ||||
) | ) | ||||
if benchmark_kernel is not None: | if benchmark_kernel is not None: | ||||
@@ -205,6 +239,8 @@ def _reset_execution_config( | |||||
__compute_mode = compute_mode | __compute_mode = compute_mode | ||||
if conv_format is not None: | if conv_format is not None: | ||||
__conv_format = conv_format | __conv_format = conv_format | ||||
if bn_format is not None: | |||||
__bn_format = bn_format | |||||
if auto_format_convert is not None: | if auto_format_convert is not None: | ||||
set_auto_format_convert(auto_format_convert) | set_auto_format_convert(auto_format_convert) | ||||
@@ -218,12 +254,14 @@ def _override( | |||||
async_level=None, | async_level=None, | ||||
compute_mode=None, | compute_mode=None, | ||||
conv_format=None, | conv_format=None, | ||||
bn_format=None, | |||||
auto_format_convert=None, | auto_format_convert=None, | ||||
): | ): | ||||
r"""A context manager that users can opt in by attaching the decorator to set | r"""A context manager that users can opt in by attaching the decorator to set | ||||
the config of the global variable. | the config of the global variable. | ||||
Examples: | Examples: | ||||
.. code-block:: | .. code-block:: | ||||
import megengine as mge | import megengine as mge | ||||
@@ -234,6 +272,7 @@ def _override( | |||||
async_level=2, | async_level=2, | ||||
compute_mode="float32", | compute_mode="float32", | ||||
conv_format="NHWC", | conv_format="NHWC", | ||||
bn_format="dim_111c", | |||||
auto_format_convert=True, | auto_format_convert=True, | ||||
) | ) | ||||
def train(): | def train(): | ||||
@@ -244,6 +283,7 @@ def _override( | |||||
async_level, | async_level, | ||||
compute_mode, | compute_mode, | ||||
conv_format, | conv_format, | ||||
bn_format, | |||||
auto_format_convert, | auto_format_convert, | ||||
) | ) | ||||
try: | try: | ||||
@@ -254,4 +294,4 @@ def _override( | |||||
def _get_actual_op_param(function_param, config_param): | def _get_actual_op_param(function_param, config_param): | ||||
return function_param if config_param == "default" else config_param | |||||
return function_param if config_param is "default" else config_param |
@@ -10,13 +10,19 @@ from .._imperative_rt.core2 import ( | |||||
_enabled = False | _enabled = False | ||||
_set_amp_dtype_autocast(_enabled) | _set_amp_dtype_autocast(_enabled) | ||||
__all__ = [ | |||||
"enabled", | |||||
"high_prec_dtype", | |||||
"low_prec_dtype", | |||||
] | |||||
@property | @property | ||||
def enabled(mod): | def enabled(mod): | ||||
r"""Get or set amp autocast mode enabled or not. | r"""Get or set amp autocast mode enabled or not. | ||||
Examples: | Examples: | ||||
.. code-block:: | .. code-block:: | ||||
import megengine as mge | import megengine as mge | ||||
@@ -36,9 +42,9 @@ def enabled(mod, enabled: bool): | |||||
def high_prec_dtype(mod): | def high_prec_dtype(mod): | ||||
r"""Get or set amp autocast mode's higher precision dtype. It will change the | r"""Get or set amp autocast mode's higher precision dtype. It will change the | ||||
target dtype in tensor casting for better precision. Default: float32. | target dtype in tensor casting for better precision. Default: float32. | ||||
Examples: | Examples: | ||||
.. code-block:: | .. code-block:: | ||||
import megengine as mge | import megengine as mge | ||||
@@ -56,9 +62,9 @@ def high_prec_dtype(mod, dtype: str): | |||||
def low_prec_dtype(mod): | def low_prec_dtype(mod): | ||||
r"""Get or set amp autocast mode's lower precision dtype. It will change the | r"""Get or set amp autocast mode's lower precision dtype. It will change the | ||||
target dtype in tensor casting for better speed and memory. Default: float16. | target dtype in tensor casting for better speed and memory. Default: float16. | ||||
Examples: | Examples: | ||||
.. code-block:: | .. code-block:: | ||||
import megengine as mge | import megengine as mge | ||||
@@ -63,6 +63,7 @@ def _matmul( | |||||
assert dim1 > 0 and dim2 > 0 | assert dim1 > 0 and dim2 > 0 | ||||
maxdim = dim1 if dim1 > dim2 else dim2 | maxdim = dim1 if dim1 > dim2 else dim2 | ||||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | ||||
if dim1 == 1 and dim2 == 1: # dispatch to Dot | if dim1 == 1 and dim2 == 1: # dispatch to Dot | ||||
(result,) = apply(builtin.Dot(), inp1, inp2) | (result,) = apply(builtin.Dot(), inp1, inp2) | ||||
return result | return result | ||||
@@ -441,7 +441,6 @@ def deformable_conv2d( | |||||
or conv_mode.name == "CROSS_CORRELATION" | or conv_mode.name == "CROSS_CORRELATION" | ||||
) | ) | ||||
if amp._enabled: | if amp._enabled: | ||||
compute_mode = "float32" | |||||
inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias) | inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias) | ||||
else: | else: | ||||
offset = offset.astype("float32") | offset = offset.astype("float32") | ||||
@@ -1182,7 +1181,6 @@ def batch_norm( | |||||
momentum: float = 0.9, | momentum: float = 0.9, | ||||
eps: float = 1e-5, | eps: float = 1e-5, | ||||
inplace: bool = True, | inplace: bool = True, | ||||
compute_mode="default", | |||||
param_dim="dim_1c11" | param_dim="dim_1c11" | ||||
): | ): | ||||
r"""Applies batch normalization to the input. | r"""Applies batch normalization to the input. | ||||
@@ -19,7 +19,6 @@ class _BatchNorm(Module): | |||||
affine=True, | affine=True, | ||||
track_running_stats=True, | track_running_stats=True, | ||||
freeze=False, | freeze=False, | ||||
compute_mode="default", | |||||
param_dim="dim_1c11", | param_dim="dim_1c11", | ||||
**kwargs | **kwargs | ||||
): | ): | ||||
@@ -31,7 +30,6 @@ class _BatchNorm(Module): | |||||
self.track_running_stats = track_running_stats | self.track_running_stats = track_running_stats | ||||
self._track_running_stats_saved = track_running_stats | self._track_running_stats_saved = track_running_stats | ||||
self.freeze = freeze | self.freeze = freeze | ||||
self.compute_mode = compute_mode | |||||
self.param_dim = param_dim | self.param_dim = param_dim | ||||
if self.freeze: | if self.freeze: | ||||
assert ( | assert ( | ||||
@@ -106,7 +104,6 @@ class _BatchNorm(Module): | |||||
or ((self.running_mean is None) and (self.running_var is None)), | or ((self.running_mean is None) and (self.running_var is None)), | ||||
momentum=exponential_average_factor, | momentum=exponential_average_factor, | ||||
eps=self.eps, | eps=self.eps, | ||||
compute_mode=self.compute_mode, | |||||
param_dim=self.param_dim, | param_dim=self.param_dim, | ||||
) | ) | ||||
@@ -8,7 +8,13 @@ from typing import Union | |||||
import numpy as np | import numpy as np | ||||
from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option | |||||
from ..core._imperative_rt.core2 import ( | |||||
get_auto_format_convert, | |||||
pop_scope, | |||||
push_scope, | |||||
set_auto_format_convert, | |||||
set_option, | |||||
) | |||||
from ..core.tensor.utils import set_convert_inputs | from ..core.tensor.utils import set_convert_inputs | ||||
from ..tensor import Parameter, Tensor | from ..tensor import Parameter, Tensor | ||||
from ..utils.deprecation import deprecated | from ..utils.deprecation import deprecated | ||||
@@ -90,7 +96,7 @@ class Optimizer(metaclass=ABCMeta): | |||||
"optimizer can only optimize Parameters, but one of the params is " | "optimizer can only optimize Parameters, but one of the params is " | ||||
+ str(type(param)) | + str(type(param)) | ||||
) | ) | ||||
param._reset(Tensor(param.numpy(), no_cache=True)) | |||||
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||||
for name, default in self._defaults.items(): | for name, default in self._defaults.items(): | ||||
if default is required and name not in param_group: | if default is required and name not in param_group: | ||||
@@ -139,6 +145,8 @@ class Optimizer(metaclass=ABCMeta): | |||||
# set the globle state `_enable_convert_inputs` to `False` to disable | # set the globle state `_enable_convert_inputs` to `False` to disable | ||||
# the `convert_inputs` for param updates | # the `convert_inputs` for param updates | ||||
set_option("record_computing_path", 0) | set_option("record_computing_path", 0) | ||||
_origin_auto_format = get_auto_format_convert() | |||||
set_auto_format_convert(False) | |||||
if self._disable_type_convert: | if self._disable_type_convert: | ||||
backup = set_convert_inputs(False) | backup = set_convert_inputs(False) | ||||
for group in self.param_groups: | for group in self.param_groups: | ||||
@@ -155,6 +163,7 @@ class Optimizer(metaclass=ABCMeta): | |||||
# restore the globle state `_enable_convert_inputs` | # restore the globle state `_enable_convert_inputs` | ||||
set_convert_inputs(backup) | set_convert_inputs(backup) | ||||
set_option("record_computing_path", 1) | set_option("record_computing_path", 1) | ||||
set_auto_format_convert(_origin_auto_format) | |||||
return self | return self | ||||
@deprecated(version="1.0", reason="use clear_grad instead") | @deprecated(version="1.0", reason="use clear_grad instead") | ||||
@@ -0,0 +1,44 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import numpy as np | |||||
import pytest | |||||
import megengine.functional as F | |||||
import megengine.module as M | |||||
from megengine import Parameter, Tensor, amp, tensor | |||||
class MyModule(M.Module): | |||||
class InnerModule(M.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.bn = M.BatchNorm2d(4) | |||||
def forward(self, x): | |||||
return self.bn(x) | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.i = self.InnerModule() | |||||
self.conv = M.Conv2d(4, 4, 4, groups=2) | |||||
self.bn = M.BatchNorm2d(4) | |||||
self.param = Parameter(np.ones((1, 3, 1, 1), dtype=np.float32)) | |||||
self.buff = Tensor(np.ones((1, 3, 1, 1), dtype=np.float32)) | |||||
def forward(self, x): | |||||
x = self.i(x) | |||||
x = self.bn(x) | |||||
return x | |||||
@pytest.mark.parametrize("is_inplace", [False, True]) | |||||
def test_convert_module(is_inplace): | |||||
m = MyModule() | |||||
m = amp.convert_module_format(m, is_inplace) | |||||
for name, param in m.named_tensors(): | |||||
assert param.format == "nhwc" |
@@ -8,14 +8,27 @@ from megengine.autodiff import GradManager | |||||
def test_basic(): | def test_basic(): | ||||
a = tensor(np.arange(0, 24).reshape((1, 2, 3, 4)), dtype="float32", format="nhwc") | |||||
data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
# init from numpy | |||||
a = tensor(data, format="nhwc") | |||||
assert a.format == "nhwc" | assert a.format == "nhwc" | ||||
# init from tensor | |||||
b = tensor(a) | b = tensor(a) | ||||
assert b.format == "nhwc" | assert b.format == "nhwc" | ||||
# TODO: fix Tensor init bug for another Tensor | |||||
# TODO: init from tensor with new format | |||||
# c = tensor(a, format="nchw") | # c = tensor(a, format="nchw") | ||||
# assert c.format == "nchw" | # assert c.format == "nchw" | ||||
# TODO: reset from numpy | |||||
# b[...] = data | |||||
# assert b.format == "nhwc" | |||||
# reset from tensor | |||||
b[...] = tensor(data, format="nchw") | |||||
assert b.format == "nchw" | |||||
def _compare_nchw_nhwc(data, func): | def _compare_nchw_nhwc(data, func): | ||||
x1 = tensor(data, format="nchw") | x1 = tensor(data, format="nchw") | ||||
@@ -23,7 +36,7 @@ def _compare_nchw_nhwc(data, func): | |||||
out1 = func(x1) | out1 = func(x1) | ||||
with mge.config._override(auto_format_convert=True): | with mge.config._override(auto_format_convert=True): | ||||
out2 = func(x2) | out2 = func(x2) | ||||
np.testing.assert_equal(out1, out2) | |||||
np.testing.assert_almost_equal(out1, out2, decimal=5) | |||||
def test_dimshuffle(): | def test_dimshuffle(): | ||||
@@ -296,8 +309,10 @@ def test_backward(): | |||||
with gm: | with gm: | ||||
with mge.config._override(auto_format_convert=True, conv_format="NHWC"): | with mge.config._override(auto_format_convert=True, conv_format="NHWC"): | ||||
x = F.conv2d(x, w, b) | x = F.conv2d(x, w, b) | ||||
# TODO: fix manually convert to NHWC, usually used in detection head | |||||
# x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2) | |||||
gm.backward(x) | gm.backward(x) | ||||
# TODO: backward grad has no format yet | |||||
# backward grad has no format | |||||
np.testing.assert_equal( | np.testing.assert_equal( | ||||
w.grad.numpy(), | w.grad.numpy(), | ||||
np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), | np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), | ||||
@@ -921,12 +921,7 @@ def test_batchnorm2d_autocast(): | |||||
amp.enabled = False | amp.enabled = False | ||||
expected = F.batch_norm( | expected = F.batch_norm( | ||||
inp.astype("float16"), | |||||
weight=weight, | |||||
bias=bias, | |||||
training=True, | |||||
inplace=False, | |||||
compute_mode="float32", | |||||
inp.astype("float16"), weight=weight, bias=bias, training=True, inplace=False, | |||||
) | ) | ||||
assert out.dtype == np.float16 | assert out.dtype == np.float16 | ||||
assert expected.dtype == np.float16 | assert expected.dtype == np.float16 | ||||