@@ -12,56 +12,14 @@ import megengine._internal as mgb | |||
from megengine._internal.opr_param_defs import CollectiveComm as CollParam | |||
from ..core import Buffer, Parameter, Tensor, wrap_io_tensor | |||
from ..core.graph import get_default_graph | |||
from ..functional import add_update | |||
from .util import ( | |||
get_backend, | |||
get_master_ip, | |||
get_master_port, | |||
get_rank, | |||
get_world_size, | |||
is_distributed, | |||
) | |||
from .helper import collective_comm_symvar | |||
from .util import get_rank, is_distributed | |||
@wrap_io_tensor | |||
def _collective_comm( | |||
inp: Union[Tensor, mgb.CompGraph], | |||
key: str, | |||
op: CollParam.Mode, | |||
nr_ranks: Optional[int] = None, | |||
rank: Optional[int] = None, | |||
root: Optional[int] = 0, | |||
dtype: Optional[type] = None, | |||
device: Optional[mgb.CompNode] = None, | |||
comp_graph: Optional[mgb.CompGraph] = None, | |||
) -> Tensor: | |||
"""Helper function for creating collective_comm operators | |||
:param inp: tensor or comp_graph | |||
:param key: unique identifier for collective communication | |||
:param op: mode of collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param rank: rank of the current process, use util.get_rank() as default | |||
:param root: rank of root node, use 0 as default | |||
:param dtype: output data type, use dtype of inp as default | |||
:param device: output comp node, use comp node of inp as default | |||
:param comp_graph: output comp graph, use comp graph of inp as default | |||
""" | |||
return mgb.opr.collective_comm( | |||
inp, | |||
key=str(key), | |||
nr_devices=nr_ranks if nr_ranks is not None else get_world_size(), | |||
rank=rank if rank is not None else get_rank(), | |||
root=root, | |||
server_addr=get_master_ip(), | |||
port=get_master_port(), | |||
param=CollParam(mode=op), | |||
dtype=dtype, | |||
backend=get_backend(), | |||
comp_node=device, | |||
comp_graph=comp_graph, | |||
) | |||
def _collective_comm(*args, **kargs): | |||
return collective_comm_symvar(*args, **kargs) | |||
def reduce_sum( | |||
@@ -0,0 +1,53 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Optional, Union | |||
import megengine._internal as mgb | |||
from megengine._internal.opr_param_defs import CollectiveComm as CollParam | |||
from .util import get_backend, get_master_ip, get_master_port, get_rank, get_world_size | |||
def collective_comm_symvar( | |||
inp: Union[mgb.SymbolVar, mgb.CompGraph], | |||
key: str, | |||
op: CollParam.Mode, | |||
nr_ranks: Optional[int] = None, | |||
rank: Optional[int] = None, | |||
root: Optional[int] = 0, | |||
dtype: Optional[type] = None, | |||
device: Optional[mgb.CompNode] = None, | |||
comp_graph: Optional[mgb.CompGraph] = None, | |||
) -> mgb.SymbolVar: | |||
"""Helper function for creating collective_comm operators | |||
:param inp: tensor or comp_graph | |||
:param key: unique identifier for collective communication | |||
:param op: mode of collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param rank: rank of the current process, use util.get_rank() as default | |||
:param root: rank of root node, use 0 as default | |||
:param dtype: output data type, use dtype of inp as default | |||
:param device: output comp node, use comp node of inp as default | |||
:param comp_graph: output comp graph, use comp graph of inp as default | |||
""" | |||
return mgb.opr.collective_comm( | |||
inp, | |||
key=str(key), | |||
nr_devices=nr_ranks if nr_ranks is not None else get_world_size(), | |||
rank=rank if rank is not None else get_rank(), | |||
root=root, | |||
server_addr=get_master_ip(), | |||
port=get_master_port(), | |||
param=CollParam(mode=op), | |||
dtype=dtype, | |||
backend=get_backend(), | |||
comp_node=device, | |||
comp_graph=comp_graph, | |||
) |
@@ -19,6 +19,7 @@ _master_port = 0 | |||
_world_size = 0 | |||
_rank = 0 | |||
_backend = None | |||
_group_id = 0 | |||
def init_process_group( | |||
@@ -43,6 +44,7 @@ def init_process_group( | |||
global _world_size # pylint: disable=global-statement | |||
global _rank # pylint: disable=global-statement | |||
global _backend # pylint: disable=global-statement | |||
global _group_id # pylint: disable=global-statement | |||
if not isinstance(master_ip, str): | |||
raise TypeError("Expect type str but got {}".format(type(master_ip))) | |||
@@ -60,6 +62,7 @@ def init_process_group( | |||
_world_size = world_size | |||
_rank = rank | |||
_backend = backend | |||
_group_id = 0 | |||
set_default_device(mgb.comp_node("gpu" + str(dev))) | |||
@@ -101,6 +104,13 @@ def get_backend() -> str: | |||
return str(_backend) | |||
def get_group_id() -> int: | |||
"""Get group id for collective communication""" | |||
global _group_id | |||
_group_id += 1 | |||
return _group_id | |||
def group_barrier() -> None: | |||
"""Block until all ranks in the group reach this barrier""" | |||
mgb.config.group_barrier(_master_ip, _master_port, _world_size, _rank) | |||
@@ -76,6 +76,7 @@ from .nn import ( | |||
roi_pooling, | |||
softmax, | |||
softplus, | |||
sync_batch_norm, | |||
warp_perspective, | |||
) | |||
from .quantized import conv_bias_activation | |||
@@ -11,15 +11,20 @@ from typing import Optional, Tuple, Union | |||
import megengine._internal as mgb | |||
from megengine._internal import CompGraph, CompNode | |||
from megengine._internal.config import add_extra_vardep | |||
from megengine._internal.opr import add_update | |||
from megengine._internal.opr_param_defs import CollectiveComm as CollParam | |||
from .. import distributed as dist | |||
from ..core import Tensor, wrap_io_tensor | |||
from ..core.graph import _use_default_if_none | |||
from ..distributed.util import get_group_id | |||
from ..jit import barrier, mark_impure | |||
from ..random import uniform | |||
from ..utils.types import _pair, _pair_nonzero | |||
from .debug_param import get_conv_execution_strategy | |||
from .elemwise import exp, log | |||
from .tensor import concat, where | |||
from .tensor import where | |||
from .utils import _decide_comp_node_and_comp_graph | |||
@@ -474,6 +479,125 @@ def batch_norm2d( | |||
return output | |||
@wrap_io_tensor | |||
def sync_batch_norm( | |||
input: Tensor, | |||
running_mean: Tensor, | |||
running_var: Tensor, | |||
weight: Optional[Tensor] = None, | |||
bias: Optional[Tensor] = None, | |||
training: bool = False, | |||
momentum: Union[float, Tensor] = 0.9, | |||
eps: float = 1e-5, | |||
eps_mode="ADDITIVE", | |||
) -> Tensor: | |||
""" Applies synchronized batch normalization to the input. | |||
Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information. | |||
:param inp: input tensor. | |||
:param running_mean: tensor to store running mean. | |||
:param running_var: tensor to store running variance. | |||
:param weight: scaling tensor in the learnable affine parameters. | |||
See :math:`\gamma` in :class:`~.BatchNorm2d` | |||
:param bias: bias tensor in the learnable affine parameters. | |||
See :math:`\beta` in :class:`~.BatchNorm2d` | |||
:param training: a boolean value to indicate whether batch norm is performed | |||
in traning mode. Default: ``False`` | |||
:param momentum: the value used for the ``running_mean`` and ``running_var`` | |||
computation. | |||
Default: 0.9 | |||
:param eps: a value added to the denominator for numerical stability. | |||
Default: 1e-5. | |||
""" | |||
assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode) | |||
input = mgb.opr.mark_no_broadcast_elemwise(input) | |||
_channels = input.imm_shape[1] | |||
_ndim = len(input.imm_shape) | |||
_param_shape = (1, _channels) + (1,) * (_ndim - 2) | |||
if training: | |||
def _sum_on_channel(input): | |||
return mgb.opr.reduce_general([input, _param_shape], mode="sum") | |||
def _allreduce(stat, key): | |||
return dist.helper.collective_comm_symvar( | |||
stat, key, CollParam.Mode.ALL_REDUCE_SUM | |||
) | |||
reduce_size = input.shape[0] | |||
for i in range(2, _ndim): | |||
reduce_size = reduce_size * input.shape[i] | |||
channel_x1s = _sum_on_channel(input) | |||
channel_x2s = _sum_on_channel(input ** 2) | |||
if dist.is_distributed(): | |||
# reduce all nodes' data to calculate mean and variance | |||
reduce_size = reduce_size.reshape(*(1,) * _ndim) | |||
stat = mgb.opr.concat([reduce_size, channel_x1s, channel_x2s], axis=1) | |||
stat = _allreduce(stat, key="sync_bn_" + str(get_group_id())) | |||
reduce_size = stat[:, :1].reshape(1) | |||
channel_x1s = stat[:, 1 : 1 + _channels] | |||
channel_x2s = stat[:, 1 + _channels :] | |||
channel_mean = channel_x1s / reduce_size | |||
channel_variance = ( | |||
channel_x1s ** 2 / (-reduce_size * reduce_size) + channel_x2s / reduce_size | |||
) | |||
else: | |||
assert running_var is not None and running_mean is not None | |||
channel_variance = running_var.reshape(*_param_shape) | |||
channel_mean = running_mean.reshape(*_param_shape) | |||
invsqrt_channel_variance = ( | |||
mgb.opr.elem.max(channel_variance, eps) | |||
if eps_mode == "MAX" | |||
else mgb.opr.elem.add(channel_variance, eps) | |||
) ** -0.5 | |||
if weight is not None: | |||
weight = weight.reshape(*_param_shape) | |||
if bias is not None: | |||
bias = bias.reshape(*_param_shape) | |||
# outvar = output * weight + bias | |||
# where output = input * invsqrt_channel_variance + ( | |||
# -channel_mean * invsqrt_channel_variance | |||
# ) | |||
# Manually expand output for gopt | |||
if weight is not None: | |||
inv_var_wt = invsqrt_channel_variance * weight | |||
neg_channel_mean = -channel_mean | |||
if bias is not None: | |||
outvar = input * inv_var_wt + (neg_channel_mean * inv_var_wt + bias) | |||
else: | |||
outvar = input * inv_var_wt + neg_channel_mean * inv_var_wt | |||
else: | |||
outvar = input * invsqrt_channel_variance + ( | |||
-channel_mean * invsqrt_channel_variance | |||
) | |||
if bias is not None: | |||
outvar = outvar + bias | |||
if training and running_var is not None and running_mean is not None: | |||
_mean_update = add_update( | |||
running_mean, channel_mean, alpha=momentum, beta=1 - momentum, | |||
) | |||
channel_variance_unbiased = channel_x1s ** 2 / ( | |||
-reduce_size * (reduce_size - 1) | |||
) + channel_x2s / (reduce_size - 1) | |||
_variance_update = add_update( | |||
running_var, channel_variance_unbiased, alpha=momentum, beta=1 - momentum | |||
) | |||
for dep in (_mean_update, _variance_update): | |||
add_extra_vardep(outvar, dep) | |||
return outvar | |||
def one_hot(inp: Tensor, num_classes: int) -> Tensor: | |||
r""" | |||
Perform one-hot encoding for the input tensor. | |||
@@ -7,7 +7,7 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||
from .batchnorm import BatchNorm1d, BatchNorm2d | |||
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
from .concat import Concat | |||
from .conv import Conv2d, ConvTranspose2d, LocalConv2d | |||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||
@@ -9,7 +9,7 @@ | |||
import numpy as np | |||
from ..core import Buffer, Parameter | |||
from ..functional import batch_norm2d | |||
from ..functional import batch_norm2d, sync_batch_norm | |||
from . import init | |||
from .module import Module | |||
@@ -74,7 +74,6 @@ class _BatchNorm(Module): | |||
inp = inp.reshape(new_shape) | |||
_iter_update = None | |||
if self.training and self.track_running_stats: | |||
exponential_average_factor = self.momentum | |||
else: | |||
@@ -97,6 +96,54 @@ class _BatchNorm(Module): | |||
return output | |||
class SyncBatchNorm(_BatchNorm): | |||
r""" | |||
Applies Synchronization Batch Normalization. | |||
""" | |||
def _check_input_ndim(self, inp): | |||
if len(inp.shape) not in {2, 3, 4}: | |||
raise ValueError( | |||
"expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape)) | |||
) | |||
def forward(self, inp): | |||
self._check_input_ndim(inp) | |||
_ndims = len(inp.shape) | |||
if _ndims != 4: | |||
origin_shape = inp.shapeof() | |||
if _ndims == 2: | |||
n, c = inp.shapeof(0), inp.shapeof(1) | |||
new_shape = (n, c, 1, 1) | |||
elif _ndims == 3: | |||
n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2) | |||
new_shape = (n, c, h, 1) | |||
inp = inp.reshape(new_shape) | |||
if self.training and self.track_running_stats: | |||
exponential_average_factor = self.momentum | |||
else: | |||
exponential_average_factor = 0.0 # useless | |||
output = sync_batch_norm( | |||
inp, | |||
self.running_mean, | |||
self.running_var, | |||
self.weight, | |||
self.bias, | |||
self.training or not self.track_running_stats, | |||
exponential_average_factor, | |||
self.eps, | |||
) | |||
if _ndims != 4: | |||
output = output.reshape(origin_shape) | |||
return output | |||
class BatchNorm1d(_BatchNorm): | |||
r""" | |||
Applies Batch Normalization over a 2D/3D tensor. | |||
@@ -18,6 +18,7 @@ from .._internal.config import opr_priority_scope | |||
from ..core import Buffer, Parameter, Tensor, TensorDict | |||
from ..core.graph import get_default_graph | |||
from ..distributed import all_reduce_sum, bcast_param, get_world_size, is_distributed | |||
from ..distributed.util import get_group_id | |||
from ..functional import add_update | |||
from ..functional import grad as grad_func | |||
from ..jit import sideeffect | |||
@@ -152,7 +153,7 @@ class Optimizer(metaclass=ABCMeta): | |||
:param loss: The obtained loss tensor | |||
""" | |||
rst = [] | |||
key = 0 | |||
priority = 0 | |||
params = [] | |||
for group in self.param_groups: | |||
for param in group["params"]: | |||
@@ -173,11 +174,14 @@ class Optimizer(metaclass=ABCMeta): | |||
for param, grad in zip(params, grads): | |||
if is_distributed(): | |||
key += 1 | |||
with opr_priority_scope(cg, -key): | |||
priority += 1 | |||
with opr_priority_scope(cg, -priority): | |||
# all_reduce_mean | |||
grad = all_reduce_sum(grad, key) / get_world_size() | |||
with opr_priority_scope(cg, (1 << 30) - key): | |||
grad = ( | |||
all_reduce_sum(grad, "grad_" + str(get_group_id())) | |||
/ get_world_size() | |||
) | |||
with opr_priority_scope(cg, (1 << 30) - priority): | |||
grad_update = add_update(param.grad, grad) | |||
else: | |||
grad_update = add_update(param.grad, grad) | |||
@@ -216,11 +220,9 @@ class Optimizer(metaclass=ABCMeta): | |||
param.grad.reset_zero() | |||
def bcast_param(self): | |||
key = 0 | |||
for group in self.param_groups: | |||
for param in group["params"]: | |||
bcast_param(param, key) | |||
key += 1 | |||
bcast_param(param, "bcast_param_" + str(get_group_id())) | |||
def state_dict(self) -> Dict: | |||
r"""Export the optimizer state. | |||
@@ -6,15 +6,86 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import multiprocessing as mp | |||
import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.distributed as dist | |||
from megengine.core import tensor | |||
from megengine.module import BatchNorm1d, BatchNorm2d | |||
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
from megengine.test import assertTensorClose | |||
@pytest.mark.isolated_distributed | |||
def test_syncbn(): | |||
nr_chan = 8 | |||
data_shape = (3, nr_chan, 4, 16) | |||
momentum = 0.9 | |||
eps = 1e-5 | |||
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||
steps = 4 | |||
def worker(rank, data, yv_expect, running_mean, running_var): | |||
if not mge.is_cuda_available(): | |||
return | |||
dist.init_process_group("localhost", 2333, 4, rank, rank) | |||
bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps) | |||
data_tensor = tensor() | |||
for i in range(steps): | |||
data_tensor.set_value(data[i]) | |||
yv = bn(data_tensor) | |||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) | |||
assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) | |||
xv = [] | |||
for i in range(steps): | |||
xv.append(np.random.normal(loc=2.3, size=data_shape).astype(np.float32)) | |||
xv_transposed = np.transpose(xv[i], [0, 2, 3, 1]).reshape( | |||
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||
) | |||
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||
sd = np.sqrt(var_biased + eps) | |||
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) | |||
running_mean = running_mean * momentum + mean * (1 - momentum) | |||
running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||
yv_expect = (xv[i] - mean) / sd | |||
data = [] | |||
for i in range(4): | |||
data.append([]) | |||
for j in range(steps): | |||
data[i].append(xv[j][:, :, :, i * 4 : i * 4 + 4]) | |||
procs = [] | |||
for rank in range(4): | |||
p = mp.Process( | |||
target=worker, | |||
args=( | |||
rank, | |||
data[rank], | |||
yv_expect[:, :, :, rank * 4 : rank * 4 + 4], | |||
running_mean, | |||
running_var, | |||
), | |||
) | |||
p.start() | |||
procs.append(p) | |||
for p in procs: | |||
p.join() | |||
assert p.exitcode == 0 | |||
def test_batchnorm(): | |||
nr_chan = 8 | |||
data_shape = (3, nr_chan, 4) | |||
@@ -64,6 +135,55 @@ def test_batchnorm(): | |||
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||
def test_syncbn1d(): | |||
nr_chan = 8 | |||
data_shape = (3, nr_chan, 4) | |||
momentum = 0.9 | |||
bn = SyncBatchNorm(nr_chan, momentum=momentum) | |||
running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32) | |||
running_var = np.ones((1, nr_chan, 1), dtype=np.float32) | |||
data = tensor() | |||
for i in range(3): | |||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||
xv_transposed = np.transpose(xv, [0, 2, 1]).reshape( | |||
(data_shape[0] * data_shape[2], nr_chan) | |||
) | |||
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1)) | |||
sd = np.sqrt(var_biased + bn.eps) | |||
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1)) | |||
running_mean = running_mean * momentum + mean * (1 - momentum) | |||
running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||
data.set_value(xv) | |||
yv = bn(data) | |||
yv_expect = (xv - mean) / sd | |||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
assertTensorClose( | |||
running_mean.reshape(-1), bn.running_mean.numpy().reshape(-1), max_err=5e-6 | |||
) | |||
assertTensorClose( | |||
running_var.reshape(-1), bn.running_var.numpy().reshape(-1), max_err=5e-6 | |||
) | |||
# test set 'training' flag to False | |||
mean_backup = bn.running_mean.numpy() | |||
var_backup = bn.running_var.numpy() | |||
bn.training = False | |||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
data.set_value(xv) | |||
yv1 = bn(data) | |||
yv2 = bn(data) | |||
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||
assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) | |||
assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) | |||
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) | |||
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||
def test_batchnorm2d(): | |||
nr_chan = 8 | |||
data_shape = (3, nr_chan, 16, 16) | |||
@@ -110,6 +230,52 @@ def test_batchnorm2d(): | |||
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||
def test_syncbn2d(): | |||
nr_chan = 8 | |||
data_shape = (3, nr_chan, 16, 16) | |||
momentum = 0.9 | |||
bn = SyncBatchNorm(nr_chan, momentum=momentum) | |||
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||
data = tensor() | |||
for i in range(3): | |||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||
) | |||
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||
sd = np.sqrt(var_biased + bn.eps) | |||
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) | |||
running_mean = running_mean * momentum + mean * (1 - momentum) | |||
running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||
data.set_value(xv) | |||
yv = bn(data) | |||
yv_expect = (xv - mean) / sd | |||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) | |||
assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) | |||
# test set 'training' flag to False | |||
mean_backup = bn.running_mean.numpy() | |||
var_backup = bn.running_var.numpy() | |||
bn.training = False | |||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
data.set_value(xv) | |||
yv1 = bn(data) | |||
yv2 = bn(data) | |||
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||
assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) | |||
assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) | |||
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) | |||
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||
def test_batchnorm_no_stats(): | |||
nr_chan = 8 | |||
data_shape = (3, nr_chan, 4) | |||
@@ -135,6 +301,31 @@ def test_batchnorm_no_stats(): | |||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
def test_syncbn_no_stats(): | |||
nr_chan = 8 | |||
data_shape = (3, nr_chan, 4) | |||
bn = SyncBatchNorm(8, track_running_stats=False) | |||
data = tensor() | |||
for i in range(4): | |||
if i == 2: | |||
bn.training = False | |||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||
var = np.var( | |||
np.transpose(xv, [0, 2, 1]).reshape( | |||
(data_shape[0] * data_shape[2], nr_chan) | |||
), | |||
axis=0, | |||
).reshape((1, nr_chan, 1)) | |||
sd = np.sqrt(var + bn.eps) | |||
data.set_value(xv) | |||
yv = bn(data) | |||
yv_expect = (xv - mean) / sd | |||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
def test_batchnorm2d_no_stats(): | |||
nr_chan = 8 | |||
data_shape = (3, nr_chan, 16, 16) | |||
@@ -157,3 +348,27 @@ def test_batchnorm2d_no_stats(): | |||
yv_expect = (xv - mean) / sd | |||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
def test_syncbn2d_no_stats(): | |||
nr_chan = 8 | |||
data_shape = (3, nr_chan, 16, 16) | |||
bn = SyncBatchNorm(8, track_running_stats=False) | |||
data = tensor() | |||
for i in range(4): | |||
if i == 2: | |||
bn.training = False | |||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||
) | |||
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||
var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||
sd = np.sqrt(var + bn.eps) | |||
data.set_value(xv) | |||
yv = bn(data) | |||
yv_expect = (xv - mean) / sd | |||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) |