Browse Source

feat(imperative): add swapaxes

GitOrigin-RevId: e84014a011
release-1.10
Megvii Engine Team 3 years ago
parent
commit
07bdb3bf1e
2 changed files with 39 additions and 0 deletions
  1. +27
    -0
      imperative/python/megengine/functional/tensor.py
  2. +12
    -0
      imperative/python/test/unit/functional/test_tensor.py

+ 27
- 0
imperative/python/megengine/functional/tensor.py View File

@@ -48,6 +48,7 @@ __all__ = [
"tile", "tile",
"copy", "copy",
"transpose", "transpose",
"swapaxes",
"where", "where",
"zeros", "zeros",
"zeros_like", "zeros_like",
@@ -715,6 +716,32 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
return inp.transpose(pattern) return inp.transpose(pattern)




def swapaxes(inp: Tensor, axis1: int, axis2: int) -> Tensor:
r"""Interchange two axes of a tensor.

Args:
inp: input tensor to swapaxes.
axis1: first axis.
axis2: second axis.

Returns:
a tensor after swapping the two axes of 'inp'.

Examples:
>>> x = Tensor(np.array([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=np.int32))
>>> F.swapaxes(x, 0, 2)
Tensor([[[0 4]
[2 6]]
[[1 5]
[3 7]]], dtype=int32, device=xpux:0)
"""
pattern = list(range(inp.ndim))
tempAxis = pattern[axis1]
pattern[axis1] = pattern[axis2]
pattern[axis2] = tempAxis
return inp.transpose(pattern)


def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
r"""Reshapes a tensor without changing its data. r"""Reshapes a tensor without changing its data.




+ 12
- 0
imperative/python/test/unit/functional/test_tensor.py View File

@@ -215,6 +215,18 @@ def test_split(symbolic):




@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_swapaxes(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x = tensor(np.array([[1, 2, 3]], dtype=np.int32))
y = F.swapaxes(x, 0, 1)
np.testing.assert_equal(y.numpy(), np.array([[1], [2], [3]]).astype(np.int32))


@pytest.mark.parametrize("is_varnode", [True, False])
def test_reshape(is_varnode): def test_reshape(is_varnode):
if is_varnode: if is_varnode:
network = Network() network = Network()


Loading…
Cancel
Save