Browse Source

feat(mge/amp): add mix precision autocast support

GitOrigin-RevId: 6fbffc4845
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
1bf18252c4
20 changed files with 403 additions and 160 deletions
  1. +1
    -0
      imperative/python/megengine/__init__.py
  2. +14
    -0
      imperative/python/megengine/amp/__init__.py
  3. +79
    -0
      imperative/python/megengine/amp/autocast.py
  4. +1
    -1
      imperative/python/megengine/core/_wrap.py
  5. +78
    -0
      imperative/python/megengine/core/tensor/amp.py
  6. +54
    -37
      imperative/python/megengine/core/tensor/array_method.py
  7. +1
    -1
      imperative/python/megengine/core/tensor/megbrain_graph.py
  8. +10
    -1
      imperative/python/megengine/core/tensor/utils.py
  9. +3
    -42
      imperative/python/megengine/functional/elemwise.py
  10. +12
    -6
      imperative/python/megengine/functional/math.py
  11. +54
    -28
      imperative/python/megengine/functional/nn.py
  12. +1
    -1
      imperative/python/megengine/functional/tensor.py
  13. +1
    -1
      imperative/python/megengine/jit/tracing.py
  14. +3
    -0
      imperative/python/megengine/module/batchnorm.py
  15. +8
    -2
      imperative/python/megengine/module/linear.py
  16. +0
    -2
      imperative/python/megengine/quantization/internal_fake_quant.py
  17. +1
    -1
      imperative/python/megengine/tensor.py
  18. +34
    -0
      imperative/python/test/unit/amp/test_autocast.py
  19. +48
    -4
      imperative/python/test/unit/functional/test_functional.py
  20. +0
    -33
      imperative/python/test/unit/module/test_module.py

+ 1
- 0
imperative/python/megengine/__init__.py View File

@@ -117,6 +117,7 @@ def _atexit(handler):




# subpackages # subpackages
import megengine.amp
import megengine.autodiff import megengine.autodiff
import megengine.data import megengine.data
import megengine.distributed import megengine.distributed


+ 14
- 0
imperative/python/megengine/amp/__init__.py View File

@@ -0,0 +1,14 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import mprop

from ..core.tensor.amp import *
from .autocast import autocast

mprop.init()

+ 79
- 0
imperative/python/megengine/amp/autocast.py View File

@@ -0,0 +1,79 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools

from ..core.tensor import amp


class autocast:
r"""
A class to control autocast mode for amp as a context manager or a decorator.

:param enabled: Whether autocast mode is enabled.
:low_prec_dtype: Set amp autocast mode's lower precision dtype. It will change the
target dtype in tensor casting for better speed and memory. Default: float16.
:high_prec_dtype: Set amp autocast mode's higher precision dtype. It will change the
target dtype in tensor casting for better precision. Default: float32.

Examples:

..code-block::

# used as decorator
@autocast()
def train_step(image, label):
with gm:
logits = model(image)
loss = F.nn.cross_entropy(logits, label)
gm.backward(loss)
opt.step().clear_grad()
return loss

# used as context manager
def train_step(image, label):
with autocast():
with gm:
logits = model(image)
loss = F.nn.cross_entropy(logits, label)
gm.backward(loss)
opt.step().clear_grad()
return loss
"""

def __init__(
self,
enabled: bool = True,
low_prec_dtype: str = "float16",
high_prec_dtype: str = "float32",
):
self.enabled = enabled
self.high_prec_dtype = high_prec_dtype
self.low_prec_dtype = low_prec_dtype
self._origin_enabled = None
self._origin_high = None
self._origin_low = None

def __enter__(self):
self._origin_enabled, amp._enabled = amp._enabled, self.enabled
self._origin_high = amp._high_prec_dtype
amp._high_prec_dtype = self.high_prec_dtype
self._origin_low = amp._low_prec_dtype
amp._low_prec_dtype = self.low_prec_dtype

def __exit__(self, *args):
amp._enabled = self._origin_enabled
amp._high_prec_dtype = self._origin_high
amp._low_prec_dtype = self._origin_low

def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with self:
return func(*args, **kwargs)

return wrapper

+ 1
- 1
imperative/python/megengine/core/_wrap.py View File

@@ -49,7 +49,7 @@ class Device:
return self._cn == rhs._cn return self._cn == rhs._cn




def device(obj):
def as_device(obj):
if isinstance(obj, Device): if isinstance(obj, Device):
return obj return obj
return Device(obj) return Device(obj)

+ 78
- 0
imperative/python/megengine/core/tensor/amp.py View File

@@ -0,0 +1,78 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
_enabled = False
_high_prec_dtype = "float32"
_low_prec_dtype = "float16"


@property
def enabled(mod):
r"""
Get or set amp autocast mode enabled or not.

Examples:

..code-block::

import megengine as mge
mge.amp.enabled = True

"""
return _enabled


