You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bn.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from ...core import ones, zeros
  9. from ...functional import add_update, relu, sqrt, sum, zero_grad
  10. from ...quantization.utils import fake_quant_bias
  11. from .. import conv_bn as Float
  12. from .module import QATModule
  13. class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
  14. def get_batch_mean_var(self, inp):
  15. def _sum_channel(inp, axis=0, keepdims=True):
  16. if isinstance(axis, int):
  17. out = sum(inp, axis=axis, keepdims=keepdims)
  18. elif isinstance(axis, tuple):
  19. for idx, elem in enumerate(axis):
  20. out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
  21. return out
  22. sum1 = _sum_channel(inp, (0, 2, 3))
  23. sum2 = _sum_channel(inp ** 2, (0, 2, 3))
  24. reduce_size = inp.shapeof().prod() / inp.shapeof(1)
  25. batch_mean = sum1 / reduce_size
  26. batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
  27. return batch_mean, batch_var
  28. def fold_weight_bias(self, bn_mean, bn_var):
  29. # get fold bn conv param
  30. # bn_istd = 1 / bn_std
  31. # w_fold = gamma / bn_std * W
  32. # b_fold = gamma * (b - bn_mean) / bn_std + beta
  33. gamma = self.bn.weight
  34. if gamma is None:
  35. gamma = ones((self.bn.num_features), dtype="float32")
  36. gamma = gamma.reshape(1, -1, 1, 1)
  37. beta = self.bn.bias
  38. if beta is None:
  39. beta = zeros((self.bn.num_features), dtype="float32")
  40. beta = beta.reshape(1, -1, 1, 1)
  41. if bn_mean is None:
  42. bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
  43. if bn_var is None:
  44. bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")
  45. conv_bias = self.conv.bias
  46. if conv_bias is None:
  47. conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
  48. bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
  49. # bn_istd = 1 / bn_std
  50. # w_fold = gamma / bn_std * W
  51. scale_factor = gamma * bn_istd
  52. if self.conv.groups == 1:
  53. w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
  54. else:
  55. w_fold = self.conv.weight * scale_factor.reshape(
  56. self.conv.groups, -1, 1, 1, 1
  57. )
  58. w_fold = self.apply_quant_weight(w_fold)
  59. # b_fold = gamma * (b - bn_mean) / bn_std + beta
  60. b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
  61. return w_fold, b_fold
  62. def update_running_mean_and_running_var(
  63. self, bn_mean, bn_var, num_elements_per_channel
  64. ):
  65. # update running mean and running var. no grad, use unbiased bn var
  66. bn_mean = zero_grad(bn_mean)
  67. bn_var = (
  68. zero_grad(bn_var)
  69. * num_elements_per_channel
  70. / (num_elements_per_channel - 1)
  71. )
  72. exponential_average_factor = 1 - self.bn.momentum
  73. add_update(
  74. self.bn.running_mean,
  75. delta=bn_mean,
  76. alpha=1 - exponential_average_factor,
  77. beta=exponential_average_factor,
  78. )
  79. add_update(
  80. self.bn.running_var,
  81. delta=bn_var,
  82. alpha=1 - exponential_average_factor,
  83. beta=exponential_average_factor,
  84. )
  85. def calc_conv_bn_qat(self, inp, approx=True):
  86. if self.training and not approx:
  87. conv = self.conv(inp)
  88. bn_mean, bn_var = self.get_batch_mean_var(conv)
  89. num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
  90. self.update_running_mean_and_running_var(
  91. bn_mean, bn_var, num_elements_per_channel
  92. )
  93. else:
  94. bn_mean, bn_var = self.bn.running_mean, self.bn.running_var
  95. # get gamma and beta in BatchNorm
  96. gamma = self.bn.weight
  97. if gamma is None:
  98. gamma = ones((self.bn.num_features), dtype="float32")
  99. gamma = gamma.reshape(1, -1, 1, 1)
  100. beta = self.bn.bias
  101. if beta is None:
  102. beta = zeros((self.bn.num_features), dtype="float32")
  103. beta = beta.reshape(1, -1, 1, 1)
  104. # conv_bias
  105. conv_bias = self.conv.bias
  106. if conv_bias is None:
  107. conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
  108. bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
  109. # bn_istd = 1 / bn_std
  110. # w_fold = gamma / bn_std * W
  111. scale_factor = gamma * bn_istd
  112. if self.conv.groups == 1:
  113. w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
  114. else:
  115. w_fold = self.conv.weight * scale_factor.reshape(
  116. self.conv.groups, -1, 1, 1, 1
  117. )
  118. b_fold = None
  119. if not (self.training and approx):
  120. # b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta
  121. b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
  122. w_qat = self.apply_quant_weight(w_fold)
  123. b_qat = fake_quant_bias(b_fold, inp, w_qat)
  124. conv = self.conv.calc_conv(inp, w_qat, b_qat)
  125. if not (self.training and approx):
  126. return conv
  127. # rescale conv to get original conv output
  128. orig_conv = conv / scale_factor.reshape(1, -1, 1, 1)
  129. if self.conv.bias is not None:
  130. orig_conv = orig_conv + self.conv.bias
  131. # calculate batch norm
  132. bn_mean, bn_var = self.get_batch_mean_var(orig_conv)
  133. bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
  134. conv = gamma * bn_istd * (orig_conv - bn_mean) + beta
  135. num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
  136. self.update_running_mean_and_running_var(
  137. bn_mean, bn_var, num_elements_per_channel
  138. )
  139. return conv
  140. @classmethod
  141. def from_float_module(cls, float_module: Float._ConvBnActivation2d):
  142. r"""
  143. Return a :class:`~.QATModule` instance converted from
  144. a float :class:`~.Module` instance.
  145. """
  146. qat_module = cls(
  147. float_module.conv.in_channels,
  148. float_module.conv.out_channels,
  149. float_module.conv.kernel_size,
  150. float_module.conv.stride,
  151. float_module.conv.padding,
  152. float_module.conv.dilation,
  153. float_module.conv.groups,
  154. float_module.conv.bias is not None,
  155. float_module.conv.conv_mode.name,
  156. float_module.conv.compute_mode.name,
  157. )
  158. qat_module.conv.weight = float_module.conv.weight
  159. qat_module.conv.bias = float_module.conv.bias
  160. qat_module.bn = float_module.bn
  161. return qat_module
  162. class ConvBn2d(_ConvBnActivation2d):
  163. r"""
  164. A fused :class:`~.QATModule` including Conv2d, BatchNorm2d with QAT support.
  165. Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
  166. """
  167. def forward(self, inp):
  168. return self.apply_quant_activation(self.calc_conv_bn_qat(inp))
  169. class ConvBnRelu2d(_ConvBnActivation2d):
  170. r"""
  171. A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu with QAT support.
  172. Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
  173. """
  174. def forward(self, inp):
  175. return self.apply_quant_activation(relu(self.calc_conv_bn_qat(inp)))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台