|
@@ -1090,6 +1090,9 @@ def layer_norm( |
|
|
eps_mode |
|
|
eps_mode |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if amp._enabled: |
|
|
|
|
|
inp, weight, bias = cast_tensors(inp, weight, bias, promote=True) |
|
|
|
|
|
|
|
|
_device = inp.device |
|
|
_device = inp.device |
|
|
_dtype = inp.dtype |
|
|
_dtype = inp.dtype |
|
|
_dim = len(inp.shape) - len(normalized_shape) |
|
|
_dim = len(inp.shape) - len(normalized_shape) |
|
|