|
|
@@ -109,22 +109,24 @@ class InstanceNorm(Module): |
|
|
|
|
|
|
|
class LayerNorm(Module): |
|
|
|
""" |
|
|
|
Simple implementation of LayerNorm. Only support 4d tensor now. |
|
|
|
Simple implementation of LayerNorm. Support tensor of any shape as input. |
|
|
|
Reference: https://arxiv.org/pdf/1803.08494.pdf. |
|
|
|
Note that LayerNorm equals using GroupNorm with num_groups=1. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs): |
|
|
|
def __init__(self, normalized_shape, eps=1e-05, affine=True, **kwargs): |
|
|
|
super().__init__(**kwargs) |
|
|
|
self.num_channels = num_channels |
|
|
|
if isinstance(normalized_shape, int): |
|
|
|
normalized_shape = (normalized_shape,) |
|
|
|
self.normalized_shape = tuple(normalized_shape) |
|
|
|
self.eps = eps |
|
|
|
self.affine = affine |
|
|
|
if self.affine: |
|
|
|
self.weight = Parameter(np.ones(num_channels, dtype="float32")) |
|
|
|
self.bias = Parameter(np.zeros(num_channels, dtype="float32")) |
|
|
|
self.weight = Parameter(np.ones(self.normalized_shape, dtype="float32")) |
|
|
|
self.bias = Parameter(np.zeros(self.normalized_shape, dtype="float32")) |
|
|
|
else: |
|
|
|
self.weight = None |
|
|
|
self.bias = None |
|
|
|
|
|
|
|
self.reset_parameters() |
|
|
|
|
|
|
|
def reset_parameters(self): |
|
|
@@ -133,20 +135,21 @@ class LayerNorm(Module): |
|
|
|
zeros_(self.bias) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
N, C, H, W = x.shape |
|
|
|
assert C == self.num_channels |
|
|
|
x = x.reshape(x.shape[0], -1) |
|
|
|
# NOTE mean will keepdims in next two lines. |
|
|
|
mean = x.mean(axis=1, keepdims=1) |
|
|
|
var = (x ** 2).mean(axis=1, keepdims=1) - mean * mean |
|
|
|
x_shape = x.shape |
|
|
|
assert x_shape[-len(self.normalized_shape) :] == self.normalized_shape |
|
|
|
dim_delta = len(x_shape) - len(self.normalized_shape) |
|
|
|
non_flatten_shape = x_shape[:dim_delta] |
|
|
|
x = x.reshape(*non_flatten_shape, -1) |
|
|
|
|
|
|
|
mean = x.mean(axis=-1, keepdims=True) |
|
|
|
var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean |
|
|
|
|
|
|
|
x = (x - mean) / F.sqrt(var + self.eps) |
|
|
|
x = x.reshape(N, C, H, W) |
|
|
|
x = x.reshape(x_shape) |
|
|
|
if self.affine: |
|
|
|
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) |
|
|
|
|
|
|
|
x = self.weight * x + self.bias |
|
|
|
return x |
|
|
|
|
|
|
|
def _module_info_string(self) -> str: |
|
|
|
s = "channels={num_channels}, eps={eps}, affine={affine}" |
|
|
|
s = "normalized_shape={normalized_shape}, eps={eps}, affine={affine}" |
|
|
|
return s.format(**self.__dict__) |