|
|
@@ -27,7 +27,7 @@ from megengine.core.ops.builtin import ( |
|
|
|
UniformRNG, |
|
|
|
) |
|
|
|
from megengine.distributed.helper import get_device_count_by_fork |
|
|
|
from megengine.random import RNG |
|
|
|
from megengine.random import RNG, seed, uniform |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif( |
|
|
@@ -387,3 +387,18 @@ def test_PermutationRNG(): |
|
|
|
|
|
|
|
assert sum_result(out, lambda x: x) < 500 |
|
|
|
assert sum_result(out, np.sort) == 1000 |
|
|
|
|
|
|
|
|
|
|
|
def test_seed(): |
|
|
|
seed(10) |
|
|
|
out1 = uniform(size=[10, 10]) |
|
|
|
out2 = uniform(size=[10, 10]) |
|
|
|
assert not (out1.numpy() == out2.numpy()).all() |
|
|
|
|
|
|
|
seed(10) |
|
|
|
out3 = uniform(size=[10, 10]) |
|
|
|
np.testing.assert_equal(out1.numpy(), out3.numpy()) |
|
|
|
|
|
|
|
seed(11) |
|
|
|
out4 = uniform(size=[10, 10]) |
|
|
|
assert not (out1.numpy() == out4.numpy()).all() |