Browse Source

fix(mge/functional): simplify the api of add/remove_axis

GitOrigin-RevId: 2482529704
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
35d46dbb41
1 changed files with 14 additions and 10 deletions
  1. +14
    -10
      python_module/megengine/functional/tensor.py

+ 14
- 10
python_module/megengine/functional/tensor.py View File

@@ -488,12 +488,12 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:


@wrap_io_tensor
def add_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
def add_axis(inp: Tensor, axis: int) -> Tensor:
r"""
Add dimension(s) before given axis/axes
Add dimension before given axis.

:param inp: Input tensor
:param axis: Place(s) of new axes
:param axis: Place of new axes
:return: The output tensor

Examples:
@@ -504,26 +504,28 @@ def add_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
from megengine import tensor
import megengine.functional as F
x = tensor([1, 2])
out = F.add_axis(x, (0, 2))
out = F.add_axis(x, 0)
print(out.shape)

Outputs:

.. testoutput::

(1, 2, 1)
(1, 2)

"""
if not isinstance(axis, int):
raise ValueError("axis must be int, but got type:{}".format(type(axis)))
return mgb.opr.add_axis(inp, axis)


@wrap_io_tensor
def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
def remove_axis(inp: Tensor, axis: int) -> Tensor:
r"""
Remove dimension(s) of shape 1
Remove dimension of shape 1.

:param inp: Input tensor
:param axis: Place(s) of axes to be removed
:param axis: Place of axis to be removed
:return: The output tensor

Examples:
@@ -534,16 +536,18 @@ def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
from megengine import tensor
import megengine.functional as F
x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
out = F.remove_axis(x, (0, 0, 1))
out = F.remove_axis(x, 3)
print(out.shape)

Outputs:

.. testoutput::

(2,)
(1, 1, 2)

"""
if not isinstance(axis, int):
raise ValueError("axis must be int, but got type:{}".format(type(axis)))
return mgb.opr.remove_axis(inp, axis)




Loading…
Cancel
Save