You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quant_dequant.py 896 B

1234567891011121314151617181920212223242526272829303132
  1. from ..qat import quant_dequant as QAT
  2. from .module import QuantizedModule
  3. class QuantStub(QuantizedModule):
  4. r"""Quantized version of :class:`~.qat.QuantStub`,
  5. will convert input to quantized dtype.
  6. """
  7. def __init__(self, dtype=None, **kwargs):
  8. super().__init__(**kwargs)
  9. self.output_dtype = dtype
  10. def forward(self, inp):
  11. return inp.astype(self.output_dtype)
  12. @classmethod
  13. def from_qat_module(cls, qat_module: QAT.QuantStub):
  14. return cls(qat_module.get_activation_dtype(), name=qat_module.name)
  15. class DequantStub(QuantizedModule):
  16. r"""Quantized version of :class:`~.qat.DequantStub`,
  17. will restore quantized input to float32 dtype.
  18. """
  19. def forward(self, inp):
  20. return inp.astype("float32")
  21. @classmethod
  22. def from_qat_module(cls, qat_module: QAT.DequantStub):
  23. return cls(name=qat_module.name)