Browse Source

fix(mge/imperative): fix syncbn in symbolic mode

GitOrigin-RevId: a9794318a7
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
a226f02e7a
2 changed files with 20 additions and 20 deletions
  1. +20
    -15
      imperative/python/megengine/functional/nn.py
  2. +0
    -5
      imperative/python/test/unit/module/test_batchnorm.py

+ 20
- 15
imperative/python/megengine/functional/nn.py View File

@@ -22,7 +22,7 @@ from .debug_param import get_conv_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, max, sum from .math import argsort, max, sum
from .tensor import add_axis, broadcast, concat, full, remove_axis, reshape
from .tensor import add_axis, broadcast, concat, remove_axis, reshape
from .types import _pair, _pair_nonzero from .types import _pair, _pair_nonzero


__all__ = [ __all__ = [
@@ -692,7 +692,7 @@ def batch_norm2d(




def sync_batch_norm( def sync_batch_norm(
input: Tensor,
inp: Tensor,
running_mean: Tensor, running_mean: Tensor,
running_var: Tensor, running_var: Tensor,
weight: Optional[Tensor] = None, weight: Optional[Tensor] = None,
@@ -723,25 +723,30 @@ def sync_batch_norm(
Default: 1e-5. Default: 1e-5.
""" """
assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode) assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode)
_channels = input.shape[1]
_ndim = input.ndim
_channels = inp.shape[1]
_ndim = inp.ndim
_device = inp.device
_dtype = inp.dtype
_param_shape = (1, _channels) + (1,) * (_ndim - 2) _param_shape = (1, _channels) + (1,) * (_ndim - 2)
_reduce_axis = [0] + [i for i in range(2, _ndim)]


if training: if training:


def _sum_on_channel(input):
return apply(builtin.Reduce(mode="SUM"), input, Tensor(_param_shape))[0]
def _sum_on_channel(inp):
return inp.sum(axis=_reduce_axis, keepdims=True)


reduce_size = input.shape[0]
reduce_size = inp.shape[0]
for i in range(2, _ndim): for i in range(2, _ndim):
reduce_size = reduce_size * input.shape[i]
channel_x1s = _sum_on_channel(input)
channel_x2s = _sum_on_channel(input ** 2)
reduce_size = reduce_size * inp.shape[i]
channel_x1s = _sum_on_channel(inp)
channel_x2s = _sum_on_channel(inp ** 2)


if is_distributed(): if is_distributed():
# reduce all nodes' data to calculate mean and variance # reduce all nodes' data to calculate mean and variance
reduce_size = full([1 for _ in range(_ndim)], reduce_size)
stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1)
reduce_size = broadcast(Tensor(reduce_size, dtype=_dtype), [1] * _ndim)
stat = concat(
[reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1
)
stat = all_reduce_sum(stat, group) stat = all_reduce_sum(stat, group)
reduce_size = stat[:, :1].reshape(1) reduce_size = stat[:, :1].reshape(1)
channel_x1s = stat[:, 1 : 1 + _channels] channel_x1s = stat[:, 1 : 1 + _channels]
@@ -775,11 +780,11 @@ def sync_batch_norm(
inv_var_wt = invsqrt_channel_variance * weight inv_var_wt = invsqrt_channel_variance * weight
neg_channel_mean = -channel_mean neg_channel_mean = -channel_mean
if bias is not None: if bias is not None:
outvar = input * inv_var_wt + (neg_channel_mean * inv_var_wt + bias)
outvar = inp * inv_var_wt + (neg_channel_mean * inv_var_wt + bias)
else: else:
outvar = input * inv_var_wt + neg_channel_mean * inv_var_wt
outvar = inp * inv_var_wt + neg_channel_mean * inv_var_wt
else: else:
outvar = input * invsqrt_channel_variance + (
outvar = inp * invsqrt_channel_variance + (
-channel_mean * invsqrt_channel_variance -channel_mean * invsqrt_channel_variance
) )
if bias is not None: if bias is not None:


+ 0
- 5
imperative/python/test/unit/module/test_batchnorm.py View File

@@ -27,7 +27,6 @@ from megengine.test import assertTensorClose
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
) )
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_syncbn(): def test_syncbn():
nr_chan = 8 nr_chan = 8
@@ -154,7 +153,6 @@ def test_batchnorm():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
) )
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_syncbn1d(): def test_syncbn1d():
nr_chan = 8 nr_chan = 8
@@ -257,7 +255,6 @@ def test_batchnorm2d():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
) )
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_syncbn2d(): def test_syncbn2d():
nr_chan = 8 nr_chan = 8
@@ -336,7 +333,6 @@ def test_batchnorm_no_stats():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
) )
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_syncbn_no_stats(): def test_syncbn_no_stats():
nr_chan = 8 nr_chan = 8
@@ -393,7 +389,6 @@ def test_batchnorm2d_no_stats():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
) )
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_syncbn2d_no_stats(): def test_syncbn2d_no_stats():
nr_chan = 8 nr_chan = 8


Loading…
Cancel
Save