@enabled.setter
def enabled(mod, enabled: bool):
global _enabled
_enabled = enabled


@property
def high_prec_dtype(mod):
r"""
Get or set amp autocast mode's higher precision dtype. It will change the
target dtype in tensor casting for better precision. Default: float32.

Examples:

..code-block::

import megengine as mge
mge.amp.high_prec_dtype = "float32"

"""
return _high_prec_dtype


@high_prec_dtype.setter
def high_prec_dtype(mod, dtype: str):
global _high_prec_dtype
_high_prec_dtype = dtype


@property
def low_prec_dtype(mod):
r"""
Get or set amp autocast mode's lower precision dtype. It will change the
target dtype in tensor casting for better speed and memory. Default: float16.

Examples:

..code-block::

import megengine as mge
mge.amp.low_prec_dtype = "float16"

"""
return _low_prec_dtype


@low_prec_dtype.setter
def low_prec_dtype(mod, dtype: str):
global _low_prec_dtype
_low_prec_dtype = dtype

+ 54
- 37
imperative/python/megengine/core/tensor/array_method.py View File

@@ -15,15 +15,20 @@ import numpy as np
from .._imperative_rt.common import CompNode from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import SymbolVar, Tensor, apply from .._imperative_rt.core2 import SymbolVar, Tensor, apply
from ..ops import builtin from ..ops import builtin
from ..ops.builtin import Elemwise, GetVarShape
from . import utils
from .indexing import getitem as _getitem
from .indexing import setitem as _setitem
from .utils import isscalar
from .utils import make_shape_tuple as _make_shape_tuple
from .utils import setscalar

_ElwMod = Elemwise.Mode
from . import amp
from .indexing import getitem, setitem
from .utils import (
_normalize_axis,
astensor1d,
astype,
cast_tensors,
convert_inputs,
isscalar,
make_shape_tuple,
setscalar,
)

_ElwMod = builtin.Elemwise.Mode




def _elwise_apply(args, mode): def _elwise_apply(args, mode):
@@ -40,47 +45,59 @@ def _elwise_apply(args, mode):




def _elwise(*args, mode): def _elwise(*args, mode):
args = convert_inputs(*args)
if mode in ( if mode in (
_ElwMod.TRUE_DIV, _ElwMod.TRUE_DIV,
_ElwMod.EXP,
_ElwMod.POW, _ElwMod.POW,
_ElwMod.CEIL,
_ElwMod.FLOOR,
_ElwMod.ROUND,
_ElwMod.LOG,
_ElwMod.EXPM1,
_ElwMod.LOG1P,
_ElwMod.TANH,
_ElwMod.ACOS,
_ElwMod.ASIN,
_ElwMod.ATAN2,
_ElwMod.COS,
_ElwMod.H_SWISH,
_ElwMod.SIGMOID,
_ElwMod.SIN,
) and (
amp._enabled or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args])
): ):
if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND) and np.issubdtype(
args[0].dtype, np.integer
):
return args[0]
args = tuple(
map(
lambda x: x.astype("float32")
if hasattr(x, "dtype") and x.dtype != np.float32
else x,
args,
)
)
args = utils.convert_inputs(*args)
# autocast to FP32 to maintain precision
# or to avoid op's not supporting all int args
args = cast_tensors(*args, promote=True)

if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND,) and np.issubdtype(
args[0].dtype, np.integer
):
return args[0]
return _elwise_apply(args, mode) return _elwise_apply(args, mode)




def _matmul(inp1, inp2): def _matmul(inp1, inp2):
if amp._enabled:
compute_mode = "float32"
inp1, inp2 = cast_tensors(inp1, inp2)
else:
compute_mode = "default"
inp1, inp2 = convert_inputs(inp1, inp2)
op = builtin.MatrixMul( op = builtin.MatrixMul(
transposeA=False, transposeB=False, compute_mode="default", format="default"
transposeA=False, transposeB=False, compute_mode=compute_mode, format="default"
) )
inp1, inp2 = utils.convert_inputs(inp1, inp2)
(result,) = apply(op, inp1, inp2) (result,) = apply(op, inp1, inp2)
return result return result




def _transpose(data, axes): def _transpose(data, axes):
op = builtin.Dimshuffle(axes) op = builtin.Dimshuffle(axes)
(data,) = utils.convert_inputs(data)
(data,) = convert_inputs(data)
(result,) = apply(op, data) (result,) = apply(op, data)
return result return result




def _broadcast(inp, shape): def _broadcast(inp, shape):
shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device)
shape = astensor1d(shape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), inp, shape) (result,) = apply(builtin.Broadcast(), inp, shape)
return result return result


