diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index a4479077..993ccfa7 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -280,6 +280,17 @@ class BatchNorm2d(_BatchNorm): statistics on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. + .. note:: + + The update formula for ``running_mean`` and ``running_var`` (taking ``running_mean`` as an example) is + + .. math:: + + \textrm{running_mean} = \textrm{momentum} \times \textrm{running_mean} + (1 - \textrm{momentum}) \times \textrm{batch_mean} + + which could be defined differently in other frameworks. Most notably, ``momentum`` of 0.1 in PyTorch + is equivalent to ``mementum`` of 0.9 here. + Args: num_features: usually :math:`C` from an input of shape :math:`(N, C, H, W)` or the highest ranked dimension of an input