Browse Source

feat(opr): let roll support empty IO

GitOrigin-RevId: b9a59b623a
release-1.6
Megvii Engine Team 3 years ago
parent
commit
43fb2c4afb
3 changed files with 31 additions and 5 deletions
  1. +4
    -3
      imperative/python/megengine/functional/tensor.py
  2. +23
    -0
      imperative/python/test/unit/functional/test_tensor.py
  3. +4
    -2
      src/opr/impl/tensor_manip.cpp

+ 4
- 3
imperative/python/megengine/functional/tensor.py View File

@@ -1352,10 +1352,11 @@ def roll(
if shift_ == 0:
continue
size = shp[axis_normalized_]
if shift_ > 0:
a, b = split(out, [size - shift_,], axis=axis_normalized_)
shift_normalized_ = 0 if size == 0 else shift_ % size
if shift_normalized_ > 0:
a, b = split(out, [size - shift_normalized_,], axis=axis_normalized_)
else:
a, b = split(out, [-shift_,], axis=axis_normalized_)
a, b = split(out, [-shift_normalized_,], axis=axis_normalized_)
out = concat((b, a), axis=axis_normalized_)
if shp_bak is not None:
out = out.reshape(shp_bak)


+ 23
- 0
imperative/python/test/unit/functional/test_tensor.py View File

@@ -806,6 +806,8 @@ def test_tile(shape, reps, is_varnode):
[
((2, 3), 0, None),
((2, 3), 1, 0),
((2, 3), 100, 0),
((2, 3), -100, 0),
((2, 3, 4, 5), (-1, 1), (0, 1)),
((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)),
],
@@ -829,3 +831,24 @@ def test_roll(shape, shifts, axis, is_varnode):
opr_test(
cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network
)


@pytest.mark.parametrize(
"shape, shifts, axis", [((10, 0), 5, 1), ((10, 0), -10, 1),],
)
@pytest.mark.parametrize("is_symbolic", [None, True, False])
def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
inp = tensor(np.random.randn(*shape).astype("float32"))

def func(inp):
return F.roll(inp, shifts, axis)

if is_symbolic is not None:
func = trace(symbolic=is_symbolic)(func)

out_ref = np.roll(inp.numpy(), shifts, axis)
for _ in range(3):
out = F.roll(inp, shifts, axis)
np.testing.assert_equal(out.numpy(), out_ref)
if is_symbolic is None:
break

+ 4
- 2
src/opr/impl/tensor_manip.cpp View File

@@ -1339,8 +1339,10 @@ void Concat::scn_do_execute() {
if (real_axis < 0)
real_axis += in.shape().ndim;
end = begin + in.shape().shape[real_axis];
out.sub(Slice(begin, end).apply(out.layout(), real_axis)).
copy_from_fixlayout(in);
if (!in.layout().is_empty()) {
out.sub(Slice(begin, end).apply(out.layout(), real_axis)).
copy_from_fixlayout(in);
}
}
}



Loading…
Cancel
Save