@@ -88,7 +105,7 @@ def _broadcast(inp, shape):
def _reshape(x, shape): def _reshape(x, shape):
unspec_axis = None unspec_axis = None
try: try:
shape_tuple = _make_shape_tuple(shape)
shape_tuple = make_shape_tuple(shape)
except ValueError: except ValueError:
pass pass
else: else:
@@ -102,7 +119,7 @@ def _reshape(x, shape):
"multiple -1 in shape: {} & {}".format(unspec_axis, i) "multiple -1 in shape: {} & {}".format(unspec_axis, i)
) )
unspec_axis = i unspec_axis = i
shape = utils.astensor1d(shape, x, dtype="int32", device=x.device)
shape = astensor1d(shape, x, dtype="int32", device=x.device)
if unspec_axis is None: if unspec_axis is None:
op = builtin.Reshape() op = builtin.Reshape()
else: else:
@@ -171,7 +188,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
return list(map(int, axis)) return list(map(int, axis))


axis = get_axes() axis = get_axes()
axis = utils._normalize_axis(inp.ndim, axis)
axis = _normalize_axis(inp.ndim, axis)
axis = [a - i for i, a in enumerate(axis)] axis = [a - i for i, a in enumerate(axis)]


op = builtin.RemoveAxis(axis=axis) op = builtin.RemoveAxis(axis=axis)
@@ -184,7 +201,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
def _reduce(mode): def _reduce(mode):
def f(self, axis=None, keepdims: bool = False): def f(self, axis=None, keepdims: bool = False):
data = self data = self
(data,) = utils.convert_inputs(data)
(data,) = convert_inputs(data)
if mode == "mean": if mode == "mean":
data = data.astype("float32") data = data.astype("float32")
elif self.dtype == np.bool_: elif self.dtype == np.bool_:
@@ -196,7 +213,7 @@ def _reduce(mode):
op = builtin.Reduce(mode=mode, axis=0) op = builtin.Reduce(mode=mode, axis=0)
(result,) = apply(op, data) (result,) = apply(op, data)
elif isinstance(axis, collections.abc.Iterable): elif isinstance(axis, collections.abc.Iterable):
axis = utils._normalize_axis(self.ndim, axis, reverse=True)
axis = _normalize_axis(self.ndim, axis, reverse=True)
for ai in axis: for ai in axis:
op = builtin.Reduce(mode=mode, axis=ai) op = builtin.Reduce(mode=mode, axis=ai)
(data,) = apply(op, data) (data,) = apply(op, data)
@@ -359,11 +376,11 @@ class ArrayMethodMixin(abc.ABC):
yield self[i] yield self[i]


def __getitem__(self, index): def __getitem__(self, index):
return _getitem(self, index)
return getitem(self, index)


def __setitem__(self, index, value): def __setitem__(self, index, value):
if index is not Ellipsis: if index is not Ellipsis:
value = _setitem(self, index, value)
value = setitem(self, index, value)
self._reset(value) self._reset(value)


__contains__ = _todo __contains__ = _todo
@@ -422,7 +439,7 @@ class ArrayMethodMixin(abc.ABC):
Returns a :class:`Tensor` with the same data and number of elements Returns a :class:`Tensor` with the same data and number of elements
with the specified :attr:`~.Tensor.dtype`. with the specified :attr:`~.Tensor.dtype`.
""" """
return utils.astype(self, dtype)
return astype(self, dtype)


def reshape(self, *args): def reshape(self, *args):
r""" r"""


+ 1
- 1
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -18,7 +18,7 @@ import numpy as np
from .. import _imperative_rt from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions from .._imperative_rt import GraphOptimizeOptions
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode
from .._wrap import device as as_device
from .._wrap import as_device
from ..ops.builtin import OpDef from ..ops.builtin import OpDef
from .core import TensorBase from .core import TensorBase




+ 10
- 1
imperative/python/megengine/core/tensor/utils.py View File

@@ -13,9 +13,10 @@ import numpy as np


from .._imperative_rt import make_const from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device
from .._wrap import device as as_device
from .._wrap import as_device
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
from .amp import _high_prec_dtype, _low_prec_dtype
from .dtype import is_dtype_equal, is_quantize from .dtype import is_dtype_equal, is_quantize


_enable_convert_inputs = True _enable_convert_inputs = True
@@ -98,6 +99,14 @@ def convert_inputs(*args, device=None):
return tuple(map(convert, args)) return tuple(map(convert, args))




def cast_tensors(*args, promote=False):
if promote:
dtype = _high_prec_dtype
else:
dtype = _low_prec_dtype
return tuple(arg.astype(dtype) if arg is not None else None for arg in args)


def result_type(*args): def result_type(*args):
dtypes = [] dtypes = []
for i in args: for i in args:


+ 3
- 42
imperative/python/megengine/functional/elemwise.py View File

