Browse Source

feat(mge/amp): add convert_format module for NHWC training

GitOrigin-RevId: 1b41e1042c
release-1.10
Megvii Engine Team 3 years ago
parent
commit
e393d1cf65
13 changed files with 207 additions and 33 deletions
  1. +1
    -0
      imperative/python/megengine/amp/__init__.py
  2. +16
    -2
      imperative/python/megengine/amp/autocast.py
  3. +45
    -0
      imperative/python/megengine/amp/convert_format.py
  4. +10
    -1
      imperative/python/megengine/autodiff/grad_manager.py
  5. +47
    -7
      imperative/python/megengine/core/_config.py
  6. +12
    -6
      imperative/python/megengine/core/tensor/amp.py
  7. +1
    -0
      imperative/python/megengine/core/tensor/array_method.py
  8. +0
    -2
      imperative/python/megengine/functional/nn.py
  9. +0
    -3
      imperative/python/megengine/module/batchnorm.py
  10. +11
    -2
      imperative/python/megengine/optimizer/optimizer.py
  11. +44
    -0
      imperative/python/test/unit/amp/test_convert_format.py
  12. +19
    -4
      imperative/python/test/unit/core/test_formatted_tensor.py
  13. +1
    -6
      imperative/python/test/unit/functional/test_functional.py

+ 1
- 0
imperative/python/megengine/amp/__init__.py View File

@@ -2,6 +2,7 @@ import mprop


from ..core.tensor.amp import * from ..core.tensor.amp import *
from .autocast import autocast from .autocast import autocast
from .convert_format import convert_module_format, convert_tensor_format
from .grad_scaler import GradScaler from .grad_scaler import GradScaler


mprop.init() mprop.init()

+ 16
- 2
imperative/python/megengine/amp/autocast.py View File

@@ -1,5 +1,6 @@
import functools import functools


from ..core import _config
from ..core.tensor import amp from ..core.tensor import amp




@@ -50,24 +51,37 @@ class autocast:
self._origin_high = None self._origin_high = None
self._origin_low = None self._origin_low = None


self._origin_configs = None

def __enter__(self): def __enter__(self):
self._origin_enabled = amp._enabled self._origin_enabled = amp._enabled
self._origin_high = amp._get_amp_high_prec_dtype()
self._origin_low = amp._get_amp_low_prec_dtype()
amp._enabled = self.enabled amp._enabled = self.enabled
amp._set_amp_dtype_autocast(self.enabled) amp._set_amp_dtype_autocast(self.enabled)
if not self.enabled:
return

self._origin_high = amp._get_amp_high_prec_dtype()
self._origin_low = amp._get_amp_low_prec_dtype()
amp._set_amp_high_prec_dtype(self.high_prec_dtype) amp._set_amp_high_prec_dtype(self.high_prec_dtype)
amp._set_amp_low_prec_dtype(self.low_prec_dtype) amp._set_amp_low_prec_dtype(self.low_prec_dtype)


self._origin_configs = _config._reset_execution_config(compute_mode="float32")

def __exit__(self, *args): def __exit__(self, *args):
amp._enabled = self._origin_enabled amp._enabled = self._origin_enabled
amp._set_amp_dtype_autocast(self._origin_enabled) amp._set_amp_dtype_autocast(self._origin_enabled)
if not self.enabled:
return
amp._set_amp_high_prec_dtype(self._origin_high) amp._set_amp_high_prec_dtype(self._origin_high)
amp._set_amp_low_prec_dtype(self._origin_low) amp._set_amp_low_prec_dtype(self._origin_low)


_config._reset_execution_config(*self._origin_configs)

