diff --git a/python_module/megengine/random/distribution.py b/python_module/megengine/random/distribution.py index 1d3882cc..6565dd08 100644 --- a/python_module/megengine/random/distribution.py +++ b/python_module/megengine/random/distribution.py @@ -62,12 +62,16 @@ def gaussian( @wrap_io_tensor def uniform( shape: Iterable[int], + low: float = 0, + high: float = 1, comp_node: Optional[CompNode] = None, comp_graph: Optional[CompGraph] = None, ) -> Tensor: r"""Random variable with uniform distribution $U(0, 1)$ :param shape: Output tensor shape + :param low: Lower range + :param high: Upper range :param comp_node: The comp node output on, default to None :param comp_graph: The graph in which output is, default to None :return: The output tensor @@ -91,6 +95,6 @@ def uniform( """ comp_node, comp_graph = _use_default_if_none(comp_node, comp_graph) seed = _random_seed_generator().__next__() - return mgb.opr.uniform_rng( + return low + (high - low) * mgb.opr.uniform_rng( shape, seed=seed, comp_node=comp_node, comp_graph=comp_graph ) diff --git a/python_module/test/unit/random/test_random.py b/python_module/test/unit/random/test_random.py index 5d67b868..2e8023e8 100644 --- a/python_module/test/unit/random/test_random.py +++ b/python_module/test/unit/random/test_random.py @@ -59,6 +59,50 @@ def test_random_dynamic_same_result(): assert np.all(a.numpy() == b.numpy()) +def test_range_uniform_static_diff_result(): + @jit.trace(symbolic=True) + def graph_a(): + return R.uniform(5, low=-2, high=2) + + @jit.trace(symbolic=True) + def graph_b(): + return R.uniform(5, low=-2, high=2) + + a = graph_a() + b = graph_b() + assert np.any(a.numpy() != b.numpy()) + + +def test_range_uniform_static_same_result(): + @jit.trace(symbolic=True) + def graph_a(): + R.manual_seed(731) + return R.uniform(5, low=-2, high=2) + + @jit.trace(symbolic=True) + def graph_b(): + R.manual_seed(731) + return R.uniform(5, low=-2, high=2) + + a = graph_a() + b = graph_b() + assert np.all(a.numpy() == b.numpy()) + + +def test_range_uniform_dynamic_diff_result(): + a = R.uniform(5, low=-2, high=2) + b = R.uniform(5, low=-2, high=2) + assert np.any(a.numpy() != b.numpy()) + + +def test_range_uniform_dynamic_same_result(): + R.manual_seed(0) + a = R.uniform(5, low=-2, high=2) + R.manual_seed(0) + b = R.uniform(5, low=-2, high=2) + assert np.all(a.numpy() == b.numpy()) + + def test_dropout_dynamic_diff_result(): x = mge.ones(10) a = F.dropout(x, 0.5)