|
@@ -13,7 +13,8 @@ import megengine |
|
|
import megengine.autodiff as ad |
|
|
import megengine.autodiff as ad |
|
|
import megengine.optimizer as optimizer |
|
|
import megengine.optimizer as optimizer |
|
|
from megengine import Parameter, tensor |
|
|
from megengine import Parameter, tensor |
|
|
from megengine.module import BatchNorm2d |
|
|
|
|
|
|
|
|
from megengine.jit import trace |
|
|
|
|
|
from megengine.module import BatchNorm2d, Module |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_frozen_bn(): |
|
|
def test_frozen_bn(): |
|
@@ -89,3 +90,25 @@ def test_bn_no_track_stat3(): |
|
|
data = np.random.random((6, nchannel, 2, 2)).astype("float32") |
|
|
data = np.random.random((6, nchannel, 2, 2)).astype("float32") |
|
|
with pytest.raises(Exception): |
|
|
with pytest.raises(Exception): |
|
|
m(data) |
|
|
m(data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_trace_bn_forward_twice(): |
|
|
|
|
|
class Simple(Module): |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.bn = BatchNorm2d(1) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, inp): |
|
|
|
|
|
x = self.bn(inp) |
|
|
|
|
|
x = self.bn(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
@trace(symbolic=True) |
|
|
|
|
|
def train_bn(inp, net=None): |
|
|
|
|
|
net.train() |
|
|
|
|
|
pred = net(inp) |
|
|
|
|
|
return pred |
|
|
|
|
|
|
|
|
|
|
|
x = np.ones((1, 1, 32, 32), dtype=np.float32) |
|
|
|
|
|
y = train_bn(x, net=Simple()) |
|
|
|
|
|
np.testing.assert_equal(y.numpy(), 0) |