You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bn_fusion.py 3.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from copy import deepcopy
  2. from ..functional import ones, sqrt, zeros
  3. from ..module import BatchNorm2d, Conv2d, ConvBn2d, ConvBnRelu2d, ConvRelu2d, ReLU
  4. from ..tensor import Parameter
  5. _MAP_TO_FUSED_MODULE = {
  6. (Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d,
  7. (Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d,
  8. (Conv2d, BatchNorm2d, False): Conv2d,
  9. (Conv2d, BatchNorm2d, True): ConvBn2d,
  10. (Conv2d, ReLU): ConvRelu2d,
  11. }
  12. def fold_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5):
  13. # get fold bn conv param
  14. kernel_shape = weight.shape
  15. if len(kernel_shape) == 5:
  16. groups, num_features = kernel_shape[0], kernel_shape[1]
  17. else:
  18. groups, num_features = 1, kernel_shape[0]
  19. if gamma is None:
  20. gamma = ones((num_features), dtype="float32")
  21. gamma = gamma.reshape(1, -1, 1, 1)
  22. if beta is None:
  23. beta = zeros((num_features), dtype="float32")
  24. beta = beta.reshape(1, -1, 1, 1)
  25. if bn_mean is None:
  26. bn_mean = zeros((1, num_features, 1, 1), dtype="float32")
  27. if bn_var is None:
  28. bn_var = ones((1, num_features, 1, 1), dtype="float32")
  29. if bias is None:
  30. bias = zeros((1, num_features, 1, 1), dtype="float32")
  31. bn_istd = 1.0 / sqrt(bn_var + eps)
  32. scale_factor = gamma * bn_istd
  33. if groups == 1:
  34. w_fold = weight * scale_factor.reshape(-1, 1, 1, 1)
  35. else:
  36. w_fold = weight * scale_factor.reshape(groups, -1, 1, 1, 1)
  37. b_fold = beta + gamma * (bias - bn_mean) * bn_istd
  38. return w_fold, b_fold
  39. def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
  40. module_key = tuple([type(m) for m in [conv, bn, relu] if m])
  41. if bn:
  42. assert (
  43. conv.training == bn.training
  44. ), "Conv and BN both must be in the same mode (train or eval)."
  45. assert (
  46. bn.num_features == conv.out_channels
  47. ), "Output channel of Conv2d must match num_features of BatchNorm2d"
  48. module_key = module_key + (conv.training,)
  49. module = _MAP_TO_FUSED_MODULE[module_key](
  50. in_channels=conv.in_channels,
  51. out_channels=conv.out_channels,
  52. kernel_size=conv.kernel_size,
  53. stride=conv.stride,
  54. padding=conv.padding,
  55. dilation=conv.dilation,
  56. groups=conv.groups,
  57. bias=conv.bias is not None,
  58. conv_mode=conv.conv_mode,
  59. compute_mode=conv.compute_mode,
  60. name=conv.name,
  61. )
  62. new_conv = module if bn is None or not conv.training else module.conv
  63. weight, bias = conv.weight, conv.bias
  64. if not conv.training and bn is not None:
  65. weight, bias = fold_weight_bias(
  66. weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps,
  67. )
  68. new_conv.weight = Parameter(weight)
  69. if bias is not None:
  70. new_conv.bias = Parameter(bias)
  71. if bn is not None and conv.training:
  72. module.bn = deepcopy(bn)
  73. new_conv.training = conv.training
  74. return module

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台