diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 040a4981..e8ff2e6a 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -6,7 +6,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from typing import Iterable, Optional, Sequence, Union +from typing import Iterable, Optional, Sequence, Tuple, Union import numpy as np @@ -148,17 +148,23 @@ def full( return broadcast_to(x, shape) -def ones(shape, dtype="float32", device=None) -> Tensor: - r"""Returns a ones tensor with given shape. +def ones( + shape: Union[int, Tuple[int, ...]], + *, + dtype="float32", + device: Optional[CompNode] = None +) -> Tensor: + r"""Returns a new tensor having a specified shape and filled with ones. Args: - shape: a list, tuple or integer defining the shape of the output tensor. - dtype: the desired data type of the output tensor. Default: ``float32``. - device: the desired device of the output tensor. Default: if ``None``, - use the default device (see :func:`~.megengine.get_default_device`). + shape (int or sequence of ints): the shape of the output tensor. + + Keyword args: + dtype (:attr:`.Tensor.dtype`): output tensor data type. Default: ``float32``. + device (:attr:`.Tensor.device`): device on which to place the created tensor. Default: ``None``. Returns: - output tensor. + a tensor containing ones. Examples: @@ -166,13 +172,23 @@ def ones(shape, dtype="float32", device=None) -> Tensor: import megengine.functional as F - out = F.ones((2, 1)) + out = F.ones(5) + print(out.numpy()) + out = F.ones((5, ), dtype='int32') + print(out.numpy()) + out = F.ones((2, 2)) + print(out.numpy()) + out = F.ones([2, 1]) print(out.numpy()) Outputs: .. testoutput:: + [1. 1. 1. 1. 1.] + [1 1 1 1 1] + [[1. 1.] + [1. 1.]] [[1.] [1.]] """