You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import time
  11. from typing import Iterable, Optional, Union
  12. from numpy.random import MT19937
  13. from .. import Tensor
  14. from ..core._imperative_rt.core2 import apply
  15. from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle
  16. from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
  17. from ..core._imperative_rt.ops import (
  18. get_rng_handle_compnode as _get_rng_handle_compnode,
  19. )
  20. from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle
  21. from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed
  22. from ..core.ops.builtin import (
  23. BetaRNG,
  24. GammaRNG,
  25. GaussianRNG,
  26. PermutationRNG,
  27. PoissonRNG,
  28. UniformRNG,
  29. )
  30. from ..core.tensor import utils
  31. from ..device import get_default_device
  32. __all__ = [
  33. "seed",
  34. "RNG",
  35. "uniform",
  36. "normal",
  37. "gamma",
  38. "beta",
  39. "poisson",
  40. "permutation",
  41. ]
  42. _rng = None
  43. def _infer_broadcasted_shape(inps: Iterable[Tensor]) -> tuple:
  44. broadcasted_ndim = inps[0].ndim
  45. broadcasted_shape = list(inps[0]._tuple_shape)
  46. for i in range(1, len(inps)):
  47. cur_ndim = inps[i].ndim
  48. cur_shape = list(inps[i]._tuple_shape)
  49. n_dim = max(cur_ndim, broadcasted_ndim)
  50. for j in range(n_dim - 1, -1, -1):
  51. cur_dim = cur_ndim + j - n_dim
  52. broad_dim = broadcasted_ndim + j - n_dim
  53. cur_size = cur_shape[cur_dim] if cur_dim >= 0 else 1
  54. broad_size = broadcasted_shape[broad_dim] if broad_dim >= 0 else 1
  55. assert cur_size == broad_size or cur_size == 1 or broad_size == 1, (
  56. "The size of inps[{}] ({}) must match the size ({}) at "
  57. "dim {}".format(i, cur_size, broad_size, j)
  58. )
  59. broad_size = max(cur_size, broad_size)
  60. if broad_dim < 0:
  61. broadcasted_shape = [broad_size] + broadcasted_shape
  62. broadcasted_ndim += 1
  63. else:
  64. broadcasted_shape[broad_dim] = broad_size
  65. return tuple(broadcasted_shape)
  66. def _broadcast_tensors_with_size(
  67. inps: Iterable[Tensor], size: Iterable[int]
  68. ) -> Iterable[Tensor]:
  69. assert inps, "The inps cloud not be empty"
  70. target_shape = _infer_broadcasted_shape(inps)
  71. if isinstance(size, collections.abc.Iterable):
  72. target_shape = tuple(size) + target_shape
  73. target_ndim = len(target_shape)
  74. for i in range(len(inps)):
  75. if inps[i]._tuple_shape != target_shape:
  76. inps[i] = (
  77. inps[i]
  78. .reshape((1,) * (target_ndim - inps[i].ndim) + inps[i]._tuple_shape)
  79. ._broadcast(target_shape)
  80. )
  81. return inps
  82. def _uniform(
  83. low: float,
  84. high: float,
  85. size: Optional[Iterable[int]],
  86. seed: int,
  87. device: str,
  88. handle: int,
  89. ) -> Tensor:
  90. assert low < high, "Uniform is not defined when low >= high"
  91. if size is None:
  92. size = (1,)
  93. op = UniformRNG(seed=seed, handle=handle, dtype="float32")
  94. _ref = Tensor([], dtype="int32", device=device)
  95. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  96. (output,) = apply(op, shape)
  97. return low + (high - low) * output
  98. def _normal(
  99. mean: float,
  100. std: float,
  101. size: Optional[Iterable[int]],
  102. seed: int,
  103. device: str,
  104. handle: int,
  105. ) -> Tensor:
  106. if size is None:
  107. size = (1,)
  108. op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle, dtype="float32")
  109. _ref = Tensor([], dtype="int32", device=device)
  110. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  111. (output,) = apply(op, shape)
  112. return output
  113. def _gamma(
  114. shape: Union[Tensor, float],
  115. scale: Union[Tensor, float],
  116. size: Optional[Iterable[int]],
  117. seed: int,
  118. handle: int,
  119. ) -> Tensor:
  120. handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle)
  121. if not isinstance(shape, Tensor):
  122. assert shape > 0, "Gamma is not defined when shape <= 0"
  123. shape = Tensor(shape, dtype="float32", device=handle_cn)
  124. if not isinstance(scale, Tensor):
  125. assert scale > 0, "Gamma is not defined when scale <= 0"
  126. scale = Tensor(scale, dtype="float32", device=handle_cn)
  127. assert (
  128. handle_cn is None or handle_cn == shape.device
  129. ), "The shape ({}) must be the same device with handle ({})".format(
  130. shape.device, handle_cn
  131. )
  132. assert (
  133. handle_cn is None or handle_cn == scale.device
  134. ), "The scale ({}) must be the same device with handle ({})".format(
  135. scale.device, handle_cn
  136. )
  137. if isinstance(size, int) and size != 0:
  138. size = (size,)
  139. shape, scale = _broadcast_tensors_with_size([shape, scale], size)
  140. op = GammaRNG(seed=seed, handle=handle)
  141. (output,) = apply(op, shape, scale)
  142. return output
  143. def _beta(
  144. alpha: Union[Tensor, float],
  145. beta: Union[Tensor, float],
  146. size: Optional[Iterable[int]],
  147. seed: int,
  148. handle: int,
  149. ) -> Tensor:
  150. handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle)
  151. if not isinstance(alpha, Tensor):
  152. assert alpha > 0, "Beta is not defined when alpha <= 0"
  153. alpha = Tensor(alpha, dtype="float32", device=handle_cn)
  154. if not isinstance(beta, Tensor):
  155. assert beta > 0, "Beta is not defined when beta <= 0"
  156. beta = Tensor(beta, dtype="float32", device=handle_cn)
  157. assert (
  158. handle_cn is None or handle_cn == alpha.device
  159. ), "The alpha ({}) must be the same device with handle ({})".format(
  160. alpha.device, handle_cn
  161. )
  162. assert (
  163. handle_cn is None or handle_cn == beta.device
  164. ), "The beta ({}) must be the same device with handle ({})".format(
  165. beta.device, handle_cn
  166. )
  167. if isinstance(size, int) and size != 0:
  168. size = (size,)
  169. alpha, beta = _broadcast_tensors_with_size([alpha, beta], size)
  170. op = BetaRNG(seed=seed, handle=handle)
  171. (output,) = apply(op, alpha, beta)
  172. return output
  173. def _poisson(
  174. lam: Union[Tensor, float], size: Optional[Iterable[int]], seed: int, handle: int
  175. ) -> Tensor:
  176. handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle)
  177. if not isinstance(lam, Tensor):
  178. assert lam > 0, "Poisson is not defined when lam <= 0"
  179. lam = Tensor(lam, dtype="float32", device=handle_cn)
  180. if isinstance(size, int) and size != 0:
  181. size = (size,)
  182. assert (
  183. handle_cn is None or handle_cn == lam.device
  184. ), "The lam ({}) must be the same device with handle ({})".format(
  185. lam.device, handle_cn
  186. )
  187. (lam,) = _broadcast_tensors_with_size([lam], size)
  188. op = PoissonRNG(seed=seed, handle=handle)
  189. (output,) = apply(op, lam)
  190. return output
  191. def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Tensor:
  192. assert isinstance(n, int)
  193. assert n >= 0, "Permutation is not defined when n < 0"
  194. size = (n,)
  195. op = PermutationRNG(seed=seed, handle=handle, dtype=dtype)
  196. _ref = Tensor([], dtype="int32", device=device)
  197. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  198. (output,) = apply(op, shape)
  199. return output
  200. class RNG:
  201. r""":class:`RNG` exposes a number of methods for generating random numbers.
  202. Args:
  203. seed: random seed used to initialize the pseudo-random number generator. Default: None
  204. device: the device of generated tensor. Default: None
  205. Examples:
  206. .. testcode::
  207. import megengine.random as rand
  208. rng = rand.RNG(seed=100)
  209. x = rng.uniform(size=(2, 2))
  210. print(x.numpy())
  211. Outputs:
  212. .. testoutput::
  213. :options: +SKIP
  214. [[0.84811664 0.6147553 ]
  215. [0.59429836 0.64727545]]
  216. """
  217. def __init__(self, seed: int = None, device: str = None):
  218. self._device = device if device else get_default_device()
  219. if seed is not None:
  220. self._seed = seed
  221. self._handle = _new_rng_handle(self._device, self._seed)
  222. else:
  223. self._seed = _get_global_rng_seed
  224. self._handle = 0
  225. self._device = None
  226. def uniform(
  227. self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None
  228. ):
  229. r"""Random variable with uniform distribution $U(0, 1)$.
  230. Args:
  231. low: lower range. Default: 0
  232. high: upper range. Default: 1
  233. size: the size of output tensor. Default: None
  234. Returns:
  235. the output tensor.
  236. Examples:
  237. .. testcode::
  238. import megengine as mge
  239. import megengine.random as rand
  240. x = rand.uniform(size=(2, 2))
  241. print(x.numpy())
  242. Outputs:
  243. .. testoutput::
  244. :options: +SKIP
  245. [[0.91600335 0.6680226 ]
  246. [0.2046729 0.2769141 ]]
  247. """
  248. _seed = self._seed() if callable(self._seed) else self._seed
  249. return _uniform(
  250. low=low,
  251. high=high,
  252. size=size,
  253. seed=_seed,
  254. device=self._device,
  255. handle=self._handle,
  256. )
  257. def normal(
  258. self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None
  259. ):
  260. r"""Random variable with Gaussian distribution :math:`N(\mu, \sigma)`.
  261. Args:
  262. mean: the mean or expectation of the distribution. Default: 0
  263. std: the standard deviation of the distribution (variance = :math:`\sigma ^ 2`).
  264. Default: 1
  265. size: the size of output tensor. Default: None
  266. Returns:
  267. the output tensor.
  268. Examples:
  269. .. testcode::
  270. import megengine as mge
  271. import megengine.random as rand
  272. x = rand.normal(mean=0, std=1, size=(2, 2))
  273. print(x.numpy())
  274. Outputs:
  275. .. testoutput::
  276. :options: +SKIP
  277. [[-1.4010863 -0.9874344 ]
  278. [ 0.56373274 0.79656655]]
  279. """
  280. _seed = self._seed() if callable(self._seed) else self._seed
  281. return _normal(
  282. mean=mean,
  283. std=std,
  284. size=size,
  285. seed=_seed,
  286. device=self._device,
  287. handle=self._handle,
  288. )
  289. def gamma(
  290. self,
  291. shape: Union[Tensor, float],
  292. scale: Union[Tensor, float] = 1,
  293. size: Optional[Iterable[int]] = None,
  294. ):
  295. r"""Random variable with Gamma distribution :math:`\Gamma(k, \theta)`.
  296. The corresponding probability density function is
  297. .. math::
  298. p(x)=x^{k-1} \frac{e^{-x / \theta}}{\theta^{k} \Gamma(k)}
  299. \quad \text { for } x>0 \quad k, \theta>0,
  300. where :math:`\Gamma(k)` is the gamma function,
  301. .. math::
  302. \Gamma(k)=(k-1) ! \quad \text { for } \quad k>0.
  303. Args:
  304. shape: the shape parameter (sometimes designated "k") of the distribution.
  305. Must be non-negative.
  306. scale: the scale parameter (sometimes designated "theta") of the distribution.
  307. Must be non-negative. Default: 1
  308. size: the size of output tensor. If shape and scale are scalars and given size is, e.g.,
  309. `(m, n)`, then the output shape is `(m, n)`. If shape or scale is a Tensor and given size
  310. is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(shape, scale).shape`.
  311. The broadcast rules are consistent with `numpy.broadcast`. Default: None
  312. Returns:
  313. the output tensor.
  314. Examples:
  315. .. testcode::
  316. import megengine as mge
  317. import megengine.random as rand
  318. x = rand.gamma(shape=2, scale=1, size=(2, 2))
  319. print(x.numpy())
  320. shape = mge.Tensor([[ 1],
  321. [10]], dtype="float32")
  322. scale = mge.Tensor([1,5], dtype="float32")
  323. x = rand.gamma(shape=shape, scale=scale)
  324. print(x.numpy())
  325. x = rand.gamma(shape=shape, scale=scale, size=2)
  326. print(x.numpy())
  327. Outputs:
  328. .. testoutput::
  329. :options: +SKIP
  330. [[1.5064533 4.0689363 ]
  331. [0.71639484 1.4551026 ]]
  332. [[ 0.4352188 11.399335 ]
  333. [ 9.1888 52.009277 ]]
  334. [[[ 1.1726005 3.9654975 ]
  335. [13.656933 36.559006 ]]
  336. [[ 0.25848487 2.5540342 ]
  337. [11.960409 21.031536 ]]]
  338. """
  339. _seed = self._seed() if callable(self._seed) else self._seed
  340. return _gamma(
  341. shape=shape, scale=scale, size=size, seed=_seed, handle=self._handle
  342. )
  343. def beta(
  344. self,
  345. alpha: Union[Tensor, float],
  346. beta: Union[Tensor, float],
  347. size: Optional[Iterable[int]] = None,
  348. ):
  349. r"""Random variable with Beta distribution :math:`\operatorname{Beta}(\alpha, \beta)`.
  350. The corresponding probability density function is
  351. .. math::
  352. p(x)=\frac{1}{\mathrm{~B}(\alpha, \beta)} x^{\alpha-1}(1-x)^{\beta-1}
  353. \quad \text { for } \alpha, \beta>0,
  354. where :math:`\mathrm{~B}(\alpha, \beta)` is the beta function,
  355. .. math::
  356. \mathrm{~B}(\alpha, \beta)=\int_{0}^{1} t^{\alpha-1}(1-t)^{\beta-1} d t.
  357. Args:
  358. alpha: the alpha parameter of the distribution. Must be non-negative.
  359. beta: the beta parameter of the distribution. Must be non-negative.
  360. size: the size of output tensor. If alpha and beta are scalars and given size is, e.g.,
  361. `(m, n)`, then the output shape is `(m, n)`. If alpha or beta is a Tensor and given size
  362. is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(alpha, beta).shape`.
  363. Returns:
  364. the output tensor.
  365. Examples:
  366. .. testcode::
  367. import megengine as mge
  368. import megengine.random as rand
  369. x = rand.beta(alpha=2, beta=1, size=(2, 2))
  370. print(x.numpy())
  371. alpha = mge.Tensor([[0.5],
  372. [ 3]], dtype="float32")
  373. beta = mge.Tensor([0.5,5], dtype="float32")
  374. x = rand.beta(alpha=alpha, beta=beta)
  375. print(x.numpy())
  376. x = rand.beta(alpha=alpha, beta=beta, size=2)
  377. print(x.numpy())
  378. Outputs:
  379. .. testoutput::
  380. :options: +SKIP
  381. [[0.582565 0.91763186]
  382. [0.86963767 0.6088103 ]]
  383. [[0.41503012 0.16438372]
  384. [0.90159506 0.47588003]]
  385. [[[0.55195075 0.01111084]
  386. [0.95298755 0.25048104]]
  387. [[0.11680304 0.13859665]
  388. [0.997879 0.43259275]]]
  389. """
  390. _seed = self._seed() if callable(self._seed) else self._seed
  391. return _beta(alpha=alpha, beta=beta, size=size, seed=_seed, handle=self._handle)
  392. def poisson(self, lam: Union[float, Tensor], size: Optional[Iterable[int]] = None):
  393. r"""Random variable with poisson distribution :math:`\operatorname{Poisson}(\lambda)`.
  394. The corresponding probability density function is
  395. .. math::
  396. f(k ; \lambda)=\frac{\lambda^{k} e^{-\lambda}}{k !},
  397. where k is the number of occurrences :math:`({\displaystyle k=0,1,2...})`.
  398. Args:
  399. lam: the lambda parameter of the distribution. Must be non-negative.
  400. size: the size of output tensor. If lam is a scalar and given size is, e.g., `(m, n)`,
  401. then the output shape is `(m, n)`. If lam is a Tensor with shape `(k, v)` and given
  402. size is, e.g., `(m, n)`, then the output shape is `(m, n, k, v)`. Default: None.
  403. Returns:
  404. the output tensor.
  405. Examples:
  406. .. testcode::
  407. import megengine as mge
  408. import megengine.random as rand
  409. x = rand.poisson(lam=2., size=(1, 3))
  410. print(x.numpy())
  411. lam = mge.Tensor([[1.,1.],
  412. [10,10]], dtype="float32")
  413. x = rand.poisson(lam=lam)
  414. print(x.numpy())
  415. x = rand.poisson(lam=lam, size=(1,3))
  416. print(x.numpy())
  417. Outputs:
  418. .. testoutput::
  419. :options: +SKIP
  420. [[3. 1. 3.]]
  421. [[ 2. 2.]
  422. [12. 11.]]
  423. [[[[ 1. 1.]
  424. [11. 4.]]
  425. [[ 0. 0.]
  426. [ 9. 13.]]
  427. [[ 0. 1.]
  428. [ 7. 12.]]]]
  429. """
  430. _seed = self._seed() if callable(self._seed) else self._seed
  431. return _poisson(lam=lam, size=size, seed=_seed, handle=self._handle)
  432. def permutation(self, n: int, *, dtype: str = "int32"):
  433. r"""Generates a random permutation of integers from :math:`0` to :math:`n - 1`.
  434. Args:
  435. n: the upper bound. Must be larger than 0.
  436. dtype: the output data type. int32, int16 and float32 are supported. Default: int32
  437. Returns:
  438. the output tensor.
  439. Examples:
  440. .. testcode::
  441. import megengine as mge
  442. import megengine.random as rand
  443. x = rand.permutation(n=10, dtype="int32")
  444. print(x.numpy())
  445. x = rand.permutation(n=10, dtype="float32")
  446. print(x.numpy())
  447. Outputs:
  448. .. testoutput::
  449. :options: +SKIP
  450. [4 5 0 7 3 8 6 1 9 2]
  451. [3. 4. 9. 0. 6. 8. 7. 1. 5. 2.]
  452. """
  453. _seed = self._seed() if callable(self._seed) else self._seed
  454. return _permutation(
  455. n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype
  456. )
  457. def __del__(self):
  458. if self._handle != 0:
  459. _delete_rng_handle(self._handle)
  460. def _default_rng():
  461. r"""Default constructor for :class:`RNG`."""
  462. return RNG(seed=None, device=None)
  463. _default_handle = _default_rng()
  464. uniform = _default_handle.uniform
  465. normal = _default_handle.normal
  466. gamma = _default_handle.gamma
  467. beta = _default_handle.beta
  468. poisson = _default_handle.poisson
  469. permutation = _default_handle.permutation
  470. def _random_seed_generator():
  471. assert _rng
  472. while True:
  473. yield _rng.random_raw()
  474. def seed(seed: int):
  475. global _rng # pylint: disable=global-statement
  476. _rng = MT19937(seed=seed)
  477. _set_global_rng_seed(seed)
  478. seed(int(time.time()))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台