diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 3c71e82c..4cd1daba 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -45,6 +45,7 @@ __all__ = [ "ones_like", "repeat", "reshape", + "roll", "split", "squeeze", "stack", @@ -1172,3 +1173,75 @@ def copy(inp, device=None): if device is None: return apply(Identity(), inp)[0] return apply(Copy(comp_node=as_device(device).to_c()), inp)[0] + + +def roll( + inp: Tensor, + shift: Union[int, Iterable[int]], + axis: Optional[Union[int, Iterable[int]]] = None, +): + """ + Roll the tensor along the given axis(or axes). Elements that are shifted + beyond the last position are re-introduced at the first position. + + :param inp: input tensor. + :param shift: the number of places by which the elements of the tensor are + shifted. If shift is a tuple, axis must be a tuple of the same size, + and each axis will be rolled by the corresponding shift value. + :param axis: axis along which to roll. If axis is not specified, the tensor + will be flattened before rolling and then restored to the original shape. + Duplicate axes is allowed if it is a tuple. Default: None. + + Examples: + + .. testcode:: + + import numpy as np + from megengine import tensor + import megengine.functional as F + + x = tensor([[1,2],[3,4],[5,6]], np.int32) + y = F.roll(x, 1, 0) + print(y.numpy()) + + Outputs: + + .. testoutput:: + + [[5 6] + [1 2] + [3 4]] + + """ + shp_bak = None + if axis is None: + shp_bak = inp.shape + inp = inp.flatten() + axis = 0 + shp = inp.shape + dim = len(shp) + if isinstance(shift, int): + assert isinstance(axis, int) + shift, axis = [shift,], [axis,] + assert len(shift) == len(axis) + out = inp + for i in range(len(shift)): + axis_ = axis[i] + shift_ = shift[i] + axis_normalized_ = axis_ + dim if axis_ < 0 else axis_ + assert ( + dim > axis_normalized_ >= 0 + ), "axis out of range (expected to be in range of [{}, {}], but got {})".format( + -dim, dim - 1, axis_ + ) + if shift_ == 0: + continue + size = shp[axis_normalized_] + if shift_ > 0: + a, b = split(out, [size - shift_,], axis=axis_normalized_) + else: + a, b = split(out, [-shift_,], axis=axis_normalized_) + out = concat((b, a), axis=axis_normalized_) + if shp_bak is not None: + out = out.reshape(shp_bak) + return out diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index a0ed6d1f..144a5e28 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -636,3 +636,33 @@ def test_tile(shape, reps, is_varnode): cases = [{"input": np.random.randn(*shape).astype("float32")}] opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network) + + +@pytest.mark.parametrize( + "shape, shifts, axis", + [ + ((2, 3), 0, None), + ((2, 3), 1, 0), + ((2, 3, 4, 5), (-1, 1), (0, 1)), + ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)), + ], +) +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_roll(shape, shifts, axis, is_varnode): + if is_varnode: + network = Network() + else: + network = None + + inp = np.random.randn(*shape).astype("float32") + + def func(inp): + return F.roll(inp, shifts, axis) + + cases = [ + {"input": inp}, + ] + + opr_test( + cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network + )