# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import time from typing import Iterable, Optional from numpy.random import MT19937 from .. import Tensor from ..core._imperative_rt.core2 import apply from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed from ..core.ops.builtin import GaussianRNG, UniformRNG from ..core.tensor import utils from ..device import get_default_device _rng = None def _normal( mean: float, std: float, size: Optional[Iterable[int]], seed: int, device: str, handle: int, ) -> Tensor: if size is None: size = (1,) op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle) _ref = Tensor([], dtype="int32", device=device) shape = utils.astensor1d(size, _ref, dtype="int32", device=device) (output,) = apply(op, shape) return output def _uniform( low: float, high: float, size: Optional[Iterable[int]], seed: int, device: str, handle: int, ) -> Tensor: assert low < high, "Uniform is not defined when low >= high" if size is None: size = (1,) op = UniformRNG(seed=seed, handle=handle) _ref = Tensor([], dtype="int32", device=device) shape = utils.astensor1d(size, _ref, dtype="int32", device=device) (output,) = apply(op, shape) return low + (high - low) * output class RNG: def __init__(self, seed=0, device=None): self.seed = seed self.device = device if device else get_default_device() self.handle = _new_rng_handle(self.device, self.seed) def uniform( self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None ): return _uniform( low=low, high=high, size=size, seed=self.seed, device=self.device, handle=self.handle, ) def normal( self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None ): return _normal( mean=mean, std=std, size=size, seed=self.seed, device=self.device, handle=self.handle, ) def __del__(self): _delete_rng_handle(self.handle) def _random_seed_generator(): assert _rng while True: yield _rng.random_raw() def seed(seed: int): global _rng # pylint: disable=global-statement _rng = MT19937(seed=seed) _set_global_rng_seed(seed) seed(int(time.time()))