GitOrigin-RevId: f8511f72ad
tags/v1.3.0
@@ -6,12 +6,8 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import math | |||
from typing import Iterable | |||
import numpy as np | |||
from .. import functional as F | |||
from ..autodiff import Function | |||
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype | |||
from ..module import Module | |||
from ..tensor import Parameter, Tensor | |||
@@ -72,20 +68,10 @@ class TQT(_FakeQuantize): | |||
""" | |||
def __init__( | |||
self, | |||
q_dict, | |||
dtype: str, | |||
narrow_range: bool = False, | |||
enable: bool = True, | |||
**kwargs | |||
self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs | |||
): | |||
super().__init__(dtype, narrow_range, enable, **kwargs) | |||
assert ( | |||
q_dict["mode"] == QuantMode.SYMMERTIC | |||
), "only symmetric quantization is supported by TQT" | |||
if "scale" not in q_dict or q_dict["scale"] is None: | |||
raise AssertionError("Can not get an initialized scale") | |||
self.scale = Tensor(F.log(q_dict["scale"]) / math.log(2)) | |||
self.scale = Parameter(0.0, dtype="float32") | |||
def fake_quant_forward(self, inp, q_dict=None): | |||
# when enable, TQT will do fakequant forward, finetune the scale | |||
@@ -93,14 +79,22 @@ class TQT(_FakeQuantize): | |||
def get_qparams(self): | |||
q_dict = get_qparam_dict(QuantMode.SYMMERTIC) | |||
q_dict["scale"] = 2 ** self.scale | |||
q_dict["scale"] = 2 ** self.scale.detach() | |||
return q_dict | |||
def set_qparams(self, q_dict): | |||
assert ( | |||
q_dict["mode"] == QuantMode.SYMMERTIC | |||
), "only symmetric quantization is supported by TQT" | |||
if "scale" not in q_dict or q_dict["scale"] is None: | |||
raise AssertionError("Can not get an initialized scale") | |||
self.scale._reset(F.log(q_dict["scale"]) / math.log(2)) | |||
def get_dtype(self): | |||
q_dict = self.get_qparams() | |||
scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0] | |||
scale = None if "scale" not in q_dict else q_dict["scale"].numpy() | |||
zero_point = ( | |||
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0] | |||
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy() | |||
) | |||
return get_quantized_dtype(self.dtype, scale, zero_point) | |||
@@ -17,7 +17,7 @@ from ..distributed import WORLD, get_rank, is_distributed | |||
from ..functional.distributed import all_reduce_max, all_reduce_min | |||
from ..module import Module | |||
from ..tensor import Tensor | |||
from .utils import QuantMode, Round, get_qparam_dict | |||
from .utils import QuantMode, get_qparam_dict | |||
class Observer(Module): | |||
@@ -110,7 +110,7 @@ class MinMaxObserver(Observer): | |||
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit | |||
) | |||
# caculate zero_point | |||
q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"])) | |||
q_dict["zero_point"] = self.qmin - F.round(min_val / q_dict["scale"]) | |||
return q_dict | |||
@@ -453,12 +453,10 @@ class PassiveObserver(Observer): | |||
This class can be set :attr:`scale` derectly. | |||
""" | |||
def __init__(self, q_dict, dtype: str, narrow_range: bool = False, **kwargs): | |||
def __init__(self, dtype: str, narrow_range: bool = False, **kwargs): | |||
super().__init__(dtype, narrow_range, **kwargs) | |||
self.q_dict = deepcopy(q_dict) | |||
if "scale" not in q_dict or q_dict["scale"] is None: | |||
raise AssertionError("Can not get an initialized scale") | |||
self.orig_scale = q_dict["scale"].numpy() | |||
self.q_dict = None | |||
self.orig_scale = None | |||
@property | |||
def scale(self): | |||
@@ -472,6 +470,12 @@ class PassiveObserver(Observer): | |||
def get_qparams(self): | |||
return self.q_dict | |||
def set_qparams(self, q_dict): | |||
self.q_dict = deepcopy(q_dict) | |||
if "scale" not in q_dict or q_dict["scale"] is None: | |||
raise AssertionError("Can not get an initialized scale") | |||
self.orig_scale = q_dict["scale"].numpy() | |||
def forward(self, x): | |||
r""" | |||
Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`. | |||
@@ -152,7 +152,10 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): | |||
module = deepcopy(module) | |||
def safe_call(func, q_dict): | |||
return func(q_dict=q_dict) if func is not None else None | |||
inst = func() if func is not None else None | |||
if inst is not None and getattr(inst, "set_qparams", None) is not None: | |||
inst.set_qparams(q_dict) | |||
return inst | |||
for m in list(module._flatten(predicate=is_qat)): | |||
if m.with_weight: | |||
@@ -41,8 +41,8 @@ def test_exponential_moving_average_observer(): | |||
m = ExponentialMovingAverageObserver(momentum=t) | |||
m(mge.tensor(x1, dtype=np.float32)) | |||
m(mge.tensor(x2, dtype=np.float32)) | |||
np.testing.assert_allclose(m.min_val.numpy(), expected_min) | |||
np.testing.assert_allclose(m.max_val.numpy(), expected_max) | |||
np.testing.assert_allclose(m.min_val.numpy(), expected_min, atol=1e-5) | |||
np.testing.assert_allclose(m.max_val.numpy(), expected_max, atol=1e-5) | |||
def test_histogram_observer(): | |||
@@ -57,7 +57,8 @@ def test_histogram_observer(): | |||
def test_passive_observer(): | |||
q_dict = {"scale": mge.tensor(1.0)} | |||
m = PassiveObserver(q_dict, "qint8") | |||
m = PassiveObserver("qint8") | |||
m.set_qparams(q_dict) | |||
assert m.orig_scale == 1.0 | |||
assert m.scale == 1.0 | |||
m.scale = 2.0 | |||