You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bn.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from ...functional import ones, relu, sqrt, sum, zeros
  2. from .. import conv_bn as Float
  3. from .module import QATModule
  4. class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
  5. def get_batch_mean_var(self, inp):
  6. def _sum_channel(inp, axis=0, keepdims=True):
  7. if isinstance(axis, int):
  8. out = sum(inp, axis=axis, keepdims=keepdims)
  9. elif isinstance(axis, tuple):
  10. for idx, elem in enumerate(axis):
  11. out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
  12. return out
  13. sum1 = _sum_channel(inp, (0, 2, 3))
  14. sum2 = _sum_channel(inp ** 2, (0, 2, 3))
  15. reduce_size = inp.size / inp.shape[1]
  16. batch_mean = sum1 / reduce_size
  17. batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
  18. return batch_mean, batch_var
  19. def fold_weight_bias(self, bn_mean, bn_var):
  20. # get fold bn conv param
  21. # bn_istd = 1 / bn_std
  22. # w_fold = gamma / bn_std * W
  23. # b_fold = gamma * (b - bn_mean) / bn_std + beta
  24. gamma = self.bn.weight
  25. if gamma is None:
  26. gamma = ones((self.bn.num_features), dtype="float32")
  27. gamma = gamma.reshape(1, -1, 1, 1)
  28. beta = self.bn.bias
  29. if beta is None:
  30. beta = zeros((self.bn.num_features), dtype="float32")
  31. beta = beta.reshape(1, -1, 1, 1)
  32. if bn_mean is None:
  33. bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
  34. if bn_var is None:
  35. bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")
  36. conv_bias = self.conv.bias
  37. if conv_bias is None:
  38. conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
  39. bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
  40. # bn_istd = 1 / bn_std
  41. # w_fold = gamma / bn_std * W
  42. scale_factor = gamma * bn_istd
  43. if self.conv.groups == 1:
  44. w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
  45. else:
  46. w_fold = self.conv.weight * scale_factor.reshape(
  47. self.conv.groups, -1, 1, 1, 1
  48. )
  49. w_fold = self.apply_quant_weight(w_fold)
  50. # b_fold = gamma * (b - bn_mean) / bn_std + beta
  51. b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
  52. return w_fold, b_fold
  53. def update_running_mean_and_running_var(
  54. self, bn_mean, bn_var, num_elements_per_channel
  55. ):
  56. # update running mean and running var. no grad, use unbiased bn var
  57. bn_mean = bn_mean.detach()
  58. bn_var = (
  59. bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1)
  60. )
  61. exponential_average_factor = 1 - self.bn.momentum
  62. self.bn.running_mean *= self.bn.momentum
  63. self.bn.running_mean += exponential_average_factor * bn_mean
  64. self.bn.running_var *= self.bn.momentum
  65. self.bn.running_var += exponential_average_factor * bn_var
  66. def calc_conv_bn_qat(self, inp, approx=True):
  67. if self.training and not approx:
  68. conv = self.conv(inp)
  69. bn_mean, bn_var = self.get_batch_mean_var(conv)
  70. num_elements_per_channel = conv.size / conv.shape[1]
  71. self.update_running_mean_and_running_var(
  72. bn_mean, bn_var, num_elements_per_channel
  73. )
  74. else:
  75. bn_mean, bn_var = self.bn.running_mean, self.bn.running_var
  76. # get gamma and beta in BatchNorm
  77. gamma = self.bn.weight
  78. if gamma is None:
  79. gamma = ones((self.bn.num_features), dtype="float32")
  80. gamma = gamma.reshape(1, -1, 1, 1)
  81. beta = self.bn.bias
  82. if beta is None:
  83. beta = zeros((self.bn.num_features), dtype="float32")
  84. beta = beta.reshape(1, -1, 1, 1)
  85. # conv_bias
  86. conv_bias = self.conv.bias
  87. if conv_bias is None:
  88. conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
  89. bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
  90. # bn_istd = 1 / bn_std
  91. # w_fold = gamma / bn_std * W
  92. scale_factor = gamma * bn_istd
  93. if self.conv.groups == 1:
  94. w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
  95. else:
  96. w_fold = self.conv.weight * scale_factor.reshape(
  97. self.conv.groups, -1, 1, 1, 1
  98. )
  99. b_fold = None
  100. if not (self.training and approx):
  101. # b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta
  102. b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
  103. w_qat = self.apply_quant_weight(w_fold)
  104. b_qat = self.apply_quant_bias(b_fold, inp, w_qat)
  105. conv = self.conv.calc_conv(inp, w_qat, b_qat)
  106. if not (self.training and approx):
  107. return conv
  108. # rescale conv to get original conv output
  109. orig_conv = conv / scale_factor.reshape(1, -1, 1, 1)
  110. if self.conv.bias is not None:
  111. orig_conv = orig_conv + self.conv.bias
  112. # calculate batch norm
  113. conv = self.bn(orig_conv)
  114. return conv
  115. @classmethod
  116. def from_float_module(cls, float_module: Float._ConvBnActivation2d):
  117. qat_module = cls(
  118. float_module.conv.in_channels,
  119. float_module.conv.out_channels,
  120. float_module.conv.kernel_size,
  121. float_module.conv.stride,
  122. float_module.conv.padding,
  123. float_module.conv.dilation,
  124. float_module.conv.groups,
  125. float_module.conv.bias is not None,
  126. float_module.conv.conv_mode,
  127. float_module.conv.compute_mode,
  128. padding_mode=float_module.conv.padding_mode,
  129. name=float_module.name,
  130. )
  131. qat_module.conv.weight = float_module.conv.weight
  132. qat_module.conv.bias = float_module.conv.bias
  133. qat_module.bn = float_module.bn
  134. return qat_module
  135. class ConvBn2d(_ConvBnActivation2d):
  136. r"""A fused :class:`~.QATModule` including :class:`~.module.Conv2d` and :class:`~.module.BatchNorm2d` with QAT support.
  137. Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
  138. """
  139. def forward(self, inp):
  140. return self.apply_quant_activation(self.calc_conv_bn_qat(inp))
  141. class ConvBnRelu2d(_ConvBnActivation2d):
  142. r"""A fused :class:`~.QATModule` including :class:`~.module.Conv2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu` with QAT support.
  143. Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
  144. """
  145. def forward(self, inp):
  146. return self.apply_quant_activation(relu(self.calc_conv_bn_qat(inp)))