You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bn_fusion.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from copy import deepcopy
  2. from ..functional import ones, sqrt, zeros
  3. from ..module import (
  4. BatchNorm2d,
  5. Conv2d,
  6. ConvBn2d,
  7. ConvBnRelu2d,
  8. ConvRelu2d,
  9. ConvTranspose2d,
  10. ConvTransposeBn2d,
  11. ConvTransposeBnRelu2d,
  12. ConvTransposeRelu2d,
  13. ReLU,
  14. )
  15. from ..tensor import Parameter
  16. _MAP_TO_FUSED_MODULE = {
  17. (Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d,
  18. (Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d,
  19. (ConvTranspose2d, BatchNorm2d, ReLU, False): ConvTransposeRelu2d,
  20. (ConvTranspose2d, BatchNorm2d, ReLU, True): ConvTransposeBnRelu2d,
  21. (Conv2d, BatchNorm2d, False): Conv2d,
  22. (Conv2d, BatchNorm2d, True): ConvBn2d,
  23. (Conv2d, ReLU): ConvRelu2d,
  24. (ConvTranspose2d, BatchNorm2d, False): ConvTranspose2d,
  25. (ConvTranspose2d, BatchNorm2d, True): ConvTransposeBn2d,
  26. (ConvTranspose2d, ReLU): ConvTransposeRelu2d,
  27. }
  28. def fold_weight_bias(
  29. weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False
  30. ):
  31. shape = (1, -1, 1, 1)
  32. if transpose:
  33. shape = (-1, 1, 1, 1)
  34. kernel_shape = weight.shape
  35. if len(kernel_shape) == 5:
  36. groups, num_features = kernel_shape[0], kernel_shape[1]
  37. else:
  38. groups, num_features = 1, kernel_shape[0]
  39. out_channels = groups * num_features
  40. if gamma is None:
  41. gamma = ones((out_channels,), dtype="float32")
  42. gamma = gamma.reshape(1, -1, 1, 1)
  43. if beta is None:
  44. beta = zeros((out_channels,), dtype="float32")
  45. beta = beta.reshape(1, -1, 1, 1)
  46. if bn_mean is None:
  47. bn_mean = zeros((1, out_channels, 1, 1), dtype="float32")
  48. if bn_var is None:
  49. bn_var = ones((1, out_channels, 1, 1), dtype="float32")
  50. if bias is None:
  51. bias = zeros((1, out_channels, 1, 1), dtype="float32")
  52. bn_istd = 1.0 / sqrt(bn_var + eps)
  53. scale_factor = gamma * bn_istd
  54. if groups == 1:
  55. w_fold = weight * scale_factor.reshape(*shape)
  56. else:
  57. w_fold = weight * scale_factor.reshape(groups, *shape)
  58. b_fold = beta + gamma * (bias - bn_mean) * bn_istd
  59. return w_fold, b_fold
  60. def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
  61. module_key = tuple([type(m) for m in [conv, bn, relu] if m])
  62. if bn:
  63. assert (
  64. conv.training == bn.training
  65. ), "Conv and BN both must be in the same mode (train or eval)."
  66. assert (
  67. bn.num_features == conv.out_channels
  68. ), "Output channel of Conv2d must match num_features of BatchNorm2d"
  69. module_key = module_key + (conv.training,)
  70. module = _MAP_TO_FUSED_MODULE[module_key](
  71. in_channels=conv.in_channels,
  72. out_channels=conv.out_channels,
  73. kernel_size=conv.kernel_size,
  74. stride=conv.stride,
  75. padding=conv.padding,
  76. dilation=conv.dilation,
  77. groups=conv.groups,
  78. bias=conv.bias is not None,
  79. conv_mode=conv.conv_mode,
  80. compute_mode=conv.compute_mode,
  81. name=conv.name,
  82. )
  83. new_conv = module if bn is None or not conv.training else module.conv
  84. weight, bias = conv.weight, conv.bias
  85. if not conv.training and bn is not None:
  86. weight, bias = fold_weight_bias(
  87. weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps,
  88. )
  89. new_conv.weight = Parameter(weight)
  90. if bias is not None:
  91. new_conv.bias = Parameter(bias)
  92. if bn is not None and conv.training:
  93. module.bn = deepcopy(bn)
  94. new_conv.training = conv.training
  95. return module
  96. def fuse_conv_transpose2d_bn_relu_module(
  97. conv_transpose2d: ConvTranspose2d, bn: BatchNorm2d, relu: ReLU
  98. ):
  99. module_key = tuple([type(m) for m in [conv_transpose2d, bn, relu] if m])
  100. if bn:
  101. assert (
  102. conv_transpose2d.training == bn.training
  103. ), "ConvTranspose2d and BN both must be in the same mode (train or eval)."
  104. assert (
  105. bn.num_features == conv_transpose2d.out_channels
  106. ), "Output channel of ConvTranspose2d must match num_features of BatchNorm2d"
  107. module_key = module_key + (conv_transpose2d.training,)
  108. module = _MAP_TO_FUSED_MODULE[module_key](
  109. in_channels=conv_transpose2d.in_channels,
  110. out_channels=conv_transpose2d.out_channels,
  111. kernel_size=conv_transpose2d.kernel_size,
  112. stride=conv_transpose2d.stride,
  113. padding=conv_transpose2d.padding,
  114. output_padding=conv_transpose2d.output_padding,
  115. dilation=conv_transpose2d.dilation,
  116. groups=conv_transpose2d.groups,
  117. bias=conv_transpose2d.bias is not None,
  118. conv_mode=conv_transpose2d.conv_mode,
  119. compute_mode=conv_transpose2d.compute_mode,
  120. name=conv_transpose2d.name,
  121. )
  122. new_conv_transpose2d = (
  123. module
  124. if bn is None or not conv_transpose2d.training
  125. else module.conv_transpose2d
  126. )
  127. weight, bias = conv_transpose2d.weight, conv_transpose2d.bias
  128. if not conv_transpose2d.training and bn is not None:
  129. weight, bias = fold_weight_bias(
  130. weight,
  131. bias,
  132. bn.weight,
  133. bn.bias,
  134. bn.running_mean,
  135. bn.running_var,
  136. bn.eps,
  137. transpose=False,
  138. )
  139. new_conv_transpose2d.weight = Parameter(weight)
  140. if bias is not None:
  141. new_conv_transpose2d.bias = Parameter(bias)
  142. if bn is not None and conv_transpose2d.training:
  143. module.bn = deepcopy(bn)
  144. new_conv_transpose2d.training = conv_transpose2d.training
  145. return module