GitOrigin-RevId: 9a851cd177
release-1.1
@@ -52,7 +52,8 @@ def normal( | |||||
size = (1,) | size = (1,) | ||||
seed = _random_seed_generator().__next__() | seed = _random_seed_generator().__next__() | ||||
op = GaussianRNG(seed=seed, mean=mean, std=std) | op = GaussianRNG(seed=seed, mean=mean, std=std) | ||||
size = Tensor(size, dtype="int32") | |||||
_ref = Tensor([], dtype="int32") | |||||
size = utils.astensor1d(size, _ref, dtype="int32") | |||||
(output,) = apply(op, size) | (output,) = apply(op, size) | ||||
return output | return output | ||||
@@ -93,7 +94,8 @@ def uniform( | |||||
size = (1,) | size = (1,) | ||||
seed = _random_seed_generator().__next__() | seed = _random_seed_generator().__next__() | ||||
op = UniformRNG(seed=seed) | op = UniformRNG(seed=seed) | ||||
size = Tensor(size, dtype="int32") | |||||
_ref = Tensor([], dtype="int32") | |||||
size = utils.astensor1d(size, _ref, dtype="int32") | |||||
(output,) = apply(op, size) | (output,) = apply(op, size) | ||||
return low + (high - low) * output | return low + (high - low) * output |
@@ -23,6 +23,7 @@ from megengine.core.tensor.core import apply | |||||
from megengine.core.tensor.raw_tensor import as_raw_tensor | from megengine.core.tensor.raw_tensor import as_raw_tensor | ||||
from megengine.functional import exp, log | from megengine.functional import exp, log | ||||
from megengine.jit import exclude_from_trace, trace | from megengine.jit import exclude_from_trace, trace | ||||
from megengine.random import normal, uniform | |||||
def test_trace(): | def test_trace(): | ||||
@@ -431,3 +432,23 @@ def test_slice(): | |||||
y = f(x) | y = f(x) | ||||
np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2]) | np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2]) | ||||
y + y | y + y | ||||
def test_random(): | |||||
def run_test(op): | |||||
for symbolic_shape in [True, False]: | |||||
@trace(symbolic=True, symbolic_shape=symbolic_shape) | |||||
def f(): | |||||
out = op(size=[10, 10]) | |||||
out_shape = out.shape | |||||
assert out_shape is not None | |||||
if not isinstance(out_shape, tuple): | |||||
assert out.shape.numpy() is not None | |||||
return out | |||||
for _ in range(3): | |||||
f() | |||||
run_test(uniform) | |||||
run_test(normal) |