You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bn_relu.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from typing import Tuple, Union
  9. from ..core import ones, zeros
  10. from ..functional import flatten, relu, sqrt, sum
  11. from .batchnorm import BatchNorm2d
  12. from .conv import Conv2d
  13. from .module import QATModule
  14. class _ConvBn2d(QATModule):
  15. def __init__(
  16. self,
  17. in_channels: int,
  18. out_channels: int,
  19. kernel_size: Union[int, Tuple[int, int]],
  20. stride: Union[int, Tuple[int, int]] = 1,
  21. padding: Union[int, Tuple[int, int]] = 0,
  22. dilation: Union[int, Tuple[int, int]] = 1,
  23. groups: int = 1,
  24. bias: bool = True,
  25. conv_mode: str = "CROSS_CORRELATION",
  26. compute_mode: str = "DEFAULT",
  27. eps=1e-5,
  28. momentum=0.9,
  29. affine=True,
  30. track_running_stats=True,
  31. freeze_bn=False,
  32. ):
  33. super().__init__()
  34. self.conv = Conv2d(
  35. in_channels,
  36. out_channels,
  37. kernel_size,
  38. stride,
  39. padding,
  40. dilation,
  41. groups,
  42. bias,
  43. conv_mode,
  44. compute_mode,
  45. )
  46. self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
  47. self.freeze_bn = freeze_bn
  48. def update_bn_stats(self):
  49. self.freeze_bn = False
  50. return self
  51. def freeze_bn_stats(self):
  52. self.freeze_bn = True
  53. return self
  54. def get_bn_gamma_beta(self):
  55. if self.bn.weight is None:
  56. gamma = ones((self.bn.num_features), dtype="float32")
  57. else:
  58. gamma = self.bn.weight
  59. if self.bn.bias is None:
  60. beta = zeros((self.bn.num_features), dtype="float32")
  61. else:
  62. beta = self.bn.bias
  63. return gamma, beta
  64. def get_batch_mean_var(self, inp):
  65. def _sum_channel(inp, axis=0, keepdims=True):
  66. if isinstance(axis, int):
  67. out = sum(inp, axis=axis, keepdims=keepdims)
  68. elif isinstance(axis, tuple):
  69. for idx, elem in enumerate(axis):
  70. out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
  71. return out
  72. sum1 = _sum_channel(inp, (0, 2, 3))
  73. sum2 = _sum_channel(inp ** 2, (0, 2, 3))
  74. reduce_size = inp.shapeof().prod() / inp.shapeof(1)
  75. batch_mean = sum1 / reduce_size
  76. batch_var = (sum2 - sum1 ** 2 / reduce_size) / (reduce_size - 1)
  77. return batch_mean, batch_var
  78. def fold_weight_bias(self, bn_mean, bn_var):
  79. # get fold bn conv param
  80. # bn_istd = 1 / bn_std
  81. # w_fold = gamma / bn_std * W
  82. # b_fold = gamma * (b - bn_mean) / bn_std + beta
  83. gamma, beta = self.get_bn_gamma_beta()
  84. b = self.conv.bias
  85. if b is None:
  86. b = zeros(self.conv._infer_bias_shape(), dtype="float32")
  87. if bn_mean is None:
  88. bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
  89. if bn_var is None:
  90. bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")
  91. bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
  92. if self.conv.groups == 1:
  93. w_fold = (
  94. self.conv.weight
  95. * gamma.reshape(-1, 1, 1, 1)
  96. * bn_istd.reshape(-1, 1, 1, 1)
  97. )
  98. else:
  99. w_fold = (
  100. self.conv.weight
  101. * gamma.reshape(self.conv.groups, -1, 1, 1, 1)
  102. * bn_istd.reshape(self.conv.groups, -1, 1, 1, 1)
  103. )
  104. b_fold = flatten(beta) + (
  105. flatten(gamma) * (flatten(b) - flatten(bn_mean)) * flatten(bn_istd)
  106. )
  107. b_fold = b_fold.reshape(self.conv._infer_bias_shape())
  108. return w_fold, b_fold
  109. def calc_conv_bn_qat(self, inp):
  110. # TODO: use pytorch method as
  111. conv = self.conv(inp)
  112. self.bn(conv)
  113. if self.training:
  114. bn_mean, bn_var = self.get_batch_mean_var(conv)
  115. else:
  116. bn_mean, bn_var = self.bn.running_mean, self.bn.running_var
  117. w_fold, b_fold = self.fold_weight_bias(bn_mean, bn_var)
  118. w_qat = self.apply_fakequant_with_observer(
  119. w_fold, self.weight_fake_quant, self.weight_observer
  120. )
  121. return self.conv.calc_conv(inp, w_qat, b_fold)
  122. class ConvBn2d(_ConvBn2d):
  123. r"""
  124. A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode
  125. and ``normal`` mode.
  126. """
  127. def forward_qat(self, inp):
  128. return self.apply_fakequant_with_observer(
  129. self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer
  130. )
  131. def forward(self, inp):
  132. return self.bn(self.conv(inp))
  133. class ConvBnRelu2d(_ConvBn2d):
  134. r"""
  135. A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat``
  136. mode and ``normal`` mode.
  137. """
  138. def forward_qat(self, inp):
  139. return self.apply_fakequant_with_observer(
  140. relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer
  141. )
  142. def forward(self, inp):
  143. return relu(self.bn(self.conv(inp)))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台