You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rng.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import time
  11. from typing import Iterable, Optional, Union
  12. from numpy.random import MT19937
  13. from .. import Tensor
  14. from ..core._imperative_rt.core2 import apply
  15. from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle
  16. from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
  17. from ..core._imperative_rt.ops import (
  18. get_rng_handle_compnode as _get_rng_handle_compnode,
  19. )
  20. from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle
  21. from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed
  22. from ..core.ops.builtin import (
  23. BetaRNG,
  24. GammaRNG,
  25. GaussianRNG,
  26. PermutationRNG,
  27. PoissonRNG,
  28. ShuffleRNG,
  29. UniformRNG,
  30. )
  31. from ..core.tensor import utils
  32. from ..device import get_default_device
  33. __all__ = [
  34. "seed",
  35. "RNG",
  36. "uniform",
  37. "normal",
  38. "gamma",
  39. "beta",
  40. "poisson",
  41. "permutation",
  42. "shuffle",
  43. ]
  44. _rng = None
  45. def _infer_broadcasted_shape(inps: Iterable[Tensor]) -> tuple:
  46. broadcasted_ndim = inps[0].ndim
  47. broadcasted_shape = list(inps[0]._tuple_shape)
  48. for i in range(1, len(inps)):
  49. cur_ndim = inps[i].ndim
  50. cur_shape = list(inps[i]._tuple_shape)
  51. n_dim = max(cur_ndim, broadcasted_ndim)
  52. for j in range(n_dim - 1, -1, -1):
  53. cur_dim = cur_ndim + j - n_dim
  54. broad_dim = broadcasted_ndim + j - n_dim
  55. cur_size = cur_shape[cur_dim] if cur_dim >= 0 else 1
  56. broad_size = broadcasted_shape[broad_dim] if broad_dim >= 0 else 1
  57. assert cur_size == broad_size or cur_size == 1 or broad_size == 1, (
  58. "The size of inps[{}] ({}) must match the size ({}) at "
  59. "dim {}".format(i, cur_size, broad_size, j)
  60. )
  61. broad_size = max(cur_size, broad_size)
  62. if broad_dim < 0:
  63. broadcasted_shape = [broad_size] + broadcasted_shape
  64. broadcasted_ndim += 1
  65. else:
  66. broadcasted_shape[broad_dim] = broad_size
  67. return tuple(broadcasted_shape)
  68. def _broadcast_tensors_with_size(
  69. inps: Iterable[Tensor], size: Iterable[int]
  70. ) -> Iterable[Tensor]:
  71. assert inps, "The inps cloud not be empty"
  72. target_shape = _infer_broadcasted_shape(inps)
  73. if isinstance(size, collections.abc.Iterable):
  74. target_shape = tuple(size) + target_shape
  75. target_ndim = len(target_shape)
  76. for i in range(len(inps)):
  77. if inps[i]._tuple_shape != target_shape:
  78. inps[i] = (
  79. inps[i]
  80. .reshape((1,) * (target_ndim - inps[i].ndim) + inps[i]._tuple_shape)
  81. ._broadcast(target_shape)
  82. )
  83. return inps
  84. def _uniform(
  85. low: float,
  86. high: float,
  87. size: Optional[Iterable[int]],
  88. seed: int,
  89. device: str,
  90. handle: int,
  91. ) -> Tensor:
  92. assert low < high, "Uniform is not defined when low >= high"
  93. if size is None:
  94. size = (1,)
  95. op = UniformRNG(seed=seed, handle=handle, dtype="float32")
  96. _ref = Tensor([], dtype="int32", device=device)
  97. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  98. (output,) = apply(op, shape)
  99. if low == 0 and high == 1:
  100. return output
  101. return low + (high - low) * output
  102. def _normal(
  103. mean: float,
  104. std: float,
  105. size: Optional[Iterable[int]],
  106. seed: int,
  107. device: str,
  108. handle: int,
  109. ) -> Tensor:
  110. if size is None:
  111. size = (1,)
  112. op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle, dtype="float32")
  113. _ref = Tensor([], dtype="int32", device=device)
  114. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  115. (output,) = apply(op, shape)
  116. return output
  117. def _gamma(
  118. shape: Union[Tensor, float],
  119. scale: Union[Tensor, float],
  120. size: Optional[Iterable[int]],
  121. seed: int,
  122. handle: int,
  123. ) -> Tensor:
  124. handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle)
  125. if not isinstance(shape, Tensor):
  126. assert shape > 0, "Gamma is not defined when shape <= 0"
  127. shape = Tensor(shape, dtype="float32", device=handle_cn)
  128. if not isinstance(scale, Tensor):
  129. assert scale > 0, "Gamma is not defined when scale <= 0"
  130. scale = Tensor(scale, dtype="float32", device=handle_cn)
  131. assert (
  132. handle_cn is None or handle_cn == shape.device
  133. ), "The shape ({}) must be the same device with handle ({})".format(
  134. shape.device, handle_cn
  135. )
  136. assert (
  137. handle_cn is None or handle_cn == scale.device
  138. ), "The scale ({}) must be the same device with handle ({})".format(
  139. scale.device, handle_cn
  140. )
  141. if isinstance(size, int) and size != 0:
  142. size = (size,)
  143. shape, scale = _broadcast_tensors_with_size([shape, scale], size)
  144. op = GammaRNG(seed=seed, handle=handle)
  145. (output,) = apply(op, shape, scale)
  146. return output
  147. def _beta(
  148. alpha: Union[Tensor, float],
  149. beta: Union[Tensor, float],
  150. size: Optional[Iterable[int]],
  151. seed: int,
  152. handle: int,
  153. ) -> Tensor:
  154. handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle)
  155. if not isinstance(alpha, Tensor):
  156. assert alpha > 0, "Beta is not defined when alpha <= 0"
  157. alpha = Tensor(alpha, dtype="float32", device=handle_cn)
  158. if not isinstance(beta, Tensor):
  159. assert beta > 0, "Beta is not defined when beta <= 0"
  160. beta = Tensor(beta, dtype="float32", device=handle_cn)
  161. assert (
  162. handle_cn is None or handle_cn == alpha.device
  163. ), "The alpha ({}) must be the same device with handle ({})".format(
  164. alpha.device, handle_cn
  165. )
  166. assert (
  167. handle_cn is None or handle_cn == beta.device
  168. ), "The beta ({}) must be the same device with handle ({})".format(
  169. beta.device, handle_cn
  170. )
  171. if isinstance(size, int) and size != 0:
  172. size = (size,)
  173. alpha, beta = _broadcast_tensors_with_size([alpha, beta], size)
  174. op = BetaRNG(seed=seed, handle=handle)
  175. (output,) = apply(op, alpha, beta)
  176. return output
  177. def _poisson(
  178. lam: Union[Tensor, float], size: Optional[Iterable[int]], seed: int, handle: int
  179. ) -> Tensor:
  180. handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle)
  181. if not isinstance(lam, Tensor):
  182. assert lam > 0, "Poisson is not defined when lam <= 0"
  183. lam = Tensor(lam, dtype="float32", device=handle_cn)
  184. if isinstance(size, int) and size != 0:
  185. size = (size,)
  186. assert (
  187. handle_cn is None or handle_cn == lam.device
  188. ), "The lam ({}) must be the same device with handle ({})".format(
  189. lam.device, handle_cn
  190. )
  191. (lam,) = _broadcast_tensors_with_size([lam], size)
  192. op = PoissonRNG(seed=seed, handle=handle)
  193. (output,) = apply(op, lam)
  194. return output
  195. def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Tensor:
  196. assert isinstance(n, int)
  197. assert n >= 0, "Permutation is not defined when n < 0"
  198. size = (n,)
  199. op = PermutationRNG(seed=seed, handle=handle, dtype=dtype)
  200. _ref = Tensor([], dtype="int32", device=device)
  201. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  202. (output,) = apply(op, shape)
  203. return output
  204. def _shuffle(inp: Tensor, seed: int, handle: int) -> Tensor:
  205. assert inp.size > 0, "size needs to be greater than 0"
  206. op = ShuffleRNG(seed=seed, handle=handle)
  207. output, _ = apply(op, inp)
  208. return output
  209. class RNG:
  210. r""":class:`RNG` exposes a number of methods for generating random numbers.
  211. Args:
  212. seed: random seed used to initialize the pseudo-random number generator. Default: None
  213. device: the device of generated tensor. Default: None
  214. Examples:
  215. .. testcode::
  216. import megengine.random as rand
  217. rng = rand.RNG(seed=100)
  218. x = rng.uniform(size=(2, 2))
  219. print(x.numpy())
  220. Outputs:
  221. .. testoutput::
  222. :options: +SKIP
  223. [[0.84811664 0.6147553 ]
  224. [0.59429836 0.64727545]]
  225. """
  226. def __init__(self, seed: int = None, device: str = None):
  227. self._device = device if device else get_default_device()
  228. if seed is not None:
  229. self._seed = seed
  230. self._handle = _new_rng_handle(self._device, self._seed)
  231. else:
  232. self._seed = _get_global_rng_seed
  233. self._handle = 0
  234. self._device = None
  235. def uniform(
  236. self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None
  237. ):
  238. r"""Random variable with uniform distribution $U(0, 1)$.
  239. Args:
  240. low: lower range. Default: 0
  241. high: upper range. Default: 1
  242. size: the size of output tensor. Default: None
  243. Returns:
  244. the output tensor.
  245. Examples:
  246. .. testcode::
  247. import megengine as mge
  248. import megengine.random as rand
  249. x = rand.uniform(size=(2, 2))
  250. print(x.numpy())
  251. Outputs:
  252. .. testoutput::
  253. :options: +SKIP
  254. [[0.91600335 0.6680226 ]
  255. [0.2046729 0.2769141 ]]
  256. """
  257. _seed = self._seed() if callable(self._seed) else self._seed
  258. return _uniform(
  259. low=low,
  260. high=high,
  261. size=size,
  262. seed=_seed,
  263. device=self._device,
  264. handle=self._handle,
  265. )
  266. def normal(
  267. self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None
  268. ):
  269. r"""Random variable with Gaussian distribution :math:`N(\mu, \sigma)`.
  270. Args:
  271. mean: the mean or expectation of the distribution. Default: 0
  272. std: the standard deviation of the distribution (variance = :math:`\sigma ^ 2`).
  273. Default: 1
  274. size: the size of output tensor. Default: None
  275. Returns:
  276. the output tensor.
  277. Examples:
  278. .. testcode::
  279. import megengine as mge
  280. import megengine.random as rand
  281. x = rand.normal(mean=0, std=1, size=(2, 2))
  282. print(x.numpy())
  283. Outputs:
  284. .. testoutput::
  285. :options: +SKIP
  286. [[-1.4010863 -0.9874344 ]
  287. [ 0.56373274 0.79656655]]
  288. """
  289. _seed = self._seed() if callable(self._seed) else self._seed
  290. return _normal(
  291. mean=mean,
  292. std=std,
  293. size=size,
  294. seed=_seed,
  295. device=self._device,
  296. handle=self._handle,
  297. )
  298. def gamma(
  299. self,
  300. shape: Union[Tensor, float],
  301. scale: Union[Tensor, float] = 1,
  302. size: Optional[Iterable[int]] = None,
  303. ):
  304. r"""Random variable with Gamma distribution :math:`\Gamma(k, \theta)`.
  305. The corresponding probability density function is
  306. .. math::
  307. p(x)=x^{k-1} \frac{e^{-x / \theta}}{\theta^{k} \Gamma(k)}
  308. \quad \text { for } x>0 \quad k, \theta>0,
  309. where :math:`\Gamma(k)` is the gamma function,
  310. .. math::
  311. \Gamma(k)=(k-1) ! \quad \text { for } \quad k>0.
  312. Args:
  313. shape: the shape parameter (sometimes designated "k") of the distribution.
  314. Must be non-negative.
  315. scale: the scale parameter (sometimes designated "theta") of the distribution.
  316. Must be non-negative. Default: 1
  317. size: the size of output tensor. If shape and scale are scalars and given size is, e.g.,
  318. `(m, n)`, then the output shape is `(m, n)`. If shape or scale is a Tensor and given size
  319. is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(shape, scale).shape`.
  320. The broadcast rules are consistent with `numpy.broadcast`. Default: None
  321. Returns:
  322. the output tensor.
  323. Examples:
  324. .. testcode::
  325. import megengine as mge
  326. import megengine.random as rand
  327. x = rand.gamma(shape=2, scale=1, size=(2, 2))
  328. print(x.numpy())
  329. shape = mge.Tensor([[ 1],
  330. [10]], dtype="float32")
  331. scale = mge.Tensor([1,5], dtype="float32")
  332. x = rand.gamma(shape=shape, scale=scale)
  333. print(x.numpy())
  334. x = rand.gamma(shape=shape, scale=scale, size=2)
  335. print(x.numpy())
  336. Outputs:
  337. .. testoutput::
  338. :options: +SKIP
  339. [[1.5064533 4.0689363 ]
  340. [0.71639484 1.4551026 ]]
  341. [[ 0.4352188 11.399335 ]
  342. [ 9.1888 52.009277 ]]
  343. [[[ 1.1726005 3.9654975 ]
  344. [13.656933 36.559006 ]]
  345. [[ 0.25848487 2.5540342 ]
  346. [11.960409 21.031536 ]]]
  347. """
  348. _seed = self._seed() if callable(self._seed) else self._seed
  349. return _gamma(
  350. shape=shape, scale=scale, size=size, seed=_seed, handle=self._handle
  351. )
  352. def beta(
  353. self,
  354. alpha: Union[Tensor, float],
  355. beta: Union[Tensor, float],
  356. size: Optional[Iterable[int]] = None,
  357. ):
  358. r"""Random variable with Beta distribution :math:`\operatorname{Beta}(\alpha, \beta)`.
  359. The corresponding probability density function is
  360. .. math::
  361. p(x)=\frac{1}{\mathrm{~B}(\alpha, \beta)} x^{\alpha-1}(1-x)^{\beta-1}
  362. \quad \text { for } \alpha, \beta>0,
  363. where :math:`\mathrm{~B}(\alpha, \beta)` is the beta function,
  364. .. math::
  365. \mathrm{~B}(\alpha, \beta)=\int_{0}^{1} t^{\alpha-1}(1-t)^{\beta-1} d t.
  366. Args:
  367. alpha: the alpha parameter of the distribution. Must be non-negative.
  368. beta: the beta parameter of the distribution. Must be non-negative.
  369. size: the size of output tensor. If alpha and beta are scalars and given size is, e.g.,
  370. `(m, n)`, then the output shape is `(m, n)`. If alpha or beta is a Tensor and given size
  371. is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(alpha, beta).shape`.
  372. Returns:
  373. the output tensor.
  374. Examples:
  375. .. testcode::
  376. import megengine as mge
  377. import megengine.random as rand
  378. x = rand.beta(alpha=2, beta=1, size=(2, 2))
  379. print(x.numpy())
  380. alpha = mge.Tensor([[0.5],
  381. [ 3]], dtype="float32")
  382. beta = mge.Tensor([0.5,5], dtype="float32")
  383. x = rand.beta(alpha=alpha, beta=beta)
  384. print(x.numpy())
  385. x = rand.beta(alpha=alpha, beta=beta, size=2)
  386. print(x.numpy())
  387. Outputs:
  388. .. testoutput::
  389. :options: +SKIP
  390. [[0.582565 0.91763186]
  391. [0.86963767 0.6088103 ]]
  392. [[0.41503012 0.16438372]
  393. [0.90159506 0.47588003]]
  394. [[[0.55195075 0.01111084]
  395. [0.95298755 0.25048104]]
  396. [[0.11680304 0.13859665]
  397. [0.997879 0.43259275]]]
  398. """
  399. _seed = self._seed() if callable(self._seed) else self._seed
  400. return _beta(alpha=alpha, beta=beta, size=size, seed=_seed, handle=self._handle)
  401. def poisson(self, lam: Union[float, Tensor], size: Optional[Iterable[int]] = None):
  402. r"""Random variable with poisson distribution :math:`\operatorname{Poisson}(\lambda)`.
  403. The corresponding probability density function is
  404. .. math::
  405. f(k ; \lambda)=\frac{\lambda^{k} e^{-\lambda}}{k !},
  406. where k is the number of occurrences :math:`({\displaystyle k=0,1,2...})`.
  407. Args:
  408. lam: the lambda parameter of the distribution. Must be non-negative.
  409. size: the size of output tensor. If lam is a scalar and given size is, e.g., `(m, n)`,
  410. then the output shape is `(m, n)`. If lam is a Tensor with shape `(k, v)` and given
  411. size is, e.g., `(m, n)`, then the output shape is `(m, n, k, v)`. Default: None.
  412. Returns:
  413. the output tensor.
  414. Examples:
  415. .. testcode::
  416. import megengine as mge
  417. import megengine.random as rand
  418. x = rand.poisson(lam=2., size=(1, 3))
  419. print(x.numpy())
  420. lam = mge.Tensor([[1.,1.],
  421. [10,10]], dtype="float32")
  422. x = rand.poisson(lam=lam)
  423. print(x.numpy())
  424. x = rand.poisson(lam=lam, size=(1,3))
  425. print(x.numpy())
  426. Outputs:
  427. .. testoutput::
  428. :options: +SKIP
  429. [[3. 1. 3.]]
  430. [[ 2. 2.]
  431. [12. 11.]]
  432. [[[[ 1. 1.]
  433. [11. 4.]]
  434. [[ 0. 0.]
  435. [ 9. 13.]]
  436. [[ 0. 1.]
  437. [ 7. 12.]]]]
  438. """
  439. _seed = self._seed() if callable(self._seed) else self._seed
  440. return _poisson(lam=lam, size=size, seed=_seed, handle=self._handle)
  441. def permutation(self, n: Union[int, Tensor], *, dtype: str = "int32"):
  442. r"""Randomly permute a sequence, or return a permuted range.
  443. If ``n`` is a multi-dimensional tensor, it is only shuffled along its first index.
  444. Args:
  445. n: If ``n`` is an integer, random permutation of integers from :math:`0` to :math:`n - 1`.
  446. If ``n`` is an tensor, make a copy and shuffle the elements randomly.
  447. dtype: the output data type when ``n`` is an integer.
  448. int32, int16 and float32 are supported. Default: int32
  449. Returns:
  450. the output tensor.
  451. Examples:
  452. .. testcode::
  453. import numpy as np
  454. import megengine as mge
  455. import megengine.random as rand
  456. x = rand.permutation(10, dtype="int32")
  457. print(x.numpy())
  458. x = rand.permutation(10, dtype="float32")
  459. print(x.numpy())
  460. x = mge.tensor(np.arange(18)).reshape(6,3)
  461. x = rand.permutation(x)
  462. print(x.numpy())
  463. Outputs:
  464. .. testoutput::
  465. :options: +SKIP
  466. [4 5 0 7 3 8 6 1 9 2]
  467. [3. 4. 9. 0. 6. 8. 7. 1. 5. 2.]
  468. [[12 13 14]
  469. [ 3 4 5]
  470. [15 16 17]
  471. [ 0 1 2]
  472. [ 9 10 11]
  473. [ 6 7 8]]
  474. """
  475. _seed = self._seed() if callable(self._seed) else self._seed
  476. if isinstance(n, int):
  477. return _permutation(
  478. n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype
  479. )
  480. assert isinstance(n, Tensor)
  481. return _shuffle(inp=n, seed=_seed, handle=self._handle)
  482. def shuffle(self, inp: Tensor):
  483. r"""Modify a sequence in-place by shuffling its contents.
  484. This function only shuffles the Tensor along the first axis of a multi-dimensional Tensor.
  485. The order of sub-Tensors is changed but their contents remains the same.
  486. Args:
  487. inp: input tensor.
  488. Examples:
  489. .. testcode::
  490. import numpy as np
  491. import megengine as mge
  492. import megengine.random as rand
  493. x = mge.tensor(np.arange(10))
  494. rand.shuffle(x)
  495. print(x.numpy())
  496. y = mge.tensor(np.arange(18)).reshape(6,3)
  497. rand.shuffle(y)
  498. print(y.numpy())
  499. Outputs:
  500. .. testoutput::
  501. :options: +SKIP
  502. [7 9 3 0 8 2 4 5 6 1]
  503. [[12. 13. 14.]
  504. [ 3. 4. 5.]
  505. [15. 16. 17.]
  506. [ 0. 1. 2.]
  507. [ 9. 10. 11.]
  508. [ 6. 7. 8.]]
  509. """
  510. _seed = self._seed() if callable(self._seed) else self._seed
  511. inp._reset(_shuffle(inp=inp, seed=_seed, handle=self._handle))
  512. def __del__(self):
  513. if self._handle != 0:
  514. _delete_rng_handle(self._handle)
  515. def _default_rng():
  516. r"""Default constructor for :class:`RNG`."""
  517. return RNG(seed=None, device=None)
  518. _default_handle = _default_rng()
  519. uniform = _default_handle.uniform
  520. normal = _default_handle.normal
  521. gamma = _default_handle.gamma
  522. beta = _default_handle.beta
  523. poisson = _default_handle.poisson
  524. permutation = _default_handle.permutation
  525. shuffle = _default_handle.shuffle
  526. def _random_seed_generator():
  527. assert _rng
  528. while True:
  529. yield _rng.random_raw()
  530. def seed(seed: int):
  531. global _rng # pylint: disable=global-statement
  532. _rng = MT19937(seed=seed)
  533. _set_global_rng_seed(seed)
  534. seed(int(time.time()))