You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rng.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import pytest
  11. import megengine.functional as F
  12. from megengine import Tensor
  13. from megengine.core._imperative_rt import CompNode
  14. from megengine.core._imperative_rt.core2 import apply
  15. from megengine.core._imperative_rt.ops import (
  16. delete_rng_handle,
  17. get_global_rng_seed,
  18. new_rng_handle,
  19. )
  20. from megengine.core.ops.builtin import (
  21. BetaRNG,
  22. GammaRNG,
  23. GaussianRNG,
  24. PermutationRNG,
  25. PoissonRNG,
  26. UniformRNG,
  27. )
  28. from megengine.distributed.helper import get_device_count_by_fork
  29. from megengine.random import RNG, seed, uniform
  30. @pytest.mark.skipif(
  31. get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2",
  32. )
  33. def test_gaussian_op():
  34. shape = (
  35. 8,
  36. 9,
  37. 11,
  38. 12,
  39. )
  40. shape = Tensor(shape, dtype="int32")
  41. op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0, dtype="float32")
  42. (output,) = apply(op, shape)
  43. assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
  44. assert np.fabs(np.sqrt(output.numpy().var()) - 3.0) < 1e-1
  45. assert str(output.device) == str(CompNode("xpux"))
  46. assert output.dtype == np.float32
  47. cn = CompNode("xpu2")
  48. seed = 233333
  49. h = new_rng_handle(cn, seed)
  50. op = GaussianRNG(seed=seed, mean=3.0, std=1.0, dtype="float32", handle=h)
  51. (output,) = apply(op, shape)
  52. delete_rng_handle(h)
  53. assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
  54. assert np.fabs(np.sqrt(output.numpy().var()) - 1.0) < 1e-1
  55. assert str(output.device) == str(cn)
  56. assert output.dtype == np.float32
  57. @pytest.mark.skipif(
  58. get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2",
  59. )
  60. def test_uniform_op():
  61. shape = (
  62. 8,
  63. 9,
  64. 11,
  65. 12,
  66. )
  67. shape = Tensor(shape, dtype="int32")
  68. op = UniformRNG(seed=get_global_rng_seed(), dtype="float32")
  69. (output,) = apply(op, shape)
  70. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  71. assert str(output.device) == str(CompNode("xpux"))
  72. assert output.dtype == np.float32
  73. cn = CompNode("xpu2")
  74. seed = 233333
  75. h = new_rng_handle(cn, seed)
  76. op = UniformRNG(seed=seed, dtype="float32", handle=h)
  77. (output,) = apply(op, shape)
  78. delete_rng_handle(h)
  79. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  80. assert str(output.device) == str(cn)
  81. assert output.dtype == np.float32
  82. @pytest.mark.skipif(
  83. get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2",
  84. )
  85. def test_gamma_op():
  86. _shape, _scale = 2, 0.8
  87. _expected_mean, _expected_std = _shape * _scale, np.sqrt(_shape) * _scale
  88. shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32")
  89. scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32")
  90. op = GammaRNG(seed=get_global_rng_seed(), handle=0)
  91. (output,) = apply(op, shape, scale)
  92. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  93. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  94. assert str(output.device) == str(CompNode("xpux"))
  95. cn = CompNode("xpu2")
  96. seed = 233333
  97. h = new_rng_handle(cn, seed)
  98. shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32", device="xpu2")
  99. scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32", device="xpu2")
  100. op = GammaRNG(seed=seed, handle=h)
  101. (output,) = apply(op, shape, scale)
  102. delete_rng_handle(h)
  103. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  104. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  105. assert str(output.device) == str(cn)
  106. @pytest.mark.skipif(
  107. get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2",
  108. )
  109. def test_beta_op():
  110. _alpha, _beta = 2, 0.8
  111. _expected_mean = _alpha / (_alpha + _beta)
  112. _expected_std = np.sqrt(
  113. _alpha * _beta / ((_alpha + _beta) ** 2 * (_alpha + _beta + 1))
  114. )
  115. alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32")
  116. beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32")
  117. op = BetaRNG(seed=get_global_rng_seed())
  118. (output,) = apply(op, alpha, beta)
  119. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  120. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  121. assert str(output.device) == str(CompNode("xpux"))
  122. cn = CompNode("xpu2")
  123. seed = 233333
  124. h = new_rng_handle(cn, seed)
  125. alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32", device=cn)
  126. beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32", device=cn)
  127. op = BetaRNG(seed=seed, handle=h)
  128. (output,) = apply(op, alpha, beta)
  129. delete_rng_handle(h)
  130. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  131. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  132. assert str(output.device) == str(cn)
  133. @pytest.mark.skipif(
  134. get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2",
  135. )
  136. def test_poisson_op():
  137. lam = F.full([8, 9, 11, 12], value=2, dtype="float32")
  138. op = PoissonRNG(seed=get_global_rng_seed())
  139. (output,) = apply(op, lam)
  140. assert np.fabs(output.numpy().mean() - 2.0) < 1e-1
  141. assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1
  142. assert str(output.device) == str(CompNode("xpux"))
  143. cn = CompNode("xpu2")
  144. seed = 233333
  145. h = new_rng_handle(cn, seed)
  146. lam = F.full([8, 9, 11, 12], value=2, dtype="float32", device=cn)
  147. op = PoissonRNG(seed=seed, handle=h)
  148. (output,) = apply(op, lam)
  149. delete_rng_handle(h)
  150. assert np.fabs(output.numpy().mean() - 2.0) < 1e-1
  151. assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1
  152. assert str(output.device) == str(cn)
  153. @pytest.mark.skipif(
  154. get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2",
  155. )
  156. def test_permutation_op():
  157. n = 1000
  158. def test_permutation_op_dtype(dtype):
  159. def sum_result(res, fun):
  160. return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))])
  161. shape = Tensor((n,), dtype="int32")
  162. op = PermutationRNG(seed=get_global_rng_seed(), dtype=dtype)
  163. (output,) = apply(op, shape)
  164. assert sum_result(output, lambda x: x) < 500
  165. assert sum_result(output, np.sort) == n
  166. assert str(output.device) == str(CompNode("xpux"))
  167. assert output.dtype == dtype
  168. cn = CompNode("xpu2")
  169. seed = 233333
  170. h = new_rng_handle(cn, seed)
  171. op = PermutationRNG(seed=seed, handle=h, dtype=dtype)
  172. (output,) = apply(op, shape)
  173. delete_rng_handle(h)
  174. assert sum_result(output, lambda x: x) < 500
  175. assert sum_result(output, np.sort) == n
  176. assert str(output.device) == str(cn)
  177. assert output.dtype == dtype
  178. test_permutation_op_dtype(np.float32)
  179. test_permutation_op_dtype(np.int32)
  180. test_permutation_op_dtype(np.int16)
  181. @pytest.mark.skipif(
  182. get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1",
  183. )
  184. def test_UniformRNG():
  185. m1 = RNG(seed=111, device="xpu0")
  186. m2 = RNG(seed=111, device="xpu1")
  187. m3 = RNG(seed=222, device="xpu0")
  188. out1 = m1.uniform(size=(100,))
  189. out1_ = m1.uniform(size=(100,))
  190. out2 = m2.uniform(size=(100,))
  191. out3 = m3.uniform(size=(100,))
  192. np.testing.assert_equal(out1.numpy(), out2.numpy())
  193. assert out1.device == "xpu0" and out2.device == "xpu1"
  194. assert not (out1.numpy() == out3.numpy()).all()
  195. assert not (out1.numpy() == out1_.numpy()).all()
  196. low = -234
  197. high = 123
  198. out = m1.uniform(low=low, high=high, size=(20, 30, 40))
  199. out_shp = out.shape
  200. if isinstance(out_shp, tuple):
  201. assert out_shp == (20, 30, 40)
  202. else:
  203. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  204. assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1
  205. @pytest.mark.skipif(
  206. get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1",
  207. )
  208. def test_NormalRNG():
  209. m1 = RNG(seed=111, device="xpu0")
  210. m2 = RNG(seed=111, device="xpu1")
  211. m3 = RNG(seed=222, device="xpu0")
  212. out1 = m1.normal(size=(100,))
  213. out1_ = m1.uniform(size=(100,))
  214. out2 = m2.normal(size=(100,))
  215. out3 = m3.normal(size=(100,))
  216. np.testing.assert_equal(out1.numpy(), out2.numpy())
  217. assert out1.device == "xpu0" and out2.device == "xpu1"
  218. assert not (out1.numpy() == out3.numpy()).all()
  219. assert not (out1.numpy() == out1_.numpy()).all()
  220. mean = -1
  221. std = 2
  222. out = m1.normal(mean=mean, std=std, size=(20, 30, 40))
  223. out_shp = out.shape
  224. if isinstance(out_shp, tuple):
  225. assert out_shp == (20, 30, 40)
  226. else:
  227. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  228. assert np.abs(out.mean().numpy() - mean) / std < 0.1
  229. assert np.abs(np.std(out.numpy()) - std) < 0.1
  230. @pytest.mark.skipif(
  231. get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1",
  232. )
  233. def test_GammaRNG():
  234. m1 = RNG(seed=111, device="xpu0")
  235. m2 = RNG(seed=111, device="xpu1")
  236. m3 = RNG(seed=222, device="xpu0")
  237. out1 = m1.gamma(2, size=(100,))
  238. out1_ = m1.uniform(size=(100,))
  239. out2 = m2.gamma(2, size=(100,))
  240. out3 = m3.gamma(2, size=(100,))
  241. np.testing.assert_equal(out1.numpy(), out2.numpy())
  242. assert out1.device == "xpu0" and out2.device == "xpu1"
  243. assert not (out1.numpy() == out3.numpy()).all()
  244. assert not (out1.numpy() == out1_.numpy()).all()
  245. shape = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0")
  246. scale = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0")
  247. expected_mean = (shape * scale).numpy()
  248. expected_std = (F.sqrt(shape) * scale).numpy()
  249. out = m1.gamma(shape=shape, scale=scale, size=(20, 30, 40))
  250. out_shp = out.shape
  251. if isinstance(out_shp, tuple):
  252. assert out_shp == (20, 30, 40, 2, 3)
  253. else:
  254. assert all(out.shape.numpy() == np.array([20, 30, 40, 2, 3]))
  255. assert (
  256. np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std
  257. ).mean() < 0.1
  258. assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1
  259. @pytest.mark.skipif(
  260. get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1",
  261. )
  262. def test_BetaRNG():
  263. m1 = RNG(seed=111, device="xpu0")
  264. m2 = RNG(seed=111, device="xpu1")
  265. m3 = RNG(seed=222, device="xpu0")
  266. out1 = m1.beta(2, 1, size=(100,))
  267. out1_ = m1.uniform(size=(100,))
  268. out2 = m2.beta(2, 1, size=(100,))
  269. out3 = m3.beta(2, 1, size=(100,))
  270. np.testing.assert_equal(out1.numpy(), out2.numpy())
  271. assert out1.device == "xpu0" and out2.device == "xpu1"
  272. assert not (out1.numpy() == out3.numpy()).all()
  273. assert not (out1.numpy() == out1_.numpy()).all()
  274. alpha = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0")
  275. beta = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0")
  276. expected_mean = (alpha / (alpha + beta)).numpy()
  277. expected_std = (
  278. F.sqrt(alpha * beta / (F.pow(alpha + beta, 2) * (alpha + beta + 1)))
  279. ).numpy()
  280. out = m1.beta(alpha=alpha, beta=beta, size=(20, 30))
  281. out_shp = out.shape
  282. if isinstance(out_shp, tuple):
  283. assert out_shp == (20, 30, 2, 3)
  284. else:
  285. assert all(out.shape.numpy() == np.array([20, 30, 2, 3]))
  286. assert (
  287. np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std
  288. ).mean() < 0.1
  289. assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1
  290. @pytest.mark.skipif(
  291. get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1",
  292. )
  293. def test_PoissonRNG():
  294. m1 = RNG(seed=111, device="xpu0")
  295. m2 = RNG(seed=111, device="xpu1")
  296. m3 = RNG(seed=222, device="xpu0")
  297. lam = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32)
  298. out1 = m1.poisson(lam.to("xpu0"), size=(100,))
  299. out2 = m2.poisson(lam.to("xpu1"), size=(100,))
  300. out3 = m3.poisson(lam.to("xpu0"), size=(100,))
  301. np.testing.assert_equal(out1.numpy(), out2.numpy())
  302. assert out1.device == "xpu0" and out2.device == "xpu1"
  303. assert not (out1.numpy() == out3.numpy()).all()
  304. out = m1.poisson(lam.to("xpu0"), size=(20, 30))
  305. out_shp = out.shape
  306. expected_shape = (20, 30) + lam._tuple_shape
  307. if isinstance(out_shp, tuple):
  308. assert out_shp == expected_shape
  309. else:
  310. assert all(out.shape.numpy() == np.array(expected_shape))
  311. lam = lam.numpy()
  312. assert (np.abs(out.mean(axis=(0, 1)).numpy() - lam) / np.sqrt(lam)).mean() < 0.1
  313. assert np.abs(np.std(out.numpy(), axis=(0, 1)) - np.sqrt(lam)).mean() < 0.1
  314. @pytest.mark.skipif(
  315. get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1",
  316. )
  317. def test_PermutationRNG():
  318. m1 = RNG(seed=111, device="xpu0")
  319. m2 = RNG(seed=111, device="xpu1")
  320. m3 = RNG(seed=222, device="xpu0")
  321. out1 = m1.permutation(n=1000)
  322. out1_ = m1.uniform(size=(1000,))
  323. out2 = m2.permutation(n=1000)
  324. out3 = m3.permutation(n=1000)
  325. np.testing.assert_equal(out1.numpy(), out2.numpy())
  326. assert out1.device == "xpu0" and out2.device == "xpu1"
  327. assert not (out1.numpy() == out3.numpy()).all()
  328. assert not (out1.numpy() == out1_.numpy()).all()
  329. out = m1.permutation(n=1000)
  330. out_shp = out.shape
  331. if isinstance(out_shp, tuple):
  332. assert out_shp == (1000,)
  333. else:
  334. assert all(out.shape.numpy() == np.array([1000]))
  335. def sum_result(res, fun):
  336. return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))])
  337. assert sum_result(out, lambda x: x) < 500
  338. assert sum_result(out, np.sort) == 1000
  339. def test_seed():
  340. seed(10)
  341. out1 = uniform(size=[10, 10])
  342. out2 = uniform(size=[10, 10])
  343. assert not (out1.numpy() == out2.numpy()).all()
  344. seed(10)
  345. out3 = uniform(size=[10, 10])
  346. np.testing.assert_equal(out1.numpy(), out3.numpy())
  347. seed(11)
  348. out4 = uniform(size=[10, 10])
  349. assert not (out1.numpy() == out4.numpy()).all()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台