Browse Source

feat(mge/amp): add mix precision autocast support

GitOrigin-RevId: 6fbffc4845
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
1bf18252c4
20 changed files with 403 additions and 160 deletions
  1. +1
    -0
      imperative/python/megengine/__init__.py
  2. +14
    -0
      imperative/python/megengine/amp/__init__.py
  3. +79
    -0
      imperative/python/megengine/amp/autocast.py
  4. +1
    -1
      imperative/python/megengine/core/_wrap.py
  5. +78
    -0
      imperative/python/megengine/core/tensor/amp.py
  6. +54
    -37
      imperative/python/megengine/core/tensor/array_method.py
  7. +1
    -1
      imperative/python/megengine/core/tensor/megbrain_graph.py
  8. +10
    -1
      imperative/python/megengine/core/tensor/utils.py
  9. +3
    -42
      imperative/python/megengine/functional/elemwise.py
  10. +12
    -6
      imperative/python/megengine/functional/math.py
  11. +54
    -28
      imperative/python/megengine/functional/nn.py
  12. +1
    -1
      imperative/python/megengine/functional/tensor.py
  13. +1
    -1
      imperative/python/megengine/jit/tracing.py
  14. +3
    -0
      imperative/python/megengine/module/batchnorm.py
  15. +8
    -2
      imperative/python/megengine/module/linear.py
  16. +0
    -2
      imperative/python/megengine/quantization/internal_fake_quant.py
  17. +1
    -1
      imperative/python/megengine/tensor.py
  18. +34
    -0
      imperative/python/test/unit/amp/test_autocast.py
  19. +48
    -4
      imperative/python/test/unit/functional/test_functional.py
  20. +0
    -33
      imperative/python/test/unit/module/test_module.py

+ 1
- 0
imperative/python/megengine/__init__.py View File

@@ -117,6 +117,7 @@ def _atexit(handler):


# subpackages
import megengine.amp
import megengine.autodiff
import megengine.data
import megengine.distributed


+ 14
- 0
imperative/python/megengine/amp/__init__.py View File

@@ -0,0 +1,14 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import mprop

from ..core.tensor.amp import *
from .autocast import autocast

mprop.init()

+ 79
- 0
imperative/python/megengine/amp/autocast.py View File

@@ -0,0 +1,79 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools

from ..core.tensor import amp


class autocast:
r"""
A class to control autocast mode for amp as a context manager or a decorator.

:param enabled: Whether autocast mode is enabled.
:low_prec_dtype: Set amp autocast mode's lower precision dtype. It will change the
target dtype in tensor casting for better speed and memory. Default: float16.
:high_prec_dtype: Set amp autocast mode's higher precision dtype. It will change the
target dtype in tensor casting for better precision. Default: float32.

Examples:

..code-block::

# used as decorator
@autocast()
def train_step(image, label):
with gm:
logits = model(image)
loss = F.nn.cross_entropy(logits, label)
gm.backward(loss)
opt.step().clear_grad()
return loss

# used as context manager
def train_step(image, label):
with autocast():
with gm:
logits = model(image)
loss = F.nn.cross_entropy(logits, label)
gm.backward(loss)
opt.step().clear_grad()
return loss
"""

def __init__(
self,
enabled: bool = True,
low_prec_dtype: str = "float16",
high_prec_dtype: str = "float32",
):
self.enabled = enabled
self.high_prec_dtype = high_prec_dtype
self.low_prec_dtype = low_prec_dtype
self._origin_enabled = None
self._origin_high = None
self._origin_low = None

def __enter__(self):
self._origin_enabled, amp._enabled = amp._enabled, self.enabled
self._origin_high = amp._high_prec_dtype
amp._high_prec_dtype = self.high_prec_dtype
self._origin_low = amp._low_prec_dtype
amp._low_prec_dtype = self.low_prec_dtype

def __exit__(self, *args):
amp._enabled = self._origin_enabled
amp._high_prec_dtype = self._origin_high
amp._low_prec_dtype = self._origin_low

def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with self:
return func(*args, **kwargs)

return wrapper

+ 1
- 1
imperative/python/megengine/core/_wrap.py View File

@@ -49,7 +49,7 @@ class Device:
return self._cn == rhs._cn


def device(obj):
def as_device(obj):
if isinstance(obj, Device):
return obj
return Device(obj)

