|
|
@@ -132,13 +132,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): |
|
|
|
if self.conv.bias is not None: |
|
|
|
orig_conv = orig_conv + self.conv.bias |
|
|
|
# calculate batch norm |
|
|
|
bn_mean, bn_var = self.get_batch_mean_var(orig_conv) |
|
|
|
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) |
|
|
|
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta |
|
|
|
num_elements_per_channel = conv.size / conv.shape[1] |
|
|
|
self.update_running_mean_and_running_var( |
|
|
|
bn_mean, bn_var, num_elements_per_channel |
|
|
|
) |
|
|
|
conv = self.bn(orig_conv) |
|
|
|
return conv |
|
|
|
|
|
|
|
@classmethod |
|
|
|