Browse Source

Merge pull request #410 from TaoweiZhang/docstring-zeros

docs(mge/functional): update functional.zeros docstring
revert-410-docstring-zeros
haolongzhangm GitHub 3 years ago
parent
commit
76578994e2
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 9 deletions
  1. +1
    -3
      imperative/python/megengine/functional/nn.py
  2. +20
    -6
      imperative/python/megengine/functional/tensor.py

+ 1
- 3
imperative/python/megengine/functional/nn.py View File

@@ -1587,9 +1587,7 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor:
[0 0 1 0]
[0 0 0 1]]
"""
zeros_tensor = zeros(
list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device
)
zeros_tensor = zeros(list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device)
ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device)

op = builtin.IndexingSetOneHot(axis=inp.ndim)


+ 20
- 6
imperative/python/megengine/functional/tensor.py View File

@@ -195,14 +195,28 @@ def ones(
return full(shape, 1.0, dtype=dtype, device=device)


def zeros(shape, dtype="float32", device=None) -> Tensor:
r"""Returns a zero tensor with given shape.
def zeros(
shape: Union[int, Tuple[int, ...]],
*,
dtype="float32",
device: Optional[CompNode] = None
) -> Tensor:
r"""Returns a new tensor having a specified shape and filled with zeros.

Args:
shape: a list, tuple or integer defining the shape of the output tensor.
dtype: the desired data type of the output tensor. Default: ``float32``.
device: the desired device of the output tensor. Default: if ``None``,
use the default device (see :func:`~.megengine.get_default_device`).
shape (int or sequence of ints): the shape of the output tensor.

Keyword args:
dtype (:attr:`.Tensor.dtype`): output tensor data type. Default: ``float32``.
device (:attr:`.Tensor.device`): device on which to place the created tensor. Default: ``None``.

Returns:
a tensor containing zeros.

Examples:
>>> F.zeros((2, 1))
Tensor([[0.]
[0.]], device=xpux:0)
"""
return full(shape, 0.0, dtype=dtype, device=device)



Loading…
Cancel
Save