You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv.py 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. from typing import Tuple, Union
  2. import numpy as np
  3. from ... import module as Float
  4. from ...core.tensor import dtype
  5. from ...functional.nn import conv_bias_activation, pad
  6. from ...functional.quantized import conv_transpose2d
  7. from ...tensor import Parameter
  8. from ..qat import conv as QAT
  9. from .module import QuantizedModule
  10. class Conv2d(Float.Conv2d, QuantizedModule):
  11. r"""Quantized version of :class:`~.qat.Conv2d`.
  12. Applies a 2D convolution over a quantized input tensor, used for inference only.
  13. The parameter is same with :class:`~.module.Conv2d`.
  14. """
  15. def __init__(
  16. self,
  17. in_channels: int,
  18. out_channels: int,
  19. kernel_size: Union[int, Tuple[int, int]],
  20. stride: Union[int, Tuple[int, int]] = 1,
  21. padding: Union[int, Tuple[int, int]] = 0,
  22. dilation: Union[int, Tuple[int, int]] = 1,
  23. groups: int = 1,
  24. conv_mode: str = "cross_correlation",
  25. compute_mode: str = "default",
  26. dtype=None,
  27. padding_mode: str = "zeros",
  28. **kwargs
  29. ):
  30. super().__init__(
  31. in_channels,
  32. out_channels,
  33. kernel_size,
  34. stride,
  35. padding,
  36. dilation,
  37. groups,
  38. True,
  39. conv_mode,
  40. compute_mode,
  41. padding_mode,
  42. )
  43. self.output_dtype = dtype
  44. def calc_conv_quantized(self, inp, nonlinear_mode="identity"):
  45. assert self.padding_mode in [
  46. "zeros",
  47. "reflect",
  48. "replicate",
  49. ]
  50. inp_scale = dtype.get_scale(inp.dtype)
  51. w_scale = dtype.get_scale(self.weight.dtype)
  52. bias_scale = inp_scale * w_scale
  53. if self.padding_mode != "zeros":
  54. return conv_bias_activation(
  55. pad(inp, self.get_pad_witdth(), self.padding_mode),
  56. self.weight,
  57. self.bias.astype(dtype.qint32(bias_scale)),
  58. self.output_dtype,
  59. self.stride,
  60. 0,
  61. self.dilation,
  62. self.groups,
  63. conv_mode=self.conv_mode,
  64. compute_mode=self.compute_mode,
  65. nonlinear_mode=nonlinear_mode,
  66. )
  67. return conv_bias_activation(
  68. inp,
  69. self.weight,
  70. self.bias.astype(dtype.qint32(bias_scale)),
  71. self.output_dtype,
  72. self.stride,
  73. self.padding,
  74. self.dilation,
  75. self.groups,
  76. conv_mode=self.conv_mode,
  77. compute_mode=self.compute_mode,
  78. nonlinear_mode=nonlinear_mode,
  79. )
  80. @classmethod
  81. def from_qat_module(cls, qat_module: QAT.Conv2d):
  82. r"""
  83. Return a :class:`~.QuantizedModule` instance converted from a
  84. :class:`~.QATModule` instance.
  85. """
  86. output_dtype = qat_module.get_activation_dtype()
  87. qconv = cls(
  88. qat_module.in_channels,
  89. qat_module.out_channels,
  90. qat_module.kernel_size,
  91. qat_module.stride,
  92. qat_module.padding,
  93. qat_module.dilation,
  94. qat_module.groups,
  95. dtype=output_dtype,
  96. padding_mode=qat_module.padding_mode,
  97. name=qat_module.name,
  98. )
  99. weight = qat_module.weight.astype(qat_module.get_weight_dtype())
  100. qconv.weight = Parameter(weight.numpy(), name=qat_module.weight.name)
  101. if qat_module.bias is not None:
  102. qconv.bias = Parameter(qat_module.bias.numpy(), name=qat_module.bias.name)
  103. else:
  104. qconv.bias = Parameter(
  105. np.zeros(qat_module._infer_bias_shape(), dtype=np.float32)
  106. )
  107. return qconv
  108. def forward(self, inp):
  109. return self.calc_conv_quantized(inp, nonlinear_mode="identity")
  110. class ConvRelu2d(Conv2d):
  111. r"""Quantized version of :class:`~.qat.ConvRelu2d`."""
  112. def forward(self, inp):
  113. return self.calc_conv_quantized(inp, nonlinear_mode="relu")
  114. class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
  115. r"""Quantized version of :class:`~.qat.ConvTranspose2d`.
  116. Applies a 2D transposed convolution over a quantized input tensor, used
  117. for inference only.
  118. The parameter is same with :class:`~.module.ConvTranspose2d` but dtype.
  119. Args:
  120. dtype: data type of the output, should be qint8.
  121. """
  122. output_padding = 0
  123. def __init__(
  124. self,
  125. in_channels: int,
  126. out_channels: int,
  127. kernel_size: Union[int, Tuple[int, int]],
  128. stride: Union[int, Tuple[int, int]] = 1,
  129. padding: Union[int, Tuple[int, int]] = 0,
  130. output_padding: Union[int, Tuple[int, int]] = 0,
  131. dilation: Union[int, Tuple[int, int]] = 1,
  132. groups: int = 1,
  133. bias: bool = True,
  134. conv_mode: str = "cross_correlation",
  135. compute_mode: str = "default",
  136. dtype=None,
  137. **kwargs
  138. ):
  139. super().__init__(
  140. in_channels=in_channels,
  141. out_channels=out_channels,
  142. kernel_size=kernel_size,
  143. stride=stride,
  144. padding=padding,
  145. output_padding=output_padding,
  146. dilation=dilation,
  147. groups=groups,
  148. bias=bias,
  149. conv_mode=conv_mode,
  150. compute_mode=compute_mode,
  151. )
  152. self.output_dtype = dtype
  153. @classmethod
  154. def from_qat_module(cls, qat_module: QAT.ConvTranspose2d):
  155. r"""
  156. return a :class:`~.QuantizedModule` instance converted from a
  157. :class:`~.QATModule` instance.
  158. """
  159. output_dtype = qat_module.get_activation_dtype()
  160. qconv_transpose2d = cls(
  161. qat_module.in_channels,
  162. qat_module.out_channels,
  163. qat_module.kernel_size,
  164. qat_module.stride,
  165. qat_module.padding,
  166. qat_module.output_padding,
  167. qat_module.dilation,
  168. qat_module.groups,
  169. qat_module.bias is not None,
  170. qat_module.conv_mode,
  171. qat_module.compute_mode,
  172. dtype=output_dtype,
  173. name=qat_module.name,
  174. )
  175. weight = qat_module.weight.astype(qat_module.get_weight_dtype())
  176. qconv_transpose2d.weight = Parameter(
  177. weight.numpy(), name=qat_module.weight.name
  178. )
  179. qconv_transpose2d.bias = (
  180. Parameter(qat_module.bias.numpy(), name=qat_module.bias.name)
  181. if qat_module.bias is not None
  182. else None
  183. )
  184. return qconv_transpose2d
  185. def calc_conv_transpose2d_quantized(self, inp, nonlinear_mode):
  186. assert nonlinear_mode == "identity", "nonlinear_mode shoule be 'identity'"
  187. if self.bias is not None:
  188. inp_scale = dtype.get_scale(inp.dtype)
  189. w_scale = dtype.get_scale(self.weight.dtype)
  190. bias_scale = inp_scale * w_scale
  191. return conv_transpose2d(
  192. inp=inp,
  193. weight=self.weight,
  194. bias=self.bias.astype(dtype.qint32(bias_scale))
  195. if self.bias is not None
  196. else None,
  197. dtype=self.output_dtype,
  198. stride=self.stride,
  199. padding=self.padding,
  200. output_padding=self.output_padding,
  201. dilation=self.dilation,
  202. groups=self.groups,
  203. conv_mode=self.conv_mode,
  204. compute_mode=self.compute_mode,
  205. )
  206. def forward(self, inp):
  207. return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="identity")
  208. class ConvTransposeRelu2d(ConvTranspose2d):
  209. r"""Quantized version of :class:`~.qat.ConvTransposeRelu2d`."""
  210. def forward(self, inp):
  211. return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="relu")