You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. # -*- coding: utf-8 -*-
  2. import collections
  3. import time
  4. from typing import Iterable, Optional, Union
  5. from numpy.random import MT19937
  6. from .. import Tensor
  7. from ..core._imperative_rt.core2 import apply
  8. from ..core._imperative_rt.core2 import sync as _sync
  9. from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle
  10. from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
  11. from ..core._imperative_rt.ops import (
  12. get_rng_handle_compnode as _get_rng_handle_compnode,
  13. )
  14. from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle
  15. from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed
  16. from ..core.ops.builtin import (
  17. BetaRNG,
  18. GammaRNG,
  19. GaussianRNG,
  20. PermutationRNG,
  21. PoissonRNG,
  22. ShuffleRNG,
  23. UniformRNG,
  24. )
  25. from ..core.tensor import utils
  26. from ..device import get_default_device
  27. __all__ = [
  28. "seed",
  29. "RNG",
  30. "uniform",
  31. "normal",
  32. "gamma",
  33. "beta",
  34. "poisson",
  35. "permutation",
  36. "shuffle",
  37. ]
  38. _rng = None
  39. def _infer_broadcasted_shape(inps: Iterable[Tensor]) -> tuple:
  40. broadcasted_ndim = inps[0].ndim
  41. broadcasted_shape = list(inps[0]._tuple_shape)
  42. for i in range(1, len(inps)):
  43. cur_ndim = inps[i].ndim
  44. cur_shape = list(inps[i]._tuple_shape)
  45. n_dim = max(cur_ndim, broadcasted_ndim)
  46. for j in range(n_dim - 1, -1, -1):
  47. cur_dim = cur_ndim + j - n_dim
  48. broad_dim = broadcasted_ndim + j - n_dim
  49. cur_size = cur_shape[cur_dim] if cur_dim >= 0 else 1
  50. broad_size = broadcasted_shape[broad_dim] if broad_dim >= 0 else 1
  51. assert cur_size == broad_size or cur_size == 1 or broad_size == 1, (
  52. "The size of inps[{}] ({}) must match the size ({}) at "
  53. "dim {}".format(i, cur_size, broad_size, j)
  54. )
  55. broad_size = max(cur_size, broad_size)
  56. if broad_dim < 0:
  57. broadcasted_shape = [broad_size] + broadcasted_shape
  58. broadcasted_ndim += 1
  59. else:
  60. broadcasted_shape[broad_dim] = broad_size
  61. return tuple(broadcasted_shape)
  62. def _broadcast_tensors_with_size(
  63. inps: Iterable[Tensor], size: Iterable[int]
  64. ) -> Iterable[Tensor]:
  65. assert inps, "The inps cloud not be empty"
  66. target_shape = _infer_broadcasted_shape(inps)
  67. if isinstance(size, collections.abc.Iterable):
  68. target_shape = tuple(size) + target_shape
  69. target_ndim = len(target_shape)
  70. for i in range(len(inps)):
  71. if inps[i]._tuple_shape != target_shape:
  72. inps[i] = (
  73. inps[i]
  74. .reshape((1,) * (target_ndim - inps[i].ndim) + inps[i]._tuple_shape)
  75. ._broadcast(target_shape)
  76. )
  77. return inps
  78. def _uniform(
  79. low: float,
  80. high: float,
  81. size: Optional[Iterable[int]],
  82. seed: int,
  83. device: str,
  84. handle: int,
  85. ) -> Tensor:
  86. assert low < high, "Uniform is not defined when low >= high"
  87. if size is None:
  88. size = (1,)
  89. op = UniformRNG(seed=seed, handle=handle, dtype="float32")
  90. _ref = Tensor([], dtype="int32", device=device)
  91. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  92. (output,) = apply(op, shape)
  93. if low == 0 and high == 1:
  94. return output
  95. return low + (high - low) * output
  96. def _normal(
  97. mean: float,
  98. std: float,
  99. size: Optional[Iterable[int]],
  100. seed: int,
  101. device: str,
  102. handle: int,
  103. ) -> Tensor:
  104. if size is None:
  105. size = (1,)
  106. op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle, dtype="float32")
  107. _ref = Tensor([], dtype="int32", device=device)
  108. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  109. (output,) = apply(op, shape)
  110. return output
  111. def _gamma(
  112. shape: Union[Tensor, float],
  113. scale: Union[Tensor, float],
  114. size: Optional[Iterable[int]],
  115. seed: int,
  116. handle: int,
  117. ) -> Tensor:
  118. handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle)
  119. if not isinstance(shape, Tensor):
  120. assert shape > 0, "Gamma is not defined when shape <= 0"
  121. shape = Tensor(shape, dtype="float32", device=handle_cn)
  122. if not isinstance(scale, Tensor):
  123. assert scale > 0, "Gamma is not defined when scale <= 0"
  124. scale = Tensor(scale, dtype="float32", device=handle_cn)
  125. assert (
  126. handle_cn is None or handle_cn == shape.device
  127. ), "The shape ({}) must be the same device with handle ({})".format(
  128. shape.device, handle_cn
  129. )
  130. assert (
  131. handle_cn is None or handle_cn == scale.device
  132. ), "The scale ({}) must be the same device with handle ({})".format(
  133. scale.device, handle_cn
  134. )
  135. if isinstance(size, int) and size != 0:
  136. size = (size,)
  137. shape, scale = _broadcast_tensors_with_size([shape, scale], size)
  138. op = GammaRNG(seed=seed, handle=handle)
  139. (output,) = apply(op, shape, scale)
  140. return output
  141. def _beta(
  142. alpha: Union[Tensor, float],
  143. beta: Union[Tensor, float],
  144. size: Optional[Iterable[int]],
  145. seed: int,
  146. handle: int,
  147. ) -> Tensor:
  148. handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle)
  149. if not isinstance(alpha, Tensor):
  150. assert alpha > 0, "Beta is not defined when alpha <= 0"
  151. alpha = Tensor(alpha, dtype="float32", device=handle_cn)
  152. if not isinstance(beta, Tensor):
  153. assert beta > 0, "Beta is not defined when beta <= 0"
  154. beta = Tensor(beta, dtype="float32", device=handle_cn)
  155. assert (
  156. handle_cn is None or handle_cn == alpha.device
  157. ), "The alpha ({}) must be the same device with handle ({})".format(
  158. alpha.device, handle_cn
  159. )
  160. assert (
  161. handle_cn is None or handle_cn == beta.device
  162. ), "The beta ({}) must be the same device with handle ({})".format(
  163. beta.device, handle_cn
  164. )
  165. if isinstance(size, int) and size != 0:
  166. size = (size,)
  167. alpha, beta = _broadcast_tensors_with_size([alpha, beta], size)
  168. op = BetaRNG(seed=seed, handle=handle)
  169. (output,) = apply(op, alpha, beta)
  170. return output
  171. def _poisson(
  172. lam: Union[Tensor, float], size: Optional[Iterable[int]], seed: int, handle: int
  173. ) -> Tensor:
  174. handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle)
  175. if not isinstance(lam, Tensor):
  176. assert lam > 0, "Poisson is not defined when lam <= 0"
  177. lam = Tensor(lam, dtype="float32", device=handle_cn)
  178. if isinstance(size, int) and size != 0:
  179. size = (size,)
  180. assert (
  181. handle_cn is None or handle_cn == lam.device
  182. ), "The lam ({}) must be the same device with handle ({})".format(
  183. lam.device, handle_cn
  184. )
  185. (lam,) = _broadcast_tensors_with_size([lam], size)
  186. op = PoissonRNG(seed=seed, handle=handle)
  187. (output,) = apply(op, lam)
  188. return output
  189. def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Tensor:
  190. assert isinstance(n, int)
  191. assert n >= 0, "Permutation is not defined when n < 0"
  192. size = (n,)
  193. op = PermutationRNG(seed=seed, handle=handle, dtype=dtype)
  194. _ref = Tensor([], dtype="int32", device=device)
  195. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  196. (output,) = apply(op, shape)
  197. return output
  198. def _shuffle(inp: Tensor, seed: int, handle: int) -> Tensor:
  199. assert inp.size > 0, "size needs to be greater than 0"
  200. op = ShuffleRNG(seed=seed, handle=handle)
  201. output, _ = apply(op, inp)
  202. return output
  203. class RNG:
  204. r""":class:`RNG` exposes a number of methods for generating random numbers.
  205. Args:
  206. seed: random seed used to initialize the pseudo-random number generator. Default: None
  207. device: the device of generated tensor. Default: None
  208. Examples:
  209. >>> import megengine.random as rand
  210. >>> rng = rand.RNG(seed=100)
  211. >>> x = rng.uniform(size=(2, 2))
  212. >>> x.numpy() # doctest: +SKIP
  213. array([[0.84811664, 0.6147553 ],
  214. [0.59429836, 0.64727545]], dtype=float32)
  215. """
  216. def __init__(self, seed: int = None, device: str = None):
  217. self._device = device if device else get_default_device()
  218. if seed is not None:
  219. self._seed = seed
  220. self._handle = _new_rng_handle(self._device, self._seed)
  221. else:
  222. self._seed = _get_global_rng_seed
  223. self._handle = 0
  224. self._device = None
  225. def uniform(
  226. self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None
  227. ):
  228. r"""Random variable with uniform distribution $U(0, 1)$.
  229. Args:
  230. low: lower range. Default: 0
  231. high: upper range. Default: 1
  232. size: the size of output tensor. Default: None
  233. Returns:
  234. the output tensor.
  235. Examples:
  236. >>> import megengine.random as rand
  237. >>> x = rand.uniform(size=(2, 2))
  238. >>> x.numpy() # doctest: +SKIP
  239. array([[0.28603864, 0.3156649 ],
  240. [0.42066026, 0.9805052 ]], dtype=float32)
  241. """
  242. _seed = self._seed() if callable(self._seed) else self._seed
  243. return _uniform(
  244. low=low,
  245. high=high,
  246. size=size,
  247. seed=_seed,
  248. device=self._device,
  249. handle=self._handle,
  250. )
  251. def normal(
  252. self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None
  253. ):
  254. r"""Random variable with Gaussian distribution :math:`N(\mu, \sigma)`.
  255. Args:
  256. mean: the mean or expectation of the distribution. Default: 0
  257. std: the standard deviation of the distribution (variance = :math:`\sigma ^ 2`).
  258. Default: 1
  259. size: the size of output tensor. Default: None
  260. Returns:
  261. the output tensor.
  262. Examples:
  263. >>> import megengine.random as rand
  264. >>> x = rand.normal(mean=0, std=1, size=(2, 2))
  265. >>> x.numpy() # doctest: +SKIP
  266. array([[ 1.5534291 , -0.28356555],
  267. [ 2.2230418 , -0.92425716]], dtype=float32)
  268. """
  269. _seed = self._seed() if callable(self._seed) else self._seed
  270. return _normal(
  271. mean=mean,
  272. std=std,
  273. size=size,
  274. seed=_seed,
  275. device=self._device,
  276. handle=self._handle,
  277. )
  278. def gamma(
  279. self,
  280. shape: Union[Tensor, float],
  281. scale: Union[Tensor, float] = 1,
  282. size: Optional[Iterable[int]] = None,
  283. ):
  284. r"""Random variable with Gamma distribution :math:`\Gamma(k, \theta)`.
  285. The corresponding probability density function is
  286. .. math::
  287. p(x)=x^{k-1} \frac{e^{-x / \theta}}{\theta^{k} \Gamma(k)}
  288. \quad \text { for } x>0 \quad k, \theta>0,
  289. where :math:`\Gamma(k)` is the gamma function,
  290. .. math::
  291. \Gamma(k)=(k-1) ! \quad \text { for } \quad k>0.
  292. Args:
  293. shape: the shape parameter (sometimes designated "k") of the distribution.
  294. Must be non-negative.
  295. scale: the scale parameter (sometimes designated "theta") of the distribution.
  296. Must be non-negative. Default: 1
  297. size: the size of output tensor. If shape and scale are scalars and given size is, e.g.,
  298. `(m, n)`, then the output shape is `(m, n)`. If shape or scale is a Tensor and given size
  299. is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(shape, scale).shape`.
  300. The broadcast rules are consistent with `numpy.broadcast`. Default: None
  301. Returns:
  302. the output tensor.
  303. Examples:
  304. >>> import megengine.random as rand
  305. >>> x = rand.gamma(shape=2, scale=1, size=(2, 2))
  306. >>> x.numpy() # doctest: +SKIP
  307. array([[0.97447544, 1.5668875 ],
  308. [1.0069491 , 0.3078318 ]], dtype=float32)
  309. >>> shape = mge.Tensor([[ 1],
  310. ... [10]], dtype="float32")
  311. >>> scale = mge.Tensor([1,5], dtype="float32")
  312. >>> x = rand.gamma(shape=shape, scale=scale)
  313. >>> x.numpy() # doctest: +SKIP
  314. array([[ 0.11312152, 3.0799196 ],
  315. [10.973469 , 29.596972 ]], dtype=float32)
  316. >>> x = rand.gamma(shape=shape, scale=scale, size=2)
  317. >>> x.numpy() # doctest: +SKIP
  318. array([[[4.35868073e+00, 1.22415285e+01],
  319. [1.02696848e+01, 4.19773598e+01]],
  320. [[7.73875117e-02, 6.06766164e-01],
  321. [1.22881927e+01, 8.13445740e+01]]], dtype=float32)
  322. """
  323. _seed = self._seed() if callable(self._seed) else self._seed
  324. return _gamma(
  325. shape=shape, scale=scale, size=size, seed=_seed, handle=self._handle
  326. )
  327. def beta(
  328. self,
  329. alpha: Union[Tensor, float],
  330. beta: Union[Tensor, float],
  331. size: Optional[Iterable[int]] = None,
  332. ):
  333. r"""Random variable with Beta distribution :math:`\operatorname{Beta}(\alpha, \beta)`.
  334. The corresponding probability density function is
  335. .. math::
  336. p(x)=\frac{1}{\mathrm{~B}(\alpha, \beta)} x^{\alpha-1}(1-x)^{\beta-1}
  337. \quad \text { for } \alpha, \beta>0,
  338. where :math:`\mathrm{~B}(\alpha, \beta)` is the beta function,
  339. .. math::
  340. \mathrm{~B}(\alpha, \beta)=\int_{0}^{1} t^{\alpha-1}(1-t)^{\beta-1} d t.
  341. Args:
  342. alpha: the alpha parameter of the distribution. Must be non-negative.
  343. beta: the beta parameter of the distribution. Must be non-negative.
  344. size: the size of output tensor. If alpha and beta are scalars and given size is, e.g.,
  345. `(m, n)`, then the output shape is `(m, n)`. If alpha or beta is a Tensor and given size
  346. is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(alpha, beta).shape`.
  347. Returns:
  348. the output tensor.
  349. Examples:
  350. >>> import megengine.random as rand
  351. >>> x = rand.beta(alpha=2, beta=1, size=(2, 2))
  352. >>> x.numpy() # doctest: +SKIP
  353. array([[0.6172312 , 0.9789006 ],
  354. [0.50004643, 0.9775796 ]], dtype=float32)
  355. >>> alpha = mge.Tensor([[0.5],
  356. ... [ 3]], dtype="float32")
  357. >>> beta = mge.Tensor([0.5,5], dtype="float32")
  358. >>> x = rand.beta(alpha=alpha, beta=beta)
  359. >>> x.numpy() # doctest: +SKIP
  360. array([[0.0075407 , 0.1275094 ],
  361. [0.96331763, 0.22299217]], dtype=float32)
  362. >>> x = rand.beta(alpha=alpha, beta=beta, size=2)
  363. >>> x.numpy() # doctest: +SKIP
  364. array([[[0.46863747, 0.13819647],
  365. [0.8646759 , 0.16014215]],
  366. [[0.0682759 , 0.04448463],
  367. [0.97733796, 0.19206746]]], dtype=float32)
  368. """
  369. _seed = self._seed() if callable(self._seed) else self._seed
  370. return _beta(alpha=alpha, beta=beta, size=size, seed=_seed, handle=self._handle)
  371. def poisson(self, lam: Union[float, Tensor], size: Optional[Iterable[int]] = None):
  372. r"""Random variable with poisson distribution :math:`\operatorname{Poisson}(\lambda)`.
  373. The corresponding probability density function is
  374. .. math::
  375. f(k ; \lambda)=\frac{\lambda^{k} e^{-\lambda}}{k !},
  376. where k is the number of occurrences :math:`({\displaystyle k=0,1,2...})`.
  377. Args:
  378. lam: the lambda parameter of the distribution. Must be non-negative.
  379. size: the size of output tensor. If lam is a scalar and given size is, e.g., `(m, n)`,
  380. then the output shape is `(m, n)`. If lam is a Tensor with shape `(k, v)` and given
  381. size is, e.g., `(m, n)`, then the output shape is `(m, n, k, v)`. Default: None.
  382. Returns:
  383. the output tensor.
  384. Examples:
  385. >>> import megengine.random as rand
  386. >>> x = rand.poisson(lam=2., size=(1, 3))
  387. >>> x.numpy() # doctest: +SKIP
  388. array([[1., 2., 2.]], dtype=float32)
  389. >>> lam = mge.Tensor([[1.,1.],
  390. ... [10,10]], dtype="float32")
  391. >>> x = rand.poisson(lam=lam)
  392. >>> x.numpy() # doctest: +SKIP
  393. array([[ 1., 2.],
  394. [11., 11.]], dtype=float32)
  395. >>> x = rand.poisson(lam=lam, size=(1,3))
  396. >>> x.numpy() # doctest: +SKIP
  397. array([[[[ 2., 1.],
  398. [10., 8.]],
  399. [[ 5., 2.],
  400. [10., 10.]],
  401. [[ 1., 2.],
  402. [ 8., 10.]]]], dtype=float32)
  403. """
  404. _seed = self._seed() if callable(self._seed) else self._seed
  405. return _poisson(lam=lam, size=size, seed=_seed, handle=self._handle)
  406. def permutation(self, n: Union[int, Tensor], *, dtype: str = "int32"):
  407. r"""Randomly permute a sequence, or return a permuted range.
  408. If ``n`` is a multi-dimensional tensor, it is only shuffled along its first index.
  409. Args:
  410. n: If ``n`` is an integer, random permutation of integers from :math:`0` to :math:`n - 1`.
  411. If ``n`` is an tensor, make a copy and shuffle the elements randomly.
  412. dtype: the output data type when ``n`` is an integer.
  413. int32, int16 and float32 are supported. Default: int32
  414. Returns:
  415. the output tensor.
  416. Examples:
  417. >>> import numpy as np
  418. >>> import megengine.random as rand
  419. >>> x = rand.permutation(10, dtype="int32")
  420. >>> x.numpy() # doctest: +SKIP
  421. array([8, 4, 0, 3, 5, 6, 2, 1, 7, 9], dtype=int32)
  422. >>> x = rand.permutation(10, dtype="float32")
  423. >>> x.numpy() # doctest: +SKIP
  424. array([1., 3., 0., 2., 4., 8., 7., 9., 6., 5.], dtype=float32)
  425. >>> x = mge.tensor(np.arange(18)).reshape(6,3)
  426. >>> x = rand.permutation(x)
  427. >>> x.numpy() # doctest: +SKIP
  428. array([[15, 16, 17],
  429. [ 6, 7, 8],
  430. [ 0, 1, 2],
  431. [ 3, 4, 5],
  432. [12, 13, 14],
  433. [ 9, 10, 11]], dtype=int32)
  434. """
  435. _seed = self._seed() if callable(self._seed) else self._seed
  436. if isinstance(n, int):
  437. return _permutation(
  438. n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype
  439. )
  440. assert isinstance(n, Tensor)
  441. return _shuffle(inp=n, seed=_seed, handle=self._handle)
  442. def shuffle(self, inp: Tensor):
  443. r"""Modify a sequence in-place by shuffling its contents.
  444. This function only shuffles the Tensor along the first axis of a multi-dimensional Tensor.
  445. The order of sub-Tensors is changed but their contents remains the same.
  446. Args:
  447. inp: input tensor.
  448. Examples:
  449. >>> import numpy as np
  450. >>> import megengine.random as rand
  451. >>> x = mge.tensor(np.arange(10))
  452. >>> rand.shuffle(x)
  453. >>> x.numpy() # doctest: +SKIP
  454. array([4, 5, 9, 6, 2, 8, 1, 0, 3, 7], dtype=int32)
  455. >>> y = mge.tensor(np.arange(18)).reshape(6,3)
  456. >>> rand.shuffle(y)
  457. >>> y.numpy() # doctest: +SKIP
  458. array([[ 3, 4, 5],
  459. [ 6, 7, 8],
  460. [15, 16, 17],
  461. [ 0, 1, 2],
  462. [12, 13, 14],
  463. [ 9, 10, 11]], dtype=int32)
  464. """
  465. _seed = self._seed() if callable(self._seed) else self._seed
  466. inp._reset(_shuffle(inp=inp, seed=_seed, handle=self._handle))
  467. def __del__(self):
  468. if self._handle != 0:
  469. # RNG op might execute after handle released due to async dispatch, so
  470. # we need sync before delete a handle to avoid memory leak or
  471. # use-after-free
  472. _sync()
  473. _delete_rng_handle(self._handle)
  474. def _default_rng():
  475. r"""Default constructor for :class:`RNG`."""
  476. return RNG(seed=None, device=None)
  477. _default_handle = _default_rng()
  478. uniform = _default_handle.uniform
  479. normal = _default_handle.normal
  480. gamma = _default_handle.gamma
  481. beta = _default_handle.beta
  482. poisson = _default_handle.poisson
  483. permutation = _default_handle.permutation
  484. shuffle = _default_handle.shuffle
  485. def _random_seed_generator():
  486. assert _rng
  487. while True:
  488. yield _rng.random_raw()
  489. def seed(seed: int):
  490. global _rng # pylint: disable=global-statement
  491. _rng = MT19937(seed=seed)
  492. _set_global_rng_seed(seed)
  493. seed(int(time.time()))