GitOrigin-RevId: 6fbffc4845
release-1.5
@@ -117,6 +117,7 @@ def _atexit(handler): | |||||
# subpackages | # subpackages | ||||
import megengine.amp | |||||
import megengine.autodiff | import megengine.autodiff | ||||
import megengine.data | import megengine.data | ||||
import megengine.distributed | import megengine.distributed | ||||
@@ -0,0 +1,14 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import mprop | |||||
from ..core.tensor.amp import * | |||||
from .autocast import autocast | |||||
mprop.init() |
@@ -0,0 +1,79 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import functools | |||||
from ..core.tensor import amp | |||||
class autocast: | |||||
r""" | |||||
A class to control autocast mode for amp as a context manager or a decorator. | |||||
:param enabled: Whether autocast mode is enabled. | |||||
:low_prec_dtype: Set amp autocast mode's lower precision dtype. It will change the | |||||
target dtype in tensor casting for better speed and memory. Default: float16. | |||||
:high_prec_dtype: Set amp autocast mode's higher precision dtype. It will change the | |||||
target dtype in tensor casting for better precision. Default: float32. | |||||
Examples: | |||||
..code-block:: | |||||
# used as decorator | |||||
@autocast() | |||||
def train_step(image, label): | |||||
with gm: | |||||
logits = model(image) | |||||
loss = F.nn.cross_entropy(logits, label) | |||||
gm.backward(loss) | |||||
opt.step().clear_grad() | |||||
return loss | |||||
# used as context manager | |||||
def train_step(image, label): | |||||
with autocast(): | |||||
with gm: | |||||
logits = model(image) | |||||
loss = F.nn.cross_entropy(logits, label) | |||||
gm.backward(loss) | |||||
opt.step().clear_grad() | |||||
return loss | |||||
""" | |||||
def __init__( | |||||
self, | |||||
enabled: bool = True, | |||||
low_prec_dtype: str = "float16", | |||||
high_prec_dtype: str = "float32", | |||||
): | |||||
self.enabled = enabled | |||||
self.high_prec_dtype = high_prec_dtype | |||||
self.low_prec_dtype = low_prec_dtype | |||||
self._origin_enabled = None | |||||
self._origin_high = None | |||||
self._origin_low = None | |||||
def __enter__(self): | |||||
self._origin_enabled, amp._enabled = amp._enabled, self.enabled | |||||
self._origin_high = amp._high_prec_dtype | |||||
amp._high_prec_dtype = self.high_prec_dtype | |||||
self._origin_low = amp._low_prec_dtype | |||||
amp._low_prec_dtype = self.low_prec_dtype | |||||
def __exit__(self, *args): | |||||
amp._enabled = self._origin_enabled | |||||
amp._high_prec_dtype = self._origin_high | |||||
amp._low_prec_dtype = self._origin_low | |||||
def __call__(self, func): | |||||
@functools.wraps(func) | |||||
def wrapper(*args, **kwargs): | |||||
with self: | |||||
return func(*args, **kwargs) | |||||
return wrapper |
@@ -49,7 +49,7 @@ class Device: | |||||
return self._cn == rhs._cn | return self._cn == rhs._cn | ||||
def device(obj): | |||||
def as_device(obj): | |||||
if isinstance(obj, Device): | if isinstance(obj, Device): | ||||
return obj | return obj | ||||
return Device(obj) | return Device(obj) |
@@ -0,0 +1,78 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
_enabled = False | |||||
_high_prec_dtype = "float32" | |||||
_low_prec_dtype = "float16" | |||||
@property | |||||
def enabled(mod): | |||||
r""" | |||||
Get or set amp autocast mode enabled or not. | |||||
Examples: | |||||
..code-block:: | |||||
import megengine as mge | |||||
mge.amp.enabled = True | |||||
""" | |||||
return _enabled | |||||
@enabled.setter | |||||
def enabled(mod, enabled: bool): | |||||
global _enabled | |||||
_enabled = enabled | |||||
@property | |||||
def high_prec_dtype(mod): | |||||
r""" | |||||
Get or set amp autocast mode's higher precision dtype. It will change the | |||||
target dtype in tensor casting for better precision. Default: float32. | |||||
Examples: | |||||
..code-block:: | |||||
import megengine as mge | |||||
mge.amp.high_prec_dtype = "float32" | |||||
""" | |||||
return _high_prec_dtype | |||||
@high_prec_dtype.setter | |||||
def high_prec_dtype(mod, dtype: str): | |||||
global _high_prec_dtype | |||||
_high_prec_dtype = dtype | |||||
@property | |||||
def low_prec_dtype(mod): | |||||
r""" | |||||
Get or set amp autocast mode's lower precision dtype. It will change the | |||||
target dtype in tensor casting for better speed and memory. Default: float16. | |||||
Examples: | |||||
..code-block:: | |||||
import megengine as mge | |||||
mge.amp.low_prec_dtype = "float16" | |||||
""" | |||||
return _low_prec_dtype | |||||
@low_prec_dtype.setter | |||||
def low_prec_dtype(mod, dtype: str): | |||||
global _low_prec_dtype | |||||
_low_prec_dtype = dtype |
@@ -15,15 +15,20 @@ import numpy as np | |||||
from .._imperative_rt.common import CompNode | from .._imperative_rt.common import CompNode | ||||
from .._imperative_rt.core2 import SymbolVar, Tensor, apply | from .._imperative_rt.core2 import SymbolVar, Tensor, apply | ||||
from ..ops import builtin | from ..ops import builtin | ||||
from ..ops.builtin import Elemwise, GetVarShape | |||||
from . import utils | |||||
from .indexing import getitem as _getitem | |||||
from .indexing import setitem as _setitem | |||||
from .utils import isscalar | |||||
from .utils import make_shape_tuple as _make_shape_tuple | |||||
from .utils import setscalar | |||||
_ElwMod = Elemwise.Mode | |||||
from . import amp | |||||
from .indexing import getitem, setitem | |||||
from .utils import ( | |||||
_normalize_axis, | |||||
astensor1d, | |||||
astype, | |||||
cast_tensors, | |||||
convert_inputs, | |||||
isscalar, | |||||
make_shape_tuple, | |||||
setscalar, | |||||
) | |||||
_ElwMod = builtin.Elemwise.Mode | |||||
def _elwise_apply(args, mode): | def _elwise_apply(args, mode): | ||||
@@ -40,47 +45,59 @@ def _elwise_apply(args, mode): | |||||
def _elwise(*args, mode): | def _elwise(*args, mode): | ||||
args = convert_inputs(*args) | |||||
if mode in ( | if mode in ( | ||||
_ElwMod.TRUE_DIV, | _ElwMod.TRUE_DIV, | ||||
_ElwMod.EXP, | |||||
_ElwMod.POW, | _ElwMod.POW, | ||||
_ElwMod.CEIL, | |||||
_ElwMod.FLOOR, | |||||
_ElwMod.ROUND, | |||||
_ElwMod.LOG, | |||||
_ElwMod.EXPM1, | |||||
_ElwMod.LOG1P, | |||||
_ElwMod.TANH, | |||||
_ElwMod.ACOS, | |||||
_ElwMod.ASIN, | |||||
_ElwMod.ATAN2, | |||||
_ElwMod.COS, | |||||
_ElwMod.H_SWISH, | |||||
_ElwMod.SIGMOID, | |||||
_ElwMod.SIN, | |||||
) and ( | |||||
amp._enabled or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) | |||||
): | ): | ||||
if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND) and np.issubdtype( | |||||
args[0].dtype, np.integer | |||||
): | |||||
return args[0] | |||||
args = tuple( | |||||
map( | |||||
lambda x: x.astype("float32") | |||||
if hasattr(x, "dtype") and x.dtype != np.float32 | |||||
else x, | |||||
args, | |||||
) | |||||
) | |||||
args = utils.convert_inputs(*args) | |||||
# autocast to FP32 to maintain precision | |||||
# or to avoid op's not supporting all int args | |||||
args = cast_tensors(*args, promote=True) | |||||
if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND,) and np.issubdtype( | |||||
args[0].dtype, np.integer | |||||
): | |||||
return args[0] | |||||
return _elwise_apply(args, mode) | return _elwise_apply(args, mode) | ||||
def _matmul(inp1, inp2): | def _matmul(inp1, inp2): | ||||
if amp._enabled: | |||||
compute_mode = "float32" | |||||
inp1, inp2 = cast_tensors(inp1, inp2) | |||||
else: | |||||
compute_mode = "default" | |||||
inp1, inp2 = convert_inputs(inp1, inp2) | |||||
op = builtin.MatrixMul( | op = builtin.MatrixMul( | ||||
transposeA=False, transposeB=False, compute_mode="default", format="default" | |||||
transposeA=False, transposeB=False, compute_mode=compute_mode, format="default" | |||||
) | ) | ||||
inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||||
(result,) = apply(op, inp1, inp2) | (result,) = apply(op, inp1, inp2) | ||||
return result | return result | ||||
def _transpose(data, axes): | def _transpose(data, axes): | ||||
op = builtin.Dimshuffle(axes) | op = builtin.Dimshuffle(axes) | ||||
(data,) = utils.convert_inputs(data) | |||||
(data,) = convert_inputs(data) | |||||
(result,) = apply(op, data) | (result,) = apply(op, data) | ||||
return result | return result | ||||
def _broadcast(inp, shape): | def _broadcast(inp, shape): | ||||
shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | |||||
shape = astensor1d(shape, inp, dtype="int32", device=inp.device) | |||||
(result,) = apply(builtin.Broadcast(), inp, shape) | (result,) = apply(builtin.Broadcast(), inp, shape) | ||||
return result | return result | ||||
@@ -88,7 +105,7 @@ def _broadcast(inp, shape): | |||||
def _reshape(x, shape): | def _reshape(x, shape): | ||||
unspec_axis = None | unspec_axis = None | ||||
try: | try: | ||||
shape_tuple = _make_shape_tuple(shape) | |||||
shape_tuple = make_shape_tuple(shape) | |||||
except ValueError: | except ValueError: | ||||
pass | pass | ||||
else: | else: | ||||
@@ -102,7 +119,7 @@ def _reshape(x, shape): | |||||
"multiple -1 in shape: {} & {}".format(unspec_axis, i) | "multiple -1 in shape: {} & {}".format(unspec_axis, i) | ||||
) | ) | ||||
unspec_axis = i | unspec_axis = i | ||||
shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) | |||||
shape = astensor1d(shape, x, dtype="int32", device=x.device) | |||||
if unspec_axis is None: | if unspec_axis is None: | ||||
op = builtin.Reshape() | op = builtin.Reshape() | ||||
else: | else: | ||||
@@ -171,7 +188,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||||
return list(map(int, axis)) | return list(map(int, axis)) | ||||
axis = get_axes() | axis = get_axes() | ||||
axis = utils._normalize_axis(inp.ndim, axis) | |||||
axis = _normalize_axis(inp.ndim, axis) | |||||
axis = [a - i for i, a in enumerate(axis)] | axis = [a - i for i, a in enumerate(axis)] | ||||
op = builtin.RemoveAxis(axis=axis) | op = builtin.RemoveAxis(axis=axis) | ||||
@@ -184,7 +201,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||||
def _reduce(mode): | def _reduce(mode): | ||||
def f(self, axis=None, keepdims: bool = False): | def f(self, axis=None, keepdims: bool = False): | ||||
data = self | data = self | ||||
(data,) = utils.convert_inputs(data) | |||||
(data,) = convert_inputs(data) | |||||
if mode == "mean": | if mode == "mean": | ||||
data = data.astype("float32") | data = data.astype("float32") | ||||
elif self.dtype == np.bool_: | elif self.dtype == np.bool_: | ||||
@@ -196,7 +213,7 @@ def _reduce(mode): | |||||
op = builtin.Reduce(mode=mode, axis=0) | op = builtin.Reduce(mode=mode, axis=0) | ||||
(result,) = apply(op, data) | (result,) = apply(op, data) | ||||
elif isinstance(axis, collections.abc.Iterable): | elif isinstance(axis, collections.abc.Iterable): | ||||
axis = utils._normalize_axis(self.ndim, axis, reverse=True) | |||||
axis = _normalize_axis(self.ndim, axis, reverse=True) | |||||
for ai in axis: | for ai in axis: | ||||
op = builtin.Reduce(mode=mode, axis=ai) | op = builtin.Reduce(mode=mode, axis=ai) | ||||
(data,) = apply(op, data) | (data,) = apply(op, data) | ||||
@@ -359,11 +376,11 @@ class ArrayMethodMixin(abc.ABC): | |||||
yield self[i] | yield self[i] | ||||
def __getitem__(self, index): | def __getitem__(self, index): | ||||
return _getitem(self, index) | |||||
return getitem(self, index) | |||||
def __setitem__(self, index, value): | def __setitem__(self, index, value): | ||||
if index is not Ellipsis: | if index is not Ellipsis: | ||||
value = _setitem(self, index, value) | |||||
value = setitem(self, index, value) | |||||
self._reset(value) | self._reset(value) | ||||
__contains__ = _todo | __contains__ = _todo | ||||
@@ -422,7 +439,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
Returns a :class:`Tensor` with the same data and number of elements | Returns a :class:`Tensor` with the same data and number of elements | ||||
with the specified :attr:`~.Tensor.dtype`. | with the specified :attr:`~.Tensor.dtype`. | ||||
""" | """ | ||||
return utils.astype(self, dtype) | |||||
return astype(self, dtype) | |||||
def reshape(self, *args): | def reshape(self, *args): | ||||
r""" | r""" | ||||
@@ -18,7 +18,7 @@ import numpy as np | |||||
from .. import _imperative_rt | from .. import _imperative_rt | ||||
from .._imperative_rt import GraphOptimizeOptions | from .._imperative_rt import GraphOptimizeOptions | ||||
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | ||||
from .._wrap import device as as_device | |||||
from .._wrap import as_device | |||||
from ..ops.builtin import OpDef | from ..ops.builtin import OpDef | ||||
from .core import TensorBase | from .core import TensorBase | ||||
@@ -13,9 +13,10 @@ import numpy as np | |||||
from .._imperative_rt import make_const | from .._imperative_rt import make_const | ||||
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device | from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device | ||||
from .._wrap import device as as_device | |||||
from .._wrap import as_device | |||||
from ..ops import builtin | from ..ops import builtin | ||||
from ..ops.special import Const | from ..ops.special import Const | ||||
from .amp import _high_prec_dtype, _low_prec_dtype | |||||
from .dtype import is_dtype_equal, is_quantize | from .dtype import is_dtype_equal, is_quantize | ||||
_enable_convert_inputs = True | _enable_convert_inputs = True | ||||
@@ -98,6 +99,14 @@ def convert_inputs(*args, device=None): | |||||
return tuple(map(convert, args)) | return tuple(map(convert, args)) | ||||
def cast_tensors(*args, promote=False): | |||||
if promote: | |||||
dtype = _high_prec_dtype | |||||
else: | |||||
dtype = _low_prec_dtype | |||||
return tuple(arg.astype(dtype) if arg is not None else None for arg in args) | |||||
def result_type(*args): | def result_type(*args): | ||||
dtypes = [] | dtypes = [] | ||||
for i in args: | for i in args: | ||||
@@ -12,10 +12,8 @@ import numpy as np | |||||
from ..core._imperative_rt.core2 import SymbolVar, apply | from ..core._imperative_rt.core2 import SymbolVar, apply | ||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.ops.builtin import Elemwise | from ..core.ops.builtin import Elemwise | ||||
from ..core.tensor import utils | |||||
from ..core.tensor.array_method import _elwise_apply | |||||
from ..core.tensor.utils import astype | |||||
from ..device import get_default_device | |||||
from ..core.tensor.array_method import _elwise | |||||
from ..core.tensor.utils import astype, convert_inputs | |||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from ..utils.deprecation import deprecated_func | from ..utils.deprecation import deprecated_func | ||||
@@ -69,46 +67,9 @@ __all__ = [ | |||||
] | ] | ||||
def _elwise(*args, mode): | |||||
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args)) | |||||
if len(tensor_args) == 0: | |||||
dtype = utils.dtype_promotion(args) | |||||
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | |||||
args = utils.convert_inputs(first_arg, *args[1:]) | |||||
else: | |||||
args = utils.convert_inputs(*args) | |||||
if mode in ( | |||||
Elemwise.Mode.TRUE_DIV, | |||||
Elemwise.Mode.EXP, | |||||
Elemwise.Mode.POW, | |||||
Elemwise.Mode.LOG, | |||||
Elemwise.Mode.EXPM1, | |||||
Elemwise.Mode.LOG1P, | |||||
Elemwise.Mode.TANH, | |||||
Elemwise.Mode.ACOS, | |||||
Elemwise.Mode.ASIN, | |||||
Elemwise.Mode.ATAN2, | |||||
Elemwise.Mode.CEIL, | |||||
Elemwise.Mode.COS, | |||||
Elemwise.Mode.FLOOR, | |||||
Elemwise.Mode.H_SWISH, | |||||
Elemwise.Mode.ROUND, | |||||
Elemwise.Mode.SIGMOID, | |||||
Elemwise.Mode.SIN, | |||||
): | |||||
if mode in ( | |||||
Elemwise.Mode.CEIL, | |||||
Elemwise.Mode.FLOOR, | |||||
Elemwise.Mode.ROUND, | |||||
) and np.issubdtype(args[0].dtype, np.integer): | |||||
return args[0] | |||||
args = tuple(map(lambda x: astype(x, "float32"), args)) | |||||
return _elwise_apply(args, mode) | |||||
def _elemwise_multi_type(*args, mode, **kwargs): | def _elemwise_multi_type(*args, mode, **kwargs): | ||||
op = builtin.ElemwiseMultiType(mode=mode, **kwargs) | op = builtin.ElemwiseMultiType(mode=mode, **kwargs) | ||||
args = utils.convert_inputs(*args) | |||||
args = convert_inputs(*args) | |||||
(result,) = apply(op, *args) | (result,) = apply(op, *args) | ||||
return result | return result | ||||
@@ -14,7 +14,8 @@ from ..core._imperative_rt.core2 import apply | |||||
from ..core._trace_option import use_symbolic_shape | from ..core._trace_option import use_symbolic_shape | ||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
from ..core.tensor import utils | |||||
from ..core.tensor import amp | |||||
from ..core.tensor.utils import _normalize_axis, cast_tensors, convert_inputs, setscalar | |||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from .debug_param import get_execution_strategy | from .debug_param import get_execution_strategy | ||||
from .elemwise import clip, exp, log, log1p | from .elemwise import clip, exp, log, log1p | ||||
@@ -471,7 +472,7 @@ def argmin( | |||||
inp = inp.flatten() | inp = inp.flatten() | ||||
axis = 0 | axis = 0 | ||||
axis = utils._normalize_axis(inp.ndim, axis, reverse=True) | |||||
axis = _normalize_axis(inp.ndim, axis, reverse=True) | |||||
if isinstance(axis, collections.abc.Iterable): | if isinstance(axis, collections.abc.Iterable): | ||||
for ai in axis: | for ai in axis: | ||||
@@ -528,7 +529,7 @@ def argmax( | |||||
assert not keepdims, "can not set axis=None and keepdims=True" | assert not keepdims, "can not set axis=None and keepdims=True" | ||||
inp = inp.flatten() | inp = inp.flatten() | ||||
axis = 0 | axis = 0 | ||||
axis = utils._normalize_axis(inp.ndim, axis, reverse=True) | |||||
axis = _normalize_axis(inp.ndim, axis, reverse=True) | |||||
if isinstance(axis, collections.abc.Iterable): | if isinstance(axis, collections.abc.Iterable): | ||||
@@ -807,8 +808,13 @@ def matmul( | |||||
[28. 40.]] | [28. 40.]] | ||||
""" | """ | ||||
if amp._enabled: | |||||
compute_mode = "float32" | |||||
inp1, inp2 = cast_tensors(inp1, inp2) | |||||
else: | |||||
inp1, inp2 = convert_inputs(inp1, inp2) | |||||
remove_row, remove_col = False, False | remove_row, remove_col = False, False | ||||
inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||||
dim1, dim2 = inp1.ndim, inp2.ndim | dim1, dim2 = inp1.ndim, inp2.ndim | ||||
# handle dim=1 cases, dot and matrix-vector multiplication | # handle dim=1 cases, dot and matrix-vector multiplication | ||||
@@ -921,12 +927,12 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | |||||
""" | """ | ||||
op = builtin.Dot() | op = builtin.Dot() | ||||
inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||||
inp1, inp2 = convert_inputs(inp1, inp2) | |||||
assert ( | assert ( | ||||
inp1.ndim <= 1 and inp2.ndim <= 1 | inp1.ndim <= 1 and inp2.ndim <= 1 | ||||
), "Input tensors for dot must be 1-dimensional or scalar" | ), "Input tensors for dot must be 1-dimensional or scalar" | ||||
(result,) = apply(op, inp1, inp2) | (result,) = apply(op, inp1, inp2) | ||||
utils.setscalar(result) | |||||
setscalar(result) | |||||
return result | return result | ||||
@@ -15,9 +15,16 @@ from ..core._trace_option import use_symbolic_shape | |||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.ops.builtin import BatchNorm, Elemwise | from ..core.ops.builtin import BatchNorm, Elemwise | ||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
from ..core.tensor import megbrain_graph, utils | |||||
from ..core.tensor import amp, megbrain_graph | |||||
from ..core.tensor.array_method import _elwise_apply | from ..core.tensor.array_method import _elwise_apply | ||||
from ..core.tensor.utils import astensor1d, astype, setscalar | |||||
from ..core.tensor.utils import ( | |||||
astensor1d, | |||||
astype, | |||||
cast_tensors, | |||||
convert_inputs, | |||||
convert_single_value, | |||||
setscalar, | |||||
) | |||||
from ..device import get_default_device | from ..device import get_default_device | ||||
from ..distributed import WORLD, is_distributed | from ..distributed import WORLD, is_distributed | ||||
from ..random import uniform | from ..random import uniform | ||||
@@ -91,7 +98,9 @@ def expand_hw(x): | |||||
return int(h), int(w) | return int(h), int(w) | ||||
def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: | |||||
def linear( | |||||
inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None, compute_mode="default", | |||||
) -> Tensor: | |||||
""" | """ | ||||
Applies a linear transformation to the input tensor. | Applies a linear transformation to the input tensor. | ||||
@@ -102,8 +111,10 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor | |||||
:param bias: bias with shape `(out_features,)`. | :param bias: bias with shape `(out_features,)`. | ||||
Default: None | Default: None | ||||
""" | """ | ||||
ret = matmul(inp, weight, transpose_b=True) | |||||
ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode) | |||||
if bias is not None: | if bias is not None: | ||||
if amp._enabled: | |||||
bias = bias.astype("float16") | |||||
ret += bias | ret += bias | ||||
return ret | return ret | ||||
@@ -153,6 +164,11 @@ def conv1d( | |||||
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | ||||
assert inp.ndim == 3, "the input dimension of conv1d should be 3" | assert inp.ndim == 3, "the input dimension of conv1d should be 3" | ||||
assert weight.ndim == 3, "the weight dimension of conv1d should be 3" | assert weight.ndim == 3, "the weight dimension of conv1d should be 3" | ||||
if amp._enabled: | |||||
compute_mode = "float32" | |||||
inp, weight, bias = cast_tensors(inp, weight, bias) | |||||
else: | |||||
inp, weight = convert_inputs(inp, weight) | |||||
inp = expand_dims(inp, 3) | inp = expand_dims(inp, 3) | ||||
weight = expand_dims(weight, 3) | weight = expand_dims(weight, 3) | ||||
@@ -177,7 +193,6 @@ def conv1d( | |||||
compute_mode=compute_mode, | compute_mode=compute_mode, | ||||
sparse=sparse_type, | sparse=sparse_type, | ||||
) | ) | ||||
inp, weight = utils.convert_inputs(inp, weight) | |||||
(output,) = apply(op, inp, weight) | (output,) = apply(op, inp, weight) | ||||
if bias is not None: | if bias is not None: | ||||
output += bias | output += bias | ||||
@@ -228,7 +243,11 @@ def conv2d( | |||||
conv_mode.lower() == "cross_correlation" | conv_mode.lower() == "cross_correlation" | ||||
or conv_mode.name == "CROSS_CORRELATION" | or conv_mode.name == "CROSS_CORRELATION" | ||||
) | ) | ||||
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||||
if amp._enabled: | |||||
compute_mode = "float32" | |||||
inp, weight, bias = cast_tensors(inp, weight, bias) | |||||
else: | |||||
inp, weight = convert_inputs(inp, weight) | |||||
stride_h, stride_w = expand_hw(stride) | stride_h, stride_w = expand_hw(stride) | ||||
pad_h, pad_w = expand_hw(padding) | pad_h, pad_w = expand_hw(padding) | ||||
@@ -247,7 +266,6 @@ def conv2d( | |||||
compute_mode=compute_mode, | compute_mode=compute_mode, | ||||
sparse=sparse_type, | sparse=sparse_type, | ||||
) | ) | ||||
inp, weight = utils.convert_inputs(inp, weight) | |||||
(output,) = apply(op, inp, weight) | (output,) = apply(op, inp, weight) | ||||
if bias is not None: | if bias is not None: | ||||
output += bias | output += bias | ||||
@@ -286,6 +304,7 @@ def conv3d( | |||||
:return: output tensor. | :return: output tensor. | ||||
""" | """ | ||||
assert conv_mode.lower() == "cross_correlation" | assert conv_mode.lower() == "cross_correlation" | ||||
inp, weight = convert_inputs(inp, weight) | |||||
D, H, W = 0, 1, 2 | D, H, W = 0, 1, 2 | ||||
@@ -308,7 +327,6 @@ def conv3d( | |||||
mode=conv_mode, | mode=conv_mode, | ||||
sparse=sparse_type, | sparse=sparse_type, | ||||
) | ) | ||||
inp, weight = utils.convert_inputs(inp, weight) | |||||
(output,) = apply(op, inp, weight) | (output,) = apply(op, inp, weight) | ||||
if bias is not None: | if bias is not None: | ||||
output += bias | output += bias | ||||
@@ -358,7 +376,11 @@ def conv_transpose2d( | |||||
conv_mode.lower() == "cross_correlation" | conv_mode.lower() == "cross_correlation" | ||||
or conv_mode.name == "CROSS_CORRELATION" | or conv_mode.name == "CROSS_CORRELATION" | ||||
) | ) | ||||
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||||
if amp._enabled: | |||||
compute_mode = "float32" | |||||
inp, weight, bias = cast_tensors(inp, weight, bias) | |||||
else: | |||||
inp, weight = convert_inputs(inp, weight) | |||||
if groups != 1: | if groups != 1: | ||||
raise NotImplementedError("group transposed conv2d is not supported yet.") | raise NotImplementedError("group transposed conv2d is not supported yet.") | ||||
@@ -375,8 +397,8 @@ def conv_transpose2d( | |||||
dilate_h=dilate_h, | dilate_h=dilate_h, | ||||
dilate_w=dilate_w, | dilate_w=dilate_w, | ||||
strategy=get_execution_strategy(), | strategy=get_execution_strategy(), | ||||
compute_mode=compute_mode, | |||||
) | ) | ||||
weight, inp = utils.convert_inputs(weight, inp) | |||||
(output,) = apply(op, weight, inp) | (output,) = apply(op, weight, inp) | ||||
if bias is not None: | if bias is not None: | ||||
output += bias | output += bias | ||||
@@ -428,7 +450,11 @@ def deformable_conv2d( | |||||
conv_mode.lower() == "cross_correlation" | conv_mode.lower() == "cross_correlation" | ||||
or conv_mode.name == "CROSS_CORRELATION" | or conv_mode.name == "CROSS_CORRELATION" | ||||
) | ) | ||||
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||||
if amp._enabled: | |||||
compute_mode = "float32" | |||||
inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias) | |||||
else: | |||||
inp, weight, offset, mask = convert_inputs(inp, weight, offset, mask) | |||||
stride_h, stride_w = expand_hw(stride) | stride_h, stride_w = expand_hw(stride) | ||||
pad_h, pad_w = expand_hw(padding) | pad_h, pad_w = expand_hw(padding) | ||||
@@ -447,7 +473,6 @@ def deformable_conv2d( | |||||
compute_mode=compute_mode, | compute_mode=compute_mode, | ||||
sparse=sparse_type, | sparse=sparse_type, | ||||
) | ) | ||||
inp, weight, offset, mask = utils.convert_inputs(inp, weight, offset, mask) | |||||
(output,) = apply(op, inp, weight, offset, mask) | (output,) = apply(op, inp, weight, offset, mask) | ||||
if bias is not None: | if bias is not None: | ||||
output += bias | output += bias | ||||
@@ -468,6 +493,7 @@ def local_conv2d( | |||||
conv_mode.lower() == "cross_correlation" | conv_mode.lower() == "cross_correlation" | ||||
or conv_mode.name == "CROSS_CORRELATION" | or conv_mode.name == "CROSS_CORRELATION" | ||||
) | ) | ||||
inp, weight = convert_inputs(inp, weight) | |||||
stride_h, stride_w = expand_hw(stride) | stride_h, stride_w = expand_hw(stride) | ||||
pad_h, pad_w = expand_hw(padding) | pad_h, pad_w = expand_hw(padding) | ||||
@@ -481,10 +507,8 @@ def local_conv2d( | |||||
dilate_h=dilate_h, | dilate_h=dilate_h, | ||||
dilate_w=dilate_w, | dilate_w=dilate_w, | ||||
mode=conv_mode, | mode=conv_mode, | ||||
compute_mode="default", | |||||
sparse="dense", | sparse="dense", | ||||
) | ) | ||||
inp, weight = utils.convert_inputs(inp, weight) | |||||
(output,) = apply(op, inp, weight) | (output,) = apply(op, inp, weight) | ||||
if bias is not None: | if bias is not None: | ||||
output += bias | output += bias | ||||
@@ -515,8 +539,9 @@ def conv_transpose3d( | |||||
:param dilation: dilation of the 3D convolution operation. Default: 1 | :param dilation: dilation of the 3D convolution operation. Default: 1 | ||||
:return: output tensor. | :return: output tensor. | ||||
""" | """ | ||||
D, H, W = 0, 1, 2 | |||||
inp, weight = convert_inputs(inp, weight) | |||||
D, H, W = 0, 1, 2 | |||||
pad = _triple(padding) | pad = _triple(padding) | ||||
stride = _triple_nonzero(stride) | stride = _triple_nonzero(stride) | ||||
dilate = _triple_nonzero(dilation) | dilate = _triple_nonzero(dilation) | ||||
@@ -533,7 +558,6 @@ def conv_transpose3d( | |||||
dilate_w=dilate[W], | dilate_w=dilate[W], | ||||
strategy=get_execution_strategy(), | strategy=get_execution_strategy(), | ||||
) | ) | ||||
weight, inp = utils.convert_inputs(weight, inp) | |||||
(output,) = apply(op, weight, inp) | (output,) = apply(op, weight, inp) | ||||
if bias is not None: | if bias is not None: | ||||
output += bias | output += bias | ||||
@@ -994,7 +1018,8 @@ def batch_norm( | |||||
training: bool = False, | training: bool = False, | ||||
momentum: float = 0.9, | momentum: float = 0.9, | ||||
eps: float = 1e-5, | eps: float = 1e-5, | ||||
inplace: bool = True | |||||
inplace: bool = True, | |||||
compute_mode="default", | |||||
): | ): | ||||
r""" | r""" | ||||
Applies batch normalization to the input. | Applies batch normalization to the input. | ||||
@@ -1027,15 +1052,11 @@ def batch_norm( | |||||
def make_full_if_none(x, value): | def make_full_if_none(x, value): | ||||
if x is None: | if x is None: | ||||
(x,) = Const(value, dtype=inp.dtype, device=inp.device)() | (x,) = Const(value, dtype=inp.dtype, device=inp.device)() | ||||
shape = utils.astensor1d( | |||||
(1, C, 1, 1), inp, dtype="int32", device=inp.device | |||||
) | |||||
shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) | |||||
(result,) = apply(builtin.Broadcast(), x, shape) | (result,) = apply(builtin.Broadcast(), x, shape) | ||||
return result | return result | ||||
elif x.ndim == 1: | elif x.ndim == 1: | ||||
shape = utils.astensor1d( | |||||
(1, C, 1, 1), inp, dtype="int32", device=inp.device | |||||
) | |||||
shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) | |||||
(result,) = apply(builtin.Reshape(), x, shape) | (result,) = apply(builtin.Reshape(), x, shape) | ||||
return result | return result | ||||
return x | return x | ||||
@@ -1052,10 +1073,15 @@ def batch_norm( | |||||
if has_var and running_var.ndim != 4: | if has_var and running_var.ndim != 4: | ||||
raise ValueError | raise ValueError | ||||
inp, weight, bias, running_mean, running_var = utils.convert_inputs( | |||||
inp, weight, bias, running_mean, running_var | |||||
) | |||||
if amp._enabled: | |||||
inp = inp.astype("float16") | |||||
weight, bias, running_mean, running_var = cast_tensors( | |||||
weight, bias, running_mean, running_var, promote=True | |||||
) | |||||
elif compute_mode != "float32": | |||||
inp, weight, bias, running_mean, running_var = convert_inputs( | |||||
inp, weight, bias, running_mean, running_var | |||||
) | |||||
weight = make_full_if_none(weight, 1) | weight = make_full_if_none(weight, 1) | ||||
bias = make_full_if_none(bias, 0) | bias = make_full_if_none(bias, 0) | ||||
@@ -1352,7 +1378,7 @@ def indexing_one_hot( | |||||
""" | """ | ||||
assert isinstance(src, Tensor), "src must be of Tensor type" | assert isinstance(src, Tensor), "src must be of Tensor type" | ||||
op = builtin.IndexingOneHot(axis=axis) | op = builtin.IndexingOneHot(axis=axis) | ||||
index = utils.convert_single_value(index, dtype="int32", device=src.device) | |||||
index = convert_single_value(index, dtype="int32", device=src.device) | |||||
(result,) = apply(op, src, index) | (result,) = apply(op, src, index) | ||||
if not keepdims: | if not keepdims: | ||||
result = squeeze(result, axis) | result = squeeze(result, axis) | ||||
@@ -13,7 +13,7 @@ import numpy as np | |||||
from ..core._imperative_rt import CompNode | from ..core._imperative_rt import CompNode | ||||
from ..core._imperative_rt.core2 import SymbolVar, apply | from ..core._imperative_rt.core2 import SymbolVar, apply | ||||
from ..core._wrap import device as as_device | |||||
from ..core._wrap import as_device | |||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.ops.builtin import Copy, Identity | from ..core.ops.builtin import Copy, Identity | ||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
@@ -33,7 +33,7 @@ from ..core._imperative_rt.ops import ( | |||||
RemoteSend, | RemoteSend, | ||||
) | ) | ||||
from ..core._trace_option import set_symbolic_shape | from ..core._trace_option import set_symbolic_shape | ||||
from ..core._wrap import device as as_device | |||||
from ..core._wrap import as_device | |||||
from ..core.ops.builtin import BatchNorm, OpDef | from ..core.ops.builtin import BatchNorm, OpDef | ||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
@@ -26,6 +26,7 @@ class _BatchNorm(Module): | |||||
affine=True, | affine=True, | ||||
track_running_stats=True, | track_running_stats=True, | ||||
freeze=False, | freeze=False, | ||||
compute_mode="default", | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
super(_BatchNorm, self).__init__(**kwargs) | super(_BatchNorm, self).__init__(**kwargs) | ||||
@@ -36,6 +37,7 @@ class _BatchNorm(Module): | |||||
self.track_running_stats = track_running_stats | self.track_running_stats = track_running_stats | ||||
self._track_running_stats_saved = track_running_stats | self._track_running_stats_saved = track_running_stats | ||||
self.freeze = freeze | self.freeze = freeze | ||||
self.compute_mode = compute_mode | |||||
if self.freeze: | if self.freeze: | ||||
assert ( | assert ( | ||||
self._track_running_stats_saved | self._track_running_stats_saved | ||||
@@ -123,6 +125,7 @@ class _BatchNorm(Module): | |||||
or ((self.running_mean is None) and (self.running_var is None)), | or ((self.running_mean is None) and (self.running_var is None)), | ||||
momentum=exponential_average_factor, | momentum=exponential_average_factor, | ||||
eps=self.eps, | eps=self.eps, | ||||
compute_mode=self.compute_mode, | |||||
) | ) | ||||
if _ndims != 4: | if _ndims != 4: | ||||
@@ -51,7 +51,12 @@ class Linear(Module): | |||||
""" | """ | ||||
def __init__( | def __init__( | ||||
self, in_features: int, out_features: int, bias: bool = True, **kwargs | |||||
self, | |||||
in_features: int, | |||||
out_features: int, | |||||
bias: bool = True, | |||||
compute_mode: str = "default", | |||||
**kwargs | |||||
): | ): | ||||
super().__init__(**kwargs) | super().__init__(**kwargs) | ||||
self.out_features = out_features | self.out_features = out_features | ||||
@@ -62,6 +67,7 @@ class Linear(Module): | |||||
if bias: | if bias: | ||||
b_shape = (out_features,) | b_shape = (out_features,) | ||||
self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | ||||
self.compute_mode = compute_mode | |||||
self.reset_parameters() | self.reset_parameters() | ||||
def _get_fanin(self): | def _get_fanin(self): | ||||
@@ -75,7 +81,7 @@ class Linear(Module): | |||||
init.zeros_(self.bias) | init.zeros_(self.bias) | ||||
def _calc_linear(self, x, weight, bias): | def _calc_linear(self, x, weight, bias): | ||||
return linear(x, weight, bias) | |||||
return linear(x, weight, bias, compute_mode=self.compute_mode) | |||||
def forward(self, x): | def forward(self, x): | ||||
return self._calc_linear(x, self.weight, self.bias) | return self._calc_linear(x, self.weight, self.bias) | ||||
@@ -5,8 +5,6 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import copy | |||||
import math | |||||
from functools import partial | from functools import partial | ||||
from .. import functional as F | from .. import functional as F | ||||
@@ -14,7 +14,7 @@ from .core._imperative_rt import CompNode | |||||
from .core._imperative_rt.core2 import Tensor as _Tensor | from .core._imperative_rt.core2 import Tensor as _Tensor | ||||
from .core._imperative_rt.core2 import apply | from .core._imperative_rt.core2 import apply | ||||
from .core._trace_option import use_symbolic_shape | from .core._trace_option import use_symbolic_shape | ||||
from .core._wrap import device as as_device | |||||
from .core._wrap import as_device | |||||
from .core.ops.builtin import Copy, GetVarShape | from .core.ops.builtin import Copy, GetVarShape | ||||
from .core.tensor.array_method import ArrayMethodMixin | from .core.tensor.array_method import ArrayMethodMixin | ||||
from .device import _valid_device, get_default_device | from .device import _valid_device, get_default_device | ||||
@@ -0,0 +1,34 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from megengine import amp | |||||
from megengine.core.tensor import amp as origin_amp | |||||
def test_grad_scaler(): | |||||
def check(enabled, low, high): | |||||
assert amp.enabled == enabled | |||||
assert origin_amp._enabled == enabled | |||||
assert amp.low_prec_dtype == low | |||||
assert origin_amp._low_prec_dtype == low | |||||
assert amp.high_prec_dtype == high | |||||
assert origin_amp._high_prec_dtype == high | |||||
origin_enabled = amp.enabled | |||||
origin_high = amp.high_prec_dtype | |||||
origin_low = amp.low_prec_dtype | |||||
with amp.autocast(low_prec_dtype="low", high_prec_dtype="high"): | |||||
check(True, "low", "high") | |||||
check(origin_enabled, origin_low, origin_high) | |||||
amp.enabled = True | |||||
amp.high_prec_dtype = "high" | |||||
amp.low_prec_dtype = "low" | |||||
check(True, "low", "high") | |||||
amp.enabled = origin_enabled | |||||
amp.high_prec_dtype = origin_high | |||||
amp.low_prec_dtype = origin_low | |||||
check(origin_enabled, origin_low, origin_high) |
@@ -14,6 +14,7 @@ import numpy as np | |||||
import pytest | import pytest | ||||
from utils import opr_test | from utils import opr_test | ||||
import megengine.amp as amp | |||||
import megengine.core.ops.builtin as builtin | import megengine.core.ops.builtin as builtin | ||||
import megengine.core.tensor.dtype as dtype | import megengine.core.tensor.dtype as dtype | ||||
import megengine.functional as F | import megengine.functional as F | ||||
@@ -767,6 +768,27 @@ def test_batch_conv_bias(): | |||||
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | ||||
def test_conv2d_io16c32(): | |||||
amp.enabled = True | |||||
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) | |||||
weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32) | |||||
out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) | |||||
amp.enabled = False | |||||
expected = F.conv2d( | |||||
inp.astype("float16"), | |||||
weight.astype("float16"), | |||||
None, | |||||
(2, 2), | |||||
(3, 3), | |||||
(1, 1), | |||||
1, | |||||
compute_mode="float32", | |||||
) | |||||
assert out.dtype == np.float16 | |||||
assert expected.dtype == np.float16 | |||||
np.testing.assert_allclose(out.numpy(), expected.numpy()) | |||||
def test_conv2d_zero_stride_numpy_array(): | def test_conv2d_zero_stride_numpy_array(): | ||||
inp = np.random.randn(3, 224, 224).astype(np.float32) | inp = np.random.randn(3, 224, 224).astype(np.float32) | ||||
inp = inp[np.newaxis, :] | inp = inp[np.newaxis, :] | ||||
@@ -787,8 +809,8 @@ def test_conv3d_zero_stride_numpy_array(): | |||||
def test_conv1d(): | def test_conv1d(): | ||||
inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4)) | |||||
weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2)) | |||||
inp = tensor(np.ones((2, 2, 4), dtype=np.float32)) | |||||
weight = tensor(np.ones((3, 2, 2), dtype=np.float32)) | |||||
out = F.conv1d(inp, weight, None, 2, 0, 1, 1) | out = F.conv1d(inp, weight, None, 2, 0, 1, 1) | ||||
np.testing.assert_equal( | np.testing.assert_equal( | ||||
out.numpy(), | out.numpy(), | ||||
@@ -798,9 +820,31 @@ def test_conv1d(): | |||||
) | ) | ||||
def test_batchnorm2d_io16c32(): | |||||
amp.enabled = True | |||||
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) | |||||
weight = tensor(np.ones((1, 3, 1, 1)), dtype=np.float32) | |||||
bias = tensor(np.zeros((1, 3, 1, 1)), dtype=np.float32) | |||||
out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False) | |||||
amp.enabled = False | |||||
expected = F.batch_norm( | |||||
inp.astype("float16"), | |||||
weight=weight, | |||||
bias=bias, | |||||
training=True, | |||||
inplace=False, | |||||
compute_mode="float32", | |||||
) | |||||
assert out.dtype == np.float16 | |||||
assert expected.dtype == np.float16 | |||||
np.testing.assert_allclose(out.numpy(), expected.numpy()) | |||||
def test_conv3d(): | def test_conv3d(): | ||||
inp = tensor(np.ones((256,), dtype=np.float32).reshape(2, 2, 4, 4, 4)) | |||||
weight = tensor(np.ones((48,), dtype=np.float32).reshape(3, 2, 2, 2, 2)) | |||||
inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32)) | |||||
weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32)) | |||||
out = F.conv3d(inp, weight, None, 2, 0, 1, 1) | out = F.conv3d(inp, weight, None, 2, 0, 1, 1) | ||||
print(out.numpy().shape) | print(out.numpy().shape) | ||||
np.testing.assert_equal( | np.testing.assert_equal( | ||||
@@ -473,39 +473,6 @@ def test_pickle_module(): | |||||
np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) | np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) | ||||
def test_load_quantized(): | |||||
from megengine.core.tensor import dtype | |||||
data_shape = (2, 28) | |||||
data = tensor(np.random.random(data_shape), dtype="float32") | |||||
data = data.astype(dtype.qint8(0.1)) | |||||
mlp = MLP() | |||||
quantize_qat(mlp) | |||||
quantize(mlp) | |||||
mlp.dense0.weight = Parameter(mlp.dense0.weight.astype(dtype.qint8(0.001)).numpy()) | |||||
mlp.dense1.weight = Parameter(mlp.dense1.weight.astype(dtype.qint8(0.0002)).numpy()) | |||||
mlp.eval() | |||||
pred0 = mlp(data) | |||||
with BytesIO() as fout: | |||||
mge.save(mlp.state_dict(), fout) | |||||
fout.seek(0) | |||||
checkpoint = mge.load(fout) | |||||
# change mlp weight. | |||||
mlp.dense0.weight = Parameter( | |||||
mlp.dense0.weight.astype(dtype.qint8(0.00001)).numpy() | |||||
) | |||||
mlp.dense1.weight = Parameter( | |||||
mlp.dense1.weight.astype(dtype.qint8(0.2)).numpy() | |||||
) | |||||
mlp.load_state_dict(checkpoint) | |||||
pred1 = mlp(data) | |||||
np.testing.assert_allclose( | |||||
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), atol=5e-6 | |||||
) | |||||
def test_repr_basic(): | def test_repr_basic(): | ||||
# test whether __repr__ can output correct information | # test whether __repr__ can output correct information | ||||
class ConvModel(Module): | class ConvModel(Module): | ||||