GitOrigin-RevId: dc434fa7ec
tags/v1.0.0-rc1
@@ -23,6 +23,15 @@ from .tensor import Tensor | |||||
def _elwise(*args, mode): | def _elwise(*args, mode): | ||||
op = builtin.Elemwise(mode=mode) | op = builtin.Elemwise(mode=mode) | ||||
if mode in ("TRUE_DIV", "POW"): | |||||
args = tuple( | |||||
map( | |||||
lambda x: x.astype("float32") | |||||
if hasattr(x, "dtype") and x.dtype != np.float32 | |||||
else x, | |||||
args, | |||||
) | |||||
) | |||||
args = utils.convert_inputs(*args) | args = utils.convert_inputs(*args) | ||||
(result,) = apply(op, *args) | (result,) = apply(op, *args) | ||||
return result | return result | ||||
@@ -76,6 +76,10 @@ __all__ = [ | |||||
def _elwise(*args, mode): | def _elwise(*args, mode): | ||||
op = builtin.Elemwise(mode=mode) | op = builtin.Elemwise(mode=mode) | ||||
if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): | |||||
args = tuple( | |||||
map(lambda x: x.astype("float32") if hasattr(x, "dtype") else x, args) | |||||
) | |||||
args = utils.convert_inputs(*args) | args = utils.convert_inputs(*args) | ||||
(result,) = apply(op, *args) | (result,) = apply(op, *args) | ||||
return result | return result | ||||