You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. from typing import Tuple, Union
  2. import numpy as np
  3. from ... import module as Float
  4. from ...core.tensor import dtype
  5. from ...functional.nn import conv_bias_activation, pad
  6. from ...functional.quantized import conv_transpose2d
  7. from ...tensor import Parameter
  8. from ..qat import conv as QAT
  9. from .module import QuantizedModule
  10. class Conv2d(Float.Conv2d, QuantizedModule):
  11. r"""Quantized version of :class:`~.qat.Conv2d`.
  12. Applies a 2D convolution over a quantized input tensor, used for inference only.
  13. The parameter is same with :class:`~.module.Conv2d`.
  14. """
  15. def __init__(
  16. self,
  17. in_channels: int,
  18. out_channels: int,
  19. kernel_size: Union[int, Tuple[int, int]],
  20. stride: Union[int, Tuple[int, int]] = 1,
  21. padding: Union[int, Tuple[int, int]] = 0,
  22. dilation: Union[int, Tuple[int, int]] = 1,
  23. groups: int = 1,
  24. conv_mode: str = "cross_correlation",
  25. compute_mode: str = "default",
  26. dtype=None,
  27. padding_mode: str = "zeros",
  28. **kwargs
  29. ):
  30. super().__init__(
  31. in_channels,
  32. out_channels,
  33. kernel_size,
  34. stride,
  35. padding,
  36. dilation,
  37. groups,
  38. True,
  39. conv_mode,
  40. compute_mode,
  41. padding_mode,
  42. )
  43. self.output_dtype = dtype
  44. def calc_conv_quantized(self, inp, nonlinear_mode="identity"):
  45. assert self.padding_mode in [
  46. "zeros",
  47. "reflect",
  48. "replicate",
  49. ]
  50. inp_scale = dtype.get_scale(inp.dtype)
  51. w_scale = dtype.get_scale(self.weight.dtype)
  52. bias_scale = inp_scale * w_scale
  53. if self.padding_mode != "zeros":
  54. return conv_bias_activation(
  55. pad(inp, self.get_pad_witdth(), self.padding_mode),
  56. self.weight,
  57. self.bias.astype(dtype.qint32(bias_scale)),
  58. self.output_dtype,
  59. self.stride,
  60. 0,
  61. self.dilation,
  62. self.groups,
  63. conv_mode=self.conv_mode,
  64. compute_mode=self.compute_mode,
  65. nonlinear_mode=nonlinear_mode,
  66. )
  67. return conv_bias_activation(
  68. inp,
  69. self.weight,
  70. self.bias.astype(dtype.qint32(bias_scale)),
  71. self.output_dtype,
  72. self.stride,
  73. self.padding,
  74. self.dilation,
  75. self.groups,
  76. conv_mode=self.conv_mode,
  77. compute_mode=self.compute_mode,
  78. nonlinear_mode=nonlinear_mode,
  79. )
  80. @classmethod
  81. def from_qat_module(cls, qat_module: QAT.Conv2d):
  82. r"""
  83. Return a :class:`~.QuantizedModule` instance converted from a
  84. :class:`~.QATModule` instance.
  85. """
  86. output_dtype = qat_module.get_activation_dtype()
  87. qconv = cls(
  88. qat_module.in_channels,
  89. qat_module.out_channels,
  90. qat_module.kernel_size,
  91. qat_module.stride,
  92. qat_module.padding,
  93. qat_module.dilation,
  94. qat_module.groups,
  95. dtype=output_dtype,
  96. padding_mode=qat_module.padding_mode,
  97. name=qat_module.name,
  98. )
  99. weight = qat_module.weight.astype(qat_module.get_weight_dtype())
  100. qconv.weight = Parameter(weight.numpy(), name=qat_module.weight.name)
  101. if qat_module.bias is not None:
  102. qconv.bias = Parameter(qat_module.bias.numpy(), name=qat_module.bias.name)
  103. else:
  104. qconv.bias = Parameter(
  105. np.zeros(qat_module._infer_bias_shape(), dtype=np.float32)
  106. )
  107. return qconv
  108. def forward(self, inp):
  109. return self.calc_conv_quantized(inp, nonlinear_mode="identity")
  110. class ConvRelu2d(Conv2d):
  111. r"""Quantized version of :class:`~.qat.ConvRelu2d`."""
  112. def forward(self, inp):
  113. return self.calc_conv_quantized(inp, nonlinear_mode="relu")
  114. class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
  115. r"""Quantized version of :class:`~.qat.ConvTranspose2d`.
  116. Applies a 2D transposed convolution over a quantized input tensor, used
  117. for inference only.
  118. The parameter is same with :class:`~.module.ConvTranspose2d` but dtype.
  119. Args:
  120. dtype: data type of the output, should be qint8.
  121. """
  122. def __init__(
  123. self,
  124. in_channels: int,
  125. out_channels: int,
  126. kernel_size: Union[int, Tuple[int, int]],
  127. stride: Union[int, Tuple[int, int]] = 1,
  128. padding: Union[int, Tuple[int, int]] = 0,
  129. dilation: Union[int, Tuple[int, int]] = 1,
  130. groups: int = 1,
  131. bias: bool = True,
  132. conv_mode: str = "cross_correlation",
  133. compute_mode: str = "default",
  134. dtype=None,
  135. **kwargs
  136. ):
  137. super().__init__(
  138. in_channels=in_channels,
  139. out_channels=out_channels,
  140. kernel_size=kernel_size,
  141. stride=stride,
  142. padding=padding,
  143. dilation=dilation,
  144. groups=groups,
  145. bias=bias,
  146. conv_mode=conv_mode,
  147. compute_mode=compute_mode,
  148. )
  149. self.output_dtype = dtype
  150. @classmethod
  151. def from_qat_module(cls, qat_module: QAT.ConvTranspose2d):
  152. r"""
  153. return a :class:`~.QuantizedModule` instance converted from a
  154. :class:`~.QATModule` instance.
  155. """
  156. output_dtype = qat_module.get_activation_dtype()
  157. qconv = cls(
  158. qat_module.in_channels,
  159. qat_module.out_channels,
  160. qat_module.kernel_size,
  161. qat_module.stride,
  162. qat_module.padding,
  163. qat_module.dilation,
  164. qat_module.groups,
  165. qat_module.bias is not None,
  166. qat_module.conv_mode,
  167. qat_module.compute_mode,
  168. dtype=output_dtype,
  169. name=qat_module.name,
  170. )
  171. weight = qat_module.weight.astype(qat_module.get_weight_dtype())
  172. qconv.weight = Parameter(weight.numpy(), name=qat_module.weight.name)
  173. qconv.bias = (
  174. Parameter(qat_module.bias.numpy(), name=qat_module.bias.name)
  175. if qat_module.bias is not None
  176. else None
  177. )
  178. return qconv
  179. def calc_conv_transpose2d_quantized(self, inp):
  180. if self.bias is not None:
  181. inp_scale = dtype.get_scale(inp.dtype)
  182. w_scale = dtype.get_scale(self.weight.dtype)
  183. bias_scale = inp_scale * w_scale
  184. return conv_transpose2d(
  185. inp=inp,
  186. weight=self.weight,
  187. bias=self.bias.astype(dtype.qint32(bias_scale))
  188. if self.bias is not None
  189. else None,
  190. dtype=self.output_dtype,
  191. stride=self.stride,
  192. padding=self.padding,
  193. dilation=self.dilation,
  194. groups=self.groups,
  195. conv_mode=self.conv_mode,
  196. compute_mode=self.compute_mode,
  197. )
  198. def forward(self, inp):
  199. return self.calc_conv_transpose2d_quantized(inp)