Browse Source

docs(mge/functional): enhance tensor creation function docstring

GitOrigin-RevId: f7682a3b02
release-1.11.1
Megvii Engine Team 2 years ago
parent
commit
eddb0aba2b
1 changed files with 229 additions and 176 deletions
  1. +229
    -176
      imperative/python/megengine/functional/tensor.py

+ 229
- 176
imperative/python/megengine/functional/tensor.py View File

@@ -55,53 +55,154 @@ __all__ = [
] ]




def diag(inp, k=0) -> Tensor:
r"""If ``inp`` is a 1D tensor, then returns a 2D tensor with the elements of ``inp`` as the diagonal.
If ``inp`` is a 2D tensor, then returns a 1D tensor with the diagonal elements of ``inp``.
# creation functions


def arange(
start: Union[int, float] = 0,
stop: Optional[Union[int, float]] = None,
step: Union[int, float] = 1,
*,
dtype="float32",
device=None,
) -> Tensor:
r"""Returns evenly spaced values within the half-open interval ``[start, stop)`` as a one-dimensional tensor.

Note:
This function cannot guarantee that the interval does not include the stop value in those cases
where step is not an integer and floating-point rounding errors affect the length of the output tensor.


Args: Args:
inp: input tensor.
k: diagonal in consider. Use :math:`k=0` for the main diagonal, :math:`k>0` for diagonals above the
main diagonal, and :math:`k<0` for diagonals below the main diagonal. Default: 0.
start(Number): if ``stop`` is specified, the start of interval (inclusive); otherwise,
the end of the interval (exclusive). If ``stop`` is not specified, the default starting value is ``0``.
stop(Number): the end of the interval.
step(Number): the distance between two adjacent elements ( ``out[i+1] - out[i]`` ). Must not be 0 ;
may be negative, this results i an empty tensor if stop >= start .

Keyword args:
dtype(:attr:`.Tensor.dtype`, optional): output tensor data type.
device(:attr:`.Tensor.device`, optional): device on which to place the created tensor.

.. seealso:: :func:`~.functional.linspace`


Returns: Returns:
the extracted diagonal or constructed diagonal array.
A one-dimensional tensor containing evenly spaced values.

The length of the output tensor must be ``ceil((stop-start)/step)``
if ``stop - start`` and ``step`` have the same sign, and length 0 otherwise.


Examples: Examples:
>>> inp = F.arange(6, dtype='int32').reshape(2,3)
>>> out = F.diag(inp, k=1)
>>> out
Tensor([1 5], dtype=int32, device=xpux:0)
>>> F.diag(out)
Tensor([[1 0]
[0 5]], dtype=int32, device=xpux:0)
>>> F.arange(5)
Tensor([0. 1. 2. 3. 4.], device=xpux:0)
>>> F.arange(1, 4)
Tensor([1. 2. 3.], device=xpux:0)

""" """
op = builtin.Diag(k=k)
(result,) = apply(op, inp)
if stop is None:
start, stop = 0, start

if not isinstance(start, Tensor):
start = Tensor(start, dtype="float32")
if not isinstance(stop, Tensor):
stop = Tensor(stop, dtype="float32")
if not isinstance(step, Tensor):
step = Tensor(step, dtype="float32")

num = ceil((stop - start) / step)
stop = start + step * (num - 1)
result = linspace(start, stop, num, device=device)
if np.dtype(dtype) != np.float32:
return result.astype(dtype)
return result


