Browse Source

feat(imperative/functional): add roll in functional

GitOrigin-RevId: ff630f1fe7
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
c64b1c94cd
2 changed files with 103 additions and 0 deletions
  1. +73
    -0
      imperative/python/megengine/functional/tensor.py
  2. +30
    -0
      imperative/python/test/unit/functional/test_tensor.py

+ 73
- 0
imperative/python/megengine/functional/tensor.py View File

@@ -45,6 +45,7 @@ __all__ = [
"ones_like",
"repeat",
"reshape",
"roll",
"split",
"squeeze",
"stack",
@@ -1172,3 +1173,75 @@ def copy(inp, device=None):
if device is None:
return apply(Identity(), inp)[0]
return apply(Copy(comp_node=as_device(device).to_c()), inp)[0]


def roll(
inp: Tensor,
shift: Union[int, Iterable[int]],
axis: Optional[Union[int, Iterable[int]]] = None,
):
"""
Roll the tensor along the given axis(or axes). Elements that are shifted
beyond the last position are re-introduced at the first position.

:param inp: input tensor.
:param shift: the number of places by which the elements of the tensor are
shifted. If shift is a tuple, axis must be a tuple of the same size,
and each axis will be rolled by the corresponding shift value.
:param axis: axis along which to roll. If axis is not specified, the tensor
will be flattened before rolling and then restored to the original shape.
Duplicate axes is allowed if it is a tuple. Default: None.

Examples:

.. testcode::

import numpy as np
from megengine import tensor
import megengine.functional as F

x = tensor([[1,2],[3,4],[5,6]], np.int32)
y = F.roll(x, 1, 0)
print(y.numpy())

Outputs:

.. testoutput::

[[5 6]
[1 2]
[3 4]]

"""
shp_bak = None
if axis is None:
shp_bak = inp.shape
inp = inp.flatten()
axis = 0
shp = inp.shape
dim = len(shp)
if isinstance(shift, int):
assert isinstance(axis, int)
shift, axis = [shift,], [axis,]
assert len(shift) == len(axis)
out = inp
for i in range(len(shift)):
axis_ = axis[i]
shift_ = shift[i]
axis_normalized_ = axis_ + dim if axis_ < 0 else axis_
assert (
dim > axis_normalized_ >= 0
), "axis out of range (expected to be in range of [{}, {}], but got {})".format(
-dim, dim - 1, axis_
)
if shift_ == 0:
continue
size = shp[axis_normalized_]
if shift_ > 0:
a, b = split(out, [size - shift_,], axis=axis_normalized_)
else:
a, b = split(out, [-shift_,], axis=axis_normalized_)
out = concat((b, a), axis=axis_normalized_)
if shp_bak is not None:
out = out.reshape(shp_bak)
return out

+ 30
- 0
imperative/python/test/unit/functional/test_tensor.py View File

@@ -636,3 +636,33 @@ def test_tile(shape, reps, is_varnode):
cases = [{"input": np.random.randn(*shape).astype("float32")}]

opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)


@pytest.mark.parametrize(
"shape, shifts, axis",
[
((2, 3), 0, None),
((2, 3), 1, 0),
((2, 3, 4, 5), (-1, 1), (0, 1)),
((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)),
],
)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_roll(shape, shifts, axis, is_varnode):
if is_varnode:
network = Network()
else:
network = None

inp = np.random.randn(*shape).astype("float32")

def func(inp):
return F.roll(inp, shifts, axis)

cases = [
{"input": inp},
]

opr_test(
cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network
)

Loading…
Cancel
Save