Browse Source

feat(mge/functional): add isnan and isinf oprs

GitOrigin-RevId: b4a347751c
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
000663c30c
3 changed files with 60 additions and 0 deletions
  1. +2
    -0
      python_module/megengine/functional/__init__.py
  2. +48
    -0
      python_module/megengine/functional/elemwise.py
  3. +10
    -0
      python_module/test/unit/functional/test_elemwise.py

+ 2
- 0
python_module/megengine/functional/__init__.py View File

@@ -21,6 +21,8 @@ from .elemwise import (
floor, floor,
greater, greater,
greater_equal, greater_equal,
isinf,
isnan,
less, less,
less_equal, less_equal,
log, log,


+ 48
- 0
python_module/megengine/functional/elemwise.py View File

@@ -27,6 +27,8 @@ __all__ = [
"greater", "greater",
"greater_equal", "greater_equal",
"floor", "floor",
"isinf",
"isnan",
"less", "less",
"less_equal", "less_equal",
"log", "log",
@@ -244,3 +246,49 @@ def clamp(inp: Tensor, lower=None, upper=None) -> Tensor:
return maximum(inp, lower) return maximum(inp, lower)
else: else:
return minimum(inp, upper) return minimum(inp, upper)


def isnan(inp: Tensor) -> Tensor:
r"""Returns a new tensor representing if each element is NaN or not.

:param: inp
:return: a new tensor representing if each element in :attr:`inp` is NaN or not.

Examples:

.. testcode::
from megengine import tensor
import megengine.functional as F

x = tensor([1, float("nan"), 0])

print(F.isnan(x))

.. testoutput::
Tensor([0 1 0], dtype=uint8)

"""
return (inp != inp).astype("uint8")


def isinf(inp: Tensor) -> Tensor:
r"""Returns a new tensor representing if each element is Inf or not.

:param: inp
:return: a new tensor representing if each element in :attr:`inp` is Inf or not.

Examples:

.. testcode::
from megengine import tensor
import megengine.functional as F

x = tensor([1, float("inf"), 0])

print(F.isinf(x))

.. testoutput::
Tensor([0 1 0], dtype=uint8)

"""
return (abs(inp) == float("inf")).astype("uint8")

+ 10
- 0
python_module/test/unit/functional/test_elemwise.py View File

@@ -53,3 +53,13 @@ def test_clamp():
x = np.linspace(-6, 6, dtype="float32") x = np.linspace(-6, 6, dtype="float32")
assertTensorClose(F.clamp(tensor(x) + 3, 0, 6).numpy(), np.clip(x + 3, 0, 6)) assertTensorClose(F.clamp(tensor(x) + 3, 0, 6).numpy(), np.clip(x + 3, 0, 6))
assertTensorClose(F.clamp(tensor(x) - 3, -6, 0).numpy(), np.clip(x - 3, -6, 0)) assertTensorClose(F.clamp(tensor(x) - 3, -6, 0).numpy(), np.clip(x - 3, -6, 0))


def test_isnan():
for case in [[1, float("nan"), 0]]:
assertTensorClose(F.isnan(tensor(case)), np.isnan(case).astype("uint8"))


def test_isinf():
for case in [[1, float("inf"), 0]]:
assertTensorClose(F.isinf(tensor(case)), np.isinf(case).astype("uint8"))

Loading…
Cancel
Save