def linspace(
start: Union[int, float],
stop: Union[int, float],
num: int,
*,
dtype="float32",
device: Optional[CompNode] = None,
) -> Tensor:
r"""Returns evenly spaced numbers over a specified interval.

Returns ``num`` evenly spaced samples, calculated over the interval ``[start, stop]``.

Args:
start(Number): the start of the interval.
stop(Number): the end of the interval.
num(int): number of values to generate.

Keyword args:
dtype(:attr:`.Tensor.dtype`, optional): output tensor data type.
If ``dtype`` is not given, the data type is inferred from ``start`` and ``stop``.
device(:attr:`.Tensor.device`, optional): device on which to place the created tensor.

Returns:
a one-dimensional tensor containing evenly spaced values.

.. seealso:: :func:`~.functional.arange`

Examples:
>>> F.linspace(1, 10, 10)
Tensor([ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10.], device=xpux:0)

>>> F.linspace(2., 3., 5)
Tensor([2. 2.25 2.5 2.75 3. ], device=xpux:0)
"""
for item in (start, stop, num):
cur_device = getattr(item, "device", None)
if device is None:
device = cur_device
else:
if not (cur_device is None or device == cur_device):
raise ("ambiguous device for linspace opr")

if not isinstance(start, Tensor):
start = Tensor(start, device=device)
if not isinstance(stop, Tensor):
stop = Tensor(stop, device=device)
if not isinstance(num, Tensor):
num = Tensor(num, device=device)

op = builtin.Linspace(comp_node=device)
(result,) = apply(op, start, stop, num)
if np.dtype(dtype) != np.float32:
return result.astype(dtype)
return result return result




def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
r"""Returns a 2D tensor with ones on the diagonal and zeros elsewhere.
def eye(N: int, M: int = None, *, dtype="float32", device=None) -> Tensor:
r"""Returns a two-dimensional tensor with ones on the diagonal and zeros elsewhere.


Args: Args:
N: an integer defining the number of rows.
M: an integer defining the number of columns. If ``M`` is not specified, the number of columns is ``N``. Default: ``None``.
dtype: the desired data type of the output tensor. Default: ``float32``.
device: the desired device of the output tensor. Default: if ``None``,
use the default device (see :func:`~.megengine.get_default_device`).
N: number of rows in the output tesnor.
M: number of columns in the output tesnor.
If ``None``, the default number of columns in the output tesnor is equal tos ``N``.

Keyword args:
dtype(:attr:`.Tensor.dtype`, optional): output tesnor data type.
If ``None``, the output tesnor data type must be the default floating-point data type.
device(:attr:`.Tensor.device`, optional): device on which to place the created tensor.

.. seealso:: If you want to create a diagonal matrix, see :func:`~.functional.diag`.


Returns: Returns:
eye matrix.
a tensor where all elements are equal to zero,
except for the diagonal, whose values are equal to one.


Examples: Examples:
>>> import numpy as np
>>> out = F.eye(4, 6, dtype=np.float32)
>>> out.numpy()
array([[1., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0.],
[0., 0., 0., 1., 0., 0.]], dtype=float32)

>>> F.eye(3)
Tensor([[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]], device=xpux:0)

>>> F.eye(4, 6)
Tensor([[1. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0.]
[0. 0. 0. 1. 0. 0.]], device=xpux:0)
""" """
if M is not None: if M is not None:
if isinstance(N, Tensor) or isinstance(M, Tensor): if isinstance(N, Tensor) or isinstance(M, Tensor):
@@ -117,34 +218,86 @@ def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Ten
return result return result




def diag(inp, k: int = 0) -> Tensor:
r"""Extract a diagonal or construct a diagonal tensor.
If ``inp`` is a 1D tensor, then returns a 2D tensor with the elements of ``inp`` as the diagonal.
If ``inp`` is a 2D tensor, then returns a 1D tensor with the diagonal elements of ``inp``.

Args:
inp: input tensor.
k: diagonal in consider. Use :math:`k=0` for the main diagonal, :math:`k>0` for diagonals above the
main diagonal, and :math:`k<0` for diagonals below the main diagonal.

.. seealso:: If you want to create a identity matrix, see :func:`~.functional.eye`.

Returns:
the extracted diagonal or constructed diagonal tensor.

Examples:

Input is a 1D tensor:

>>> F.diag(Tensor([1, 2, 3]))
Tensor([[1 0 0]
[0 2 0]
[0 0 3]], dtype=int32, device=xpux:0)
>>> F.diag(Tensor([1, 2, 3]), k=1)
Tensor([[0 1 0 0]
[0 0 2 0]
[0 0 0 3]
[0 0 0 0]], dtype=int32, device=xpux:0)

Input is a 2D tensor:

>>> x = F.arange(9).reshape(3, 3)
>>> x
Tensor([[0. 1. 2.]
[3. 4. 5.]
[6. 7. 8.]], device=xpux:0)
>>> F.diag(x)
Tensor([0. 4. 8.], device=xpux:0)

Get the k-th diagonal of a given matrix:

>>> F.diag(x, k=1)
Tensor([1. 5.], device=xpux:0)
>>> F.diag(x, k=-1)
Tensor([3. 7.], device=xpux:0)
"""
op = builtin.Diag(k=k)
(result,) = apply(op, inp)
return result


