From 000663c30c43d21aa5441c1a6e1a1075333a636e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 21 Apr 2020 12:11:57 +0800 Subject: [PATCH] feat(mge/functional): add isnan and isinf oprs GitOrigin-RevId: b4a347751c2022a2b0b2eca5a7e62c625b5e8c27 --- python_module/megengine/functional/__init__.py | 2 + python_module/megengine/functional/elemwise.py | 48 ++++++++++++++++++++++ .../test/unit/functional/test_elemwise.py | 10 +++++ 3 files changed, 60 insertions(+) diff --git a/python_module/megengine/functional/__init__.py b/python_module/megengine/functional/__init__.py index 651037ee..6b262bfd 100644 --- a/python_module/megengine/functional/__init__.py +++ b/python_module/megengine/functional/__init__.py @@ -21,6 +21,8 @@ from .elemwise import ( floor, greater, greater_equal, + isinf, + isnan, less, less_equal, log, diff --git a/python_module/megengine/functional/elemwise.py b/python_module/megengine/functional/elemwise.py index 6bed2d3d..2bb59255 100644 --- a/python_module/megengine/functional/elemwise.py +++ b/python_module/megengine/functional/elemwise.py @@ -27,6 +27,8 @@ __all__ = [ "greater", "greater_equal", "floor", + "isinf", + "isnan", "less", "less_equal", "log", @@ -244,3 +246,49 @@ def clamp(inp: Tensor, lower=None, upper=None) -> Tensor: return maximum(inp, lower) else: return minimum(inp, upper) + + +def isnan(inp: Tensor) -> Tensor: + r"""Returns a new tensor representing if each element is NaN or not. + + :param: inp + :return: a new tensor representing if each element in :attr:`inp` is NaN or not. + + Examples: + + .. testcode:: + from megengine import tensor + import megengine.functional as F + + x = tensor([1, float("nan"), 0]) + + print(F.isnan(x)) + + .. testoutput:: + Tensor([0 1 0], dtype=uint8) + + """ + return (inp != inp).astype("uint8") + + +def isinf(inp: Tensor) -> Tensor: + r"""Returns a new tensor representing if each element is Inf or not. + + :param: inp + :return: a new tensor representing if each element in :attr:`inp` is Inf or not. + + Examples: + + .. testcode:: + from megengine import tensor + import megengine.functional as F + + x = tensor([1, float("inf"), 0]) + + print(F.isinf(x)) + + .. testoutput:: + Tensor([0 1 0], dtype=uint8) + + """ + return (abs(inp) == float("inf")).astype("uint8") diff --git a/python_module/test/unit/functional/test_elemwise.py b/python_module/test/unit/functional/test_elemwise.py index ef9cf6fa..c02bd58b 100644 --- a/python_module/test/unit/functional/test_elemwise.py +++ b/python_module/test/unit/functional/test_elemwise.py @@ -53,3 +53,13 @@ def test_clamp(): x = np.linspace(-6, 6, dtype="float32") assertTensorClose(F.clamp(tensor(x) + 3, 0, 6).numpy(), np.clip(x + 3, 0, 6)) assertTensorClose(F.clamp(tensor(x) - 3, -6, 0).numpy(), np.clip(x - 3, -6, 0)) + + +def test_isnan(): + for case in [[1, float("nan"), 0]]: + assertTensorClose(F.isnan(tensor(case)), np.isnan(case).astype("uint8")) + + +def test_isinf(): + for case in [[1, float("inf"), 0]]: + assertTensorClose(F.isinf(tensor(case)), np.isinf(case).astype("uint8"))