def __call__(self, func): def __call__(self, func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not self.enabled:
return func(*args, **kwargs)
with self: with self:
return func(*args, **kwargs) return func(*args, **kwargs)




+ 45
- 0
imperative/python/megengine/amp/convert_format.py View File

@@ -0,0 +1,45 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy

from .. import functional as F
from ..module import Module
from ..tensor import Tensor


def _is_nchw_format(param: Tensor):
# TODO: use better condition
return (len(param.shape) == 4 or len(param.shape) == 5) and param.format != "nhwc"


def convert_tensor_format(x: Tensor, inplace: bool = True):
"""Convert NCHW Tensor to NHWC Tensor."""
if x.ndim == 4:
pattern = (0, 2, 3, 1)
elif x.ndim == 5:
pattern = (0, 1, 3, 4, 2)
else:
raise ValueError("Unsupport tensor ndim {}".format(x.ndim))
# TODO: use initialization from tensor after fixing format setting
if inplace:
x[...] = Tensor(x.numpy().transpose(*pattern), format="nhwc")
else:
x = Tensor(x.numpy().transpose(*pattern), format="nhwc")
return x


def convert_module_format(module: Module, inplace: bool = True):
"""Convert NCHW Module to NHWC Module."""
if not inplace:
module = deepcopy(module)

for name, param in module.named_tensors():
if _is_nchw_format(param):
# hostvalue should still be valid, so no d2h cost.
convert_tensor_format(param, inplace=True)
return module

+ 10
- 1
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -1,7 +1,13 @@
import weakref import weakref
from typing import Callable, Iterable, List, Union from typing import Callable, Iterable, List, Union


from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
from ..core._imperative_rt.core2 import (
get_auto_format_convert,
pop_scope,
push_scope,
set_auto_format_convert,
set_option,
)
from ..core.autodiff.grad import Grad from ..core.autodiff.grad import Grad
from ..core.tensor.dtype import is_differentible_dtype from ..core.tensor.dtype import is_differentible_dtype
from ..logger import get_logger from ..logger import get_logger
@@ -253,6 +259,8 @@ class GradManager:
""" """
push_scope("backward") push_scope("backward")
set_option("record_computing_path", 0) set_option("record_computing_path", 0)
_origin_auto_format = get_auto_format_convert()
set_auto_format_convert(False)
from ..functional import ones_like from ..functional import ones_like


global backwarding_grad_manager global backwarding_grad_manager
@@ -296,6 +304,7 @@ class GradManager:
self.release() self.release()
backwarding_grad_manager = cache backwarding_grad_manager = cache
set_option("record_computing_path", 1) set_option("record_computing_path", 1)
set_auto_format_convert(_origin_auto_format)
pop_scope("backward") pop_scope("backward")


def record(self): def record(self):


+ 47
- 7
imperative/python/megengine/core/_config.py View File

@@ -10,8 +10,10 @@ from ._imperative_rt.core2 import (
set_option, set_option,
) )


# use "default" to distinguish it from None in _reset_execution_config
__compute_mode = "default" __compute_mode = "default"
__conv_format = "default" __conv_format = "default"
__bn_format = "default"
_benchmark_kernel = False _benchmark_kernel = False
_deterministic_kernel = False _deterministic_kernel = False


@@ -22,6 +24,8 @@ __all__ = [
"disable_memory_forwarding", "disable_memory_forwarding",
"_compute_mode", "_compute_mode",
"_conv_format", "_conv_format",
"_bn_format",
"_auto_format_convert",
"_override", "_override",
] ]


@@ -32,6 +36,7 @@ def benchmark_kernel(mod):
which means use heuristic to choose the fastest algorithm. which means use heuristic to choose the fastest algorithm.


Examples: Examples:

.. code-block:: .. code-block::


import megengine as mge import megengine as mge
@@ -55,6 +60,7 @@ def deterministic_kernel(mod):
which means the algorithm is not reproducible. which means the algorithm is not reproducible.


Examples: Examples:

.. code-block:: .. code-block::


import megengine as mge import megengine as mge
@@ -75,6 +81,7 @@ def async_level(mod) -> int:
which means both device and user side errors are async. which means both device and user side errors are async.


Examples: Examples:

.. code-block:: .. code-block::


import megengine as mge import megengine as mge
@@ -110,16 +117,17 @@ def disable_memory_forwarding(mod, disable: bool):


@property @property
def _compute_mode(mod): def _compute_mode(mod):
r"""Get or set the precision of intermediate results. The default option is "default",
which means that no special requirements will be placed on. When set to 'float32', it
would be used for accumulator and intermediate result, but only effective when input and
r"""Get or set the precision of intermediate results for conv, matmul. The default
option is None and will fallback to "default". When set to "float32", it will
trigger mixed precision computation on TensorCore, but only effective when input and
output are of float16 dtype. output are of float16 dtype.


Examples: Examples:

.. code-block:: .. code-block::


import megengine as mge import megengine as mge
mge.config._compute_mode = "default"
mge.config._compute_mode = "float32"
""" """
return __compute_mode return __compute_mode


@@ -132,7 +140,7 @@ def _compute_mode(mod, _compute_mode: str):


@property @property
def _conv_format(mod): def _conv_format(mod):
r"""Get or set convolution data/filter/output layout format. The default option is "default",
r"""Get or set convolution data/filter/output layout format. The default option is None,
which means that no special format will be placed on. There are all layout definitions which means that no special format will be placed on. There are all layout definitions


``NCHW`` layout: ``{N, C, H, W}`` ``NCHW`` layout: ``{N, C, H, W}``
@@ -145,6 +153,7 @@ def _conv_format(mod):
``NCHW64`` layout: ``{N, C/64, H, W, 64}`` ``NCHW64`` layout: ``{N, C/64, H, W, 64}``


