|
@@ -136,7 +136,6 @@ class LayerNorm(Module): |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
def forward(self, x): |
|
|
x_shape = x.shape |
|
|
x_shape = x.shape |
|
|
assert x_shape[-len(self.normalized_shape) :] == self.normalized_shape |
|
|
|
|
|
dim_delta = len(x_shape) - len(self.normalized_shape) |
|
|
dim_delta = len(x_shape) - len(self.normalized_shape) |
|
|
non_flatten_shape = x_shape[:dim_delta] |
|
|
non_flatten_shape = x_shape[:dim_delta] |
|
|
x = x.reshape(*non_flatten_shape, -1) |
|
|
x = x.reshape(*non_flatten_shape, -1) |
|
|