Browse Source

feat(mge/amp): add convert_format module for NHWC training

GitOrigin-RevId: 1b41e1042c
release-1.10
Megvii Engine Team 3 years ago
parent
commit
e393d1cf65
13 changed files with 207 additions and 33 deletions
  1. +1
    -0
      imperative/python/megengine/amp/__init__.py
  2. +16
    -2
      imperative/python/megengine/amp/autocast.py
  3. +45
    -0
      imperative/python/megengine/amp/convert_format.py
  4. +10
    -1
      imperative/python/megengine/autodiff/grad_manager.py
  5. +47
    -7
      imperative/python/megengine/core/_config.py
  6. +12
    -6
      imperative/python/megengine/core/tensor/amp.py
  7. +1
    -0
      imperative/python/megengine/core/tensor/array_method.py
  8. +0
    -2
      imperative/python/megengine/functional/nn.py
  9. +0
    -3
      imperative/python/megengine/module/batchnorm.py
  10. +11
    -2
      imperative/python/megengine/optimizer/optimizer.py
  11. +44
    -0
      imperative/python/test/unit/amp/test_convert_format.py
  12. +19
    -4
      imperative/python/test/unit/core/test_formatted_tensor.py
  13. +1
    -6
      imperative/python/test/unit/functional/test_functional.py

+ 1
- 0
imperative/python/megengine/amp/__init__.py View File

@@ -2,6 +2,7 @@ import mprop

from ..core.tensor.amp import *
from .autocast import autocast
from .convert_format import convert_module_format, convert_tensor_format
from .grad_scaler import GradScaler

mprop.init()

+ 16
- 2
imperative/python/megengine/amp/autocast.py View File

@@ -1,5 +1,6 @@
import functools

from ..core import _config
from ..core.tensor import amp


@@ -50,24 +51,37 @@ class autocast:
self._origin_high = None
self._origin_low = None

self._origin_configs = None

def __enter__(self):
self._origin_enabled = amp._enabled
self._origin_high = amp._get_amp_high_prec_dtype()
self._origin_low = amp._get_amp_low_prec_dtype()
amp._enabled = self.enabled
amp._set_amp_dtype_autocast(self.enabled)
if not self.enabled:
return

self._origin_high = amp._get_amp_high_prec_dtype()
self._origin_low = amp._get_amp_low_prec_dtype()
amp._set_amp_high_prec_dtype(self.high_prec_dtype)
amp._set_amp_low_prec_dtype(self.low_prec_dtype)

self._origin_configs = _config._reset_execution_config(compute_mode="float32")

def __exit__(self, *args):
amp._enabled = self._origin_enabled
amp._set_amp_dtype_autocast(self._origin_enabled)
if not self.enabled:
return
amp._set_amp_high_prec_dtype(self._origin_high)
amp._set_amp_low_prec_dtype(self._origin_low)

_config._reset_execution_config(*self._origin_configs)

def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not self.enabled:
return func(*args, **kwargs)
with self:
return func(*args, **kwargs)



+ 45
- 0
imperative/python/megengine/amp/convert_format.py View File

@@ -0,0 +1,45 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy

from .. import functional as F
from ..module import Module
from ..tensor import Tensor


def _is_nchw_format(param: Tensor):
# TODO: use better condition
return (len(param.shape) == 4 or len(param.shape) == 5) and param.format != "nhwc"


def convert_tensor_format(x: Tensor, inplace: bool = True):
"""Convert NCHW Tensor to NHWC Tensor."""
if x.ndim == 4:
pattern = (0, 2, 3, 1)
elif x.ndim == 5:
pattern = (0, 1, 3, 4, 2)
else:
raise ValueError("Unsupport tensor ndim {}".format(x.ndim))
# TODO: use initialization from tensor after fixing format setting
if inplace:
x[...] = Tensor(x.numpy().transpose(*pattern), format="nhwc")
else:
x = Tensor(x.numpy().transpose(*pattern), format="nhwc")
return x


def convert_module_format(module: Module, inplace: bool = True):
"""Convert NCHW Module to NHWC Module."""
if not inplace:
module = deepcopy(module)

for name, param in module.named_tensors():
if _is_nchw_format(param):
# hostvalue should still be valid, so no d2h cost.
convert_tensor_format(param, inplace=True)
return module

+ 10
- 1
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -1,7 +1,13 @@
import weakref
from typing import Callable, Iterable, List, Union

