You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batch_matmul_activation.py 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from typing import Tuple, Union
  2. import numpy as np
  3. from ... import module as Float
  4. from ...core.tensor import dtype
  5. from ...functional import expand_dims, squeeze
  6. from ...functional.quantized import batch_conv_bias_activation
  7. from ...tensor import Parameter
  8. from ..qat import batch_matmul_activation as QAT
  9. from .module import QuantizedModule
  10. class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule):
  11. r"""Quantized version of :class:`~.qat.BatchMatMulActivation`."""
  12. def __init__(
  13. self,
  14. batch: int,
  15. in_features: int,
  16. out_features: int,
  17. bias: bool = True,
  18. nonlinear_mode="identity",
  19. dtype=None,
  20. **kwargs
  21. ):
  22. super().__init__(batch, in_features, out_features, bias, **kwargs)
  23. self.output_dtype = dtype
  24. def calc_bmm_quantized(self, inp):
  25. inp_scale = dtype.get_scale(inp.dtype)
  26. w_scale = dtype.get_scale(self.weight.dtype)
  27. bias_scale = inp_scale * w_scale
  28. inp = expand_dims(inp, [-1])
  29. res = batch_conv_bias_activation(
  30. inp,
  31. self.weight,
  32. self.bias.astype(dtype.qint32(bias_scale)),
  33. dtype=self.output_dtype,
  34. stride=1,
  35. padding=0,
  36. dilation=1,
  37. groups=1,
  38. nonlinear_mode=self.nonlinear_mode,
  39. )
  40. return squeeze(res, -1)
  41. @classmethod
  42. def from_qat_module(cls, qat_module: QAT.BatchMatMulActivation):
  43. output_dtype = qat_module.get_activation_dtype()
  44. qbmm = cls(
  45. qat_module.batch,
  46. qat_module.in_features,
  47. qat_module.out_features,
  48. qat_module.bias is not None,
  49. dtype=output_dtype,
  50. name=qat_module.name,
  51. )
  52. weight = qat_module.weight.astype(qat_module.get_weight_dtype())
  53. weight = expand_dims(weight, [-1, -2])
  54. qbmm.weight = Parameter(weight.numpy(), name=qat_module.weight.name)
  55. if qat_module.bias is not None:
  56. bias = qat_module.bias.reshape((1, qbmm.out_features, 1, 1))
  57. qbmm.bias = Parameter(bias.numpy(), name=qat_module.bias.name)
  58. else:
  59. qbmm.bias = Parameter(
  60. np.zeros((1, qbmm.out_features, 1, 1), dtype=np.float32)
  61. )
  62. return qbmm
  63. def forward(self, inp):
  64. return self.calc_bmm_quantized(inp)