|
|
@@ -19,6 +19,7 @@ from ..core.ops.builtin import ( |
|
|
|
GetVarShape, |
|
|
|
Identity, |
|
|
|
Reduce, |
|
|
|
Reshape, |
|
|
|
TypeCvt, |
|
|
|
) |
|
|
|
from ..core.ops.special import Const |
|
|
@@ -1022,6 +1023,92 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: |
|
|
|
return cached / down |
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=None) |
|
|
|
def _get_layerNorm(device, dtype, dim, gopt_level=2): |
|
|
|
@subgraph("LayerNormAffine", dtype, device, 5, gopt_level=gopt_level) |
|
|
|
def layerNormAffine(inputs, f, c): |
|
|
|
inp, eps, _flatten_shape, weight, bias = inputs |
|
|
|
inp_shape = f(GetVarShape(), inp) |
|
|
|
|
|
|
|
inp = f(Reshape(axis=dim), inp, _flatten_shape) |
|
|
|
mean = f(Reduce(mode="mean", axis=-1), inp) |
|
|
|
x2s = f(Reduce(mode="sum_sqr", axis=-1), inp) |
|
|
|
reduce_shape = f(GetVarShape(), x2s) |
|
|
|
reduce_size = f( |
|
|
|
"//", |
|
|
|
f(Reduce(mode="product", axis=0), inp_shape), |
|
|
|
f(Reduce(mode="product", axis=0), reduce_shape), |
|
|
|
) |
|
|
|
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) |
|
|
|
var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2))) |
|
|
|
inv_sqrt_var = f("**", f("+", var, eps), c(-0.5)) |
|
|
|
oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var)) |
|
|
|
affine_oup = f(Reshape(), oup, inp_shape) |
|
|
|
affine_oup = f("fma3", affine_oup, weight, bias) |
|
|
|
|
|
|
|
# NOTE: return oup make backward faster but take more memory |
|
|
|
return (affine_oup, oup, mean, x2s), (True, False, False, False) |
|
|
|
|
|
|
|
@subgraph("LayerNorm", dtype, device, 3, gopt_level=gopt_level) |
|
|
|
def layerNorm(inputs, f, c): |
|
|
|
inp, eps, _flatten_shape = inputs |
|
|
|
inp_shape = f(GetVarShape(), inp) |
|
|
|
|
|
|
|
inp = f(Reshape(axis=dim), inp, _flatten_shape) |
|
|
|
mean = f(Reduce(mode="mean", axis=-1), inp) |
|
|
|
x2s = f(Reduce(mode="sum_sqr", axis=-1), inp) |
|
|
|
reduce_shape = f(GetVarShape(), x2s) |
|
|
|
reduce_size = f( |
|
|
|
"//", |
|
|
|
f(Reduce(mode="product", axis=0), inp_shape), |
|
|
|
f(Reduce(mode="product", axis=0), reduce_shape), |
|
|
|
) |
|
|
|
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) |
|
|
|
var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2))) |
|
|
|
inv_sqrt_var = f("**", f("+", var, eps), c(-0.5)) |
|
|
|
oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var)) |
|
|
|
oup = f(Reshape(), oup, inp_shape) |
|
|
|
|
|
|
|
return (oup,), (True,) |
|
|
|
|
|
|
|
return (layerNorm, layerNormAffine) |
|
|
|
|
|
|
|
|
|
|
|
def layer_norm( |
|
|
|
inp: Tensor, |
|
|
|
normalized_shape: tuple, |
|
|
|
affine: bool, |
|
|
|
weight: Optional[Tensor] = None, |
|
|
|
bias: Optional[Tensor] = None, |
|
|
|
eps: float = 1e-5, |
|
|
|
eps_mode="additive", |
|
|
|
): |
|
|
|
|
|
|
|
assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format( |
|
|
|
eps_mode |
|
|
|
) |
|
|
|
|
|
|
|
_device = inp.device |
|
|
|
_dtype = inp.dtype |
|
|
|
_dim = len(inp.shape) - len(normalized_shape) |
|
|
|
|
|
|
|
_flatten_shape = concat( |
|
|
|
( |
|
|
|
convert_single_value(inp.shape[:_dim], dtype="int32", device=inp.device), |
|
|
|
convert_single_value(-1, dtype="int32", device=inp.device), |
|
|
|
) |
|
|
|
) |
|
|
|
(layerNorm, layerNormAffine) = _get_layerNorm(_device, _dtype, _dim) |
|
|
|
|
|
|
|
eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device) |
|
|
|
if affine: |
|
|
|
outvar, *_ = apply(layerNormAffine(), inp, eps, _flatten_shape, weight, bias) |
|
|
|
else: |
|
|
|
outvar, *_ = apply(layerNorm(), inp, eps, _flatten_shape) |
|
|
|
|
|
|
|
return outvar |
|
|
|
|
|
|
|
|
|
|
|
def batch_norm( |
|
|
|
inp: Tensor, |
|
|
|
running_mean: Tensor = None, |
|
|
|