You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rng.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import pytest
  11. import megengine as mge
  12. import megengine.functional as F
  13. from megengine import Tensor, jit, random
  14. from megengine.core._imperative_rt import CompNode
  15. from megengine.core._imperative_rt.core2 import apply
  16. from megengine.core._imperative_rt.ops import (
  17. delete_rng_handle,
  18. get_global_rng_seed,
  19. new_rng_handle,
  20. )
  21. from megengine.core.autodiff.grad import Grad
  22. from megengine.core.ops.builtin import (
  23. BetaRNG,
  24. GammaRNG,
  25. GaussianRNG,
  26. PermutationRNG,
  27. PoissonRNG,
  28. UniformRNG,
  29. )
  30. from megengine.device import get_device_count
  31. from megengine.jit import trace
  32. from megengine.random import RNG
  33. from megengine.random import seed as set_global_seed
  34. from megengine.random import uniform
  35. @pytest.mark.skipif(
  36. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  37. )
  38. def test_gaussian_op():
  39. set_global_seed(1024)
  40. shape = (
  41. 8,
  42. 9,
  43. 11,
  44. 12,
  45. )
  46. shape = Tensor(shape, dtype="int32")
  47. op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0, dtype="float32")
  48. (output,) = apply(op, shape)
  49. assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
  50. assert np.fabs(np.sqrt(output.numpy().var()) - 3.0) < 1e-1
  51. assert str(output.device) == str(CompNode("xpux"))
  52. assert output.dtype == np.float32
  53. cn = CompNode("xpu2")
  54. seed = 233333
  55. h = new_rng_handle(cn, seed)
  56. op = GaussianRNG(seed=seed, mean=3.0, std=1.0, dtype="float32", handle=h)
  57. (output,) = apply(op, shape)
  58. delete_rng_handle(h)
  59. assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
  60. assert np.fabs(np.sqrt(output.numpy().var()) - 1.0) < 1e-1
  61. assert str(output.device) == str(cn)
  62. assert output.dtype == np.float32
  63. @pytest.mark.skipif(
  64. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  65. )
  66. def test_uniform_op():
  67. set_global_seed(1024)
  68. shape = (
  69. 8,
  70. 9,
  71. 11,
  72. 12,
  73. )
  74. shape = Tensor(shape, dtype="int32")
  75. op = UniformRNG(seed=get_global_rng_seed(), dtype="float32")
  76. (output,) = apply(op, shape)
  77. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  78. assert str(output.device) == str(CompNode("xpux"))
  79. assert output.dtype == np.float32
  80. cn = CompNode("xpu2")
  81. seed = 233333
  82. h = new_rng_handle(cn, seed)
  83. op = UniformRNG(seed=seed, dtype="float32", handle=h)
  84. (output,) = apply(op, shape)
  85. delete_rng_handle(h)
  86. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  87. assert str(output.device) == str(cn)
  88. assert output.dtype == np.float32
  89. @pytest.mark.skipif(
  90. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  91. )
  92. def test_gamma_op():
  93. set_global_seed(1024)
  94. _shape, _scale = 2, 0.8
  95. _expected_mean, _expected_std = _shape * _scale, np.sqrt(_shape) * _scale
  96. shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32")
  97. scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32")
  98. op = GammaRNG(seed=get_global_rng_seed(), handle=0)
  99. (output,) = apply(op, shape, scale)
  100. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  101. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  102. assert str(output.device) == str(CompNode("xpux"))
  103. cn = CompNode("xpu2")
  104. seed = 233333
  105. h = new_rng_handle(cn, seed)
  106. shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32", device="xpu2")
  107. scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32", device="xpu2")
  108. op = GammaRNG(seed=seed, handle=h)
  109. (output,) = apply(op, shape, scale)
  110. delete_rng_handle(h)
  111. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  112. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  113. assert str(output.device) == str(cn)
  114. @pytest.mark.skipif(
  115. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  116. )
  117. def test_beta_op():
  118. set_global_seed(1024)
  119. _alpha, _beta = 2, 0.8
  120. _expected_mean = _alpha / (_alpha + _beta)
  121. _expected_std = np.sqrt(
  122. _alpha * _beta / ((_alpha + _beta) ** 2 * (_alpha + _beta + 1))
  123. )
  124. alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32")
  125. beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32")
  126. op = BetaRNG(seed=get_global_rng_seed())
  127. (output,) = apply(op, alpha, beta)
  128. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  129. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  130. assert str(output.device) == str(CompNode("xpux"))
  131. cn = CompNode("xpu2")
  132. seed = 233333
  133. h = new_rng_handle(cn, seed)
  134. alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32", device=cn)
  135. beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32", device=cn)
  136. op = BetaRNG(seed=seed, handle=h)
  137. (output,) = apply(op, alpha, beta)
  138. delete_rng_handle(h)
  139. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  140. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  141. assert str(output.device) == str(cn)
  142. @pytest.mark.skipif(
  143. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  144. )
  145. def test_poisson_op():
  146. set_global_seed(1024)
  147. lam = F.full([8, 9, 11, 12], value=2, dtype="float32")
  148. op = PoissonRNG(seed=get_global_rng_seed())
  149. (output,) = apply(op, lam)
  150. assert np.fabs(output.numpy().mean() - 2.0) < 1e-1
  151. assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1
  152. assert str(output.device) == str(CompNode("xpux"))
  153. cn = CompNode("xpu2")
  154. seed = 233333
  155. h = new_rng_handle(cn, seed)
  156. lam = F.full([8, 9, 11, 12], value=2, dtype="float32", device=cn)
  157. op = PoissonRNG(seed=seed, handle=h)
  158. (output,) = apply(op, lam)
  159. delete_rng_handle(h)
  160. assert np.fabs(output.numpy().mean() - 2.0) < 1e-1
  161. assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1
  162. assert str(output.device) == str(cn)
  163. @pytest.mark.skipif(
  164. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  165. )
  166. def test_permutation_op():
  167. set_global_seed(1024)
  168. n = 1000
  169. def test_permutation_op_dtype(dtype):
  170. def sum_result(res, fun):
  171. return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))])
  172. shape = Tensor((n,), dtype="int32")
  173. op = PermutationRNG(seed=get_global_rng_seed(), dtype=dtype)
  174. (output,) = apply(op, shape)
  175. assert sum_result(output, lambda x: x) < 500
  176. assert sum_result(output, np.sort) == n
  177. assert str(output.device) == str(CompNode("xpux"))
  178. assert output.dtype == dtype
  179. cn = CompNode("xpu2")
  180. seed = 233333
  181. h = new_rng_handle(cn, seed)
  182. op = PermutationRNG(seed=seed, handle=h, dtype=dtype)
  183. (output,) = apply(op, shape)
  184. delete_rng_handle(h)
  185. assert sum_result(output, lambda x: x) < 500
  186. assert sum_result(output, np.sort) == n
  187. assert str(output.device) == str(cn)
  188. assert output.dtype == dtype
  189. # FIXME: remove this sync
  190. mge.core.set_option("async_level", 0)
  191. test_permutation_op_dtype(np.float32)
  192. test_permutation_op_dtype(np.int32)
  193. test_permutation_op_dtype(np.int16)
  194. mge.core.set_option("async_level", 2)
  195. @pytest.mark.skipif(
  196. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  197. )
  198. def test_UniformRNG():
  199. m1 = RNG(seed=111, device="xpu0")
  200. m2 = RNG(seed=111, device="xpu1")
  201. m3 = RNG(seed=222, device="xpu0")
  202. out1 = m1.uniform(size=(100,))
  203. out1_ = m1.uniform(size=(100,))
  204. out2 = m2.uniform(size=(100,))
  205. out3 = m3.uniform(size=(100,))
  206. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  207. assert out1.device == "xpu0" and out2.device == "xpu1"
  208. assert not (out1.numpy() == out3.numpy()).all()
  209. assert not (out1.numpy() == out1_.numpy()).all()
  210. low = -234
  211. high = 123
  212. out = m1.uniform(low=low, high=high, size=(20, 30, 40))
  213. out_shp = out.shape
  214. if isinstance(out_shp, tuple):
  215. assert out_shp == (20, 30, 40)
  216. else:
  217. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  218. assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1
  219. @pytest.mark.skipif(
  220. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  221. )
  222. def test_NormalRNG():
  223. m1 = RNG(seed=111, device="xpu0")
  224. m2 = RNG(seed=111, device="xpu1")
  225. m3 = RNG(seed=222, device="xpu0")
  226. out1 = m1.normal(size=(100,))
  227. out1_ = m1.uniform(size=(100,))
  228. out2 = m2.normal(size=(100,))
  229. out3 = m3.normal(size=(100,))
  230. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  231. assert out1.device == "xpu0" and out2.device == "xpu1"
  232. assert not (out1.numpy() == out3.numpy()).all()
  233. assert not (out1.numpy() == out1_.numpy()).all()
  234. mean = -1
  235. std = 2
  236. out = m1.normal(mean=mean, std=std, size=(20, 30, 40))
  237. out_shp = out.shape
  238. if isinstance(out_shp, tuple):
  239. assert out_shp == (20, 30, 40)
  240. else:
  241. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  242. assert np.abs(out.mean().numpy() - mean) / std < 0.1
  243. assert np.abs(np.std(out.numpy()) - std) < 0.1
  244. @pytest.mark.skipif(
  245. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  246. )
  247. def test_GammaRNG():
  248. m1 = RNG(seed=111, device="xpu0")
  249. m2 = RNG(seed=111, device="xpu1")
  250. m3 = RNG(seed=222, device="xpu0")
  251. out1 = m1.gamma(2, size=(100,))
  252. out1_ = m1.uniform(size=(100,))
  253. out2 = m2.gamma(2, size=(100,))
  254. out3 = m3.gamma(2, size=(100,))
  255. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  256. assert out1.device == "xpu0" and out2.device == "xpu1"
  257. assert not (out1.numpy() == out3.numpy()).all()
  258. assert not (out1.numpy() == out1_.numpy()).all()
  259. shape = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0")
  260. scale = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0")
  261. expected_mean = (shape * scale).numpy()
  262. expected_std = (F.sqrt(shape) * scale).numpy()
  263. out = m1.gamma(shape=shape, scale=scale, size=(20, 30, 40))
  264. out_shp = out.shape
  265. if isinstance(out_shp, tuple):
  266. assert out_shp == (20, 30, 40, 2, 3)
  267. else:
  268. assert all(out.shape.numpy() == np.array([20, 30, 40, 2, 3]))
  269. assert (
  270. np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std
  271. ).mean() < 0.1
  272. assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1
  273. @pytest.mark.skipif(
  274. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  275. )
  276. def test_BetaRNG():
  277. m1 = RNG(seed=111, device="xpu0")
  278. m2 = RNG(seed=111, device="xpu1")
  279. m3 = RNG(seed=222, device="xpu0")
  280. out1 = m1.beta(2, 1, size=(100,))
  281. out1_ = m1.uniform(size=(100,))
  282. out2 = m2.beta(2, 1, size=(100,))
  283. out3 = m3.beta(2, 1, size=(100,))
  284. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  285. assert out1.device == "xpu0" and out2.device == "xpu1"
  286. assert not (out1.numpy() == out3.numpy()).all()
  287. assert not (out1.numpy() == out1_.numpy()).all()
  288. alpha = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0")
  289. beta = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0")
  290. expected_mean = (alpha / (alpha + beta)).numpy()
  291. expected_std = (
  292. F.sqrt(alpha * beta / (F.pow(alpha + beta, 2) * (alpha + beta + 1)))
  293. ).numpy()
  294. out = m1.beta(alpha=alpha, beta=beta, size=(20, 30))
  295. out_shp = out.shape
  296. if isinstance(out_shp, tuple):
  297. assert out_shp == (20, 30, 2, 3)
  298. else:
  299. assert all(out.shape.numpy() == np.array([20, 30, 2, 3]))
  300. assert (
  301. np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std
  302. ).mean() < 0.1
  303. assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1
  304. @pytest.mark.skipif(
  305. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  306. )
  307. def test_PoissonRNG():
  308. m1 = RNG(seed=111, device="xpu0")
  309. m2 = RNG(seed=111, device="xpu1")
  310. m3 = RNG(seed=222, device="xpu0")
  311. lam = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32)
  312. out1 = m1.poisson(lam.to("xpu0"), size=(100,))
  313. out2 = m2.poisson(lam.to("xpu1"), size=(100,))
  314. out3 = m3.poisson(lam.to("xpu0"), size=(100,))
  315. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  316. assert out1.device == "xpu0" and out2.device == "xpu1"
  317. assert not (out1.numpy() == out3.numpy()).all()
  318. out = m1.poisson(lam.to("xpu0"), size=(20, 30))
  319. out_shp = out.shape
  320. expected_shape = (20, 30) + lam._tuple_shape
  321. if isinstance(out_shp, tuple):
  322. assert out_shp == expected_shape
  323. else:
  324. assert all(out.shape.numpy() == np.array(expected_shape))
  325. lam = lam.numpy()
  326. assert (np.abs(out.mean(axis=(0, 1)).numpy() - lam) / np.sqrt(lam)).mean() < 0.1
  327. assert np.abs(np.std(out.numpy(), axis=(0, 1)) - np.sqrt(lam)).mean() < 0.1
  328. @pytest.mark.skipif(
  329. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  330. )
  331. @pytest.mark.parametrize("symbolic", [True, False])
  332. def test_PermutationRNG(symbolic):
  333. m1 = RNG(seed=111, device="xpu0")
  334. m2 = RNG(seed=111, device="xpu1")
  335. m3 = RNG(seed=222, device="xpu0")
  336. out1 = m1.permutation(1000)
  337. out1_ = m1.uniform(size=(1000,))
  338. out2 = m2.permutation(1000)
  339. out3 = m3.permutation(1000)
  340. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  341. assert out1.device == "xpu0" and out2.device == "xpu1"
  342. assert not (out1.numpy() == out3.numpy()).all()
  343. assert not (out1.numpy() == out1_.numpy()).all()
  344. out = m1.permutation(1000)
  345. out_shp = out.shape
  346. if isinstance(out_shp, tuple):
  347. assert out_shp == (1000,)
  348. else:
  349. assert all(out.shape.numpy() == np.array([1000]))
  350. def sum_result(res, fun):
  351. return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))])
  352. assert sum_result(out, lambda x: x) < 500
  353. assert sum_result(out, np.sort) == 1000
  354. def func():
  355. out = m1.permutation(Tensor(7))
  356. out_shp = out.shape
  357. if isinstance(out_shp, tuple):
  358. assert out_shp == (1,)
  359. else:
  360. assert all(out.shape.numpy() == np.array([1]))
  361. n, m = 6, 3
  362. out = m1.permutation(Tensor(np.arange(n * m), dtype="float32").reshape(n, m))
  363. out_shp = out.shape
  364. if isinstance(out_shp, tuple):
  365. assert out_shp == (n, m)
  366. else:
  367. assert all(out.shape.numpy() == np.array([n, m]))
  368. func = trace(symbolic=symbolic)(func)
  369. func()
  370. @pytest.mark.skipif(
  371. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  372. )
  373. def test_ShuffleRNG():
  374. g = []
  375. def cb(grad):
  376. g.append(grad)
  377. n, m = 6, 3
  378. arr = np.arange(n * m)
  379. out0 = Tensor(arr, dtype="float32")
  380. with Grad() as grad:
  381. grad.wrt(out0, callback=cb)
  382. random.shuffle(out0)
  383. grad(out0, F.ones_like(out0))
  384. m1 = RNG(seed=111, device="xpu0")
  385. m2 = RNG(seed=111, device="xpu1")
  386. m3 = RNG(seed=222, device="xpu0")
  387. out1 = Tensor(arr, dtype="float32", device="xpu0")
  388. out2 = Tensor(arr, dtype="float32", device="xpu1")
  389. out3 = Tensor(arr, dtype="float32", device="xpu0")
  390. m1.shuffle(out1)
  391. m2.shuffle(out2)
  392. m3.shuffle(out3)
  393. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  394. assert out1.device == "xpu0" and out2.device == "xpu1"
  395. assert not (out1.numpy() == out3.numpy()).all()
  396. out = Tensor(arr, dtype="float32").reshape(n, m)
  397. m1.shuffle(out)
  398. out_shp = out.shape
  399. if isinstance(out_shp, tuple):
  400. assert out_shp == (n, m)
  401. else:
  402. assert all(out.shape.numpy() == np.array([n, m]))
  403. def test_seed():
  404. set_global_seed(10)
  405. out1 = uniform(size=[10, 10])
  406. out2 = uniform(size=[10, 10])
  407. assert not (out1.numpy() == out2.numpy()).all()
  408. set_global_seed(10)
  409. out3 = uniform(size=[10, 10])
  410. np.testing.assert_allclose(out1.numpy(), out3.numpy(), atol=1e-6)
  411. set_global_seed(11)
  412. out4 = uniform(size=[10, 10])
  413. assert not (out1.numpy() == out4.numpy()).all()
  414. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  415. def test_rng_empty_tensor(is_symbolic):
  416. set_global_seed(1024)
  417. shapes = [
  418. (0,),
  419. (0, 0, 0),
  420. (10, 0, 10),
  421. ]
  422. def fn(shape):
  423. o1 = random.uniform(0, 1, shape)
  424. o2 = random.normal(0, 1, shape)
  425. o3 = random.gamma(2, 1, shape)
  426. o4 = random.beta(2, 1, shape)
  427. o5 = random.poisson(2, shape)
  428. return o1, o2, o3, o4, o5
  429. for shape in shapes:
  430. if is_symbolic is not None:
  431. fn_ = jit.trace(symbolic=is_symbolic)(fn)
  432. else:
  433. fn_ = fn
  434. for _ in range(3):
  435. outs = fn_(shape)
  436. for out in outs:
  437. np.testing.assert_equal(out.numpy().shape, shape)
  438. if is_symbolic is None:
  439. break
  440. def fn2(n):
  441. return random.permutation(n=n)
  442. if is_symbolic is not None:
  443. fn2 = jit.trace(symbolic=is_symbolic)(fn2)
  444. for _ in range(3):
  445. out = fn2(0)
  446. np.testing.assert_equal(out.numpy().shape, (0,))
  447. if is_symbolic is None:
  448. break