GitOrigin-RevId: 0a94cb6b17
release-0.5
@@ -235,6 +235,14 @@ class Tensor: | |||
return self.__val.dtype | |||
return self._symvar.dtype | |||
def set_dtype(self, dtype: str = None): | |||
r"""Set the data type of the tensor. | |||
""" | |||
if self.__val is not None: | |||
self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy()) | |||
elif self.__sym is not None: | |||
self.__sym = self.__sym.astype(dtype) | |||
@property | |||
def _comp_node(self): | |||
if self.__val is not None: | |||
@@ -11,6 +11,7 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||
import numpy as np | |||
from .._internal.dtype import is_quantize | |||
from ..core import Buffer, Parameter, Tensor | |||
from ..logger import get_logger | |||
@@ -460,6 +461,10 @@ class Module(metaclass=ABCMeta): | |||
), "param `{}` shape mismatch, should be {}, get {}".format( | |||
k, var.shape, to_be_load.shape | |||
) | |||
# For quantized dtype, the initialized dtype | |||
# scale/zero_points maybe invalid, use pretrained dtype instead. | |||
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | |||
var.set_dtype(to_be_load.dtype) | |||
var.set_value(to_be_load) | |||
loaded.append(k) | |||
@@ -10,6 +10,7 @@ import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine._internal as mgb | |||
def test_wrong_dtype(): | |||
@@ -26,3 +27,48 @@ def test_tensor_routine(): | |||
mge.tensor([1]) | |||
mge.tensor(1.5) | |||
def test_tensor_set_dtype(): | |||
def check_dtype_value(tensor, dtype_scale, value): | |||
if mgb.dtype.is_quantize(tensor.dtype): | |||
if np.abs(mgb.dtype.get_scale(tensor.dtype) - dtype_scale) > 1e-5: | |||
raise AssertionError( | |||
"compare scale failed expect {} got {}".format( | |||
dtype_scale, mgb.dtype.get_scale(tensor.dtype) | |||
) | |||
) | |||
if np.abs(tensor.numpy()[0][0] - value) > 1e-5: | |||
raise AssertionError( | |||
"compare value failed expect {} got {}".format( | |||
tensor.numpy()[0][0], value | |||
) | |||
) | |||
t = mge.Parameter(np.ones((3, 4), dtype="float32")) | |||
t.set_dtype(mgb.dtype.qint8(0.1)) | |||
check_dtype_value(t, 0.1, 10) | |||
t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | |||
t.set_dtype(mgb.dtype.qint8(0.3)) | |||
check_dtype_value(t, 0.3, 3) | |||
t = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||
t.set_dtype(mgb.dtype.qint8(0.1)) | |||
check_dtype_value(t, 0.1, 10) | |||
t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | |||
t.set_dtype(mgb.dtype.qint8(0.3)) | |||
check_dtype_value(t, 0.3, 3) | |||
t = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||
s = t + 1 | |||
s.set_dtype(mgb.dtype.qint8(0.2)) | |||
check_dtype_value(s, 0.2, 10) | |||
t.set_dtype(mgb.dtype.qint8(0.3)) | |||
s = t + 1 | |||
s.set_dtype(mgb.dtype.qint8(0.1)) | |||
check_dtype_value(s, 0.1, 18) | |||
s.set_dtype("float32") | |||
check_dtype_value(s, 0, 1.8) |
@@ -14,8 +14,10 @@ import pytest | |||
from helpers import MLP | |||
import megengine as mge | |||
import megengine._internal as mgb | |||
from megengine.core import Buffer, Parameter, Tensor, tensor | |||
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential | |||
from megengine.quantization.quantize import quantize, quantize_qat | |||
from megengine.test import assertTensorClose | |||
@@ -347,3 +349,38 @@ def test_dump_model(): | |||
pred = mlp(data) | |||
with tempfile.NamedTemporaryFile() as f: | |||
mge.dump(pred, f.name) | |||
def test_load_quantized(): | |||
data_shape = (2, 28) | |||
data = tensor(np.random.random(data_shape), dtype="float32") | |||
data = data.astype(mgb.dtype.qint8(0.1)) | |||
mlp = MLP() | |||
quantize_qat(mlp) | |||
quantize(mlp) | |||
mlp.dense0.weight = Parameter( | |||
mlp.dense0.weight.astype(mgb.dtype.qint8(0.001)).numpy() | |||
) | |||
mlp.dense1.weight = Parameter( | |||
mlp.dense1.weight.astype(mgb.dtype.qint8(0.0002)).numpy() | |||
) | |||
mlp.eval() | |||
pred0 = mlp(data) | |||
with BytesIO() as fout: | |||
mge.save(mlp.state_dict(), fout) | |||
fout.seek(0) | |||
checkpoint = mge.load(fout) | |||
# change mlp weight. | |||
mlp.dense0.weight = Parameter( | |||
mlp.dense0.weight.astype(mgb.dtype.qint8(0.00001)).numpy() | |||
) | |||
mlp.dense1.weight = Parameter( | |||
mlp.dense1.weight.astype(mgb.dtype.qint8(0.2)).numpy() | |||
) | |||
mlp.load_state_dict(checkpoint) | |||
pred1 = mlp(data) | |||
assertTensorClose( | |||
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6 | |||
) |