Browse Source

feat(mge/module): add quantize dtype load support for module load_state_dict

GitOrigin-RevId: 0a94cb6b17
release-0.5
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
3f2770fe03
4 changed files with 96 additions and 0 deletions
  1. +8
    -0
      python_module/megengine/core/tensor.py
  2. +5
    -0
      python_module/megengine/module/module.py
  3. +46
    -0
      python_module/test/unit/core/test_tensor.py
  4. +37
    -0
      python_module/test/unit/module/test_module.py

+ 8
- 0
python_module/megengine/core/tensor.py View File

@@ -235,6 +235,14 @@ class Tensor:
return self.__val.dtype
return self._symvar.dtype

def set_dtype(self, dtype: str = None):
r"""Set the data type of the tensor.
"""
if self.__val is not None:
self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy())
elif self.__sym is not None:
self.__sym = self.__sym.astype(dtype)

@property
def _comp_node(self):
if self.__val is not None:


+ 5
- 0
python_module/megengine/module/module.py View File

@@ -11,6 +11,7 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union

import numpy as np

from .._internal.dtype import is_quantize
from ..core import Buffer, Parameter, Tensor
from ..logger import get_logger

@@ -460,6 +461,10 @@ class Module(metaclass=ABCMeta):
), "param `{}` shape mismatch, should be {}, get {}".format(
k, var.shape, to_be_load.shape
)
# For quantized dtype, the initialized dtype
# scale/zero_points maybe invalid, use pretrained dtype instead.
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype):
var.set_dtype(to_be_load.dtype)
var.set_value(to_be_load)
loaded.append(k)



+ 46
- 0
python_module/test/unit/core/test_tensor.py View File

@@ -10,6 +10,7 @@ import numpy as np
import pytest

import megengine as mge
import megengine._internal as mgb


def test_wrong_dtype():
@@ -26,3 +27,48 @@ def test_tensor_routine():
mge.tensor([1])

mge.tensor(1.5)


def test_tensor_set_dtype():
def check_dtype_value(tensor, dtype_scale, value):
if mgb.dtype.is_quantize(tensor.dtype):
if np.abs(mgb.dtype.get_scale(tensor.dtype) - dtype_scale) > 1e-5:
raise AssertionError(
"compare scale failed expect {} got {}".format(
dtype_scale, mgb.dtype.get_scale(tensor.dtype)
)
)
if np.abs(tensor.numpy()[0][0] - value) > 1e-5:
raise AssertionError(
"compare value failed expect {} got {}".format(
tensor.numpy()[0][0], value
)
)

t = mge.Parameter(np.ones((3, 4), dtype="float32"))
t.set_dtype(mgb.dtype.qint8(0.1))
check_dtype_value(t, 0.1, 10)

t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1)))
t.set_dtype(mgb.dtype.qint8(0.3))
check_dtype_value(t, 0.3, 3)

t = mge.Buffer(np.ones((3, 4), dtype="float32"))
t.set_dtype(mgb.dtype.qint8(0.1))
check_dtype_value(t, 0.1, 10)

t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1)))
t.set_dtype(mgb.dtype.qint8(0.3))
check_dtype_value(t, 0.3, 3)

t = mge.Buffer(np.ones((3, 4), dtype="float32"))
s = t + 1
s.set_dtype(mgb.dtype.qint8(0.2))
check_dtype_value(s, 0.2, 10)

t.set_dtype(mgb.dtype.qint8(0.3))
s = t + 1
s.set_dtype(mgb.dtype.qint8(0.1))
check_dtype_value(s, 0.1, 18)
s.set_dtype("float32")
check_dtype_value(s, 0, 1.8)

+ 37
- 0
python_module/test/unit/module/test_module.py View File

@@ -14,8 +14,10 @@ import pytest
from helpers import MLP

import megengine as mge
import megengine._internal as mgb
from megengine.core import Buffer, Parameter, Tensor, tensor
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential
from megengine.quantization.quantize import quantize, quantize_qat
from megengine.test import assertTensorClose


@@ -347,3 +349,38 @@ def test_dump_model():
pred = mlp(data)
with tempfile.NamedTemporaryFile() as f:
mge.dump(pred, f.name)


def test_load_quantized():
data_shape = (2, 28)
data = tensor(np.random.random(data_shape), dtype="float32")
data = data.astype(mgb.dtype.qint8(0.1))
mlp = MLP()
quantize_qat(mlp)
quantize(mlp)
mlp.dense0.weight = Parameter(
mlp.dense0.weight.astype(mgb.dtype.qint8(0.001)).numpy()
)
mlp.dense1.weight = Parameter(
mlp.dense1.weight.astype(mgb.dtype.qint8(0.0002)).numpy()
)
mlp.eval()
pred0 = mlp(data)

with BytesIO() as fout:
mge.save(mlp.state_dict(), fout)
fout.seek(0)
checkpoint = mge.load(fout)
# change mlp weight.
mlp.dense0.weight = Parameter(
mlp.dense0.weight.astype(mgb.dtype.qint8(0.00001)).numpy()
)
mlp.dense1.weight = Parameter(
mlp.dense1.weight.astype(mgb.dtype.qint8(0.2)).numpy()
)
mlp.load_state_dict(checkpoint)
pred1 = mlp(data)

assertTensorClose(
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6
)

Loading…
Cancel
Save