diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 66f9ad65..0bc05023 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -48,6 +48,7 @@ __all__ = [ "tile", "copy", "transpose", + "swapaxes", "where", "zeros", "zeros_like", @@ -715,6 +716,32 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: return inp.transpose(pattern) +def swapaxes(inp: Tensor, axis1: int, axis2: int) -> Tensor: + r"""Interchange two axes of a tensor. + + Args: + inp: input tensor to swapaxes. + axis1: first axis. + axis2: second axis. + + Returns: + a tensor after swapping the two axes of 'inp'. + + Examples: + >>> x = Tensor(np.array([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=np.int32)) + >>> F.swapaxes(x, 0, 2) + Tensor([[[0 4] + [2 6]] + [[1 5] + [3 7]]], dtype=int32, device=xpux:0) + """ + pattern = list(range(inp.ndim)) + tempAxis = pattern[axis1] + pattern[axis1] = pattern[axis2] + pattern[axis2] = tempAxis + return inp.transpose(pattern) + + def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: r"""Reshapes a tensor without changing its data. diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 05d2abd5..26b807a8 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -215,6 +215,18 @@ def test_split(symbolic): @pytest.mark.parametrize("is_varnode", [True, False]) +def test_swapaxes(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + x = tensor(np.array([[1, 2, 3]], dtype=np.int32)) + y = F.swapaxes(x, 0, 1) + np.testing.assert_equal(y.numpy(), np.array([[1], [2], [3]]).astype(np.int32)) + + +@pytest.mark.parametrize("is_varnode", [True, False]) def test_reshape(is_varnode): if is_varnode: network = Network()