@@ -35,6 +35,10 @@ class _BatchNorm(Module): | |||||
self.track_running_stats = track_running_stats | self.track_running_stats = track_running_stats | ||||
self._track_running_stats_saved = track_running_stats | self._track_running_stats_saved = track_running_stats | ||||
self.freeze = freeze | self.freeze = freeze | ||||
if self.freeze: | |||||
assert ( | |||||
self._track_running_stats_saved | |||||
), "track_running_stats must be initilized to True if freeze is True" | |||||
tshape = (1, self.num_features, 1, 1) | tshape = (1, self.num_features, 1, 1) | ||||
if self.affine: | if self.affine: | ||||
self.weight = Parameter(np.ones(tshape, dtype=np.float32)) | self.weight = Parameter(np.ones(tshape, dtype=np.float32)) | ||||
@@ -84,10 +88,24 @@ class _BatchNorm(Module): | |||||
inp = inp.reshape(new_shape) | inp = inp.reshape(new_shape) | ||||
if self.freeze and self.training and self._track_running_stats_saved: | |||||
scale = self.weight * (self.running_var + self.eps) ** (-0.5) | |||||
bias = self.bias - self.running_mean * scale | |||||
return inp * scale.detach() + bias.detach() | |||||
_weight = self.weight | |||||
_bias = self.bias | |||||
if self.freeze: | |||||
if _weight is not None: | |||||
_weight = _weight.detach() | |||||
if _bias is not None: | |||||
_bias = _bias.detach() | |||||
# Need to expand to elementwise operations here | |||||
# see MGB_IMPL_OPR_GRAD(BatchNormForward) in src/opr/impl/dnn/batch_norm.cpp | |||||
scale = (self.running_var + self.eps) ** (-0.5) | |||||
if _weight is not None: | |||||
scale *= _weight | |||||
bias = -self.running_mean * scale | |||||
if _bias is not None: | |||||
bias += _bias | |||||
return inp * scale + bias | |||||
if self.training and self.track_running_stats: | if self.training and self.track_running_stats: | ||||
exponential_average_factor = self.momentum | exponential_average_factor = self.momentum | ||||
@@ -98,8 +116,8 @@ class _BatchNorm(Module): | |||||
inp, | inp, | ||||
self.running_mean if self.track_running_stats else None, | self.running_mean if self.track_running_stats else None, | ||||
self.running_var if self.track_running_stats else None, | self.running_var if self.track_running_stats else None, | ||||
self.weight, | |||||
self.bias, | |||||
_weight, | |||||
_bias, | |||||
training=self.training | training=self.training | ||||
or ((self.running_mean is None) and (self.running_var is None)), | or ((self.running_mean is None) and (self.running_var is None)), | ||||
momentum=exponential_average_factor, | momentum=exponential_average_factor, | ||||
@@ -121,7 +139,7 @@ class _BatchNorm(Module): | |||||
class SyncBatchNorm(_BatchNorm): | class SyncBatchNorm(_BatchNorm): | ||||
r""" | r""" | ||||
Applies Synchronization Batch Normalization. | |||||
Applies Synchronized Batch Normalization for distributed training. | |||||
""" | """ | ||||
def __init__( | def __init__( | ||||
@@ -169,15 +187,25 @@ class SyncBatchNorm(_BatchNorm): | |||||
else: | else: | ||||
exponential_average_factor = 0.0 # useless | exponential_average_factor = 0.0 # useless | ||||
_weight = self.weight | |||||
_bias = self.bias | |||||
if self.freeze: | |||||
if _weight is not None: | |||||
_weight = _weight.detach() | |||||
if _bias is not None: | |||||
_bias = _bias.detach() | |||||
output = sync_batch_norm( | output = sync_batch_norm( | ||||
inp, | inp, | ||||
self.running_mean, | self.running_mean, | ||||
self.running_var, | self.running_var, | ||||
self.weight, | |||||
self.bias, | |||||
self.training or not self.track_running_stats, | |||||
exponential_average_factor, | |||||
self.eps, | |||||
_weight, | |||||
_bias, | |||||
training=(self.training and not self.freeze) | |||||
or ((self.running_mean is None) and (self.running_var is None)), | |||||
momentum=exponential_average_factor, | |||||
eps=self.eps, | |||||
group=self.group, | group=self.group, | ||||
) | ) | ||||
@@ -257,8 +285,7 @@ class BatchNorm2d(_BatchNorm): | |||||
:param freeze: when set to True, this module does not update the | :param freeze: when set to True, this module does not update the | ||||
running mean and variance, and uses the running mean and variance instead of | running mean and variance, and uses the running mean and variance instead of | ||||
the batch mean and batch variance to normalize the input. The parameter takes effect | the batch mean and batch variance to normalize the input. The parameter takes effect | ||||
only when the module is initilized with track_running_stats as True and | |||||
the module is in training mode. | |||||
only when the module is initilized with track_running_stats as True. | |||||
Default: False | Default: False | ||||
Examples: | Examples: | ||||
@@ -11,15 +11,23 @@ import pytest | |||||
import megengine | import megengine | ||||
import megengine.autodiff as ad | import megengine.autodiff as ad | ||||
import megengine.distributed as dist | |||||
import megengine.functional as F | |||||
import megengine.optimizer as optimizer | import megengine.optimizer as optimizer | ||||
from megengine import Parameter, tensor | from megengine import Parameter, tensor | ||||
from megengine.distributed.helper import get_device_count_by_fork | |||||
from megengine.jit import trace | from megengine.jit import trace | ||||
from megengine.module import BatchNorm2d, Module | |||||
from megengine.module import BatchNorm2d, Module, SyncBatchNorm | |||||
def test_frozen_bn(): | |||||
def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False): | |||||
nchannel = 3 | nchannel = 3 | ||||
m = BatchNorm2d(nchannel, freeze=True) | |||||
m = BNModule(nchannel, freeze=True) | |||||
var = 4.0 | |||||
bias = 1.0 | |||||
shape = (1, nchannel, 1, 1) | |||||
m.running_var[...] = var * F.ones(shape) | |||||
m.running_mean[...] = bias * F.ones(shape) | |||||
saved_var = m.running_var.numpy() | saved_var = m.running_var.numpy() | ||||
saved_mean = m.running_mean.numpy() | saved_mean = m.running_mean.numpy() | ||||
@@ -31,16 +39,45 @@ def test_frozen_bn(): | |||||
optim.clear_grad() | optim.clear_grad() | ||||
data = np.random.random((6, nchannel, 2, 2)).astype("float32") | data = np.random.random((6, nchannel, 2, 2)).astype("float32") | ||||
with gm: | |||||
loss = m(data).mean() | |||||
gm.backward(loss) | |||||
optim.step() | |||||
np.testing.assert_equal(m.running_var.numpy(), saved_var) | |||||
np.testing.assert_equal(m.running_mean.numpy(), saved_mean) | |||||
np.testing.assert_equal(m.weight.numpy(), saved_wt) | |||||
np.testing.assert_equal(m.bias.numpy(), saved_bias) | |||||
np.testing.assert_almost_equal(loss.numpy(), data.mean(), 5) | |||||
def train_fn(d): | |||||
for _ in range(3): | |||||
with gm: | |||||
loss = m(d).mean() | |||||
gm.backward(loss) | |||||
optim.step() | |||||
return loss | |||||
if use_trace: | |||||
train_fn = trace(train_fn, symbolic=use_symbolic) | |||||
for _ in range(3): | |||||
loss = train_fn(megengine.Tensor(data)) | |||||
np.testing.assert_equal(m.running_var.numpy(), saved_var) | |||||
np.testing.assert_equal(m.running_mean.numpy(), saved_mean) | |||||
np.testing.assert_equal(m.weight.numpy(), saved_wt) | |||||
np.testing.assert_equal(m.bias.numpy(), saved_bias) | |||||
np.testing.assert_almost_equal( | |||||
loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5 | |||||
) | |||||
def test_frozen_bn(): | |||||
run_frozen_bn(BatchNorm2d) | |||||
run_frozen_bn(BatchNorm2d, True, False) | |||||
run_frozen_bn(BatchNorm2d, True, True) | |||||
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||||
@pytest.mark.isolated_distributed | |||||
def test_frozen_synced_bn(): | |||||
@dist.launcher(n_gpus=2) | |||||
def worker(): | |||||
run_frozen_bn(SyncBatchNorm) | |||||
run_frozen_bn(SyncBatchNorm, True, False) | |||||
run_frozen_bn(SyncBatchNorm, True, True) | |||||
worker() | |||||
def test_bn_no_track_stat(): | def test_bn_no_track_stat(): | ||||
@@ -112,3 +149,11 @@ def test_trace_bn_forward_twice(): | |||||
x = np.ones((1, 1, 32, 32), dtype=np.float32) | x = np.ones((1, 1, 32, 32), dtype=np.float32) | ||||
y = train_bn(x, net=Simple()) | y = train_bn(x, net=Simple()) | ||||
np.testing.assert_equal(y.numpy(), 0) | np.testing.assert_equal(y.numpy(), 0) | ||||
# https://github.com/MegEngine/MegEngine/issues/145 | |||||
def test_frozen_bn_no_affine(): | |||||
nchannel = 3 | |||||
m = BatchNorm2d(nchannel, freeze=True, affine=False) | |||||
data = megengine.Tensor(np.random.random((6, nchannel, 2, 2)).astype("float32")) | |||||
m(data).numpy() |