Browse Source

fix(mge/functional): support non-float32 input when call isxxx

GitOrigin-RevId: ea8f394958
release-1.11.1
Megvii Engine Team 2 years ago
parent
commit
034c7787fa
1 changed files with 25 additions and 1 deletions
  1. +25
    -1
      imperative/python/megengine/functional/math.py

+ 25
- 1
imperative/python/megengine/functional/math.py View File

@@ -3,6 +3,8 @@ import collections
import math import math
from typing import Iterable, Optional, Sequence, Tuple, Union from typing import Iterable, Optional, Sequence, Tuple, Union


import numpy as np

from ..core._imperative_rt.core2 import Const, apply from ..core._imperative_rt.core2 import Const, apply
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core.ops import builtin from ..core.ops import builtin
@@ -11,7 +13,7 @@ from ..core.tensor.utils import _normalize_axis
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.deprecation import deprecated_kwargs_default from ..utils.deprecation import deprecated_kwargs_default
from .elemwise import _elemwise_multi_type, clip from .elemwise import _elemwise_multi_type, clip
from .tensor import expand_dims, squeeze
from .tensor import broadcast_to, expand_dims, squeeze


__all__ = [ __all__ = [
"argmax", "argmax",
@@ -55,11 +57,22 @@ def isnan(inp: Tensor) -> Tensor:
The returned array should have a data type of bool. The returned array should have a data type of bool.


Examples: Examples:
>>> F.isnan(Tensor(1))
Tensor(False, dtype=bool, device=xpux:0)

.. TODO: Remove these comments when _elemwise_multi_type support scalar input
.. >>> F.isnan(Tensor(float("nan")))
.. Tensor(True, dtype=bool, device=xpux:0)

Element-wise isnan:


>>> x = Tensor([1, float("nan"), 0]) >>> x = Tensor([1, float("nan"), 0])
>>> F.isnan(x) >>> F.isnan(x)
Tensor([False True False], dtype=bool, device=xpux:0) Tensor([False True False], dtype=bool, device=xpux:0)
""" """
if not np.issubdtype(inp.dtype, np.floating):
return broadcast_to(Tensor(False), inp.shape)
return _elemwise_multi_type(inp, mode="isnan", dtype="bool") return _elemwise_multi_type(inp, mode="isnan", dtype="bool")




@@ -79,10 +92,21 @@ def isinf(inp: Tensor) -> Tensor:


Examples: Examples:


>>> F.isinf(Tensor(1))
Tensor(False, dtype=bool, device=xpux:0)

.. TODO: Remove these comments when _elemwise_multi_type support scalar input
.. >>> F.isinf(Tensor(float("inf")))
.. Tensor(True, dtype=bool, device=xpux:0)

Element-wise isinf:

>>> x = Tensor([1, float("inf"), 0]) >>> x = Tensor([1, float("inf"), 0])
>>> F.isinf(x) >>> F.isinf(x)
Tensor([False True False], dtype=bool, device=xpux:0) Tensor([False True False], dtype=bool, device=xpux:0)
""" """
if not np.issubdtype(inp.dtype, np.floating):
return broadcast_to(Tensor(False), inp.shape)
return _elemwise_multi_type(inp, mode="isinf", dtype="bool") return _elemwise_multi_type(inp, mode="isinf", dtype="bool")






Loading…
Cancel
Save