@@ -12,10 +12,8 @@ import numpy as np
from ..core._imperative_rt.core2 import SymbolVar, apply from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import Elemwise from ..core.ops.builtin import Elemwise
from ..core.tensor import utils
from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import astype
from ..device import get_default_device
from ..core.tensor.array_method import _elwise
from ..core.tensor.utils import astype, convert_inputs
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.deprecation import deprecated_func from ..utils.deprecation import deprecated_func


@@ -69,46 +67,9 @@ __all__ = [
] ]




def _elwise(*args, mode):
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args))
if len(tensor_args) == 0:
dtype = utils.dtype_promotion(args)
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())
args = utils.convert_inputs(first_arg, *args[1:])
else:
args = utils.convert_inputs(*args)
if mode in (
Elemwise.Mode.TRUE_DIV,
Elemwise.Mode.EXP,
Elemwise.Mode.POW,
Elemwise.Mode.LOG,
Elemwise.Mode.EXPM1,
Elemwise.Mode.LOG1P,
Elemwise.Mode.TANH,
Elemwise.Mode.ACOS,
Elemwise.Mode.ASIN,
Elemwise.Mode.ATAN2,
Elemwise.Mode.CEIL,
Elemwise.Mode.COS,
Elemwise.Mode.FLOOR,
Elemwise.Mode.H_SWISH,
Elemwise.Mode.ROUND,
Elemwise.Mode.SIGMOID,
Elemwise.Mode.SIN,
):
if mode in (
Elemwise.Mode.CEIL,
Elemwise.Mode.FLOOR,
Elemwise.Mode.ROUND,
) and np.issubdtype(args[0].dtype, np.integer):
return args[0]
args = tuple(map(lambda x: astype(x, "float32"), args))
return _elwise_apply(args, mode)


def _elemwise_multi_type(*args, mode, **kwargs): def _elemwise_multi_type(*args, mode, **kwargs):
op = builtin.ElemwiseMultiType(mode=mode, **kwargs) op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
args = utils.convert_inputs(*args)
args = convert_inputs(*args)
(result,) = apply(op, *args) (result,) = apply(op, *args)
return result return result




+ 12
- 6
imperative/python/megengine/functional/math.py View File

@@ -14,7 +14,8 @@ from ..core._imperative_rt.core2 import apply
from ..core._trace_option import use_symbolic_shape from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import utils
from ..core.tensor import amp
from ..core.tensor.utils import _normalize_axis, cast_tensors, convert_inputs, setscalar
from ..tensor import Tensor from ..tensor import Tensor
from .debug_param import get_execution_strategy from .debug_param import get_execution_strategy
from .elemwise import clip, exp, log, log1p from .elemwise import clip, exp, log, log1p
@@ -471,7 +472,7 @@ def argmin(
inp = inp.flatten() inp = inp.flatten()
axis = 0 axis = 0


axis = utils._normalize_axis(inp.ndim, axis, reverse=True)
axis = _normalize_axis(inp.ndim, axis, reverse=True)
if isinstance(axis, collections.abc.Iterable): if isinstance(axis, collections.abc.Iterable):


for ai in axis: for ai in axis:
@@ -528,7 +529,7 @@ def argmax(
assert not keepdims, "can not set axis=None and keepdims=True" assert not keepdims, "can not set axis=None and keepdims=True"
inp = inp.flatten() inp = inp.flatten()
axis = 0 axis = 0
axis = utils._normalize_axis(inp.ndim, axis, reverse=True)
axis = _normalize_axis(inp.ndim, axis, reverse=True)


if isinstance(axis, collections.abc.Iterable): if isinstance(axis, collections.abc.Iterable):


@@ -807,8 +808,13 @@ def matmul(
[28. 40.]] [28. 40.]]


""" """
if amp._enabled:
compute_mode = "float32"
inp1, inp2 = cast_tensors(inp1, inp2)
else:
inp1, inp2 = convert_inputs(inp1, inp2)

remove_row, remove_col = False, False remove_row, remove_col = False, False
inp1, inp2 = utils.convert_inputs(inp1, inp2)


dim1, dim2 = inp1.ndim, inp2.ndim dim1, dim2 = inp1.ndim, inp2.ndim
# handle dim=1 cases, dot and matrix-vector multiplication # handle dim=1 cases, dot and matrix-vector multiplication
@@ -921,12 +927,12 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor:


""" """
op = builtin.Dot() op = builtin.Dot()
inp1, inp2 = utils.convert_inputs(inp1, inp2)
inp1, inp2 = convert_inputs(inp1, inp2)
assert ( assert (
inp1.ndim <= 1 and inp2.ndim <= 1 inp1.ndim <= 1 and inp2.ndim <= 1
), "Input tensors for dot must be 1-dimensional or scalar" ), "Input tensors for dot must be 1-dimensional or scalar"
(result,) = apply(op, inp1, inp2) (result,) = apply(op, inp1, inp2)
utils.setscalar(result)
setscalar(result)
return result return result






+ 54
- 28
imperative/python/megengine/functional/nn.py View File

@@ -15,9 +15,16 @@ from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise from ..core.ops.builtin import BatchNorm, Elemwise
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import megbrain_graph, utils
from ..core.tensor import amp, megbrain_graph
from ..core.tensor.array_method import _elwise_apply from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import astensor1d, astype, setscalar
from ..core.tensor.utils import (
astensor1d,
astype,
cast_tensors,
convert_inputs,
convert_single_value,
setscalar,
)
from ..device import get_default_device from ..device import get_default_device
from ..distributed import WORLD, is_distributed from ..distributed import WORLD, is_distributed
from ..random import uniform from ..random import uniform
@@ -91,7 +98,9 @@ def expand_hw(x):
return int(h), int(w) return int(h), int(w)




def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
def linear(
inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None, compute_mode="default",
) -> Tensor:
""" """
Applies a linear transformation to the input tensor. Applies a linear transformation to the input tensor.


@@ -102,8 +111,10 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor
:param bias: bias with shape `(out_features,)`. :param bias: bias with shape `(out_features,)`.
Default: None Default: None
""" """
ret = matmul(inp, weight, transpose_b=True)
ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode)
if bias is not None: if bias is not None:
if amp._enabled:
bias = bias.astype("float16")
ret += bias ret += bias
return ret return ret