def full( def full(
shape: Union[int, tuple, list],
value: Union[bool, int, float, Tensor],
shape: Union[int, Tuple[int, ...]],
value: Union[bool, int, float],
*,
dtype=None, dtype=None,
device=None, device=None,
) -> Tensor: ) -> Tensor:
r"""Creates a tensor of shape ``shape`` filled with ``value``.
r"""Returns a new tensor having a specified shape and filled with given value.


Args: Args:
shape: output tensor shape.
value: fill value.
dtype: output tensor data type. If ``dtype`` is ``None``, the output tensor
data type must be inferred from ``value``. If the value is an ``int``,
the output tensor data type must be the default integer data type. If the
value is a ``float``, the output tensor data type must be the default
floating-point data type. If the value is a ``bool``, the output tensor
must have boolean data type. Default: ``None``.
device: device on which to place the created tensor. Default: ``None``.
shape(int...): output tensor shape.
value(Scalar): fill value.

Keyword args:
dtype(:attr:`.Tensor.dtype`, optional): output tensor data type.
If ``dtype`` is ``None``, the output tensor data type must be inferred from ``value``.
If the value is an ``int``, the output tensor data type must be the default integer data type.
If the value is a ``float``, the output tensor data type must be the default floating-point data type.
If the value is a ``bool``, the output tensor must have boolean data type.
device(:attr:`.Tensor.device`, optional): device on which to place the created tensor.


Returns: Returns:
a tensor where every element is equal to ``value``. a tensor where every element is equal to ``value``.


Examples: Examples:
>>> import numpy as np
>>> out = F.full([2,3], 1.5)
>>> out.numpy()
array([[1.5, 1.5, 1.5],
[1.5, 1.5, 1.5]], dtype=float32)
>>> F.full((2, 3), 6)
Tensor([[6 6 6]
[6 6 6]], dtype=int32, device=xpux:0)
""" """


