|
|
@@ -45,6 +45,7 @@ __all__ = [ |
|
|
|
"ones_like", |
|
|
|
"repeat", |
|
|
|
"reshape", |
|
|
|
"roll", |
|
|
|
"split", |
|
|
|
"squeeze", |
|
|
|
"stack", |
|
|
@@ -1172,3 +1173,75 @@ def copy(inp, device=None): |
|
|
|
if device is None: |
|
|
|
return apply(Identity(), inp)[0] |
|
|
|
return apply(Copy(comp_node=as_device(device).to_c()), inp)[0] |
|
|
|
|
|
|
|
|
|
|
|
def roll( |
|
|
|
inp: Tensor, |
|
|
|
shift: Union[int, Iterable[int]], |
|
|
|
axis: Optional[Union[int, Iterable[int]]] = None, |
|
|
|
): |
|
|
|
""" |
|
|
|
Roll the tensor along the given axis(or axes). Elements that are shifted |
|
|
|
beyond the last position are re-introduced at the first position. |
|
|
|
|
|
|
|
:param inp: input tensor. |
|
|
|
:param shift: the number of places by which the elements of the tensor are |
|
|
|
shifted. If shift is a tuple, axis must be a tuple of the same size, |
|
|
|
and each axis will be rolled by the corresponding shift value. |
|
|
|
:param axis: axis along which to roll. If axis is not specified, the tensor |
|
|
|
will be flattened before rolling and then restored to the original shape. |
|
|
|
Duplicate axes is allowed if it is a tuple. Default: None. |
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
|
|
.. testcode:: |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
from megengine import tensor |
|
|
|
import megengine.functional as F |
|
|
|
|
|
|
|
x = tensor([[1,2],[3,4],[5,6]], np.int32) |
|
|
|
y = F.roll(x, 1, 0) |
|
|
|
print(y.numpy()) |
|
|
|
|
|
|
|
Outputs: |
|
|
|
|
|
|
|
.. testoutput:: |
|
|
|
|
|
|
|
[[5 6] |
|
|
|
[1 2] |
|
|
|
[3 4]] |
|
|
|
|
|
|
|
""" |
|
|
|
shp_bak = None |
|
|
|
if axis is None: |
|
|
|
shp_bak = inp.shape |
|
|
|
inp = inp.flatten() |
|
|
|
axis = 0 |
|
|
|
shp = inp.shape |
|
|
|
dim = len(shp) |
|
|
|
if isinstance(shift, int): |
|
|
|
assert isinstance(axis, int) |
|
|
|
shift, axis = [shift,], [axis,] |
|
|
|
assert len(shift) == len(axis) |
|
|
|
out = inp |
|
|
|
for i in range(len(shift)): |
|
|
|
axis_ = axis[i] |
|
|
|
shift_ = shift[i] |
|
|
|
axis_normalized_ = axis_ + dim if axis_ < 0 else axis_ |
|
|
|
assert ( |
|
|
|
dim > axis_normalized_ >= 0 |
|
|
|
), "axis out of range (expected to be in range of [{}, {}], but got {})".format( |
|
|
|
-dim, dim - 1, axis_ |
|
|
|
) |
|
|
|
if shift_ == 0: |
|
|
|
continue |
|
|
|
size = shp[axis_normalized_] |
|
|
|
if shift_ > 0: |
|
|
|
a, b = split(out, [size - shift_,], axis=axis_normalized_) |
|
|
|
else: |
|
|
|
a, b = split(out, [-shift_,], axis=axis_normalized_) |
|
|
|
out = concat((b, a), axis=axis_normalized_) |
|
|
|
if shp_bak is not None: |
|
|
|
out = out.reshape(shp_bak) |
|
|
|
return out |