You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

linear.py 1.1 kB

1234567891011121314151617181920212223242526272829303132
  1. from .. import linear as Float
  2. from .module import QATModule
  3. class Linear(Float.Linear, QATModule):
  4. r"""A :class:`~.QATModule` version of :class:`~.module.Linear`.
  5. Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
  6. Args:
  7. in_features: size of each input sample.
  8. out_features: size of each output sample.
  9. bias: If set to ``False``, the layer will not learn an additive bias.
  10. Default: True
  11. """
  12. def forward(self, inp):
  13. w_qat = self.apply_quant_weight(self.weight)
  14. b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
  15. return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat))
  16. @classmethod
  17. def from_float_module(cls, float_module: Float.Linear):
  18. r"""
  19. Return a :class:`~.QATModule` instance converted from
  20. a float :class:`~.Module` instance.
  21. """
  22. qmod = cls(
  23. float_module.in_features, float_module.out_features, name=float_module.name
  24. )
  25. qmod.weight = float_module.weight
  26. qmod.bias = float_module.bias
  27. return qmod