diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index cf3c9883..ef453fc7 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -117,6 +117,7 @@ def _atexit(handler): # subpackages +import megengine.amp import megengine.autodiff import megengine.data import megengine.distributed diff --git a/imperative/python/megengine/amp/__init__.py b/imperative/python/megengine/amp/__init__.py new file mode 100644 index 00000000..29be0ddd --- /dev/null +++ b/imperative/python/megengine/amp/__init__.py @@ -0,0 +1,14 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +import mprop + +from ..core.tensor.amp import * +from .autocast import autocast + +mprop.init() diff --git a/imperative/python/megengine/amp/autocast.py b/imperative/python/megengine/amp/autocast.py new file mode 100644 index 00000000..03163034 --- /dev/null +++ b/imperative/python/megengine/amp/autocast.py @@ -0,0 +1,79 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import functools + +from ..core.tensor import amp + + +class autocast: + r""" + A class to control autocast mode for amp as a context manager or a decorator. + + :param enabled: Whether autocast mode is enabled. + :low_prec_dtype: Set amp autocast mode's lower precision dtype. It will change the + target dtype in tensor casting for better speed and memory. Default: float16. + :high_prec_dtype: Set amp autocast mode's higher precision dtype. It will change the + target dtype in tensor casting for better precision. Default: float32. + + Examples: + + ..code-block:: + + # used as decorator + @autocast() + def train_step(image, label): + with gm: + logits = model(image) + loss = F.nn.cross_entropy(logits, label) + gm.backward(loss) + opt.step().clear_grad() + return loss + + # used as context manager + def train_step(image, label): + with autocast(): + with gm: + logits = model(image) + loss = F.nn.cross_entropy(logits, label) + gm.backward(loss) + opt.step().clear_grad() + return loss + """ + + def __init__( + self, + enabled: bool = True, + low_prec_dtype: str = "float16", + high_prec_dtype: str = "float32", + ): + self.enabled = enabled + self.high_prec_dtype = high_prec_dtype + self.low_prec_dtype = low_prec_dtype + self._origin_enabled = None + self._origin_high = None + self._origin_low = None + + def __enter__(self): + self._origin_enabled, amp._enabled = amp._enabled, self.enabled + self._origin_high = amp._high_prec_dtype + amp._high_prec_dtype = self.high_prec_dtype + self._origin_low = amp._low_prec_dtype + amp._low_prec_dtype = self.low_prec_dtype + + def __exit__(self, *args): + amp._enabled = self._origin_enabled + amp._high_prec_dtype = self._origin_high + amp._low_prec_dtype = self._origin_low + + def __call__(self, func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapper diff --git a/imperative/python/megengine/core/_wrap.py b/imperative/python/megengine/core/_wrap.py index 7f57120c..f1650509 100644 --- a/imperative/python/megengine/core/_wrap.py +++ b/imperative/python/megengine/core/_wrap.py @@ -49,7 +49,7 @@ class Device: return self._cn == rhs._cn -def device(obj): +def as_device(obj): if isinstance(obj, Device): return obj return Device(obj) diff --git a/imperative/python/megengine/core/tensor/amp.py b/imperative/python/megengine/core/tensor/amp.py new file mode 100644 index 00000000..4f8e49a3 --- /dev/null +++ b/imperative/python/megengine/core/tensor/amp.py @@ -0,0 +1,78 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +_enabled = False +_high_prec_dtype = "float32" +_low_prec_dtype = "float16" + + +@property +def enabled(mod): + r""" + Get or set amp autocast mode enabled or not. + + Examples: + + ..code-block:: + + import megengine as mge + mge.amp.enabled = True + + """ + return _enabled + + +@enabled.setter +def enabled(mod, enabled: bool): + global _enabled + _enabled = enabled + + +@property +def high_prec_dtype(mod): + r""" + Get or set amp autocast mode's higher precision dtype. It will change the + target dtype in tensor casting for better precision. Default: float32. + + Examples: + + ..code-block:: + + import megengine as mge + mge.amp.high_prec_dtype = "float32" + + """ + return _high_prec_dtype + + +@high_prec_dtype.setter +def high_prec_dtype(mod, dtype: str): + global _high_prec_dtype + _high_prec_dtype = dtype + + +@property +def low_prec_dtype(mod): + r""" + Get or set amp autocast mode's lower precision dtype. It will change the + target dtype in tensor casting for better speed and memory. Default: float16. + + Examples: + + ..code-block:: + + import megengine as mge + mge.amp.low_prec_dtype = "float16" + + """ + return _low_prec_dtype + + +@low_prec_dtype.setter +def low_prec_dtype(mod, dtype: str): + global _low_prec_dtype + _low_prec_dtype = dtype diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 496d71fd..aed4cb16 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -15,15 +15,20 @@ import numpy as np from .._imperative_rt.common import CompNode from .._imperative_rt.core2 import SymbolVar, Tensor, apply from ..ops import builtin -from ..ops.builtin import Elemwise, GetVarShape -from . import utils -from .indexing import getitem as _getitem -from .indexing import setitem as _setitem -from .utils import isscalar -from .utils import make_shape_tuple as _make_shape_tuple -from .utils import setscalar - -_ElwMod = Elemwise.Mode +from . import amp +from .indexing import getitem, setitem +from .utils import ( + _normalize_axis, + astensor1d, + astype, + cast_tensors, + convert_inputs, + isscalar, + make_shape_tuple, + setscalar, +) + +_ElwMod = builtin.Elemwise.Mode def _elwise_apply(args, mode): @@ -40,47 +45,59 @@ def _elwise_apply(args, mode): def _elwise(*args, mode): + args = convert_inputs(*args) if mode in ( _ElwMod.TRUE_DIV, + _ElwMod.EXP, _ElwMod.POW, - _ElwMod.CEIL, - _ElwMod.FLOOR, - _ElwMod.ROUND, + _ElwMod.LOG, + _ElwMod.EXPM1, + _ElwMod.LOG1P, + _ElwMod.TANH, + _ElwMod.ACOS, + _ElwMod.ASIN, + _ElwMod.ATAN2, + _ElwMod.COS, + _ElwMod.H_SWISH, + _ElwMod.SIGMOID, + _ElwMod.SIN, + ) and ( + amp._enabled or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) ): - if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND) and np.issubdtype( - args[0].dtype, np.integer - ): - return args[0] - args = tuple( - map( - lambda x: x.astype("float32") - if hasattr(x, "dtype") and x.dtype != np.float32 - else x, - args, - ) - ) - args = utils.convert_inputs(*args) + # autocast to FP32 to maintain precision + # or to avoid op's not supporting all int args + args = cast_tensors(*args, promote=True) + + if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND,) and np.issubdtype( + args[0].dtype, np.integer + ): + return args[0] return _elwise_apply(args, mode) def _matmul(inp1, inp2): + if amp._enabled: + compute_mode = "float32" + inp1, inp2 = cast_tensors(inp1, inp2) + else: + compute_mode = "default" + inp1, inp2 = convert_inputs(inp1, inp2) op = builtin.MatrixMul( - transposeA=False, transposeB=False, compute_mode="default", format="default" + transposeA=False, transposeB=False, compute_mode=compute_mode, format="default" ) - inp1, inp2 = utils.convert_inputs(inp1, inp2) (result,) = apply(op, inp1, inp2) return result def _transpose(data, axes): op = builtin.Dimshuffle(axes) - (data,) = utils.convert_inputs(data) + (data,) = convert_inputs(data) (result,) = apply(op, data) return result def _broadcast(inp, shape): - shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) + shape = astensor1d(shape, inp, dtype="int32", device=inp.device) (result,) = apply(builtin.Broadcast(), inp, shape) return result @@ -88,7 +105,7 @@ def _broadcast(inp, shape): def _reshape(x, shape): unspec_axis = None try: - shape_tuple = _make_shape_tuple(shape) + shape_tuple = make_shape_tuple(shape) except ValueError: pass else: @@ -102,7 +119,7 @@ def _reshape(x, shape): "multiple -1 in shape: {} & {}".format(unspec_axis, i) ) unspec_axis = i - shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) + shape = astensor1d(shape, x, dtype="int32", device=x.device) if unspec_axis is None: op = builtin.Reshape() else: @@ -171,7 +188,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: return list(map(int, axis)) axis = get_axes() - axis = utils._normalize_axis(inp.ndim, axis) + axis = _normalize_axis(inp.ndim, axis) axis = [a - i for i, a in enumerate(axis)] op = builtin.RemoveAxis(axis=axis) @@ -184,7 +201,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: def _reduce(mode): def f(self, axis=None, keepdims: bool = False): data = self - (data,) = utils.convert_inputs(data) + (data,) = convert_inputs(data) if mode == "mean": data = data.astype("float32") elif self.dtype == np.bool_: @@ -196,7 +213,7 @@ def _reduce(mode): op = builtin.Reduce(mode=mode, axis=0) (result,) = apply(op, data) elif isinstance(axis, collections.abc.Iterable): - axis = utils._normalize_axis(self.ndim, axis, reverse=True) + axis = _normalize_axis(self.ndim, axis, reverse=True) for ai in axis: op = builtin.Reduce(mode=mode, axis=ai) (data,) = apply(op, data) @@ -359,11 +376,11 @@ class ArrayMethodMixin(abc.ABC): yield self[i] def __getitem__(self, index): - return _getitem(self, index) + return getitem(self, index) def __setitem__(self, index, value): if index is not Ellipsis: - value = _setitem(self, index, value) + value = setitem(self, index, value) self._reset(value) __contains__ = _todo @@ -422,7 +439,7 @@ class ArrayMethodMixin(abc.ABC): Returns a :class:`Tensor` with the same data and number of elements with the specified :attr:`~.Tensor.dtype`. """ - return utils.astype(self, dtype) + return astype(self, dtype) def reshape(self, *args): r""" diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index ea86c46e..02efb4b4 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -18,7 +18,7 @@ import numpy as np from .. import _imperative_rt from .._imperative_rt import GraphOptimizeOptions from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode -from .._wrap import device as as_device +from .._wrap import as_device from ..ops.builtin import OpDef from .core import TensorBase diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index b52b01f0..201366a1 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -13,9 +13,10 @@ import numpy as np from .._imperative_rt import make_const from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device -from .._wrap import device as as_device +from .._wrap import as_device from ..ops import builtin from ..ops.special import Const +from .amp import _high_prec_dtype, _low_prec_dtype from .dtype import is_dtype_equal, is_quantize _enable_convert_inputs = True @@ -98,6 +99,14 @@ def convert_inputs(*args, device=None): return tuple(map(convert, args)) +def cast_tensors(*args, promote=False): + if promote: + dtype = _high_prec_dtype + else: + dtype = _low_prec_dtype + return tuple(arg.astype(dtype) if arg is not None else None for arg in args) + + def result_type(*args): dtypes = [] for i in args: diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 4a8c4d7b..a89ce790 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -12,10 +12,8 @@ import numpy as np from ..core._imperative_rt.core2 import SymbolVar, apply from ..core.ops import builtin from ..core.ops.builtin import Elemwise -from ..core.tensor import utils -from ..core.tensor.array_method import _elwise_apply -from ..core.tensor.utils import astype -from ..device import get_default_device +from ..core.tensor.array_method import _elwise +from ..core.tensor.utils import astype, convert_inputs from ..tensor import Tensor from ..utils.deprecation import deprecated_func @@ -69,46 +67,9 @@ __all__ = [ ] -def _elwise(*args, mode): - tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args)) - if len(tensor_args) == 0: - dtype = utils.dtype_promotion(args) - first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) - args = utils.convert_inputs(first_arg, *args[1:]) - else: - args = utils.convert_inputs(*args) - if mode in ( - Elemwise.Mode.TRUE_DIV, - Elemwise.Mode.EXP, - Elemwise.Mode.POW, - Elemwise.Mode.LOG, - Elemwise.Mode.EXPM1, - Elemwise.Mode.LOG1P, - Elemwise.Mode.TANH, - Elemwise.Mode.ACOS, - Elemwise.Mode.ASIN, - Elemwise.Mode.ATAN2, - Elemwise.Mode.CEIL, - Elemwise.Mode.COS, - Elemwise.Mode.FLOOR, - Elemwise.Mode.H_SWISH, - Elemwise.Mode.ROUND, - Elemwise.Mode.SIGMOID, - Elemwise.Mode.SIN, - ): - if mode in ( - Elemwise.Mode.CEIL, - Elemwise.Mode.FLOOR, - Elemwise.Mode.ROUND, - ) and np.issubdtype(args[0].dtype, np.integer): - return args[0] - args = tuple(map(lambda x: astype(x, "float32"), args)) - return _elwise_apply(args, mode) - - def _elemwise_multi_type(*args, mode, **kwargs): op = builtin.ElemwiseMultiType(mode=mode, **kwargs) - args = utils.convert_inputs(*args) + args = convert_inputs(*args) (result,) = apply(op, *args) return result diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index fbc021a6..9c796296 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -14,7 +14,8 @@ from ..core._imperative_rt.core2 import apply from ..core._trace_option import use_symbolic_shape from ..core.ops import builtin from ..core.ops.special import Const -from ..core.tensor import utils +from ..core.tensor import amp +from ..core.tensor.utils import _normalize_axis, cast_tensors, convert_inputs, setscalar from ..tensor import Tensor from .debug_param import get_execution_strategy from .elemwise import clip, exp, log, log1p @@ -471,7 +472,7 @@ def argmin( inp = inp.flatten() axis = 0 - axis = utils._normalize_axis(inp.ndim, axis, reverse=True) + axis = _normalize_axis(inp.ndim, axis, reverse=True) if isinstance(axis, collections.abc.Iterable): for ai in axis: @@ -528,7 +529,7 @@ def argmax( assert not keepdims, "can not set axis=None and keepdims=True" inp = inp.flatten() axis = 0 - axis = utils._normalize_axis(inp.ndim, axis, reverse=True) + axis = _normalize_axis(inp.ndim, axis, reverse=True) if isinstance(axis, collections.abc.Iterable): @@ -807,8 +808,13 @@ def matmul( [28. 40.]] """ + if amp._enabled: + compute_mode = "float32" + inp1, inp2 = cast_tensors(inp1, inp2) + else: + inp1, inp2 = convert_inputs(inp1, inp2) + remove_row, remove_col = False, False - inp1, inp2 = utils.convert_inputs(inp1, inp2) dim1, dim2 = inp1.ndim, inp2.ndim # handle dim=1 cases, dot and matrix-vector multiplication @@ -921,12 +927,12 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor: """ op = builtin.Dot() - inp1, inp2 = utils.convert_inputs(inp1, inp2) + inp1, inp2 = convert_inputs(inp1, inp2) assert ( inp1.ndim <= 1 and inp2.ndim <= 1 ), "Input tensors for dot must be 1-dimensional or scalar" (result,) = apply(op, inp1, inp2) - utils.setscalar(result) + setscalar(result) return result diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index ffe7b7c7..5dab1d87 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -15,9 +15,16 @@ from ..core._trace_option import use_symbolic_shape from ..core.ops import builtin from ..core.ops.builtin import BatchNorm, Elemwise from ..core.ops.special import Const -from ..core.tensor import megbrain_graph, utils +from ..core.tensor import amp, megbrain_graph from ..core.tensor.array_method import _elwise_apply -from ..core.tensor.utils import astensor1d, astype, setscalar +from ..core.tensor.utils import ( + astensor1d, + astype, + cast_tensors, + convert_inputs, + convert_single_value, + setscalar, +) from ..device import get_default_device from ..distributed import WORLD, is_distributed from ..random import uniform @@ -91,7 +98,9 @@ def expand_hw(x): return int(h), int(w) -def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: +def linear( + inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None, compute_mode="default", +) -> Tensor: """ Applies a linear transformation to the input tensor. @@ -102,8 +111,10 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor :param bias: bias with shape `(out_features,)`. Default: None """ - ret = matmul(inp, weight, transpose_b=True) + ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode) if bias is not None: + if amp._enabled: + bias = bias.astype("float16") ret += bias return ret @@ -153,6 +164,11 @@ def conv1d( assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" assert inp.ndim == 3, "the input dimension of conv1d should be 3" assert weight.ndim == 3, "the weight dimension of conv1d should be 3" + if amp._enabled: + compute_mode = "float32" + inp, weight, bias = cast_tensors(inp, weight, bias) + else: + inp, weight = convert_inputs(inp, weight) inp = expand_dims(inp, 3) weight = expand_dims(weight, 3) @@ -177,7 +193,6 @@ def conv1d( compute_mode=compute_mode, sparse=sparse_type, ) - inp, weight = utils.convert_inputs(inp, weight) (output,) = apply(op, inp, weight) if bias is not None: output += bias @@ -228,7 +243,11 @@ def conv2d( conv_mode.lower() == "cross_correlation" or conv_mode.name == "CROSS_CORRELATION" ) - assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" + if amp._enabled: + compute_mode = "float32" + inp, weight, bias = cast_tensors(inp, weight, bias) + else: + inp, weight = convert_inputs(inp, weight) stride_h, stride_w = expand_hw(stride) pad_h, pad_w = expand_hw(padding) @@ -247,7 +266,6 @@ def conv2d( compute_mode=compute_mode, sparse=sparse_type, ) - inp, weight = utils.convert_inputs(inp, weight) (output,) = apply(op, inp, weight) if bias is not None: output += bias @@ -286,6 +304,7 @@ def conv3d( :return: output tensor. """ assert conv_mode.lower() == "cross_correlation" + inp, weight = convert_inputs(inp, weight) D, H, W = 0, 1, 2 @@ -308,7 +327,6 @@ def conv3d( mode=conv_mode, sparse=sparse_type, ) - inp, weight = utils.convert_inputs(inp, weight) (output,) = apply(op, inp, weight) if bias is not None: output += bias @@ -358,7 +376,11 @@ def conv_transpose2d( conv_mode.lower() == "cross_correlation" or conv_mode.name == "CROSS_CORRELATION" ) - assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" + if amp._enabled: + compute_mode = "float32" + inp, weight, bias = cast_tensors(inp, weight, bias) + else: + inp, weight = convert_inputs(inp, weight) if groups != 1: raise NotImplementedError("group transposed conv2d is not supported yet.") @@ -375,8 +397,8 @@ def conv_transpose2d( dilate_h=dilate_h, dilate_w=dilate_w, strategy=get_execution_strategy(), + compute_mode=compute_mode, ) - weight, inp = utils.convert_inputs(weight, inp) (output,) = apply(op, weight, inp) if bias is not None: output += bias @@ -428,7 +450,11 @@ def deformable_conv2d( conv_mode.lower() == "cross_correlation" or conv_mode.name == "CROSS_CORRELATION" ) - assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" + if amp._enabled: + compute_mode = "float32" + inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias) + else: + inp, weight, offset, mask = convert_inputs(inp, weight, offset, mask) stride_h, stride_w = expand_hw(stride) pad_h, pad_w = expand_hw(padding) @@ -447,7 +473,6 @@ def deformable_conv2d( compute_mode=compute_mode, sparse=sparse_type, ) - inp, weight, offset, mask = utils.convert_inputs(inp, weight, offset, mask) (output,) = apply(op, inp, weight, offset, mask) if bias is not None: output += bias @@ -468,6 +493,7 @@ def local_conv2d( conv_mode.lower() == "cross_correlation" or conv_mode.name == "CROSS_CORRELATION" ) + inp, weight = convert_inputs(inp, weight) stride_h, stride_w = expand_hw(stride) pad_h, pad_w = expand_hw(padding) @@ -481,10 +507,8 @@ def local_conv2d( dilate_h=dilate_h, dilate_w=dilate_w, mode=conv_mode, - compute_mode="default", sparse="dense", ) - inp, weight = utils.convert_inputs(inp, weight) (output,) = apply(op, inp, weight) if bias is not None: output += bias @@ -515,8 +539,9 @@ def conv_transpose3d( :param dilation: dilation of the 3D convolution operation. Default: 1 :return: output tensor. """ - D, H, W = 0, 1, 2 + inp, weight = convert_inputs(inp, weight) + D, H, W = 0, 1, 2 pad = _triple(padding) stride = _triple_nonzero(stride) dilate = _triple_nonzero(dilation) @@ -533,7 +558,6 @@ def conv_transpose3d( dilate_w=dilate[W], strategy=get_execution_strategy(), ) - weight, inp = utils.convert_inputs(weight, inp) (output,) = apply(op, weight, inp) if bias is not None: output += bias @@ -994,7 +1018,8 @@ def batch_norm( training: bool = False, momentum: float = 0.9, eps: float = 1e-5, - inplace: bool = True + inplace: bool = True, + compute_mode="default", ): r""" Applies batch normalization to the input. @@ -1027,15 +1052,11 @@ def batch_norm( def make_full_if_none(x, value): if x is None: (x,) = Const(value, dtype=inp.dtype, device=inp.device)() - shape = utils.astensor1d( - (1, C, 1, 1), inp, dtype="int32", device=inp.device - ) + shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) (result,) = apply(builtin.Broadcast(), x, shape) return result elif x.ndim == 1: - shape = utils.astensor1d( - (1, C, 1, 1), inp, dtype="int32", device=inp.device - ) + shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) (result,) = apply(builtin.Reshape(), x, shape) return result return x @@ -1052,10 +1073,15 @@ def batch_norm( if has_var and running_var.ndim != 4: raise ValueError - inp, weight, bias, running_mean, running_var = utils.convert_inputs( - inp, weight, bias, running_mean, running_var - ) - + if amp._enabled: + inp = inp.astype("float16") + weight, bias, running_mean, running_var = cast_tensors( + weight, bias, running_mean, running_var, promote=True + ) + elif compute_mode != "float32": + inp, weight, bias, running_mean, running_var = convert_inputs( + inp, weight, bias, running_mean, running_var + ) weight = make_full_if_none(weight, 1) bias = make_full_if_none(bias, 0) @@ -1352,7 +1378,7 @@ def indexing_one_hot( """ assert isinstance(src, Tensor), "src must be of Tensor type" op = builtin.IndexingOneHot(axis=axis) - index = utils.convert_single_value(index, dtype="int32", device=src.device) + index = convert_single_value(index, dtype="int32", device=src.device) (result,) = apply(op, src, index) if not keepdims: result = squeeze(result, axis) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 269f206b..cab093a9 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -13,7 +13,7 @@ import numpy as np from ..core._imperative_rt import CompNode from ..core._imperative_rt.core2 import SymbolVar, apply -from ..core._wrap import device as as_device +from ..core._wrap import as_device from ..core.ops import builtin from ..core.ops.builtin import Copy, Identity from ..core.ops.special import Const diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 28bfc7e1..43310e6a 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -33,7 +33,7 @@ from ..core._imperative_rt.ops import ( RemoteSend, ) from ..core._trace_option import set_symbolic_shape -from ..core._wrap import device as as_device +from ..core._wrap import as_device from ..core.ops.builtin import BatchNorm, OpDef from ..core.ops.special import Const from ..core.tensor import megbrain_graph as G diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index 7428d0c5..ad188d1a 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -26,6 +26,7 @@ class _BatchNorm(Module): affine=True, track_running_stats=True, freeze=False, + compute_mode="default", **kwargs ): super(_BatchNorm, self).__init__(**kwargs) @@ -36,6 +37,7 @@ class _BatchNorm(Module): self.track_running_stats = track_running_stats self._track_running_stats_saved = track_running_stats self.freeze = freeze + self.compute_mode = compute_mode if self.freeze: assert ( self._track_running_stats_saved @@ -123,6 +125,7 @@ class _BatchNorm(Module): or ((self.running_mean is None) and (self.running_var is None)), momentum=exponential_average_factor, eps=self.eps, + compute_mode=self.compute_mode, ) if _ndims != 4: diff --git a/imperative/python/megengine/module/linear.py b/imperative/python/megengine/module/linear.py index a8544547..136040d8 100644 --- a/imperative/python/megengine/module/linear.py +++ b/imperative/python/megengine/module/linear.py @@ -51,7 +51,12 @@ class Linear(Module): """ def __init__( - self, in_features: int, out_features: int, bias: bool = True, **kwargs + self, + in_features: int, + out_features: int, + bias: bool = True, + compute_mode: str = "default", + **kwargs ): super().__init__(**kwargs) self.out_features = out_features @@ -62,6 +67,7 @@ class Linear(Module): if bias: b_shape = (out_features,) self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) + self.compute_mode = compute_mode self.reset_parameters() def _get_fanin(self): @@ -75,7 +81,7 @@ class Linear(Module): init.zeros_(self.bias) def _calc_linear(self, x, weight, bias): - return linear(x, weight, bias) + return linear(x, weight, bias, compute_mode=self.compute_mode) def forward(self, x): return self._calc_linear(x, self.weight, self.bias) diff --git a/imperative/python/megengine/quantization/internal_fake_quant.py b/imperative/python/megengine/quantization/internal_fake_quant.py index 040a70af..7972c622 100644 --- a/imperative/python/megengine/quantization/internal_fake_quant.py +++ b/imperative/python/megengine/quantization/internal_fake_quant.py @@ -5,8 +5,6 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import copy -import math from functools import partial from .. import functional as F diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index cc17d924..8b09a1d1 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -14,7 +14,7 @@ from .core._imperative_rt import CompNode from .core._imperative_rt.core2 import Tensor as _Tensor from .core._imperative_rt.core2 import apply from .core._trace_option import use_symbolic_shape -from .core._wrap import device as as_device +from .core._wrap import as_device from .core.ops.builtin import Copy, GetVarShape from .core.tensor.array_method import ArrayMethodMixin from .device import _valid_device, get_default_device diff --git a/imperative/python/test/unit/amp/test_autocast.py b/imperative/python/test/unit/amp/test_autocast.py new file mode 100644 index 00000000..0aff6aa0 --- /dev/null +++ b/imperative/python/test/unit/amp/test_autocast.py @@ -0,0 +1,34 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from megengine import amp +from megengine.core.tensor import amp as origin_amp + + +def test_grad_scaler(): + def check(enabled, low, high): + assert amp.enabled == enabled + assert origin_amp._enabled == enabled + assert amp.low_prec_dtype == low + assert origin_amp._low_prec_dtype == low + assert amp.high_prec_dtype == high + assert origin_amp._high_prec_dtype == high + + origin_enabled = amp.enabled + origin_high = amp.high_prec_dtype + origin_low = amp.low_prec_dtype + with amp.autocast(low_prec_dtype="low", high_prec_dtype="high"): + check(True, "low", "high") + check(origin_enabled, origin_low, origin_high) + amp.enabled = True + amp.high_prec_dtype = "high" + amp.low_prec_dtype = "low" + check(True, "low", "high") + amp.enabled = origin_enabled + amp.high_prec_dtype = origin_high + amp.low_prec_dtype = origin_low + check(origin_enabled, origin_low, origin_high) diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 61f2adc2..977feb6b 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -14,6 +14,7 @@ import numpy as np import pytest from utils import opr_test +import megengine.amp as amp import megengine.core.ops.builtin as builtin import megengine.core.tensor.dtype as dtype import megengine.functional as F @@ -767,6 +768,27 @@ def test_batch_conv_bias(): run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) +def test_conv2d_io16c32(): + amp.enabled = True + inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) + weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32) + out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) + amp.enabled = False + expected = F.conv2d( + inp.astype("float16"), + weight.astype("float16"), + None, + (2, 2), + (3, 3), + (1, 1), + 1, + compute_mode="float32", + ) + assert out.dtype == np.float16 + assert expected.dtype == np.float16 + np.testing.assert_allclose(out.numpy(), expected.numpy()) + + def test_conv2d_zero_stride_numpy_array(): inp = np.random.randn(3, 224, 224).astype(np.float32) inp = inp[np.newaxis, :] @@ -787,8 +809,8 @@ def test_conv3d_zero_stride_numpy_array(): def test_conv1d(): - inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4)) - weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2)) + inp = tensor(np.ones((2, 2, 4), dtype=np.float32)) + weight = tensor(np.ones((3, 2, 2), dtype=np.float32)) out = F.conv1d(inp, weight, None, 2, 0, 1, 1) np.testing.assert_equal( out.numpy(), @@ -798,9 +820,31 @@ def test_conv1d(): ) +def test_batchnorm2d_io16c32(): + amp.enabled = True + inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) + weight = tensor(np.ones((1, 3, 1, 1)), dtype=np.float32) + bias = tensor(np.zeros((1, 3, 1, 1)), dtype=np.float32) + + out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False) + + amp.enabled = False + expected = F.batch_norm( + inp.astype("float16"), + weight=weight, + bias=bias, + training=True, + inplace=False, + compute_mode="float32", + ) + assert out.dtype == np.float16 + assert expected.dtype == np.float16 + np.testing.assert_allclose(out.numpy(), expected.numpy()) + + def test_conv3d(): - inp = tensor(np.ones((256,), dtype=np.float32).reshape(2, 2, 4, 4, 4)) - weight = tensor(np.ones((48,), dtype=np.float32).reshape(3, 2, 2, 2, 2)) + inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32)) + weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32)) out = F.conv3d(inp, weight, None, 2, 0, 1, 1) print(out.numpy().shape) np.testing.assert_equal( diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index 01e35471..0ff93a36 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -473,39 +473,6 @@ def test_pickle_module(): np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) -def test_load_quantized(): - from megengine.core.tensor import dtype - - data_shape = (2, 28) - data = tensor(np.random.random(data_shape), dtype="float32") - data = data.astype(dtype.qint8(0.1)) - mlp = MLP() - quantize_qat(mlp) - quantize(mlp) - mlp.dense0.weight = Parameter(mlp.dense0.weight.astype(dtype.qint8(0.001)).numpy()) - mlp.dense1.weight = Parameter(mlp.dense1.weight.astype(dtype.qint8(0.0002)).numpy()) - mlp.eval() - pred0 = mlp(data) - - with BytesIO() as fout: - mge.save(mlp.state_dict(), fout) - fout.seek(0) - checkpoint = mge.load(fout) - # change mlp weight. - mlp.dense0.weight = Parameter( - mlp.dense0.weight.astype(dtype.qint8(0.00001)).numpy() - ) - mlp.dense1.weight = Parameter( - mlp.dense1.weight.astype(dtype.qint8(0.2)).numpy() - ) - mlp.load_state_dict(checkpoint) - pred1 = mlp(data) - - np.testing.assert_allclose( - pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), atol=5e-6 - ) - - def test_repr_basic(): # test whether __repr__ can output correct information class ConvModel(Module):