|
|
@@ -111,10 +111,8 @@ class QParams: |
|
|
|
return "QParams({})".format(content) |
|
|
|
|
|
|
|
|
|
|
|
class LSQParams: |
|
|
|
r"""To standardize LSQ's qparams format. If custom |
|
|
|
qparams is needed, inherit this class and add custom ``__slots__``. |
|
|
|
""" |
|
|
|
class LSQParams(QParams): |
|
|
|
r"""LSQ qparams with extra grad_scale slot.""" |
|
|
|
|
|
|
|
__slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale" |
|
|
|
|
|
|
@@ -126,30 +124,9 @@ class LSQParams: |
|
|
|
zero_point: Tensor, |
|
|
|
grad_scale: Tensor, |
|
|
|
): |
|
|
|
self.mode = mode |
|
|
|
self.dtype_meta = dtype_meta |
|
|
|
self.scale = scale |
|
|
|
self.zero_point = zero_point |
|
|
|
super().__init__(mode, dtype_meta, scale, zero_point) |
|
|
|
self.grad_scale = grad_scale |
|
|
|
|
|
|
|
def update(self, lsqparams: "LSQParams"): |
|
|
|
for key in self.__slots__: |
|
|
|
setattr(self, key, getattr(lsqparams, key)) |
|
|
|
|
|
|
|
def __eq__(self, other): |
|
|
|
if len(self.__slots__) != len(other.__slots__): |
|
|
|
return False |
|
|
|
for key in self.__slots__: |
|
|
|
if not hasattr(other, key) or getattr(self, key) != getattr(other, key): |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
content = ", ".join( |
|
|
|
["{}={}".format(key, getattr(self, key)) for key in self.__slots__] |
|
|
|
) |
|
|
|
return "LSQParams({})".format(content) |
|
|
|
|
|
|
|
|
|
|
|
class QParamsModuleMixin(abc.ABC): |
|
|
|
def get_quantized_dtype(self): |
|
|
|