@@ -153,6 +164,11 @@ def conv1d(
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
assert inp.ndim == 3, "the input dimension of conv1d should be 3" assert inp.ndim == 3, "the input dimension of conv1d should be 3"
assert weight.ndim == 3, "the weight dimension of conv1d should be 3" assert weight.ndim == 3, "the weight dimension of conv1d should be 3"
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
inp, weight = convert_inputs(inp, weight)


inp = expand_dims(inp, 3) inp = expand_dims(inp, 3)
weight = expand_dims(weight, 3) weight = expand_dims(weight, 3)
@@ -177,7 +193,6 @@ def conv1d(
compute_mode=compute_mode, compute_mode=compute_mode,
sparse=sparse_type, sparse=sparse_type,
) )
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight) (output,) = apply(op, inp, weight)
if bias is not None: if bias is not None:
output += bias output += bias
@@ -228,7 +243,11 @@ def conv2d(
conv_mode.lower() == "cross_correlation" conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
) )
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
inp, weight = convert_inputs(inp, weight)


stride_h, stride_w = expand_hw(stride) stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding) pad_h, pad_w = expand_hw(padding)
@@ -247,7 +266,6 @@ def conv2d(
compute_mode=compute_mode, compute_mode=compute_mode,
sparse=sparse_type, sparse=sparse_type,
) )
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight) (output,) = apply(op, inp, weight)
if bias is not None: if bias is not None:
output += bias output += bias
@@ -286,6 +304,7 @@ def conv3d(
:return: output tensor. :return: output tensor.
""" """
assert conv_mode.lower() == "cross_correlation" assert conv_mode.lower() == "cross_correlation"
inp, weight = convert_inputs(inp, weight)


D, H, W = 0, 1, 2 D, H, W = 0, 1, 2


@@ -308,7 +327,6 @@ def conv3d(
mode=conv_mode, mode=conv_mode,
sparse=sparse_type, sparse=sparse_type,
) )
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight) (output,) = apply(op, inp, weight)
if bias is not None: if bias is not None:
output += bias output += bias
@@ -358,7 +376,11 @@ def conv_transpose2d(
conv_mode.lower() == "cross_correlation" conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
) )
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
inp, weight = convert_inputs(inp, weight)


if groups != 1: if groups != 1:
raise NotImplementedError("group transposed conv2d is not supported yet.") raise NotImplementedError("group transposed conv2d is not supported yet.")
@@ -375,8 +397,8 @@ def conv_transpose2d(
dilate_h=dilate_h, dilate_h=dilate_h,
dilate_w=dilate_w, dilate_w=dilate_w,
strategy=get_execution_strategy(), strategy=get_execution_strategy(),
compute_mode=compute_mode,
) )
weight, inp = utils.convert_inputs(weight, inp)
(output,) = apply(op, weight, inp) (output,) = apply(op, weight, inp)
if bias is not None: if bias is not None:
output += bias output += bias
@@ -428,7 +450,11 @@ def deformable_conv2d(
conv_mode.lower() == "cross_correlation" conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
) )
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
if amp._enabled:
compute_mode = "float32"
inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias)
else:
inp, weight, offset, mask = convert_inputs(inp, weight, offset, mask)