Examples: Examples:

.. code-block:: .. code-block::


import megengine as mge import megengine as mge
@@ -160,11 +169,34 @@ def _conv_format(mod, format: str):




@property @property
def _bn_format(mod):
r"""Get or set batchnorm param layout format. The default option is None and will
fallback to "dim_1c11" which corresponds to NCHW format. When set to "dim_111c",
param format of batchnorm will be changed to NHWC.

Examples:

.. code-block::

import megengine as mge
mge.config._bn_format = "dim_111c"
"""
return __bn_format


@_bn_format.setter
def _bn_format(mod, format: str):
global __bn_format
__bn_format = format


@property
def _auto_format_convert(mod): def _auto_format_convert(mod):
r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order. r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order.
The default value is False, which means no convert. The default value is False, which means no convert.


Examples: Examples:

.. code-block:: .. code-block::


import megengine as mge import megengine as mge
@@ -184,15 +216,17 @@ def _reset_execution_config(
async_level=None, async_level=None,
compute_mode=None, compute_mode=None,
conv_format=None, conv_format=None,
bn_format=None,
auto_format_convert=None, auto_format_convert=None,
): ):
global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format
global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format, __bn_format
orig_flags = ( orig_flags = (
_benchmark_kernel, _benchmark_kernel,
_deterministic_kernel, _deterministic_kernel,
get_option("async_level"), get_option("async_level"),
__compute_mode, __compute_mode,
__conv_format, __conv_format,
__bn_format,
get_auto_format_convert(), get_auto_format_convert(),
) )
if benchmark_kernel is not None: if benchmark_kernel is not None:
@@ -205,6 +239,8 @@ def _reset_execution_config(
__compute_mode = compute_mode __compute_mode = compute_mode
if conv_format is not None: if conv_format is not None:
__conv_format = conv_format __conv_format = conv_format
if bn_format is not None:
__bn_format = bn_format
if auto_format_convert is not None: if auto_format_convert is not None:
set_auto_format_convert(auto_format_convert) set_auto_format_convert(auto_format_convert)


@@ -218,12 +254,14 @@ def _override(
async_level=None, async_level=None,
compute_mode=None, compute_mode=None,
conv_format=None, conv_format=None,
bn_format=None,
auto_format_convert=None, auto_format_convert=None,
): ):
r"""A context manager that users can opt in by attaching the decorator to set r"""A context manager that users can opt in by attaching the decorator to set
the config of the global variable. the config of the global variable.


Examples: Examples:

.. code-block:: .. code-block::


import megengine as mge import megengine as mge
@@ -234,6 +272,7 @@ def _override(
async_level=2, async_level=2,
compute_mode="float32", compute_mode="float32",
conv_format="NHWC", conv_format="NHWC",
bn_format="dim_111c",
auto_format_convert=True, auto_format_convert=True,
) )
def train(): def train():
@@ -244,6 +283,7 @@ def _override(
async_level, async_level,
compute_mode, compute_mode,
conv_format, conv_format,
bn_format,
auto_format_convert, auto_format_convert,
) )
try: try:
@@ -254,4 +294,4 @@ def _override(




def _get_actual_op_param(function_param, config_param): def _get_actual_op_param(function_param, config_param):
return function_param if config_param == "default" else config_param
return function_param if config_param is "default" else config_param

+ 12
- 6
imperative/python/megengine/core/tensor/amp.py View File

@@ -10,13 +10,19 @@ from .._imperative_rt.core2 import (
_enabled = False _enabled = False
_set_amp_dtype_autocast(_enabled) _set_amp_dtype_autocast(_enabled)


__all__ = [
"enabled",
"high_prec_dtype",
"low_prec_dtype",
]



@property @property
def enabled(mod): def enabled(mod):
r"""Get or set amp autocast mode enabled or not. r"""Get or set amp autocast mode enabled or not.
Examples: Examples:
.. code-block:: .. code-block::


import megengine as mge import megengine as mge
@@ -36,9 +42,9 @@ def enabled(mod, enabled: bool):
def high_prec_dtype(mod): def high_prec_dtype(mod):
r"""Get or set amp autocast mode's higher precision dtype. It will change the r"""Get or set amp autocast mode's higher precision dtype. It will change the
target dtype in tensor casting for better precision. Default: float32. target dtype in tensor casting for better precision. Default: float32.
Examples: Examples:
.. code-block:: .. code-block::


import megengine as mge import megengine as mge
@@ -56,9 +62,9 @@ def high_prec_dtype(mod, dtype: str):
def low_prec_dtype(mod): def low_prec_dtype(mod):
r"""Get or set amp autocast mode's lower precision dtype. It will change the r"""Get or set amp autocast mode's lower precision dtype. It will change the
target dtype in tensor casting for better speed and memory. Default: float16. target dtype in tensor casting for better speed and memory. Default: float16.
Examples: Examples:
.. code-block:: .. code-block::


import megengine as mge import megengine as mge


+ 1
- 0
imperative/python/megengine/core/tensor/array_method.py View File

@@ -63,6 +63,7 @@ def _matmul(
assert dim1 > 0 and dim2 > 0 assert dim1 > 0 and dim2 > 0
maxdim = dim1 if dim1 > dim2 else dim2 maxdim = dim1 if dim1 > dim2 else dim2
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)

if dim1 == 1 and dim2 == 1: # dispatch to Dot if dim1 == 1 and dim2 == 1: # dispatch to Dot
(result,) = apply(builtin.Dot(), inp1, inp2) (result,) = apply(builtin.Dot(), inp1, inp2)
return result return result


+ 0
- 2
imperative/python/megengine/functional/nn.py View File

@@ -441,7 +441,6 @@ def deformable_conv2d(
or conv_mode.name == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
) )
if amp._enabled: if amp._enabled:
compute_mode = "float32"
inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias) inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias)
else: else:
offset = offset.astype("float32") offset = offset.astype("float32")
@@ -1182,7 +1181,6 @@ def batch_norm(
momentum: float = 0.9, momentum: float = 0.9,
eps: float = 1e-5, eps: float = 1e-5,
inplace: bool = True, inplace: bool = True,
compute_mode="default",
param_dim="dim_1c11" param_dim="dim_1c11"
): ):
r"""Applies batch normalization to the input. r"""Applies batch normalization to the input.


