diff --git a/imperative/python/megengine/module/normalization.py b/imperative/python/megengine/module/normalization.py index f524d03e..317fa17d 100644 --- a/imperative/python/megengine/module/normalization.py +++ b/imperative/python/megengine/module/normalization.py @@ -136,7 +136,6 @@ class LayerNorm(Module): def forward(self, x): x_shape = x.shape - assert x_shape[-len(self.normalized_shape) :] == self.normalized_shape dim_delta = len(x_shape) - len(self.normalized_shape) non_flatten_shape = x_shape[:dim_delta] x = x.reshape(*non_flatten_shape, -1)