stride_h, stride_w = expand_hw(stride) stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding) pad_h, pad_w = expand_hw(padding)
@@ -447,7 +473,6 @@ def deformable_conv2d(
compute_mode=compute_mode, compute_mode=compute_mode,
sparse=sparse_type, sparse=sparse_type,
) )
inp, weight, offset, mask = utils.convert_inputs(inp, weight, offset, mask)
(output,) = apply(op, inp, weight, offset, mask) (output,) = apply(op, inp, weight, offset, mask)
if bias is not None: if bias is not None:
output += bias output += bias
@@ -468,6 +493,7 @@ def local_conv2d(
conv_mode.lower() == "cross_correlation" conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
) )
inp, weight = convert_inputs(inp, weight)


stride_h, stride_w = expand_hw(stride) stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding) pad_h, pad_w = expand_hw(padding)
@@ -481,10 +507,8 @@ def local_conv2d(
dilate_h=dilate_h, dilate_h=dilate_h,
dilate_w=dilate_w, dilate_w=dilate_w,
mode=conv_mode, mode=conv_mode,
compute_mode="default",
sparse="dense", sparse="dense",
) )
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight) (output,) = apply(op, inp, weight)
if bias is not None: if bias is not None:
output += bias output += bias
@@ -515,8 +539,9 @@ def conv_transpose3d(
:param dilation: dilation of the 3D convolution operation. Default: 1 :param dilation: dilation of the 3D convolution operation. Default: 1
:return: output tensor. :return: output tensor.
""" """
D, H, W = 0, 1, 2
inp, weight = convert_inputs(inp, weight)


D, H, W = 0, 1, 2
pad = _triple(padding) pad = _triple(padding)
stride = _triple_nonzero(stride) stride = _triple_nonzero(stride)
dilate = _triple_nonzero(dilation) dilate = _triple_nonzero(dilation)
@@ -533,7 +558,6 @@ def conv_transpose3d(
dilate_w=dilate[W], dilate_w=dilate[W],
strategy=get_execution_strategy(), strategy=get_execution_strategy(),
) )
weight, inp = utils.convert_inputs(weight, inp)
(output,) = apply(op, weight, inp) (output,) = apply(op, weight, inp)
if bias is not None: if bias is not None:
output += bias output += bias
@@ -994,7 +1018,8 @@ def batch_norm(
training: bool = False, training: bool = False,
momentum: float = 0.9, momentum: float = 0.9,
eps: float = 1e-5, eps: float = 1e-5,
inplace: bool = True
inplace: bool = True,
compute_mode="default",
): ):
r""" r"""
Applies batch normalization to the input. Applies batch normalization to the input.
@@ -1027,15 +1052,11 @@ def batch_norm(
def make_full_if_none(x, value): def make_full_if_none(x, value):
if x is None: if x is None:
(x,) = Const(value, dtype=inp.dtype, device=inp.device)() (x,) = Const(value, dtype=inp.dtype, device=inp.device)()
shape = utils.astensor1d(
(1, C, 1, 1), inp, dtype="int32", device=inp.device
)
shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), x, shape) (result,) = apply(builtin.Broadcast(), x, shape)
return result return result
elif x.ndim == 1: elif x.ndim == 1:
shape = utils.astensor1d(
(1, C, 1, 1), inp, dtype="int32", device=inp.device
)
shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Reshape(), x, shape) (result,) = apply(builtin.Reshape(), x, shape)
return result return result
return x return x
@@ -1052,10 +1073,15 @@ def batch_norm(
if has_var and running_var.ndim != 4: if has_var and running_var.ndim != 4:
raise ValueError raise ValueError


inp, weight, bias, running_mean, running_var = utils.convert_inputs(
inp, weight, bias, running_mean, running_var
)

if amp._enabled:
inp = inp.astype("float16")
weight, bias, running_mean, running_var = cast_tensors(
weight, bias, running_mean, running_var, promote=True
)
elif compute_mode != "float32":
inp, weight, bias, running_mean, running_var = convert_inputs(
inp, weight, bias, running_mean, running_var
)
weight = make_full_if_none(weight, 1) weight = make_full_if_none(weight, 1)
bias = make_full_if_none(bias, 0) bias = make_full_if_none(bias, 0)


@@ -1352,7 +1378,7 @@ def indexing_one_hot(
""" """
assert isinstance(src, Tensor), "src must be of Tensor type" assert isinstance(src, Tensor), "src must be of Tensor type"
op = builtin.IndexingOneHot(axis=axis) op = builtin.IndexingOneHot(axis=axis)
index = utils.convert_single_value(index, dtype="int32", device=src.device)
index = convert_single_value(index, dtype="int32", device=src.device)
(result,) = apply(op, src, index) (result,) = apply(op, src, index)
if not keepdims: if not keepdims:
result = squeeze(result, axis) result = squeeze(result, axis)


+ 1
- 1
imperative/python/megengine/functional/tensor.py View File

@@ -13,7 +13,7 @@ import numpy as np


from ..core._imperative_rt import CompNode from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import SymbolVar, apply from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core._wrap import device as as_device
from ..core._wrap import as_device
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity from ..core.ops.builtin import Copy, Identity
from ..core.ops.special import Const from ..core.ops.special import Const


+ 1
- 1
imperative/python/megengine/jit/tracing.py View File

@@ -33,7 +33,7 @@ from ..core._imperative_rt.ops import (
RemoteSend, RemoteSend,
) )
from ..core._trace_option import set_symbolic_shape from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device
from ..core._wrap import as_device
from ..core.ops.builtin import BatchNorm, OpDef from ..core.ops.builtin import BatchNorm, OpDef
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G


+ 3
- 0
imperative/python/megengine/module/batchnorm.py View File

@@ -26,6 +26,7 @@ class _BatchNorm(Module):
affine=True, affine=True,
track_running_stats=True, track_running_stats=True,
freeze=False, freeze=False,
compute_mode="default",
**kwargs **kwargs
): ):
super(_BatchNorm, self).__init__(**kwargs) super(_BatchNorm, self).__init__(**kwargs)
@@ -36,6 +37,7 @@ class _BatchNorm(Module):
self.track_running_stats = track_running_stats self.track_running_stats = track_running_stats
self._track_running_stats_saved = track_running_stats self._track_running_stats_saved = track_running_stats
self.freeze = freeze self.freeze = freeze
self.compute_mode = compute_mode
if self.freeze: if self.freeze:
assert ( assert (
self._track_running_stats_saved self._track_running_stats_saved
@@ -123,6 +125,7 @@ class _BatchNorm(Module):
or ((self.running_mean is None) and (self.running_var is None)), or ((self.running_mean is None) and (self.running_var is None)),
momentum=exponential_average_factor, momentum=exponential_average_factor,
eps=self.eps, eps=self.eps,
compute_mode=self.compute_mode,
) )


