You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_batchnorm.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import platform
  11. import numpy as np
  12. import pytest
  13. import megengine as mge
  14. import megengine.distributed as dist
  15. from megengine import Tensor, jit
  16. from megengine.autodiff.grad_manager import GradManager
  17. from megengine.core._trace_option import use_symbolic_shape
  18. from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm
  19. _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
  20. @pytest.mark.require_ngpu(2)
  21. @pytest.mark.isolated_distributed
  22. def test_syncbn():
  23. nr_chan = 8
  24. data_shape = (3, nr_chan, 4, 16)
  25. momentum = 0.9
  26. eps = 1e-5
  27. running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
  28. running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
  29. steps = 4
  30. nr_ranks = 2
  31. server = dist.Server()
  32. port = server.py_server_port
  33. @dist.launcher(n_gpus=2)
  34. def worker(data, yv_expect, running_mean, running_var):
  35. rank = dist.get_rank()
  36. bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps)
  37. for i in range(steps):
  38. yv = bn(Tensor(data[rank][i]))
  39. _assert_allclose(yv.numpy(), yv_expect[rank])
  40. _assert_allclose(bn.running_mean.numpy(), running_mean)
  41. _assert_allclose(bn.running_var.numpy(), running_var)
  42. xv = []
  43. for i in range(steps):
  44. xv.append(np.random.normal(loc=2.3, size=data_shape).astype(np.float32))
  45. xv_transposed = np.transpose(xv[i], [0, 2, 3, 1]).reshape(
  46. (data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
  47. )
  48. mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
  49. var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
  50. sd = np.sqrt(var_biased + eps)
  51. var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
  52. running_mean = running_mean * momentum + mean * (1 - momentum)
  53. running_var = running_var * momentum + var_unbiased * (1 - momentum)
  54. yv_expect = (xv[i] - mean) / sd
  55. data = []
  56. for i in range(nr_ranks):
  57. data.append([])
  58. for j in range(steps):
  59. data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8])
  60. yv_expect = [yv_expect[:, :, :, i * 8 : i * 8 + 8] for i in range(nr_ranks)]
  61. worker(data, yv_expect, running_mean, running_var)
  62. def test_batchnorm():
  63. nr_chan = 8
  64. data_shape = (3, nr_chan, 4)
  65. momentum = 0.9
  66. bn = BatchNorm1d(nr_chan, momentum=momentum)
  67. running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32)
  68. running_var = np.ones((1, nr_chan, 1), dtype=np.float32)
  69. for i in range(3):
  70. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  71. mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
  72. xv_transposed = np.transpose(xv, [0, 2, 1]).reshape(
  73. (data_shape[0] * data_shape[2], nr_chan)
  74. )
  75. var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1))
  76. sd = np.sqrt(var_biased + bn.eps)
  77. var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1))
  78. running_mean = running_mean * momentum + mean * (1 - momentum)
  79. running_var = running_var * momentum + var_unbiased * (1 - momentum)
  80. yv = bn(Tensor(xv))
  81. yv_expect = (xv - mean) / sd
  82. _assert_allclose(yv.numpy(), yv_expect)
  83. _assert_allclose(bn.running_mean.numpy().reshape(-1), running_mean.reshape(-1))
  84. _assert_allclose(bn.running_var.numpy().reshape(-1), running_var.reshape(-1))
  85. # test set 'training' flag to False
  86. mean_backup = bn.running_mean.numpy()
  87. var_backup = bn.running_var.numpy()
  88. bn.training = False
  89. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  90. data = Tensor(xv)
  91. yv1 = bn(data)
  92. yv2 = bn(data)
  93. np.testing.assert_equal(yv1.numpy(), yv2.numpy())
  94. np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
  95. np.testing.assert_equal(var_backup, bn.running_var.numpy())
  96. yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
  97. _assert_allclose(yv1.numpy(), yv_expect)
  98. def test_syncbn1d():
  99. nr_chan = 8
  100. data_shape = (3, nr_chan, 4)
  101. momentum = 0.9
  102. bn = SyncBatchNorm(nr_chan, momentum=momentum)
  103. running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32)
  104. running_var = np.ones((1, nr_chan, 1), dtype=np.float32)
  105. for i in range(3):
  106. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  107. mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
  108. xv_transposed = np.transpose(xv, [0, 2, 1]).reshape(
  109. (data_shape[0] * data_shape[2], nr_chan)
  110. )
  111. var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1))
  112. sd = np.sqrt(var_biased + bn.eps)
  113. var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1))
  114. running_mean = running_mean * momentum + mean * (1 - momentum)
  115. running_var = running_var * momentum + var_unbiased * (1 - momentum)
  116. yv = bn(Tensor(xv))
  117. yv_expect = (xv - mean) / sd
  118. _assert_allclose(yv.numpy(), yv_expect)
  119. _assert_allclose(bn.running_mean.numpy().reshape(-1), running_mean.reshape(-1))
  120. _assert_allclose(bn.running_var.numpy().reshape(-1), running_var.reshape(-1))
  121. # test set 'training' flag to False
  122. mean_backup = bn.running_mean.numpy()
  123. var_backup = bn.running_var.numpy()
  124. bn.training = False
  125. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  126. data = Tensor(xv)
  127. yv1 = bn(data)
  128. yv2 = bn(data)
  129. np.testing.assert_equal(yv1.numpy(), yv2.numpy())
  130. np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
  131. np.testing.assert_equal(var_backup, bn.running_var.numpy())
  132. yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
  133. _assert_allclose(yv1.numpy(), yv_expect)
  134. def test_batchnorm2d():
  135. nr_chan = 8
  136. data_shape = (3, nr_chan, 16, 16)
  137. momentum = 0.9
  138. bn = BatchNorm2d(nr_chan, momentum=momentum)
  139. running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
  140. running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
  141. for i in range(3):
  142. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  143. xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
  144. (data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
  145. )
  146. mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
  147. var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
  148. sd = np.sqrt(var_biased + bn.eps)
  149. var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
  150. running_mean = running_mean * momentum + mean * (1 - momentum)
  151. running_var = running_var * momentum + var_unbiased * (1 - momentum)
  152. yv = bn(Tensor(xv))
  153. yv_expect = (xv - mean) / sd
  154. _assert_allclose(yv.numpy(), yv_expect)
  155. _assert_allclose(bn.running_mean.numpy(), running_mean)
  156. _assert_allclose(bn.running_var.numpy(), running_var)
  157. # test set 'training' flag to False
  158. mean_backup = bn.running_mean.numpy()
  159. var_backup = bn.running_var.numpy()
  160. bn.training = False
  161. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  162. data = Tensor(xv)
  163. yv1 = bn(data)
  164. yv2 = bn(data)
  165. np.testing.assert_equal(yv1.numpy(), yv2.numpy())
  166. np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
  167. np.testing.assert_equal(var_backup, bn.running_var.numpy())
  168. yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
  169. _assert_allclose(yv1.numpy(), yv_expect)
  170. def test_syncbn2d():
  171. nr_chan = 8
  172. data_shape = (3, nr_chan, 16, 16)
  173. momentum = 0.9
  174. bn = SyncBatchNorm(nr_chan, momentum=momentum)
  175. running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
  176. running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
  177. for i in range(3):
  178. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  179. xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
  180. (data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
  181. )
  182. mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
  183. var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
  184. sd = np.sqrt(var_biased + bn.eps)
  185. var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
  186. running_mean = running_mean * momentum + mean * (1 - momentum)
  187. running_var = running_var * momentum + var_unbiased * (1 - momentum)
  188. yv = bn(Tensor(xv))
  189. yv_expect = (xv - mean) / sd
  190. _assert_allclose(yv.numpy(), yv_expect)
  191. _assert_allclose(bn.running_mean.numpy(), running_mean)
  192. _assert_allclose(bn.running_var.numpy(), running_var)
  193. # test set 'training' flag to False
  194. mean_backup = bn.running_mean.numpy()
  195. var_backup = bn.running_var.numpy()
  196. bn.training = False
  197. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  198. data = Tensor(xv)
  199. yv1 = bn(data)
  200. yv2 = bn(data)
  201. np.testing.assert_equal(yv1.numpy(), yv2.numpy())
  202. np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
  203. np.testing.assert_equal(var_backup, bn.running_var.numpy())
  204. yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
  205. _assert_allclose(yv1.numpy(), yv_expect)
  206. def test_batchnorm_no_stats():
  207. nr_chan = 8
  208. data_shape = (3, nr_chan, 4)
  209. bn = BatchNorm1d(8, track_running_stats=False)
  210. for i in range(4):
  211. if i == 2:
  212. bn.training = False
  213. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  214. mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
  215. var = np.var(
  216. np.transpose(xv, [0, 2, 1]).reshape(
  217. (data_shape[0] * data_shape[2], nr_chan)
  218. ),
  219. axis=0,
  220. ).reshape((1, nr_chan, 1))
  221. sd = np.sqrt(var + bn.eps)
  222. yv = bn(Tensor(xv))
  223. yv_expect = (xv - mean) / sd
  224. _assert_allclose(yv.numpy(), yv_expect)
  225. def test_syncbn_no_stats():
  226. nr_chan = 8
  227. data_shape = (3, nr_chan, 4)
  228. bn = SyncBatchNorm(8, track_running_stats=False)
  229. for i in range(4):
  230. if i == 2:
  231. bn.training = False
  232. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  233. mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
  234. var = np.var(
  235. np.transpose(xv, [0, 2, 1]).reshape(
  236. (data_shape[0] * data_shape[2], nr_chan)
  237. ),
  238. axis=0,
  239. ).reshape((1, nr_chan, 1))
  240. sd = np.sqrt(var + bn.eps)
  241. yv = bn(Tensor(xv))
  242. yv_expect = (xv - mean) / sd
  243. _assert_allclose(yv.numpy(), yv_expect)
  244. def test_batchnorm2d_no_stats():
  245. nr_chan = 8
  246. data_shape = (3, nr_chan, 16, 16)
  247. bn = BatchNorm2d(8, track_running_stats=False)
  248. for i in range(4):
  249. if i == 2:
  250. bn.training = False
  251. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  252. xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
  253. (data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
  254. )
  255. mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
  256. var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
  257. sd = np.sqrt(var + bn.eps)
  258. yv = bn(Tensor(xv))
  259. yv_expect = (xv - mean) / sd
  260. _assert_allclose(yv.numpy(), yv_expect)
  261. def test_syncbn2d_no_stats():
  262. nr_chan = 8
  263. data_shape = (3, nr_chan, 16, 16)
  264. bn = SyncBatchNorm(8, track_running_stats=False)
  265. for i in range(4):
  266. if i == 2:
  267. bn.training = False
  268. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  269. xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
  270. (data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
  271. )
  272. mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
  273. var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
  274. sd = np.sqrt(var + bn.eps)
  275. yv = bn(Tensor(xv))
  276. yv_expect = (xv - mean) / sd
  277. _assert_allclose(yv.numpy(), yv_expect)
  278. def test_syncbn2d_grad():
  279. nr_chan = 8
  280. data_shape = (3, nr_chan, 16, 16)
  281. syncbn = SyncBatchNorm(8, track_running_stats=False)
  282. bn = BatchNorm2d(8, track_running_stats=False)
  283. for i in range(4):
  284. if i == 2:
  285. syncbn.training = False
  286. bn.training = False
  287. inp = Tensor(np.random.normal(loc=2.3, size=data_shape).astype(np.float32))
  288. diff = Tensor(np.random.normal(size=data_shape).astype(np.float32))
  289. with GradManager().attach(inp) as gm:
  290. oup = syncbn(inp)
  291. gm.backward(oup, diff)
  292. grad = inp.grad
  293. inp.grad = None
  294. with GradManager().attach(inp) as gm:
  295. oup_expect = bn(inp)
  296. gm.backward(oup_expect, diff)
  297. grad_expect = inp.grad
  298. inp.grad = None
  299. _assert_allclose(oup.numpy(), oup_expect.numpy())
  300. _assert_allclose(grad.numpy(), grad_expect.numpy())
  301. @pytest.mark.parametrize("dim", [1, 2])
  302. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  303. def test_batchnorm_empty_tensor(dim, is_symbolic):
  304. if dim == 1:
  305. m = BatchNorm1d(4, affine=True)
  306. inp = mge.tensor(np.random.randn(0, 4, 0).astype("float32"))
  307. elif dim == 2:
  308. m = BatchNorm2d(4, affine=True)
  309. inp = mge.tensor(np.random.randn(0, 4, 0, 0).astype("float32"))
  310. else:
  311. raise NotImplementedError
  312. m.train()
  313. def fn(inp):
  314. return m(inp)
  315. if is_symbolic is not None:
  316. fn = jit.trace(symbolic=is_symbolic)(fn)
  317. for _ in range(3):
  318. out = fn(inp)
  319. np.testing.assert_equal(out.numpy(), inp)
  320. if is_symbolic is None:
  321. break