if isinstance(shape, int): if isinstance(shape, int):
@@ -166,11 +319,11 @@ def ones(
r"""Returns a new tensor having a specified shape and filled with ones. r"""Returns a new tensor having a specified shape and filled with ones.


Args: Args:
shape (int or sequence of ints): the shape of the output tensor.
shape(int...): the shape of the output tensor.


Keyword args: Keyword args:
dtype (:attr:`.Tensor.dtype`): output tensor data type. Default: ``float32``.
device (:attr:`.Tensor.device`): device on which to place the created tensor. Default: ``None``.
dtype(:attr:`.Tensor.dtype`, optional): output tensor data type.
device(:attr:`.Tensor.device`, optional): device on which to place the created tensor.


Returns: Returns:
a tensor containing ones. a tensor containing ones.
@@ -183,9 +336,6 @@ def ones(
>>> F.ones((2, 2)) >>> F.ones((2, 2))
Tensor([[1. 1.] Tensor([[1. 1.]
[1. 1.]], device=xpux:0) [1. 1.]], device=xpux:0)
>>> F.ones([2, 1])
Tensor([[1.]
[1.]], device=xpux:0)
""" """
return full(shape, 1.0, dtype=dtype, device=device) return full(shape, 1.0, dtype=dtype, device=device)


@@ -199,19 +349,19 @@ def zeros(
r"""Returns a new tensor having a specified shape and filled with zeros. r"""Returns a new tensor having a specified shape and filled with zeros.


Args: Args:
shape (int or sequence of ints): the shape of the output tensor.
shape(int...): the shape of the output tensor.


Keyword args: Keyword args:
dtype (:attr:`.Tensor.dtype`): output tensor data type. Default: ``float32``.
device (:attr:`.Tensor.device`): device on which to place the created tensor. Default: ``None``.
dtype(:attr:`.Tensor.dtype`, optional): output tensor data type.
device(:attr:`.Tensor.device`, optional): device on which to place the created tensor.


Returns: Returns:
a tensor containing zeros. a tensor containing zeros.


Examples: Examples:
>>> F.zeros((2, 1))
Tensor([[0.]
[0.]], device=xpux:0)
>>> F.zeros((2, 3))
Tensor([[0. 0. 0.]
[0. 0. 0.]], device=xpux:0)
""" """
return full(shape, 0.0, dtype=dtype, device=device) return full(shape, 0.0, dtype=dtype, device=device)


@@ -220,16 +370,15 @@ def zeros_like(inp: Tensor) -> Tensor:
r"""Returns a tensor filled with zeros with the same shape and data type as input tensor. r"""Returns a tensor filled with zeros with the same shape and data type as input tensor.


Args: Args:
inp (Tensor): input tensor.
inp(Tensor): input tensor from which to derive the output tensor shape.


Return: Return:
a tensor containing zeros.
a tensor having the same shape as input tensor and filled with zeros.


Examples: Examples:
>>> input = F.arange(9, dtype='int32').reshape(3,3)
>>> F.zeros_like(input)
>>> x = F.arange(6, dtype='int32').reshape(2, 3)
>>> F.zeros_like(x)
Tensor([[0 0 0] Tensor([[0 0 0]
[0 0 0]
[0 0 0]], dtype=int32, device=xpux:0) [0 0 0]], dtype=int32, device=xpux:0)
""" """
return full_like(inp, 0.0) return full_like(inp, 0.0)
@@ -239,14 +388,14 @@ def ones_like(inp: Tensor) -> Tensor:
r"""Returns a tensor filled with ones with the same shape and data type as input tensor. r"""Returns a tensor filled with ones with the same shape and data type as input tensor.


Args: Args:
inp (Tensor): input tensor.
inp(Tensor): input tensor from which to derive the output tensor shape.


Return: Return:
a tensor containing ones.
a tensor having the same shape as input tensor and filled with ones.


Examples: Examples:
>>> input = F.arange(6, dtype='int32').reshape(2,3)
>>> F.ones_like(input)
>>> x = F.arange(6, dtype='int32').reshape(2, 3)
>>> F.ones_like(x)
Tensor([[1 1 1] Tensor([[1 1 1]
[1 1 1]], dtype=int32, device=xpux:0) [1 1 1]], dtype=int32, device=xpux:0)
""" """
@@ -257,16 +406,15 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
r"""Returns a tensor filled with given value with the same shape as input tensor. r"""Returns a tensor filled with given value with the same shape as input tensor.


Args: Args:
inp: input tensor.
value: target value.
inp(Tensor): input tensor from which to derive the output tensor shape.
value(Scalar): fill value.


Return: Return:
output tensor.
a tensor having the same shape as input tensor and where every element is equal to fill value.


Examples: Examples:
>>> import numpy as np
>>> inp = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
>>> F.full_like(inp, 2)
>>> x = F.arange(6, dtype='int32').reshape(2, 3)
>>> F.full_like(x, 2)
Tensor([[2 2 2] Tensor([[2 2 2]
[2 2 2]], dtype=int32, device=xpux:0) [2 2 2]], dtype=int32, device=xpux:0)
""" """
@@ -280,6 +428,9 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
return rst return rst




# manipulation functions


def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
r"""Broadcasts a tensor to given shape. r"""Broadcasts a tensor to given shape.


@@ -821,107 +972,6 @@ def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Te
return squeeze_cpp(inp, axis) return squeeze_cpp(inp, axis)




def linspace(
start: Union[int, float, Tensor],
stop: Union[int, float, Tensor],
num: Union[int, Tensor],
dtype="float32",
device: Optional[CompNode] = None,
) -> Tensor:
r"""Returns equally spaced numbers over a specified interval.

Args:
start: starting value of the squence, shoule be scalar.
stop: last value of the squence, shoule be scalar.
num: number of values to generate.
dtype: result data type.

Returns:
generated tensor.

Examples:
>>> import numpy as np
>>> a = F.linspace(3, 10, 5)
>>> a.numpy()
array([ 3. , 4.75, 6.5 , 8.25, 10. ], dtype=float32)
"""
for item in (start, stop, num):
cur_device = getattr(item, "device", None)
if device is None:
device = cur_device
else:
if not (cur_device is None or device == cur_device):
raise ("ambiguous device for linspace opr")

if not isinstance(start, Tensor):
start = Tensor(start, device=device)
if not isinstance(stop, Tensor):
stop = Tensor(stop, device=device)
if not isinstance(num, Tensor):
num = Tensor(num, device=device)

op = builtin.Linspace(comp_node=device)
(result,) = apply(op, start, stop, num)
if np.dtype(dtype) != np.float32:
return result.astype(dtype)
return result


def arange(
start: Union[int, float, Tensor] = 0,
stop: Optional[Union[int, float, Tensor]] = None,
step: Union[int, float, Tensor] = 1,
dtype="float32",
device: Optional[CompNode] = None,
) -> Tensor:
r"""Returns evenly spaced values within the half-open interval ``[start, stop)`` as a one-dimensional tensor.

Note:
This function cannot guarantee that the interval does not include the stop value in those cases
where step is not an integer and floating-point rounding errors affect the length of the output tensor.

Args:
start: if ``stop`` is specified, the start of interval (inclusive); otherwise,
the end of the interval (exclusive). If ``stop`` is not specified, the default starting value is ``0``.
stop: the end of the interval. Default: ``None``.
step: the distance between two adjacent elements ( ``out[i+1] - out[i]`` ). Must not be 0 ;
may be negative, this results i an empty tensor if stop >= start . Default: 1 .

Keyword args:
dtype( :attr:`.Tensor.dtype` ): output tensor data type. Default: ``float32``.
device( :attr:`.Tensor.device` ): device on which to place the created tensor. Default: ``None``.

Returns:
A one-dimensional tensor containing evenly spaced values.

The length of the output tensor must be ``ceil((stop-start)/step)``
if ``stop - start`` and ``step`` have the same sign, and length 0 otherwise.

Examples:
>>> F.arange(5)
Tensor([0. 1. 2. 3. 4.], device=xpux:0)
>>> F.arange(1, 4)
Tensor([1. 2. 3.], device=xpux:0)

"""
if stop is None:
start, stop = 0, start

if not isinstance(start, Tensor):
start = Tensor(start, dtype="float32")
if not isinstance(stop, Tensor):
stop = Tensor(stop, dtype="float32")
if not isinstance(step, Tensor):
step = Tensor(step, dtype="float32")

num = ceil((stop - start) / step)
stop = start + step * (num - 1)
result = linspace(start, stop, num, device=device)
if np.dtype(dtype) != np.float32:
return result.astype(dtype)
return result


def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None): def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None):
r"""Repeat elements of an array. r"""Repeat elements of an array.


@@ -1125,6 +1175,9 @@ def roll(
return out return out




# TODO: Should be moved to math - statistical functions


def cumsum(inp: Tensor, axis: int): def cumsum(inp: Tensor, axis: int):
r"""Calculates the cumulative sum of tensor elements over a given axis. r"""Calculates the cumulative sum of tensor elements over a given axis.




Loading…
Cancel
Save