You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batch_matmul_activation.py 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from typing import Tuple, Union
  9. import numpy as np
  10. from ... import module as Float
  11. from ...core.tensor import dtype
  12. from ...functional import expand_dims, squeeze
  13. from ...functional.quantized import batch_conv_bias_activation
  14. from ...tensor import Parameter
  15. from ..qat import batch_matmul_activation as QAT
  16. from .module import QuantizedModule
  17. class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule):
  18. r"""Quantized version of :class:`~.qat.BatchMatMulActivation`."""
  19. def __init__(
  20. self,
  21. batch: int,
  22. in_features: int,
  23. out_features: int,
  24. bias: bool = True,
  25. nonlinear_mode="identity",
  26. dtype=None,
  27. **kwargs
  28. ):
  29. super().__init__(batch, in_features, out_features, bias, **kwargs)
  30. self.output_dtype = dtype
  31. def calc_bmm_quantized(self, inp):
  32. inp_scale = dtype.get_scale(inp.dtype)
  33. w_scale = dtype.get_scale(self.weight.dtype)
  34. bias_scale = inp_scale * w_scale
  35. inp = expand_dims(inp, [-1])
  36. res = batch_conv_bias_activation(
  37. inp,
  38. self.weight,
  39. self.bias.astype(dtype.qint32(bias_scale)),
  40. dtype=self.output_dtype,
  41. stride=1,
  42. padding=0,
  43. dilation=1,
  44. groups=1,
  45. nonlinear_mode=self.nonlinear_mode,
  46. )
  47. return squeeze(res, -1)
  48. @classmethod
  49. def from_qat_module(cls, qat_module: QAT.BatchMatMulActivation):
  50. output_dtype = qat_module.get_activation_dtype()
  51. qbmm = cls(
  52. qat_module.batch,
  53. qat_module.in_features,
  54. qat_module.out_features,
  55. qat_module.bias is not None,
  56. dtype=output_dtype,
  57. name=qat_module.name,
  58. )
  59. weight = qat_module.weight.astype(qat_module.get_weight_dtype())
  60. weight = expand_dims(weight, [-1, -2])
  61. qbmm.weight = Parameter(weight.numpy(), name=qat_module.weight.name)
  62. if qat_module.bias is not None:
  63. bias = qat_module.bias.reshape((1, qbmm.out_features, 1, 1))
  64. qbmm.bias = Parameter(bias.numpy(), name=qat_module.bias.name)
  65. else:
  66. qbmm.bias = Parameter(
  67. np.zeros((1, qbmm.out_features, 1, 1), dtype=np.float32)
  68. )
  69. return qbmm
  70. def forward(self, inp):
  71. return self.calc_bmm_quantized(inp)