+ 0
- 3
imperative/python/megengine/module/batchnorm.py View File

@@ -19,7 +19,6 @@ class _BatchNorm(Module):
affine=True, affine=True,
track_running_stats=True, track_running_stats=True,
freeze=False, freeze=False,
compute_mode="default",
param_dim="dim_1c11", param_dim="dim_1c11",
**kwargs **kwargs
): ):
@@ -31,7 +30,6 @@ class _BatchNorm(Module):
self.track_running_stats = track_running_stats self.track_running_stats = track_running_stats
self._track_running_stats_saved = track_running_stats self._track_running_stats_saved = track_running_stats
self.freeze = freeze self.freeze = freeze
self.compute_mode = compute_mode
self.param_dim = param_dim self.param_dim = param_dim
if self.freeze: if self.freeze:
assert ( assert (
@@ -106,7 +104,6 @@ class _BatchNorm(Module):
or ((self.running_mean is None) and (self.running_var is None)), or ((self.running_mean is None) and (self.running_var is None)),
momentum=exponential_average_factor, momentum=exponential_average_factor,
eps=self.eps, eps=self.eps,
compute_mode=self.compute_mode,
param_dim=self.param_dim, param_dim=self.param_dim,
) )




+ 11
- 2
imperative/python/megengine/optimizer/optimizer.py View File

@@ -8,7 +8,13 @@ from typing import Union


import numpy as np import numpy as np


from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
from ..core._imperative_rt.core2 import (
get_auto_format_convert,
pop_scope,
push_scope,
set_auto_format_convert,
set_option,
)
from ..core.tensor.utils import set_convert_inputs from ..core.tensor.utils import set_convert_inputs
from ..tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
from ..utils.deprecation import deprecated from ..utils.deprecation import deprecated
@@ -90,7 +96,7 @@ class Optimizer(metaclass=ABCMeta):
"optimizer can only optimize Parameters, but one of the params is " "optimizer can only optimize Parameters, but one of the params is "
+ str(type(param)) + str(type(param))
) )
param._reset(Tensor(param.numpy(), no_cache=True))
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format))


for name, default in self._defaults.items(): for name, default in self._defaults.items():
if default is required and name not in param_group: if default is required and name not in param_group:
@@ -139,6 +145,8 @@ class Optimizer(metaclass=ABCMeta):
# set the globle state `_enable_convert_inputs` to `False` to disable # set the globle state `_enable_convert_inputs` to `False` to disable
# the `convert_inputs` for param updates # the `convert_inputs` for param updates
set_option("record_computing_path", 0) set_option("record_computing_path", 0)
_origin_auto_format = get_auto_format_convert()
set_auto_format_convert(False)
if self._disable_type_convert: if self._disable_type_convert:
backup = set_convert_inputs(False) backup = set_convert_inputs(False)
for group in self.param_groups: for group in self.param_groups:
@@ -155,6 +163,7 @@ class Optimizer(metaclass=ABCMeta):
# restore the globle state `_enable_convert_inputs` # restore the globle state `_enable_convert_inputs`
set_convert_inputs(backup) set_convert_inputs(backup)
set_option("record_computing_path", 1) set_option("record_computing_path", 1)
set_auto_format_convert(_origin_auto_format)
return self return self


