You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rng.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import pytest
  11. import megengine.functional as F
  12. from megengine import Tensor, jit, random
  13. from megengine.core._imperative_rt import CompNode
  14. from megengine.core._imperative_rt.core2 import apply
  15. from megengine.core._imperative_rt.ops import (
  16. delete_rng_handle,
  17. get_global_rng_seed,
  18. new_rng_handle,
  19. )
  20. from megengine.core.autodiff.grad import Grad
  21. from megengine.core.ops.builtin import (
  22. BetaRNG,
  23. GammaRNG,
  24. GaussianRNG,
  25. PermutationRNG,
  26. PoissonRNG,
  27. UniformRNG,
  28. )
  29. from megengine.device import get_device_count
  30. from megengine.jit import trace
  31. from megengine.random import RNG
  32. from megengine.random import seed as set_global_seed
  33. from megengine.random import uniform
  34. @pytest.mark.skipif(
  35. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  36. )
  37. def test_gaussian_op():
  38. set_global_seed(1024)
  39. shape = (
  40. 8,
  41. 9,
  42. 11,
  43. 12,
  44. )
  45. shape = Tensor(shape, dtype="int32")
  46. op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0, dtype="float32")
  47. (output,) = apply(op, shape)
  48. assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
  49. assert np.fabs(np.sqrt(output.numpy().var()) - 3.0) < 1e-1
  50. assert str(output.device) == str(CompNode("xpux"))
  51. assert output.dtype == np.float32
  52. cn = CompNode("xpu2")
  53. seed = 233333
  54. h = new_rng_handle(cn, seed)
  55. op = GaussianRNG(seed=seed, mean=3.0, std=1.0, dtype="float32", handle=h)
  56. (output,) = apply(op, shape)
  57. delete_rng_handle(h)
  58. assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
  59. assert np.fabs(np.sqrt(output.numpy().var()) - 1.0) < 1e-1
  60. assert str(output.device) == str(cn)
  61. assert output.dtype == np.float32
  62. @pytest.mark.skipif(
  63. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  64. )
  65. def test_uniform_op():
  66. set_global_seed(1024)
  67. shape = (
  68. 8,
  69. 9,
  70. 11,
  71. 12,
  72. )
  73. shape = Tensor(shape, dtype="int32")
  74. op = UniformRNG(seed=get_global_rng_seed(), dtype="float32")
  75. (output,) = apply(op, shape)
  76. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  77. assert str(output.device) == str(CompNode("xpux"))
  78. assert output.dtype == np.float32
  79. cn = CompNode("xpu2")
  80. seed = 233333
  81. h = new_rng_handle(cn, seed)
  82. op = UniformRNG(seed=seed, dtype="float32", handle=h)
  83. (output,) = apply(op, shape)
  84. delete_rng_handle(h)
  85. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  86. assert str(output.device) == str(cn)
  87. assert output.dtype == np.float32
  88. @pytest.mark.skipif(
  89. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  90. )
  91. def test_gamma_op():
  92. set_global_seed(1024)
  93. _shape, _scale = 2, 0.8
  94. _expected_mean, _expected_std = _shape * _scale, np.sqrt(_shape) * _scale
  95. shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32")
  96. scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32")
  97. op = GammaRNG(seed=get_global_rng_seed(), handle=0)
  98. (output,) = apply(op, shape, scale)
  99. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  100. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  101. assert str(output.device) == str(CompNode("xpux"))
  102. cn = CompNode("xpu2")
  103. seed = 233333
  104. h = new_rng_handle(cn, seed)
  105. shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32", device="xpu2")
  106. scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32", device="xpu2")
  107. op = GammaRNG(seed=seed, handle=h)
  108. (output,) = apply(op, shape, scale)
  109. delete_rng_handle(h)
  110. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  111. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  112. assert str(output.device) == str(cn)
  113. @pytest.mark.skipif(
  114. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  115. )
  116. def test_beta_op():
  117. set_global_seed(1024)
  118. _alpha, _beta = 2, 0.8
  119. _expected_mean = _alpha / (_alpha + _beta)
  120. _expected_std = np.sqrt(
  121. _alpha * _beta / ((_alpha + _beta) ** 2 * (_alpha + _beta + 1))
  122. )
  123. alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32")
  124. beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32")
  125. op = BetaRNG(seed=get_global_rng_seed())
  126. (output,) = apply(op, alpha, beta)
  127. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  128. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  129. assert str(output.device) == str(CompNode("xpux"))
  130. cn = CompNode("xpu2")
  131. seed = 233333
  132. h = new_rng_handle(cn, seed)
  133. alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32", device=cn)
  134. beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32", device=cn)
  135. op = BetaRNG(seed=seed, handle=h)
  136. (output,) = apply(op, alpha, beta)
  137. delete_rng_handle(h)
  138. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  139. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  140. assert str(output.device) == str(cn)
  141. @pytest.mark.skipif(
  142. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  143. )
  144. def test_poisson_op():
  145. set_global_seed(1024)
  146. lam = F.full([8, 9, 11, 12], value=2, dtype="float32")
  147. op = PoissonRNG(seed=get_global_rng_seed())
  148. (output,) = apply(op, lam)
  149. assert np.fabs(output.numpy().mean() - 2.0) < 1e-1
  150. assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1
  151. assert str(output.device) == str(CompNode("xpux"))
  152. cn = CompNode("xpu2")
  153. seed = 233333
  154. h = new_rng_handle(cn, seed)
  155. lam = F.full([8, 9, 11, 12], value=2, dtype="float32", device=cn)
  156. op = PoissonRNG(seed=seed, handle=h)
  157. (output,) = apply(op, lam)
  158. delete_rng_handle(h)
  159. assert np.fabs(output.numpy().mean() - 2.0) < 1e-1
  160. assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1
  161. assert str(output.device) == str(cn)
  162. @pytest.mark.skipif(
  163. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  164. )
  165. def test_permutation_op():
  166. set_global_seed(1024)
  167. n = 1000
  168. def test_permutation_op_dtype(dtype):
  169. def sum_result(res, fun):
  170. return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))])
  171. shape = Tensor((n,), dtype="int32")
  172. op = PermutationRNG(seed=get_global_rng_seed(), dtype=dtype)
  173. (output,) = apply(op, shape)
  174. assert sum_result(output, lambda x: x) < 500
  175. assert sum_result(output, np.sort) == n
  176. assert str(output.device) == str(CompNode("xpux"))
  177. assert output.dtype == dtype
  178. cn = CompNode("xpu2")
  179. seed = 233333
  180. h = new_rng_handle(cn, seed)
  181. op = PermutationRNG(seed=seed, handle=h, dtype=dtype)
  182. (output,) = apply(op, shape)
  183. delete_rng_handle(h)
  184. assert sum_result(output, lambda x: x) < 500
  185. assert sum_result(output, np.sort) == n
  186. assert str(output.device) == str(cn)
  187. assert output.dtype == dtype
  188. test_permutation_op_dtype(np.float32)
  189. test_permutation_op_dtype(np.int32)
  190. test_permutation_op_dtype(np.int16)
  191. @pytest.mark.skipif(
  192. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  193. )
  194. def test_UniformRNG():
  195. m1 = RNG(seed=111, device="xpu0")
  196. m2 = RNG(seed=111, device="xpu1")
  197. m3 = RNG(seed=222, device="xpu0")
  198. out1 = m1.uniform(size=(100,))
  199. out1_ = m1.uniform(size=(100,))
  200. out2 = m2.uniform(size=(100,))
  201. out3 = m3.uniform(size=(100,))
  202. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  203. assert out1.device == "xpu0" and out2.device == "xpu1"
  204. assert not (out1.numpy() == out3.numpy()).all()
  205. assert not (out1.numpy() == out1_.numpy()).all()
  206. low = -234
  207. high = 123
  208. out = m1.uniform(low=low, high=high, size=(20, 30, 40))
  209. out_shp = out.shape
  210. if isinstance(out_shp, tuple):
  211. assert out_shp == (20, 30, 40)
  212. else:
  213. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  214. assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1
  215. @pytest.mark.skipif(
  216. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  217. )
  218. def test_NormalRNG():
  219. m1 = RNG(seed=111, device="xpu0")
  220. m2 = RNG(seed=111, device="xpu1")
  221. m3 = RNG(seed=222, device="xpu0")
  222. out1 = m1.normal(size=(100,))
  223. out1_ = m1.uniform(size=(100,))
  224. out2 = m2.normal(size=(100,))
  225. out3 = m3.normal(size=(100,))
  226. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  227. assert out1.device == "xpu0" and out2.device == "xpu1"
  228. assert not (out1.numpy() == out3.numpy()).all()
  229. assert not (out1.numpy() == out1_.numpy()).all()
  230. mean = -1
  231. std = 2
  232. out = m1.normal(mean=mean, std=std, size=(20, 30, 40))
  233. out_shp = out.shape
  234. if isinstance(out_shp, tuple):
  235. assert out_shp == (20, 30, 40)
  236. else:
  237. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  238. assert np.abs(out.mean().numpy() - mean) / std < 0.1
  239. assert np.abs(np.std(out.numpy()) - std) < 0.1
  240. @pytest.mark.skipif(
  241. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  242. )
  243. def test_GammaRNG():
  244. m1 = RNG(seed=111, device="xpu0")
  245. m2 = RNG(seed=111, device="xpu1")
  246. m3 = RNG(seed=222, device="xpu0")
  247. out1 = m1.gamma(2, size=(100,))
  248. out1_ = m1.uniform(size=(100,))
  249. out2 = m2.gamma(2, size=(100,))
  250. out3 = m3.gamma(2, size=(100,))
  251. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  252. assert out1.device == "xpu0" and out2.device == "xpu1"
  253. assert not (out1.numpy() == out3.numpy()).all()
  254. assert not (out1.numpy() == out1_.numpy()).all()
  255. shape = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0")
  256. scale = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0")
  257. expected_mean = (shape * scale).numpy()
  258. expected_std = (F.sqrt(shape) * scale).numpy()
  259. out = m1.gamma(shape=shape, scale=scale, size=(20, 30, 40))
  260. out_shp = out.shape
  261. if isinstance(out_shp, tuple):
  262. assert out_shp == (20, 30, 40, 2, 3)
  263. else:
  264. assert all(out.shape.numpy() == np.array([20, 30, 40, 2, 3]))
  265. assert (
  266. np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std
  267. ).mean() < 0.1
  268. assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1
  269. @pytest.mark.skipif(
  270. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  271. )
  272. def test_BetaRNG():
  273. m1 = RNG(seed=111, device="xpu0")
  274. m2 = RNG(seed=111, device="xpu1")
  275. m3 = RNG(seed=222, device="xpu0")
  276. out1 = m1.beta(2, 1, size=(100,))
  277. out1_ = m1.uniform(size=(100,))
  278. out2 = m2.beta(2, 1, size=(100,))
  279. out3 = m3.beta(2, 1, size=(100,))
  280. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  281. assert out1.device == "xpu0" and out2.device == "xpu1"
  282. assert not (out1.numpy() == out3.numpy()).all()
  283. assert not (out1.numpy() == out1_.numpy()).all()
  284. alpha = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0")
  285. beta = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0")
  286. expected_mean = (alpha / (alpha + beta)).numpy()
  287. expected_std = (
  288. F.sqrt(alpha * beta / (F.pow(alpha + beta, 2) * (alpha + beta + 1)))
  289. ).numpy()
  290. out = m1.beta(alpha=alpha, beta=beta, size=(20, 30))
  291. out_shp = out.shape
  292. if isinstance(out_shp, tuple):
  293. assert out_shp == (20, 30, 2, 3)
  294. else:
  295. assert all(out.shape.numpy() == np.array([20, 30, 2, 3]))
  296. assert (
  297. np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std
  298. ).mean() < 0.1
  299. assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1
  300. @pytest.mark.skipif(
  301. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  302. )
  303. def test_PoissonRNG():
  304. m1 = RNG(seed=111, device="xpu0")
  305. m2 = RNG(seed=111, device="xpu1")
  306. m3 = RNG(seed=222, device="xpu0")
  307. lam = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32)
  308. out1 = m1.poisson(lam.to("xpu0"), size=(100,))
  309. out2 = m2.poisson(lam.to("xpu1"), size=(100,))
  310. out3 = m3.poisson(lam.to("xpu0"), size=(100,))
  311. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  312. assert out1.device == "xpu0" and out2.device == "xpu1"
  313. assert not (out1.numpy() == out3.numpy()).all()
  314. out = m1.poisson(lam.to("xpu0"), size=(20, 30))
  315. out_shp = out.shape
  316. expected_shape = (20, 30) + lam._tuple_shape
  317. if isinstance(out_shp, tuple):
  318. assert out_shp == expected_shape
  319. else:
  320. assert all(out.shape.numpy() == np.array(expected_shape))
  321. lam = lam.numpy()
  322. assert (np.abs(out.mean(axis=(0, 1)).numpy() - lam) / np.sqrt(lam)).mean() < 0.1
  323. assert np.abs(np.std(out.numpy(), axis=(0, 1)) - np.sqrt(lam)).mean() < 0.1
  324. @pytest.mark.skipif(
  325. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  326. )
  327. @pytest.mark.parametrize("symbolic", [True, False])
  328. def test_PermutationRNG(symbolic):
  329. m1 = RNG(seed=111, device="xpu0")
  330. m2 = RNG(seed=111, device="xpu1")
  331. m3 = RNG(seed=222, device="xpu0")
  332. out1 = m1.permutation(1000)
  333. out1_ = m1.uniform(size=(1000,))
  334. out2 = m2.permutation(1000)
  335. out3 = m3.permutation(1000)
  336. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  337. assert out1.device == "xpu0" and out2.device == "xpu1"
  338. assert not (out1.numpy() == out3.numpy()).all()
  339. assert not (out1.numpy() == out1_.numpy()).all()
  340. out = m1.permutation(1000)
  341. out_shp = out.shape
  342. if isinstance(out_shp, tuple):
  343. assert out_shp == (1000,)
  344. else:
  345. assert all(out.shape.numpy() == np.array([1000]))
  346. def sum_result(res, fun):
  347. return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))])
  348. assert sum_result(out, lambda x: x) < 500
  349. assert sum_result(out, np.sort) == 1000
  350. def func():
  351. out = m1.permutation(Tensor(7))
  352. out_shp = out.shape
  353. if isinstance(out_shp, tuple):
  354. assert out_shp == (1,)
  355. else:
  356. assert all(out.shape.numpy() == np.array([1]))
  357. n, m = 6, 3
  358. out = m1.permutation(Tensor(np.arange(n * m), dtype="float32").reshape(n, m))
  359. out_shp = out.shape
  360. if isinstance(out_shp, tuple):
  361. assert out_shp == (n, m)
  362. else:
  363. assert all(out.shape.numpy() == np.array([n, m]))
  364. func = trace(symbolic=symbolic)(func)
  365. func()
  366. @pytest.mark.skipif(
  367. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  368. )
  369. def test_ShuffleRNG():
  370. g = []
  371. def cb(grad):
  372. g.append(grad)
  373. n, m = 6, 3
  374. arr = np.arange(n * m)
  375. out0 = Tensor(arr, dtype="float32")
  376. grad = Grad().wrt(out0, callback=cb)
  377. random.shuffle(out0)
  378. grad(out0, F.ones_like(out0))
  379. m1 = RNG(seed=111, device="xpu0")
  380. m2 = RNG(seed=111, device="xpu1")
  381. m3 = RNG(seed=222, device="xpu0")
  382. out1 = Tensor(arr, dtype="float32", device="xpu0")
  383. out2 = Tensor(arr, dtype="float32", device="xpu1")
  384. out3 = Tensor(arr, dtype="float32", device="xpu0")
  385. m1.shuffle(out1)
  386. m2.shuffle(out2)
  387. m3.shuffle(out3)
  388. np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
  389. assert out1.device == "xpu0" and out2.device == "xpu1"
  390. assert not (out1.numpy() == out3.numpy()).all()
  391. out = Tensor(arr, dtype="float32").reshape(n, m)
  392. m1.shuffle(out)
  393. out_shp = out.shape
  394. if isinstance(out_shp, tuple):
  395. assert out_shp == (n, m)
  396. else:
  397. assert all(out.shape.numpy() == np.array([n, m]))
  398. def test_seed():
  399. set_global_seed(10)
  400. out1 = uniform(size=[10, 10])
  401. out2 = uniform(size=[10, 10])
  402. assert not (out1.numpy() == out2.numpy()).all()
  403. set_global_seed(10)
  404. out3 = uniform(size=[10, 10])
  405. np.testing.assert_allclose(out1.numpy(), out3.numpy(), atol=1e-6)
  406. set_global_seed(11)
  407. out4 = uniform(size=[10, 10])
  408. assert not (out1.numpy() == out4.numpy()).all()
  409. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  410. def test_rng_empty_tensor(is_symbolic):
  411. set_global_seed(1024)
  412. shapes = [
  413. (0,),
  414. (0, 0, 0),
  415. (10, 0, 10),
  416. ]
  417. def fn(shape):
  418. o1 = random.uniform(0, 1, shape)
  419. o2 = random.normal(0, 1, shape)
  420. o3 = random.gamma(2, 1, shape)
  421. o4 = random.beta(2, 1, shape)
  422. o5 = random.poisson(2, shape)
  423. return o1, o2, o3, o4, o5
  424. for shape in shapes:
  425. if is_symbolic is not None:
  426. fn_ = jit.trace(symbolic=is_symbolic)(fn)
  427. else:
  428. fn_ = fn
  429. for _ in range(3):
  430. outs = fn_(shape)
  431. for out in outs:
  432. np.testing.assert_equal(out.numpy().shape, shape)
  433. if is_symbolic is None:
  434. break
  435. def fn2(n):
  436. return random.permutation(n=n)
  437. if is_symbolic is not None:
  438. fn2 = jit.trace(symbolic=is_symbolic)(fn2)
  439. for _ in range(3):
  440. out = fn2(0)
  441. np.testing.assert_equal(out.numpy().shape, (0,))
  442. if is_symbolic is None:
  443. break