Browse Source

fix(mge/random): add lower bound and higher bound for uniform sampling

GitOrigin-RevId: 2a2c56fd17
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
fc4dedc0bc
2 changed files with 49 additions and 1 deletions
  1. +5
    -1
      python_module/megengine/random/distribution.py
  2. +44
    -0
      python_module/test/unit/random/test_random.py

+ 5
- 1
python_module/megengine/random/distribution.py View File

@@ -62,12 +62,16 @@ def gaussian(
@wrap_io_tensor @wrap_io_tensor
def uniform( def uniform(
shape: Iterable[int], shape: Iterable[int],
low: float = 0,
high: float = 1,
comp_node: Optional[CompNode] = None, comp_node: Optional[CompNode] = None,
comp_graph: Optional[CompGraph] = None, comp_graph: Optional[CompGraph] = None,
) -> Tensor: ) -> Tensor:
r"""Random variable with uniform distribution $U(0, 1)$ r"""Random variable with uniform distribution $U(0, 1)$


:param shape: Output tensor shape :param shape: Output tensor shape
:param low: Lower range
:param high: Upper range
:param comp_node: The comp node output on, default to None :param comp_node: The comp node output on, default to None
:param comp_graph: The graph in which output is, default to None :param comp_graph: The graph in which output is, default to None
:return: The output tensor :return: The output tensor
@@ -91,6 +95,6 @@ def uniform(
""" """
comp_node, comp_graph = _use_default_if_none(comp_node, comp_graph) comp_node, comp_graph = _use_default_if_none(comp_node, comp_graph)
seed = _random_seed_generator().__next__() seed = _random_seed_generator().__next__()
return mgb.opr.uniform_rng(
return low + (high - low) * mgb.opr.uniform_rng(
shape, seed=seed, comp_node=comp_node, comp_graph=comp_graph shape, seed=seed, comp_node=comp_node, comp_graph=comp_graph
) )

+ 44
- 0
python_module/test/unit/random/test_random.py View File

@@ -59,6 +59,50 @@ def test_random_dynamic_same_result():
assert np.all(a.numpy() == b.numpy()) assert np.all(a.numpy() == b.numpy())




def test_range_uniform_static_diff_result():
@jit.trace(symbolic=True)
def graph_a():
return R.uniform(5, low=-2, high=2)

@jit.trace(symbolic=True)
def graph_b():
return R.uniform(5, low=-2, high=2)

a = graph_a()
b = graph_b()
assert np.any(a.numpy() != b.numpy())


def test_range_uniform_static_same_result():
@jit.trace(symbolic=True)
def graph_a():
R.manual_seed(731)
return R.uniform(5, low=-2, high=2)

@jit.trace(symbolic=True)
def graph_b():
R.manual_seed(731)
return R.uniform(5, low=-2, high=2)

a = graph_a()
b = graph_b()
assert np.all(a.numpy() == b.numpy())


def test_range_uniform_dynamic_diff_result():
a = R.uniform(5, low=-2, high=2)
b = R.uniform(5, low=-2, high=2)
assert np.any(a.numpy() != b.numpy())


def test_range_uniform_dynamic_same_result():
R.manual_seed(0)
a = R.uniform(5, low=-2, high=2)
R.manual_seed(0)
b = R.uniform(5, low=-2, high=2)
assert np.all(a.numpy() == b.numpy())


def test_dropout_dynamic_diff_result(): def test_dropout_dynamic_diff_result():
x = mge.ones(10) x = mge.ones(10)
a = F.dropout(x, 0.5) a = F.dropout(x, 0.5)


Loading…
Cancel
Save