|
|
@@ -19,9 +19,13 @@ from megengine.jit import trace |
|
|
|
from megengine.module import BatchNorm2d, Conv2d, Module, Sequential, SyncBatchNorm |
|
|
|
|
|
|
|
|
|
|
|
def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False): |
|
|
|
def run_frozen_bn(BNModule, is_training, use_trace, use_symbolic): |
|
|
|
nchannel = 3 |
|
|
|
m = BNModule(nchannel, freeze=True) |
|
|
|
if is_training: |
|
|
|
m.train() |
|
|
|
else: |
|
|
|
m.eval() |
|
|
|
var = 4.0 |
|
|
|
bias = 1.0 |
|
|
|
shape = (1, nchannel, 1, 1) |
|
|
@@ -51,30 +55,33 @@ def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False): |
|
|
|
train_fn = trace(train_fn, symbolic=use_symbolic) |
|
|
|
|
|
|
|
for _ in range(3): |
|
|
|
loss = train_fn(megengine.Tensor(data)) |
|
|
|
np.testing.assert_equal(m.running_var.numpy(), saved_var) |
|
|
|
np.testing.assert_equal(m.running_mean.numpy(), saved_mean) |
|
|
|
loss = train_fn(megengine.tensor(data)) |
|
|
|
if not is_training: |
|
|
|
np.testing.assert_equal(m.running_var.numpy(), saved_var) |
|
|
|
np.testing.assert_equal(m.running_mean.numpy(), saved_mean) |
|
|
|
np.testing.assert_almost_equal( |
|
|
|
loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5 |
|
|
|
) |
|
|
|
np.testing.assert_equal(m.weight.numpy(), saved_wt) |
|
|
|
np.testing.assert_equal(m.bias.numpy(), saved_bias) |
|
|
|
np.testing.assert_almost_equal( |
|
|
|
loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5 |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def test_frozen_bn(): |
|
|
|
run_frozen_bn(BatchNorm2d) |
|
|
|
run_frozen_bn(BatchNorm2d, True, False) |
|
|
|
run_frozen_bn(BatchNorm2d, True, True) |
|
|
|
@pytest.mark.parametrize("is_training", [False, True]) |
|
|
|
@pytest.mark.parametrize("use_trace", [False, True]) |
|
|
|
@pytest.mark.parametrize("use_symbolic", [False, True]) |
|
|
|
def test_frozen_bn(is_training, use_trace, use_symbolic): |
|
|
|
run_frozen_bn(BatchNorm2d, is_training, use_trace, use_symbolic) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.require_ngpu(2) |
|
|
|
@pytest.mark.isolated_distributed |
|
|
|
def test_frozen_synced_bn(): |
|
|
|
@pytest.mark.parametrize("is_training", [False, True]) |
|
|
|
@pytest.mark.parametrize("use_trace", [False, True]) |
|
|
|
@pytest.mark.parametrize("use_symbolic", [False, True]) |
|
|
|
def test_frozen_synced_bn(is_training, use_trace, use_symbolic): |
|
|
|
@dist.launcher(n_gpus=2) |
|
|
|
def worker(): |
|
|
|
run_frozen_bn(SyncBatchNorm) |
|
|
|
run_frozen_bn(SyncBatchNorm, True, False) |
|
|
|
run_frozen_bn(SyncBatchNorm, True, True) |
|
|
|
run_frozen_bn(SyncBatchNorm, is_training, use_trace, use_symbolic) |
|
|
|
|
|
|
|
worker() |
|
|
|
|
|
|
@@ -190,8 +197,13 @@ def test_trace_several_syncbn(trace_mode): |
|
|
|
|
|
|
|
|
|
|
|
# https://github.com/MegEngine/MegEngine/issues/145 |
|
|
|
def test_frozen_bn_no_affine(): |
|
|
|
@pytest.mark.parametrize("is_training", [False, True]) |
|
|
|
def test_frozen_bn_no_affine(is_training): |
|
|
|
nchannel = 3 |
|
|
|
m = BatchNorm2d(nchannel, freeze=True, affine=False) |
|
|
|
data = tensor(np.random.random((6, nchannel, 2, 2)).astype("float32")) |
|
|
|
if is_training: |
|
|
|
m.train() |
|
|
|
else: |
|
|
|
m.eval() |
|
|
|
data = megengine.tensor(np.random.random((6, nchannel, 2, 2)).astype("float32")) |
|
|
|
m(data).numpy() |