|
|
@@ -141,6 +141,35 @@ class _BatchNorm(Module): |
|
|
|
class SyncBatchNorm(_BatchNorm): |
|
|
|
r""" |
|
|
|
Applies Synchronized Batch Normalization for distributed training. |
|
|
|
|
|
|
|
:type num_features: int |
|
|
|
:param num_features: usually :math:`C` from an input of shape |
|
|
|
:math:`(N, C, H, W)` or the highest ranked dimension of an input |
|
|
|
less than 4D. |
|
|
|
:type eps: float |
|
|
|
:param eps: a value added to the denominator for numerical stability. |
|
|
|
Default: 1e-5 |
|
|
|
:type momentum: float |
|
|
|
:param momentum: the value used for the ``running_mean`` and ``running_var`` computation. |
|
|
|
Default: 0.9 |
|
|
|
:type affine: bool |
|
|
|
:param affine: a boolean value that when set to True, this module has |
|
|
|
learnable affine parameters. Default: True |
|
|
|
:type track_running_stats: bool |
|
|
|
:param track_running_stats: when set to True, this module tracks the |
|
|
|
running mean and variance. When set to False, this module does not |
|
|
|
track such statistics and always uses batch statistics in both training |
|
|
|
and eval modes. Default: True |
|
|
|
:type freeze: bool |
|
|
|
:param freeze: when set to True, this module does not update the |
|
|
|
running mean and variance, and uses the running mean and variance instead of |
|
|
|
the batch mean and batch variance to normalize the input. The parameter takes effect |
|
|
|
only when the module is initilized with track_running_stats as True. |
|
|
|
Default: False |
|
|
|
:type group: :class:`~megengine.distributed.Group` |
|
|
|
:param group: communication group, caculate mean and variance between this group. |
|
|
|
Default: :obj:`~megengine.distributed.WORLD` |
|
|
|
:return: output tensor. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|