GitOrigin-RevId: cf3bf8cb80
tags/v1.3.0
@@ -27,9 +27,31 @@ from .utils import setscalar | |||||
_ElwMod = Elemwise.Mode | _ElwMod = Elemwise.Mode | ||||
def _elwise(*args, mode): | |||||
def _elwise_apply(args, mode): | |||||
op = builtin.Elemwise(mode) | op = builtin.Elemwise(mode) | ||||
if mode in (_ElwMod.TRUE_DIV, _ElwMod.POW): | |||||
_isscalar = True | |||||
for i in args: | |||||
if isscalar(i) == False: | |||||
_isscalar = False | |||||
break | |||||
(result,) = apply(op, *args) | |||||
if _isscalar: | |||||
setscalar(result) | |||||
return result | |||||
def _elwise(*args, mode): | |||||
if mode in ( | |||||
_ElwMod.TRUE_DIV, | |||||
_ElwMod.POW, | |||||
_ElwMod.CEIL, | |||||
_ElwMod.FLOOR, | |||||
_ElwMod.ROUND, | |||||
): | |||||
if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND) and np.issubdtype( | |||||
args[0].dtype, np.integer | |||||
): | |||||
return args[0] | |||||
args = tuple( | args = tuple( | ||||
map( | map( | ||||
lambda x: x.astype("float32") | lambda x: x.astype("float32") | ||||
@@ -39,16 +61,7 @@ def _elwise(*args, mode): | |||||
) | ) | ||||
) | ) | ||||
args = utils.convert_inputs(*args) | args = utils.convert_inputs(*args) | ||||
(result,) = apply(op, *args) | |||||
_isscalar = True | |||||
for i in args: | |||||
if isscalar(i) == False: | |||||
_isscalar = False | |||||
break | |||||
if _isscalar: | |||||
setscalar(result) | |||||
return result | |||||
return _elwise_apply(args, mode) | |||||
def _matmul(inp1, inp2): | def _matmul(inp1, inp2): | ||||
@@ -9,10 +9,13 @@ | |||||
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | ||||
import functools | import functools | ||||
import numpy as np | |||||
from ..core._imperative_rt.core2 import apply | from ..core._imperative_rt.core2 import apply | ||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.ops.builtin import Elemwise | from ..core.ops.builtin import Elemwise | ||||
from ..core.tensor import megbrain_graph, utils | from ..core.tensor import megbrain_graph, utils | ||||
from ..core.tensor.array_method import _elwise_apply | |||||
from ..core.tensor.utils import isscalar, setscalar | from ..core.tensor.utils import isscalar, setscalar | ||||
from ..device import get_default_device | from ..device import get_default_device | ||||
from ..jit.tracing import is_tracing | from ..jit.tracing import is_tracing | ||||
@@ -74,7 +77,6 @@ __all__ = [ | |||||
def _elwise(*args, mode): | def _elwise(*args, mode): | ||||
op = builtin.Elemwise(mode) | |||||
tensor_args = list( | tensor_args = list( | ||||
filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) | filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) | ||||
) | ) | ||||
@@ -84,17 +86,33 @@ def _elwise(*args, mode): | |||||
args = utils.convert_inputs(first_arg, *args[1:]) | args = utils.convert_inputs(first_arg, *args[1:]) | ||||
else: | else: | ||||
args = utils.convert_inputs(*args) | args = utils.convert_inputs(*args) | ||||
if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): | |||||
if mode in ( | |||||
Elemwise.Mode.TRUE_DIV, | |||||
Elemwise.Mode.EXP, | |||||
Elemwise.Mode.POW, | |||||
Elemwise.Mode.LOG, | |||||
Elemwise.Mode.EXPM1, | |||||
Elemwise.Mode.LOG1P, | |||||
Elemwise.Mode.TANH, | |||||
Elemwise.Mode.ACOS, | |||||
Elemwise.Mode.ASIN, | |||||
Elemwise.Mode.ATAN2, | |||||
Elemwise.Mode.CEIL, | |||||
Elemwise.Mode.COS, | |||||
Elemwise.Mode.FLOOR, | |||||
Elemwise.Mode.H_SWISH, | |||||
Elemwise.Mode.ROUND, | |||||
Elemwise.Mode.SIGMOID, | |||||
Elemwise.Mode.SIN, | |||||
): | |||||
if mode in ( | |||||
Elemwise.Mode.CEIL, | |||||
Elemwise.Mode.FLOOR, | |||||
Elemwise.Mode.ROUND, | |||||
) and np.issubdtype(args[0].dtype, np.integer): | |||||
return args[0] | |||||
args = tuple(map(lambda x: x.astype("float32"), args)) | args = tuple(map(lambda x: x.astype("float32"), args)) | ||||
_isscalar = True | |||||
for i in args: | |||||
if isscalar(i) == False: | |||||
_isscalar = False | |||||
break | |||||
(result,) = apply(op, *args) | |||||
if _isscalar: | |||||
setscalar(result) | |||||
return result | |||||
return _elwise_apply(args, mode) | |||||
def _elemwise_multi_type(*args, mode, **kwargs): | def _elemwise_multi_type(*args, mode, **kwargs): | ||||
@@ -9,6 +9,7 @@ | |||||
import numpy as np | import numpy as np | ||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.functional.elemwise as elemwise | |||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.core.tensor import dtype | from megengine.core.tensor import dtype | ||||
from megengine.functional.elemwise import _elwise | from megengine.functional.elemwise import _elwise | ||||
@@ -166,3 +167,20 @@ def test_qadd(): | |||||
result_mge = result_mge.astype("float32").numpy() | result_mge = result_mge.astype("float32").numpy() | ||||
result_expect = x.astype("float32").numpy() + y.astype("float32").numpy() | result_expect = x.astype("float32").numpy() + y.astype("float32").numpy() | ||||
np.testing.assert_almost_equal(result_mge, result_expect, decimal=6) | np.testing.assert_almost_equal(result_mge, result_expect, decimal=6) | ||||
def test_int32_input(): | |||||
x = tensor(np.array([1, 2, 3, 4, 5]), dtype="int32") | |||||
for op_name in elemwise.__all__: | |||||
op = getattr(elemwise, op_name) | |||||
nargs = op.__code__.co_argcount | |||||
if op_name == "clip": | |||||
inp = (x, 0, 1) | |||||
elif op_name.endswith("_shift"): | |||||
inp = (x, 1) | |||||
elif op_name.startswith("logical_"): | |||||
continue | |||||
else: | |||||
inp = (x,) * nargs | |||||
y = op(*inp) | |||||
y.numpy() |