@@ -35,6 +35,10 @@ class _BatchNorm(Module): | |||
self.track_running_stats = track_running_stats | |||
self._track_running_stats_saved = track_running_stats | |||
self.freeze = freeze | |||
if self.freeze: | |||
assert ( | |||
self._track_running_stats_saved | |||
), "track_running_stats must be initilized to True if freeze is True" | |||
tshape = (1, self.num_features, 1, 1) | |||
if self.affine: | |||
self.weight = Parameter(np.ones(tshape, dtype=np.float32)) | |||
@@ -84,10 +88,24 @@ class _BatchNorm(Module): | |||
inp = inp.reshape(new_shape) | |||
if self.freeze and self.training and self._track_running_stats_saved: | |||
scale = self.weight * (self.running_var + self.eps) ** (-0.5) | |||
bias = self.bias - self.running_mean * scale | |||
return inp * scale.detach() + bias.detach() | |||
_weight = self.weight | |||
_bias = self.bias | |||
if self.freeze: | |||
if _weight is not None: | |||
_weight = _weight.detach() | |||
if _bias is not None: | |||
_bias = _bias.detach() | |||
# Need to expand to elementwise operations here | |||
# see MGB_IMPL_OPR_GRAD(BatchNormForward) in src/opr/impl/dnn/batch_norm.cpp | |||
scale = (self.running_var + self.eps) ** (-0.5) | |||
if _weight is not None: | |||
scale *= _weight | |||
bias = -self.running_mean * scale | |||
if _bias is not None: | |||
bias += _bias | |||
return inp * scale + bias | |||
if self.training and self.track_running_stats: | |||
exponential_average_factor = self.momentum | |||
@@ -98,8 +116,8 @@ class _BatchNorm(Module): | |||
inp, | |||
self.running_mean if self.track_running_stats else None, | |||
self.running_var if self.track_running_stats else None, | |||
self.weight, | |||
self.bias, | |||
_weight, | |||
_bias, | |||
training=self.training | |||
or ((self.running_mean is None) and (self.running_var is None)), | |||
momentum=exponential_average_factor, | |||
@@ -121,7 +139,7 @@ class _BatchNorm(Module): | |||
class SyncBatchNorm(_BatchNorm): | |||
r""" | |||
Applies Synchronization Batch Normalization. | |||
Applies Synchronized Batch Normalization for distributed training. | |||
""" | |||
def __init__( | |||
@@ -169,15 +187,25 @@ class SyncBatchNorm(_BatchNorm): | |||
else: | |||
exponential_average_factor = 0.0 # useless | |||
_weight = self.weight | |||
_bias = self.bias | |||
if self.freeze: | |||
if _weight is not None: | |||
_weight = _weight.detach() | |||
if _bias is not None: | |||
_bias = _bias.detach() | |||
output = sync_batch_norm( | |||
inp, | |||
self.running_mean, | |||
self.running_var, | |||
self.weight, | |||
self.bias, | |||
self.training or not self.track_running_stats, | |||
exponential_average_factor, | |||
self.eps, | |||
_weight, | |||
_bias, | |||
training=(self.training and not self.freeze) | |||
or ((self.running_mean is None) and (self.running_var is None)), | |||
momentum=exponential_average_factor, | |||
eps=self.eps, | |||
group=self.group, | |||
) | |||
@@ -257,8 +285,7 @@ class BatchNorm2d(_BatchNorm): | |||
:param freeze: when set to True, this module does not update the | |||
running mean and variance, and uses the running mean and variance instead of | |||
the batch mean and batch variance to normalize the input. The parameter takes effect | |||
only when the module is initilized with track_running_stats as True and | |||
the module is in training mode. | |||
only when the module is initilized with track_running_stats as True. | |||
Default: False | |||
Examples: | |||
@@ -11,15 +11,23 @@ import pytest | |||
import megengine | |||
import megengine.autodiff as ad | |||
import megengine.distributed as dist | |||
import megengine.functional as F | |||
import megengine.optimizer as optimizer | |||
from megengine import Parameter, tensor | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.jit import trace | |||
from megengine.module import BatchNorm2d, Module | |||
from megengine.module import BatchNorm2d, Module, SyncBatchNorm | |||
def test_frozen_bn(): | |||
def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False): | |||
nchannel = 3 | |||
m = BatchNorm2d(nchannel, freeze=True) | |||
m = BNModule(nchannel, freeze=True) | |||
var = 4.0 | |||
bias = 1.0 | |||
shape = (1, nchannel, 1, 1) | |||
m.running_var[...] = var * F.ones(shape) | |||
m.running_mean[...] = bias * F.ones(shape) | |||
saved_var = m.running_var.numpy() | |||
saved_mean = m.running_mean.numpy() | |||
@@ -31,16 +39,45 @@ def test_frozen_bn(): | |||
optim.clear_grad() | |||
data = np.random.random((6, nchannel, 2, 2)).astype("float32") | |||
with gm: | |||
loss = m(data).mean() | |||
gm.backward(loss) | |||
optim.step() | |||
np.testing.assert_equal(m.running_var.numpy(), saved_var) | |||
np.testing.assert_equal(m.running_mean.numpy(), saved_mean) | |||
np.testing.assert_equal(m.weight.numpy(), saved_wt) | |||
np.testing.assert_equal(m.bias.numpy(), saved_bias) | |||
np.testing.assert_almost_equal(loss.numpy(), data.mean(), 5) | |||
def train_fn(d): | |||
for _ in range(3): | |||
with gm: | |||
loss = m(d).mean() | |||
gm.backward(loss) | |||
optim.step() | |||
return loss | |||
if use_trace: | |||
train_fn = trace(train_fn, symbolic=use_symbolic) | |||
for _ in range(3): | |||
loss = train_fn(megengine.Tensor(data)) | |||
np.testing.assert_equal(m.running_var.numpy(), saved_var) | |||
np.testing.assert_equal(m.running_mean.numpy(), saved_mean) | |||
np.testing.assert_equal(m.weight.numpy(), saved_wt) | |||
np.testing.assert_equal(m.bias.numpy(), saved_bias) | |||
np.testing.assert_almost_equal( | |||
loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5 | |||
) | |||
def test_frozen_bn(): | |||
run_frozen_bn(BatchNorm2d) | |||
run_frozen_bn(BatchNorm2d, True, False) | |||
run_frozen_bn(BatchNorm2d, True, True) | |||
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||
@pytest.mark.isolated_distributed | |||
def test_frozen_synced_bn(): | |||
@dist.launcher(n_gpus=2) | |||
def worker(): | |||
run_frozen_bn(SyncBatchNorm) | |||
run_frozen_bn(SyncBatchNorm, True, False) | |||
run_frozen_bn(SyncBatchNorm, True, True) | |||
worker() | |||
def test_bn_no_track_stat(): | |||
@@ -112,3 +149,11 @@ def test_trace_bn_forward_twice(): | |||
x = np.ones((1, 1, 32, 32), dtype=np.float32) | |||
y = train_bn(x, net=Simple()) | |||
np.testing.assert_equal(y.numpy(), 0) | |||
# https://github.com/MegEngine/MegEngine/issues/145 | |||
def test_frozen_bn_no_affine(): | |||
nchannel = 3 | |||
m = BatchNorm2d(nchannel, freeze=True, affine=False) | |||
data = megengine.Tensor(np.random.random((6, nchannel, 2, 2)).astype("float32")) | |||
m(data).numpy() |