Browse Source

fix(mge/random): fix random seed

GitOrigin-RevId: 121f459b1b
release-1.5
Megvii Engine Team huangxinda 3 years ago
parent
commit
7c9569e4e5
2 changed files with 24 additions and 4 deletions
  1. +16
    -1
      imperative/python/test/unit/random/test_rng.py
  2. +8
    -3
      imperative/src/impl/ops/rng.cpp

+ 16
- 1
imperative/python/test/unit/random/test_rng.py View File

@@ -27,7 +27,7 @@ from megengine.core.ops.builtin import (
UniformRNG,
)
from megengine.distributed.helper import get_device_count_by_fork
from megengine.random import RNG
from megengine.random import RNG, seed, uniform


@pytest.mark.skipif(
@@ -387,3 +387,18 @@ def test_PermutationRNG():

assert sum_result(out, lambda x: x) < 500
assert sum_result(out, np.sort) == 1000


def test_seed():
seed(10)
out1 = uniform(size=[10, 10])
out2 = uniform(size=[10, 10])
assert not (out1.numpy() == out2.numpy()).all()

seed(10)
out3 = uniform(size=[10, 10])
np.testing.assert_equal(out1.numpy(), out3.numpy())

seed(11)
out4 = uniform(size=[10, 10])
assert not (out1.numpy() == out4.numpy()).all()

+ 8
- 3
imperative/src/impl/ops/rng.cpp View File

@@ -127,10 +127,8 @@ public:
auto&& glob_handle = glob_default_handles[comp_node];
if (!glob_handle) {
glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
} else if (get_seed(glob_handle) != glob_default_seed) {
inst().DnnOpManagerBase::delete_handle(glob_handle);
glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
}
mgb_assert(get_seed(glob_handle) == glob_default_seed);
return glob_handle;
}

@@ -141,6 +139,13 @@ public:

static void set_glob_default_seed(uint64_t seed) {
MGB_LOCK_GUARD(sm_mtx);
for(auto && elem : glob_default_handles){
mgb_assert(elem.first.valid());
if(elem.second){
inst().DnnOpManagerBase::delete_handle(elem.second);
}
elem.second = inst().do_new_handle(elem.first, seed);
}
glob_default_seed = seed;
}



Loading…
Cancel
Save