You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bn_relu.py 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from typing import Tuple, Union
  9. from ..core import ones, zeros
  10. from ..functional import add_update, flatten, relu, sqrt, sum, zero_grad
  11. from .batchnorm import BatchNorm2d
  12. from .conv import Conv2d
  13. from .module import QATModule
  14. class _ConvBn2d(QATModule):
  15. def __init__(
  16. self,
  17. in_channels: int,
  18. out_channels: int,
  19. kernel_size: Union[int, Tuple[int, int]],
  20. stride: Union[int, Tuple[int, int]] = 1,
  21. padding: Union[int, Tuple[int, int]] = 0,
  22. dilation: Union[int, Tuple[int, int]] = 1,
  23. groups: int = 1,
  24. bias: bool = True,
  25. conv_mode: str = "CROSS_CORRELATION",
  26. compute_mode: str = "DEFAULT",
  27. eps=1e-5,
  28. momentum=0.9,
  29. affine=True,
  30. track_running_stats=True,
  31. ):
  32. super().__init__()
  33. self.conv = Conv2d(
  34. in_channels,
  35. out_channels,
  36. kernel_size,
  37. stride,
  38. padding,
  39. dilation,
  40. groups,
  41. bias,
  42. conv_mode,
  43. compute_mode,
  44. )
  45. self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
  46. def get_batch_mean_var(self, inp):
  47. def _sum_channel(inp, axis=0, keepdims=True):
  48. if isinstance(axis, int):
  49. out = sum(inp, axis=axis, keepdims=keepdims)
  50. elif isinstance(axis, tuple):
  51. for idx, elem in enumerate(axis):
  52. out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
  53. return out
  54. sum1 = _sum_channel(inp, (0, 2, 3))
  55. sum2 = _sum_channel(inp ** 2, (0, 2, 3))
  56. reduce_size = inp.shapeof().prod() / inp.shapeof(1)
  57. batch_mean = sum1 / reduce_size
  58. batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
  59. return batch_mean, batch_var
  60. def fold_weight_bias(self, bn_mean, bn_var):
  61. # get fold bn conv param
  62. # bn_istd = 1 / bn_std
  63. # w_fold = gamma / bn_std * W
  64. # b_fold = gamma * (b - bn_mean) / bn_std + beta
  65. gamma = self.bn.weight
  66. if gamma is None:
  67. gamma = ones((self.bn.num_features), dtype="float32")
  68. gamma = gamma.reshape(1, -1, 1, 1)
  69. beta = self.bn.bias
  70. if beta is None:
  71. beta = zeros((self.bn.num_features), dtype="float32")
  72. beta = beta.reshape(1, -1, 1, 1)
  73. if bn_mean is None:
  74. bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
  75. if bn_var is None:
  76. bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")
  77. conv_bias = self.conv.bias
  78. if conv_bias is None:
  79. conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
  80. bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
  81. # bn_istd = 1 / bn_std
  82. # w_fold = gamma / bn_std * W
  83. scale_factor = gamma * bn_istd
  84. if self.conv.groups == 1:
  85. w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
  86. else:
  87. w_fold = self.conv.weight * scale_factor.reshape(
  88. self.conv.groups, -1, 1, 1, 1
  89. )
  90. # b_fold = gamma * (b - bn_mean) / bn_std + beta
  91. b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
  92. return w_fold, b_fold
  93. def update_running_mean_and_running_var(
  94. self, bn_mean, bn_var, num_elements_per_channel
  95. ):
  96. # update running mean and running var. no grad, use unbiased bn var
  97. bn_mean = zero_grad(bn_mean)
  98. bn_var = (
  99. zero_grad(bn_var)
  100. * num_elements_per_channel
  101. / (num_elements_per_channel - 1)
  102. )
  103. exponential_average_factor = 1 - self.bn.momentum
  104. add_update(
  105. self.bn.running_mean,
  106. delta=bn_mean,
  107. alpha=1 - exponential_average_factor,
  108. beta=exponential_average_factor,
  109. )
  110. add_update(
  111. self.bn.running_var,
  112. delta=bn_var,
  113. alpha=1 - exponential_average_factor,
  114. beta=exponential_average_factor,
  115. )
  116. def calc_conv_bn_qat(self, inp, approx=True):
  117. if self.training and not approx:
  118. conv = self.conv(inp)
  119. bn_mean, bn_var = self.get_batch_mean_var(conv)
  120. num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
  121. self.update_running_mean_and_running_var(
  122. bn_mean, bn_var, num_elements_per_channel
  123. )
  124. else:
  125. bn_mean, bn_var = self.bn.running_mean, self.bn.running_var
  126. # get gamma and beta in BatchNorm
  127. gamma = self.bn.weight
  128. if gamma is None:
  129. gamma = ones((self.bn.num_features), dtype="float32")
  130. gamma = gamma.reshape(1, -1, 1, 1)
  131. beta = self.bn.bias
  132. if beta is None:
  133. beta = zeros((self.bn.num_features), dtype="float32")
  134. beta = beta.reshape(1, -1, 1, 1)
  135. # conv_bias
  136. conv_bias = self.conv.bias
  137. if conv_bias is None:
  138. conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
  139. bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
  140. # bn_istd = 1 / bn_std
  141. # w_fold = gamma / bn_std * W
  142. scale_factor = gamma * bn_istd
  143. if self.conv.groups == 1:
  144. w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
  145. else:
  146. w_fold = self.conv.weight * scale_factor.reshape(
  147. self.conv.groups, -1, 1, 1, 1
  148. )
  149. b_fold = None
  150. if not (self.training and approx):
  151. # b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta
  152. b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
  153. w_qat = self.apply_fakequant_with_observer(
  154. w_fold, self.weight_fake_quant, self.weight_observer
  155. )
  156. conv = self.conv.calc_conv(inp, w_qat, b_fold)
  157. if not (self.training and approx):
  158. return conv
  159. # rescale conv to get original conv output
  160. orig_conv = conv / scale_factor.reshape(1, -1, 1, 1)
  161. if self.conv.bias is not None:
  162. orig_conv = orig_conv + self.conv.bias
  163. # calculate batch norm
  164. bn_mean, bn_var = self.get_batch_mean_var(orig_conv)
  165. bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
  166. conv = gamma * bn_istd * (orig_conv - bn_mean) + beta
  167. num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
  168. self.update_running_mean_and_running_var(
  169. bn_mean, bn_var, num_elements_per_channel
  170. )
  171. return conv
  172. class ConvBn2d(_ConvBn2d):
  173. r"""
  174. A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode
  175. and ``normal`` mode.
  176. """
  177. def forward_qat(self, inp):
  178. return self.apply_fakequant_with_observer(
  179. self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer
  180. )
  181. def forward(self, inp):
  182. return self.bn(self.conv(inp))
  183. class ConvBnRelu2d(_ConvBn2d):
  184. r"""
  185. A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat``
  186. mode and ``normal`` mode.
  187. """
  188. def forward_qat(self, inp):
  189. return self.apply_fakequant_with_observer(
  190. relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer
  191. )
  192. def forward(self, inp):
  193. return relu(self.bn(self.conv(inp)))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台