diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index 2e95f25a..efeaba17 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -100,6 +100,15 @@ class _BatchNorm(Module): if _bias is not None: _bias = _bias.detach() + # fastpath excution for freeze + scale = (self.running_var + self.eps) ** (-0.5) + if _weight is not None: + scale *= _weight + bias = -self.running_mean * scale + if _bias is not None: + bias += _bias + return inp * scale + bias + if self.training and self.track_running_stats: exponential_average_factor = self.momentum else: