You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bn.py 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. import pytest
  4. import megengine
  5. import megengine.autodiff as ad
  6. import megengine.distributed as dist
  7. import megengine.functional as F
  8. import megengine.optimizer as optimizer
  9. from megengine import tensor
  10. from megengine.jit import trace
  11. from megengine.module import BatchNorm2d, Conv2d, Module, Sequential, SyncBatchNorm
  12. def run_frozen_bn(BNModule, is_training, use_trace, use_symbolic):
  13. nchannel = 3
  14. m = BNModule(nchannel, freeze=True)
  15. if is_training:
  16. m.train()
  17. else:
  18. m.eval()
  19. var = 4.0
  20. bias = 1.0
  21. shape = (1, nchannel, 1, 1)
  22. m.running_var[...] = var * F.ones(shape)
  23. m.running_mean[...] = bias * F.ones(shape)
  24. saved_var = m.running_var.numpy()
  25. saved_mean = m.running_mean.numpy()
  26. saved_wt = m.weight.numpy()
  27. saved_bias = m.bias.numpy()
  28. gm = ad.GradManager().attach(m.parameters())
  29. optim = optimizer.SGD(m.parameters(), lr=1.0)
  30. optim.clear_grad()
  31. data = np.random.random((6, nchannel, 2, 2)).astype("float32")
  32. def train_fn(d):
  33. for _ in range(3):
  34. with gm:
  35. loss = m(d).mean()
  36. gm.backward(loss)
  37. optim.step()
  38. return loss
  39. if use_trace:
  40. train_fn = trace(train_fn, symbolic=use_symbolic)
  41. for _ in range(3):
  42. loss = train_fn(megengine.tensor(data))
  43. if not is_training:
  44. np.testing.assert_equal(m.running_var.numpy(), saved_var)
  45. np.testing.assert_equal(m.running_mean.numpy(), saved_mean)
  46. np.testing.assert_almost_equal(
  47. loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5
  48. )
  49. np.testing.assert_equal(m.weight.numpy(), saved_wt)
  50. np.testing.assert_equal(m.bias.numpy(), saved_bias)
  51. @pytest.mark.parametrize("is_training", [False, True])
  52. @pytest.mark.parametrize("use_trace", [False, True])
  53. @pytest.mark.parametrize("use_symbolic", [False, True])
  54. def test_frozen_bn(is_training, use_trace, use_symbolic):
  55. run_frozen_bn(BatchNorm2d, is_training, use_trace, use_symbolic)
  56. @pytest.mark.require_ngpu(2)
  57. @pytest.mark.isolated_distributed
  58. @pytest.mark.parametrize("is_training", [False, True])
  59. @pytest.mark.parametrize("use_trace", [False, True])
  60. @pytest.mark.parametrize("use_symbolic", [False, True])
  61. def test_frozen_synced_bn(is_training, use_trace, use_symbolic):
  62. @dist.launcher(n_gpus=2)
  63. def worker():
  64. run_frozen_bn(SyncBatchNorm, is_training, use_trace, use_symbolic)
  65. worker()
  66. def test_bn_no_track_stat():
  67. nchannel = 3
  68. m = BatchNorm2d(nchannel, track_running_stats=False)
  69. gm = ad.GradManager().attach(m.parameters())
  70. optim = optimizer.SGD(m.parameters(), lr=1.0)
  71. optim.clear_grad()
  72. data = tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
  73. with gm:
  74. loss = m(data).sum()
  75. gm.backward(loss)
  76. optim.step()
  77. def test_bn_no_track_stat2():
  78. nchannel = 3
  79. m = BatchNorm2d(nchannel) # Init with track_running_stat = True
  80. m.track_running_stats = False
  81. # m.running_var and m.running_mean created during init time
  82. saved_var = m.running_var.numpy()
  83. assert saved_var is not None
  84. saved_mean = m.running_mean.numpy()
  85. assert saved_mean is not None
  86. gm = ad.GradManager().attach(m.parameters())
  87. optim = optimizer.SGD(m.parameters(), lr=1.0)
  88. optim.clear_grad()
  89. data = tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
  90. with gm:
  91. loss = m(data).sum()
  92. gm.backward(loss)
  93. optim.step()
  94. np.testing.assert_equal(m.running_var.numpy(), saved_var)
  95. np.testing.assert_equal(m.running_mean.numpy(), saved_mean)
  96. def test_bn_no_track_stat3():
  97. nchannel = 3
  98. m = BatchNorm2d(nchannel, track_running_stats=False)
  99. m.track_running_stats = True
  100. data = np.random.random((6, nchannel, 2, 2)).astype("float32")
  101. with pytest.raises(Exception):
  102. m(data)
  103. def test_trace_bn_forward_twice():
  104. class Simple(Module):
  105. def __init__(self):
  106. super().__init__()
  107. self.bn = BatchNorm2d(1)
  108. def forward(self, inp):
  109. x = self.bn(inp)
  110. x = self.bn(x)
  111. return x
  112. @trace(symbolic=True)
  113. def train_bn(inp, net=None):
  114. net.train()
  115. pred = net(inp)
  116. return pred
  117. x = tensor(np.ones((1, 1, 32, 32), dtype=np.float32))
  118. y = train_bn(x, net=Simple())
  119. np.testing.assert_equal(y.numpy(), 0)
  120. def run_syncbn(trace_mode):
  121. x = F.ones([2, 16, 4, 4], dtype="float32")
  122. net = Sequential(
  123. Conv2d(16, 16, 1), SyncBatchNorm(16), Conv2d(16, 16, 1), SyncBatchNorm(16),
  124. )
  125. gm = ad.GradManager().attach(
  126. net.parameters(), callbacks=dist.make_allreduce_cb("MEAN")
  127. )
  128. opt = optimizer.SGD(net.parameters(), 1e-3)
  129. def train_func(x):
  130. with gm:
  131. y = net(x)
  132. loss = y.mean()
  133. gm.backward(loss)
  134. opt.step().clear_grad()
  135. return loss
  136. if trace_mode is not None:
  137. train_func = trace(train_func, symbolic=trace_mode)
  138. for _ in range(3):
  139. loss = train_func(x)
  140. loss.numpy()
  141. @pytest.mark.require_ngpu(2)
  142. @pytest.mark.isolated_distributed
  143. @pytest.mark.parametrize("trace_mode", [None, True, False])
  144. def test_trace_several_syncbn(trace_mode):
  145. @dist.launcher(n_gpus=2)
  146. def worker():
  147. run_syncbn(trace_mode)
  148. worker()
  149. # https://github.com/MegEngine/MegEngine/issues/145
  150. @pytest.mark.parametrize("is_training", [False, True])
  151. def test_frozen_bn_no_affine(is_training):
  152. nchannel = 3
  153. m = BatchNorm2d(nchannel, freeze=True, affine=False)
  154. if is_training:
  155. m.train()
  156. else:
  157. m.eval()
  158. data = megengine.tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
  159. m(data).numpy()