if _ndims != 4: if _ndims != 4:


+ 8
- 2
imperative/python/megengine/module/linear.py View File

@@ -51,7 +51,12 @@ class Linear(Module):
""" """


def __init__( def __init__(
self, in_features: int, out_features: int, bias: bool = True, **kwargs
self,
in_features: int,
out_features: int,
bias: bool = True,
compute_mode: str = "default",
**kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self.out_features = out_features self.out_features = out_features
@@ -62,6 +67,7 @@ class Linear(Module):
if bias: if bias:
b_shape = (out_features,) b_shape = (out_features,)
self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) self.bias = Parameter(np.zeros(b_shape, dtype=np.float32))
self.compute_mode = compute_mode
self.reset_parameters() self.reset_parameters()


def _get_fanin(self): def _get_fanin(self):
@@ -75,7 +81,7 @@ class Linear(Module):
init.zeros_(self.bias) init.zeros_(self.bias)


def _calc_linear(self, x, weight, bias): def _calc_linear(self, x, weight, bias):
return linear(x, weight, bias)
return linear(x, weight, bias, compute_mode=self.compute_mode)


def forward(self, x): def forward(self, x):
return self._calc_linear(x, self.weight, self.bias) return self._calc_linear(x, self.weight, self.bias)


+ 0
- 2
imperative/python/megengine/quantization/internal_fake_quant.py View File

@@ -5,8 +5,6 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
import math
from functools import partial from functools import partial


from .. import functional as F from .. import functional as F


+ 1
- 1
imperative/python/megengine/tensor.py View File

@@ -14,7 +14,7 @@ from .core._imperative_rt import CompNode
from .core._imperative_rt.core2 import Tensor as _Tensor from .core._imperative_rt.core2 import Tensor as _Tensor
from .core._imperative_rt.core2 import apply from .core._imperative_rt.core2 import apply
from .core._trace_option import use_symbolic_shape from .core._trace_option import use_symbolic_shape
from .core._wrap import device as as_device
from .core._wrap import as_device
from .core.ops.builtin import Copy, GetVarShape from .core.ops.builtin import Copy, GetVarShape
from .core.tensor.array_method import ArrayMethodMixin from .core.tensor.array_method import ArrayMethodMixin
from .device import _valid_device, get_default_device from .device import _valid_device, get_default_device


+ 34
- 0
imperative/python/test/unit/amp/test_autocast.py View File

@@ -0,0 +1,34 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import amp
from megengine.core.tensor import amp as origin_amp


def test_grad_scaler():
def check(enabled, low, high):
assert amp.enabled == enabled
assert origin_amp._enabled == enabled
assert amp.low_prec_dtype == low
assert origin_amp._low_prec_dtype == low
assert amp.high_prec_dtype == high
assert origin_amp._high_prec_dtype == high

origin_enabled = amp.enabled
origin_high = amp.high_prec_dtype
origin_low = amp.low_prec_dtype
with amp.autocast(low_prec_dtype="low", high_prec_dtype="high"):
check(True, "low", "high")
check(origin_enabled, origin_low, origin_high)
amp.enabled = True
amp.high_prec_dtype = "high"
amp.low_prec_dtype = "low"
check(True, "low", "high")
amp.enabled = origin_enabled
amp.high_prec_dtype = origin_high
amp.low_prec_dtype = origin_low
check(origin_enabled, origin_low, origin_high)

+ 48
- 4
imperative/python/test/unit/functional/test_functional.py View File

@@ -14,6 +14,7 @@ import numpy as np
import pytest import pytest
from utils import opr_test from utils import opr_test


import megengine.amp as amp
import megengine.core.ops.builtin as builtin import megengine.core.ops.builtin as builtin
import megengine.core.tensor.dtype as dtype import megengine.core.tensor.dtype as dtype
import megengine.functional as F import megengine.functional as F
@@ -767,6 +768,27 @@ def test_batch_conv_bias():
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True)




