You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rng.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import megengine
  11. from megengine import tensor
  12. from megengine.core._imperative_rt import CompNode
  13. from megengine.core._imperative_rt.core2 import apply
  14. from megengine.core._imperative_rt.ops import (
  15. delete_rng_handle,
  16. get_global_rng_seed,
  17. new_rng_handle,
  18. )
  19. from megengine.core.ops.builtin import GaussianRNG, UniformRNG
  20. from megengine.random import RNG
  21. from megengine.random.rng import _normal, _uniform
  22. def test_gaussian_op():
  23. shape = (
  24. 8,
  25. 9,
  26. 11,
  27. 12,
  28. )
  29. shape = tensor(shape, dtype="int32")
  30. op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0)
  31. (output,) = apply(op, shape)
  32. assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
  33. assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1
  34. assert str(output.device) == str(CompNode("xpux"))
  35. cn = CompNode("xpu2")
  36. seed = 233333
  37. h = new_rng_handle(cn, seed)
  38. op = GaussianRNG(seed=seed, mean=3.0, std=1.0, handle=h)
  39. (output,) = apply(op, shape)
  40. delete_rng_handle(h)
  41. assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
  42. assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1
  43. assert str(output.device) == str(cn)
  44. def test_uniform_op():
  45. shape = (
  46. 8,
  47. 9,
  48. 11,
  49. 12,
  50. )
  51. shape = tensor(shape, dtype="int32")
  52. op = UniformRNG(seed=get_global_rng_seed())
  53. (output,) = apply(op, shape)
  54. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  55. assert str(output.device) == str(CompNode("xpux"))
  56. cn = CompNode("xpu2")
  57. seed = 233333
  58. h = new_rng_handle(cn, seed)
  59. op = UniformRNG(seed=seed, handle=h)
  60. (output,) = apply(op, shape)
  61. delete_rng_handle(h)
  62. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  63. assert str(output.device) == str(cn)
  64. def test_UniformRNG():
  65. m1 = RNG(seed=111, device="xpu0")
  66. m2 = RNG(seed=111, device="xpu1")
  67. m3 = RNG(seed=222, device="xpu0")
  68. out1 = m1.uniform(size=(100,))
  69. out1_ = m1.uniform(size=(100,))
  70. out2 = m2.uniform(size=(100,))
  71. out3 = m3.uniform(size=(100,))
  72. np.testing.assert_equal(out1.numpy(), out2.numpy())
  73. assert out1.device == "xpu0" and out2.device == "xpu1"
  74. assert not (out1.numpy() == out3.numpy()).all()
  75. assert not (out1.numpy() == out1_.numpy()).all()
  76. low = -234
  77. high = 123
  78. out = m1.uniform(low=low, high=high, size=(20, 30, 40))
  79. out_shp = out.shape
  80. if isinstance(out_shp, tuple):
  81. assert out_shp == (20, 30, 40)
  82. else:
  83. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  84. assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1
  85. def test_NormalRNG():
  86. m1 = RNG(seed=111, device="xpu0")
  87. m2 = RNG(seed=111, device="xpu1")
  88. m3 = RNG(seed=222, device="xpu0")
  89. out1 = m1.normal(size=(100,))
  90. out1_ = m1.uniform(size=(100,))
  91. out2 = m2.normal(size=(100,))
  92. out3 = m3.normal(size=(100,))
  93. np.testing.assert_equal(out1.numpy(), out2.numpy())
  94. assert out1.device == "xpu0" and out2.device == "xpu1"
  95. assert not (out1.numpy() == out3.numpy()).all()
  96. assert not (out1.numpy() == out1_.numpy()).all()
  97. mean = -1
  98. std = 2
  99. out = m1.normal(mean=mean, std=std, size=(20, 30, 40))
  100. out_shp = out.shape
  101. if isinstance(out_shp, tuple):
  102. assert out_shp == (20, 30, 40)
  103. else:
  104. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  105. assert np.abs(out.mean().numpy() - mean) / std < 0.1
  106. assert np.abs(np.std(out.numpy()) - std) < 0.1

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台