Browse Source

feat(mge/functional): add arange in tensor.py

GitOrigin-RevId: ad88a4c18e
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
ce30d045ee
3 changed files with 74 additions and 0 deletions
  1. +1
    -0
      python_module/megengine/functional/__init__.py
  2. +41
    -0
      python_module/megengine/functional/tensor.py
  3. +32
    -0
      python_module/test/unit/functional/test_functional.py

+ 1
- 0
python_module/megengine/functional/__init__.py View File

@@ -75,6 +75,7 @@ from .nn import (
from .sort import argsort, sort, top_k
from .tensor import (
add_axis,
arange,
broadcast_to,
concat,
dimshuffle,


+ 41
- 0
python_module/megengine/functional/tensor.py View File

@@ -17,6 +17,7 @@ from megengine._internal import CompGraph, CompNode
from ..core import zeros
from ..core.graph import _use_default_if_none
from ..core.tensor import Tensor, wrap_io_tensor
from .elemwise import ceil
from .utils import _decide_comp_node_and_comp_graph


@@ -553,6 +554,46 @@ def linspace(
return ret.astype(dtype)


def arange(
start: Union[int, float, Tensor],
end: Union[int, float, Tensor],
step: Union[int, float, Tensor] = 1,
dtype=np.float32,
device: Optional[CompNode] = None,
comp_graph: Optional[CompGraph] = None,
) -> Tensor:
r"""
Returns a Tensor with values from `start` to `end` with adjacent interval `step`

:param start: starting value of the squence, shoule be scalar
:param end: ending value of the squence, shoule be scalar
:param step: the gap between each pair of adjacent values. Default 1
:param dtype: result data type
:return: The generated tensor

Examples:

.. testcode::

import numpy as np
import megengine.functional as F

a = F.arange(1, 5, 1)
print(a.numpy())

.. testoutput::

[1. 2. 3. 4.]

"""
if dtype is not np.float32:
raise ValueError("arange is only implemented for float32")
num = ceil((end - start) / step)
stop = start + step * (num - 1)
ret = linspace(start, stop, num, device=device, comp_graph=comp_graph)
return ret


def zeros_like(inp: Tensor) -> Tensor:
r"""
Returns a zero tensor with the same shape as input tensor


+ 32
- 0
python_module/test/unit/functional/test_functional.py View File

@@ -157,6 +157,38 @@ def test_broadcast_to():
opr_test(cases, F.broadcast_to, compare_fn=compare_fn)


def test_arange():
cases = [
{"input": [1, 9, 1]},
{"input": [2, 10, 2]},
]
opr_test(
cases,
F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
)

cases = [
{"input": [9, 1, -1]},
{"input": [10, 2, -2]},
]
opr_test(
cases,
F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
)

cases = [
{"input": [9.3, 1.2, -0.5]},
{"input": [10.3, 2.1, -1.7]},
]
opr_test(
cases,
F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
)


def test_add_update():
shape = (2, 3)
v = np.random.random(shape).astype(np.float32)


Loading…
Cancel
Save