Browse Source

feat(mge): rename batch_norm2d -> batch_norm

GitOrigin-RevId: 253e8564ea
release-1.1
Megvii Engine Team 4 years ago
parent
commit
20e93630f2
3 changed files with 8 additions and 6 deletions
  1. +4
    -2
      imperative/python/megengine/functional/nn.py
  2. +2
    -2
      imperative/python/megengine/module/batchnorm.py
  3. +2
    -2
      imperative/python/test/unit/module/test_module.py

+ 4
- 2
imperative/python/megengine/functional/nn.py View File

@@ -39,7 +39,7 @@ __all__ = [
"adaptive_avg_pool2d", "adaptive_avg_pool2d",
"adaptive_max_pool2d", "adaptive_max_pool2d",
"avg_pool2d", "avg_pool2d",
"batch_norm2d",
"batch_norm",
"conv2d", "conv2d",
"conv_transpose2d", "conv_transpose2d",
"dot", "dot",
@@ -605,7 +605,7 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
return cached / down return cached / down




def batch_norm2d(
def batch_norm(
inp: Tensor, inp: Tensor,
running_mean: Tensor = None, running_mean: Tensor = None,
running_var: Tensor = None, running_var: Tensor = None,
@@ -639,6 +639,8 @@ def batch_norm2d(
Default: True Default: True
:return: output tensor. :return: output tensor.
""" """
if inp.ndim != 4:
raise NotImplementedError("batch_norm for ndim != 4")


def full_value(value): def full_value(value):
C = inp.shape[1] C = inp.shape[1]


+ 2
- 2
imperative/python/megengine/module/batchnorm.py View File

@@ -11,7 +11,7 @@ from typing import Optional
import numpy as np import numpy as np


from ..distributed.group import WORLD, Group from ..distributed.group import WORLD, Group
from ..functional.nn import batch_norm2d, sync_batch_norm
from ..functional.nn import batch_norm, sync_batch_norm
from ..tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
from . import init from . import init
from .module import Module from .module import Module
@@ -96,7 +96,7 @@ class _BatchNorm(Module):
else: else:
exponential_average_factor = 0.0 # useless exponential_average_factor = 0.0 # useless


output = batch_norm2d(
output = batch_norm(
inp, inp,
self.running_mean if self.track_running_stats else None, self.running_mean if self.track_running_stats else None,
self.running_var if self.track_running_stats else None, self.running_var if self.track_running_stats else None,


+ 2
- 2
imperative/python/test/unit/module/test_module.py View File

@@ -327,14 +327,14 @@ def test_module_api_hooks():
assert pre_hook_num == 4 assert pre_hook_num == 4
assert post_hook_num == 4 assert post_hook_num == 4
mean1 = Parameter(np.zeros(shape), dtype=np.float32) mean1 = Parameter(np.zeros(shape), dtype=np.float32)
bn1 = F.batch_norm2d(
bn1 = F.batch_norm(
x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True
) )
np.testing.assert_allclose( np.testing.assert_allclose(
net.i.bn.running_mean.numpy(), mean1.numpy(), net.i.bn.running_mean.numpy(), mean1.numpy(),
) )
mean2 = Parameter(np.zeros(shape), dtype=np.float32) mean2 = Parameter(np.zeros(shape), dtype=np.float32)
bn2 = F.batch_norm2d(
bn2 = F.batch_norm(
bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True
) )
np.testing.assert_allclose( np.testing.assert_allclose(


Loading…
Cancel
Save