Browse Source

feat(mge/elwise): removed back to fp32 mode

GitOrigin-RevId: a665a279a6
tags/v1.7.2.m1
Megvii Engine Team 3 years ago
parent
commit
9fd2e66350
1 changed files with 21 additions and 18 deletions
  1. +21
    -18
      imperative/python/megengine/core/tensor/array_method.py

+ 21
- 18
imperative/python/megengine/core/tensor/array_method.py View File

@@ -47,24 +47,27 @@ def _elwise_apply(args, mode):

def _elwise(*args, mode):
args = convert_inputs(*args)
if mode in (
_ElwMod.TRUE_DIV,
_ElwMod.EXP,
_ElwMod.POW,
_ElwMod.LOG,
_ElwMod.EXPM1,
_ElwMod.LOG1P,
_ElwMod.TANH,
_ElwMod.ACOS,
_ElwMod.ASIN,
_ElwMod.ATAN2,
_ElwMod.COS,
_ElwMod.H_SWISH,
_ElwMod.SIGMOID,
_ElwMod.SIN,
_ElwMod.LOG_SUM_EXP,
) and (
amp._enabled or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args])
if (
mode
in (
_ElwMod.EXP,
_ElwMod.POW,
_ElwMod.LOG,
_ElwMod.EXPM1,
_ElwMod.LOG1P,
_ElwMod.ACOS,
_ElwMod.ASIN,
_ElwMod.ATAN2,
_ElwMod.COS,
_ElwMod.SIN,
_ElwMod.LOG_SUM_EXP,
)
and (
amp._enabled
or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args])
)
or mode in (_ElwMod.TRUE_DIV, _ElwMod.TANH,)
and np.all([np.issubdtype(arg.dtype, np.integer) for arg in args])
):
# autocast to FP32 to maintain precision
# or to avoid op's not supporting all int args


Loading…
Cancel
Save