You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rng.py 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import time
  10. from typing import Iterable, Optional
  11. from numpy.random import MT19937
  12. from .. import Tensor
  13. from ..core._imperative_rt.core2 import apply
  14. from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle
  15. from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
  16. from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle
  17. from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed
  18. from ..core.ops.builtin import GaussianRNG, UniformRNG
  19. from ..core.tensor import utils
  20. from ..device import get_default_device
  21. _rng = None
  22. def _normal(
  23. mean: float,
  24. std: float,
  25. size: Optional[Iterable[int]],
  26. seed: int,
  27. device: str,
  28. handle: int,
  29. ) -> Tensor:
  30. if size is None:
  31. size = (1,)
  32. op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle)
  33. _ref = Tensor([], dtype="int32", device=device)
  34. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  35. (output,) = apply(op, shape)
  36. return output
  37. def _uniform(
  38. low: float,
  39. high: float,
  40. size: Optional[Iterable[int]],
  41. seed: int,
  42. device: str,
  43. handle: int,
  44. ) -> Tensor:
  45. assert low < high, "Uniform is not defined when low >= high"
  46. if size is None:
  47. size = (1,)
  48. op = UniformRNG(seed=seed, handle=handle)
  49. _ref = Tensor([], dtype="int32", device=device)
  50. shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
  51. (output,) = apply(op, shape)
  52. return low + (high - low) * output
  53. class RNG:
  54. def __init__(self, seed=0, device=None):
  55. self.seed = seed
  56. self.device = device if device else get_default_device()
  57. self.handle = _new_rng_handle(self.device, self.seed)
  58. def uniform(
  59. self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None
  60. ):
  61. return _uniform(
  62. low=low,
  63. high=high,
  64. size=size,
  65. seed=self.seed,
  66. device=self.device,
  67. handle=self.handle,
  68. )
  69. def normal(
  70. self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None
  71. ):
  72. return _normal(
  73. mean=mean,
  74. std=std,
  75. size=size,
  76. seed=self.seed,
  77. device=self.device,
  78. handle=self.handle,
  79. )
  80. def __del__(self):
  81. _delete_rng_handle(self.handle)
  82. def _random_seed_generator():
  83. assert _rng
  84. while True:
  85. yield _rng.random_raw()
  86. def seed(seed: int):
  87. global _rng # pylint: disable=global-statement
  88. _rng = MT19937(seed=seed)
  89. _set_global_rng_seed(seed)
  90. seed(int(time.time()))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台