from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
from ..core._imperative_rt.core2 import (
get_auto_format_convert,
pop_scope,
push_scope,
set_auto_format_convert,
set_option,
)
from ..core.autodiff.grad import Grad
from ..core.tensor.dtype import is_differentible_dtype
from ..logger import get_logger
@@ -253,6 +259,8 @@ class GradManager:
"""
push_scope("backward")
set_option("record_computing_path", 0)
_origin_auto_format = get_auto_format_convert()
set_auto_format_convert(False)
from ..functional import ones_like

global backwarding_grad_manager
@@ -296,6 +304,7 @@ class GradManager:
self.release()
backwarding_grad_manager = cache
set_option("record_computing_path", 1)
set_auto_format_convert(_origin_auto_format)
pop_scope("backward")

def record(self):


+ 47
- 7
imperative/python/megengine/core/_config.py View File

@@ -10,8 +10,10 @@ from ._imperative_rt.core2 import (
set_option,
)

# use "default" to distinguish it from None in _reset_execution_config
__compute_mode = "default"
__conv_format = "default"
__bn_format = "default"
_benchmark_kernel = False
_deterministic_kernel = False

@@ -22,6 +24,8 @@ __all__ = [
"disable_memory_forwarding",
"_compute_mode",
"_conv_format",
"_bn_format",
"_auto_format_convert",
"_override",
]

@@ -32,6 +36,7 @@ def benchmark_kernel(mod):
which means use heuristic to choose the fastest algorithm.

Examples:

.. code-block::

import megengine as mge
@@ -55,6 +60,7 @@ def deterministic_kernel(mod):
which means the algorithm is not reproducible.

Examples:

.. code-block::

import megengine as mge
@@ -75,6 +81,7 @@ def async_level(mod) -> int:
which means both device and user side errors are async.

Examples:

.. code-block::

import megengine as mge
@@ -110,16 +117,17 @@ def disable_memory_forwarding(mod, disable: bool):

@property
def _compute_mode(mod):
r"""Get or set the precision of intermediate results. The default option is "default",
which means that no special requirements will be placed on. When set to 'float32', it
would be used for accumulator and intermediate result, but only effective when input and
r"""Get or set the precision of intermediate results for conv, matmul. The default
option is None and will fallback to "default". When set to "float32", it will
trigger mixed precision computation on TensorCore, but only effective when input and
output are of float16 dtype.

Examples:

.. code-block::

import megengine as mge
mge.config._compute_mode = "default"
mge.config._compute_mode = "float32"
"""
return __compute_mode

@@ -132,7 +140,7 @@ def _compute_mode(mod, _compute_mode: str):

@property
def _conv_format(mod):
r"""Get or set convolution data/filter/output layout format. The default option is "default",
r"""Get or set convolution data/filter/output layout format. The default option is None,
which means that no special format will be placed on. There are all layout definitions

``NCHW`` layout: ``{N, C, H, W}``
@@ -145,6 +153,7 @@ def _conv_format(mod):
``NCHW64`` layout: ``{N, C/64, H, W, 64}``

Examples:

.. code-block::

import megengine as mge
@@ -160,11 +169,34 @@ def _conv_format(mod, format: str):


@property
def _bn_format(mod):
r"""Get or set batchnorm param layout format. The default option is None and will
fallback to "dim_1c11" which corresponds to NCHW format. When set to "dim_111c",
param format of batchnorm will be changed to NHWC.

Examples:

.. code-block::

import megengine as mge
mge.config._bn_format = "dim_111c"
"""
return __bn_format


@_bn_format.setter
def _bn_format(mod, format: str):
global __bn_format
__bn_format = format


@property
def _auto_format_convert(mod):
r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order.
The default value is False, which means no convert.

Examples:

.. code-block::

