You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

linear.py 1.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import numpy as np
  2. from ... import functional as F
  3. from ...core.tensor import dtype
  4. from ...tensor import Parameter
  5. from ..qat import linear as QAT
  6. from .module import QuantizedModule
  7. class Linear(QuantizedModule):
  8. r"""Quantized version of :class:`~.qat.Linear`."""
  9. def __init__(self, dtype: np.dtype = None, **kwargs):
  10. super().__init__(**kwargs)
  11. self.weight = None
  12. self.bias = None
  13. self.output_dtype = dtype
  14. def forward(self, inp):
  15. if self.training:
  16. raise ValueError("quantized module only support inference.")
  17. inp_scale = dtype.get_scale(inp.dtype)
  18. w_scale = dtype.get_scale(self.weight.dtype)
  19. bias_dtype = dtype.qint32(inp_scale * w_scale)
  20. ret = F.nn.linear(
  21. inp,
  22. self.weight,
  23. None if self.bias is None else self.bias.astype(bias_dtype),
  24. )
  25. ret = ret if self.output_dtype is None else ret.astype(self.output_dtype)
  26. return ret
  27. @classmethod
  28. def from_qat_module(cls, qat_module: QAT.Linear):
  29. r"""
  30. Return a :class:`~.QuantizedModule` instance converted from a
  31. :class:`~.QATModule` instance.
  32. """
  33. output_dtype = qat_module.get_activation_dtype()
  34. qmod = cls(dtype=output_dtype, name=qat_module.name)
  35. weight = qat_module.weight.astype(qat_module.get_weight_dtype())
  36. qmod.weight = Parameter(weight.numpy(), name=qat_module.weight.name)
  37. if qat_module.bias is not None:
  38. qmod.bias = Parameter(qat_module.bias.numpy(), name=qat_module.bias.name)
  39. return qmod