diff --git a/imperative/python/megengine/module/qat/conv_bn.py b/imperative/python/megengine/module/qat/conv_bn.py index bb7414d9..89e270f1 100644 --- a/imperative/python/megengine/module/qat/conv_bn.py +++ b/imperative/python/megengine/module/qat/conv_bn.py @@ -132,13 +132,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): if self.conv.bias is not None: orig_conv = orig_conv + self.conv.bias # calculate batch norm - bn_mean, bn_var = self.get_batch_mean_var(orig_conv) - bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) - conv = gamma * bn_istd * (orig_conv - bn_mean) + beta - num_elements_per_channel = conv.size / conv.shape[1] - self.update_running_mean_and_running_var( - bn_mean, bn_var, num_elements_per_channel - ) + conv = self.bn(orig_conv) return conv @classmethod