Browse Source

feat(mge/module): add quantize dtype load support for module load_state_dict

GitOrigin-RevId: 0a94cb6b17
release-0.5
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
3f2770fe03
4 changed files with 96 additions and 0 deletions
  1. +8
    -0
      python_module/megengine/core/tensor.py
  2. +5
    -0
      python_module/megengine/module/module.py
  3. +46
    -0
      python_module/test/unit/core/test_tensor.py
  4. +37
    -0
      python_module/test/unit/module/test_module.py

+ 8
- 0
python_module/megengine/core/tensor.py View File

@@ -235,6 +235,14 @@ class Tensor:
return self.__val.dtype return self.__val.dtype
return self._symvar.dtype return self._symvar.dtype


def set_dtype(self, dtype: str = None):
r"""Set the data type of the tensor.
"""
if self.__val is not None:
self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy())
elif self.__sym is not None:
self.__sym = self.__sym.astype(dtype)

@property @property
def _comp_node(self): def _comp_node(self):
if self.__val is not None: if self.__val is not None:


+ 5
- 0
python_module/megengine/module/module.py View File

@@ -11,6 +11,7 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union


import numpy as np import numpy as np


from .._internal.dtype import is_quantize
from ..core import Buffer, Parameter, Tensor from ..core import Buffer, Parameter, Tensor
from ..logger import get_logger from ..logger import get_logger


@@ -460,6 +461,10 @@ class Module(metaclass=ABCMeta):
), "param `{}` shape mismatch, should be {}, get {}".format( ), "param `{}` shape mismatch, should be {}, get {}".format(
k, var.shape, to_be_load.shape k, var.shape, to_be_load.shape
) )
# For quantized dtype, the initialized dtype
# scale/zero_points maybe invalid, use pretrained dtype instead.
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype):
var.set_dtype(to_be_load.dtype)
var.set_value(to_be_load) var.set_value(to_be_load)
loaded.append(k) loaded.append(k)




+ 46
- 0
python_module/test/unit/core/test_tensor.py View File

@@ -10,6 +10,7 @@ import numpy as np
import pytest import pytest


import megengine as mge import megengine as mge
import megengine._internal as mgb




def test_wrong_dtype(): def test_wrong_dtype():
@@ -26,3 +27,48 @@ def test_tensor_routine():
mge.tensor([1]) mge.tensor([1])


mge.tensor(1.5) mge.tensor(1.5)


def test_tensor_set_dtype():
def check_dtype_value(tensor, dtype_scale, value):
if mgb.dtype.is_quantize(tensor.dtype):
if np.abs(mgb.dtype.get_scale(tensor.dtype) - dtype_scale) > 1e-5:
raise AssertionError(
"compare scale failed expect {} got {}".format(
dtype_scale, mgb.dtype.get_scale(tensor.dtype)
)
)
if np.abs(tensor.numpy()[0][0] - value) > 1e-5:
raise AssertionError(
"compare value failed expect {} got {}".format(
tensor.numpy()[0][0], value
)
)

t = mge.Parameter(np.ones((3, 4), dtype="float32"))
t.set_dtype(mgb.dtype.qint8(0.1))
check_dtype_value(t, 0.1, 10)

t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1)))
t.set_dtype(mgb.dtype.qint8(0.3))
check_dtype_value(t, 0.3, 3)

t = mge.Buffer(np.ones((3, 4), dtype="float32"))
t.set_dtype(mgb.dtype.qint8(0.1))
check_dtype_value(t, 0.1, 10)

t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1)))
t.set_dtype(mgb.dtype.qint8(0.3))
check_dtype_value(t, 0.3, 3)

t = mge.Buffer(np.ones((3, 4), dtype="float32"))
s = t + 1
s.set_dtype(mgb.dtype.qint8(0.2))
check_dtype_value(s, 0.2, 10)

t.set_dtype(mgb.dtype.qint8(0.3))
s = t + 1
s.set_dtype(mgb.dtype.qint8(0.1))
check_dtype_value(s, 0.1, 18)
s.set_dtype("float32")
check_dtype_value(s, 0, 1.8)

+ 37
- 0
python_module/test/unit/module/test_module.py View File

@@ -14,8 +14,10 @@ import pytest
from helpers import MLP from helpers import MLP


import megengine as mge import megengine as mge
import megengine._internal as mgb
from megengine.core import Buffer, Parameter, Tensor, tensor from megengine.core import Buffer, Parameter, Tensor, tensor
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential
from megengine.quantization.quantize import quantize, quantize_qat
from megengine.test import assertTensorClose from megengine.test import assertTensorClose




@@ -347,3 +349,38 @@ def test_dump_model():
pred = mlp(data) pred = mlp(data)
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
mge.dump(pred, f.name) mge.dump(pred, f.name)


def test_load_quantized():
data_shape = (2, 28)
data = tensor(np.random.random(data_shape), dtype="float32")
data = data.astype(mgb.dtype.qint8(0.1))
mlp = MLP()
quantize_qat(mlp)
quantize(mlp)
mlp.dense0.weight = Parameter(
mlp.dense0.weight.astype(mgb.dtype.qint8(0.001)).numpy()
)
mlp.dense1.weight = Parameter(
mlp.dense1.weight.astype(mgb.dtype.qint8(0.0002)).numpy()
)
mlp.eval()
pred0 = mlp(data)

with BytesIO() as fout:
mge.save(mlp.state_dict(), fout)
fout.seek(0)
checkpoint = mge.load(fout)
# change mlp weight.
mlp.dense0.weight = Parameter(
mlp.dense0.weight.astype(mgb.dtype.qint8(0.00001)).numpy()
)
mlp.dense1.weight = Parameter(
mlp.dense1.weight.astype(mgb.dtype.qint8(0.2)).numpy()
)
mlp.load_state_dict(checkpoint)
pred1 = mlp(data)

assertTensorClose(
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6
)

Loading…
Cancel
Save