You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rng.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. import pytest
  4. import megengine as mge
  5. import megengine.functional as F
  6. from megengine import Tensor, jit, random
  7. from megengine.core._imperative_rt import CompNode
  8. from megengine.core._imperative_rt.core2 import apply
  9. from megengine.core._imperative_rt.ops import (
  10. delete_rng_handle,
  11. get_global_rng_seed,
  12. new_rng_handle,
  13. )
  14. from megengine.core.autodiff.grad import Grad
  15. from megengine.core.ops.builtin import (
  16. BetaRNG,
  17. GammaRNG,
  18. GaussianRNG,
  19. PermutationRNG,
  20. PoissonRNG,
  21. UniformRNG,
  22. )
  23. from megengine.device import get_device_count
  24. from megengine.jit import trace
  25. from megengine.random import RNG
  26. from megengine.random import seed as set_global_seed
  27. from megengine.random import uniform
  28. @pytest.mark.skipif(
  29. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  30. )
  31. def test_gaussian_op():
  32. set_global_seed(1024)
  33. shape = (
  34. 8,
  35. 9,
  36. 11,
  37. 12,
  38. )
  39. shape = Tensor(shape, dtype="int32")
  40. op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0, dtype="float32")
  41. (output,) = apply(op, shape)
  42. assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
  43. assert np.fabs(np.sqrt(output.numpy().var()) - 3.0) < 1e-1
  44. assert str(output.device) == str(CompNode("xpux"))
  45. assert output.dtype == np.float32
  46. cn = CompNode("xpu2")
  47. seed = 233333
  48. h = new_rng_handle(cn, seed)
  49. op = GaussianRNG(seed=seed, mean=3.0, std=1.0, dtype="float32", handle=h)
  50. (output,) = apply(op, shape)
  51. delete_rng_handle(h)
  52. assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
  53. assert np.fabs(np.sqrt(output.numpy().var()) - 1.0) < 1e-1
  54. assert str(output.device) == str(cn)
  55. assert output.dtype == np.float32
  56. @pytest.mark.skipif(
  57. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  58. )
  59. def test_uniform_op():
  60. set_global_seed(1024)
  61. shape = (
  62. 8,
  63. 9,
  64. 11,
  65. 12,
  66. )
  67. shape = Tensor(shape, dtype="int32")
  68. op = UniformRNG(seed=get_global_rng_seed(), dtype="float32")
  69. (output,) = apply(op, shape)
  70. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  71. assert str(output.device) == str(CompNode("xpux"))
  72. assert output.dtype == np.float32
  73. cn = CompNode("xpu2")
  74. seed = 233333
  75. h = new_rng_handle(cn, seed)
  76. op = UniformRNG(seed=seed, dtype="float32", handle=h)
  77. (output,) = apply(op, shape)
  78. delete_rng_handle(h)
  79. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  80. assert str(output.device) == str(cn)
  81. assert output.dtype == np.float32
  82. @pytest.mark.skipif(
  83. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  84. )
  85. def test_gamma_op():
  86. set_global_seed(1024)
  87. _shape, _scale = 2, 0.8
  88. _expected_mean, _expected_std = _shape * _scale, np.sqrt(_shape) * _scale
  89. shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32")
  90. scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32")
  91. op = GammaRNG(seed=get_global_rng_seed(), handle=0)
  92. (output,) = apply(op, shape, scale)
  93. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  94. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  95. assert str(output.device) == str(CompNode("xpux"))
  96. cn = CompNode("xpu2")
  97. seed = 233333
  98. h = new_rng_handle(cn, seed)
  99. shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32", device="xpu2")
  100. scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32", device="xpu2")
  101. op = GammaRNG(seed=seed, handle=h)
  102. (output,) = apply(op, shape, scale)
  103. delete_rng_handle(h)
  104. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  105. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  106. assert str(output.device) == str(cn)
  107. @pytest.mark.skipif(
  108. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  109. )
  110. def test_beta_op():
  111. set_global_seed(1024)
  112. _alpha, _beta = 2, 0.8
  113. _expected_mean = _alpha / (_alpha + _beta)
  114. _expected_std = np.sqrt(
  115. _alpha * _beta / ((_alpha + _beta) ** 2 * (_alpha + _beta + 1))
  116. )
  117. alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32")
  118. beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32")
  119. op = BetaRNG(seed=get_global_rng_seed())
  120. (output,) = apply(op, alpha, beta)
  121. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  122. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  123. assert str(output.device) == str(CompNode("xpux"))
  124. cn = CompNode("xpu2")
  125. seed = 233333
  126. h = new_rng_handle(cn, seed)
  127. alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32", device=cn)
  128. beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32", device=cn)
  129. op = BetaRNG(seed=seed, handle=h)
  130. (output,) = apply(op, alpha, beta)
  131. delete_rng_handle(h)
  132. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  133. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  134. assert str(output.device) == str(cn)
  135. @pytest.mark.skipif(
  136. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  137. )
  138. def test_poisson_op():
  139. set_global_seed(1024)
  140. lam = F.full([8, 9, 11, 12], value=2, dtype="float32")
  141. op = PoissonRNG(seed=get_global_rng_seed())
  142. (output,) = apply(op, lam)
  143. assert np.fabs(output.numpy().mean() - 2.0) < 1e-1
  144. assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1
  145. assert str(output.device) == str(CompNode("xpux"))
  146. cn = CompNode("xpu2")
  147. seed = 233333
  148. h = new_rng_handle(cn, seed)
  149. lam = F.full([8, 9, 11, 12], value=2, dtype="float32", device=cn)
  150. op = PoissonRNG(seed=seed, handle=h)
  151. (output,) = apply(op, lam)
  152. delete_rng_handle(h)
  153. assert np.fabs(output.numpy().mean() - 2.0) < 1e-1
  154. assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1
  155. assert str(output.device) == str(cn)
  156. @pytest.mark.skipif(
  157. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  158. )
  159. def test_permutation_op():
  160. set_global_seed(1024)
  161. n = 1000
  162. def test_permutation_op_dtype(dtype):
  163. def sum_result(res, fun):
  164. return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))])
  165. shape = Tensor((n,), dtype="int32")
  166. op = PermutationRNG(seed=get_global_rng_seed(), dtype=dtype)
  167. (output,) = apply(op, shape)
  168. assert sum_result(output, lambda x: x) < 500
  169. assert sum_result(output, np.sort) == n
  170. assert str(output.device) == str(CompNode("xpux"))
  171. assert output.dtype == dtype
  172. cn = CompNode("xpu2")
  173. seed = 233333
  174. h = new_rng_handle(cn, seed)
  175. op = PermutationRNG(seed=seed, handle=h, dtype=dtype)
  176. (output,) = apply(op, shape)
  177. delete_rng_handle(h)
  178. assert sum_result(output, lambda x: x) < 500
  179. assert sum_result(output, np.sort) == n
  180. assert str(output.device) == str(cn)
  181. assert output.dtype == dtype
  182. test_permutation_op_dtype(np.float32)
  183. test_permutation_op_dtype(np.int32)
  184. test_permutation_op_dtype(np.int16)
  185. @pytest.mark.skipif(
  186. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  187. )
  188. def test_UniformRNG():
  189. m1 = RNG(seed=111, device="xpu0")
  190. m2 = RNG(seed=111, device="xpu1")
  191. m3 = RNG(seed=222, device="xpu0")
  192. out1 = m1.uniform(size=(100,))
  193. out1_ = m1.uniform(size=(100,))
  194. out2 = m2.uniform(size=(100,))
  195. out3 = m3.uniform(size=(100,))
  196. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  197. assert out1.device == "xpu0" and out2.device == "xpu1"
  198. assert not (out1.numpy() == out3.numpy()).all()
  199. assert not (out1.numpy() == out1_.numpy()).all()
  200. low = -234
  201. high = 123
  202. out = m1.uniform(low=low, high=high, size=(20, 30, 40))
  203. out_shp = out.shape
  204. if isinstance(out_shp, tuple):
  205. assert out_shp == (20, 30, 40)
  206. else:
  207. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  208. assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1
  209. @pytest.mark.skipif(
  210. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  211. )
  212. def test_NormalRNG():
  213. m1 = RNG(seed=111, device="xpu0")
  214. m2 = RNG(seed=111, device="xpu1")
  215. m3 = RNG(seed=222, device="xpu0")
  216. out1 = m1.normal(size=(100,))
  217. out1_ = m1.uniform(size=(100,))
  218. out2 = m2.normal(size=(100,))
  219. out3 = m3.normal(size=(100,))
  220. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  221. assert out1.device == "xpu0" and out2.device == "xpu1"
  222. assert not (out1.numpy() == out3.numpy()).all()
  223. assert not (out1.numpy() == out1_.numpy()).all()
  224. mean = -1
  225. std = 2
  226. out = m1.normal(mean=mean, std=std, size=(20, 30, 40))
  227. out_shp = out.shape
  228. if isinstance(out_shp, tuple):
  229. assert out_shp == (20, 30, 40)
  230. else:
  231. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  232. assert np.abs(out.mean().numpy() - mean) / std < 0.1
  233. assert np.abs(np.std(out.numpy()) - std) < 0.1
  234. @pytest.mark.skipif(
  235. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  236. )
  237. def test_GammaRNG():
  238. m1 = RNG(seed=111, device="xpu0")
  239. m2 = RNG(seed=111, device="xpu1")
  240. m3 = RNG(seed=222, device="xpu0")
  241. out1 = m1.gamma(2, size=(100,))
  242. out1_ = m1.uniform(size=(100,))
  243. out2 = m2.gamma(2, size=(100,))
  244. out3 = m3.gamma(2, size=(100,))
  245. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  246. assert out1.device == "xpu0" and out2.device == "xpu1"
  247. assert not (out1.numpy() == out3.numpy()).all()
  248. assert not (out1.numpy() == out1_.numpy()).all()
  249. shape = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0")
  250. scale = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0")
  251. expected_mean = (shape * scale).numpy()
  252. expected_std = (F.sqrt(shape) * scale).numpy()
  253. out = m1.gamma(shape=shape, scale=scale, size=(20, 30, 40))
  254. out_shp = out.shape
  255. if isinstance(out_shp, tuple):
  256. assert out_shp == (20, 30, 40, 2, 3)
  257. else:
  258. assert all(out.shape.numpy() == np.array([20, 30, 40, 2, 3]))
  259. assert (
  260. np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std
  261. ).mean() < 0.1
  262. assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1
  263. @pytest.mark.skipif(
  264. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  265. )
  266. def test_BetaRNG():
  267. m1 = RNG(seed=111, device="xpu0")
  268. m2 = RNG(seed=111, device="xpu1")
  269. m3 = RNG(seed=222, device="xpu0")
  270. out1 = m1.beta(2, 1, size=(100,))
  271. out1_ = m1.uniform(size=(100,))
  272. out2 = m2.beta(2, 1, size=(100,))
  273. out3 = m3.beta(2, 1, size=(100,))
  274. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  275. assert out1.device == "xpu0" and out2.device == "xpu1"
  276. assert not (out1.numpy() == out3.numpy()).all()
  277. assert not (out1.numpy() == out1_.numpy()).all()
  278. alpha = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0")
  279. beta = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0")
  280. expected_mean = (alpha / (alpha + beta)).numpy()
  281. expected_std = (
  282. F.sqrt(alpha * beta / (F.pow(alpha + beta, 2) * (alpha + beta + 1)))
  283. ).numpy()
  284. out = m1.beta(alpha=alpha, beta=beta, size=(20, 30))
  285. out_shp = out.shape
  286. if isinstance(out_shp, tuple):
  287. assert out_shp == (20, 30, 2, 3)
  288. else:
  289. assert all(out.shape.numpy() == np.array([20, 30, 2, 3]))
  290. assert (
  291. np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std
  292. ).mean() < 0.1
  293. assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1
  294. @pytest.mark.skipif(
  295. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  296. )
  297. def test_PoissonRNG():
  298. m1 = RNG(seed=111, device="xpu0")
  299. m2 = RNG(seed=111, device="xpu1")
  300. m3 = RNG(seed=222, device="xpu0")
  301. lam = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32)
  302. out1 = m1.poisson(lam.to("xpu0"), size=(100,))
  303. out2 = m2.poisson(lam.to("xpu1"), size=(100,))
  304. out3 = m3.poisson(lam.to("xpu0"), size=(100,))
  305. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  306. assert out1.device == "xpu0" and out2.device == "xpu1"
  307. assert not (out1.numpy() == out3.numpy()).all()
  308. out = m1.poisson(lam.to("xpu0"), size=(20, 30))
  309. out_shp = out.shape
  310. expected_shape = (20, 30) + lam._tuple_shape
  311. if isinstance(out_shp, tuple):
  312. assert out_shp == expected_shape
  313. else:
  314. assert all(out.shape.numpy() == np.array(expected_shape))
  315. lam = lam.numpy()
  316. assert (np.abs(out.mean(axis=(0, 1)).numpy() - lam) / np.sqrt(lam)).mean() < 0.1
  317. assert np.abs(np.std(out.numpy(), axis=(0, 1)) - np.sqrt(lam)).mean() < 0.1
  318. @pytest.mark.skipif(
  319. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  320. )
  321. @pytest.mark.parametrize("symbolic", [True, False])
  322. def test_PermutationRNG(symbolic):
  323. m1 = RNG(seed=111, device="xpu0")
  324. m2 = RNG(seed=111, device="xpu1")
  325. m3 = RNG(seed=222, device="xpu0")
  326. out1 = m1.permutation(1000)
  327. out1_ = m1.uniform(size=(1000,))
  328. out2 = m2.permutation(1000)
  329. out3 = m3.permutation(1000)
  330. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  331. assert out1.device == "xpu0" and out2.device == "xpu1"
  332. assert not (out1.numpy() == out3.numpy()).all()
  333. assert not (out1.numpy() == out1_.numpy()).all()
  334. out = m1.permutation(1000)
  335. out_shp = out.shape
  336. if isinstance(out_shp, tuple):
  337. assert out_shp == (1000,)
  338. else:
  339. assert all(out.shape.numpy() == np.array([1000]))
  340. def sum_result(res, fun):
  341. return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))])
  342. assert sum_result(out, lambda x: x) < 500
  343. assert sum_result(out, np.sort) == 1000
  344. def func():
  345. out = m1.permutation(Tensor(7))
  346. out_shp = out.shape
  347. if isinstance(out_shp, tuple):
  348. assert out_shp == (1,)
  349. else:
  350. assert all(out.shape.numpy() == np.array([1]))
  351. n, m = 6, 3
  352. out = m1.permutation(Tensor(np.arange(n * m), dtype="float32").reshape(n, m))
  353. out_shp = out.shape
  354. if isinstance(out_shp, tuple):
  355. assert out_shp == (n, m)
  356. else:
  357. assert all(out.shape.numpy() == np.array([n, m]))
  358. func = trace(symbolic=symbolic)(func)
  359. func()
  360. @pytest.mark.skipif(
  361. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  362. )
  363. def test_ShuffleRNG():
  364. g = []
  365. def cb(grad):
  366. g.append(grad)
  367. n, m = 6, 3
  368. arr = np.arange(n * m)
  369. out0 = Tensor(arr, dtype="float32")
  370. with Grad() as grad:
  371. grad.wrt(out0, callback=cb)
  372. random.shuffle(out0)
  373. grad(out0, F.ones_like(out0))
  374. m1 = RNG(seed=111, device="xpu0")
  375. m2 = RNG(seed=111, device="xpu1")
  376. m3 = RNG(seed=222, device="xpu0")
  377. out1 = Tensor(arr, dtype="float32", device="xpu0")
  378. out2 = Tensor(arr, dtype="float32", device="xpu1")
  379. out3 = Tensor(arr, dtype="float32", device="xpu0")
  380. m1.shuffle(out1)
  381. m2.shuffle(out2)
  382. m3.shuffle(out3)
  383. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  384. assert out1.device == "xpu0" and out2.device == "xpu1"
  385. assert not (out1.numpy() == out3.numpy()).all()
  386. out = Tensor(arr, dtype="float32").reshape(n, m)
  387. m1.shuffle(out)
  388. out_shp = out.shape
  389. if isinstance(out_shp, tuple):
  390. assert out_shp == (n, m)
  391. else:
  392. assert all(out.shape.numpy() == np.array([n, m]))
  393. def test_seed():
  394. set_global_seed(10)
  395. out1 = uniform(size=[10, 10])
  396. out2 = uniform(size=[10, 10])
  397. assert not (out1.numpy() == out2.numpy()).all()
  398. set_global_seed(10)
  399. out3 = uniform(size=[10, 10])
  400. np.testing.assert_allclose(out1.numpy(), out3.numpy(), atol=1e-6)
  401. set_global_seed(11)
  402. out4 = uniform(size=[10, 10])
  403. assert not (out1.numpy() == out4.numpy()).all()
  404. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  405. def test_rng_empty_tensor(is_symbolic):
  406. set_global_seed(1024)
  407. shapes = [
  408. (0,),
  409. (0, 0, 0),
  410. (10, 0, 10),
  411. ]
  412. def fn(shape):
  413. o1 = random.uniform(0, 1, shape)
  414. o2 = random.normal(0, 1, shape)
  415. o3 = random.gamma(2, 1, shape)
  416. o4 = random.beta(2, 1, shape)
  417. o5 = random.poisson(2, shape)
  418. return o1, o2, o3, o4, o5
  419. for shape in shapes:
  420. if is_symbolic is not None:
  421. fn_ = jit.trace(symbolic=is_symbolic)(fn)
  422. else:
  423. fn_ = fn
  424. for _ in range(3):
  425. outs = fn_(shape)
  426. for out in outs:
  427. np.testing.assert_equal(out.numpy().shape, shape)
  428. if is_symbolic is None:
  429. break
  430. def fn2(n):
  431. return random.permutation(n=n)
  432. if is_symbolic is not None:
  433. fn2 = jit.trace(symbolic=is_symbolic)(fn2)
  434. for _ in range(3):
  435. out = fn2(0)
  436. np.testing.assert_equal(out.numpy().shape, (0,))
  437. if is_symbolic is None:
  438. break