You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rng.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import pytest
  11. import megengine
  12. from megengine import is_cuda_available, tensor
  13. from megengine.core._imperative_rt import CompNode
  14. from megengine.core._imperative_rt.core2 import apply
  15. from megengine.core._imperative_rt.ops import (
  16. delete_rng_handle,
  17. get_global_rng_seed,
  18. new_rng_handle,
  19. )
  20. from megengine.core.ops.builtin import GaussianRNG, UniformRNG
  21. from megengine.distributed.helper import get_device_count_by_fork
  22. from megengine.random import RNG
  23. from megengine.random.rng import _normal, _uniform
  24. @pytest.mark.skipif(
  25. get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2",
  26. )
  27. def test_gaussian_op():
  28. shape = (
  29. 8,
  30. 9,
  31. 11,
  32. 12,
  33. )
  34. shape = tensor(shape, dtype="int32")
  35. op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0)
  36. (output,) = apply(op, shape)
  37. assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
  38. assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1
  39. assert str(output.device) == str(CompNode("xpux"))
  40. cn = CompNode("xpu2")
  41. seed = 233333
  42. h = new_rng_handle(cn, seed)
  43. op = GaussianRNG(seed=seed, mean=3.0, std=1.0, handle=h)
  44. (output,) = apply(op, shape)
  45. delete_rng_handle(h)
  46. assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
  47. assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1
  48. assert str(output.device) == str(cn)
  49. @pytest.mark.skipif(
  50. get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2",
  51. )
  52. def test_uniform_op():
  53. shape = (
  54. 8,
  55. 9,
  56. 11,
  57. 12,
  58. )
  59. shape = tensor(shape, dtype="int32")
  60. op = UniformRNG(seed=get_global_rng_seed())
  61. (output,) = apply(op, shape)
  62. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  63. assert str(output.device) == str(CompNode("xpux"))
  64. cn = CompNode("xpu2")
  65. seed = 233333
  66. h = new_rng_handle(cn, seed)
  67. op = UniformRNG(seed=seed, handle=h)
  68. (output,) = apply(op, shape)
  69. delete_rng_handle(h)
  70. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  71. assert str(output.device) == str(cn)
  72. @pytest.mark.skipif(
  73. get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1",
  74. )
  75. def test_UniformRNG():
  76. m1 = RNG(seed=111, device="xpu0")
  77. m2 = RNG(seed=111, device="xpu1")
  78. m3 = RNG(seed=222, device="xpu0")
  79. out1 = m1.uniform(size=(100,))
  80. out1_ = m1.uniform(size=(100,))
  81. out2 = m2.uniform(size=(100,))
  82. out3 = m3.uniform(size=(100,))
  83. np.testing.assert_equal(out1.numpy(), out2.numpy())
  84. assert out1.device == "xpu0" and out2.device == "xpu1"
  85. assert not (out1.numpy() == out3.numpy()).all()
  86. assert not (out1.numpy() == out1_.numpy()).all()
  87. low = -234
  88. high = 123
  89. out = m1.uniform(low=low, high=high, size=(20, 30, 40))
  90. out_shp = out.shape
  91. if isinstance(out_shp, tuple):
  92. assert out_shp == (20, 30, 40)
  93. else:
  94. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  95. assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1
  96. @pytest.mark.skipif(
  97. get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1",
  98. )
  99. def test_NormalRNG():
  100. m1 = RNG(seed=111, device="xpu0")
  101. m2 = RNG(seed=111, device="xpu1")
  102. m3 = RNG(seed=222, device="xpu0")
  103. out1 = m1.normal(size=(100,))
  104. out1_ = m1.uniform(size=(100,))
  105. out2 = m2.normal(size=(100,))
  106. out3 = m3.normal(size=(100,))
  107. np.testing.assert_equal(out1.numpy(), out2.numpy())
  108. assert out1.device == "xpu0" and out2.device == "xpu1"
  109. assert not (out1.numpy() == out3.numpy()).all()
  110. assert not (out1.numpy() == out1_.numpy()).all()
  111. mean = -1
  112. std = 2
  113. out = m1.normal(mean=mean, std=std, size=(20, 30, 40))
  114. out_shp = out.shape
  115. if isinstance(out_shp, tuple):
  116. assert out_shp == (20, 30, 40)
  117. else:
  118. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  119. assert np.abs(out.mean().numpy() - mean) / std < 0.1
  120. assert np.abs(np.std(out.numpy()) - std) < 0.1

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台