You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rng.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import time
  11. from typing import Iterable, Optional, Union
  12. from numpy.random import MT19937
  13. from .. import Tensor
  14. from ..core._imperative_rt.core2 import apply
  15. from ..core._imperative_rt.core2 import sync as _sync
  16. from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle
  17. from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
  18. from ..core._imperative_rt.ops import (
  19. get_rng_handle_compnode as _get_rng_handle_compnode,
  20. )
  21. from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle
  22. from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed
  23. from ..core.ops.builtin import (
  24. BetaRNG,
  25. GammaRNG,
  26. GaussianRNG,
  27. PermutationRNG,
  28. PoissonRNG,
  29. ShuffleRNG,
  30. UniformRNG,
  31. )
  32. from ..core.tensor import utils
  33. from ..device import get_default_device
  34. __all__ = [
  35. "seed",
  36. "RNG",
  37. "uniform",
  38. "normal",
  39. "gamma",
  40. "beta",
  41. "poisson",
  42. "permutation",
  43. "shuffle",
  44. ]
  45. _rng = None
  46. def _infer_broadcasted_shape(inps: Iterable[Tensor]) -> tuple:
  47. broadcasted_ndim = inps[0].ndim
  48. broadcasted_shape = list(inps[0]._tuple_shape)
  49. for i in range(1, len(inps)):
  50. cur_ndim = inps[i].ndim
  51. cur_shape = list(inps[i]._tuple_shape)
  52. n_dim = max(cur_ndim, broadcasted_ndim)
  53. for j in range(n_dim - 1, -1, -1):
  54. cur_dim = cur_ndim + j - n_dim
  55. broad_dim = broadcasted_ndim + j - n_dim
  56. cur_size = cur_shape[cur_dim] if cur_dim >= 0 else 1
  57. broad_size = broadcasted_shape[broad_dim] if broad_dim >= 0 else 1
  58. assert cur_size == broad_size or cur_size == 1 or broad_size == 1, (
  59. "The size of inps[{}] ({}) must match the size ({}) at "
  60. "dim {}".format(i, cur_size, broad_size, j)
  61. )
  62. broad_size = max(cur_size, broad_size)
  63. if broad_dim < 0:
  64. broadcasted_shape = [broad_size] + broadcasted_shape
  65. broadcasted_ndim += 1
  66. else:
  67. broadcasted_shape[broad_dim] = broad_size
  68. return tuple(broadcasted_shape)
  69. def _broadcast_tensors_with_size(
  70. inps: Iterable[Tensor], size: Iterable[int]
  71. ) -> Iterable[Tensor]:
  72. assert inps, "The inps cloud not be empty"
  73. target_shape = _infer_broadcasted_shape(inps)
  74. if isinstance(size, collections.abc.Iterable):
  75. target_shape = tuple(size) + target_shape
  76. target_ndim = len(target_shape)
  77. for i in range(len(inps)):
  78. if inps[i]._tuple_shape != target_shape:
  79. inps[i] = (
  80. inps[i]
  81. .reshape((1,) * (target_ndim - inps[i].ndim) + inps[i]._tuple_shape)
  82. ._broadcast(target_shape)
  83. )
  84. return inps
  85. def _uniform(
  86. low: float,
  87. high: float,
  88. size: Optional[Iterable[int]],
  89. seed: int,
  90. device: str,
  91. handle: int,
  92. ) -> Tensor:
  93. assert low < high, "Uniform is not defined when low >= high"
  94. if size is None:
  95. size = (1,)
  96. op = UniformRNG(seed=seed, handle=handle, dtype="float32")
  97. _ref = Tensor([], dtype="int32", device=device)
  98. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  99. (output,) = apply(op, shape)
  100. if low == 0 and high == 1:
  101. return output
  102. return low + (high - low) * output
  103. def _normal(
  104. mean: float,
  105. std: float,
  106. size: Optional[Iterable[int]],
  107. seed: int,
  108. device: str,
  109. handle: int,
  110. ) -> Tensor:
  111. if size is None:
  112. size = (1,)
  113. op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle, dtype="float32")
  114. _ref = Tensor([], dtype="int32", device=device)
  115. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  116. (output,) = apply(op, shape)
  117. return output
  118. def _gamma(
  119. shape: Union[Tensor, float],
  120. scale: Union[Tensor, float],
  121. size: Optional[Iterable[int]],
  122. seed: int,
  123. handle: int,
  124. ) -> Tensor:
  125. handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle)
  126. if not isinstance(shape, Tensor):
  127. assert shape > 0, "Gamma is not defined when shape <= 0"
  128. shape = Tensor(shape, dtype="float32", device=handle_cn)
  129. if not isinstance(scale, Tensor):
  130. assert scale > 0, "Gamma is not defined when scale <= 0"
  131. scale = Tensor(scale, dtype="float32", device=handle_cn)
  132. assert (
  133. handle_cn is None or handle_cn == shape.device
  134. ), "The shape ({}) must be the same device with handle ({})".format(
  135. shape.device, handle_cn
  136. )
  137. assert (
  138. handle_cn is None or handle_cn == scale.device
  139. ), "The scale ({}) must be the same device with handle ({})".format(
  140. scale.device, handle_cn
  141. )
  142. if isinstance(size, int) and size != 0:
  143. size = (size,)
  144. shape, scale = _broadcast_tensors_with_size([shape, scale], size)
  145. op = GammaRNG(seed=seed, handle=handle)
  146. (output,) = apply(op, shape, scale)
  147. return output
  148. def _beta(
  149. alpha: Union[Tensor, float],
  150. beta: Union[Tensor, float],
  151. size: Optional[Iterable[int]],
  152. seed: int,
  153. handle: int,
  154. ) -> Tensor:
  155. handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle)
  156. if not isinstance(alpha, Tensor):
  157. assert alpha > 0, "Beta is not defined when alpha <= 0"
  158. alpha = Tensor(alpha, dtype="float32", device=handle_cn)
  159. if not isinstance(beta, Tensor):
  160. assert beta > 0, "Beta is not defined when beta <= 0"
  161. beta = Tensor(beta, dtype="float32", device=handle_cn)
  162. assert (
  163. handle_cn is None or handle_cn == alpha.device
  164. ), "The alpha ({}) must be the same device with handle ({})".format(
  165. alpha.device, handle_cn
  166. )
  167. assert (
  168. handle_cn is None or handle_cn == beta.device
  169. ), "The beta ({}) must be the same device with handle ({})".format(
  170. beta.device, handle_cn
  171. )
  172. if isinstance(size, int) and size != 0:
  173. size = (size,)
  174. alpha, beta = _broadcast_tensors_with_size([alpha, beta], size)
  175. op = BetaRNG(seed=seed, handle=handle)
  176. (output,) = apply(op, alpha, beta)
  177. return output
  178. def _poisson(
  179. lam: Union[Tensor, float], size: Optional[Iterable[int]], seed: int, handle: int
  180. ) -> Tensor:
  181. handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle)
  182. if not isinstance(lam, Tensor):
  183. assert lam > 0, "Poisson is not defined when lam <= 0"
  184. lam = Tensor(lam, dtype="float32", device=handle_cn)
  185. if isinstance(size, int) and size != 0:
  186. size = (size,)
  187. assert (
  188. handle_cn is None or handle_cn == lam.device
  189. ), "The lam ({}) must be the same device with handle ({})".format(
  190. lam.device, handle_cn
  191. )
  192. (lam,) = _broadcast_tensors_with_size([lam], size)
  193. op = PoissonRNG(seed=seed, handle=handle)
  194. (output,) = apply(op, lam)
  195. return output
  196. def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Tensor:
  197. assert isinstance(n, int)
  198. assert n >= 0, "Permutation is not defined when n < 0"
  199. size = (n,)
  200. op = PermutationRNG(seed=seed, handle=handle, dtype=dtype)
  201. _ref = Tensor([], dtype="int32", device=device)
  202. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  203. (output,) = apply(op, shape)
  204. return output
  205. def _shuffle(inp: Tensor, seed: int, handle: int) -> Tensor:
  206. assert inp.size > 0, "size needs to be greater than 0"
  207. op = ShuffleRNG(seed=seed, handle=handle)
  208. output, _ = apply(op, inp)
  209. return output
  210. class RNG:
  211. r""":class:`RNG` exposes a number of methods for generating random numbers.
  212. Args:
  213. seed: random seed used to initialize the pseudo-random number generator. Default: None
  214. device: the device of generated tensor. Default: None
  215. Examples:
  216. >>> import megengine.random as rand
  217. >>> rng = rand.RNG(seed=100)
  218. >>> x = rng.uniform(size=(2, 2))
  219. >>> x.numpy() # doctest: +SKIP
  220. array([[0.84811664, 0.6147553 ],
  221. [0.59429836, 0.64727545]], dtype=float32)
  222. """
  223. def __init__(self, seed: int = None, device: str = None):
  224. self._device = device if device else get_default_device()
  225. if seed is not None:
  226. self._seed = seed
  227. self._handle = _new_rng_handle(self._device, self._seed)
  228. else:
  229. self._seed = _get_global_rng_seed
  230. self._handle = 0
  231. self._device = None
  232. def uniform(
  233. self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None
  234. ):
  235. r"""Random variable with uniform distribution $U(0, 1)$.
  236. Args:
  237. low: lower range. Default: 0
  238. high: upper range. Default: 1
  239. size: the size of output tensor. Default: None
  240. Returns:
  241. the output tensor.
  242. Examples:
  243. >>> import megengine.random as rand
  244. >>> x = rand.uniform(size=(2, 2))
  245. >>> x.numpy() # doctest: +SKIP
  246. array([[0.28603864, 0.3156649 ],
  247. [0.42066026, 0.9805052 ]], dtype=float32)
  248. """
  249. _seed = self._seed() if callable(self._seed) else self._seed
  250. return _uniform(
  251. low=low,
  252. high=high,
  253. size=size,
  254. seed=_seed,
  255. device=self._device,
  256. handle=self._handle,
  257. )
  258. def normal(
  259. self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None
  260. ):
  261. r"""Random variable with Gaussian distribution :math:`N(\mu, \sigma)`.
  262. Args:
  263. mean: the mean or expectation of the distribution. Default: 0
  264. std: the standard deviation of the distribution (variance = :math:`\sigma ^ 2`).
  265. Default: 1
  266. size: the size of output tensor. Default: None
  267. Returns:
  268. the output tensor.
  269. Examples:
  270. >>> import megengine.random as rand
  271. >>> x = rand.normal(mean=0, std=1, size=(2, 2))
  272. >>> x.numpy() # doctest: +SKIP
  273. array([[ 1.5534291 , -0.28356555],
  274. [ 2.2230418 , -0.92425716]], dtype=float32)
  275. """
  276. _seed = self._seed() if callable(self._seed) else self._seed
  277. return _normal(
  278. mean=mean,
  279. std=std,
  280. size=size,
  281. seed=_seed,
  282. device=self._device,
  283. handle=self._handle,
  284. )
  285. def gamma(
  286. self,
  287. shape: Union[Tensor, float],
  288. scale: Union[Tensor, float] = 1,
  289. size: Optional[Iterable[int]] = None,
  290. ):
  291. r"""Random variable with Gamma distribution :math:`\Gamma(k, \theta)`.
  292. The corresponding probability density function is
  293. .. math::
  294. p(x)=x^{k-1} \frac{e^{-x / \theta}}{\theta^{k} \Gamma(k)}
  295. \quad \text { for } x>0 \quad k, \theta>0,
  296. where :math:`\Gamma(k)` is the gamma function,
  297. .. math::
  298. \Gamma(k)=(k-1) ! \quad \text { for } \quad k>0.
  299. Args:
  300. shape: the shape parameter (sometimes designated "k") of the distribution.
  301. Must be non-negative.
  302. scale: the scale parameter (sometimes designated "theta") of the distribution.
  303. Must be non-negative. Default: 1
  304. size: the size of output tensor. If shape and scale are scalars and given size is, e.g.,
  305. `(m, n)`, then the output shape is `(m, n)`. If shape or scale is a Tensor and given size
  306. is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(shape, scale).shape`.
  307. The broadcast rules are consistent with `numpy.broadcast`. Default: None
  308. Returns:
  309. the output tensor.
  310. Examples:
  311. >>> import megengine.random as rand
  312. >>> x = rand.gamma(shape=2, scale=1, size=(2, 2))
  313. >>> x.numpy() # doctest: +SKIP
  314. array([[0.97447544, 1.5668875 ],
  315. [1.0069491 , 0.3078318 ]], dtype=float32)
  316. >>> shape = mge.Tensor([[ 1],
  317. ... [10]], dtype="float32")
  318. >>> scale = mge.Tensor([1,5], dtype="float32")
  319. >>> x = rand.gamma(shape=shape, scale=scale)
  320. >>> x.numpy() # doctest: +SKIP
  321. array([[ 0.11312152, 3.0799196 ],
  322. [10.973469 , 29.596972 ]], dtype=float32)
  323. >>> x = rand.gamma(shape=shape, scale=scale, size=2)
  324. >>> x.numpy() # doctest: +SKIP
  325. array([[[4.35868073e+00, 1.22415285e+01],
  326. [1.02696848e+01, 4.19773598e+01]],
  327. [[7.73875117e-02, 6.06766164e-01],
  328. [1.22881927e+01, 8.13445740e+01]]], dtype=float32)
  329. """
  330. _seed = self._seed() if callable(self._seed) else self._seed
  331. return _gamma(
  332. shape=shape, scale=scale, size=size, seed=_seed, handle=self._handle
  333. )
  334. def beta(
  335. self,
  336. alpha: Union[Tensor, float],
  337. beta: Union[Tensor, float],
  338. size: Optional[Iterable[int]] = None,
  339. ):
  340. r"""Random variable with Beta distribution :math:`\operatorname{Beta}(\alpha, \beta)`.
  341. The corresponding probability density function is
  342. .. math::
  343. p(x)=\frac{1}{\mathrm{~B}(\alpha, \beta)} x^{\alpha-1}(1-x)^{\beta-1}
  344. \quad \text { for } \alpha, \beta>0,
  345. where :math:`\mathrm{~B}(\alpha, \beta)` is the beta function,
  346. .. math::
  347. \mathrm{~B}(\alpha, \beta)=\int_{0}^{1} t^{\alpha-1}(1-t)^{\beta-1} d t.
  348. Args:
  349. alpha: the alpha parameter of the distribution. Must be non-negative.
  350. beta: the beta parameter of the distribution. Must be non-negative.
  351. size: the size of output tensor. If alpha and beta are scalars and given size is, e.g.,
  352. `(m, n)`, then the output shape is `(m, n)`. If alpha or beta is a Tensor and given size
  353. is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(alpha, beta).shape`.
  354. Returns:
  355. the output tensor.
  356. Examples:
  357. >>> import megengine.random as rand
  358. >>> x = rand.beta(alpha=2, beta=1, size=(2, 2))
  359. >>> x.numpy() # doctest: +SKIP
  360. array([[0.6172312 , 0.9789006 ],
  361. [0.50004643, 0.9775796 ]], dtype=float32)
  362. >>> alpha = mge.Tensor([[0.5],
  363. ... [ 3]], dtype="float32")
  364. >>> beta = mge.Tensor([0.5,5], dtype="float32")
  365. >>> x = rand.beta(alpha=alpha, beta=beta)
  366. >>> x.numpy() # doctest: +SKIP
  367. array([[0.0075407 , 0.1275094 ],
  368. [0.96331763, 0.22299217]], dtype=float32)
  369. >>> x = rand.beta(alpha=alpha, beta=beta, size=2)
  370. >>> x.numpy() # doctest: +SKIP
  371. array([[[0.46863747, 0.13819647],
  372. [0.8646759 , 0.16014215]],
  373. [[0.0682759 , 0.04448463],
  374. [0.97733796, 0.19206746]]], dtype=float32)
  375. """
  376. _seed = self._seed() if callable(self._seed) else self._seed
  377. return _beta(alpha=alpha, beta=beta, size=size, seed=_seed, handle=self._handle)
  378. def poisson(self, lam: Union[float, Tensor], size: Optional[Iterable[int]] = None):
  379. r"""Random variable with poisson distribution :math:`\operatorname{Poisson}(\lambda)`.
  380. The corresponding probability density function is
  381. .. math::
  382. f(k ; \lambda)=\frac{\lambda^{k} e^{-\lambda}}{k !},
  383. where k is the number of occurrences :math:`({\displaystyle k=0,1,2...})`.
  384. Args:
  385. lam: the lambda parameter of the distribution. Must be non-negative.
  386. size: the size of output tensor. If lam is a scalar and given size is, e.g., `(m, n)`,
  387. then the output shape is `(m, n)`. If lam is a Tensor with shape `(k, v)` and given
  388. size is, e.g., `(m, n)`, then the output shape is `(m, n, k, v)`. Default: None.
  389. Returns:
  390. the output tensor.
  391. Examples:
  392. >>> import megengine.random as rand
  393. >>> x = rand.poisson(lam=2., size=(1, 3))
  394. >>> x.numpy() # doctest: +SKIP
  395. array([[1., 2., 2.]], dtype=float32)
  396. >>> lam = mge.Tensor([[1.,1.],
  397. ... [10,10]], dtype="float32")
  398. >>> x = rand.poisson(lam=lam)
  399. >>> x.numpy() # doctest: +SKIP
  400. array([[ 1., 2.],
  401. [11., 11.]], dtype=float32)
  402. >>> x = rand.poisson(lam=lam, size=(1,3))
  403. >>> x.numpy() # doctest: +SKIP
  404. array([[[[ 2., 1.],
  405. [10., 8.]],
  406. [[ 5., 2.],
  407. [10., 10.]],
  408. [[ 1., 2.],
  409. [ 8., 10.]]]], dtype=float32)
  410. """
  411. _seed = self._seed() if callable(self._seed) else self._seed
  412. return _poisson(lam=lam, size=size, seed=_seed, handle=self._handle)
  413. def permutation(self, n: Union[int, Tensor], *, dtype: str = "int32"):
  414. r"""Randomly permute a sequence, or return a permuted range.
  415. If ``n`` is a multi-dimensional tensor, it is only shuffled along its first index.
  416. Args:
  417. n: If ``n`` is an integer, random permutation of integers from :math:`0` to :math:`n - 1`.
  418. If ``n`` is an tensor, make a copy and shuffle the elements randomly.
  419. dtype: the output data type when ``n`` is an integer.
  420. int32, int16 and float32 are supported. Default: int32
  421. Returns:
  422. the output tensor.
  423. Examples:
  424. >>> import numpy as np
  425. >>> import megengine.random as rand
  426. >>> x = rand.permutation(10, dtype="int32")
  427. >>> x.numpy() # doctest: +SKIP
  428. array([8, 4, 0, 3, 5, 6, 2, 1, 7, 9], dtype=int32)
  429. >>> x = rand.permutation(10, dtype="float32")
  430. >>> x.numpy() # doctest: +SKIP
  431. array([1., 3., 0., 2., 4., 8., 7., 9., 6., 5.], dtype=float32)
  432. >>> x = mge.tensor(np.arange(18)).reshape(6,3)
  433. >>> x = rand.permutation(x)
  434. >>> x.numpy() # doctest: +SKIP
  435. array([[15, 16, 17],
  436. [ 6, 7, 8],
  437. [ 0, 1, 2],
  438. [ 3, 4, 5],
  439. [12, 13, 14],
  440. [ 9, 10, 11]], dtype=int32)
  441. """
  442. _seed = self._seed() if callable(self._seed) else self._seed
  443. if isinstance(n, int):
  444. return _permutation(
  445. n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype
  446. )
  447. assert isinstance(n, Tensor)
  448. return _shuffle(inp=n, seed=_seed, handle=self._handle)
  449. def shuffle(self, inp: Tensor):
  450. r"""Modify a sequence in-place by shuffling its contents.
  451. This function only shuffles the Tensor along the first axis of a multi-dimensional Tensor.
  452. The order of sub-Tensors is changed but their contents remains the same.
  453. Args:
  454. inp: input tensor.
  455. Examples:
  456. >>> import numpy as np
  457. >>> import megengine.random as rand
  458. >>> x = mge.tensor(np.arange(10))
  459. >>> rand.shuffle(x)
  460. >>> x.numpy() # doctest: +SKIP
  461. array([4, 5, 9, 6, 2, 8, 1, 0, 3, 7], dtype=int32)
  462. >>> y = mge.tensor(np.arange(18)).reshape(6,3)
  463. >>> rand.shuffle(y)
  464. >>> y.numpy() # doctest: +SKIP
  465. array([[ 3, 4, 5],
  466. [ 6, 7, 8],
  467. [15, 16, 17],
  468. [ 0, 1, 2],
  469. [12, 13, 14],
  470. [ 9, 10, 11]], dtype=int32)
  471. """
  472. _seed = self._seed() if callable(self._seed) else self._seed
  473. inp._reset(_shuffle(inp=inp, seed=_seed, handle=self._handle))
  474. def __del__(self):
  475. if self._handle != 0:
  476. # RNG op might execute after handle released due to async dispatch, so
  477. # we need sync before delete a handle to avoid memory leak or
  478. # use-after-free
  479. _sync()
  480. _delete_rng_handle(self._handle)
  481. def _default_rng():
  482. r"""Default constructor for :class:`RNG`."""
  483. return RNG(seed=None, device=None)
  484. _default_handle = _default_rng()
  485. uniform = _default_handle.uniform
  486. normal = _default_handle.normal
  487. gamma = _default_handle.gamma
  488. beta = _default_handle.beta
  489. poisson = _default_handle.poisson
  490. permutation = _default_handle.permutation
  491. shuffle = _default_handle.shuffle
  492. def _random_seed_generator():
  493. assert _rng
  494. while True:
  495. yield _rng.random_raw()
  496. def seed(seed: int):
  497. global _rng # pylint: disable=global-statement
  498. _rng = MT19937(seed=seed)
  499. _set_global_rng_seed(seed)
  500. seed(int(time.time()))