Browse Source

fix(imperative): syncbn fp16 support

GitOrigin-RevId: 6059d5b76b
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
406115dba0
2 changed files with 21 additions and 8 deletions
  1. +7
    -1
      imperative/python/megengine/functional/nn.py
  2. +14
    -7
      imperative/python/test/unit/module/test_batchnorm.py

+ 7
- 1
imperative/python/megengine/functional/nn.py View File

@@ -1385,6 +1385,11 @@ def sync_batch_norm(
momentum=momentum, momentum=momentum,
eps=eps, eps=eps,
) )
if amp._enabled:
inp, weight, bias, running_mean, running_var = cast_tensors(
inp, weight, bias, running_mean, running_var, promote=True
)

_channels = make_shape_tuple(inp.shape)[1] _channels = make_shape_tuple(inp.shape)[1]
_ndim = inp.ndim _ndim = inp.ndim
_device = inp.device _device = inp.device
@@ -1464,7 +1469,8 @@ def sync_batch_norm(
channel_x2s, channel_x2s,
channel_mean, channel_mean,
) )

if amp._enabled:
outvar = outvar.astype("float16")
return outvar return outvar






+ 14
- 7
imperative/python/test/unit/module/test_batchnorm.py View File

@@ -13,6 +13,7 @@ import numpy as np
import pytest import pytest


import megengine as mge import megengine as mge
import megengine.amp as amp
import megengine.distributed as dist import megengine.distributed as dist
from megengine import Tensor, jit from megengine import Tensor, jit
from megengine.autodiff.grad_manager import GradManager from megengine.autodiff.grad_manager import GradManager
@@ -24,7 +25,8 @@ _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol


@pytest.mark.require_ngpu(2) @pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_syncbn():
@pytest.mark.parametrize("enable_amp", [False, True])
def test_syncbn(enable_amp):
nr_chan = 8 nr_chan = 8
data_shape = (3, nr_chan, 4, 16) data_shape = (3, nr_chan, 4, 16)
momentum = 0.9 momentum = 0.9
@@ -38,12 +40,17 @@ def test_syncbn():


@dist.launcher(n_gpus=2) @dist.launcher(n_gpus=2)
def worker(data, yv_expect, running_mean, running_var): def worker(data, yv_expect, running_mean, running_var):
rank = dist.get_rank()
bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps)
for i in range(steps):
yv = bn(Tensor(data[rank][i]))

_assert_allclose(yv.numpy(), yv_expect[rank])
with amp.autocast(enabled=enable_amp):
rank = dist.get_rank()
bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps)
for i in range(steps):
yv = bn(Tensor(data[rank][i]))
if enable_amp:
np.testing.assert_allclose(
yv.numpy(), yv_expect[rank], atol=5e-4, rtol=5e-4
)
else:
_assert_allclose(yv.numpy(), yv_expect[rank])
_assert_allclose(bn.running_mean.numpy(), running_mean) _assert_allclose(bn.running_mean.numpy(), running_mean)
_assert_allclose(bn.running_var.numpy(), running_var) _assert_allclose(bn.running_var.numpy(), running_var)




Loading…
Cancel
Save