def test_conv2d_io16c32():
amp.enabled = True
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32)
weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32)
out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
amp.enabled = False
expected = F.conv2d(
inp.astype("float16"),
weight.astype("float16"),
None,
(2, 2),
(3, 3),
(1, 1),
1,
compute_mode="float32",
)
assert out.dtype == np.float16
assert expected.dtype == np.float16
np.testing.assert_allclose(out.numpy(), expected.numpy())


def test_conv2d_zero_stride_numpy_array(): def test_conv2d_zero_stride_numpy_array():
inp = np.random.randn(3, 224, 224).astype(np.float32) inp = np.random.randn(3, 224, 224).astype(np.float32)
inp = inp[np.newaxis, :] inp = inp[np.newaxis, :]
@@ -787,8 +809,8 @@ def test_conv3d_zero_stride_numpy_array():




def test_conv1d(): def test_conv1d():
inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4))
weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2))
inp = tensor(np.ones((2, 2, 4), dtype=np.float32))
weight = tensor(np.ones((3, 2, 2), dtype=np.float32))
out = F.conv1d(inp, weight, None, 2, 0, 1, 1) out = F.conv1d(inp, weight, None, 2, 0, 1, 1)
np.testing.assert_equal( np.testing.assert_equal(
out.numpy(), out.numpy(),
@@ -798,9 +820,31 @@ def test_conv1d():
) )




def test_batchnorm2d_io16c32():
amp.enabled = True
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32)
weight = tensor(np.ones((1, 3, 1, 1)), dtype=np.float32)
bias = tensor(np.zeros((1, 3, 1, 1)), dtype=np.float32)

out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False)

amp.enabled = False
expected = F.batch_norm(
inp.astype("float16"),
weight=weight,
bias=bias,
training=True,
inplace=False,
compute_mode="float32",
)
assert out.dtype == np.float16
assert expected.dtype == np.float16
np.testing.assert_allclose(out.numpy(), expected.numpy())


def test_conv3d(): def test_conv3d():
inp = tensor(np.ones((256,), dtype=np.float32).reshape(2, 2, 4, 4, 4))
weight = tensor(np.ones((48,), dtype=np.float32).reshape(3, 2, 2, 2, 2))
inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32))
weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32))
out = F.conv3d(inp, weight, None, 2, 0, 1, 1) out = F.conv3d(inp, weight, None, 2, 0, 1, 1)
print(out.numpy().shape) print(out.numpy().shape)
np.testing.assert_equal( np.testing.assert_equal(


+ 0
- 33
imperative/python/test/unit/module/test_module.py View File

@@ -473,39 +473,6 @@ def test_pickle_module():
np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6)




def test_load_quantized():
from megengine.core.tensor import dtype

data_shape = (2, 28)
data = tensor(np.random.random(data_shape), dtype="float32")
data = data.astype(dtype.qint8(0.1))
mlp = MLP()
quantize_qat(mlp)
quantize(mlp)
mlp.dense0.weight = Parameter(mlp.dense0.weight.astype(dtype.qint8(0.001)).numpy())
mlp.dense1.weight = Parameter(mlp.dense1.weight.astype(dtype.qint8(0.0002)).numpy())
mlp.eval()
pred0 = mlp(data)

with BytesIO() as fout:
mge.save(mlp.state_dict(), fout)
fout.seek(0)
checkpoint = mge.load(fout)
# change mlp weight.
mlp.dense0.weight = Parameter(
mlp.dense0.weight.astype(dtype.qint8(0.00001)).numpy()
)
mlp.dense1.weight = Parameter(
mlp.dense1.weight.astype(dtype.qint8(0.2)).numpy()
)
mlp.load_state_dict(checkpoint)
pred1 = mlp(data)

np.testing.assert_allclose(
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), atol=5e-6
)


def test_repr_basic(): def test_repr_basic():
# test whether __repr__ can output correct information # test whether __repr__ can output correct information
class ConvModel(Module): class ConvModel(Module):


Loading…
Cancel
Save