+ 78
- 0
imperative/python/megengine/core/tensor/amp.py View File

@@ -0,0 +1,78 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
_enabled = False
_high_prec_dtype = "float32"
_low_prec_dtype = "float16"


@property
def enabled(mod):
r"""
Get or set amp autocast mode enabled or not.

Examples:

..code-block::

import megengine as mge
mge.amp.enabled = True

"""
return _enabled


@enabled.setter
def enabled(mod, enabled: bool):
global _enabled
_enabled = enabled


@property
def high_prec_dtype(mod):
r"""
Get or set amp autocast mode's higher precision dtype. It will change the
target dtype in tensor casting for better precision. Default: float32.

Examples:

..code-block::

import megengine as mge
mge.amp.high_prec_dtype = "float32"

"""
return _high_prec_dtype


@high_prec_dtype.setter
def high_prec_dtype(mod, dtype: str):
global _high_prec_dtype
_high_prec_dtype = dtype


@property
def low_prec_dtype(mod):
r"""
Get or set amp autocast mode's lower precision dtype. It will change the
target dtype in tensor casting for better speed and memory. Default: float16.

Examples:

..code-block::

import megengine as mge
mge.amp.low_prec_dtype = "float16"

"""
return _low_prec_dtype


@low_prec_dtype.setter
def low_prec_dtype(mod, dtype: str):
global _low_prec_dtype
_low_prec_dtype = dtype

+ 54
- 37
imperative/python/megengine/core/tensor/array_method.py View File

@@ -15,15 +15,20 @@ import numpy as np
from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import SymbolVar, Tensor, apply
from ..ops import builtin
from ..ops.builtin import Elemwise, GetVarShape
from . import utils
from .indexing import getitem as _getitem
from .indexing import setitem as _setitem
from .utils import isscalar
from .utils import make_shape_tuple as _make_shape_tuple
from .utils import setscalar

_ElwMod = Elemwise.Mode
from . import amp
from .indexing import getitem, setitem
from .utils import (
_normalize_axis,
astensor1d,
astype,
cast_tensors,
convert_inputs,
isscalar,
make_shape_tuple,
setscalar,
)

_ElwMod = builtin.Elemwise.Mode


def _elwise_apply(args, mode):
@@ -40,47 +45,59 @@ def _elwise_apply(args, mode):


def _elwise(*args, mode):
args = convert_inputs(*args)
if mode in (
_ElwMod.TRUE_DIV,
_ElwMod.EXP,
_ElwMod.POW,
_ElwMod.CEIL,
_ElwMod.FLOOR,
_ElwMod.ROUND,
_ElwMod.LOG,
_ElwMod.EXPM1,
_ElwMod.LOG1P,
_ElwMod.TANH,
_ElwMod.ACOS,
_ElwMod.ASIN,
_ElwMod.ATAN2,
_ElwMod.COS,
_ElwMod.H_SWISH,
_ElwMod.SIGMOID,
_ElwMod.SIN,
) and (
amp._enabled or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args])
):
if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND) and np.issubdtype(
args[0].dtype, np.integer
):
return args[0]
args = tuple(
map(
lambda x: x.astype("float32")
if hasattr(x, "dtype") and x.dtype != np.float32
else x,
args,
)
)
args = utils.convert_inputs(*args)
# autocast to FP32 to maintain precision
# or to avoid op's not supporting all int args
args = cast_tensors(*args, promote=True)

if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND,) and np.issubdtype(
args[0].dtype, np.integer
):
return args[0]
return _elwise_apply(args, mode)


def _matmul(inp1, inp2):
if amp._enabled:
compute_mode = "float32"
inp1, inp2 = cast_tensors(inp1, inp2)
else:
compute_mode = "default"
inp1, inp2 = convert_inputs(inp1, inp2)
op = builtin.MatrixMul(
transposeA=False, transposeB=False, compute_mode="default", format="default"
transposeA=False, transposeB=False, compute_mode=compute_mode, format="default"
)
inp1, inp2 = utils.convert_inputs(inp1, inp2)
(result,) = apply(op, inp1, inp2)
return result


def _transpose(data, axes):
op = builtin.Dimshuffle(axes)
(data,) = utils.convert_inputs(data)
(data,) = convert_inputs(data)
(result,) = apply(op, data)
return result


