GitOrigin-RevId: cf3bf8cb80
tags/v1.3.0
@@ -27,9 +27,31 @@ from .utils import setscalar | |||
_ElwMod = Elemwise.Mode | |||
def _elwise(*args, mode): | |||
def _elwise_apply(args, mode): | |||
op = builtin.Elemwise(mode) | |||
if mode in (_ElwMod.TRUE_DIV, _ElwMod.POW): | |||
_isscalar = True | |||
for i in args: | |||
if isscalar(i) == False: | |||
_isscalar = False | |||
break | |||
(result,) = apply(op, *args) | |||
if _isscalar: | |||
setscalar(result) | |||
return result | |||
def _elwise(*args, mode): | |||
if mode in ( | |||
_ElwMod.TRUE_DIV, | |||
_ElwMod.POW, | |||
_ElwMod.CEIL, | |||
_ElwMod.FLOOR, | |||
_ElwMod.ROUND, | |||
): | |||
if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND) and np.issubdtype( | |||
args[0].dtype, np.integer | |||
): | |||
return args[0] | |||
args = tuple( | |||
map( | |||
lambda x: x.astype("float32") | |||
@@ -39,16 +61,7 @@ def _elwise(*args, mode): | |||
) | |||
) | |||
args = utils.convert_inputs(*args) | |||
(result,) = apply(op, *args) | |||
_isscalar = True | |||
for i in args: | |||
if isscalar(i) == False: | |||
_isscalar = False | |||
break | |||
if _isscalar: | |||
setscalar(result) | |||
return result | |||
return _elwise_apply(args, mode) | |||
def _matmul(inp1, inp2): | |||
@@ -9,10 +9,13 @@ | |||
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | |||
import functools | |||
import numpy as np | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core.ops import builtin | |||
from ..core.ops.builtin import Elemwise | |||
from ..core.tensor import megbrain_graph, utils | |||
from ..core.tensor.array_method import _elwise_apply | |||
from ..core.tensor.utils import isscalar, setscalar | |||
from ..device import get_default_device | |||
from ..jit.tracing import is_tracing | |||
@@ -74,7 +77,6 @@ __all__ = [ | |||
def _elwise(*args, mode): | |||
op = builtin.Elemwise(mode) | |||
tensor_args = list( | |||
filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) | |||
) | |||
@@ -84,17 +86,33 @@ def _elwise(*args, mode): | |||
args = utils.convert_inputs(first_arg, *args[1:]) | |||
else: | |||
args = utils.convert_inputs(*args) | |||
if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): | |||
if mode in ( | |||
Elemwise.Mode.TRUE_DIV, | |||
Elemwise.Mode.EXP, | |||
Elemwise.Mode.POW, | |||
Elemwise.Mode.LOG, | |||
Elemwise.Mode.EXPM1, | |||
Elemwise.Mode.LOG1P, | |||
Elemwise.Mode.TANH, | |||
Elemwise.Mode.ACOS, | |||
Elemwise.Mode.ASIN, | |||
Elemwise.Mode.ATAN2, | |||
Elemwise.Mode.CEIL, | |||
Elemwise.Mode.COS, | |||
Elemwise.Mode.FLOOR, | |||
Elemwise.Mode.H_SWISH, | |||
Elemwise.Mode.ROUND, | |||
Elemwise.Mode.SIGMOID, | |||
Elemwise.Mode.SIN, | |||
): | |||
if mode in ( | |||
Elemwise.Mode.CEIL, | |||
Elemwise.Mode.FLOOR, | |||
Elemwise.Mode.ROUND, | |||
) and np.issubdtype(args[0].dtype, np.integer): | |||
return args[0] | |||
args = tuple(map(lambda x: x.astype("float32"), args)) | |||
_isscalar = True | |||
for i in args: | |||
if isscalar(i) == False: | |||
_isscalar = False | |||
break | |||
(result,) = apply(op, *args) | |||
if _isscalar: | |||
setscalar(result) | |||
return result | |||
return _elwise_apply(args, mode) | |||
def _elemwise_multi_type(*args, mode, **kwargs): | |||
@@ -9,6 +9,7 @@ | |||
import numpy as np | |||
import megengine.functional as F | |||
import megengine.functional.elemwise as elemwise | |||
from megengine import tensor | |||
from megengine.core.tensor import dtype | |||
from megengine.functional.elemwise import _elwise | |||
@@ -166,3 +167,20 @@ def test_qadd(): | |||
result_mge = result_mge.astype("float32").numpy() | |||
result_expect = x.astype("float32").numpy() + y.astype("float32").numpy() | |||
np.testing.assert_almost_equal(result_mge, result_expect, decimal=6) | |||
def test_int32_input(): | |||
x = tensor(np.array([1, 2, 3, 4, 5]), dtype="int32") | |||
for op_name in elemwise.__all__: | |||
op = getattr(elemwise, op_name) | |||
nargs = op.__code__.co_argcount | |||
if op_name == "clip": | |||
inp = (x, 0, 1) | |||
elif op_name.endswith("_shift"): | |||
inp = (x, 1) | |||
elif op_name.startswith("logical_"): | |||
continue | |||
else: | |||
inp = (x,) * nargs | |||
y = op(*inp) | |||
y.numpy() |