|
|
@@ -48,6 +48,7 @@ __all__ = [ |
|
|
|
"tile", |
|
|
|
"copy", |
|
|
|
"transpose", |
|
|
|
"swapaxes", |
|
|
|
"where", |
|
|
|
"zeros", |
|
|
|
"zeros_like", |
|
|
@@ -715,6 +716,32 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: |
|
|
|
return inp.transpose(pattern) |
|
|
|
|
|
|
|
|
|
|
|
def swapaxes(inp: Tensor, axis1: int, axis2: int) -> Tensor: |
|
|
|
r"""Interchange two axes of a tensor. |
|
|
|
|
|
|
|
Args: |
|
|
|
inp: input tensor to swapaxes. |
|
|
|
axis1: first axis. |
|
|
|
axis2: second axis. |
|
|
|
|
|
|
|
Returns: |
|
|
|
a tensor after swapping the two axes of 'inp'. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> x = Tensor(np.array([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=np.int32)) |
|
|
|
>>> F.swapaxes(x, 0, 2) |
|
|
|
Tensor([[[0 4] |
|
|
|
[2 6]] |
|
|
|
[[1 5] |
|
|
|
[3 7]]], dtype=int32, device=xpux:0) |
|
|
|
""" |
|
|
|
pattern = list(range(inp.ndim)) |
|
|
|
tempAxis = pattern[axis1] |
|
|
|
pattern[axis1] = pattern[axis2] |
|
|
|
pattern[axis2] = tempAxis |
|
|
|
return inp.transpose(pattern) |
|
|
|
|
|
|
|
|
|
|
|
def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: |
|
|
|
r"""Reshapes a tensor without changing its data. |
|
|
|
|
|
|
|