You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rng.py 2.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. from megengine import tensor
  11. from megengine.core._imperative_rt import CompNode
  12. from megengine.core._imperative_rt.ops import delete_rng_handle, new_rng_handle
  13. from megengine.core.ops.builtin import GaussianRNG, UniformRNG
  14. from megengine.core.tensor.core import apply
  15. def test_gaussian_rng():
  16. shape = (
  17. 8,
  18. 9,
  19. 11,
  20. 12,
  21. )
  22. shape = tensor(shape, dtype="int32")
  23. op = GaussianRNG(1.0, 3.0)
  24. (output,) = apply(op, shape)
  25. assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
  26. assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1
  27. assert str(output.device) == str(CompNode("xpux"))
  28. cn = CompNode("xpu1")
  29. op = GaussianRNG(-1.0, 2.0, cn)
  30. (output,) = apply(op, shape)
  31. assert np.fabs(output.numpy().mean() - (-1.0)) < 1e-1
  32. assert np.sqrt(output.numpy().var()) - 2.0 < 1e-1
  33. assert str(output.device) == str(cn)
  34. cn = CompNode("xpu2")
  35. seed = 233333
  36. h = new_rng_handle(cn, seed)
  37. op = GaussianRNG(3.0, 1.0, h)
  38. (output,) = apply(op, shape)
  39. delete_rng_handle(h)
  40. assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
  41. assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1
  42. assert str(output.device) == str(cn)
  43. def test_uniform_rng():
  44. shape = (
  45. 8,
  46. 9,
  47. 11,
  48. 12,
  49. )
  50. shape = tensor(shape, dtype="int32")
  51. op = UniformRNG()
  52. (output,) = apply(op, shape)
  53. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  54. assert str(output.device) == str(CompNode("xpux"))
  55. cn = CompNode("xpu1")
  56. op = UniformRNG(cn)
  57. (output,) = apply(op, shape)
  58. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  59. assert str(output.device) == str(cn)
  60. cn = CompNode("xpu2")
  61. seed = 233333
  62. h = new_rng_handle(cn, seed)
  63. op = UniformRNG(h)
  64. (output,) = apply(op, shape)
  65. delete_rng_handle(h)
  66. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  67. assert str(output.device) == str(cn)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台