|
|
@@ -806,6 +806,8 @@ def test_tile(shape, reps, is_varnode): |
|
|
|
[ |
|
|
|
((2, 3), 0, None), |
|
|
|
((2, 3), 1, 0), |
|
|
|
((2, 3), 100, 0), |
|
|
|
((2, 3), -100, 0), |
|
|
|
((2, 3, 4, 5), (-1, 1), (0, 1)), |
|
|
|
((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)), |
|
|
|
], |
|
|
@@ -829,3 +831,24 @@ def test_roll(shape, shifts, axis, is_varnode): |
|
|
|
opr_test( |
|
|
|
cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
"shape, shifts, axis", [((10, 0), 5, 1), ((10, 0), -10, 1),], |
|
|
|
) |
|
|
|
@pytest.mark.parametrize("is_symbolic", [None, True, False]) |
|
|
|
def test_roll_empty_tensor(shape, shifts, axis, is_symbolic): |
|
|
|
inp = tensor(np.random.randn(*shape).astype("float32")) |
|
|
|
|
|
|
|
def func(inp): |
|
|
|
return F.roll(inp, shifts, axis) |
|
|
|
|
|
|
|
if is_symbolic is not None: |
|
|
|
func = trace(symbolic=is_symbolic)(func) |
|
|
|
|
|
|
|
out_ref = np.roll(inp.numpy(), shifts, axis) |
|
|
|
for _ in range(3): |
|
|
|
out = F.roll(inp, shifts, axis) |
|
|
|
np.testing.assert_equal(out.numpy(), out_ref) |
|
|
|
if is_symbolic is None: |
|
|
|
break |