Browse Source

fix(mge): fix layer norm amp bug

GitOrigin-RevId: dba691fcbf
release-1.7
Megvii Engine Team 3 years ago
parent
commit
ca4c93dee7
1 changed files with 3 additions and 0 deletions
  1. +3
    -0
      imperative/python/megengine/functional/nn.py

+ 3
- 0
imperative/python/megengine/functional/nn.py View File

@@ -1090,6 +1090,9 @@ def layer_norm(
eps_mode eps_mode
) )


if amp._enabled:
inp, weight, bias = cast_tensors(inp, weight, bias, promote=True)

_device = inp.device _device = inp.device
_dtype = inp.dtype _dtype = inp.dtype
_dim = len(inp.shape) - len(normalized_shape) _dim = len(inp.shape) - len(normalized_shape)


Loading…
Cancel
Save