You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_batchnorm.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import platform
  11. import numpy as np
  12. import pytest
  13. import megengine as mge
  14. import megengine.amp as amp
  15. import megengine.distributed as dist
  16. from megengine import Tensor, jit
  17. from megengine.autodiff.grad_manager import GradManager
  18. from megengine.core._trace_option import use_symbolic_shape
  19. from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm
  20. _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
  21. @pytest.mark.require_ngpu(2)
  22. @pytest.mark.isolated_distributed
  23. @pytest.mark.parametrize("enable_amp", [False, True])
  24. def test_syncbn(enable_amp):
  25. nr_chan = 8
  26. data_shape = (3, nr_chan, 4, 16)
  27. momentum = 0.9
  28. eps = 1e-5
  29. running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
  30. running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
  31. steps = 4
  32. nr_ranks = 2
  33. server = dist.Server()
  34. port = server.py_server_port
  35. @dist.launcher(n_gpus=2)
  36. def worker(data, yv_expect, running_mean, running_var):
  37. with amp.autocast(enabled=enable_amp):
  38. rank = dist.get_rank()
  39. bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps)
  40. for i in range(steps):
  41. yv = bn(Tensor(data[rank][i]))
  42. if enable_amp:
  43. np.testing.assert_allclose(
  44. yv.numpy(), yv_expect[rank], atol=5e-4, rtol=5e-4
  45. )
  46. else:
  47. _assert_allclose(yv.numpy(), yv_expect[rank])
  48. _assert_allclose(bn.running_mean.numpy(), running_mean)
  49. _assert_allclose(bn.running_var.numpy(), running_var)
  50. xv = []
  51. for i in range(steps):
  52. xv.append(np.random.normal(loc=2.3, size=data_shape).astype(np.float32))
  53. xv_transposed = np.transpose(xv[i], [0, 2, 3, 1]).reshape(
  54. (data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
  55. )
  56. mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
  57. var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
  58. sd = np.sqrt(var_biased + eps)
  59. var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
  60. running_mean = running_mean * momentum + mean * (1 - momentum)
  61. running_var = running_var * momentum + var_unbiased * (1 - momentum)
  62. yv_expect = (xv[i] - mean) / sd
  63. data = []
  64. for i in range(nr_ranks):
  65. data.append([])
  66. for j in range(steps):
  67. data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8])
  68. yv_expect = [yv_expect[:, :, :, i * 8 : i * 8 + 8] for i in range(nr_ranks)]
  69. worker(data, yv_expect, running_mean, running_var)
  70. def test_batchnorm():
  71. nr_chan = 8
  72. data_shape = (3, nr_chan, 4)
  73. momentum = 0.9
  74. bn = BatchNorm1d(nr_chan, momentum=momentum)
  75. running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32)
  76. running_var = np.ones((1, nr_chan, 1), dtype=np.float32)
  77. for i in range(3):
  78. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  79. mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
  80. xv_transposed = np.transpose(xv, [0, 2, 1]).reshape(
  81. (data_shape[0] * data_shape[2], nr_chan)
  82. )
  83. var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1))
  84. sd = np.sqrt(var_biased + bn.eps)
  85. var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1))
  86. running_mean = running_mean * momentum + mean * (1 - momentum)
  87. running_var = running_var * momentum + var_unbiased * (1 - momentum)
  88. yv = bn(Tensor(xv))
  89. yv_expect = (xv - mean) / sd
  90. _assert_allclose(yv.numpy(), yv_expect)
  91. _assert_allclose(bn.running_mean.numpy().reshape(-1), running_mean.reshape(-1))
  92. _assert_allclose(bn.running_var.numpy().reshape(-1), running_var.reshape(-1))
  93. # test set 'training' flag to False
  94. mean_backup = bn.running_mean.numpy()
  95. var_backup = bn.running_var.numpy()
  96. bn.training = False
  97. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  98. data = Tensor(xv)
  99. yv1 = bn(data)
  100. yv2 = bn(data)
  101. np.testing.assert_equal(yv1.numpy(), yv2.numpy())
  102. np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
  103. np.testing.assert_equal(var_backup, bn.running_var.numpy())
  104. yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
  105. _assert_allclose(yv1.numpy(), yv_expect)
  106. def test_syncbn1d():
  107. nr_chan = 8
  108. data_shape = (3, nr_chan, 4)
  109. momentum = 0.9
  110. bn = SyncBatchNorm(nr_chan, momentum=momentum)
  111. running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32)
  112. running_var = np.ones((1, nr_chan, 1), dtype=np.float32)
  113. for i in range(3):
  114. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  115. mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
  116. xv_transposed = np.transpose(xv, [0, 2, 1]).reshape(
  117. (data_shape[0] * data_shape[2], nr_chan)
  118. )
  119. var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1))
  120. sd = np.sqrt(var_biased + bn.eps)
  121. var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1))
  122. running_mean = running_mean * momentum + mean * (1 - momentum)
  123. running_var = running_var * momentum + var_unbiased * (1 - momentum)
  124. yv = bn(Tensor(xv))
  125. yv_expect = (xv - mean) / sd
  126. _assert_allclose(yv.numpy(), yv_expect)
  127. _assert_allclose(bn.running_mean.numpy().reshape(-1), running_mean.reshape(-1))
  128. _assert_allclose(bn.running_var.numpy().reshape(-1), running_var.reshape(-1))
  129. # test set 'training' flag to False
  130. mean_backup = bn.running_mean.numpy()
  131. var_backup = bn.running_var.numpy()
  132. bn.training = False
  133. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  134. data = Tensor(xv)
  135. yv1 = bn(data)
  136. yv2 = bn(data)
  137. np.testing.assert_equal(yv1.numpy(), yv2.numpy())
  138. np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
  139. np.testing.assert_equal(var_backup, bn.running_var.numpy())
  140. yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
  141. _assert_allclose(yv1.numpy(), yv_expect)
  142. def test_batchnorm2d():
  143. nr_chan = 8
  144. data_shape = (3, nr_chan, 16, 16)
  145. momentum = 0.9
  146. bn = BatchNorm2d(nr_chan, momentum=momentum)
  147. running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
  148. running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
  149. for i in range(3):
  150. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  151. xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
  152. (data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
  153. )
  154. mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
  155. var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
  156. sd = np.sqrt(var_biased + bn.eps)
  157. var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
  158. running_mean = running_mean * momentum + mean * (1 - momentum)
  159. running_var = running_var * momentum + var_unbiased * (1 - momentum)
  160. yv = bn(Tensor(xv))
  161. yv_expect = (xv - mean) / sd
  162. _assert_allclose(yv.numpy(), yv_expect)
  163. _assert_allclose(bn.running_mean.numpy(), running_mean)
  164. _assert_allclose(bn.running_var.numpy(), running_var)
  165. # test set 'training' flag to False
  166. mean_backup = bn.running_mean.numpy()
  167. var_backup = bn.running_var.numpy()
  168. bn.training = False
  169. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  170. data = Tensor(xv)
  171. yv1 = bn(data)
  172. yv2 = bn(data)
  173. np.testing.assert_equal(yv1.numpy(), yv2.numpy())
  174. np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
  175. np.testing.assert_equal(var_backup, bn.running_var.numpy())
  176. yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
  177. _assert_allclose(yv1.numpy(), yv_expect)
  178. def test_syncbn2d():
  179. nr_chan = 8
  180. data_shape = (3, nr_chan, 16, 16)
  181. momentum = 0.9
  182. bn = SyncBatchNorm(nr_chan, momentum=momentum)
  183. running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
  184. running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
  185. for i in range(3):
  186. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  187. xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
  188. (data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
  189. )
  190. mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
  191. var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
  192. sd = np.sqrt(var_biased + bn.eps)
  193. var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
  194. running_mean = running_mean * momentum + mean * (1 - momentum)
  195. running_var = running_var * momentum + var_unbiased * (1 - momentum)
  196. yv = bn(Tensor(xv))
  197. yv_expect = (xv - mean) / sd
  198. _assert_allclose(yv.numpy(), yv_expect)
  199. _assert_allclose(bn.running_mean.numpy(), running_mean)
  200. _assert_allclose(bn.running_var.numpy(), running_var)
  201. # test set 'training' flag to False
  202. mean_backup = bn.running_mean.numpy()
  203. var_backup = bn.running_var.numpy()
  204. bn.training = False
  205. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  206. data = Tensor(xv)
  207. yv1 = bn(data)
  208. yv2 = bn(data)
  209. np.testing.assert_equal(yv1.numpy(), yv2.numpy())
  210. np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
  211. np.testing.assert_equal(var_backup, bn.running_var.numpy())
  212. yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
  213. _assert_allclose(yv1.numpy(), yv_expect)
  214. def test_batchnorm_no_stats():
  215. nr_chan = 8
  216. data_shape = (3, nr_chan, 4)
  217. bn = BatchNorm1d(8, track_running_stats=False)
  218. for i in range(4):
  219. if i == 2:
  220. bn.training = False
  221. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  222. mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
  223. var = np.var(
  224. np.transpose(xv, [0, 2, 1]).reshape(
  225. (data_shape[0] * data_shape[2], nr_chan)
  226. ),
  227. axis=0,
  228. ).reshape((1, nr_chan, 1))
  229. sd = np.sqrt(var + bn.eps)
  230. yv = bn(Tensor(xv))
  231. yv_expect = (xv - mean) / sd
  232. _assert_allclose(yv.numpy(), yv_expect)
  233. def test_syncbn_no_stats():
  234. nr_chan = 8
  235. data_shape = (3, nr_chan, 4)
  236. bn = SyncBatchNorm(8, track_running_stats=False)
  237. for i in range(4):
  238. if i == 2:
  239. bn.training = False
  240. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  241. mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
  242. var = np.var(
  243. np.transpose(xv, [0, 2, 1]).reshape(
  244. (data_shape[0] * data_shape[2], nr_chan)
  245. ),
  246. axis=0,
  247. ).reshape((1, nr_chan, 1))
  248. sd = np.sqrt(var + bn.eps)
  249. yv = bn(Tensor(xv))
  250. yv_expect = (xv - mean) / sd
  251. _assert_allclose(yv.numpy(), yv_expect)
  252. def test_batchnorm2d_no_stats():
  253. nr_chan = 8
  254. data_shape = (3, nr_chan, 16, 16)
  255. bn = BatchNorm2d(8, track_running_stats=False)
  256. for i in range(4):
  257. if i == 2:
  258. bn.training = False
  259. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  260. xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
  261. (data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
  262. )
  263. mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
  264. var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
  265. sd = np.sqrt(var + bn.eps)
  266. yv = bn(Tensor(xv))
  267. yv_expect = (xv - mean) / sd
  268. _assert_allclose(yv.numpy(), yv_expect)
  269. def test_syncbn2d_no_stats():
  270. nr_chan = 8
  271. data_shape = (3, nr_chan, 16, 16)
  272. bn = SyncBatchNorm(8, track_running_stats=False)
  273. for i in range(4):
  274. if i == 2:
  275. bn.training = False
  276. xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
  277. xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
  278. (data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
  279. )
  280. mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
  281. var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
  282. sd = np.sqrt(var + bn.eps)
  283. yv = bn(Tensor(xv))
  284. yv_expect = (xv - mean) / sd
  285. _assert_allclose(yv.numpy(), yv_expect)
  286. def test_syncbn2d_grad():
  287. nr_chan = 8
  288. data_shape = (3, nr_chan, 16, 16)
  289. syncbn = SyncBatchNorm(8, track_running_stats=False)
  290. bn = BatchNorm2d(8, track_running_stats=False)
  291. for i in range(4):
  292. if i == 2:
  293. syncbn.training = False
  294. bn.training = False
  295. inp = Tensor(np.random.normal(loc=2.3, size=data_shape).astype(np.float32))
  296. diff = Tensor(np.random.normal(size=data_shape).astype(np.float32))
  297. with GradManager().attach(inp) as gm:
  298. oup = syncbn(inp)
  299. gm.backward(oup, diff)
  300. grad = inp.grad
  301. inp.grad = None
  302. with GradManager().attach(inp) as gm:
  303. oup_expect = bn(inp)
  304. gm.backward(oup_expect, diff)
  305. grad_expect = inp.grad
  306. inp.grad = None
  307. _assert_allclose(oup.numpy(), oup_expect.numpy())
  308. _assert_allclose(grad.numpy(), grad_expect.numpy())
  309. @pytest.mark.parametrize("dim", [1, 2])
  310. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  311. def test_batchnorm_empty_tensor(dim, is_symbolic):
  312. if dim == 1:
  313. m = BatchNorm1d(4, affine=True)
  314. inp = mge.tensor(np.random.randn(0, 4, 0).astype("float32"))
  315. elif dim == 2:
  316. m = BatchNorm2d(4, affine=True)
  317. inp = mge.tensor(np.random.randn(0, 4, 0, 0).astype("float32"))
  318. else:
  319. raise NotImplementedError
  320. m.train()
  321. def fn(inp):
  322. return m(inp)
  323. if is_symbolic is not None:
  324. fn = jit.trace(symbolic=is_symbolic)(fn)
  325. for _ in range(3):
  326. out = fn(inp)
  327. np.testing.assert_equal(out.numpy(), inp)
  328. if is_symbolic is None:
  329. break