def _broadcast(inp, shape):
shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device)
shape = astensor1d(shape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), inp, shape)
return result

@@ -88,7 +105,7 @@ def _broadcast(inp, shape):
def _reshape(x, shape):
unspec_axis = None
try:
shape_tuple = _make_shape_tuple(shape)
shape_tuple = make_shape_tuple(shape)
except ValueError:
pass
else:
@@ -102,7 +119,7 @@ def _reshape(x, shape):
"multiple -1 in shape: {} & {}".format(unspec_axis, i)
)
unspec_axis = i
shape = utils.astensor1d(shape, x, dtype="int32", device=x.device)
shape = astensor1d(shape, x, dtype="int32", device=x.device)
if unspec_axis is None:
op = builtin.Reshape()
else:
@@ -171,7 +188,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
return list(map(int, axis))

axis = get_axes()
axis = utils._normalize_axis(inp.ndim, axis)
axis = _normalize_axis(inp.ndim, axis)
axis = [a - i for i, a in enumerate(axis)]

op = builtin.RemoveAxis(axis=axis)
@@ -184,7 +201,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
def _reduce(mode):
def f(self, axis=None, keepdims: bool = False):
data = self
(data,) = utils.convert_inputs(data)
(data,) = convert_inputs(data)
if mode == "mean":
data = data.astype("float32")
elif self.dtype == np.bool_:
@@ -196,7 +213,7 @@ def _reduce(mode):
op = builtin.Reduce(mode=mode, axis=0)
(result,) = apply(op, data)
elif isinstance(axis, collections.abc.Iterable):
axis = utils._normalize_axis(self.ndim, axis, reverse=True)
axis = _normalize_axis(self.ndim, axis, reverse=True)
for ai in axis:
op = builtin.Reduce(mode=mode, axis=ai)
(data,) = apply(op, data)
@@ -359,11 +376,11 @@ class ArrayMethodMixin(abc.ABC):
yield self[i]

def __getitem__(self, index):
return _getitem(self, index)
return getitem(self, index)

def __setitem__(self, index, value):
if index is not Ellipsis:
value = _setitem(self, index, value)
value = setitem(self, index, value)
self._reset(value)

__contains__ = _todo
@@ -422,7 +439,7 @@ class ArrayMethodMixin(abc.ABC):
Returns a :class:`Tensor` with the same data and number of elements
with the specified :attr:`~.Tensor.dtype`.
"""
return utils.astype(self, dtype)
return astype(self, dtype)

def reshape(self, *args):
r"""


+ 1
- 1
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -18,7 +18,7 @@ import numpy as np
from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode
from .._wrap import device as as_device
from .._wrap import as_device
from ..ops.builtin import OpDef
from .core import TensorBase



+ 10
- 1
imperative/python/megengine/core/tensor/utils.py View File

@@ -13,9 +13,10 @@ import numpy as np

from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device
from .._wrap import device as as_device
from .._wrap import as_device
from ..ops import builtin
from ..ops.special import Const
from .amp import _high_prec_dtype, _low_prec_dtype
from .dtype import is_dtype_equal, is_quantize

_enable_convert_inputs = True
@@ -98,6 +99,14 @@ def convert_inputs(*args, device=None):
return tuple(map(convert, args))


def cast_tensors(*args, promote=False):
if promote:
dtype = _high_prec_dtype
else:
dtype = _low_prec_dtype
return tuple(arg.astype(dtype) if arg is not None else None for arg in args)


def result_type(*args):
dtypes = []
for i in args:


+ 3
- 42
imperative/python/megengine/functional/elemwise.py View File

@@ -12,10 +12,8 @@ import numpy as np
from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core.ops import builtin
from ..core.ops.builtin import Elemwise
from ..core.tensor import utils
from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import astype
from ..device import get_default_device
from ..core.tensor.array_method import _elwise
from ..core.tensor.utils import astype, convert_inputs
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func

@@ -69,46 +67,9 @@ __all__ = [
]


