|
|
@@ -488,12 +488,12 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: |
|
|
|
|
|
|
|
|
|
|
|
@wrap_io_tensor |
|
|
|
def add_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor: |
|
|
|
def add_axis(inp: Tensor, axis: int) -> Tensor: |
|
|
|
r""" |
|
|
|
Add dimension(s) before given axis/axes |
|
|
|
Add dimension before given axis. |
|
|
|
|
|
|
|
:param inp: Input tensor |
|
|
|
:param axis: Place(s) of new axes |
|
|
|
:param axis: Place of new axes |
|
|
|
:return: The output tensor |
|
|
|
|
|
|
|
Examples: |
|
|
@@ -504,26 +504,28 @@ def add_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor: |
|
|
|
from megengine import tensor |
|
|
|
import megengine.functional as F |
|
|
|
x = tensor([1, 2]) |
|
|
|
out = F.add_axis(x, (0, 2)) |
|
|
|
out = F.add_axis(x, 0) |
|
|
|
print(out.shape) |
|
|
|
|
|
|
|
Outputs: |
|
|
|
|
|
|
|
.. testoutput:: |
|
|
|
|
|
|
|
(1, 2, 1) |
|
|
|
(1, 2) |
|
|
|
|
|
|
|
""" |
|
|
|
if not isinstance(axis, int): |
|
|
|
raise ValueError("axis must be int, but got type:{}".format(type(axis))) |
|
|
|
return mgb.opr.add_axis(inp, axis) |
|
|
|
|
|
|
|
|
|
|
|
@wrap_io_tensor |
|
|
|
def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor: |
|
|
|
def remove_axis(inp: Tensor, axis: int) -> Tensor: |
|
|
|
r""" |
|
|
|
Remove dimension(s) of shape 1 |
|
|
|
Remove dimension of shape 1. |
|
|
|
|
|
|
|
:param inp: Input tensor |
|
|
|
:param axis: Place(s) of axes to be removed |
|
|
|
:param axis: Place of axis to be removed |
|
|
|
:return: The output tensor |
|
|
|
|
|
|
|
Examples: |
|
|
@@ -534,16 +536,18 @@ def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor: |
|
|
|
from megengine import tensor |
|
|
|
import megengine.functional as F |
|
|
|
x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1)) |
|
|
|
out = F.remove_axis(x, (0, 0, 1)) |
|
|
|
out = F.remove_axis(x, 3) |
|
|
|
print(out.shape) |
|
|
|
|
|
|
|
Outputs: |
|
|
|
|
|
|
|
.. testoutput:: |
|
|
|
|
|
|
|
(2,) |
|
|
|
(1, 1, 2) |
|
|
|
|
|
|
|
""" |
|
|
|
if not isinstance(axis, int): |
|
|
|
raise ValueError("axis must be int, but got type:{}".format(type(axis))) |
|
|
|
return mgb.opr.remove_axis(inp, axis) |
|
|
|
|
|
|
|
|
|
|
|