|
|
@@ -61,14 +61,14 @@ class _FakeQuantize(Module): |
|
|
|
def fake_quant_forward(self, inp, qparams: QParams = None): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def normal_foward(self, inp, qparams: QParams = None): |
|
|
|
def normal_forward(self, inp, qparams: QParams = None): |
|
|
|
return inp |
|
|
|
|
|
|
|
def forward(self, inp, qparams: QParams = None): |
|
|
|
if self.enabled: |
|
|
|
return self.fake_quant_forward(inp, qparams=qparams) |
|
|
|
else: |
|
|
|
return self.normal_foward(inp, qparams=qparams) |
|
|
|
return self.normal_forward(inp, qparams=qparams) |
|
|
|
|
|
|
|
|
|
|
|
class TQT(_FakeQuantize, QParamsModuleMixin): |
|
|
|