|
|
@@ -327,14 +327,14 @@ def test_module_api_hooks(): |
|
|
|
assert pre_hook_num == 4 |
|
|
|
assert post_hook_num == 4 |
|
|
|
mean1 = Parameter(np.zeros(shape), dtype=np.float32) |
|
|
|
bn1 = F.batch_norm2d( |
|
|
|
bn1 = F.batch_norm( |
|
|
|
x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True |
|
|
|
) |
|
|
|
np.testing.assert_allclose( |
|
|
|
net.i.bn.running_mean.numpy(), mean1.numpy(), |
|
|
|
) |
|
|
|
mean2 = Parameter(np.zeros(shape), dtype=np.float32) |
|
|
|
bn2 = F.batch_norm2d( |
|
|
|
bn2 = F.batch_norm( |
|
|
|
bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True |
|
|
|
) |
|
|
|
np.testing.assert_allclose( |
|
|
|