def _elwise(*args, mode):
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args))
if len(tensor_args) == 0:
dtype = utils.dtype_promotion(args)
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())
args = utils.convert_inputs(first_arg, *args[1:])
else:
args = utils.convert_inputs(*args)
if mode in (
Elemwise.Mode.TRUE_DIV,
Elemwise.Mode.EXP,
Elemwise.Mode.POW,
Elemwise.Mode.LOG,
Elemwise.Mode.EXPM1,
Elemwise.Mode.LOG1P,
Elemwise.Mode.TANH,
Elemwise.Mode.ACOS,
Elemwise.Mode.ASIN,
Elemwise.Mode.ATAN2,
Elemwise.Mode.CEIL,
Elemwise.Mode.COS,
Elemwise.Mode.FLOOR,
Elemwise.Mode.H_SWISH,
Elemwise.Mode.ROUND,
Elemwise.Mode.SIGMOID,
Elemwise.Mode.SIN,
):
if mode in (
Elemwise.Mode.CEIL,
Elemwise.Mode.FLOOR,
Elemwise.Mode.ROUND,
) and np.issubdtype(args[0].dtype, np.integer):
return args[0]
args = tuple(map(lambda x: astype(x, "float32"), args))
return _elwise_apply(args, mode)


def _elemwise_multi_type(*args, mode, **kwargs):
op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
args = utils.convert_inputs(*args)
args = convert_inputs(*args)
(result,) = apply(op, *args)
return result



+ 12
- 6
imperative/python/megengine/functional/math.py View File

