|
|
@@ -132,7 +132,7 @@ class LSQ(_FakeQuantize, QParamsModuleMixin): |
|
|
|
:param eps:a small value to avoid division by zero. Default: 1e-5 |
|
|
|
""" |
|
|
|
|
|
|
|
def init( |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
dtype: Union[str, QuantDtypeMeta], |
|
|
|
enable: bool = True, |
|
|
@@ -142,6 +142,9 @@ class LSQ(_FakeQuantize, QParamsModuleMixin): |
|
|
|
super().__init__(dtype=dtype, enable=enable, **kwargs) |
|
|
|
self.eps = Tensor(eps, dtype="float32") |
|
|
|
self.step_size = Parameter(1.0, dtype="float32") |
|
|
|
self.mode = None |
|
|
|
self.zero_point = Tensor(0.0, dtype="float32") |
|
|
|
self.grad_scale = Tensor(1.0, dtype="float32") |
|
|
|
|
|
|
|
def set_qparams(self, qparams: LSQParams): |
|
|
|
self.mode = qparams.mode |
|
|
|