You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rng.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import time
  11. from typing import Iterable, Optional, Union
  12. from numpy.random import MT19937
  13. from .. import Tensor
  14. from ..core._imperative_rt.core2 import apply
  15. from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle
  16. from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
  17. from ..core._imperative_rt.ops import (
  18. get_rng_handle_compnode as _get_rng_handle_compnode,
  19. )
  20. from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle
  21. from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed
  22. from ..core.ops.builtin import (
  23. BetaRNG,
  24. GammaRNG,
  25. GaussianRNG,
  26. PermutationRNG,
  27. PoissonRNG,
  28. ShuffleRNG,
  29. UniformRNG,
  30. )
  31. from ..core.tensor import utils
  32. from ..device import get_default_device
  33. __all__ = [
  34. "seed",
  35. "RNG",
  36. "uniform",
  37. "normal",
  38. "gamma",
  39. "beta",
  40. "poisson",
  41. "permutation",
  42. "shuffle",
  43. ]
  44. _rng = None
  45. def _infer_broadcasted_shape(inps: Iterable[Tensor]) -> tuple:
  46. broadcasted_ndim = inps[0].ndim
  47. broadcasted_shape = list(inps[0]._tuple_shape)
  48. for i in range(1, len(inps)):
  49. cur_ndim = inps[i].ndim
  50. cur_shape = list(inps[i]._tuple_shape)
  51. n_dim = max(cur_ndim, broadcasted_ndim)
  52. for j in range(n_dim - 1, -1, -1):
  53. cur_dim = cur_ndim + j - n_dim
  54. broad_dim = broadcasted_ndim + j - n_dim
  55. cur_size = cur_shape[cur_dim] if cur_dim >= 0 else 1
  56. broad_size = broadcasted_shape[broad_dim] if broad_dim >= 0 else 1
  57. assert cur_size == broad_size or cur_size == 1 or broad_size == 1, (
  58. "The size of inps[{}] ({}) must match the size ({}) at "
  59. "dim {}".format(i, cur_size, broad_size, j)
  60. )
  61. broad_size = max(cur_size, broad_size)
  62. if broad_dim < 0:
  63. broadcasted_shape = [broad_size] + broadcasted_shape
  64. broadcasted_ndim += 1
  65. else:
  66. broadcasted_shape[broad_dim] = broad_size
  67. return tuple(broadcasted_shape)
  68. def _broadcast_tensors_with_size(
  69. inps: Iterable[Tensor], size: Iterable[int]
  70. ) -> Iterable[Tensor]:
  71. assert inps, "The inps cloud not be empty"
  72. target_shape = _infer_broadcasted_shape(inps)
  73. if isinstance(size, collections.abc.Iterable):
  74. target_shape = tuple(size) + target_shape
  75. target_ndim = len(target_shape)
  76. for i in range(len(inps)):
  77. if inps[i]._tuple_shape != target_shape:
  78. inps[i] = (
  79. inps[i]
  80. .reshape((1,) * (target_ndim - inps[i].ndim) + inps[i]._tuple_shape)
  81. ._broadcast(target_shape)
  82. )
  83. return inps
  84. def _uniform(
  85. low: float,
  86. high: float,
  87. size: Optional[Iterable[int]],
  88. seed: int,
  89. device: str,
  90. handle: int,
  91. ) -> Tensor:
  92. assert low < high, "Uniform is not defined when low >= high"
  93. if size is None:
  94. size = (1,)
  95. op = UniformRNG(seed=seed, handle=handle, dtype="float32")
  96. _ref = Tensor([], dtype="int32", device=device)
  97. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  98. (output,) = apply(op, shape)
  99. return low + (high - low) * output
  100. def _normal(
  101. mean: float,
  102. std: float,
  103. size: Optional[Iterable[int]],
  104. seed: int,
  105. device: str,
  106. handle: int,
  107. ) -> Tensor:
  108. if size is None:
  109. size = (1,)
  110. op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle, dtype="float32")
  111. _ref = Tensor([], dtype="int32", device=device)
  112. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  113. (output,) = apply(op, shape)
  114. return output
  115. def _gamma(
  116. shape: Union[Tensor, float],
  117. scale: Union[Tensor, float],
  118. size: Optional[Iterable[int]],
  119. seed: int,
  120. handle: int,
  121. ) -> Tensor:
  122. handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle)
  123. if not isinstance(shape, Tensor):
  124. assert shape > 0, "Gamma is not defined when shape <= 0"
  125. shape = Tensor(shape, dtype="float32", device=handle_cn)
  126. if not isinstance(scale, Tensor):
  127. assert scale > 0, "Gamma is not defined when scale <= 0"
  128. scale = Tensor(scale, dtype="float32", device=handle_cn)
  129. assert (
  130. handle_cn is None or handle_cn == shape.device
  131. ), "The shape ({}) must be the same device with handle ({})".format(
  132. shape.device, handle_cn
  133. )
  134. assert (
  135. handle_cn is None or handle_cn == scale.device
  136. ), "The scale ({}) must be the same device with handle ({})".format(
  137. scale.device, handle_cn
  138. )
  139. if isinstance(size, int) and size != 0:
  140. size = (size,)
  141. shape, scale = _broadcast_tensors_with_size([shape, scale], size)
  142. op = GammaRNG(seed=seed, handle=handle)
  143. (output,) = apply(op, shape, scale)
  144. return output
  145. def _beta(
  146. alpha: Union[Tensor, float],
  147. beta: Union[Tensor, float],
  148. size: Optional[Iterable[int]],
  149. seed: int,
  150. handle: int,
  151. ) -> Tensor:
  152. handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle)
  153. if not isinstance(alpha, Tensor):
  154. assert alpha > 0, "Beta is not defined when alpha <= 0"
  155. alpha = Tensor(alpha, dtype="float32", device=handle_cn)
  156. if not isinstance(beta, Tensor):
  157. assert beta > 0, "Beta is not defined when beta <= 0"
  158. beta = Tensor(beta, dtype="float32", device=handle_cn)
  159. assert (
  160. handle_cn is None or handle_cn == alpha.device
  161. ), "The alpha ({}) must be the same device with handle ({})".format(
  162. alpha.device, handle_cn
  163. )
  164. assert (
  165. handle_cn is None or handle_cn == beta.device
  166. ), "The beta ({}) must be the same device with handle ({})".format(
  167. beta.device, handle_cn
  168. )
  169. if isinstance(size, int) and size != 0:
  170. size = (size,)
  171. alpha, beta = _broadcast_tensors_with_size([alpha, beta], size)
  172. op = BetaRNG(seed=seed, handle=handle)
  173. (output,) = apply(op, alpha, beta)
  174. return output
  175. def _poisson(
  176. lam: Union[Tensor, float], size: Optional[Iterable[int]], seed: int, handle: int
  177. ) -> Tensor:
  178. handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle)
  179. if not isinstance(lam, Tensor):
  180. assert lam > 0, "Poisson is not defined when lam <= 0"
  181. lam = Tensor(lam, dtype="float32", device=handle_cn)
  182. if isinstance(size, int) and size != 0:
  183. size = (size,)
  184. assert (
  185. handle_cn is None or handle_cn == lam.device
  186. ), "The lam ({}) must be the same device with handle ({})".format(
  187. lam.device, handle_cn
  188. )
  189. (lam,) = _broadcast_tensors_with_size([lam], size)
  190. op = PoissonRNG(seed=seed, handle=handle)
  191. (output,) = apply(op, lam)
  192. return output
  193. def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Tensor:
  194. assert isinstance(n, int)
  195. assert n >= 0, "Permutation is not defined when n < 0"
  196. size = (n,)
  197. op = PermutationRNG(seed=seed, handle=handle, dtype=dtype)
  198. _ref = Tensor([], dtype="int32", device=device)
  199. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  200. (output,) = apply(op, shape)
  201. return output
  202. def _shuffle(inp: Tensor, seed: int, handle: int) -> Tensor:
  203. assert inp.size > 0, "size needs to be greater than 0"
  204. op = ShuffleRNG(seed=seed, handle=handle)
  205. output, _ = apply(op, inp)
  206. return output
  207. class RNG:
  208. r""":class:`RNG` exposes a number of methods for generating random numbers.
  209. Args:
  210. seed: random seed used to initialize the pseudo-random number generator. Default: None
  211. device: the device of generated tensor. Default: None
  212. Examples:
  213. .. testcode::
  214. import megengine.random as rand
  215. rng = rand.RNG(seed=100)
  216. x = rng.uniform(size=(2, 2))
  217. print(x.numpy())
  218. Outputs:
  219. .. testoutput::
  220. :options: +SKIP
  221. [[0.84811664 0.6147553 ]
  222. [0.59429836 0.64727545]]
  223. """
  224. def __init__(self, seed: int = None, device: str = None):
  225. self._device = device if device else get_default_device()
  226. if seed is not None:
  227. self._seed = seed
  228. self._handle = _new_rng_handle(self._device, self._seed)
  229. else:
  230. self._seed = _get_global_rng_seed
  231. self._handle = 0
  232. self._device = None
  233. def uniform(
  234. self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None
  235. ):
  236. r"""Random variable with uniform distribution $U(0, 1)$.
  237. Args:
  238. low: lower range. Default: 0
  239. high: upper range. Default: 1
  240. size: the size of output tensor. Default: None
  241. Returns:
  242. the output tensor.
  243. Examples:
  244. .. testcode::
  245. import megengine as mge
  246. import megengine.random as rand
  247. x = rand.uniform(size=(2, 2))
  248. print(x.numpy())
  249. Outputs:
  250. .. testoutput::
  251. :options: +SKIP
  252. [[0.91600335 0.6680226 ]
  253. [0.2046729 0.2769141 ]]
  254. """
  255. _seed = self._seed() if callable(self._seed) else self._seed
  256. return _uniform(
  257. low=low,
  258. high=high,
  259. size=size,
  260. seed=_seed,
  261. device=self._device,
  262. handle=self._handle,
  263. )
  264. def normal(
  265. self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None
  266. ):
  267. r"""Random variable with Gaussian distribution :math:`N(\mu, \sigma)`.
  268. Args:
  269. mean: the mean or expectation of the distribution. Default: 0
  270. std: the standard deviation of the distribution (variance = :math:`\sigma ^ 2`).
  271. Default: 1
  272. size: the size of output tensor. Default: None
  273. Returns:
  274. the output tensor.
  275. Examples:
  276. .. testcode::
  277. import megengine as mge
  278. import megengine.random as rand
  279. x = rand.normal(mean=0, std=1, size=(2, 2))
  280. print(x.numpy())
  281. Outputs:
  282. .. testoutput::
  283. :options: +SKIP
  284. [[-1.4010863 -0.9874344 ]
  285. [ 0.56373274 0.79656655]]
  286. """
  287. _seed = self._seed() if callable(self._seed) else self._seed
  288. return _normal(
  289. mean=mean,
  290. std=std,
  291. size=size,
  292. seed=_seed,
  293. device=self._device,
  294. handle=self._handle,
  295. )
  296. def gamma(
  297. self,
  298. shape: Union[Tensor, float],
  299. scale: Union[Tensor, float] = 1,
  300. size: Optional[Iterable[int]] = None,
  301. ):
  302. r"""Random variable with Gamma distribution :math:`\Gamma(k, \theta)`.
  303. The corresponding probability density function is
  304. .. math::
  305. p(x)=x^{k-1} \frac{e^{-x / \theta}}{\theta^{k} \Gamma(k)}
  306. \quad \text { for } x>0 \quad k, \theta>0,
  307. where :math:`\Gamma(k)` is the gamma function,
  308. .. math::
  309. \Gamma(k)=(k-1) ! \quad \text { for } \quad k>0.
  310. Args:
  311. shape: the shape parameter (sometimes designated "k") of the distribution.
  312. Must be non-negative.
  313. scale: the scale parameter (sometimes designated "theta") of the distribution.
  314. Must be non-negative. Default: 1
  315. size: the size of output tensor. If shape and scale are scalars and given size is, e.g.,
  316. `(m, n)`, then the output shape is `(m, n)`. If shape or scale is a Tensor and given size
  317. is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(shape, scale).shape`.
  318. The broadcast rules are consistent with `numpy.broadcast`. Default: None
  319. Returns:
  320. the output tensor.
  321. Examples:
  322. .. testcode::
  323. import megengine as mge
  324. import megengine.random as rand
  325. x = rand.gamma(shape=2, scale=1, size=(2, 2))
  326. print(x.numpy())
  327. shape = mge.Tensor([[ 1],
  328. [10]], dtype="float32")
  329. scale = mge.Tensor([1,5], dtype="float32")
  330. x = rand.gamma(shape=shape, scale=scale)
  331. print(x.numpy())
  332. x = rand.gamma(shape=shape, scale=scale, size=2)
  333. print(x.numpy())
  334. Outputs:
  335. .. testoutput::
  336. :options: +SKIP
  337. [[1.5064533 4.0689363 ]
  338. [0.71639484 1.4551026 ]]
  339. [[ 0.4352188 11.399335 ]
  340. [ 9.1888 52.009277 ]]
  341. [[[ 1.1726005 3.9654975 ]
  342. [13.656933 36.559006 ]]
  343. [[ 0.25848487 2.5540342 ]
  344. [11.960409 21.031536 ]]]
  345. """
  346. _seed = self._seed() if callable(self._seed) else self._seed
  347. return _gamma(
  348. shape=shape, scale=scale, size=size, seed=_seed, handle=self._handle
  349. )
  350. def beta(
  351. self,
  352. alpha: Union[Tensor, float],
  353. beta: Union[Tensor, float],
  354. size: Optional[Iterable[int]] = None,
  355. ):
  356. r"""Random variable with Beta distribution :math:`\operatorname{Beta}(\alpha, \beta)`.
  357. The corresponding probability density function is
  358. .. math::
  359. p(x)=\frac{1}{\mathrm{~B}(\alpha, \beta)} x^{\alpha-1}(1-x)^{\beta-1}
  360. \quad \text { for } \alpha, \beta>0,
  361. where :math:`\mathrm{~B}(\alpha, \beta)` is the beta function,
  362. .. math::
  363. \mathrm{~B}(\alpha, \beta)=\int_{0}^{1} t^{\alpha-1}(1-t)^{\beta-1} d t.
  364. Args:
  365. alpha: the alpha parameter of the distribution. Must be non-negative.
  366. beta: the beta parameter of the distribution. Must be non-negative.
  367. size: the size of output tensor. If alpha and beta are scalars and given size is, e.g.,
  368. `(m, n)`, then the output shape is `(m, n)`. If alpha or beta is a Tensor and given size
  369. is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(alpha, beta).shape`.
  370. Returns:
  371. the output tensor.
  372. Examples:
  373. .. testcode::
  374. import megengine as mge
  375. import megengine.random as rand
  376. x = rand.beta(alpha=2, beta=1, size=(2, 2))
  377. print(x.numpy())
  378. alpha = mge.Tensor([[0.5],
  379. [ 3]], dtype="float32")
  380. beta = mge.Tensor([0.5,5], dtype="float32")
  381. x = rand.beta(alpha=alpha, beta=beta)
  382. print(x.numpy())
  383. x = rand.beta(alpha=alpha, beta=beta, size=2)
  384. print(x.numpy())
  385. Outputs:
  386. .. testoutput::
  387. :options: +SKIP
  388. [[0.582565 0.91763186]
  389. [0.86963767 0.6088103 ]]
  390. [[0.41503012 0.16438372]
  391. [0.90159506 0.47588003]]
  392. [[[0.55195075 0.01111084]
  393. [0.95298755 0.25048104]]
  394. [[0.11680304 0.13859665]
  395. [0.997879 0.43259275]]]
  396. """
  397. _seed = self._seed() if callable(self._seed) else self._seed
  398. return _beta(alpha=alpha, beta=beta, size=size, seed=_seed, handle=self._handle)
  399. def poisson(self, lam: Union[float, Tensor], size: Optional[Iterable[int]] = None):
  400. r"""Random variable with poisson distribution :math:`\operatorname{Poisson}(\lambda)`.
  401. The corresponding probability density function is
  402. .. math::
  403. f(k ; \lambda)=\frac{\lambda^{k} e^{-\lambda}}{k !},
  404. where k is the number of occurrences :math:`({\displaystyle k=0,1,2...})`.
  405. Args:
  406. lam: the lambda parameter of the distribution. Must be non-negative.
  407. size: the size of output tensor. If lam is a scalar and given size is, e.g., `(m, n)`,
  408. then the output shape is `(m, n)`. If lam is a Tensor with shape `(k, v)` and given
  409. size is, e.g., `(m, n)`, then the output shape is `(m, n, k, v)`. Default: None.
  410. Returns:
  411. the output tensor.
  412. Examples:
  413. .. testcode::
  414. import megengine as mge
  415. import megengine.random as rand
  416. x = rand.poisson(lam=2., size=(1, 3))
  417. print(x.numpy())
  418. lam = mge.Tensor([[1.,1.],
  419. [10,10]], dtype="float32")
  420. x = rand.poisson(lam=lam)
  421. print(x.numpy())
  422. x = rand.poisson(lam=lam, size=(1,3))
  423. print(x.numpy())
  424. Outputs:
  425. .. testoutput::
  426. :options: +SKIP
  427. [[3. 1. 3.]]
  428. [[ 2. 2.]
  429. [12. 11.]]
  430. [[[[ 1. 1.]
  431. [11. 4.]]
  432. [[ 0. 0.]
  433. [ 9. 13.]]
  434. [[ 0. 1.]
  435. [ 7. 12.]]]]
  436. """
  437. _seed = self._seed() if callable(self._seed) else self._seed
  438. return _poisson(lam=lam, size=size, seed=_seed, handle=self._handle)
  439. def permutation(self, n: Union[int, Tensor], *, dtype: str = "int32"):
  440. r"""Randomly permute a sequence, or return a permuted range.
  441. If ``n`` is a multi-dimensional tensor, it is only shuffled along its first index.
  442. Args:
  443. n: If ``n`` is an integer, random permutation of integers from :math:`0` to :math:`n - 1`.
  444. If ``n`` is an tensor, make a copy and shuffle the elements randomly.
  445. dtype: the output data type when ``n`` is an integer.
  446. int32, int16 and float32 are supported. Default: int32
  447. Returns:
  448. the output tensor.
  449. Examples:
  450. .. testcode::
  451. import numpy as np
  452. import megengine as mge
  453. import megengine.random as rand
  454. x = rand.permutation(10, dtype="int32")
  455. print(x.numpy())
  456. x = rand.permutation(10, dtype="float32")
  457. print(x.numpy())
  458. x = mge.tensor(np.arange(18)).reshape(6,3)
  459. x = rand.permutation(x)
  460. print(x.numpy())
  461. Outputs:
  462. .. testoutput::
  463. :options: +SKIP
  464. [4 5 0 7 3 8 6 1 9 2]
  465. [3. 4. 9. 0. 6. 8. 7. 1. 5. 2.]
  466. [[12 13 14]
  467. [ 3 4 5]
  468. [15 16 17]
  469. [ 0 1 2]
  470. [ 9 10 11]
  471. [ 6 7 8]]
  472. """
  473. _seed = self._seed() if callable(self._seed) else self._seed
  474. if isinstance(n, int):
  475. return _permutation(
  476. n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype
  477. )
  478. assert isinstance(n, Tensor)
  479. return _shuffle(inp=n, seed=_seed, handle=self._handle)
  480. def shuffle(self, inp: Tensor):
  481. r"""Modify a sequence in-place by shuffling its contents.
  482. This function only shuffles the Tensor along the first axis of a multi-dimensional Tensor.
  483. The order of sub-Tensors is changed but their contents remains the same.
  484. Args:
  485. inp: input tensor.
  486. Examples:
  487. .. testcode::
  488. import numpy as np
  489. import megengine as mge
  490. import megengine.random as rand
  491. x = mge.tensor(np.arange(10))
  492. rand.shuffle(x)
  493. print(x.numpy())
  494. y = mge.tensor(np.arange(18)).reshape(6,3)
  495. rand.shuffle(y)
  496. print(y.numpy())
  497. Outputs:
  498. .. testoutput::
  499. :options: +SKIP
  500. [7 9 3 0 8 2 4 5 6 1]
  501. [[12. 13. 14.]
  502. [ 3. 4. 5.]
  503. [15. 16. 17.]
  504. [ 0. 1. 2.]
  505. [ 9. 10. 11.]
  506. [ 6. 7. 8.]]
  507. """
  508. _seed = self._seed() if callable(self._seed) else self._seed
  509. inp._reset(_shuffle(inp=inp, seed=_seed, handle=self._handle))
  510. def __del__(self):
  511. if self._handle != 0:
  512. _delete_rng_handle(self._handle)
  513. def _default_rng():
  514. r"""Default constructor for :class:`RNG`."""
  515. return RNG(seed=None, device=None)
  516. _default_handle = _default_rng()
  517. uniform = _default_handle.uniform
  518. normal = _default_handle.normal
  519. gamma = _default_handle.gamma
  520. beta = _default_handle.beta
  521. poisson = _default_handle.poisson
  522. permutation = _default_handle.permutation
  523. shuffle = _default_handle.shuffle
  524. def _random_seed_generator():
  525. assert _rng
  526. while True:
  527. yield _rng.random_raw()
  528. def seed(seed: int):
  529. global _rng # pylint: disable=global-statement
  530. _rng = MT19937(seed=seed)
  531. _set_global_rng_seed(seed)
  532. seed(int(time.time()))