diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 92a080c8..da6392f3 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -47,24 +47,27 @@ def _elwise_apply(args, mode): def _elwise(*args, mode): args = convert_inputs(*args) - if mode in ( - _ElwMod.TRUE_DIV, - _ElwMod.EXP, - _ElwMod.POW, - _ElwMod.LOG, - _ElwMod.EXPM1, - _ElwMod.LOG1P, - _ElwMod.TANH, - _ElwMod.ACOS, - _ElwMod.ASIN, - _ElwMod.ATAN2, - _ElwMod.COS, - _ElwMod.H_SWISH, - _ElwMod.SIGMOID, - _ElwMod.SIN, - _ElwMod.LOG_SUM_EXP, - ) and ( - amp._enabled or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) + if ( + mode + in ( + _ElwMod.EXP, + _ElwMod.POW, + _ElwMod.LOG, + _ElwMod.EXPM1, + _ElwMod.LOG1P, + _ElwMod.ACOS, + _ElwMod.ASIN, + _ElwMod.ATAN2, + _ElwMod.COS, + _ElwMod.SIN, + _ElwMod.LOG_SUM_EXP, + ) + and ( + amp._enabled + or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) + ) + or mode in (_ElwMod.TRUE_DIV, _ElwMod.TANH,) + and np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) ): # autocast to FP32 to maintain precision # or to avoid op's not supporting all int args