@@ -14,7 +14,8 @@ from ..core._imperative_rt.core2 import apply
from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin
from ..core.ops.special import Const
from ..core.tensor import utils
from ..core.tensor import amp
from ..core.tensor.utils import _normalize_axis, cast_tensors, convert_inputs, setscalar
from ..tensor import Tensor
from .debug_param import get_execution_strategy
from .elemwise import clip, exp, log, log1p
@@ -471,7 +472,7 @@ def argmin(
inp = inp.flatten()
axis = 0

axis = utils._normalize_axis(inp.ndim, axis, reverse=True)
axis = _normalize_axis(inp.ndim, axis, reverse=True)
if isinstance(axis, collections.abc.Iterable):

for ai in axis:
@@ -528,7 +529,7 @@ def argmax(
assert not keepdims, "can not set axis=None and keepdims=True"
inp = inp.flatten()
axis = 0
axis = utils._normalize_axis(inp.ndim, axis, reverse=True)
axis = _normalize_axis(inp.ndim, axis, reverse=True)

if isinstance(axis, collections.abc.Iterable):

@@ -807,8 +808,13 @@ def matmul(
[28. 40.]]

"""
if amp._enabled:
compute_mode = "float32"
inp1, inp2 = cast_tensors(inp1, inp2)
else:
inp1, inp2 = convert_inputs(inp1, inp2)

remove_row, remove_col = False, False
inp1, inp2 = utils.convert_inputs(inp1, inp2)

dim1, dim2 = inp1.ndim, inp2.ndim
# handle dim=1 cases, dot and matrix-vector multiplication
@@ -921,12 +927,12 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor:

"""
op = builtin.Dot()
inp1, inp2 = utils.convert_inputs(inp1, inp2)
inp1, inp2 = convert_inputs(inp1, inp2)
assert (
inp1.ndim <= 1 and inp2.ndim <= 1
), "Input tensors for dot must be 1-dimensional or scalar"
(result,) = apply(op, inp1, inp2)
utils.setscalar(result)
setscalar(result)
return result




+ 54
- 28
imperative/python/megengine/functional/nn.py View File

@@ -15,9 +15,16 @@ from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph, utils
from ..core.tensor import amp, megbrain_graph
from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import astensor1d, astype, setscalar
from ..core.tensor.utils import (
astensor1d,
astype,
cast_tensors,
convert_inputs,
convert_single_value,
setscalar,
)
from ..device import get_default_device
from ..distributed import WORLD, is_distributed
from ..random import uniform
@@ -91,7 +98,9 @@ def expand_hw(x):
return int(h), int(w)


def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
def linear(
inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None, compute_mode="default",
) -> Tensor:
"""
Applies a linear transformation to the input tensor.

@@ -102,8 +111,10 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor
:param bias: bias with shape `(out_features,)`.
Default: None
"""
ret = matmul(inp, weight, transpose_b=True)
ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode)
if bias is not None:
if amp._enabled:
bias = bias.astype("float16")
ret += bias
return ret

@@ -153,6 +164,11 @@ def conv1d(
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
assert inp.ndim == 3, "the input dimension of conv1d should be 3"
assert weight.ndim == 3, "the weight dimension of conv1d should be 3"
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
inp, weight = convert_inputs(inp, weight)

inp = expand_dims(inp, 3)
weight = expand_dims(weight, 3)
@@ -177,7 +193,6 @@ def conv1d(
compute_mode=compute_mode,
sparse=sparse_type,
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
@@ -228,7 +243,11 @@ def conv2d(
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
inp, weight = convert_inputs(inp, weight)

stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
@@ -247,7 +266,6 @@ def conv2d(
compute_mode=compute_mode,
sparse=sparse_type,
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
@@ -286,6 +304,7 @@ def conv3d(
:return: output tensor.
"""
assert conv_mode.lower() == "cross_correlation"
inp, weight = convert_inputs(inp, weight)

D, H, W = 0, 1, 2

@@ -308,7 +327,6 @@ def conv3d(
mode=conv_mode,
sparse=sparse_type,
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
@@ -358,7 +376,11 @@ def conv_transpose2d(
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
inp, weight = convert_inputs(inp, weight)

if groups != 1:
raise NotImplementedError("group transposed conv2d is not supported yet.")
@@ -375,8 +397,8 @@ def conv_transpose2d(
dilate_h=dilate_h,
dilate_w=dilate_w,
strategy=get_execution_strategy(),
compute_mode=compute_mode,
)
weight, inp = utils.convert_inputs(weight, inp)
(output,) = apply(op, weight, inp)
if bias is not None:
output += bias
@@ -428,7 +450,11 @@ def deformable_conv2d(
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
if amp._enabled:
compute_mode = "float32"
inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias)
else:
inp, weight, offset, mask = convert_inputs(inp, weight, offset, mask)

stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
@@ -447,7 +473,6 @@ def deformable_conv2d(
compute_mode=compute_mode,
sparse=sparse_type,
)
inp, weight, offset, mask = utils.convert_inputs(inp, weight, offset, mask)
(output,) = apply(op, inp, weight, offset, mask)
if bias is not None:
output += bias
@@ -468,6 +493,7 @@ def local_conv2d(
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
inp, weight = convert_inputs(inp, weight)

stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
@@ -481,10 +507,8 @@ def local_conv2d(
dilate_h=dilate_h,
dilate_w=dilate_w,
mode=conv_mode,
compute_mode="default",
sparse="dense",
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
@@ -515,8 +539,9 @@ def conv_transpose3d(
:param dilation: dilation of the 3D convolution operation. Default: 1
:return: output tensor.
"""
D, H, W = 0, 1, 2
inp, weight = convert_inputs(inp, weight)

D, H, W = 0, 1, 2
pad = _triple(padding)
stride = _triple_nonzero(stride)
dilate = _triple_nonzero(dilation)
@@ -533,7 +558,6 @@ def conv_transpose3d(
dilate_w=dilate[W],
strategy=get_execution_strategy(),
)
weight, inp = utils.convert_inputs(weight, inp)
(output,) = apply(op, weight, inp)
if bias is not None:
output += bias
@@ -994,7 +1018,8 @@ def batch_norm(
training: bool = False,
momentum: float = 0.9,
eps: float = 1e-5,
inplace: bool = True
inplace: bool = True,
compute_mode="default",
):
r"""
Applies batch normalization to the input.
@@ -1027,15 +1052,11 @@ def batch_norm(
def make_full_if_none(x, value):
if x is None:
(x,) = Const(value, dtype=inp.dtype, device=inp.device)()
shape = utils.astensor1d(
(1, C, 1, 1), inp, dtype="int32", device=inp.device
)
shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), x, shape)
return result
elif x.ndim == 1:
shape = utils.astensor1d(
(1, C, 1, 1), inp, dtype="int32", device=inp.device
)
shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Reshape(), x, shape)
return result
return x
@@ -1052,10 +1073,15 @@ def batch_norm(
if has_var and running_var.ndim != 4:
raise ValueError

inp, weight, bias, running_mean, running_var = utils.convert_inputs(
inp, weight, bias, running_mean, running_var
)

if amp._enabled:
inp = inp.astype("float16")
weight, bias, running_mean, running_var = cast_tensors(
weight, bias, running_mean, running_var, promote=True
)
elif compute_mode != "float32":
inp, weight, bias, running_mean, running_var = convert_inputs(
inp, weight, bias, running_mean, running_var
)
weight = make_full_if_none(weight, 1)
bias = make_full_if_none(bias, 0)

@@ -1352,7 +1378,7 @@ def indexing_one_hot(
"""
assert isinstance(src, Tensor), "src must be of Tensor type"
op = builtin.IndexingOneHot(axis=axis)
index = utils.convert_single_value(index, dtype="int32", device=src.device)
index = convert_single_value(index, dtype="int32", device=src.device)
(result,) = apply(op, src, index)
if not keepdims:
result = squeeze(result, axis)


+ 1
- 1
imperative/python/megengine/functional/tensor.py View File

@@ -13,7 +13,7 @@ import numpy as np

from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core._wrap import device as as_device
from ..core._wrap import as_device
from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity
from ..core.ops.special import Const


+ 1
- 1
imperative/python/megengine/jit/tracing.py View File

@@ -33,7 +33,7 @@ from ..core._imperative_rt.ops import (
RemoteSend,
)
from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device
from ..core._wrap import as_device
from ..core.ops.builtin import BatchNorm, OpDef
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G


+ 3
- 0
imperative/python/megengine/module/batchnorm.py View File

@@ -26,6 +26,7 @@ class _BatchNorm(Module):
affine=True,
track_running_stats=True,
freeze=False,
compute_mode="default",
**kwargs
):
super(_BatchNorm, self).__init__(**kwargs)
@@ -36,6 +37,7 @@ class _BatchNorm(Module):
self.track_running_stats = track_running_stats
self._track_running_stats_saved = track_running_stats
self.freeze = freeze
self.compute_mode = compute_mode
if self.freeze:
assert (
self._track_running_stats_saved
@@ -123,6 +125,7 @@ class _BatchNorm(Module):
or ((self.running_mean is None) and (self.running_var is None)),
momentum=exponential_average_factor,
eps=self.eps,
compute_mode=self.compute_mode,
)

if _ndims != 4:


+ 8
- 2
imperative/python/megengine/module/linear.py View File

@@ -51,7 +51,12 @@ class Linear(Module):
"""

def __init__(
self, in_features: int, out_features: int, bias: bool = True, **kwargs
self,
in_features: int,
out_features: int,
bias: bool = True,
compute_mode: str = "default",
**kwargs
):
super().__init__(**kwargs)
self.out_features = out_features
@@ -62,6 +67,7 @@ class Linear(Module):
if bias:
b_shape = (out_features,)
self.bias = Parameter(np.zeros(b_shape, dtype=np.float32))
self.compute_mode = compute_mode
self.reset_parameters()

def _get_fanin(self):
@@ -75,7 +81,7 @@ class Linear(Module):
init.zeros_(self.bias)

def _calc_linear(self, x, weight, bias):
return linear(x, weight, bias)
return linear(x, weight, bias, compute_mode=self.compute_mode)

def forward(self, x):
return self._calc_linear(x, self.weight, self.bias)


+ 0
- 2
imperative/python/megengine/quantization/internal_fake_quant.py View File

@@ -5,8 +5,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
import math
from functools import partial

from .. import functional as F


+ 1
- 1
imperative/python/megengine/tensor.py View File

@@ -14,7 +14,7 @@ from .core._imperative_rt import CompNode
from .core._imperative_rt.core2 import Tensor as _Tensor
from .core._imperative_rt.core2 import apply
from .core._trace_option import use_symbolic_shape
from .core._wrap import device as as_device
from .core._wrap import as_device
from .core.ops.builtin import Copy, GetVarShape
from .core.tensor.array_method import ArrayMethodMixin
from .device import _valid_device, get_default_device


+ 34
- 0
imperative/python/test/unit/amp/test_autocast.py View File

@@ -0,0 +1,34 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import amp
from megengine.core.tensor import amp as origin_amp


def test_grad_scaler():
def check(enabled, low, high):
assert amp.enabled == enabled
assert origin_amp._enabled == enabled
assert amp.low_prec_dtype == low
assert origin_amp._low_prec_dtype == low
assert amp.high_prec_dtype == high
assert origin_amp._high_prec_dtype == high

origin_enabled = amp.enabled
origin_high = amp.high_prec_dtype
origin_low = amp.low_prec_dtype
with amp.autocast(low_prec_dtype="low", high_prec_dtype="high"):
check(True, "low", "high")
check(origin_enabled, origin_low, origin_high)
amp.enabled = True
amp.high_prec_dtype = "high"
amp.low_prec_dtype = "low"
check(True, "low", "high")
amp.enabled = origin_enabled
amp.high_prec_dtype = origin_high
amp.low_prec_dtype = origin_low
check(origin_enabled, origin_low, origin_high)

+ 48
- 4
imperative/python/test/unit/functional/test_functional.py View File

@@ -14,6 +14,7 @@ import numpy as np
import pytest
from utils import opr_test

import megengine.amp as amp
import megengine.core.ops.builtin as builtin
import megengine.core.tensor.dtype as dtype
import megengine.functional as F
@@ -767,6 +768,27 @@ def test_batch_conv_bias():
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True)


def test_conv2d_io16c32():
amp.enabled = True
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32)
weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32)
out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
amp.enabled = False
expected = F.conv2d(
inp.astype("float16"),
weight.astype("float16"),
None,
(2, 2),
(3, 3),
(1, 1),
1,
compute_mode="float32",
)
assert out.dtype == np.float16
assert expected.dtype == np.float16
np.testing.assert_allclose(out.numpy(), expected.numpy())


def test_conv2d_zero_stride_numpy_array():
inp = np.random.randn(3, 224, 224).astype(np.float32)
inp = inp[np.newaxis, :]
@@ -787,8 +809,8 @@ def test_conv3d_zero_stride_numpy_array():


def test_conv1d():
inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4))
weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2))
inp = tensor(np.ones((2, 2, 4), dtype=np.float32))
weight = tensor(np.ones((3, 2, 2), dtype=np.float32))
out = F.conv1d(inp, weight, None, 2, 0, 1, 1)
np.testing.assert_equal(
out.numpy(),
@@ -798,9 +820,31 @@ def test_conv1d():
)


def test_batchnorm2d_io16c32():
amp.enabled = True
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32)
weight = tensor(np.ones((1, 3, 1, 1)), dtype=np.float32)
bias = tensor(np.zeros((1, 3, 1, 1)), dtype=np.float32)

out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False)

amp.enabled = False
expected = F.batch_norm(
inp.astype("float16"),
weight=weight,
bias=bias,
training=True,
inplace=False,
compute_mode="float32",
)
assert out.dtype == np.float16
assert expected.dtype == np.float16
np.testing.assert_allclose(out.numpy(), expected.numpy())


def test_conv3d():
inp = tensor(np.ones((256,), dtype=np.float32).reshape(2, 2, 4, 4, 4))
weight = tensor(np.ones((48,), dtype=np.float32).reshape(3, 2, 2, 2, 2))
inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32))
weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32))
out = F.conv3d(inp, weight, None, 2, 0, 1, 1)
print(out.numpy().shape)
np.testing.assert_equal(


+ 0
- 33
imperative/python/test/unit/module/test_module.py View File

@@ -473,39 +473,6 @@ def test_pickle_module():
np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6)


def test_load_quantized():
from megengine.core.tensor import dtype

data_shape = (2, 28)
data = tensor(np.random.random(data_shape), dtype="float32")
data = data.astype(dtype.qint8(0.1))
mlp = MLP()
quantize_qat(mlp)
quantize(mlp)
mlp.dense0.weight = Parameter(mlp.dense0.weight.astype(dtype.qint8(0.001)).numpy())
mlp.dense1.weight = Parameter(mlp.dense1.weight.astype(dtype.qint8(0.0002)).numpy())
mlp.eval()
pred0 = mlp(data)

with BytesIO() as fout:
mge.save(mlp.state_dict(), fout)
fout.seek(0)
checkpoint = mge.load(fout)
# change mlp weight.
mlp.dense0.weight = Parameter(
mlp.dense0.weight.astype(dtype.qint8(0.00001)).numpy()
)
mlp.dense1.weight = Parameter(
mlp.dense1.weight.astype(dtype.qint8(0.2)).numpy()
)
mlp.load_state_dict(checkpoint)
pred1 = mlp(data)

np.testing.assert_allclose(
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), atol=5e-6
)


def test_repr_basic():
# test whether __repr__ can output correct information
class ConvModel(Module):


Loading…
Cancel
Save