@deprecated(version="1.0", reason="use clear_grad instead") @deprecated(version="1.0", reason="use clear_grad instead")


+ 44
- 0
imperative/python/test/unit/amp/test_convert_format.py View File

@@ -0,0 +1,44 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest

import megengine.functional as F
import megengine.module as M
from megengine import Parameter, Tensor, amp, tensor


class MyModule(M.Module):
class InnerModule(M.Module):
def __init__(self):
super().__init__()
self.bn = M.BatchNorm2d(4)

def forward(self, x):
return self.bn(x)

def __init__(self):
super().__init__()
self.i = self.InnerModule()
self.conv = M.Conv2d(4, 4, 4, groups=2)
self.bn = M.BatchNorm2d(4)
self.param = Parameter(np.ones((1, 3, 1, 1), dtype=np.float32))
self.buff = Tensor(np.ones((1, 3, 1, 1), dtype=np.float32))

def forward(self, x):
x = self.i(x)
x = self.bn(x)
return x


@pytest.mark.parametrize("is_inplace", [False, True])
def test_convert_module(is_inplace):
m = MyModule()
m = amp.convert_module_format(m, is_inplace)
for name, param in m.named_tensors():
assert param.format == "nhwc"

+ 19
- 4
imperative/python/test/unit/core/test_formatted_tensor.py View File

@@ -8,14 +8,27 @@ from megengine.autodiff import GradManager




def test_basic(): def test_basic():
a = tensor(np.arange(0, 24).reshape((1, 2, 3, 4)), dtype="float32", format="nhwc")
data = np.arange(0, 24).reshape((1, 2, 3, 4))
# init from numpy
a = tensor(data, format="nhwc")
assert a.format == "nhwc" assert a.format == "nhwc"

# init from tensor
b = tensor(a) b = tensor(a)
assert b.format == "nhwc" assert b.format == "nhwc"
# TODO: fix Tensor init bug for another Tensor

# TODO: init from tensor with new format
# c = tensor(a, format="nchw") # c = tensor(a, format="nchw")
# assert c.format == "nchw" # assert c.format == "nchw"


# TODO: reset from numpy
# b[...] = data
# assert b.format == "nhwc"

# reset from tensor
b[...] = tensor(data, format="nchw")
assert b.format == "nchw"



def _compare_nchw_nhwc(data, func): def _compare_nchw_nhwc(data, func):
x1 = tensor(data, format="nchw") x1 = tensor(data, format="nchw")
@@ -23,7 +36,7 @@ def _compare_nchw_nhwc(data, func):
out1 = func(x1) out1 = func(x1)
with mge.config._override(auto_format_convert=True): with mge.config._override(auto_format_convert=True):
out2 = func(x2) out2 = func(x2)
np.testing.assert_equal(out1, out2)
np.testing.assert_almost_equal(out1, out2, decimal=5)




def test_dimshuffle(): def test_dimshuffle():
@@ -296,8 +309,10 @@ def test_backward():
with gm: with gm:
with mge.config._override(auto_format_convert=True, conv_format="NHWC"): with mge.config._override(auto_format_convert=True, conv_format="NHWC"):
x = F.conv2d(x, w, b) x = F.conv2d(x, w, b)
# TODO: fix manually convert to NHWC, usually used in detection head
# x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2)
gm.backward(x) gm.backward(x)
# TODO: backward grad has no format yet
# backward grad has no format
np.testing.assert_equal( np.testing.assert_equal(
w.grad.numpy(), w.grad.numpy(),
np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)),


+ 1
- 6
imperative/python/test/unit/functional/test_functional.py View File

@@ -921,12 +921,7 @@ def test_batchnorm2d_autocast():


amp.enabled = False amp.enabled = False
expected = F.batch_norm( expected = F.batch_norm(
inp.astype("float16"),
weight=weight,
bias=bias,
training=True,
inplace=False,
compute_mode="float32",
inp.astype("float16"), weight=weight, bias=bias, training=True, inplace=False,
) )
assert out.dtype == np.float16 assert out.dtype == np.float16
assert expected.dtype == np.float16 assert expected.dtype == np.float16


Loading…
Cancel
Save