|
- from copy import deepcopy
-
- from ..functional import ones, sqrt, zeros
- from ..module import BatchNorm2d, Conv2d, ConvBn2d, ConvBnRelu2d, ConvRelu2d, ReLU
- from ..tensor import Parameter
-
- _MAP_TO_FUSED_MODULE = {
- (Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d,
- (Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d,
- (Conv2d, BatchNorm2d, False): Conv2d,
- (Conv2d, BatchNorm2d, True): ConvBn2d,
- (Conv2d, ReLU): ConvRelu2d,
- }
-
-
- def fold_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5):
- # get fold bn conv param
- kernel_shape = weight.shape
- if len(kernel_shape) == 5:
- groups, num_features = kernel_shape[0], kernel_shape[1]
- else:
- groups, num_features = 1, kernel_shape[0]
-
- if gamma is None:
- gamma = ones((num_features), dtype="float32")
- gamma = gamma.reshape(1, -1, 1, 1)
- if beta is None:
- beta = zeros((num_features), dtype="float32")
- beta = beta.reshape(1, -1, 1, 1)
-
- if bn_mean is None:
- bn_mean = zeros((1, num_features, 1, 1), dtype="float32")
- if bn_var is None:
- bn_var = ones((1, num_features, 1, 1), dtype="float32")
-
- if bias is None:
- bias = zeros((1, num_features, 1, 1), dtype="float32")
-
- bn_istd = 1.0 / sqrt(bn_var + eps)
- scale_factor = gamma * bn_istd
-
- if groups == 1:
- w_fold = weight * scale_factor.reshape(-1, 1, 1, 1)
- else:
- w_fold = weight * scale_factor.reshape(groups, -1, 1, 1, 1)
-
- b_fold = beta + gamma * (bias - bn_mean) * bn_istd
- return w_fold, b_fold
-
-
- def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
- module_key = tuple([type(m) for m in [conv, bn, relu] if m])
- if bn:
- assert (
- conv.training == bn.training
- ), "Conv and BN both must be in the same mode (train or eval)."
- assert (
- bn.num_features == conv.out_channels
- ), "Output channel of Conv2d must match num_features of BatchNorm2d"
- module_key = module_key + (conv.training,)
- module = _MAP_TO_FUSED_MODULE[module_key](
- in_channels=conv.in_channels,
- out_channels=conv.out_channels,
- kernel_size=conv.kernel_size,
- stride=conv.stride,
- padding=conv.padding,
- dilation=conv.dilation,
- groups=conv.groups,
- bias=conv.bias is not None,
- conv_mode=conv.conv_mode,
- compute_mode=conv.compute_mode,
- name=conv.name,
- )
- new_conv = module if bn is None or not conv.training else module.conv
- weight, bias = conv.weight, conv.bias
- if not conv.training and bn is not None:
- weight, bias = fold_weight_bias(
- weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps,
- )
- new_conv.weight = Parameter(weight)
- if bias is not None:
- new_conv.bias = Parameter(bias)
- if bn is not None and conv.training:
- module.bn = deepcopy(bn)
- new_conv.training = conv.training
- return module
|