@@ -27,7 +27,7 @@ from megengine.core.ops.builtin import ( | |||||
UniformRNG, | UniformRNG, | ||||
) | ) | ||||
from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
from megengine.random import RNG | |||||
from megengine.random import RNG, seed, uniform | |||||
@pytest.mark.skipif( | @pytest.mark.skipif( | ||||
@@ -387,3 +387,18 @@ def test_PermutationRNG(): | |||||
assert sum_result(out, lambda x: x) < 500 | assert sum_result(out, lambda x: x) < 500 | ||||
assert sum_result(out, np.sort) == 1000 | assert sum_result(out, np.sort) == 1000 | ||||
def test_seed(): | |||||
seed(10) | |||||
out1 = uniform(size=[10, 10]) | |||||
out2 = uniform(size=[10, 10]) | |||||
assert not (out1.numpy() == out2.numpy()).all() | |||||
seed(10) | |||||
out3 = uniform(size=[10, 10]) | |||||
np.testing.assert_equal(out1.numpy(), out3.numpy()) | |||||
seed(11) | |||||
out4 = uniform(size=[10, 10]) | |||||
assert not (out1.numpy() == out4.numpy()).all() |
@@ -127,10 +127,8 @@ public: | |||||
auto&& glob_handle = glob_default_handles[comp_node]; | auto&& glob_handle = glob_default_handles[comp_node]; | ||||
if (!glob_handle) { | if (!glob_handle) { | ||||
glob_handle = inst().do_new_handle(comp_node, glob_default_seed); | glob_handle = inst().do_new_handle(comp_node, glob_default_seed); | ||||
} else if (get_seed(glob_handle) != glob_default_seed) { | |||||
inst().DnnOpManagerBase::delete_handle(glob_handle); | |||||
glob_handle = inst().do_new_handle(comp_node, glob_default_seed); | |||||
} | } | ||||
mgb_assert(get_seed(glob_handle) == glob_default_seed); | |||||
return glob_handle; | return glob_handle; | ||||
} | } | ||||
@@ -141,6 +139,13 @@ public: | |||||
static void set_glob_default_seed(uint64_t seed) { | static void set_glob_default_seed(uint64_t seed) { | ||||
MGB_LOCK_GUARD(sm_mtx); | MGB_LOCK_GUARD(sm_mtx); | ||||
for(auto && elem : glob_default_handles){ | |||||
mgb_assert(elem.first.valid()); | |||||
if(elem.second){ | |||||
inst().DnnOpManagerBase::delete_handle(elem.second); | |||||
} | |||||
elem.second = inst().do_new_handle(elem.first, seed); | |||||
} | |||||
glob_default_seed = seed; | glob_default_seed = seed; | ||||
} | } | ||||