GitOrigin-RevId: 6fbffc4845
release-1.5
@@ -117,6 +117,7 @@ def _atexit(handler): | |||
# subpackages | |||
import megengine.amp | |||
import megengine.autodiff | |||
import megengine.data | |||
import megengine.distributed | |||
@@ -0,0 +1,14 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import mprop | |||
from ..core.tensor.amp import * | |||
from .autocast import autocast | |||
mprop.init() |
@@ -0,0 +1,79 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
from ..core.tensor import amp | |||
class autocast: | |||
r""" | |||
A class to control autocast mode for amp as a context manager or a decorator. | |||
:param enabled: Whether autocast mode is enabled. | |||
:low_prec_dtype: Set amp autocast mode's lower precision dtype. It will change the | |||
target dtype in tensor casting for better speed and memory. Default: float16. | |||
:high_prec_dtype: Set amp autocast mode's higher precision dtype. It will change the | |||
target dtype in tensor casting for better precision. Default: float32. | |||
Examples: | |||
..code-block:: | |||
# used as decorator | |||
@autocast() | |||
def train_step(image, label): | |||
with gm: | |||
logits = model(image) | |||
loss = F.nn.cross_entropy(logits, label) | |||
gm.backward(loss) | |||
opt.step().clear_grad() | |||
return loss | |||
# used as context manager | |||
def train_step(image, label): | |||
with autocast(): | |||
with gm: | |||
logits = model(image) | |||
loss = F.nn.cross_entropy(logits, label) | |||
gm.backward(loss) | |||
opt.step().clear_grad() | |||
return loss | |||
""" | |||
def __init__( | |||
self, | |||
enabled: bool = True, | |||
low_prec_dtype: str = "float16", | |||
high_prec_dtype: str = "float32", | |||
): | |||
self.enabled = enabled | |||
self.high_prec_dtype = high_prec_dtype | |||
self.low_prec_dtype = low_prec_dtype | |||
self._origin_enabled = None | |||
self._origin_high = None | |||
self._origin_low = None | |||
def __enter__(self): | |||
self._origin_enabled, amp._enabled = amp._enabled, self.enabled | |||
self._origin_high = amp._high_prec_dtype | |||
amp._high_prec_dtype = self.high_prec_dtype | |||
self._origin_low = amp._low_prec_dtype | |||
amp._low_prec_dtype = self.low_prec_dtype | |||
def __exit__(self, *args): | |||
amp._enabled = self._origin_enabled | |||
amp._high_prec_dtype = self._origin_high | |||
amp._low_prec_dtype = self._origin_low | |||
def __call__(self, func): | |||
@functools.wraps(func) | |||
def wrapper(*args, **kwargs): | |||
with self: | |||
return func(*args, **kwargs) | |||
return wrapper |
@@ -49,7 +49,7 @@ class Device: | |||
return self._cn == rhs._cn | |||
def device(obj): | |||
def as_device(obj): | |||
if isinstance(obj, Device): | |||
return obj | |||
return Device(obj) |
@@ -0,0 +1,78 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
_enabled = False | |||
_high_prec_dtype = "float32" | |||
_low_prec_dtype = "float16" | |||
@property | |||
def enabled(mod): | |||
r""" | |||
Get or set amp autocast mode enabled or not. | |||
Examples: | |||
..code-block:: | |||
import megengine as mge | |||
mge.amp.enabled = True | |||
""" | |||
return _enabled | |||
@enabled.setter | |||
def enabled(mod, enabled: bool): | |||
global _enabled | |||
_enabled = enabled | |||
@property | |||
def high_prec_dtype(mod): | |||
r""" | |||
Get or set amp autocast mode's higher precision dtype. It will change the | |||
target dtype in tensor casting for better precision. Default: float32. | |||
Examples: | |||
..code-block:: | |||
import megengine as mge | |||
mge.amp.high_prec_dtype = "float32" | |||
""" | |||
return _high_prec_dtype | |||
@high_prec_dtype.setter | |||
def high_prec_dtype(mod, dtype: str): | |||
global _high_prec_dtype | |||
_high_prec_dtype = dtype | |||
@property | |||
def low_prec_dtype(mod): | |||
r""" | |||
Get or set amp autocast mode's lower precision dtype. It will change the | |||
target dtype in tensor casting for better speed and memory. Default: float16. | |||
Examples: | |||
..code-block:: | |||
import megengine as mge | |||
mge.amp.low_prec_dtype = "float16" | |||
""" | |||
return _low_prec_dtype | |||
@low_prec_dtype.setter | |||
def low_prec_dtype(mod, dtype: str): | |||
global _low_prec_dtype | |||
_low_prec_dtype = dtype |
@@ -15,15 +15,20 @@ import numpy as np | |||
from .._imperative_rt.common import CompNode | |||
from .._imperative_rt.core2 import SymbolVar, Tensor, apply | |||
from ..ops import builtin | |||
from ..ops.builtin import Elemwise, GetVarShape | |||
from . import utils | |||
from .indexing import getitem as _getitem | |||
from .indexing import setitem as _setitem | |||
from .utils import isscalar | |||
from .utils import make_shape_tuple as _make_shape_tuple | |||
from .utils import setscalar | |||
_ElwMod = Elemwise.Mode | |||
from . import amp | |||
from .indexing import getitem, setitem | |||
from .utils import ( | |||
_normalize_axis, | |||
astensor1d, | |||
astype, | |||
cast_tensors, | |||
convert_inputs, | |||
isscalar, | |||
make_shape_tuple, | |||
setscalar, | |||
) | |||
_ElwMod = builtin.Elemwise.Mode | |||
def _elwise_apply(args, mode): | |||
@@ -40,47 +45,59 @@ def _elwise_apply(args, mode): | |||
def _elwise(*args, mode): | |||
args = convert_inputs(*args) | |||
if mode in ( | |||
_ElwMod.TRUE_DIV, | |||
_ElwMod.EXP, | |||
_ElwMod.POW, | |||
_ElwMod.CEIL, | |||
_ElwMod.FLOOR, | |||
_ElwMod.ROUND, | |||
_ElwMod.LOG, | |||
_ElwMod.EXPM1, | |||
_ElwMod.LOG1P, | |||
_ElwMod.TANH, | |||
_ElwMod.ACOS, | |||
_ElwMod.ASIN, | |||
_ElwMod.ATAN2, | |||
_ElwMod.COS, | |||
_ElwMod.H_SWISH, | |||
_ElwMod.SIGMOID, | |||
_ElwMod.SIN, | |||
) and ( | |||
amp._enabled or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) | |||
): | |||
if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND) and np.issubdtype( | |||
args[0].dtype, np.integer | |||
): | |||
return args[0] | |||
args = tuple( | |||
map( | |||
lambda x: x.astype("float32") | |||
if hasattr(x, "dtype") and x.dtype != np.float32 | |||
else x, | |||
args, | |||
) | |||
) | |||
args = utils.convert_inputs(*args) | |||
# autocast to FP32 to maintain precision | |||
# or to avoid op's not supporting all int args | |||
args = cast_tensors(*args, promote=True) | |||
if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND,) and np.issubdtype( | |||
args[0].dtype, np.integer | |||
): | |||
return args[0] | |||
return _elwise_apply(args, mode) | |||
def _matmul(inp1, inp2): | |||
if amp._enabled: | |||
compute_mode = "float32" | |||
inp1, inp2 = cast_tensors(inp1, inp2) | |||
else: | |||
compute_mode = "default" | |||
inp1, inp2 = convert_inputs(inp1, inp2) | |||
op = builtin.MatrixMul( | |||
transposeA=False, transposeB=False, compute_mode="default", format="default" | |||
transposeA=False, transposeB=False, compute_mode=compute_mode, format="default" | |||
) | |||
inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
(result,) = apply(op, inp1, inp2) | |||
return result | |||
def _transpose(data, axes): | |||
op = builtin.Dimshuffle(axes) | |||
(data,) = utils.convert_inputs(data) | |||
(data,) = convert_inputs(data) | |||
(result,) = apply(op, data) | |||
return result | |||
def _broadcast(inp, shape): | |||
shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | |||
shape = astensor1d(shape, inp, dtype="int32", device=inp.device) | |||
(result,) = apply(builtin.Broadcast(), inp, shape) | |||
return result | |||
@@ -88,7 +105,7 @@ def _broadcast(inp, shape): | |||
def _reshape(x, shape): | |||
unspec_axis = None | |||
try: | |||
shape_tuple = _make_shape_tuple(shape) | |||
shape_tuple = make_shape_tuple(shape) | |||
except ValueError: | |||
pass | |||
else: | |||
@@ -102,7 +119,7 @@ def _reshape(x, shape): | |||
"multiple -1 in shape: {} & {}".format(unspec_axis, i) | |||
) | |||
unspec_axis = i | |||
shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) | |||
shape = astensor1d(shape, x, dtype="int32", device=x.device) | |||
if unspec_axis is None: | |||
op = builtin.Reshape() | |||
else: | |||
@@ -171,7 +188,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
return list(map(int, axis)) | |||
axis = get_axes() | |||
axis = utils._normalize_axis(inp.ndim, axis) | |||
axis = _normalize_axis(inp.ndim, axis) | |||
axis = [a - i for i, a in enumerate(axis)] | |||
op = builtin.RemoveAxis(axis=axis) | |||
@@ -184,7 +201,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
def _reduce(mode): | |||
def f(self, axis=None, keepdims: bool = False): | |||
data = self | |||
(data,) = utils.convert_inputs(data) | |||
(data,) = convert_inputs(data) | |||
if mode == "mean": | |||
data = data.astype("float32") | |||
elif self.dtype == np.bool_: | |||
@@ -196,7 +213,7 @@ def _reduce(mode): | |||
op = builtin.Reduce(mode=mode, axis=0) | |||
(result,) = apply(op, data) | |||
elif isinstance(axis, collections.abc.Iterable): | |||
axis = utils._normalize_axis(self.ndim, axis, reverse=True) | |||
axis = _normalize_axis(self.ndim, axis, reverse=True) | |||
for ai in axis: | |||
op = builtin.Reduce(mode=mode, axis=ai) | |||
(data,) = apply(op, data) | |||
@@ -359,11 +376,11 @@ class ArrayMethodMixin(abc.ABC): | |||
yield self[i] | |||
def __getitem__(self, index): | |||
return _getitem(self, index) | |||
return getitem(self, index) | |||
def __setitem__(self, index, value): | |||
if index is not Ellipsis: | |||
value = _setitem(self, index, value) | |||
value = setitem(self, index, value) | |||
self._reset(value) | |||
__contains__ = _todo | |||
@@ -422,7 +439,7 @@ class ArrayMethodMixin(abc.ABC): | |||
Returns a :class:`Tensor` with the same data and number of elements | |||
with the specified :attr:`~.Tensor.dtype`. | |||
""" | |||
return utils.astype(self, dtype) | |||
return astype(self, dtype) | |||
def reshape(self, *args): | |||
r""" | |||
@@ -18,7 +18,7 @@ import numpy as np | |||
from .. import _imperative_rt | |||
from .._imperative_rt import GraphOptimizeOptions | |||
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | |||
from .._wrap import device as as_device | |||
from .._wrap import as_device | |||
from ..ops.builtin import OpDef | |||
from .core import TensorBase | |||
@@ -13,9 +13,10 @@ import numpy as np | |||
from .._imperative_rt import make_const | |||
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device | |||
from .._wrap import device as as_device | |||
from .._wrap import as_device | |||
from ..ops import builtin | |||
from ..ops.special import Const | |||
from .amp import _high_prec_dtype, _low_prec_dtype | |||
from .dtype import is_dtype_equal, is_quantize | |||
_enable_convert_inputs = True | |||
@@ -98,6 +99,14 @@ def convert_inputs(*args, device=None): | |||
return tuple(map(convert, args)) | |||
def cast_tensors(*args, promote=False): | |||
if promote: | |||
dtype = _high_prec_dtype | |||
else: | |||
dtype = _low_prec_dtype | |||
return tuple(arg.astype(dtype) if arg is not None else None for arg in args) | |||
def result_type(*args): | |||
dtypes = [] | |||
for i in args: | |||
@@ -12,10 +12,8 @@ import numpy as np | |||
from ..core._imperative_rt.core2 import SymbolVar, apply | |||
from ..core.ops import builtin | |||
from ..core.ops.builtin import Elemwise | |||
from ..core.tensor import utils | |||
from ..core.tensor.array_method import _elwise_apply | |||
from ..core.tensor.utils import astype | |||
from ..device import get_default_device | |||
from ..core.tensor.array_method import _elwise | |||
from ..core.tensor.utils import astype, convert_inputs | |||
from ..tensor import Tensor | |||
from ..utils.deprecation import deprecated_func | |||
@@ -69,46 +67,9 @@ __all__ = [ | |||
] | |||
def _elwise(*args, mode): | |||
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args)) | |||
if len(tensor_args) == 0: | |||
dtype = utils.dtype_promotion(args) | |||
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | |||
args = utils.convert_inputs(first_arg, *args[1:]) | |||
else: | |||
args = utils.convert_inputs(*args) | |||
if mode in ( | |||
Elemwise.Mode.TRUE_DIV, | |||
Elemwise.Mode.EXP, | |||
Elemwise.Mode.POW, | |||
Elemwise.Mode.LOG, | |||
Elemwise.Mode.EXPM1, | |||
Elemwise.Mode.LOG1P, | |||
Elemwise.Mode.TANH, | |||
Elemwise.Mode.ACOS, | |||
Elemwise.Mode.ASIN, | |||
Elemwise.Mode.ATAN2, | |||
Elemwise.Mode.CEIL, | |||
Elemwise.Mode.COS, | |||
Elemwise.Mode.FLOOR, | |||
Elemwise.Mode.H_SWISH, | |||
Elemwise.Mode.ROUND, | |||
Elemwise.Mode.SIGMOID, | |||
Elemwise.Mode.SIN, | |||
): | |||
if mode in ( | |||
Elemwise.Mode.CEIL, | |||
Elemwise.Mode.FLOOR, | |||
Elemwise.Mode.ROUND, | |||
) and np.issubdtype(args[0].dtype, np.integer): | |||
return args[0] | |||
args = tuple(map(lambda x: astype(x, "float32"), args)) | |||
return _elwise_apply(args, mode) | |||
def _elemwise_multi_type(*args, mode, **kwargs): | |||
op = builtin.ElemwiseMultiType(mode=mode, **kwargs) | |||
args = utils.convert_inputs(*args) | |||
args = convert_inputs(*args) | |||
(result,) = apply(op, *args) | |||
return result | |||
@@ -14,7 +14,8 @@ from ..core._imperative_rt.core2 import apply | |||
from ..core._trace_option import use_symbolic_shape | |||
from ..core.ops import builtin | |||
from ..core.ops.special import Const | |||
from ..core.tensor import utils | |||
from ..core.tensor import amp | |||
from ..core.tensor.utils import _normalize_axis, cast_tensors, convert_inputs, setscalar | |||
from ..tensor import Tensor | |||
from .debug_param import get_execution_strategy | |||
from .elemwise import clip, exp, log, log1p | |||
@@ -471,7 +472,7 @@ def argmin( | |||
inp = inp.flatten() | |||
axis = 0 | |||
axis = utils._normalize_axis(inp.ndim, axis, reverse=True) | |||
axis = _normalize_axis(inp.ndim, axis, reverse=True) | |||
if isinstance(axis, collections.abc.Iterable): | |||
for ai in axis: | |||
@@ -528,7 +529,7 @@ def argmax( | |||
assert not keepdims, "can not set axis=None and keepdims=True" | |||
inp = inp.flatten() | |||
axis = 0 | |||
axis = utils._normalize_axis(inp.ndim, axis, reverse=True) | |||
axis = _normalize_axis(inp.ndim, axis, reverse=True) | |||
if isinstance(axis, collections.abc.Iterable): | |||
@@ -807,8 +808,13 @@ def matmul( | |||
[28. 40.]] | |||
""" | |||
if amp._enabled: | |||
compute_mode = "float32" | |||
inp1, inp2 = cast_tensors(inp1, inp2) | |||
else: | |||
inp1, inp2 = convert_inputs(inp1, inp2) | |||
remove_row, remove_col = False, False | |||
inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
dim1, dim2 = inp1.ndim, inp2.ndim | |||
# handle dim=1 cases, dot and matrix-vector multiplication | |||
@@ -921,12 +927,12 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | |||
""" | |||
op = builtin.Dot() | |||
inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
inp1, inp2 = convert_inputs(inp1, inp2) | |||
assert ( | |||
inp1.ndim <= 1 and inp2.ndim <= 1 | |||
), "Input tensors for dot must be 1-dimensional or scalar" | |||
(result,) = apply(op, inp1, inp2) | |||
utils.setscalar(result) | |||
setscalar(result) | |||
return result | |||
@@ -15,9 +15,16 @@ from ..core._trace_option import use_symbolic_shape | |||
from ..core.ops import builtin | |||
from ..core.ops.builtin import BatchNorm, Elemwise | |||
from ..core.ops.special import Const | |||
from ..core.tensor import megbrain_graph, utils | |||
from ..core.tensor import amp, megbrain_graph | |||
from ..core.tensor.array_method import _elwise_apply | |||
from ..core.tensor.utils import astensor1d, astype, setscalar | |||
from ..core.tensor.utils import ( | |||
astensor1d, | |||
astype, | |||
cast_tensors, | |||
convert_inputs, | |||
convert_single_value, | |||
setscalar, | |||
) | |||
from ..device import get_default_device | |||
from ..distributed import WORLD, is_distributed | |||
from ..random import uniform | |||
@@ -91,7 +98,9 @@ def expand_hw(x): | |||
return int(h), int(w) | |||
def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: | |||
def linear( | |||
inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None, compute_mode="default", | |||
) -> Tensor: | |||
""" | |||
Applies a linear transformation to the input tensor. | |||
@@ -102,8 +111,10 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor | |||
:param bias: bias with shape `(out_features,)`. | |||
Default: None | |||
""" | |||
ret = matmul(inp, weight, transpose_b=True) | |||
ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode) | |||
if bias is not None: | |||
if amp._enabled: | |||
bias = bias.astype("float16") | |||
ret += bias | |||
return ret | |||
@@ -153,6 +164,11 @@ def conv1d( | |||
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||
assert inp.ndim == 3, "the input dimension of conv1d should be 3" | |||
assert weight.ndim == 3, "the weight dimension of conv1d should be 3" | |||
if amp._enabled: | |||
compute_mode = "float32" | |||
inp, weight, bias = cast_tensors(inp, weight, bias) | |||
else: | |||
inp, weight = convert_inputs(inp, weight) | |||
inp = expand_dims(inp, 3) | |||
weight = expand_dims(weight, 3) | |||
@@ -177,7 +193,6 @@ def conv1d( | |||
compute_mode=compute_mode, | |||
sparse=sparse_type, | |||
) | |||
inp, weight = utils.convert_inputs(inp, weight) | |||
(output,) = apply(op, inp, weight) | |||
if bias is not None: | |||
output += bias | |||
@@ -228,7 +243,11 @@ def conv2d( | |||
conv_mode.lower() == "cross_correlation" | |||
or conv_mode.name == "CROSS_CORRELATION" | |||
) | |||
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||
if amp._enabled: | |||
compute_mode = "float32" | |||
inp, weight, bias = cast_tensors(inp, weight, bias) | |||
else: | |||
inp, weight = convert_inputs(inp, weight) | |||
stride_h, stride_w = expand_hw(stride) | |||
pad_h, pad_w = expand_hw(padding) | |||
@@ -247,7 +266,6 @@ def conv2d( | |||
compute_mode=compute_mode, | |||
sparse=sparse_type, | |||
) | |||
inp, weight = utils.convert_inputs(inp, weight) | |||
(output,) = apply(op, inp, weight) | |||
if bias is not None: | |||
output += bias | |||
@@ -286,6 +304,7 @@ def conv3d( | |||
:return: output tensor. | |||
""" | |||
assert conv_mode.lower() == "cross_correlation" | |||
inp, weight = convert_inputs(inp, weight) | |||
D, H, W = 0, 1, 2 | |||
@@ -308,7 +327,6 @@ def conv3d( | |||
mode=conv_mode, | |||
sparse=sparse_type, | |||
) | |||
inp, weight = utils.convert_inputs(inp, weight) | |||
(output,) = apply(op, inp, weight) | |||
if bias is not None: | |||
output += bias | |||
@@ -358,7 +376,11 @@ def conv_transpose2d( | |||
conv_mode.lower() == "cross_correlation" | |||
or conv_mode.name == "CROSS_CORRELATION" | |||
) | |||
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||
if amp._enabled: | |||
compute_mode = "float32" | |||
inp, weight, bias = cast_tensors(inp, weight, bias) | |||
else: | |||
inp, weight = convert_inputs(inp, weight) | |||
if groups != 1: | |||
raise NotImplementedError("group transposed conv2d is not supported yet.") | |||
@@ -375,8 +397,8 @@ def conv_transpose2d( | |||
dilate_h=dilate_h, | |||
dilate_w=dilate_w, | |||
strategy=get_execution_strategy(), | |||
compute_mode=compute_mode, | |||
) | |||
weight, inp = utils.convert_inputs(weight, inp) | |||
(output,) = apply(op, weight, inp) | |||
if bias is not None: | |||
output += bias | |||
@@ -428,7 +450,11 @@ def deformable_conv2d( | |||
conv_mode.lower() == "cross_correlation" | |||
or conv_mode.name == "CROSS_CORRELATION" | |||
) | |||
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||
if amp._enabled: | |||
compute_mode = "float32" | |||
inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias) | |||
else: | |||
inp, weight, offset, mask = convert_inputs(inp, weight, offset, mask) | |||
stride_h, stride_w = expand_hw(stride) | |||
pad_h, pad_w = expand_hw(padding) | |||
@@ -447,7 +473,6 @@ def deformable_conv2d( | |||
compute_mode=compute_mode, | |||
sparse=sparse_type, | |||
) | |||
inp, weight, offset, mask = utils.convert_inputs(inp, weight, offset, mask) | |||
(output,) = apply(op, inp, weight, offset, mask) | |||
if bias is not None: | |||
output += bias | |||
@@ -468,6 +493,7 @@ def local_conv2d( | |||
conv_mode.lower() == "cross_correlation" | |||
or conv_mode.name == "CROSS_CORRELATION" | |||
) | |||
inp, weight = convert_inputs(inp, weight) | |||
stride_h, stride_w = expand_hw(stride) | |||
pad_h, pad_w = expand_hw(padding) | |||
@@ -481,10 +507,8 @@ def local_conv2d( | |||
dilate_h=dilate_h, | |||
dilate_w=dilate_w, | |||
mode=conv_mode, | |||
compute_mode="default", | |||
sparse="dense", | |||
) | |||
inp, weight = utils.convert_inputs(inp, weight) | |||
(output,) = apply(op, inp, weight) | |||
if bias is not None: | |||
output += bias | |||
@@ -515,8 +539,9 @@ def conv_transpose3d( | |||
:param dilation: dilation of the 3D convolution operation. Default: 1 | |||
:return: output tensor. | |||
""" | |||
D, H, W = 0, 1, 2 | |||
inp, weight = convert_inputs(inp, weight) | |||
D, H, W = 0, 1, 2 | |||
pad = _triple(padding) | |||
stride = _triple_nonzero(stride) | |||
dilate = _triple_nonzero(dilation) | |||
@@ -533,7 +558,6 @@ def conv_transpose3d( | |||
dilate_w=dilate[W], | |||
strategy=get_execution_strategy(), | |||
) | |||
weight, inp = utils.convert_inputs(weight, inp) | |||
(output,) = apply(op, weight, inp) | |||
if bias is not None: | |||
output += bias | |||
@@ -994,7 +1018,8 @@ def batch_norm( | |||
training: bool = False, | |||
momentum: float = 0.9, | |||
eps: float = 1e-5, | |||
inplace: bool = True | |||
inplace: bool = True, | |||
compute_mode="default", | |||
): | |||
r""" | |||
Applies batch normalization to the input. | |||
@@ -1027,15 +1052,11 @@ def batch_norm( | |||
def make_full_if_none(x, value): | |||
if x is None: | |||
(x,) = Const(value, dtype=inp.dtype, device=inp.device)() | |||
shape = utils.astensor1d( | |||
(1, C, 1, 1), inp, dtype="int32", device=inp.device | |||
) | |||
shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) | |||
(result,) = apply(builtin.Broadcast(), x, shape) | |||
return result | |||
elif x.ndim == 1: | |||
shape = utils.astensor1d( | |||
(1, C, 1, 1), inp, dtype="int32", device=inp.device | |||
) | |||
shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) | |||
(result,) = apply(builtin.Reshape(), x, shape) | |||
return result | |||
return x | |||
@@ -1052,10 +1073,15 @@ def batch_norm( | |||
if has_var and running_var.ndim != 4: | |||
raise ValueError | |||
inp, weight, bias, running_mean, running_var = utils.convert_inputs( | |||
inp, weight, bias, running_mean, running_var | |||
) | |||
if amp._enabled: | |||
inp = inp.astype("float16") | |||
weight, bias, running_mean, running_var = cast_tensors( | |||
weight, bias, running_mean, running_var, promote=True | |||
) | |||
elif compute_mode != "float32": | |||
inp, weight, bias, running_mean, running_var = convert_inputs( | |||
inp, weight, bias, running_mean, running_var | |||
) | |||
weight = make_full_if_none(weight, 1) | |||
bias = make_full_if_none(bias, 0) | |||
@@ -1352,7 +1378,7 @@ def indexing_one_hot( | |||
""" | |||
assert isinstance(src, Tensor), "src must be of Tensor type" | |||
op = builtin.IndexingOneHot(axis=axis) | |||
index = utils.convert_single_value(index, dtype="int32", device=src.device) | |||
index = convert_single_value(index, dtype="int32", device=src.device) | |||
(result,) = apply(op, src, index) | |||
if not keepdims: | |||
result = squeeze(result, axis) | |||
@@ -13,7 +13,7 @@ import numpy as np | |||
from ..core._imperative_rt import CompNode | |||
from ..core._imperative_rt.core2 import SymbolVar, apply | |||
from ..core._wrap import device as as_device | |||
from ..core._wrap import as_device | |||
from ..core.ops import builtin | |||
from ..core.ops.builtin import Copy, Identity | |||
from ..core.ops.special import Const | |||
@@ -33,7 +33,7 @@ from ..core._imperative_rt.ops import ( | |||
RemoteSend, | |||
) | |||
from ..core._trace_option import set_symbolic_shape | |||
from ..core._wrap import device as as_device | |||
from ..core._wrap import as_device | |||
from ..core.ops.builtin import BatchNorm, OpDef | |||
from ..core.ops.special import Const | |||
from ..core.tensor import megbrain_graph as G | |||
@@ -26,6 +26,7 @@ class _BatchNorm(Module): | |||
affine=True, | |||
track_running_stats=True, | |||
freeze=False, | |||
compute_mode="default", | |||
**kwargs | |||
): | |||
super(_BatchNorm, self).__init__(**kwargs) | |||
@@ -36,6 +37,7 @@ class _BatchNorm(Module): | |||
self.track_running_stats = track_running_stats | |||
self._track_running_stats_saved = track_running_stats | |||
self.freeze = freeze | |||
self.compute_mode = compute_mode | |||
if self.freeze: | |||
assert ( | |||
self._track_running_stats_saved | |||
@@ -123,6 +125,7 @@ class _BatchNorm(Module): | |||
or ((self.running_mean is None) and (self.running_var is None)), | |||
momentum=exponential_average_factor, | |||
eps=self.eps, | |||
compute_mode=self.compute_mode, | |||
) | |||
if _ndims != 4: | |||
@@ -51,7 +51,12 @@ class Linear(Module): | |||
""" | |||
def __init__( | |||
self, in_features: int, out_features: int, bias: bool = True, **kwargs | |||
self, | |||
in_features: int, | |||
out_features: int, | |||
bias: bool = True, | |||
compute_mode: str = "default", | |||
**kwargs | |||
): | |||
super().__init__(**kwargs) | |||
self.out_features = out_features | |||
@@ -62,6 +67,7 @@ class Linear(Module): | |||
if bias: | |||
b_shape = (out_features,) | |||
self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | |||
self.compute_mode = compute_mode | |||
self.reset_parameters() | |||
def _get_fanin(self): | |||
@@ -75,7 +81,7 @@ class Linear(Module): | |||
init.zeros_(self.bias) | |||
def _calc_linear(self, x, weight, bias): | |||
return linear(x, weight, bias) | |||
return linear(x, weight, bias, compute_mode=self.compute_mode) | |||
def forward(self, x): | |||
return self._calc_linear(x, self.weight, self.bias) | |||
@@ -5,8 +5,6 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import copy | |||
import math | |||
from functools import partial | |||
from .. import functional as F | |||
@@ -14,7 +14,7 @@ from .core._imperative_rt import CompNode | |||
from .core._imperative_rt.core2 import Tensor as _Tensor | |||
from .core._imperative_rt.core2 import apply | |||
from .core._trace_option import use_symbolic_shape | |||
from .core._wrap import device as as_device | |||
from .core._wrap import as_device | |||
from .core.ops.builtin import Copy, GetVarShape | |||
from .core.tensor.array_method import ArrayMethodMixin | |||
from .device import _valid_device, get_default_device | |||
@@ -0,0 +1,34 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from megengine import amp | |||
from megengine.core.tensor import amp as origin_amp | |||
def test_grad_scaler(): | |||
def check(enabled, low, high): | |||
assert amp.enabled == enabled | |||
assert origin_amp._enabled == enabled | |||
assert amp.low_prec_dtype == low | |||
assert origin_amp._low_prec_dtype == low | |||
assert amp.high_prec_dtype == high | |||
assert origin_amp._high_prec_dtype == high | |||
origin_enabled = amp.enabled | |||
origin_high = amp.high_prec_dtype | |||
origin_low = amp.low_prec_dtype | |||
with amp.autocast(low_prec_dtype="low", high_prec_dtype="high"): | |||
check(True, "low", "high") | |||
check(origin_enabled, origin_low, origin_high) | |||
amp.enabled = True | |||
amp.high_prec_dtype = "high" | |||
amp.low_prec_dtype = "low" | |||
check(True, "low", "high") | |||
amp.enabled = origin_enabled | |||
amp.high_prec_dtype = origin_high | |||
amp.low_prec_dtype = origin_low | |||
check(origin_enabled, origin_low, origin_high) |
@@ -14,6 +14,7 @@ import numpy as np | |||
import pytest | |||
from utils import opr_test | |||
import megengine.amp as amp | |||
import megengine.core.ops.builtin as builtin | |||
import megengine.core.tensor.dtype as dtype | |||
import megengine.functional as F | |||
@@ -767,6 +768,27 @@ def test_batch_conv_bias(): | |||
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | |||
def test_conv2d_io16c32(): | |||
amp.enabled = True | |||
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) | |||
weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32) | |||
out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) | |||
amp.enabled = False | |||
expected = F.conv2d( | |||
inp.astype("float16"), | |||
weight.astype("float16"), | |||
None, | |||
(2, 2), | |||
(3, 3), | |||
(1, 1), | |||
1, | |||
compute_mode="float32", | |||
) | |||
assert out.dtype == np.float16 | |||
assert expected.dtype == np.float16 | |||
np.testing.assert_allclose(out.numpy(), expected.numpy()) | |||
def test_conv2d_zero_stride_numpy_array(): | |||
inp = np.random.randn(3, 224, 224).astype(np.float32) | |||
inp = inp[np.newaxis, :] | |||
@@ -787,8 +809,8 @@ def test_conv3d_zero_stride_numpy_array(): | |||
def test_conv1d(): | |||
inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4)) | |||
weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2)) | |||
inp = tensor(np.ones((2, 2, 4), dtype=np.float32)) | |||
weight = tensor(np.ones((3, 2, 2), dtype=np.float32)) | |||
out = F.conv1d(inp, weight, None, 2, 0, 1, 1) | |||
np.testing.assert_equal( | |||
out.numpy(), | |||
@@ -798,9 +820,31 @@ def test_conv1d(): | |||
) | |||
def test_batchnorm2d_io16c32(): | |||
amp.enabled = True | |||
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) | |||
weight = tensor(np.ones((1, 3, 1, 1)), dtype=np.float32) | |||
bias = tensor(np.zeros((1, 3, 1, 1)), dtype=np.float32) | |||
out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False) | |||
amp.enabled = False | |||
expected = F.batch_norm( | |||
inp.astype("float16"), | |||
weight=weight, | |||
bias=bias, | |||
training=True, | |||
inplace=False, | |||
compute_mode="float32", | |||
) | |||
assert out.dtype == np.float16 | |||
assert expected.dtype == np.float16 | |||
np.testing.assert_allclose(out.numpy(), expected.numpy()) | |||
def test_conv3d(): | |||
inp = tensor(np.ones((256,), dtype=np.float32).reshape(2, 2, 4, 4, 4)) | |||
weight = tensor(np.ones((48,), dtype=np.float32).reshape(3, 2, 2, 2, 2)) | |||
inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32)) | |||
weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32)) | |||
out = F.conv3d(inp, weight, None, 2, 0, 1, 1) | |||
print(out.numpy().shape) | |||
np.testing.assert_equal( | |||
@@ -473,39 +473,6 @@ def test_pickle_module(): | |||
np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) | |||
def test_load_quantized(): | |||
from megengine.core.tensor import dtype | |||
data_shape = (2, 28) | |||
data = tensor(np.random.random(data_shape), dtype="float32") | |||
data = data.astype(dtype.qint8(0.1)) | |||
mlp = MLP() | |||
quantize_qat(mlp) | |||
quantize(mlp) | |||
mlp.dense0.weight = Parameter(mlp.dense0.weight.astype(dtype.qint8(0.001)).numpy()) | |||
mlp.dense1.weight = Parameter(mlp.dense1.weight.astype(dtype.qint8(0.0002)).numpy()) | |||
mlp.eval() | |||
pred0 = mlp(data) | |||
with BytesIO() as fout: | |||
mge.save(mlp.state_dict(), fout) | |||
fout.seek(0) | |||
checkpoint = mge.load(fout) | |||
# change mlp weight. | |||
mlp.dense0.weight = Parameter( | |||
mlp.dense0.weight.astype(dtype.qint8(0.00001)).numpy() | |||
) | |||
mlp.dense1.weight = Parameter( | |||
mlp.dense1.weight.astype(dtype.qint8(0.2)).numpy() | |||
) | |||
mlp.load_state_dict(checkpoint) | |||
pred1 = mlp(data) | |||
np.testing.assert_allclose( | |||
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), atol=5e-6 | |||
) | |||
def test_repr_basic(): | |||
# test whether __repr__ can output correct information | |||
class ConvModel(Module): | |||