Browse Source

feat(mgb): add megbrain layer norm opr with subgraph

GitOrigin-RevId: 9b7fa821f8
revert-211-master
Megvii Engine Team 3 years ago
parent
commit
97b1b7774d
3 changed files with 146 additions and 12 deletions
  1. +87
    -0
      imperative/python/megengine/functional/nn.py
  2. +3
    -12
      imperative/python/megengine/module/normalization.py
  3. +56
    -0
      imperative/python/test/unit/functional/test_functional.py

+ 87
- 0
imperative/python/megengine/functional/nn.py View File

@@ -19,6 +19,7 @@ from ..core.ops.builtin import (
GetVarShape,
Identity,
Reduce,
Reshape,
TypeCvt,
)
from ..core.ops.special import Const
@@ -1022,6 +1023,92 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
return cached / down


@lru_cache(maxsize=None)
def _get_layerNorm(device, dtype, dim, gopt_level=2):
@subgraph("LayerNormAffine", dtype, device, 5, gopt_level=gopt_level)
def layerNormAffine(inputs, f, c):
inp, eps, _flatten_shape, weight, bias = inputs
inp_shape = f(GetVarShape(), inp)

inp = f(Reshape(axis=dim), inp, _flatten_shape)
mean = f(Reduce(mode="mean", axis=-1), inp)
x2s = f(Reduce(mode="sum_sqr", axis=-1), inp)
reduce_shape = f(GetVarShape(), x2s)
reduce_size = f(
"//",
f(Reduce(mode="product", axis=0), inp_shape),
f(Reduce(mode="product", axis=0), reduce_shape),
)
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2)))
inv_sqrt_var = f("**", f("+", var, eps), c(-0.5))
oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var))
affine_oup = f(Reshape(), oup, inp_shape)
affine_oup = f("fma3", affine_oup, weight, bias)

# NOTE: return oup make backward faster but take more memory
return (affine_oup, oup, mean, x2s), (True, False, False, False)

@subgraph("LayerNorm", dtype, device, 3, gopt_level=gopt_level)
def layerNorm(inputs, f, c):
inp, eps, _flatten_shape = inputs
inp_shape = f(GetVarShape(), inp)

inp = f(Reshape(axis=dim), inp, _flatten_shape)
mean = f(Reduce(mode="mean", axis=-1), inp)
x2s = f(Reduce(mode="sum_sqr", axis=-1), inp)
reduce_shape = f(GetVarShape(), x2s)
reduce_size = f(
"//",
f(Reduce(mode="product", axis=0), inp_shape),
f(Reduce(mode="product", axis=0), reduce_shape),
)
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2)))
inv_sqrt_var = f("**", f("+", var, eps), c(-0.5))
oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var))
oup = f(Reshape(), oup, inp_shape)

return (oup,), (True,)

return (layerNorm, layerNormAffine)


def layer_norm(
inp: Tensor,
normalized_shape: tuple,
affine: bool,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
eps_mode="additive",
):

assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format(
eps_mode
)

_device = inp.device
_dtype = inp.dtype
_dim = len(inp.shape) - len(normalized_shape)

_flatten_shape = concat(
(
convert_single_value(inp.shape[:_dim], dtype="int32", device=inp.device),
convert_single_value(-1, dtype="int32", device=inp.device),
)
)
(layerNorm, layerNormAffine) = _get_layerNorm(_device, _dtype, _dim)

eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device)
if affine:
outvar, *_ = apply(layerNormAffine(), inp, eps, _flatten_shape, weight, bias)
else:
outvar, *_ = apply(layerNorm(), inp, eps, _flatten_shape)

return outvar


def batch_norm(
inp: Tensor,
running_mean: Tensor = None,


+ 3
- 12
imperative/python/megengine/module/normalization.py View File

@@ -132,18 +132,9 @@ class LayerNorm(Module):
zeros_(self.bias)

def forward(self, x):
x_shape = x.shape
dim_delta = len(x_shape) - len(self.normalized_shape)
non_flatten_shape = x_shape[:dim_delta]
x = x.reshape(*non_flatten_shape, -1)

mean = x.mean(axis=-1, keepdims=True)
var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean

x = (x - mean) / F.sqrt(var + self.eps)
x = x.reshape(x_shape)
if self.affine:
x = self.weight * x + self.bias
x = F.nn.layer_norm(
x, self.normalized_shape, self.affine, self.weight, self.bias, self.eps
)
return x

def _module_info_string(self) -> str:


+ 56
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -24,6 +24,7 @@ from megengine.core._trace_option import use_symbolic_shape
from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.utils import make_shape_tuple
from megengine.device import get_device_count
from megengine.module import LayerNorm


def test_where():
@@ -862,6 +863,61 @@ def test_conv1d():
)


def test_layer_norm():
def _layer_norm(x, normalized_shape, affine, weight=None, bias=None, eps=1e-5):
__layer_norm = LayerNorm(normalized_shape=normalized_shape, affine=affine)
__layer_norm.weight = weight
__layer_norm.bias = bias
return __layer_norm(x)

def _layer_norm_numpy(
x, normalized_shape, affine, weight=None, bias=None, eps=1e-5
):
x_shape = x.shape
dim_delta = len(x_shape) - len(normalized_shape)
non_flatten_shape = x_shape[:dim_delta]
x = x.reshape(*non_flatten_shape, -1)

mean = x.mean(axis=-1, keepdims=True)
var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean

x = (x - mean) / F.sqrt(var + eps)
x = x.reshape(x_shape)
if affine:
x = weight * x + bias

return x

normalized_shape = (28, 28)
inp_feat = Tensor(np.random.randn(32, 64, 28, 28), dtype="float32")
weight = Tensor(np.random.randn(28, 28), dtype="float32")
bias = Tensor(np.random.randn(28, 28), dtype="float32")

inp_feat = inp_feat + 1
weight = weight + 1
bias = bias

affine = False

outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias)
targetvar = _layer_norm_numpy(inp_feat, normalized_shape, affine, weight, bias)

assert abs(outvar - targetvar).mean() < 1e-7

# no random, affine True
normalized_shape = (28, 28)
inp_feat = Tensor(np.ones((32, 64, 28, 28)), dtype="float32")
weight = Tensor(np.ones((28, 28)), dtype="float32")
bias = Tensor(np.zeros((28, 28)), dtype="float32")

affine = True

outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias)
targetvar = _layer_norm(inp_feat, normalized_shape, affine, weight, bias)
assert abs((outvar - targetvar).mean()) < 1e-7
assert abs(outvar.mean()) < 1e-7


def test_batchnorm2d_io16c32():
amp.enabled = True
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32)


Loading…
Cancel
Save