|
|
@@ -7,11 +7,13 @@ |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
# pylint: disable=too-many-lines |
|
|
|
from typing import Optional, Sequence, Tuple, Union |
|
|
|
from functools import lru_cache |
|
|
|
from typing import NamedTuple, Optional, Sequence, Tuple, Union |
|
|
|
|
|
|
|
from ..core._imperative_rt.core2 import apply, dtype_promotion |
|
|
|
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder |
|
|
|
from ..core.ops import builtin |
|
|
|
from ..core.ops.builtin import BatchNorm, Elemwise |
|
|
|
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt |
|
|
|
from ..core.ops.special import Const |
|
|
|
from ..core.tensor import amp, megbrain_graph |
|
|
|
from ..core.tensor.array_method import _elwise_apply |
|
|
@@ -20,10 +22,13 @@ from ..core.tensor.utils import ( |
|
|
|
astype, |
|
|
|
cast_tensors, |
|
|
|
convert_single_value, |
|
|
|
make_shape_tuple, |
|
|
|
setscalar, |
|
|
|
subgraph, |
|
|
|
) |
|
|
|
from ..device import get_default_device |
|
|
|
from ..distributed import WORLD, is_distributed |
|
|
|
from ..jit import exclude_from_trace |
|
|
|
from ..random import uniform |
|
|
|
from ..tensor import Tensor |
|
|
|
from ..utils.deprecation import deprecated_func |
|
|
@@ -1153,6 +1158,111 @@ def batch_norm( |
|
|
|
return inp |
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=None) |
|
|
|
def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): |
|
|
|
# fmt: off |
|
|
|
@subgraph("SyncBnStage0", dtype, device, 1) |
|
|
|
def syncbn_stage0(inputs, f, c): |
|
|
|
input = inputs[0] |
|
|
|
reduce_shape = c((1, channels) + (1,) * (ndim - 2), dtype="int32", device=device) |
|
|
|
input_shape = f(GetVarShape(), input) |
|
|
|
input_elems = f(Reduce(mode="product", axis=0), input_shape) |
|
|
|
reduce_elems = f(Reduce(mode="product", axis=0), reduce_shape) |
|
|
|
reduce_size = f("//", input_elems, reduce_elems) |
|
|
|
channel_x1s = f(Reduce(mode="sum"), input, reduce_shape) |
|
|
|
channel_x2s = f(Reduce(mode="sum_sqr"), input, reduce_shape) |
|
|
|
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) |
|
|
|
return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True) |
|
|
|
|
|
|
|
@subgraph("SyncBnStage1", dtype, device, 7, gopt_level=3) |
|
|
|
def syncbn_stage1(inputs, f, c): |
|
|
|
input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5] |
|
|
|
weight, bias = inputs[5:7] |
|
|
|
channel_mean = f("/", channel_x1s, reduce_size) |
|
|
|
channel_var =\ |
|
|
|
f("+", f("/", f("**", channel_x1s, c(2)), |
|
|
|
f("-", f("*", reduce_size, reduce_size))), |
|
|
|
f("/", channel_x2s, reduce_size)) |
|
|
|
invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5)) |
|
|
|
inv_var_wt = f("*", invsqrt_channel_var, weight) |
|
|
|
neg_channel_mean = f("-", channel_mean) |
|
|
|
outvar =\ |
|
|
|
f("+", f("*", input, inv_var_wt), |
|
|
|
f("+", f("*", neg_channel_mean, inv_var_wt), |
|
|
|
bias)) |
|
|
|
return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False) |
|
|
|
|
|
|
|
@subgraph("SyncBnStage1Inference", dtype, device, 6, gopt_level=3) |
|
|
|
def syncbn_stage1_inference(inputs, f, c): |
|
|
|
input, channel_mean, channel_var, eps = inputs[0:4] |
|
|
|
weight, bias = inputs[4:6] |
|
|
|
invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5)) |
|
|
|
inv_var_wt = f("*", invsqrt_channel_var, weight) |
|
|
|
neg_channel_mean = f("-", channel_mean) |
|
|
|
outvar =\ |
|
|
|
f("+", f("*", input, inv_var_wt), |
|
|
|
f("+", f("*", neg_channel_mean, inv_var_wt), |
|
|
|
bias)) |
|
|
|
return (outvar,), (True,) |
|
|
|
|
|
|
|
@subgraph("SyncBnStage2", dtype, device, 7, gopt_level=3) |
|
|
|
def syncbn_stage2(inputs, f, c): |
|
|
|
running_mean, running_var, momentum = inputs[0:3] |
|
|
|
reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7] |
|
|
|
running_mean = f("*", running_mean, momentum) |
|
|
|
running_mean =\ |
|
|
|
f("+", running_mean, |
|
|
|
f("*", f("-", c(1), momentum), |
|
|
|
channel_mean)) |
|
|
|
channel_variance_unbiased =\ |
|
|
|
f("+", f("/", f("**", channel_x1s, c(2)), |
|
|
|
f("*", f("-", reduce_size), |
|
|
|
f("-", reduce_size, c(1)))), |
|
|
|
f("/", channel_x2s, |
|
|
|
f("-", reduce_size, c(1)))) |
|
|
|
running_var = f("*", running_var, momentum) |
|
|
|
running_var =\ |
|
|
|
f("+", running_var, |
|
|
|
f("*", f("-", c(1), momentum), |
|
|
|
channel_variance_unbiased)) |
|
|
|
return (running_mean, running_var), (True, True) |
|
|
|
|
|
|
|
@subgraph("SyncBnConcatStats", dtype, device, 3, gopt_level=3) |
|
|
|
def syncbn_concat_stats(inputs, f, c): |
|
|
|
reduce_size, channel_x1s, channel_x2s = inputs[0:3] |
|
|
|
reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32")) |
|
|
|
stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s) |
|
|
|
return (stats,), (True,) |
|
|
|
|
|
|
|
@subgraph("SyncBnSplitStats", dtype, device, 1, gopt_level=3) |
|
|
|
def syncbn_split_stats(inputs, f, c): |
|
|
|
stats = inputs[0] |
|
|
|
c_1 = c(1, dtype="int32") |
|
|
|
channel_x1s_end = c(channels+1, dtype="int32") |
|
|
|
def _subtensor(src, axis, begin, end): |
|
|
|
items = (axis, (begin is not None), (end is not None), False, False), |
|
|
|
args = () |
|
|
|
if begin is not None: |
|
|
|
args += begin, |
|
|
|
if end is not None: |
|
|
|
args += end, |
|
|
|
return f(builtin.Subtensor(items=items), src, *args) |
|
|
|
reduce_size = _subtensor(stats, 1, None, c_1) |
|
|
|
channel_x1s = _subtensor(stats, 1, c_1, channel_x1s_end) |
|
|
|
channel_x2s = _subtensor(stats, 1, channel_x1s_end, None) |
|
|
|
reduce_size = f(builtin.Reshape(), reduce_size, c_1) |
|
|
|
return (reduce_size, channel_x1s, channel_x2s), (False, True, True) |
|
|
|
# fmt: on |
|
|
|
return ( |
|
|
|
syncbn_stage0, |
|
|
|
syncbn_stage1, |
|
|
|
syncbn_stage1_inference, |
|
|
|
syncbn_stage2, |
|
|
|
syncbn_concat_stats, |
|
|
|
syncbn_split_stats, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def sync_batch_norm( |
|
|
|
inp: Tensor, |
|
|
|
running_mean: Tensor, |
|
|
@@ -1193,52 +1303,55 @@ def sync_batch_norm( |
|
|
|
assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format( |
|
|
|
eps_mode |
|
|
|
) |
|
|
|
_channels = inp.shape[1] |
|
|
|
# TODO: cudnnBn fastpath |
|
|
|
_channels = make_shape_tuple(inp.shape)[1] |
|
|
|
_ndim = inp.ndim |
|
|
|
_device = inp.device |
|
|
|
_dtype = inp.dtype |
|
|
|
_param_shape = (1, _channels) + (1,) * (_ndim - 2) |
|
|
|
_reduce_axis = [0] + [i for i in range(2, _ndim)] |
|
|
|
|
|
|
|
if training: |
|
|
|
def _make_full_if_none(x, value): |
|
|
|
if x is None: |
|
|
|
(x,) = Const(value, dtype=inp.dtype, device=_device)() |
|
|
|
(result,) = apply(builtin.Broadcast(), x, reduce_shape) |
|
|
|
return result |
|
|
|
elif x.ndim == 1: |
|
|
|
(result,) = apply(builtin.Reshape(), x, reduce_shape) |
|
|
|
return result |
|
|
|
return x |
|
|
|
|
|
|
|
( |
|
|
|
syncbn_stage0, |
|
|
|
syncbn_stage1, |
|
|
|
syncbn_stage1_inference, |
|
|
|
syncbn_stage2, |
|
|
|
syncbn_concat_stats, |
|
|
|
syncbn_split_stats, |
|
|
|
) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels) |
|
|
|
|
|
|
|
reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0, inp) |
|
|
|
|
|
|
|
def _sum_on_channel(inp): |
|
|
|
return inp.sum(axis=_reduce_axis, keepdims=True) |
|
|
|
eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device) |
|
|
|
|
|
|
|
reduce_size = inp.shape[0] |
|
|
|
for i in range(2, _ndim): |
|
|
|
reduce_size = reduce_size * inp.shape[i] |
|
|
|
channel_x1s = _sum_on_channel(inp) |
|
|
|
channel_x2s = _sum_on_channel(inp ** 2) |
|
|
|
weight = _make_full_if_none(weight, 1) |
|
|
|
bias = _make_full_if_none(bias, 0) |
|
|
|
|
|
|
|
if training: |
|
|
|
if is_distributed(): |
|
|
|
# reduce all nodes' data to calculate mean and variance |
|
|
|
reduce_size = broadcast_to( |
|
|
|
Tensor(reduce_size).astype(dtype=_dtype), [1] * _ndim |
|
|
|
) |
|
|
|
stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1) |
|
|
|
(stat,) = apply(syncbn_concat_stats, reduce_size, channel_x1s, channel_x2s) |
|
|
|
stat = all_reduce_sum(stat, group) |
|
|
|
reduce_size = stat[:, :1].reshape(1) |
|
|
|
channel_x1s = stat[:, 1 : 1 + _channels] |
|
|
|
channel_x2s = stat[:, 1 + _channels :] |
|
|
|
reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats, stat) |
|
|
|
|
|
|
|
channel_mean = channel_x1s / reduce_size |
|
|
|
channel_variance = ( |
|
|
|
channel_x1s ** 2 / (-reduce_size * reduce_size) + channel_x2s / reduce_size |
|
|
|
outvar, channel_mean, *_ = apply( |
|
|
|
syncbn_stage1, inp, reduce_size, channel_x1s, channel_x2s, eps, weight, bias |
|
|
|
) |
|
|
|
else: |
|
|
|
assert running_var is not None and running_mean is not None |
|
|
|
channel_variance = running_var.reshape(*_param_shape) |
|
|
|
channel_mean = running_mean.reshape(*_param_shape) |
|
|
|
|
|
|
|
invsqrt_channel_variance = ( |
|
|
|
maximum(channel_variance, eps) if eps_mode == "max" else channel_variance + eps |
|
|
|
) ** -0.5 |
|
|
|
|
|
|
|
if weight is not None: |
|
|
|
weight = weight.reshape(*_param_shape) |
|
|
|
if bias is not None: |
|
|
|
bias = bias.reshape(*_param_shape) |
|
|
|
channel_mean = running_mean |
|
|
|
channel_var = running_var |
|
|
|
outvar, *_ = apply( |
|
|
|
syncbn_stage1_inference, inp, channel_mean, channel_var, eps, weight, bias |
|
|
|
) |
|
|
|
|
|
|
|
# outvar = output * weight + bias |
|
|
|
# where output = inp * invsqrt_channel_variance + ( |
|
|
@@ -1246,28 +1359,18 @@ def sync_batch_norm( |
|
|
|
# ) |
|
|
|
# Manually expand output for gopt |
|
|
|
|
|
|
|
if weight is not None: |
|
|
|
inv_var_wt = invsqrt_channel_variance * weight |
|
|
|
neg_channel_mean = -channel_mean |
|
|
|
if bias is not None: |
|
|
|
outvar = inp * inv_var_wt + (neg_channel_mean * inv_var_wt + bias) |
|
|
|
else: |
|
|
|
outvar = inp * inv_var_wt + neg_channel_mean * inv_var_wt |
|
|
|
else: |
|
|
|
outvar = inp * invsqrt_channel_variance + ( |
|
|
|
-channel_mean * invsqrt_channel_variance |
|
|
|
) |
|
|
|
if bias is not None: |
|
|
|
outvar = outvar + bias |
|
|
|
|
|
|
|
if training and running_var is not None and running_mean is not None: |
|
|
|
running_mean *= momentum |
|
|
|
running_mean += (1 - momentum) * channel_mean |
|
|
|
channel_variance_unbiased = channel_x1s ** 2 / ( |
|
|
|
-reduce_size * (reduce_size - 1) |
|
|
|
) + channel_x2s / (reduce_size - 1) |
|
|
|
running_var *= momentum |
|
|
|
running_var += (1 - momentum) * channel_variance_unbiased |
|
|
|
momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device) |
|
|
|
running_mean[...], running_var[...] = apply( |
|
|
|
syncbn_stage2, |
|
|
|
running_mean, |
|
|
|
running_var, |
|
|
|
momentum, |
|
|
|
reduce_size, |
|
|
|
channel_x1s, |
|
|
|
channel_x2s, |
|
|
|
channel_mean, |
|
|
|
) |
|
|
|
|
|
|
|
return outvar |
|
|
|
|
|
|
|