diff --git a/python_module/megengine/functional/__init__.py b/python_module/megengine/functional/__init__.py index 767e9942..651037ee 100644 --- a/python_module/megengine/functional/__init__.py +++ b/python_module/megengine/functional/__init__.py @@ -75,6 +75,7 @@ from .nn import ( from .sort import argsort, sort, top_k from .tensor import ( add_axis, + arange, broadcast_to, concat, dimshuffle, diff --git a/python_module/megengine/functional/tensor.py b/python_module/megengine/functional/tensor.py index e1a7322c..31aedc27 100644 --- a/python_module/megengine/functional/tensor.py +++ b/python_module/megengine/functional/tensor.py @@ -17,6 +17,7 @@ from megengine._internal import CompGraph, CompNode from ..core import zeros from ..core.graph import _use_default_if_none from ..core.tensor import Tensor, wrap_io_tensor +from .elemwise import ceil from .utils import _decide_comp_node_and_comp_graph @@ -553,6 +554,46 @@ def linspace( return ret.astype(dtype) +def arange( + start: Union[int, float, Tensor], + end: Union[int, float, Tensor], + step: Union[int, float, Tensor] = 1, + dtype=np.float32, + device: Optional[CompNode] = None, + comp_graph: Optional[CompGraph] = None, +) -> Tensor: + r""" + Returns a Tensor with values from `start` to `end` with adjacent interval `step` + + :param start: starting value of the squence, shoule be scalar + :param end: ending value of the squence, shoule be scalar + :param step: the gap between each pair of adjacent values. Default 1 + :param dtype: result data type + :return: The generated tensor + + Examples: + + .. testcode:: + + import numpy as np + import megengine.functional as F + + a = F.arange(1, 5, 1) + print(a.numpy()) + + .. testoutput:: + + [1. 2. 3. 4.] + + """ + if dtype is not np.float32: + raise ValueError("arange is only implemented for float32") + num = ceil((end - start) / step) + stop = start + step * (num - 1) + ret = linspace(start, stop, num, device=device, comp_graph=comp_graph) + return ret + + def zeros_like(inp: Tensor) -> Tensor: r""" Returns a zero tensor with the same shape as input tensor diff --git a/python_module/test/unit/functional/test_functional.py b/python_module/test/unit/functional/test_functional.py index 261e0644..3b88cd8c 100644 --- a/python_module/test/unit/functional/test_functional.py +++ b/python_module/test/unit/functional/test_functional.py @@ -157,6 +157,38 @@ def test_broadcast_to(): opr_test(cases, F.broadcast_to, compare_fn=compare_fn) +def test_arange(): + cases = [ + {"input": [1, 9, 1]}, + {"input": [2, 10, 2]}, + ] + opr_test( + cases, + F.arange, + ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), + ) + + cases = [ + {"input": [9, 1, -1]}, + {"input": [10, 2, -2]}, + ] + opr_test( + cases, + F.arange, + ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), + ) + + cases = [ + {"input": [9.3, 1.2, -0.5]}, + {"input": [10.3, 2.1, -1.7]}, + ] + opr_test( + cases, + F.arange, + ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), + ) + + def test_add_update(): shape = (2, 3) v = np.random.random(shape).astype(np.float32)