GitOrigin-RevId: 9b7fa821f8
revert-211-master
@@ -19,6 +19,7 @@ from ..core.ops.builtin import ( | |||||
GetVarShape, | GetVarShape, | ||||
Identity, | Identity, | ||||
Reduce, | Reduce, | ||||
Reshape, | |||||
TypeCvt, | TypeCvt, | ||||
) | ) | ||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
@@ -1022,6 +1023,92 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: | |||||
return cached / down | return cached / down | ||||
@lru_cache(maxsize=None) | |||||
def _get_layerNorm(device, dtype, dim, gopt_level=2): | |||||
@subgraph("LayerNormAffine", dtype, device, 5, gopt_level=gopt_level) | |||||
def layerNormAffine(inputs, f, c): | |||||
inp, eps, _flatten_shape, weight, bias = inputs | |||||
inp_shape = f(GetVarShape(), inp) | |||||
inp = f(Reshape(axis=dim), inp, _flatten_shape) | |||||
mean = f(Reduce(mode="mean", axis=-1), inp) | |||||
x2s = f(Reduce(mode="sum_sqr", axis=-1), inp) | |||||
reduce_shape = f(GetVarShape(), x2s) | |||||
reduce_size = f( | |||||
"//", | |||||
f(Reduce(mode="product", axis=0), inp_shape), | |||||
f(Reduce(mode="product", axis=0), reduce_shape), | |||||
) | |||||
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) | |||||
var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2))) | |||||
inv_sqrt_var = f("**", f("+", var, eps), c(-0.5)) | |||||
oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var)) | |||||
affine_oup = f(Reshape(), oup, inp_shape) | |||||
affine_oup = f("fma3", affine_oup, weight, bias) | |||||
# NOTE: return oup make backward faster but take more memory | |||||
return (affine_oup, oup, mean, x2s), (True, False, False, False) | |||||
@subgraph("LayerNorm", dtype, device, 3, gopt_level=gopt_level) | |||||
def layerNorm(inputs, f, c): | |||||
inp, eps, _flatten_shape = inputs | |||||
inp_shape = f(GetVarShape(), inp) | |||||
inp = f(Reshape(axis=dim), inp, _flatten_shape) | |||||
mean = f(Reduce(mode="mean", axis=-1), inp) | |||||
x2s = f(Reduce(mode="sum_sqr", axis=-1), inp) | |||||
reduce_shape = f(GetVarShape(), x2s) | |||||
reduce_size = f( | |||||
"//", | |||||
f(Reduce(mode="product", axis=0), inp_shape), | |||||
f(Reduce(mode="product", axis=0), reduce_shape), | |||||
) | |||||
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) | |||||
var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2))) | |||||
inv_sqrt_var = f("**", f("+", var, eps), c(-0.5)) | |||||
oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var)) | |||||
oup = f(Reshape(), oup, inp_shape) | |||||
return (oup,), (True,) | |||||
return (layerNorm, layerNormAffine) | |||||
def layer_norm( | |||||
inp: Tensor, | |||||
normalized_shape: tuple, | |||||
affine: bool, | |||||
weight: Optional[Tensor] = None, | |||||
bias: Optional[Tensor] = None, | |||||
eps: float = 1e-5, | |||||
eps_mode="additive", | |||||
): | |||||
assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format( | |||||
eps_mode | |||||
) | |||||
_device = inp.device | |||||
_dtype = inp.dtype | |||||
_dim = len(inp.shape) - len(normalized_shape) | |||||
_flatten_shape = concat( | |||||
( | |||||
convert_single_value(inp.shape[:_dim], dtype="int32", device=inp.device), | |||||
convert_single_value(-1, dtype="int32", device=inp.device), | |||||
) | |||||
) | |||||
(layerNorm, layerNormAffine) = _get_layerNorm(_device, _dtype, _dim) | |||||
eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device) | |||||
if affine: | |||||
outvar, *_ = apply(layerNormAffine(), inp, eps, _flatten_shape, weight, bias) | |||||
else: | |||||
outvar, *_ = apply(layerNorm(), inp, eps, _flatten_shape) | |||||
return outvar | |||||
def batch_norm( | def batch_norm( | ||||
inp: Tensor, | inp: Tensor, | ||||
running_mean: Tensor = None, | running_mean: Tensor = None, | ||||
@@ -132,18 +132,9 @@ class LayerNorm(Module): | |||||
zeros_(self.bias) | zeros_(self.bias) | ||||
def forward(self, x): | def forward(self, x): | ||||
x_shape = x.shape | |||||
dim_delta = len(x_shape) - len(self.normalized_shape) | |||||
non_flatten_shape = x_shape[:dim_delta] | |||||
x = x.reshape(*non_flatten_shape, -1) | |||||
mean = x.mean(axis=-1, keepdims=True) | |||||
var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean | |||||
x = (x - mean) / F.sqrt(var + self.eps) | |||||
x = x.reshape(x_shape) | |||||
if self.affine: | |||||
x = self.weight * x + self.bias | |||||
x = F.nn.layer_norm( | |||||
x, self.normalized_shape, self.affine, self.weight, self.bias, self.eps | |||||
) | |||||
return x | return x | ||||
def _module_info_string(self) -> str: | def _module_info_string(self) -> str: | ||||
@@ -24,6 +24,7 @@ from megengine.core._trace_option import use_symbolic_shape | |||||
from megengine.core.autodiff.grad import Grad | from megengine.core.autodiff.grad import Grad | ||||
from megengine.core.tensor.utils import make_shape_tuple | from megengine.core.tensor.utils import make_shape_tuple | ||||
from megengine.device import get_device_count | from megengine.device import get_device_count | ||||
from megengine.module import LayerNorm | |||||
def test_where(): | def test_where(): | ||||
@@ -862,6 +863,61 @@ def test_conv1d(): | |||||
) | ) | ||||
def test_layer_norm(): | |||||
def _layer_norm(x, normalized_shape, affine, weight=None, bias=None, eps=1e-5): | |||||
__layer_norm = LayerNorm(normalized_shape=normalized_shape, affine=affine) | |||||
__layer_norm.weight = weight | |||||
__layer_norm.bias = bias | |||||
return __layer_norm(x) | |||||
def _layer_norm_numpy( | |||||
x, normalized_shape, affine, weight=None, bias=None, eps=1e-5 | |||||
): | |||||
x_shape = x.shape | |||||
dim_delta = len(x_shape) - len(normalized_shape) | |||||
non_flatten_shape = x_shape[:dim_delta] | |||||
x = x.reshape(*non_flatten_shape, -1) | |||||
mean = x.mean(axis=-1, keepdims=True) | |||||
var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean | |||||
x = (x - mean) / F.sqrt(var + eps) | |||||
x = x.reshape(x_shape) | |||||
if affine: | |||||
x = weight * x + bias | |||||
return x | |||||
normalized_shape = (28, 28) | |||||
inp_feat = Tensor(np.random.randn(32, 64, 28, 28), dtype="float32") | |||||
weight = Tensor(np.random.randn(28, 28), dtype="float32") | |||||
bias = Tensor(np.random.randn(28, 28), dtype="float32") | |||||
inp_feat = inp_feat + 1 | |||||
weight = weight + 1 | |||||
bias = bias | |||||
affine = False | |||||
outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias) | |||||
targetvar = _layer_norm_numpy(inp_feat, normalized_shape, affine, weight, bias) | |||||
assert abs(outvar - targetvar).mean() < 1e-7 | |||||
# no random, affine True | |||||
normalized_shape = (28, 28) | |||||
inp_feat = Tensor(np.ones((32, 64, 28, 28)), dtype="float32") | |||||
weight = Tensor(np.ones((28, 28)), dtype="float32") | |||||
bias = Tensor(np.zeros((28, 28)), dtype="float32") | |||||
affine = True | |||||
outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias) | |||||
targetvar = _layer_norm(inp_feat, normalized_shape, affine, weight, bias) | |||||
assert abs((outvar - targetvar).mean()) < 1e-7 | |||||
assert abs(outvar.mean()) < 1e-7 | |||||
def test_batchnorm2d_io16c32(): | def test_batchnorm2d_io16c32(): | ||||
amp.enabled = True | amp.enabled = True | ||||
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) | inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) | ||||