import megengine as mge
@@ -184,15 +216,17 @@ def _reset_execution_config(
async_level=None,
compute_mode=None,
conv_format=None,
bn_format=None,
auto_format_convert=None,
):
global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format
global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format, __bn_format
orig_flags = (
_benchmark_kernel,
_deterministic_kernel,
get_option("async_level"),
__compute_mode,
__conv_format,
__bn_format,
get_auto_format_convert(),
)
if benchmark_kernel is not None:
@@ -205,6 +239,8 @@ def _reset_execution_config(
__compute_mode = compute_mode
if conv_format is not None:
__conv_format = conv_format
if bn_format is not None:
__bn_format = bn_format
if auto_format_convert is not None:
set_auto_format_convert(auto_format_convert)

@@ -218,12 +254,14 @@ def _override(
async_level=None,
compute_mode=None,
conv_format=None,
bn_format=None,
auto_format_convert=None,
):
r"""A context manager that users can opt in by attaching the decorator to set
the config of the global variable.

Examples:

.. code-block::

import megengine as mge
@@ -234,6 +272,7 @@ def _override(
async_level=2,
compute_mode="float32",
conv_format="NHWC",
bn_format="dim_111c",
auto_format_convert=True,
)
def train():
@@ -244,6 +283,7 @@ def _override(
async_level,
compute_mode,
conv_format,
bn_format,
auto_format_convert,
)
try:
@@ -254,4 +294,4 @@ def _override(


def _get_actual_op_param(function_param, config_param):
return function_param if config_param == "default" else config_param
return function_param if config_param is "default" else config_param

+ 12
- 6
imperative/python/megengine/core/tensor/amp.py View File

@@ -10,13 +10,19 @@ from .._imperative_rt.core2 import (
_enabled = False
_set_amp_dtype_autocast(_enabled)

__all__ = [
"enabled",
"high_prec_dtype",
"low_prec_dtype",
]


@property
def enabled(mod):
r"""Get or set amp autocast mode enabled or not.
Examples:
.. code-block::

import megengine as mge
@@ -36,9 +42,9 @@ def enabled(mod, enabled: bool):
def high_prec_dtype(mod):
r"""Get or set amp autocast mode's higher precision dtype. It will change the
target dtype in tensor casting for better precision. Default: float32.
Examples:
.. code-block::

import megengine as mge
@@ -56,9 +62,9 @@ def high_prec_dtype(mod, dtype: str):
def low_prec_dtype(mod):
r"""Get or set amp autocast mode's lower precision dtype. It will change the
target dtype in tensor casting for better speed and memory. Default: float16.
Examples:
.. code-block::

import megengine as mge


+ 1
- 0
imperative/python/megengine/core/tensor/array_method.py View File

@@ -63,6 +63,7 @@ def _matmul(
assert dim1 > 0 and dim2 > 0
maxdim = dim1 if dim1 > dim2 else dim2
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)

if dim1 == 1 and dim2 == 1: # dispatch to Dot
(result,) = apply(builtin.Dot(), inp1, inp2)
return result


+ 0
- 2
imperative/python/megengine/functional/nn.py View File

@@ -441,7 +441,6 @@ def deformable_conv2d(
or conv_mode.name == "CROSS_CORRELATION"
)
if amp._enabled:
compute_mode = "float32"
inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias)
else:
offset = offset.astype("float32")
@@ -1182,7 +1181,6 @@ def batch_norm(
momentum: float = 0.9,
eps: float = 1e-5,
inplace: bool = True,
compute_mode="default",
param_dim="dim_1c11"
):
r"""Applies batch normalization to the input.


+ 0
- 3
imperative/python/megengine/module/batchnorm.py View File

@@ -19,7 +19,6 @@ class _BatchNorm(Module):
affine=True,
track_running_stats=True,
freeze=False,
compute_mode="default",
param_dim="dim_1c11",
**kwargs
):
@@ -31,7 +30,6 @@ class _BatchNorm(Module):
self.track_running_stats = track_running_stats
self._track_running_stats_saved = track_running_stats
self.freeze = freeze
self.compute_mode = compute_mode
self.param_dim = param_dim
if self.freeze:
assert (
@@ -106,7 +104,6 @@ class _BatchNorm(Module):
or ((self.running_mean is None) and (self.running_var is None)),
momentum=exponential_average_factor,
eps=self.eps,
compute_mode=self.compute_mode,
param_dim=self.param_dim,
)



+ 11
- 2
imperative/python/megengine/optimizer/optimizer.py View File

@@ -8,7 +8,13 @@ from typing import Union

import numpy as np

from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
from ..core._imperative_rt.core2 import (
get_auto_format_convert,
pop_scope,
push_scope,
set_auto_format_convert,
set_option,
)
from ..core.tensor.utils import set_convert_inputs
from ..tensor import Parameter, Tensor
from ..utils.deprecation import deprecated
@@ -90,7 +96,7 @@ class Optimizer(metaclass=ABCMeta):
"optimizer can only optimize Parameters, but one of the params is "
+ str(type(param))
)
param._reset(Tensor(param.numpy(), no_cache=True))
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format))

for name, default in self._defaults.items():
if default is required and name not in param_group:
@@ -139,6 +145,8 @@ class Optimizer(metaclass=ABCMeta):
# set the globle state `_enable_convert_inputs` to `False` to disable
# the `convert_inputs` for param updates
set_option("record_computing_path", 0)
_origin_auto_format = get_auto_format_convert()
set_auto_format_convert(False)
if self._disable_type_convert:
backup = set_convert_inputs(False)
for group in self.param_groups:
@@ -155,6 +163,7 @@ class Optimizer(metaclass=ABCMeta):
# restore the globle state `_enable_convert_inputs`
set_convert_inputs(backup)
set_option("record_computing_path", 1)
set_auto_format_convert(_origin_auto_format)
return self

@deprecated(version="1.0", reason="use clear_grad instead")


+ 44
- 0
imperative/python/test/unit/amp/test_convert_format.py View File

@@ -0,0 +1,44 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest

import megengine.functional as F
import megengine.module as M
from megengine import Parameter, Tensor, amp, tensor


class MyModule(M.Module):
class InnerModule(M.Module):
def __init__(self):
super().__init__()
self.bn = M.BatchNorm2d(4)

def forward(self, x):
return self.bn(x)

def __init__(self):
super().__init__()
self.i = self.InnerModule()
self.conv = M.Conv2d(4, 4, 4, groups=2)
self.bn = M.BatchNorm2d(4)
self.param = Parameter(np.ones((1, 3, 1, 1), dtype=np.float32))
self.buff = Tensor(np.ones((1, 3, 1, 1), dtype=np.float32))

def forward(self, x):
x = self.i(x)
x = self.bn(x)
return x


@pytest.mark.parametrize("is_inplace", [False, True])
def test_convert_module(is_inplace):
m = MyModule()
m = amp.convert_module_format(m, is_inplace)
for name, param in m.named_tensors():
assert param.format == "nhwc"

+ 19
- 4
imperative/python/test/unit/core/test_formatted_tensor.py View File

@@ -8,14 +8,27 @@ from megengine.autodiff import GradManager


def test_basic():
a = tensor(np.arange(0, 24).reshape((1, 2, 3, 4)), dtype="float32", format="nhwc")
data = np.arange(0, 24).reshape((1, 2, 3, 4))
# init from numpy
a = tensor(data, format="nhwc")
assert a.format == "nhwc"

# init from tensor
b = tensor(a)
assert b.format == "nhwc"
# TODO: fix Tensor init bug for another Tensor

# TODO: init from tensor with new format
# c = tensor(a, format="nchw")
# assert c.format == "nchw"

# TODO: reset from numpy
# b[...] = data
# assert b.format == "nhwc"

# reset from tensor
b[...] = tensor(data, format="nchw")
assert b.format == "nchw"


def _compare_nchw_nhwc(data, func):
x1 = tensor(data, format="nchw")
@@ -23,7 +36,7 @@ def _compare_nchw_nhwc(data, func):
out1 = func(x1)
with mge.config._override(auto_format_convert=True):
out2 = func(x2)
np.testing.assert_equal(out1, out2)
np.testing.assert_almost_equal(out1, out2, decimal=5)


def test_dimshuffle():
@@ -296,8 +309,10 @@ def test_backward():
with gm:
with mge.config._override(auto_format_convert=True, conv_format="NHWC"):
x = F.conv2d(x, w, b)
# TODO: fix manually convert to NHWC, usually used in detection head
# x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2)
gm.backward(x)
# TODO: backward grad has no format yet
# backward grad has no format
np.testing.assert_equal(
w.grad.numpy(),
np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)),


+ 1
- 6
imperative/python/test/unit/functional/test_functional.py View File

@@ -921,12 +921,7 @@ def test_batchnorm2d_autocast():

amp.enabled = False
expected = F.batch_norm(
inp.astype("float16"),
weight=weight,
bias=bias,
training=True,
inplace=False,
compute_mode="float32",
inp.astype("float16"), weight=weight, bias=bias, training=True, inplace=False,
)
assert out.dtype == np.float16
assert expected.dtype == np.float16


Loading…
Cancel
Save