Browse Source

fix(mge/trace): fix random op in symbolic trace

GitOrigin-RevId: 9a851cd177
release-1.1
Megvii Engine Team 4 years ago
parent
commit
13272eaa63
2 changed files with 25 additions and 2 deletions
  1. +4
    -2
      imperative/python/megengine/random/distribution.py
  2. +21
    -0
      imperative/python/test/unit/test_tracing.py

+ 4
- 2
imperative/python/megengine/random/distribution.py View File

@@ -52,7 +52,8 @@ def normal(
size = (1,)
seed = _random_seed_generator().__next__()
op = GaussianRNG(seed=seed, mean=mean, std=std)
size = Tensor(size, dtype="int32")
_ref = Tensor([], dtype="int32")
size = utils.astensor1d(size, _ref, dtype="int32")
(output,) = apply(op, size)
return output

@@ -93,7 +94,8 @@ def uniform(
size = (1,)
seed = _random_seed_generator().__next__()
op = UniformRNG(seed=seed)
size = Tensor(size, dtype="int32")
_ref = Tensor([], dtype="int32")
size = utils.astensor1d(size, _ref, dtype="int32")
(output,) = apply(op, size)

return low + (high - low) * output

+ 21
- 0
imperative/python/test/unit/test_tracing.py View File

@@ -23,6 +23,7 @@ from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace
from megengine.random import normal, uniform


def test_trace():
@@ -431,3 +432,23 @@ def test_slice():
y = f(x)
np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2])
y + y


def test_random():
def run_test(op):
for symbolic_shape in [True, False]:

@trace(symbolic=True, symbolic_shape=symbolic_shape)
def f():
out = op(size=[10, 10])
out_shape = out.shape
assert out_shape is not None
if not isinstance(out_shape, tuple):
assert out.shape.numpy() is not None
return out

for _ in range(3):
f()

run_test(uniform)
run_test(normal)

Loading…
Cancel
Save