From 20e93630f2835f9e614e96ef5c31957edecc1d02 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 10 Oct 2020 21:29:17 +0800 Subject: [PATCH] feat(mge): rename batch_norm2d -> batch_norm GitOrigin-RevId: 253e8564eab59528c3c08170958e7c0b3fe3b1c3 --- imperative/python/megengine/functional/nn.py | 6 ++++-- imperative/python/megengine/module/batchnorm.py | 4 ++-- imperative/python/test/unit/module/test_module.py | 4 ++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 11c95972..3fc2e223 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -39,7 +39,7 @@ __all__ = [ "adaptive_avg_pool2d", "adaptive_max_pool2d", "avg_pool2d", - "batch_norm2d", + "batch_norm", "conv2d", "conv_transpose2d", "dot", @@ -605,7 +605,7 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: return cached / down -def batch_norm2d( +def batch_norm( inp: Tensor, running_mean: Tensor = None, running_var: Tensor = None, @@ -639,6 +639,8 @@ def batch_norm2d( Default: True :return: output tensor. """ + if inp.ndim != 4: + raise NotImplementedError("batch_norm for ndim != 4") def full_value(value): C = inp.shape[1] diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index 5906a459..9f2d7bd1 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -11,7 +11,7 @@ from typing import Optional import numpy as np from ..distributed.group import WORLD, Group -from ..functional.nn import batch_norm2d, sync_batch_norm +from ..functional.nn import batch_norm, sync_batch_norm from ..tensor import Parameter, Tensor from . import init from .module import Module @@ -96,7 +96,7 @@ class _BatchNorm(Module): else: exponential_average_factor = 0.0 # useless - output = batch_norm2d( + output = batch_norm( inp, self.running_mean if self.track_running_stats else None, self.running_var if self.track_running_stats else None, diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index 333498d0..99712413 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -327,14 +327,14 @@ def test_module_api_hooks(): assert pre_hook_num == 4 assert post_hook_num == 4 mean1 = Parameter(np.zeros(shape), dtype=np.float32) - bn1 = F.batch_norm2d( + bn1 = F.batch_norm( x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True ) np.testing.assert_allclose( net.i.bn.running_mean.numpy(), mean1.numpy(), ) mean2 = Parameter(np.zeros(shape), dtype=np.float32) - bn2 = F.batch_norm2d( + bn2 = F.batch_norm( bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True ) np.testing.assert_allclose(