You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

distribution.py 2.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Iterable, Optional
  10. from .. import Tensor
  11. from ..core._imperative_rt import invoke_op
  12. from ..core.ops.builtin import GaussianRNG, UniformRNG
  13. from ..core.tensor import utils
  14. from ..core.tensor.core import apply
  15. from .rng import _random_seed_generator
  16. __all__ = ["gaussian", "uniform"]
  17. def gaussian(shape: Iterable[int], mean: float = 0, std: float = 1,) -> Tensor:
  18. r"""Random variable with Gaussian distribution $N(\mu, \sigma)$
  19. :param shape: Output tensor shape
  20. :param mean: The mean or expectation of the distribution
  21. :param std: The standard deviation of the distribution (variance = $\sigma ^ 2$)
  22. :return: The output tensor
  23. Examples:
  24. .. testcode::
  25. import megengine as mge
  26. import megengine.random as rand
  27. x = rand.gaussian((2, 2), mean=0, std=1)
  28. print(x.numpy())
  29. .. testoutput::
  30. :options: +SKIP
  31. [[-0.20235455 -0.6959438 ]
  32. [-1.4939808 -1.5824696 ]]
  33. """
  34. seed = _random_seed_generator().__next__()
  35. op = GaussianRNG(seed=seed, mean=mean, std=std)
  36. shape = Tensor(shape, dtype="int32")
  37. (output,) = apply(op, shape)
  38. return output
  39. def uniform(shape: Iterable[int], low: float = 0, high: float = 1,) -> Tensor:
  40. r"""Random variable with uniform distribution $U(0, 1)$
  41. :param shape: Output tensor shape
  42. :param low: Lower range
  43. :param high: Upper range
  44. :return: The output tensor
  45. Examples:
  46. .. testcode::
  47. import megengine as mge
  48. import megengine.random as rand
  49. x = rand.uniform((2, 2))
  50. print(x.numpy())
  51. .. testoutput::
  52. :options: +SKIP
  53. [[0.76901674 0.70496535]
  54. [0.09365904 0.62957656]]
  55. """
  56. assert low < high, "Uniform is not defined when low >= high"
  57. seed = _random_seed_generator().__next__()
  58. op = UniformRNG(seed=seed)
  59. shape = Tensor(shape, dtype="int32")
  60. (output,) = apply(op, shape)
  61. return low + (high - low) * output

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台