Browse Source

feat(mge): rename arange paramter end -> stop

GitOrigin-RevId: ea9064e351
release-1.1
Megvii Engine Team 4 years ago
parent
commit
f551d44e37
1 changed files with 8 additions and 8 deletions
  1. +8
    -8
      imperative/python/megengine/functional/tensor.py

+ 8
- 8
imperative/python/megengine/functional/tensor.py View File

@@ -929,15 +929,15 @@ def linspace(


def arange( def arange(
start: Union[int, float, Tensor] = 0, start: Union[int, float, Tensor] = 0,
end: Optional[Union[int, float, Tensor]] = None,
stop: Optional[Union[int, float, Tensor]] = None,
step: Union[int, float, Tensor] = 1, step: Union[int, float, Tensor] = 1,
dtype="float32", dtype="float32",
device: Optional[CompNode] = None, device: Optional[CompNode] = None,
) -> Tensor: ) -> Tensor:
r"""Returns a tensor with values from start to end with adjacent interval step.
r"""Returns a tensor with values from start to stop with adjacent interval step.


:param start: starting value of the squence, shoule be scalar. :param start: starting value of the squence, shoule be scalar.
:param end: ending value of the squence, shoule be scalar.
:param stop: ending value of the squence, shoule be scalar.
:param step: gap between each pair of adjacent values. Default: 1 :param step: gap between each pair of adjacent values. Default: 1
:param dtype: result data type. :param dtype: result data type.
:return: generated tensor. :return: generated tensor.
@@ -961,16 +961,16 @@ def arange(
[0. 1. 2. 3. 4.] [0. 1. 2. 3. 4.]


""" """
if end is None:
start, end = 0, start
if stop is None:
start, stop = 0, start


if isinstance(start, Tensor): if isinstance(start, Tensor):
start = start.astype("float32") start = start.astype("float32")
if isinstance(end, Tensor):
end = end.astype("float32")
if isinstance(stop, Tensor):
stop = stop.astype("float32")
if isinstance(step, Tensor): if isinstance(step, Tensor):
step = step.astype("float32") step = step.astype("float32")
num = ceil(Tensor((end - start) / step, device=device))
num = ceil(Tensor((stop - start) / step, device=device))
stop = start + step * (num - 1) stop = start + step * (num - 1)
result = linspace(start, stop, num, device=device) result = linspace(start, stop, num, device=device)
if np.dtype(dtype) == np.int32: if np.dtype(dtype) == np.int32:


Loading…
Cancel
Save