From 13272eaa63ef7160d332f8d746ff940c4529bb52 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 18 Nov 2020 21:48:26 +0800 Subject: [PATCH] fix(mge/trace): fix random op in symbolic trace GitOrigin-RevId: 9a851cd177119ee43155b831e9622ce342423090 --- imperative/python/megengine/random/distribution.py | 6 ++++-- imperative/python/test/unit/test_tracing.py | 21 +++++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/imperative/python/megengine/random/distribution.py b/imperative/python/megengine/random/distribution.py index f3a91662..82852200 100644 --- a/imperative/python/megengine/random/distribution.py +++ b/imperative/python/megengine/random/distribution.py @@ -52,7 +52,8 @@ def normal( size = (1,) seed = _random_seed_generator().__next__() op = GaussianRNG(seed=seed, mean=mean, std=std) - size = Tensor(size, dtype="int32") + _ref = Tensor([], dtype="int32") + size = utils.astensor1d(size, _ref, dtype="int32") (output,) = apply(op, size) return output @@ -93,7 +94,8 @@ def uniform( size = (1,) seed = _random_seed_generator().__next__() op = UniformRNG(seed=seed) - size = Tensor(size, dtype="int32") + _ref = Tensor([], dtype="int32") + size = utils.astensor1d(size, _ref, dtype="int32") (output,) = apply(op, size) return low + (high - low) * output diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index e0ce660f..8646becc 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -23,6 +23,7 @@ from megengine.core.tensor.core import apply from megengine.core.tensor.raw_tensor import as_raw_tensor from megengine.functional import exp, log from megengine.jit import exclude_from_trace, trace +from megengine.random import normal, uniform def test_trace(): @@ -431,3 +432,23 @@ def test_slice(): y = f(x) np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2]) y + y + + +def test_random(): + def run_test(op): + for symbolic_shape in [True, False]: + + @trace(symbolic=True, symbolic_shape=symbolic_shape) + def f(): + out = op(size=[10, 10]) + out_shape = out.shape + assert out_shape is not None + if not isinstance(out_shape, tuple): + assert out.shape.numpy() is not None + return out + + for _ in range(3): + f() + + run_test(uniform) + run_test(normal)