|
|
@@ -47,24 +47,27 @@ def _elwise_apply(args, mode): |
|
|
|
|
|
|
|
def _elwise(*args, mode): |
|
|
|
args = convert_inputs(*args) |
|
|
|
if mode in ( |
|
|
|
_ElwMod.TRUE_DIV, |
|
|
|
_ElwMod.EXP, |
|
|
|
_ElwMod.POW, |
|
|
|
_ElwMod.LOG, |
|
|
|
_ElwMod.EXPM1, |
|
|
|
_ElwMod.LOG1P, |
|
|
|
_ElwMod.TANH, |
|
|
|
_ElwMod.ACOS, |
|
|
|
_ElwMod.ASIN, |
|
|
|
_ElwMod.ATAN2, |
|
|
|
_ElwMod.COS, |
|
|
|
_ElwMod.H_SWISH, |
|
|
|
_ElwMod.SIGMOID, |
|
|
|
_ElwMod.SIN, |
|
|
|
_ElwMod.LOG_SUM_EXP, |
|
|
|
) and ( |
|
|
|
amp._enabled or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) |
|
|
|
if ( |
|
|
|
mode |
|
|
|
in ( |
|
|
|
_ElwMod.EXP, |
|
|
|
_ElwMod.POW, |
|
|
|
_ElwMod.LOG, |
|
|
|
_ElwMod.EXPM1, |
|
|
|
_ElwMod.LOG1P, |
|
|
|
_ElwMod.ACOS, |
|
|
|
_ElwMod.ASIN, |
|
|
|
_ElwMod.ATAN2, |
|
|
|
_ElwMod.COS, |
|
|
|
_ElwMod.SIN, |
|
|
|
_ElwMod.LOG_SUM_EXP, |
|
|
|
) |
|
|
|
and ( |
|
|
|
amp._enabled |
|
|
|
or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) |
|
|
|
) |
|
|
|
or mode in (_ElwMod.TRUE_DIV, _ElwMod.TANH,) |
|
|
|
and np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) |
|
|
|
): |
|
|
|
# autocast to FP32 to maintain precision |
|
|
|
# or to avoid op's not supporting all int args |
|
|
|