You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rng.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import pytest
  11. import megengine as mge
  12. import megengine.functional as F
  13. from megengine import Tensor, jit, random
  14. from megengine.core._imperative_rt import CompNode
  15. from megengine.core._imperative_rt.core2 import apply
  16. from megengine.core._imperative_rt.ops import (
  17. delete_rng_handle,
  18. get_global_rng_seed,
  19. new_rng_handle,
  20. )
  21. from megengine.core.autodiff.grad import Grad
  22. from megengine.core.ops.builtin import (
  23. BetaRNG,
  24. GammaRNG,
  25. GaussianRNG,
  26. PermutationRNG,
  27. PoissonRNG,
  28. UniformRNG,
  29. )
  30. from megengine.device import get_device_count
  31. from megengine.jit import trace
  32. from megengine.random import RNG
  33. from megengine.random import seed as set_global_seed
  34. from megengine.random import uniform
  35. @pytest.mark.skipif(
  36. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  37. )
  38. def test_gaussian_op():
  39. # FIXME: remove this sync
  40. mge.core.set_option("async_level", 0)
  41. set_global_seed(1024)
  42. shape = (
  43. 8,
  44. 9,
  45. 11,
  46. 12,
  47. )
  48. shape = Tensor(shape, dtype="int32")
  49. op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0, dtype="float32")
  50. (output,) = apply(op, shape)
  51. assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
  52. assert np.fabs(np.sqrt(output.numpy().var()) - 3.0) < 1e-1
  53. assert str(output.device) == str(CompNode("xpux"))
  54. assert output.dtype == np.float32
  55. cn = CompNode("xpu2")
  56. seed = 233333
  57. h = new_rng_handle(cn, seed)
  58. op = GaussianRNG(seed=seed, mean=3.0, std=1.0, dtype="float32", handle=h)
  59. (output,) = apply(op, shape)
  60. delete_rng_handle(h)
  61. assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
  62. assert np.fabs(np.sqrt(output.numpy().var()) - 1.0) < 1e-1
  63. assert str(output.device) == str(cn)
  64. assert output.dtype == np.float32
  65. @pytest.mark.skipif(
  66. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  67. )
  68. def test_uniform_op():
  69. set_global_seed(1024)
  70. shape = (
  71. 8,
  72. 9,
  73. 11,
  74. 12,
  75. )
  76. shape = Tensor(shape, dtype="int32")
  77. op = UniformRNG(seed=get_global_rng_seed(), dtype="float32")
  78. (output,) = apply(op, shape)
  79. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  80. assert str(output.device) == str(CompNode("xpux"))
  81. assert output.dtype == np.float32
  82. cn = CompNode("xpu2")
  83. seed = 233333
  84. h = new_rng_handle(cn, seed)
  85. op = UniformRNG(seed=seed, dtype="float32", handle=h)
  86. (output,) = apply(op, shape)
  87. delete_rng_handle(h)
  88. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  89. assert str(output.device) == str(cn)
  90. assert output.dtype == np.float32
  91. @pytest.mark.skipif(
  92. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  93. )
  94. def test_gamma_op():
  95. set_global_seed(1024)
  96. _shape, _scale = 2, 0.8
  97. _expected_mean, _expected_std = _shape * _scale, np.sqrt(_shape) * _scale
  98. shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32")
  99. scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32")
  100. op = GammaRNG(seed=get_global_rng_seed(), handle=0)
  101. (output,) = apply(op, shape, scale)
  102. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  103. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  104. assert str(output.device) == str(CompNode("xpux"))
  105. cn = CompNode("xpu2")
  106. seed = 233333
  107. h = new_rng_handle(cn, seed)
  108. shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32", device="xpu2")
  109. scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32", device="xpu2")
  110. op = GammaRNG(seed=seed, handle=h)
  111. (output,) = apply(op, shape, scale)
  112. delete_rng_handle(h)
  113. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  114. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  115. assert str(output.device) == str(cn)
  116. @pytest.mark.skipif(
  117. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  118. )
  119. def test_beta_op():
  120. set_global_seed(1024)
  121. _alpha, _beta = 2, 0.8
  122. _expected_mean = _alpha / (_alpha + _beta)
  123. _expected_std = np.sqrt(
  124. _alpha * _beta / ((_alpha + _beta) ** 2 * (_alpha + _beta + 1))
  125. )
  126. alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32")
  127. beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32")
  128. op = BetaRNG(seed=get_global_rng_seed())
  129. (output,) = apply(op, alpha, beta)
  130. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  131. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  132. assert str(output.device) == str(CompNode("xpux"))
  133. cn = CompNode("xpu2")
  134. seed = 233333
  135. h = new_rng_handle(cn, seed)
  136. alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32", device=cn)
  137. beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32", device=cn)
  138. op = BetaRNG(seed=seed, handle=h)
  139. (output,) = apply(op, alpha, beta)
  140. delete_rng_handle(h)
  141. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  142. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  143. assert str(output.device) == str(cn)
  144. @pytest.mark.skipif(
  145. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  146. )
  147. def test_poisson_op():
  148. set_global_seed(1024)
  149. lam = F.full([8, 9, 11, 12], value=2, dtype="float32")
  150. op = PoissonRNG(seed=get_global_rng_seed())
  151. (output,) = apply(op, lam)
  152. assert np.fabs(output.numpy().mean() - 2.0) < 1e-1
  153. assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1
  154. assert str(output.device) == str(CompNode("xpux"))
  155. cn = CompNode("xpu2")
  156. seed = 233333
  157. h = new_rng_handle(cn, seed)
  158. lam = F.full([8, 9, 11, 12], value=2, dtype="float32", device=cn)
  159. op = PoissonRNG(seed=seed, handle=h)
  160. (output,) = apply(op, lam)
  161. delete_rng_handle(h)
  162. assert np.fabs(output.numpy().mean() - 2.0) < 1e-1
  163. assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1
  164. assert str(output.device) == str(cn)
  165. @pytest.mark.skipif(
  166. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  167. )
  168. def test_permutation_op():
  169. set_global_seed(1024)
  170. n = 1000
  171. def test_permutation_op_dtype(dtype):
  172. def sum_result(res, fun):
  173. return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))])
  174. shape = Tensor((n,), dtype="int32")
  175. op = PermutationRNG(seed=get_global_rng_seed(), dtype=dtype)
  176. (output,) = apply(op, shape)
  177. assert sum_result(output, lambda x: x) < 500
  178. assert sum_result(output, np.sort) == n
  179. assert str(output.device) == str(CompNode("xpux"))
  180. assert output.dtype == dtype
  181. cn = CompNode("xpu2")
  182. seed = 233333
  183. h = new_rng_handle(cn, seed)
  184. op = PermutationRNG(seed=seed, handle=h, dtype=dtype)
  185. (output,) = apply(op, shape)
  186. delete_rng_handle(h)
  187. assert sum_result(output, lambda x: x) < 500
  188. assert sum_result(output, np.sort) == n
  189. assert str(output.device) == str(cn)
  190. assert output.dtype == dtype
  191. test_permutation_op_dtype(np.float32)
  192. test_permutation_op_dtype(np.int32)
  193. test_permutation_op_dtype(np.int16)
  194. @pytest.mark.skipif(
  195. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  196. )
  197. def test_UniformRNG():
  198. m1 = RNG(seed=111, device="xpu0")
  199. m2 = RNG(seed=111, device="xpu1")
  200. m3 = RNG(seed=222, device="xpu0")
  201. out1 = m1.uniform(size=(100,))
  202. out1_ = m1.uniform(size=(100,))
  203. out2 = m2.uniform(size=(100,))
  204. out3 = m3.uniform(size=(100,))
  205. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  206. assert out1.device == "xpu0" and out2.device == "xpu1"
  207. assert not (out1.numpy() == out3.numpy()).all()
  208. assert not (out1.numpy() == out1_.numpy()).all()
  209. low = -234
  210. high = 123
  211. out = m1.uniform(low=low, high=high, size=(20, 30, 40))
  212. out_shp = out.shape
  213. if isinstance(out_shp, tuple):
  214. assert out_shp == (20, 30, 40)
  215. else:
  216. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  217. assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1
  218. @pytest.mark.skipif(
  219. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  220. )
  221. def test_NormalRNG():
  222. m1 = RNG(seed=111, device="xpu0")
  223. m2 = RNG(seed=111, device="xpu1")
  224. m3 = RNG(seed=222, device="xpu0")
  225. out1 = m1.normal(size=(100,))
  226. out1_ = m1.uniform(size=(100,))
  227. out2 = m2.normal(size=(100,))
  228. out3 = m3.normal(size=(100,))
  229. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  230. assert out1.device == "xpu0" and out2.device == "xpu1"
  231. assert not (out1.numpy() == out3.numpy()).all()
  232. assert not (out1.numpy() == out1_.numpy()).all()
  233. mean = -1
  234. std = 2
  235. out = m1.normal(mean=mean, std=std, size=(20, 30, 40))
  236. out_shp = out.shape
  237. if isinstance(out_shp, tuple):
  238. assert out_shp == (20, 30, 40)
  239. else:
  240. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  241. assert np.abs(out.mean().numpy() - mean) / std < 0.1
  242. assert np.abs(np.std(out.numpy()) - std) < 0.1
  243. @pytest.mark.skipif(
  244. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  245. )
  246. def test_GammaRNG():
  247. m1 = RNG(seed=111, device="xpu0")
  248. m2 = RNG(seed=111, device="xpu1")
  249. m3 = RNG(seed=222, device="xpu0")
  250. out1 = m1.gamma(2, size=(100,))
  251. out1_ = m1.uniform(size=(100,))
  252. out2 = m2.gamma(2, size=(100,))
  253. out3 = m3.gamma(2, size=(100,))
  254. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  255. assert out1.device == "xpu0" and out2.device == "xpu1"
  256. assert not (out1.numpy() == out3.numpy()).all()
  257. assert not (out1.numpy() == out1_.numpy()).all()
  258. shape = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0")
  259. scale = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0")
  260. expected_mean = (shape * scale).numpy()
  261. expected_std = (F.sqrt(shape) * scale).numpy()
  262. out = m1.gamma(shape=shape, scale=scale, size=(20, 30, 40))
  263. out_shp = out.shape
  264. if isinstance(out_shp, tuple):
  265. assert out_shp == (20, 30, 40, 2, 3)
  266. else:
  267. assert all(out.shape.numpy() == np.array([20, 30, 40, 2, 3]))
  268. assert (
  269. np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std
  270. ).mean() < 0.1
  271. assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1
  272. @pytest.mark.skipif(
  273. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  274. )
  275. def test_BetaRNG():
  276. m1 = RNG(seed=111, device="xpu0")
  277. m2 = RNG(seed=111, device="xpu1")
  278. m3 = RNG(seed=222, device="xpu0")
  279. out1 = m1.beta(2, 1, size=(100,))
  280. out1_ = m1.uniform(size=(100,))
  281. out2 = m2.beta(2, 1, size=(100,))
  282. out3 = m3.beta(2, 1, size=(100,))
  283. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  284. assert out1.device == "xpu0" and out2.device == "xpu1"
  285. assert not (out1.numpy() == out3.numpy()).all()
  286. assert not (out1.numpy() == out1_.numpy()).all()
  287. alpha = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0")
  288. beta = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0")
  289. expected_mean = (alpha / (alpha + beta)).numpy()
  290. expected_std = (
  291. F.sqrt(alpha * beta / (F.pow(alpha + beta, 2) * (alpha + beta + 1)))
  292. ).numpy()
  293. out = m1.beta(alpha=alpha, beta=beta, size=(20, 30))
  294. out_shp = out.shape
  295. if isinstance(out_shp, tuple):
  296. assert out_shp == (20, 30, 2, 3)
  297. else:
  298. assert all(out.shape.numpy() == np.array([20, 30, 2, 3]))
  299. assert (
  300. np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std
  301. ).mean() < 0.1
  302. assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1
  303. @pytest.mark.skipif(
  304. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  305. )
  306. def test_PoissonRNG():
  307. m1 = RNG(seed=111, device="xpu0")
  308. m2 = RNG(seed=111, device="xpu1")
  309. m3 = RNG(seed=222, device="xpu0")
  310. lam = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32)
  311. out1 = m1.poisson(lam.to("xpu0"), size=(100,))
  312. out2 = m2.poisson(lam.to("xpu1"), size=(100,))
  313. out3 = m3.poisson(lam.to("xpu0"), size=(100,))
  314. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  315. assert out1.device == "xpu0" and out2.device == "xpu1"
  316. assert not (out1.numpy() == out3.numpy()).all()
  317. out = m1.poisson(lam.to("xpu0"), size=(20, 30))
  318. out_shp = out.shape
  319. expected_shape = (20, 30) + lam._tuple_shape
  320. if isinstance(out_shp, tuple):
  321. assert out_shp == expected_shape
  322. else:
  323. assert all(out.shape.numpy() == np.array(expected_shape))
  324. lam = lam.numpy()
  325. assert (np.abs(out.mean(axis=(0, 1)).numpy() - lam) / np.sqrt(lam)).mean() < 0.1
  326. assert np.abs(np.std(out.numpy(), axis=(0, 1)) - np.sqrt(lam)).mean() < 0.1
  327. @pytest.mark.skipif(
  328. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  329. )
  330. @pytest.mark.parametrize("symbolic", [True, False])
  331. def test_PermutationRNG(symbolic):
  332. m1 = RNG(seed=111, device="xpu0")
  333. m2 = RNG(seed=111, device="xpu1")
  334. m3 = RNG(seed=222, device="xpu0")
  335. out1 = m1.permutation(1000)
  336. out1_ = m1.uniform(size=(1000,))
  337. out2 = m2.permutation(1000)
  338. out3 = m3.permutation(1000)
  339. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  340. assert out1.device == "xpu0" and out2.device == "xpu1"
  341. assert not (out1.numpy() == out3.numpy()).all()
  342. assert not (out1.numpy() == out1_.numpy()).all()
  343. out = m1.permutation(1000)
  344. out_shp = out.shape
  345. if isinstance(out_shp, tuple):
  346. assert out_shp == (1000,)
  347. else:
  348. assert all(out.shape.numpy() == np.array([1000]))
  349. def sum_result(res, fun):
  350. return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))])
  351. assert sum_result(out, lambda x: x) < 500
  352. assert sum_result(out, np.sort) == 1000
  353. def func():
  354. out = m1.permutation(Tensor(7))
  355. out_shp = out.shape
  356. if isinstance(out_shp, tuple):
  357. assert out_shp == (1,)
  358. else:
  359. assert all(out.shape.numpy() == np.array([1]))
  360. n, m = 6, 3
  361. out = m1.permutation(Tensor(np.arange(n * m), dtype="float32").reshape(n, m))
  362. out_shp = out.shape
  363. if isinstance(out_shp, tuple):
  364. assert out_shp == (n, m)
  365. else:
  366. assert all(out.shape.numpy() == np.array([n, m]))
  367. func = trace(symbolic=symbolic)(func)
  368. func()
  369. @pytest.mark.skipif(
  370. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  371. )
  372. def test_ShuffleRNG():
  373. g = []
  374. def cb(grad):
  375. g.append(grad)
  376. n, m = 6, 3
  377. arr = np.arange(n * m)
  378. out0 = Tensor(arr, dtype="float32")
  379. with Grad() as grad:
  380. grad.wrt(out0, callback=cb)
  381. random.shuffle(out0)
  382. grad(out0, F.ones_like(out0))
  383. m1 = RNG(seed=111, device="xpu0")
  384. m2 = RNG(seed=111, device="xpu1")
  385. m3 = RNG(seed=222, device="xpu0")
  386. out1 = Tensor(arr, dtype="float32", device="xpu0")
  387. out2 = Tensor(arr, dtype="float32", device="xpu1")
  388. out3 = Tensor(arr, dtype="float32", device="xpu0")
  389. m1.shuffle(out1)
  390. m2.shuffle(out2)
  391. m3.shuffle(out3)
  392. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  393. assert out1.device == "xpu0" and out2.device == "xpu1"
  394. assert not (out1.numpy() == out3.numpy()).all()
  395. out = Tensor(arr, dtype="float32").reshape(n, m)
  396. m1.shuffle(out)
  397. out_shp = out.shape
  398. if isinstance(out_shp, tuple):
  399. assert out_shp == (n, m)
  400. else:
  401. assert all(out.shape.numpy() == np.array([n, m]))
  402. def test_seed():
  403. set_global_seed(10)
  404. out1 = uniform(size=[10, 10])
  405. out2 = uniform(size=[10, 10])
  406. assert not (out1.numpy() == out2.numpy()).all()
  407. set_global_seed(10)
  408. out3 = uniform(size=[10, 10])
  409. np.testing.assert_allclose(out1.numpy(), out3.numpy(), atol=1e-6)
  410. set_global_seed(11)
  411. out4 = uniform(size=[10, 10])
  412. assert not (out1.numpy() == out4.numpy()).all()
  413. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  414. def test_rng_empty_tensor(is_symbolic):
  415. set_global_seed(1024)
  416. shapes = [
  417. (0,),
  418. (0, 0, 0),
  419. (10, 0, 10),
  420. ]
  421. def fn(shape):
  422. o1 = random.uniform(0, 1, shape)
  423. o2 = random.normal(0, 1, shape)
  424. o3 = random.gamma(2, 1, shape)
  425. o4 = random.beta(2, 1, shape)
  426. o5 = random.poisson(2, shape)
  427. return o1, o2, o3, o4, o5
  428. for shape in shapes:
  429. if is_symbolic is not None:
  430. fn_ = jit.trace(symbolic=is_symbolic)(fn)
  431. else:
  432. fn_ = fn
  433. for _ in range(3):
  434. outs = fn_(shape)
  435. for out in outs:
  436. np.testing.assert_equal(out.numpy().shape, shape)
  437. if is_symbolic is None:
  438. break
  439. def fn2(n):
  440. return random.permutation(n=n)
  441. if is_symbolic is not None:
  442. fn2 = jit.trace(symbolic=is_symbolic)(fn2)
  443. for _ in range(3):
  444. out = fn2(0)
  445. np.testing.assert_equal(out.numpy().shape, (0,))
  446. if is_symbolic is None:
  447. break
  448. mge.core.set_option("async_level", 2)