You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rng.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import pytest
  11. import megengine as mge
  12. import megengine.functional as F
  13. from megengine import Tensor, jit, random
  14. from megengine.core._imperative_rt import CompNode
  15. from megengine.core._imperative_rt.core2 import apply
  16. from megengine.core._imperative_rt.ops import (
  17. delete_rng_handle,
  18. get_global_rng_seed,
  19. new_rng_handle,
  20. )
  21. from megengine.core.autodiff.grad import Grad
  22. from megengine.core.ops.builtin import (
  23. BetaRNG,
  24. GammaRNG,
  25. GaussianRNG,
  26. PermutationRNG,
  27. PoissonRNG,
  28. UniformRNG,
  29. )
  30. from megengine.device import get_device_count
  31. from megengine.jit import trace
  32. from megengine.random import RNG
  33. from megengine.random import seed as set_global_seed
  34. from megengine.random import uniform
  35. @pytest.mark.skipif(
  36. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  37. )
  38. def test_gaussian_op():
  39. set_global_seed(1024)
  40. shape = (
  41. 8,
  42. 9,
  43. 11,
  44. 12,
  45. )
  46. shape = Tensor(shape, dtype="int32")
  47. op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0, dtype="float32")
  48. (output,) = apply(op, shape)
  49. assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
  50. assert np.fabs(np.sqrt(output.numpy().var()) - 3.0) < 1e-1
  51. assert str(output.device) == str(CompNode("xpux"))
  52. assert output.dtype == np.float32
  53. cn = CompNode("xpu2")
  54. seed = 233333
  55. h = new_rng_handle(cn, seed)
  56. op = GaussianRNG(seed=seed, mean=3.0, std=1.0, dtype="float32", handle=h)
  57. (output,) = apply(op, shape)
  58. delete_rng_handle(h)
  59. assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
  60. assert np.fabs(np.sqrt(output.numpy().var()) - 1.0) < 1e-1
  61. assert str(output.device) == str(cn)
  62. assert output.dtype == np.float32
  63. @pytest.mark.skipif(
  64. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  65. )
  66. def test_uniform_op():
  67. set_global_seed(1024)
  68. shape = (
  69. 8,
  70. 9,
  71. 11,
  72. 12,
  73. )
  74. shape = Tensor(shape, dtype="int32")
  75. op = UniformRNG(seed=get_global_rng_seed(), dtype="float32")
  76. (output,) = apply(op, shape)
  77. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  78. assert str(output.device) == str(CompNode("xpux"))
  79. assert output.dtype == np.float32
  80. cn = CompNode("xpu2")
  81. seed = 233333
  82. h = new_rng_handle(cn, seed)
  83. op = UniformRNG(seed=seed, dtype="float32", handle=h)
  84. (output,) = apply(op, shape)
  85. delete_rng_handle(h)
  86. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  87. assert str(output.device) == str(cn)
  88. assert output.dtype == np.float32
  89. @pytest.mark.skipif(
  90. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  91. )
  92. def test_gamma_op():
  93. set_global_seed(1024)
  94. _shape, _scale = 2, 0.8
  95. _expected_mean, _expected_std = _shape * _scale, np.sqrt(_shape) * _scale
  96. shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32")
  97. scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32")
  98. op = GammaRNG(seed=get_global_rng_seed(), handle=0)
  99. (output,) = apply(op, shape, scale)
  100. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  101. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  102. assert str(output.device) == str(CompNode("xpux"))
  103. cn = CompNode("xpu2")
  104. seed = 233333
  105. h = new_rng_handle(cn, seed)
  106. shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32", device="xpu2")
  107. scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32", device="xpu2")
  108. op = GammaRNG(seed=seed, handle=h)
  109. (output,) = apply(op, shape, scale)
  110. delete_rng_handle(h)
  111. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  112. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  113. assert str(output.device) == str(cn)
  114. @pytest.mark.skipif(
  115. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  116. )
  117. def test_beta_op():
  118. set_global_seed(1024)
  119. _alpha, _beta = 2, 0.8
  120. _expected_mean = _alpha / (_alpha + _beta)
  121. _expected_std = np.sqrt(
  122. _alpha * _beta / ((_alpha + _beta) ** 2 * (_alpha + _beta + 1))
  123. )
  124. alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32")
  125. beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32")
  126. op = BetaRNG(seed=get_global_rng_seed())
  127. (output,) = apply(op, alpha, beta)
  128. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  129. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  130. assert str(output.device) == str(CompNode("xpux"))
  131. cn = CompNode("xpu2")
  132. seed = 233333
  133. h = new_rng_handle(cn, seed)
  134. alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32", device=cn)
  135. beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32", device=cn)
  136. op = BetaRNG(seed=seed, handle=h)
  137. (output,) = apply(op, alpha, beta)
  138. delete_rng_handle(h)
  139. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  140. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  141. assert str(output.device) == str(cn)
  142. @pytest.mark.skipif(
  143. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  144. )
  145. def test_poisson_op():
  146. set_global_seed(1024)
  147. lam = F.full([8, 9, 11, 12], value=2, dtype="float32")
  148. op = PoissonRNG(seed=get_global_rng_seed())
  149. (output,) = apply(op, lam)
  150. assert np.fabs(output.numpy().mean() - 2.0) < 1e-1
  151. assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1
  152. assert str(output.device) == str(CompNode("xpux"))
  153. cn = CompNode("xpu2")
  154. seed = 233333
  155. h = new_rng_handle(cn, seed)
  156. lam = F.full([8, 9, 11, 12], value=2, dtype="float32", device=cn)
  157. op = PoissonRNG(seed=seed, handle=h)
  158. (output,) = apply(op, lam)
  159. delete_rng_handle(h)
  160. assert np.fabs(output.numpy().mean() - 2.0) < 1e-1
  161. assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1
  162. assert str(output.device) == str(cn)
  163. @pytest.mark.skipif(
  164. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  165. )
  166. def test_permutation_op():
  167. set_global_seed(1024)
  168. n = 1000
  169. def test_permutation_op_dtype(dtype):
  170. def sum_result(res, fun):
  171. return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))])
  172. shape = Tensor((n,), dtype="int32")
  173. op = PermutationRNG(seed=get_global_rng_seed(), dtype=dtype)
  174. (output,) = apply(op, shape)
  175. assert sum_result(output, lambda x: x) < 500
  176. assert sum_result(output, np.sort) == n
  177. assert str(output.device) == str(CompNode("xpux"))
  178. assert output.dtype == dtype
  179. cn = CompNode("xpu2")
  180. seed = 233333
  181. h = new_rng_handle(cn, seed)
  182. op = PermutationRNG(seed=seed, handle=h, dtype=dtype)
  183. (output,) = apply(op, shape)
  184. delete_rng_handle(h)
  185. assert sum_result(output, lambda x: x) < 500
  186. assert sum_result(output, np.sort) == n
  187. assert str(output.device) == str(cn)
  188. assert output.dtype == dtype
  189. test_permutation_op_dtype(np.float32)
  190. test_permutation_op_dtype(np.int32)
  191. test_permutation_op_dtype(np.int16)
  192. @pytest.mark.skipif(
  193. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  194. )
  195. def test_UniformRNG():
  196. m1 = RNG(seed=111, device="xpu0")
  197. m2 = RNG(seed=111, device="xpu1")
  198. m3 = RNG(seed=222, device="xpu0")
  199. out1 = m1.uniform(size=(100,))
  200. out1_ = m1.uniform(size=(100,))
  201. out2 = m2.uniform(size=(100,))
  202. out3 = m3.uniform(size=(100,))
  203. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  204. assert out1.device == "xpu0" and out2.device == "xpu1"
  205. assert not (out1.numpy() == out3.numpy()).all()
  206. assert not (out1.numpy() == out1_.numpy()).all()
  207. low = -234
  208. high = 123
  209. out = m1.uniform(low=low, high=high, size=(20, 30, 40))
  210. out_shp = out.shape
  211. if isinstance(out_shp, tuple):
  212. assert out_shp == (20, 30, 40)
  213. else:
  214. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  215. assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1
  216. @pytest.mark.skipif(
  217. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  218. )
  219. def test_NormalRNG():
  220. m1 = RNG(seed=111, device="xpu0")
  221. m2 = RNG(seed=111, device="xpu1")
  222. m3 = RNG(seed=222, device="xpu0")
  223. out1 = m1.normal(size=(100,))
  224. out1_ = m1.uniform(size=(100,))
  225. out2 = m2.normal(size=(100,))
  226. out3 = m3.normal(size=(100,))
  227. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  228. assert out1.device == "xpu0" and out2.device == "xpu1"
  229. assert not (out1.numpy() == out3.numpy()).all()
  230. assert not (out1.numpy() == out1_.numpy()).all()
  231. mean = -1
  232. std = 2
  233. out = m1.normal(mean=mean, std=std, size=(20, 30, 40))
  234. out_shp = out.shape
  235. if isinstance(out_shp, tuple):
  236. assert out_shp == (20, 30, 40)
  237. else:
  238. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  239. assert np.abs(out.mean().numpy() - mean) / std < 0.1
  240. assert np.abs(np.std(out.numpy()) - std) < 0.1
  241. @pytest.mark.skipif(
  242. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  243. )
  244. def test_GammaRNG():
  245. m1 = RNG(seed=111, device="xpu0")
  246. m2 = RNG(seed=111, device="xpu1")
  247. m3 = RNG(seed=222, device="xpu0")
  248. out1 = m1.gamma(2, size=(100,))
  249. out1_ = m1.uniform(size=(100,))
  250. out2 = m2.gamma(2, size=(100,))
  251. out3 = m3.gamma(2, size=(100,))
  252. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  253. assert out1.device == "xpu0" and out2.device == "xpu1"
  254. assert not (out1.numpy() == out3.numpy()).all()
  255. assert not (out1.numpy() == out1_.numpy()).all()
  256. shape = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0")
  257. scale = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0")
  258. expected_mean = (shape * scale).numpy()
  259. expected_std = (F.sqrt(shape) * scale).numpy()
  260. out = m1.gamma(shape=shape, scale=scale, size=(20, 30, 40))
  261. out_shp = out.shape
  262. if isinstance(out_shp, tuple):
  263. assert out_shp == (20, 30, 40, 2, 3)
  264. else:
  265. assert all(out.shape.numpy() == np.array([20, 30, 40, 2, 3]))
  266. assert (
  267. np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std
  268. ).mean() < 0.1
  269. assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1
  270. @pytest.mark.skipif(
  271. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  272. )
  273. def test_BetaRNG():
  274. m1 = RNG(seed=111, device="xpu0")
  275. m2 = RNG(seed=111, device="xpu1")
  276. m3 = RNG(seed=222, device="xpu0")
  277. out1 = m1.beta(2, 1, size=(100,))
  278. out1_ = m1.uniform(size=(100,))
  279. out2 = m2.beta(2, 1, size=(100,))
  280. out3 = m3.beta(2, 1, size=(100,))
  281. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  282. assert out1.device == "xpu0" and out2.device == "xpu1"
  283. assert not (out1.numpy() == out3.numpy()).all()
  284. assert not (out1.numpy() == out1_.numpy()).all()
  285. alpha = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0")
  286. beta = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0")
  287. expected_mean = (alpha / (alpha + beta)).numpy()
  288. expected_std = (
  289. F.sqrt(alpha * beta / (F.pow(alpha + beta, 2) * (alpha + beta + 1)))
  290. ).numpy()
  291. out = m1.beta(alpha=alpha, beta=beta, size=(20, 30))
  292. out_shp = out.shape
  293. if isinstance(out_shp, tuple):
  294. assert out_shp == (20, 30, 2, 3)
  295. else:
  296. assert all(out.shape.numpy() == np.array([20, 30, 2, 3]))
  297. assert (
  298. np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std
  299. ).mean() < 0.1
  300. assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1
  301. @pytest.mark.skipif(
  302. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  303. )
  304. def test_PoissonRNG():
  305. m1 = RNG(seed=111, device="xpu0")
  306. m2 = RNG(seed=111, device="xpu1")
  307. m3 = RNG(seed=222, device="xpu0")
  308. lam = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32)
  309. out1 = m1.poisson(lam.to("xpu0"), size=(100,))
  310. out2 = m2.poisson(lam.to("xpu1"), size=(100,))
  311. out3 = m3.poisson(lam.to("xpu0"), size=(100,))
  312. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  313. assert out1.device == "xpu0" and out2.device == "xpu1"
  314. assert not (out1.numpy() == out3.numpy()).all()
  315. out = m1.poisson(lam.to("xpu0"), size=(20, 30))
  316. out_shp = out.shape
  317. expected_shape = (20, 30) + lam._tuple_shape
  318. if isinstance(out_shp, tuple):
  319. assert out_shp == expected_shape
  320. else:
  321. assert all(out.shape.numpy() == np.array(expected_shape))
  322. lam = lam.numpy()
  323. assert (np.abs(out.mean(axis=(0, 1)).numpy() - lam) / np.sqrt(lam)).mean() < 0.1
  324. assert np.abs(np.std(out.numpy(), axis=(0, 1)) - np.sqrt(lam)).mean() < 0.1
  325. @pytest.mark.skipif(
  326. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  327. )
  328. @pytest.mark.parametrize("symbolic", [True, False])
  329. def test_PermutationRNG(symbolic):
  330. m1 = RNG(seed=111, device="xpu0")
  331. m2 = RNG(seed=111, device="xpu1")
  332. m3 = RNG(seed=222, device="xpu0")
  333. out1 = m1.permutation(1000)
  334. out1_ = m1.uniform(size=(1000,))
  335. out2 = m2.permutation(1000)
  336. out3 = m3.permutation(1000)
  337. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  338. assert out1.device == "xpu0" and out2.device == "xpu1"
  339. assert not (out1.numpy() == out3.numpy()).all()
  340. assert not (out1.numpy() == out1_.numpy()).all()
  341. out = m1.permutation(1000)
  342. out_shp = out.shape
  343. if isinstance(out_shp, tuple):
  344. assert out_shp == (1000,)
  345. else:
  346. assert all(out.shape.numpy() == np.array([1000]))
  347. def sum_result(res, fun):
  348. return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))])
  349. assert sum_result(out, lambda x: x) < 500
  350. assert sum_result(out, np.sort) == 1000
  351. def func():
  352. out = m1.permutation(Tensor(7))
  353. out_shp = out.shape
  354. if isinstance(out_shp, tuple):
  355. assert out_shp == (1,)
  356. else:
  357. assert all(out.shape.numpy() == np.array([1]))
  358. n, m = 6, 3
  359. out = m1.permutation(Tensor(np.arange(n * m), dtype="float32").reshape(n, m))
  360. out_shp = out.shape
  361. if isinstance(out_shp, tuple):
  362. assert out_shp == (n, m)
  363. else:
  364. assert all(out.shape.numpy() == np.array([n, m]))
  365. func = trace(symbolic=symbolic)(func)
  366. func()
  367. @pytest.mark.skipif(
  368. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  369. )
  370. def test_ShuffleRNG():
  371. g = []
  372. def cb(grad):
  373. g.append(grad)
  374. n, m = 6, 3
  375. arr = np.arange(n * m)
  376. out0 = Tensor(arr, dtype="float32")
  377. with Grad() as grad:
  378. grad.wrt(out0, callback=cb)
  379. random.shuffle(out0)
  380. grad(out0, F.ones_like(out0))
  381. m1 = RNG(seed=111, device="xpu0")
  382. m2 = RNG(seed=111, device="xpu1")
  383. m3 = RNG(seed=222, device="xpu0")
  384. out1 = Tensor(arr, dtype="float32", device="xpu0")
  385. out2 = Tensor(arr, dtype="float32", device="xpu1")
  386. out3 = Tensor(arr, dtype="float32", device="xpu0")
  387. m1.shuffle(out1)
  388. m2.shuffle(out2)
  389. m3.shuffle(out3)
  390. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  391. assert out1.device == "xpu0" and out2.device == "xpu1"
  392. assert not (out1.numpy() == out3.numpy()).all()
  393. out = Tensor(arr, dtype="float32").reshape(n, m)
  394. m1.shuffle(out)
  395. out_shp = out.shape
  396. if isinstance(out_shp, tuple):
  397. assert out_shp == (n, m)
  398. else:
  399. assert all(out.shape.numpy() == np.array([n, m]))
  400. def test_seed():
  401. set_global_seed(10)
  402. out1 = uniform(size=[10, 10])
  403. out2 = uniform(size=[10, 10])
  404. assert not (out1.numpy() == out2.numpy()).all()
  405. set_global_seed(10)
  406. out3 = uniform(size=[10, 10])
  407. np.testing.assert_allclose(out1.numpy(), out3.numpy(), atol=1e-6)
  408. set_global_seed(11)
  409. out4 = uniform(size=[10, 10])
  410. assert not (out1.numpy() == out4.numpy()).all()
  411. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  412. def test_rng_empty_tensor(is_symbolic):
  413. set_global_seed(1024)
  414. shapes = [
  415. (0,),
  416. (0, 0, 0),
  417. (10, 0, 10),
  418. ]
  419. def fn(shape):
  420. o1 = random.uniform(0, 1, shape)
  421. o2 = random.normal(0, 1, shape)
  422. o3 = random.gamma(2, 1, shape)
  423. o4 = random.beta(2, 1, shape)
  424. o5 = random.poisson(2, shape)
  425. return o1, o2, o3, o4, o5
  426. for shape in shapes:
  427. if is_symbolic is not None:
  428. fn_ = jit.trace(symbolic=is_symbolic)(fn)
  429. else:
  430. fn_ = fn
  431. for _ in range(3):
  432. outs = fn_(shape)
  433. for out in outs:
  434. np.testing.assert_equal(out.numpy().shape, shape)
  435. if is_symbolic is None:
  436. break
  437. def fn2(n):
  438. return random.permutation(n=n)
  439. if is_symbolic is not None:
  440. fn2 = jit.trace(symbolic=is_symbolic)(fn2)
  441. for _ in range(3):
  442. out = fn2(0)
  443. np.testing.assert_equal(out.numpy().shape, (0,))
  444. if is_symbolic is None:
  445. break