diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 11c95972..3fc2e223 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -39,7 +39,7 @@ __all__ = [ "adaptive_avg_pool2d", "adaptive_max_pool2d", "avg_pool2d", - "batch_norm2d", + "batch_norm", "conv2d", "conv_transpose2d", "dot", @@ -605,7 +605,7 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: return cached / down -def batch_norm2d( +def batch_norm( inp: Tensor, running_mean: Tensor = None, running_var: Tensor = None, @@ -639,6 +639,8 @@ def batch_norm2d( Default: True :return: output tensor. """ + if inp.ndim != 4: + raise NotImplementedError("batch_norm for ndim != 4") def full_value(value): C = inp.shape[1] diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index 5906a459..9f2d7bd1 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -11,7 +11,7 @@ from typing import Optional import numpy as np from ..distributed.group import WORLD, Group -from ..functional.nn import batch_norm2d, sync_batch_norm +from ..functional.nn import batch_norm, sync_batch_norm from ..tensor import Parameter, Tensor from . import init from .module import Module @@ -96,7 +96,7 @@ class _BatchNorm(Module): else: exponential_average_factor = 0.0 # useless - output = batch_norm2d( + output = batch_norm( inp, self.running_mean if self.track_running_stats else None, self.running_var if self.track_running_stats else None, diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index 333498d0..99712413 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -327,14 +327,14 @@ def test_module_api_hooks(): assert pre_hook_num == 4 assert post_hook_num == 4 mean1 = Parameter(np.zeros(shape), dtype=np.float32) - bn1 = F.batch_norm2d( + bn1 = F.batch_norm( x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True ) np.testing.assert_allclose( net.i.bn.running_mean.numpy(), mean1.numpy(), ) mean2 = Parameter(np.zeros(shape), dtype=np.float32) - bn2 = F.batch_norm2d( + bn2 = F.batch_norm( bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True ) np.testing.assert_allclose(