|
@@ -1269,15 +1269,27 @@ def sync_batch_norm( |
|
|
group: communication group, caculate mean and variance between this group. |
|
|
group: communication group, caculate mean and variance between this group. |
|
|
Default: :obj:`~megengine.distributed.WORLD` |
|
|
Default: :obj:`~megengine.distributed.WORLD` |
|
|
""" |
|
|
""" |
|
|
assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format( |
|
|
|
|
|
eps_mode |
|
|
|
|
|
) |
|
|
|
|
|
# TODO: cudnnBn fastpath |
|
|
|
|
|
|
|
|
_eps_mode = eps_mode.lower() |
|
|
|
|
|
assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode) |
|
|
|
|
|
if _eps_mode == "additive" and not (is_distributed() and training): |
|
|
|
|
|
return batch_norm( |
|
|
|
|
|
inp, |
|
|
|
|
|
running_mean, |
|
|
|
|
|
running_var, |
|
|
|
|
|
weight, |
|
|
|
|
|
bias, |
|
|
|
|
|
training=training, |
|
|
|
|
|
momentum=momentum, |
|
|
|
|
|
eps=eps, |
|
|
|
|
|
) |
|
|
_channels = make_shape_tuple(inp.shape)[1] |
|
|
_channels = make_shape_tuple(inp.shape)[1] |
|
|
_ndim = inp.ndim |
|
|
_ndim = inp.ndim |
|
|
_device = inp.device |
|
|
_device = inp.device |
|
|
_dtype = inp.dtype |
|
|
_dtype = inp.dtype |
|
|
|
|
|
|
|
|
|
|
|
if _ndim != 4: |
|
|
|
|
|
raise NotImplementedError("sync_batch_norm for ndim != 4") |
|
|
|
|
|
|
|
|
def _make_full_if_none(x, value): |
|
|
def _make_full_if_none(x, value): |
|
|
if x is None: |
|
|
if x is None: |
|
|
(x,) = Const(value, dtype=inp.dtype, device=_device)() |
|
|
(x,) = Const(value, dtype=inp.dtype, device=_device)() |
|
|