From 9fd2e66350f1664352106ab3483b6239f26b3e98 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 23 Dec 2021 14:32:37 +0800 Subject: [PATCH] feat(mge/elwise): removed back to fp32 mode GitOrigin-RevId: a665a279a684c41a5ff746c95e7a5246125035c1 --- .../python/megengine/core/tensor/array_method.py | 39 ++++++++++++---------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 92a080c8..da6392f3 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -47,24 +47,27 @@ def _elwise_apply(args, mode): def _elwise(*args, mode): args = convert_inputs(*args) - if mode in ( - _ElwMod.TRUE_DIV, - _ElwMod.EXP, - _ElwMod.POW, - _ElwMod.LOG, - _ElwMod.EXPM1, - _ElwMod.LOG1P, - _ElwMod.TANH, - _ElwMod.ACOS, - _ElwMod.ASIN, - _ElwMod.ATAN2, - _ElwMod.COS, - _ElwMod.H_SWISH, - _ElwMod.SIGMOID, - _ElwMod.SIN, - _ElwMod.LOG_SUM_EXP, - ) and ( - amp._enabled or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) + if ( + mode + in ( + _ElwMod.EXP, + _ElwMod.POW, + _ElwMod.LOG, + _ElwMod.EXPM1, + _ElwMod.LOG1P, + _ElwMod.ACOS, + _ElwMod.ASIN, + _ElwMod.ATAN2, + _ElwMod.COS, + _ElwMod.SIN, + _ElwMod.LOG_SUM_EXP, + ) + and ( + amp._enabled + or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) + ) + or mode in (_ElwMod.TRUE_DIV, _ElwMod.TANH,) + and np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) ): # autocast to FP32 to maintain precision # or to avoid op's not supporting all int args