You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_batchnorm.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. # -*- coding: utf-8 -*-
  2. import functools
  3. import platform
  4. import numpy as np
  5. import pytest
  6. import megengine as mge
  7. import megengine.amp as amp
  8. import megengine.distributed as dist
  9. from megengine import Tensor, jit
  10. from megengine.autodiff.grad_manager import GradManager
  11. from megengine.core._trace_option import use_symbolic_shape
  12. from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm
  13. _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
  14. @pytest.mark.require_ngpu(2)
  15. @pytest.mark.isolated_distributed
  16. @pytest.mark.parametrize("enable_amp", [False, True])
  17. def test_syncbn(enable_amp):
  18. nr_chan = 8
  19. data_shape = (3, nr_chan, 4, 16)
  20. momentum = 0.9
  21. eps = 1e-5
  22. running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
  23. running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
  24. steps = 4
  25. nr_ranks = 2
  26. server = dist.Server()
  27. port = server.py_server_port
  28. @dist.launcher(n_gpus=2)
  29. def worker(data, yv_expect, running_mean, running_var):
  30. with amp.autocast(enabled=enable_amp):
  31. rank = dist.get_rank()
  32. bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps)
  33. for i in range(steps):
  34. yv = bn(Tensor(data[rank][i]))
  35. if enable_amp:
  36. np.testing.assert_allclose(
  37. yv.numpy(), yv_expect[rank], atol=5e-4, rtol=5e-4
  38. )
  39. else:
  40. _assert_allclose(yv.numpy(), yv_expect[rank])
  41. _assert_allclose(bn.running_mean.numpy(), running_mean)
  42. _assert_allclose(bn.running_var.numpy(), running_var)
  43. xv = []
  44. for i in range(steps):
  45. xv.append(np.random.normal(loc=2.3, size=data_shape).astype(np.float32))
  46. xv_transposed = np.transpose(xv[i], [0, 2, 3, 1]).reshape(
  47. (data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
  48. )
  49. mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
  50. var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
  51. sd = np.sqrt(var_biased + eps)
  52. var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
  53. running_mean = running_mean * momentum + mean * (1 - momentum)
  54. running_var = running_var * momentum + var_unbiased * (1 - momentum)
  55. yv_expect = (xv[i] - mean) / sd
  56. data = []
  57. for i in range(nr_ranks):
  58. data.append([])
  59. for j in range(steps):
  60. data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8])
  61. yv_expect = [yv_expect[:, :, :, i * 8 : i * 8 + 8] for i in range(nr_ranks)]
  62. worker(data, yv_expect, running_mean, running_var)
  63. def test_batchnorm():
  64. nr_chan = 8
  65. data_shape = (3, nr_chan, 4)
  66. momentum = 0.9
  67. bn = BatchNorm1d(nr_chan, momentum=momentum)
  68. running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32)
  69. running_var = np.ones((1, nr_chan, 1), dtype=np.float32)
  70. for i in range(3):
  71. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  72. mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
  73. xv_transposed = np.transpose(xv, [0, 2, 1]).reshape(
  74. (data_shape[0] * data_shape[2], nr_chan)
  75. )
  76. var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1))
  77. sd = np.sqrt(var_biased + bn.eps)
  78. var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1))
  79. running_mean = running_mean * momentum + mean * (1 - momentum)
  80. running_var = running_var * momentum + var_unbiased * (1 - momentum)
  81. yv = bn(Tensor(xv))
  82. yv_expect = (xv - mean) / sd
  83. _assert_allclose(yv.numpy(), yv_expect)
  84. _assert_allclose(bn.running_mean.numpy().reshape(-1), running_mean.reshape(-1))
  85. _assert_allclose(bn.running_var.numpy().reshape(-1), running_var.reshape(-1))
  86. # test set 'training' flag to False
  87. mean_backup = bn.running_mean.numpy()
  88. var_backup = bn.running_var.numpy()
  89. bn.training = False
  90. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  91. data = Tensor(xv)
  92. yv1 = bn(data)
  93. yv2 = bn(data)
  94. np.testing.assert_equal(yv1.numpy(), yv2.numpy())
  95. np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
  96. np.testing.assert_equal(var_backup, bn.running_var.numpy())
  97. yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
  98. _assert_allclose(yv1.numpy(), yv_expect)
  99. def test_syncbn1d():
  100. nr_chan = 8
  101. data_shape = (3, nr_chan, 4)
  102. momentum = 0.9
  103. bn = SyncBatchNorm(nr_chan, momentum=momentum)
  104. running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32)
  105. running_var = np.ones((1, nr_chan, 1), dtype=np.float32)
  106. for i in range(3):
  107. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  108. mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
  109. xv_transposed = np.transpose(xv, [0, 2, 1]).reshape(
  110. (data_shape[0] * data_shape[2], nr_chan)
  111. )
  112. var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1))
  113. sd = np.sqrt(var_biased + bn.eps)
  114. var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1))
  115. running_mean = running_mean * momentum + mean * (1 - momentum)
  116. running_var = running_var * momentum + var_unbiased * (1 - momentum)
  117. yv = bn(Tensor(xv))
  118. yv_expect = (xv - mean) / sd
  119. _assert_allclose(yv.numpy(), yv_expect)
  120. _assert_allclose(bn.running_mean.numpy().reshape(-1), running_mean.reshape(-1))
  121. _assert_allclose(bn.running_var.numpy().reshape(-1), running_var.reshape(-1))
  122. # test set 'training' flag to False
  123. mean_backup = bn.running_mean.numpy()
  124. var_backup = bn.running_var.numpy()
  125. bn.training = False
  126. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  127. data = Tensor(xv)
  128. yv1 = bn(data)
  129. yv2 = bn(data)
  130. np.testing.assert_equal(yv1.numpy(), yv2.numpy())
  131. np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
  132. np.testing.assert_equal(var_backup, bn.running_var.numpy())
  133. yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
  134. _assert_allclose(yv1.numpy(), yv_expect)
  135. def test_batchnorm2d():
  136. nr_chan = 8
  137. data_shape = (3, nr_chan, 16, 16)
  138. momentum = 0.9
  139. bn = BatchNorm2d(nr_chan, momentum=momentum)
  140. running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
  141. running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
  142. for i in range(3):
  143. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  144. xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
  145. (data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
  146. )
  147. mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
  148. var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
  149. sd = np.sqrt(var_biased + bn.eps)
  150. var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
  151. running_mean = running_mean * momentum + mean * (1 - momentum)
  152. running_var = running_var * momentum + var_unbiased * (1 - momentum)
  153. yv = bn(Tensor(xv))
  154. yv_expect = (xv - mean) / sd
  155. _assert_allclose(yv.numpy(), yv_expect)
  156. _assert_allclose(bn.running_mean.numpy(), running_mean)
  157. _assert_allclose(bn.running_var.numpy(), running_var)
  158. # test set 'training' flag to False
  159. mean_backup = bn.running_mean.numpy()
  160. var_backup = bn.running_var.numpy()
  161. bn.training = False
  162. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  163. data = Tensor(xv)
  164. yv1 = bn(data)
  165. yv2 = bn(data)
  166. np.testing.assert_equal(yv1.numpy(), yv2.numpy())
  167. np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
  168. np.testing.assert_equal(var_backup, bn.running_var.numpy())
  169. yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
  170. _assert_allclose(yv1.numpy(), yv_expect)
  171. def test_syncbn2d():
  172. nr_chan = 8
  173. data_shape = (3, nr_chan, 16, 16)
  174. momentum = 0.9
  175. bn = SyncBatchNorm(nr_chan, momentum=momentum)
  176. running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
  177. running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
  178. for i in range(3):
  179. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  180. xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
  181. (data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
  182. )
  183. mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
  184. var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
  185. sd = np.sqrt(var_biased + bn.eps)
  186. var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
  187. running_mean = running_mean * momentum + mean * (1 - momentum)
  188. running_var = running_var * momentum + var_unbiased * (1 - momentum)
  189. yv = bn(Tensor(xv))
  190. yv_expect = (xv - mean) / sd
  191. _assert_allclose(yv.numpy(), yv_expect)
  192. _assert_allclose(bn.running_mean.numpy(), running_mean)
  193. _assert_allclose(bn.running_var.numpy(), running_var)
  194. # test set 'training' flag to False
  195. mean_backup = bn.running_mean.numpy()
  196. var_backup = bn.running_var.numpy()
  197. bn.training = False
  198. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  199. data = Tensor(xv)
  200. yv1 = bn(data)
  201. yv2 = bn(data)
  202. np.testing.assert_equal(yv1.numpy(), yv2.numpy())
  203. np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
  204. np.testing.assert_equal(var_backup, bn.running_var.numpy())
  205. yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
  206. _assert_allclose(yv1.numpy(), yv_expect)
  207. def test_batchnorm_no_stats():
  208. nr_chan = 8
  209. data_shape = (3, nr_chan, 4)
  210. bn = BatchNorm1d(8, track_running_stats=False)
  211. for i in range(4):
  212. if i == 2:
  213. bn.training = False
  214. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  215. mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
  216. var = np.var(
  217. np.transpose(xv, [0, 2, 1]).reshape(
  218. (data_shape[0] * data_shape[2], nr_chan)
  219. ),
  220. axis=0,
  221. ).reshape((1, nr_chan, 1))
  222. sd = np.sqrt(var + bn.eps)
  223. yv = bn(Tensor(xv))
  224. yv_expect = (xv - mean) / sd
  225. _assert_allclose(yv.numpy(), yv_expect)
  226. def test_syncbn_no_stats():
  227. nr_chan = 8
  228. data_shape = (3, nr_chan, 4)
  229. bn = SyncBatchNorm(8, track_running_stats=False)
  230. for i in range(4):
  231. if i == 2:
  232. bn.training = False
  233. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  234. mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
  235. var = np.var(
  236. np.transpose(xv, [0, 2, 1]).reshape(
  237. (data_shape[0] * data_shape[2], nr_chan)
  238. ),
  239. axis=0,
  240. ).reshape((1, nr_chan, 1))
  241. sd = np.sqrt(var + bn.eps)
  242. yv = bn(Tensor(xv))
  243. yv_expect = (xv - mean) / sd
  244. _assert_allclose(yv.numpy(), yv_expect)
  245. def test_batchnorm2d_no_stats():
  246. nr_chan = 8
  247. data_shape = (3, nr_chan, 16, 16)
  248. bn = BatchNorm2d(8, track_running_stats=False)
  249. for i in range(4):
  250. if i == 2:
  251. bn.training = False
  252. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  253. xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
  254. (data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
  255. )
  256. mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
  257. var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
  258. sd = np.sqrt(var + bn.eps)
  259. yv = bn(Tensor(xv))
  260. yv_expect = (xv - mean) / sd
  261. _assert_allclose(yv.numpy(), yv_expect)
  262. def test_syncbn2d_no_stats():
  263. nr_chan = 8
  264. data_shape = (3, nr_chan, 16, 16)
  265. bn = SyncBatchNorm(8, track_running_stats=False)
  266. for i in range(4):
  267. if i == 2:
  268. bn.training = False
  269. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  270. xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
  271. (data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
  272. )
  273. mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
  274. var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
  275. sd = np.sqrt(var + bn.eps)
  276. yv = bn(Tensor(xv))
  277. yv_expect = (xv - mean) / sd
  278. _assert_allclose(yv.numpy(), yv_expect)
  279. def test_syncbn2d_grad():
  280. nr_chan = 8
  281. data_shape = (3, nr_chan, 16, 16)
  282. syncbn = SyncBatchNorm(8, track_running_stats=False)
  283. bn = BatchNorm2d(8, track_running_stats=False)
  284. for i in range(4):
  285. if i == 2:
  286. syncbn.training = False
  287. bn.training = False
  288. inp = Tensor(np.random.normal(loc=2.3, size=data_shape).astype(np.float32))
  289. diff = Tensor(np.random.normal(size=data_shape).astype(np.float32))
  290. with GradManager().attach(inp) as gm:
  291. oup = syncbn(inp)
  292. gm.backward(oup, diff)
  293. grad = inp.grad
  294. inp.grad = None
  295. with GradManager().attach(inp) as gm:
  296. oup_expect = bn(inp)
  297. gm.backward(oup_expect, diff)
  298. grad_expect = inp.grad
  299. inp.grad = None
  300. _assert_allclose(oup.numpy(), oup_expect.numpy())
  301. _assert_allclose(grad.numpy(), grad_expect.numpy())
  302. @pytest.mark.parametrize("dim", [1, 2])
  303. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  304. def test_batchnorm_empty_tensor(dim, is_symbolic):
  305. if dim == 1:
  306. m = BatchNorm1d(4, affine=True)
  307. inp = mge.tensor(np.random.randn(0, 4, 0).astype("float32"))
  308. elif dim == 2:
  309. m = BatchNorm2d(4, affine=True)
  310. inp = mge.tensor(np.random.randn(0, 4, 0, 0).astype("float32"))
  311. else:
  312. raise NotImplementedError
  313. m.train()
  314. def fn(inp):
  315. return m(inp)
  316. if is_symbolic is not None:
  317. fn = jit.trace(symbolic=is_symbolic)(fn)
  318. for _ in range(3):
  319. out = fn(inp)
  320. np.testing.assert_equal(out.numpy(), inp)
  321. if is_symbolic is None:
  322. break