|
@@ -22,7 +22,7 @@ from .debug_param import get_conv_execution_strategy |
|
|
from .distributed import all_reduce_sum |
|
|
from .distributed import all_reduce_sum |
|
|
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu |
|
|
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu |
|
|
from .math import argsort, max, sum |
|
|
from .math import argsort, max, sum |
|
|
from .tensor import add_axis, broadcast, concat, full, remove_axis, reshape |
|
|
|
|
|
|
|
|
from .tensor import add_axis, broadcast, concat, remove_axis, reshape |
|
|
from .types import _pair, _pair_nonzero |
|
|
from .types import _pair, _pair_nonzero |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
__all__ = [ |
|
@@ -692,7 +692,7 @@ def batch_norm2d( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sync_batch_norm( |
|
|
def sync_batch_norm( |
|
|
input: Tensor, |
|
|
|
|
|
|
|
|
inp: Tensor, |
|
|
running_mean: Tensor, |
|
|
running_mean: Tensor, |
|
|
running_var: Tensor, |
|
|
running_var: Tensor, |
|
|
weight: Optional[Tensor] = None, |
|
|
weight: Optional[Tensor] = None, |
|
@@ -723,25 +723,30 @@ def sync_batch_norm( |
|
|
Default: 1e-5. |
|
|
Default: 1e-5. |
|
|
""" |
|
|
""" |
|
|
assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode) |
|
|
assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode) |
|
|
_channels = input.shape[1] |
|
|
|
|
|
_ndim = input.ndim |
|
|
|
|
|
|
|
|
_channels = inp.shape[1] |
|
|
|
|
|
_ndim = inp.ndim |
|
|
|
|
|
_device = inp.device |
|
|
|
|
|
_dtype = inp.dtype |
|
|
_param_shape = (1, _channels) + (1,) * (_ndim - 2) |
|
|
_param_shape = (1, _channels) + (1,) * (_ndim - 2) |
|
|
|
|
|
_reduce_axis = [0] + [i for i in range(2, _ndim)] |
|
|
|
|
|
|
|
|
if training: |
|
|
if training: |
|
|
|
|
|
|
|
|
def _sum_on_channel(input): |
|
|
|
|
|
return apply(builtin.Reduce(mode="SUM"), input, Tensor(_param_shape))[0] |
|
|
|
|
|
|
|
|
def _sum_on_channel(inp): |
|
|
|
|
|
return inp.sum(axis=_reduce_axis, keepdims=True) |
|
|
|
|
|
|
|
|
reduce_size = input.shape[0] |
|
|
|
|
|
|
|
|
reduce_size = inp.shape[0] |
|
|
for i in range(2, _ndim): |
|
|
for i in range(2, _ndim): |
|
|
reduce_size = reduce_size * input.shape[i] |
|
|
|
|
|
channel_x1s = _sum_on_channel(input) |
|
|
|
|
|
channel_x2s = _sum_on_channel(input ** 2) |
|
|
|
|
|
|
|
|
reduce_size = reduce_size * inp.shape[i] |
|
|
|
|
|
channel_x1s = _sum_on_channel(inp) |
|
|
|
|
|
channel_x2s = _sum_on_channel(inp ** 2) |
|
|
|
|
|
|
|
|
if is_distributed(): |
|
|
if is_distributed(): |
|
|
# reduce all nodes' data to calculate mean and variance |
|
|
# reduce all nodes' data to calculate mean and variance |
|
|
reduce_size = full([1 for _ in range(_ndim)], reduce_size) |
|
|
|
|
|
stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1) |
|
|
|
|
|
|
|
|
reduce_size = broadcast(Tensor(reduce_size, dtype=_dtype), [1] * _ndim) |
|
|
|
|
|
stat = concat( |
|
|
|
|
|
[reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1 |
|
|
|
|
|
) |
|
|
stat = all_reduce_sum(stat, group) |
|
|
stat = all_reduce_sum(stat, group) |
|
|
reduce_size = stat[:, :1].reshape(1) |
|
|
reduce_size = stat[:, :1].reshape(1) |
|
|
channel_x1s = stat[:, 1 : 1 + _channels] |
|
|
channel_x1s = stat[:, 1 : 1 + _channels] |
|
@@ -775,11 +780,11 @@ def sync_batch_norm( |
|
|
inv_var_wt = invsqrt_channel_variance * weight |
|
|
inv_var_wt = invsqrt_channel_variance * weight |
|
|
neg_channel_mean = -channel_mean |
|
|
neg_channel_mean = -channel_mean |
|
|
if bias is not None: |
|
|
if bias is not None: |
|
|
outvar = input * inv_var_wt + (neg_channel_mean * inv_var_wt + bias) |
|
|
|
|
|
|
|
|
outvar = inp * inv_var_wt + (neg_channel_mean * inv_var_wt + bias) |
|
|
else: |
|
|
else: |
|
|
outvar = input * inv_var_wt + neg_channel_mean * inv_var_wt |
|
|
|
|
|
|
|
|
outvar = inp * inv_var_wt + neg_channel_mean * inv_var_wt |
|
|
else: |
|
|
else: |
|
|
outvar = input * invsqrt_channel_variance + ( |
|
|
|
|
|
|
|
|
outvar = inp * invsqrt_channel_variance + ( |
|
|
-channel_mean * invsqrt_channel_variance |
|
|
-channel_mean * invsqrt_channel_variance |
|
|
) |
|
|
) |
|
|
if bias is not None: |
|
|
if bias is not None: |
|
|