GitOrigin-RevId: f8511f72ad
tags/v1.3.0
@@ -6,12 +6,8 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import math | import math | ||||
from typing import Iterable | |||||
import numpy as np | |||||
from .. import functional as F | from .. import functional as F | ||||
from ..autodiff import Function | |||||
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype | from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype | ||||
from ..module import Module | from ..module import Module | ||||
from ..tensor import Parameter, Tensor | from ..tensor import Parameter, Tensor | ||||
@@ -72,20 +68,10 @@ class TQT(_FakeQuantize): | |||||
""" | """ | ||||
def __init__( | def __init__( | ||||
self, | |||||
q_dict, | |||||
dtype: str, | |||||
narrow_range: bool = False, | |||||
enable: bool = True, | |||||
**kwargs | |||||
self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs | |||||
): | ): | ||||
super().__init__(dtype, narrow_range, enable, **kwargs) | super().__init__(dtype, narrow_range, enable, **kwargs) | ||||
assert ( | |||||
q_dict["mode"] == QuantMode.SYMMERTIC | |||||
), "only symmetric quantization is supported by TQT" | |||||
if "scale" not in q_dict or q_dict["scale"] is None: | |||||
raise AssertionError("Can not get an initialized scale") | |||||
self.scale = Tensor(F.log(q_dict["scale"]) / math.log(2)) | |||||
self.scale = Parameter(0.0, dtype="float32") | |||||
def fake_quant_forward(self, inp, q_dict=None): | def fake_quant_forward(self, inp, q_dict=None): | ||||
# when enable, TQT will do fakequant forward, finetune the scale | # when enable, TQT will do fakequant forward, finetune the scale | ||||
@@ -93,14 +79,22 @@ class TQT(_FakeQuantize): | |||||
def get_qparams(self): | def get_qparams(self): | ||||
q_dict = get_qparam_dict(QuantMode.SYMMERTIC) | q_dict = get_qparam_dict(QuantMode.SYMMERTIC) | ||||
q_dict["scale"] = 2 ** self.scale | |||||
q_dict["scale"] = 2 ** self.scale.detach() | |||||
return q_dict | return q_dict | ||||
def set_qparams(self, q_dict): | |||||
assert ( | |||||
q_dict["mode"] == QuantMode.SYMMERTIC | |||||
), "only symmetric quantization is supported by TQT" | |||||
if "scale" not in q_dict or q_dict["scale"] is None: | |||||
raise AssertionError("Can not get an initialized scale") | |||||
self.scale._reset(F.log(q_dict["scale"]) / math.log(2)) | |||||
def get_dtype(self): | def get_dtype(self): | ||||
q_dict = self.get_qparams() | q_dict = self.get_qparams() | ||||
scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0] | |||||
scale = None if "scale" not in q_dict else q_dict["scale"].numpy() | |||||
zero_point = ( | zero_point = ( | ||||
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0] | |||||
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy() | |||||
) | ) | ||||
return get_quantized_dtype(self.dtype, scale, zero_point) | return get_quantized_dtype(self.dtype, scale, zero_point) | ||||
@@ -17,7 +17,7 @@ from ..distributed import WORLD, get_rank, is_distributed | |||||
from ..functional.distributed import all_reduce_max, all_reduce_min | from ..functional.distributed import all_reduce_max, all_reduce_min | ||||
from ..module import Module | from ..module import Module | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from .utils import QuantMode, Round, get_qparam_dict | |||||
from .utils import QuantMode, get_qparam_dict | |||||
class Observer(Module): | class Observer(Module): | ||||
@@ -110,7 +110,7 @@ class MinMaxObserver(Observer): | |||||
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit | (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit | ||||
) | ) | ||||
# caculate zero_point | # caculate zero_point | ||||
q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"])) | |||||
q_dict["zero_point"] = self.qmin - F.round(min_val / q_dict["scale"]) | |||||
return q_dict | return q_dict | ||||
@@ -453,12 +453,10 @@ class PassiveObserver(Observer): | |||||
This class can be set :attr:`scale` derectly. | This class can be set :attr:`scale` derectly. | ||||
""" | """ | ||||
def __init__(self, q_dict, dtype: str, narrow_range: bool = False, **kwargs): | |||||
def __init__(self, dtype: str, narrow_range: bool = False, **kwargs): | |||||
super().__init__(dtype, narrow_range, **kwargs) | super().__init__(dtype, narrow_range, **kwargs) | ||||
self.q_dict = deepcopy(q_dict) | |||||
if "scale" not in q_dict or q_dict["scale"] is None: | |||||
raise AssertionError("Can not get an initialized scale") | |||||
self.orig_scale = q_dict["scale"].numpy() | |||||
self.q_dict = None | |||||
self.orig_scale = None | |||||
@property | @property | ||||
def scale(self): | def scale(self): | ||||
@@ -472,6 +470,12 @@ class PassiveObserver(Observer): | |||||
def get_qparams(self): | def get_qparams(self): | ||||
return self.q_dict | return self.q_dict | ||||
def set_qparams(self, q_dict): | |||||
self.q_dict = deepcopy(q_dict) | |||||
if "scale" not in q_dict or q_dict["scale"] is None: | |||||
raise AssertionError("Can not get an initialized scale") | |||||
self.orig_scale = q_dict["scale"].numpy() | |||||
def forward(self, x): | def forward(self, x): | ||||
r""" | r""" | ||||
Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`. | Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`. | ||||
@@ -152,7 +152,10 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): | |||||
module = deepcopy(module) | module = deepcopy(module) | ||||
def safe_call(func, q_dict): | def safe_call(func, q_dict): | ||||
return func(q_dict=q_dict) if func is not None else None | |||||
inst = func() if func is not None else None | |||||
if inst is not None and getattr(inst, "set_qparams", None) is not None: | |||||
inst.set_qparams(q_dict) | |||||
return inst | |||||
for m in list(module._flatten(predicate=is_qat)): | for m in list(module._flatten(predicate=is_qat)): | ||||
if m.with_weight: | if m.with_weight: | ||||
@@ -41,8 +41,8 @@ def test_exponential_moving_average_observer(): | |||||
m = ExponentialMovingAverageObserver(momentum=t) | m = ExponentialMovingAverageObserver(momentum=t) | ||||
m(mge.tensor(x1, dtype=np.float32)) | m(mge.tensor(x1, dtype=np.float32)) | ||||
m(mge.tensor(x2, dtype=np.float32)) | m(mge.tensor(x2, dtype=np.float32)) | ||||
np.testing.assert_allclose(m.min_val.numpy(), expected_min) | |||||
np.testing.assert_allclose(m.max_val.numpy(), expected_max) | |||||
np.testing.assert_allclose(m.min_val.numpy(), expected_min, atol=1e-5) | |||||
np.testing.assert_allclose(m.max_val.numpy(), expected_max, atol=1e-5) | |||||
def test_histogram_observer(): | def test_histogram_observer(): | ||||
@@ -57,7 +57,8 @@ def test_histogram_observer(): | |||||
def test_passive_observer(): | def test_passive_observer(): | ||||
q_dict = {"scale": mge.tensor(1.0)} | q_dict = {"scale": mge.tensor(1.0)} | ||||
m = PassiveObserver(q_dict, "qint8") | |||||
m = PassiveObserver("qint8") | |||||
m.set_qparams(q_dict) | |||||
assert m.orig_scale == 1.0 | assert m.orig_scale == 1.0 | ||||
assert m.scale == 1.0 | assert m.scale == 1.0 | ||||
m.scale = 2.0 | m.scale = 2.0 | ||||