GitOrigin-RevId: 0a94cb6b17
release-0.5
@@ -235,6 +235,14 @@ class Tensor: | |||||
return self.__val.dtype | return self.__val.dtype | ||||
return self._symvar.dtype | return self._symvar.dtype | ||||
def set_dtype(self, dtype: str = None): | |||||
r"""Set the data type of the tensor. | |||||
""" | |||||
if self.__val is not None: | |||||
self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy()) | |||||
elif self.__sym is not None: | |||||
self.__sym = self.__sym.astype(dtype) | |||||
@property | @property | ||||
def _comp_node(self): | def _comp_node(self): | ||||
if self.__val is not None: | if self.__val is not None: | ||||
@@ -11,6 +11,7 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||||
import numpy as np | import numpy as np | ||||
from .._internal.dtype import is_quantize | |||||
from ..core import Buffer, Parameter, Tensor | from ..core import Buffer, Parameter, Tensor | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
@@ -460,6 +461,10 @@ class Module(metaclass=ABCMeta): | |||||
), "param `{}` shape mismatch, should be {}, get {}".format( | ), "param `{}` shape mismatch, should be {}, get {}".format( | ||||
k, var.shape, to_be_load.shape | k, var.shape, to_be_load.shape | ||||
) | ) | ||||
# For quantized dtype, the initialized dtype | |||||
# scale/zero_points maybe invalid, use pretrained dtype instead. | |||||
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | |||||
var.set_dtype(to_be_load.dtype) | |||||
var.set_value(to_be_load) | var.set_value(to_be_load) | ||||
loaded.append(k) | loaded.append(k) | ||||
@@ -10,6 +10,7 @@ import numpy as np | |||||
import pytest | import pytest | ||||
import megengine as mge | import megengine as mge | ||||
import megengine._internal as mgb | |||||
def test_wrong_dtype(): | def test_wrong_dtype(): | ||||
@@ -26,3 +27,48 @@ def test_tensor_routine(): | |||||
mge.tensor([1]) | mge.tensor([1]) | ||||
mge.tensor(1.5) | mge.tensor(1.5) | ||||
def test_tensor_set_dtype(): | |||||
def check_dtype_value(tensor, dtype_scale, value): | |||||
if mgb.dtype.is_quantize(tensor.dtype): | |||||
if np.abs(mgb.dtype.get_scale(tensor.dtype) - dtype_scale) > 1e-5: | |||||
raise AssertionError( | |||||
"compare scale failed expect {} got {}".format( | |||||
dtype_scale, mgb.dtype.get_scale(tensor.dtype) | |||||
) | |||||
) | |||||
if np.abs(tensor.numpy()[0][0] - value) > 1e-5: | |||||
raise AssertionError( | |||||
"compare value failed expect {} got {}".format( | |||||
tensor.numpy()[0][0], value | |||||
) | |||||
) | |||||
t = mge.Parameter(np.ones((3, 4), dtype="float32")) | |||||
t.set_dtype(mgb.dtype.qint8(0.1)) | |||||
check_dtype_value(t, 0.1, 10) | |||||
t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | |||||
t.set_dtype(mgb.dtype.qint8(0.3)) | |||||
check_dtype_value(t, 0.3, 3) | |||||
t = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||||
t.set_dtype(mgb.dtype.qint8(0.1)) | |||||
check_dtype_value(t, 0.1, 10) | |||||
t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | |||||
t.set_dtype(mgb.dtype.qint8(0.3)) | |||||
check_dtype_value(t, 0.3, 3) | |||||
t = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||||
s = t + 1 | |||||
s.set_dtype(mgb.dtype.qint8(0.2)) | |||||
check_dtype_value(s, 0.2, 10) | |||||
t.set_dtype(mgb.dtype.qint8(0.3)) | |||||
s = t + 1 | |||||
s.set_dtype(mgb.dtype.qint8(0.1)) | |||||
check_dtype_value(s, 0.1, 18) | |||||
s.set_dtype("float32") | |||||
check_dtype_value(s, 0, 1.8) |
@@ -14,8 +14,10 @@ import pytest | |||||
from helpers import MLP | from helpers import MLP | ||||
import megengine as mge | import megengine as mge | ||||
import megengine._internal as mgb | |||||
from megengine.core import Buffer, Parameter, Tensor, tensor | from megengine.core import Buffer, Parameter, Tensor, tensor | ||||
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential | from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential | ||||
from megengine.quantization.quantize import quantize, quantize_qat | |||||
from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
@@ -347,3 +349,38 @@ def test_dump_model(): | |||||
pred = mlp(data) | pred = mlp(data) | ||||
with tempfile.NamedTemporaryFile() as f: | with tempfile.NamedTemporaryFile() as f: | ||||
mge.dump(pred, f.name) | mge.dump(pred, f.name) | ||||
def test_load_quantized(): | |||||
data_shape = (2, 28) | |||||
data = tensor(np.random.random(data_shape), dtype="float32") | |||||
data = data.astype(mgb.dtype.qint8(0.1)) | |||||
mlp = MLP() | |||||
quantize_qat(mlp) | |||||
quantize(mlp) | |||||
mlp.dense0.weight = Parameter( | |||||
mlp.dense0.weight.astype(mgb.dtype.qint8(0.001)).numpy() | |||||
) | |||||
mlp.dense1.weight = Parameter( | |||||
mlp.dense1.weight.astype(mgb.dtype.qint8(0.0002)).numpy() | |||||
) | |||||
mlp.eval() | |||||
pred0 = mlp(data) | |||||
with BytesIO() as fout: | |||||
mge.save(mlp.state_dict(), fout) | |||||
fout.seek(0) | |||||
checkpoint = mge.load(fout) | |||||
# change mlp weight. | |||||
mlp.dense0.weight = Parameter( | |||||
mlp.dense0.weight.astype(mgb.dtype.qint8(0.00001)).numpy() | |||||
) | |||||
mlp.dense1.weight = Parameter( | |||||
mlp.dense1.weight.astype(mgb.dtype.qint8(0.2)).numpy() | |||||
) | |||||
mlp.load_state_dict(checkpoint) | |||||
pred1 = mlp(data) | |||||
assertTensorClose( | |||||
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6 | |||||
) |