@@ -39,7 +39,7 @@ __all__ = [ | |||||
"adaptive_avg_pool2d", | "adaptive_avg_pool2d", | ||||
"adaptive_max_pool2d", | "adaptive_max_pool2d", | ||||
"avg_pool2d", | "avg_pool2d", | ||||
"batch_norm2d", | |||||
"batch_norm", | |||||
"conv2d", | "conv2d", | ||||
"conv_transpose2d", | "conv_transpose2d", | ||||
"dot", | "dot", | ||||
@@ -605,7 +605,7 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: | |||||
return cached / down | return cached / down | ||||
def batch_norm2d( | |||||
def batch_norm( | |||||
inp: Tensor, | inp: Tensor, | ||||
running_mean: Tensor = None, | running_mean: Tensor = None, | ||||
running_var: Tensor = None, | running_var: Tensor = None, | ||||
@@ -639,6 +639,8 @@ def batch_norm2d( | |||||
Default: True | Default: True | ||||
:return: output tensor. | :return: output tensor. | ||||
""" | """ | ||||
if inp.ndim != 4: | |||||
raise NotImplementedError("batch_norm for ndim != 4") | |||||
def full_value(value): | def full_value(value): | ||||
C = inp.shape[1] | C = inp.shape[1] | ||||
@@ -11,7 +11,7 @@ from typing import Optional | |||||
import numpy as np | import numpy as np | ||||
from ..distributed.group import WORLD, Group | from ..distributed.group import WORLD, Group | ||||
from ..functional.nn import batch_norm2d, sync_batch_norm | |||||
from ..functional.nn import batch_norm, sync_batch_norm | |||||
from ..tensor import Parameter, Tensor | from ..tensor import Parameter, Tensor | ||||
from . import init | from . import init | ||||
from .module import Module | from .module import Module | ||||
@@ -96,7 +96,7 @@ class _BatchNorm(Module): | |||||
else: | else: | ||||
exponential_average_factor = 0.0 # useless | exponential_average_factor = 0.0 # useless | ||||
output = batch_norm2d( | |||||
output = batch_norm( | |||||
inp, | inp, | ||||
self.running_mean if self.track_running_stats else None, | self.running_mean if self.track_running_stats else None, | ||||
self.running_var if self.track_running_stats else None, | self.running_var if self.track_running_stats else None, | ||||
@@ -327,14 +327,14 @@ def test_module_api_hooks(): | |||||
assert pre_hook_num == 4 | assert pre_hook_num == 4 | ||||
assert post_hook_num == 4 | assert post_hook_num == 4 | ||||
mean1 = Parameter(np.zeros(shape), dtype=np.float32) | mean1 = Parameter(np.zeros(shape), dtype=np.float32) | ||||
bn1 = F.batch_norm2d( | |||||
bn1 = F.batch_norm( | |||||
x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True | x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True | ||||
) | ) | ||||
np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
net.i.bn.running_mean.numpy(), mean1.numpy(), | net.i.bn.running_mean.numpy(), mean1.numpy(), | ||||
) | ) | ||||
mean2 = Parameter(np.zeros(shape), dtype=np.float32) | mean2 = Parameter(np.zeros(shape), dtype=np.float32) | ||||
bn2 = F.batch_norm2d( | |||||
bn2 = F.batch_norm( | |||||
bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True | bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True | ||||
) | ) | ||||
np.testing.assert_allclose( | np.testing.assert_allclose( | ||||