You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bn_fusion.py 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. from copy import deepcopy
  2. from ..functional import ones, sqrt, zeros
  3. from ..module import (
  4. BatchNorm2d,
  5. Conv2d,
  6. ConvBn2d,
  7. ConvBnRelu2d,
  8. ConvRelu2d,
  9. ConvTranspose2d,
  10. ConvTransposeBn2d,
  11. ConvTransposeBnRelu2d,
  12. ConvTransposeRelu2d,
  13. ReLU,
  14. )
  15. from ..tensor import Parameter
  16. _MAP_TO_FUSED_MODULE = {
  17. (Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d,
  18. (Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d,
  19. (ConvTranspose2d, BatchNorm2d, ReLU, False): ConvTransposeRelu2d,
  20. (ConvTranspose2d, BatchNorm2d, ReLU, True): ConvTransposeBnRelu2d,
  21. (Conv2d, BatchNorm2d, False): Conv2d,
  22. (Conv2d, BatchNorm2d, True): ConvBn2d,
  23. (Conv2d, ReLU): ConvRelu2d,
  24. (ConvTranspose2d, BatchNorm2d, False): ConvTranspose2d,
  25. (ConvTranspose2d, BatchNorm2d, True): ConvTransposeBn2d,
  26. (ConvTranspose2d, ReLU): ConvTransposeRelu2d,
  27. }
  28. def fold_weight_bias(
  29. weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False
  30. ):
  31. shape = (-1, 1, 1, 1)
  32. if transpose:
  33. shape = (1, -1, 1, 1)
  34. kernel_shape = weight.shape
  35. if len(kernel_shape) == 5:
  36. if transpose:
  37. groups, num_features = kernel_shape[0], kernel_shape[2]
  38. else:
  39. groups, num_features = kernel_shape[0], kernel_shape[1]
  40. else:
  41. if transpose:
  42. groups, num_features = 1, kernel_shape[1]
  43. else:
  44. groups, num_features = 1, kernel_shape[0]
  45. out_channels = groups * num_features
  46. if gamma is None:
  47. gamma = ones((out_channels,), dtype="float32")
  48. gamma = gamma.reshape(1, -1, 1, 1)
  49. if beta is None:
  50. beta = zeros((out_channels,), dtype="float32")
  51. beta = beta.reshape(1, -1, 1, 1)
  52. if bn_mean is None:
  53. bn_mean = zeros((1, out_channels, 1, 1), dtype="float32")
  54. if bn_var is None:
  55. bn_var = ones((1, out_channels, 1, 1), dtype="float32")
  56. if bias is None:
  57. bias = zeros((1, out_channels, 1, 1), dtype="float32")
  58. bn_istd = 1.0 / sqrt(bn_var + eps)
  59. scale_factor = gamma * bn_istd
  60. if groups == 1:
  61. w_fold = weight * scale_factor.reshape(*shape)
  62. else:
  63. w_fold = weight * scale_factor.reshape(groups, *shape)
  64. b_fold = beta + gamma * (bias - bn_mean) * bn_istd
  65. return w_fold, b_fold
  66. def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
  67. module_key = tuple([type(m) for m in [conv, bn, relu] if m])
  68. if bn:
  69. assert (
  70. conv.training == bn.training
  71. ), "Conv and BN both must be in the same mode (train or eval)."
  72. assert (
  73. bn.num_features == conv.out_channels
  74. ), "Output channel of Conv2d must match num_features of BatchNorm2d"
  75. module_key = module_key + (conv.training,)
  76. module = _MAP_TO_FUSED_MODULE[module_key](
  77. in_channels=conv.in_channels,
  78. out_channels=conv.out_channels,
  79. kernel_size=conv.kernel_size,
  80. stride=conv.stride,
  81. padding=conv.padding,
  82. dilation=conv.dilation,
  83. groups=conv.groups,
  84. bias=conv.bias is not None,
  85. conv_mode=conv.conv_mode,
  86. compute_mode=conv.compute_mode,
  87. name=conv.name,
  88. )
  89. if isinstance(conv, ConvTranspose2d):
  90. module.output_padding = conv.output_padding
  91. new_conv = (
  92. module if bn is None or not conv.training else module.conv_transpose2d
  93. )
  94. else:
  95. new_conv = module if bn is None or not conv.training else module.conv
  96. weight, bias = conv.weight, conv.bias
  97. if not conv.training and bn is not None:
  98. if isinstance(conv, ConvTranspose2d):
  99. weight, bias = fold_weight_bias(
  100. weight,
  101. bias,
  102. bn.weight,
  103. bn.bias,
  104. bn.running_mean,
  105. bn.running_var,
  106. bn.eps,
  107. transpose=True,
  108. )
  109. else:
  110. weight, bias = fold_weight_bias(
  111. weight,
  112. bias,
  113. bn.weight,
  114. bn.bias,
  115. bn.running_mean,
  116. bn.running_var,
  117. bn.eps,
  118. )
  119. new_conv.weight = Parameter(weight)
  120. if bias is not None:
  121. new_conv.bias = Parameter(bias)
  122. if bn is not None and conv.training:
  123. module.bn = deepcopy(bn)
  124. new_conv.training = conv.training
  125. return module