|
|
@@ -13,6 +13,7 @@ import numpy as np |
|
|
|
import pytest |
|
|
|
|
|
|
|
import megengine as mge |
|
|
|
import megengine.amp as amp |
|
|
|
import megengine.distributed as dist |
|
|
|
from megengine import Tensor, jit |
|
|
|
from megengine.autodiff.grad_manager import GradManager |
|
|
@@ -24,7 +25,8 @@ _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol |
|
|
|
|
|
|
|
@pytest.mark.require_ngpu(2) |
|
|
|
@pytest.mark.isolated_distributed |
|
|
|
def test_syncbn(): |
|
|
|
@pytest.mark.parametrize("enable_amp", [False, True]) |
|
|
|
def test_syncbn(enable_amp): |
|
|
|
nr_chan = 8 |
|
|
|
data_shape = (3, nr_chan, 4, 16) |
|
|
|
momentum = 0.9 |
|
|
@@ -38,12 +40,17 @@ def test_syncbn(): |
|
|
|
|
|
|
|
@dist.launcher(n_gpus=2) |
|
|
|
def worker(data, yv_expect, running_mean, running_var): |
|
|
|
rank = dist.get_rank() |
|
|
|
bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps) |
|
|
|
for i in range(steps): |
|
|
|
yv = bn(Tensor(data[rank][i])) |
|
|
|
|
|
|
|
_assert_allclose(yv.numpy(), yv_expect[rank]) |
|
|
|
with amp.autocast(enabled=enable_amp): |
|
|
|
rank = dist.get_rank() |
|
|
|
bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps) |
|
|
|
for i in range(steps): |
|
|
|
yv = bn(Tensor(data[rank][i])) |
|
|
|
if enable_amp: |
|
|
|
np.testing.assert_allclose( |
|
|
|
yv.numpy(), yv_expect[rank], atol=5e-4, rtol=5e-4 |
|
|
|
) |
|
|
|
else: |
|
|
|
_assert_allclose(yv.numpy(), yv_expect[rank]) |
|
|
|
_assert_allclose(bn.running_mean.numpy(), running_mean) |
|
|
|
_assert_allclose(bn.running_var.numpy(), running_var) |
|
|
|
|
|
|
|