Browse Source

perf(syncbn): reimplement with subgraph

GitOrigin-RevId: 13e7e3d3c0
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
8c47c1f149
5 changed files with 257 additions and 56 deletions
  1. +47
    -0
      imperative/python/megengine/core/tensor/utils.py
  2. +158
    -55
      imperative/python/megengine/functional/nn.py
  3. +1
    -1
      imperative/python/megengine/jit/tracing.py
  4. +4
    -0
      imperative/python/src/common.cpp
  5. +47
    -0
      imperative/python/src/ops.cpp

+ 47
- 0
imperative/python/megengine/core/tensor/utils.py View File

@@ -13,6 +13,7 @@ import numpy as np


from .._imperative_rt import make_const from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device
from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from .._wrap import as_device from .._wrap import as_device
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
@@ -219,3 +220,49 @@ def _normalize_axis(
) )
return axis return axis
raise raise


def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
if device.physical_name.startswith("cpu"):
gopt_level = None # disable jit and compile

binary_ops = {
"+": builtin.Elemwise(mode="add"),
"-": builtin.Elemwise(mode="sub"),
"*": builtin.Elemwise(mode="mul"),
"/": builtin.Elemwise(mode="true_div"),
"//": builtin.Elemwise(mode="floor_div"),
"**": builtin.Elemwise(mode="pow"),
"√": builtin.Elemwise(mode="expm1"),
"max": builtin.Elemwise(mode="max"),
"additive": builtin.Elemwise(mode="add"),
}

unary_ops = {
"-": builtin.Elemwise(mode="negate"),
}

def decorator(func):
builder = _SubgraphBuilder(name)

def apply_expr(op, *args):
if isinstance(op, str):
if len(args) == 2:
op = binary_ops[op]
elif len(args) == 1:
op = unary_ops[op]
return builder.apply(op, args, 1)[0]

def apply_const(value, dtype=dtype, device=device):
return builder.apply_const(value, dtype, device)

inputs = [builder.input() for _ in range(nr_inputs)]
outputs, outputs_has_grad = func(inputs, apply_expr, apply_const)
builder.outputs(outputs)
builder.outputs_has_grad(outputs_has_grad)
if gopt_level is None:
return builder.get()
else:
return builder.compile(gopt_level)

return decorator

+ 158
- 55
imperative/python/megengine/functional/nn.py View File

@@ -7,11 +7,13 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
from typing import Optional, Sequence, Tuple, Union
from functools import lru_cache
from typing import NamedTuple, Optional, Sequence, Tuple, Union


from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import amp, megbrain_graph from ..core.tensor import amp, megbrain_graph
from ..core.tensor.array_method import _elwise_apply from ..core.tensor.array_method import _elwise_apply
@@ -20,10 +22,13 @@ from ..core.tensor.utils import (
astype, astype,
cast_tensors, cast_tensors,
convert_single_value, convert_single_value,
make_shape_tuple,
setscalar, setscalar,
subgraph,
) )
from ..device import get_default_device from ..device import get_default_device
from ..distributed import WORLD, is_distributed from ..distributed import WORLD, is_distributed
from ..jit import exclude_from_trace
from ..random import uniform from ..random import uniform
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.deprecation import deprecated_func from ..utils.deprecation import deprecated_func
@@ -1153,6 +1158,111 @@ def batch_norm(
return inp return inp




@lru_cache(maxsize=None)
def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
# fmt: off
@subgraph("SyncBnStage0", dtype, device, 1)
def syncbn_stage0(inputs, f, c):
input = inputs[0]
reduce_shape = c((1, channels) + (1,) * (ndim - 2), dtype="int32", device=device)
input_shape = f(GetVarShape(), input)
input_elems = f(Reduce(mode="product", axis=0), input_shape)
reduce_elems = f(Reduce(mode="product", axis=0), reduce_shape)
reduce_size = f("//", input_elems, reduce_elems)
channel_x1s = f(Reduce(mode="sum"), input, reduce_shape)
channel_x2s = f(Reduce(mode="sum_sqr"), input, reduce_shape)
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True)

@subgraph("SyncBnStage1", dtype, device, 7, gopt_level=3)
def syncbn_stage1(inputs, f, c):
input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5]
weight, bias = inputs[5:7]
channel_mean = f("/", channel_x1s, reduce_size)
channel_var =\
f("+", f("/", f("**", channel_x1s, c(2)),
f("-", f("*", reduce_size, reduce_size))),
f("/", channel_x2s, reduce_size))
invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5))
inv_var_wt = f("*", invsqrt_channel_var, weight)
neg_channel_mean = f("-", channel_mean)
outvar =\
f("+", f("*", input, inv_var_wt),
f("+", f("*", neg_channel_mean, inv_var_wt),
bias))
return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False)

@subgraph("SyncBnStage1Inference", dtype, device, 6, gopt_level=3)
def syncbn_stage1_inference(inputs, f, c):
input, channel_mean, channel_var, eps = inputs[0:4]
weight, bias = inputs[4:6]
invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5))
inv_var_wt = f("*", invsqrt_channel_var, weight)
neg_channel_mean = f("-", channel_mean)
outvar =\
f("+", f("*", input, inv_var_wt),
f("+", f("*", neg_channel_mean, inv_var_wt),
bias))
return (outvar,), (True,)

@subgraph("SyncBnStage2", dtype, device, 7, gopt_level=3)
def syncbn_stage2(inputs, f, c):
running_mean, running_var, momentum = inputs[0:3]
reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7]
running_mean = f("*", running_mean, momentum)
running_mean =\
f("+", running_mean,
f("*", f("-", c(1), momentum),
channel_mean))
channel_variance_unbiased =\
f("+", f("/", f("**", channel_x1s, c(2)),
f("*", f("-", reduce_size),
f("-", reduce_size, c(1)))),
f("/", channel_x2s,
f("-", reduce_size, c(1))))
running_var = f("*", running_var, momentum)
running_var =\
f("+", running_var,
f("*", f("-", c(1), momentum),
channel_variance_unbiased))
return (running_mean, running_var), (True, True)

@subgraph("SyncBnConcatStats", dtype, device, 3, gopt_level=3)
def syncbn_concat_stats(inputs, f, c):
reduce_size, channel_x1s, channel_x2s = inputs[0:3]
reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32"))
stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s)
return (stats,), (True,)

@subgraph("SyncBnSplitStats", dtype, device, 1, gopt_level=3)
def syncbn_split_stats(inputs, f, c):
stats = inputs[0]
c_1 = c(1, dtype="int32")
channel_x1s_end = c(channels+1, dtype="int32")
def _subtensor(src, axis, begin, end):
items = (axis, (begin is not None), (end is not None), False, False),
args = ()
if begin is not None:
args += begin,
if end is not None:
args += end,
return f(builtin.Subtensor(items=items), src, *args)
reduce_size = _subtensor(stats, 1, None, c_1)
channel_x1s = _subtensor(stats, 1, c_1, channel_x1s_end)
channel_x2s = _subtensor(stats, 1, channel_x1s_end, None)
reduce_size = f(builtin.Reshape(), reduce_size, c_1)
return (reduce_size, channel_x1s, channel_x2s), (False, True, True)
# fmt: on
return (
syncbn_stage0,
syncbn_stage1,
syncbn_stage1_inference,
syncbn_stage2,
syncbn_concat_stats,
syncbn_split_stats,
)


def sync_batch_norm( def sync_batch_norm(
inp: Tensor, inp: Tensor,
running_mean: Tensor, running_mean: Tensor,
@@ -1193,52 +1303,55 @@ def sync_batch_norm(
assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format( assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format(
eps_mode eps_mode
) )
_channels = inp.shape[1]
# TODO: cudnnBn fastpath
_channels = make_shape_tuple(inp.shape)[1]
_ndim = inp.ndim _ndim = inp.ndim
_device = inp.device _device = inp.device
_dtype = inp.dtype _dtype = inp.dtype
_param_shape = (1, _channels) + (1,) * (_ndim - 2)
_reduce_axis = [0] + [i for i in range(2, _ndim)]


if training:
def _make_full_if_none(x, value):
if x is None:
(x,) = Const(value, dtype=inp.dtype, device=_device)()
(result,) = apply(builtin.Broadcast(), x, reduce_shape)
return result
elif x.ndim == 1:
(result,) = apply(builtin.Reshape(), x, reduce_shape)
return result
return x

(
syncbn_stage0,
syncbn_stage1,
syncbn_stage1_inference,
syncbn_stage2,
syncbn_concat_stats,
syncbn_split_stats,
) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels)

reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0, inp)


def _sum_on_channel(inp):
return inp.sum(axis=_reduce_axis, keepdims=True)
eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device)


reduce_size = inp.shape[0]
for i in range(2, _ndim):
reduce_size = reduce_size * inp.shape[i]
channel_x1s = _sum_on_channel(inp)
channel_x2s = _sum_on_channel(inp ** 2)
weight = _make_full_if_none(weight, 1)
bias = _make_full_if_none(bias, 0)


if training:
if is_distributed(): if is_distributed():
# reduce all nodes' data to calculate mean and variance # reduce all nodes' data to calculate mean and variance
reduce_size = broadcast_to(
Tensor(reduce_size).astype(dtype=_dtype), [1] * _ndim
)
stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1)
(stat,) = apply(syncbn_concat_stats, reduce_size, channel_x1s, channel_x2s)
stat = all_reduce_sum(stat, group) stat = all_reduce_sum(stat, group)
reduce_size = stat[:, :1].reshape(1)
channel_x1s = stat[:, 1 : 1 + _channels]
channel_x2s = stat[:, 1 + _channels :]
reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats, stat)


channel_mean = channel_x1s / reduce_size
channel_variance = (
channel_x1s ** 2 / (-reduce_size * reduce_size) + channel_x2s / reduce_size
outvar, channel_mean, *_ = apply(
syncbn_stage1, inp, reduce_size, channel_x1s, channel_x2s, eps, weight, bias
) )
else: else:
assert running_var is not None and running_mean is not None assert running_var is not None and running_mean is not None
channel_variance = running_var.reshape(*_param_shape)
channel_mean = running_mean.reshape(*_param_shape)

invsqrt_channel_variance = (
maximum(channel_variance, eps) if eps_mode == "max" else channel_variance + eps
) ** -0.5

if weight is not None:
weight = weight.reshape(*_param_shape)
if bias is not None:
bias = bias.reshape(*_param_shape)
channel_mean = running_mean
channel_var = running_var
outvar, *_ = apply(
syncbn_stage1_inference, inp, channel_mean, channel_var, eps, weight, bias
)


# outvar = output * weight + bias # outvar = output * weight + bias
# where output = inp * invsqrt_channel_variance + ( # where output = inp * invsqrt_channel_variance + (
@@ -1246,28 +1359,18 @@ def sync_batch_norm(
# ) # )
# Manually expand output for gopt # Manually expand output for gopt


if weight is not None:
inv_var_wt = invsqrt_channel_variance * weight
neg_channel_mean = -channel_mean
if bias is not None:
outvar = inp * inv_var_wt + (neg_channel_mean * inv_var_wt + bias)
else:
outvar = inp * inv_var_wt + neg_channel_mean * inv_var_wt
else:
outvar = inp * invsqrt_channel_variance + (
-channel_mean * invsqrt_channel_variance
)
if bias is not None:
outvar = outvar + bias

if training and running_var is not None and running_mean is not None: if training and running_var is not None and running_mean is not None:
running_mean *= momentum
running_mean += (1 - momentum) * channel_mean
channel_variance_unbiased = channel_x1s ** 2 / (
-reduce_size * (reduce_size - 1)
) + channel_x2s / (reduce_size - 1)
running_var *= momentum
running_var += (1 - momentum) * channel_variance_unbiased
momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device)
running_mean[...], running_var[...] = apply(
syncbn_stage2,
running_mean,
running_var,
momentum,
reduce_size,
channel_x1s,
channel_x2s,
channel_mean,
)


return outvar return outvar




+ 1
- 1
imperative/python/megengine/jit/tracing.py View File

@@ -66,7 +66,7 @@ def is_tracing():
@contextlib.contextmanager @contextlib.contextmanager
def exclude_from_trace(): def exclude_from_trace():
global skip_tracing global skip_tracing
if skip_tracing:
if skip_tracing or (active_trace is None):
yield yield
return return
try: try:


+ 4
- 0
imperative/python/src/common.cpp View File

@@ -58,6 +58,9 @@ void init_common(py::module m) {
.def_property_readonly("logical_name", [](const CompNode& cn) { .def_property_readonly("logical_name", [](const CompNode& cn) {
return cn.to_string_logical(); return cn.to_string_logical();
}) })
.def_property_readonly("physical_name", [](const CompNode& cn) {
return cn.to_string();
})
.def_property_readonly("get_mem_status_bytes", [](const CompNode& cn) { .def_property_readonly("get_mem_status_bytes", [](const CompNode& cn) {
return cn.get_mem_status_bytes(); return cn.get_mem_status_bytes();
}) })
@@ -70,6 +73,7 @@ void init_common(py::module m) {
cn.to_string_physical().c_str(), cn.to_string_physical().c_str(),
cn.to_string_logical().c_str()); cn.to_string_logical().c_str());
}) })
.def("__hash__", [](CompNode cn){ return mgb::hash(cn); })
.def_static("_sync_all", &CompNode::sync_all) .def_static("_sync_all", &CompNode::sync_all)
.def(py::self == py::self) .def(py::self == py::self)
.def_static("_get_device_count", &CompNode::get_device_count, .def_static("_get_device_count", &CompNode::get_device_count,


+ 47
- 0
imperative/python/src/ops.cpp View File

@@ -15,6 +15,7 @@


#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/imperative.h" #include "megbrain/imperative.h"
#include "megbrain/imperative/graph_builder.h"
#include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
@@ -477,4 +478,50 @@ void init_ops(py::module m) {
m.def("set_global_rng_seed", &rng::set_global_rng_seed); m.def("set_global_rng_seed", &rng::set_global_rng_seed);
m.def("get_global_rng_seed", &rng::get_global_rng_seed); m.def("get_global_rng_seed", &rng::get_global_rng_seed);
m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode); m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode);

struct PySubgraphBuilder {
explicit PySubgraphBuilder(std::string name) : name{name}{}
std::string name;
Subgraph graph;
mgb::SmallVector<bool> output_grad_mask;
Subgraph::var_t next_var = 1;
};

py::class_<PySubgraphBuilder>(m, "SubgraphBuilder")
.def(py::init<std::string>())
.def("input", [](PySubgraphBuilder& self){
auto var = self.next_var++;
self.graph.inputs.push_back(var);
return var;
})
.def("apply", [](PySubgraphBuilder& self, std::shared_ptr<OpDef> op, Subgraph::vars_t inputs, size_t nr_outputs){
Subgraph::vars_t outputs;
for (size_t i = 0; i < nr_outputs; ++i) {
outputs.push_back(self.next_var++);
}
self.graph.exprs.push_back({op, inputs, outputs});
return outputs;
})
.def("apply_const", [](PySubgraphBuilder& self, py::object value, mgb::DType dtype, mgb::CompNode cn){
auto var = self.next_var++;
mgb::HostTensorND hvalue(cn);
npy::np2tensor(value.cast<py::array>().ptr(), npy::Meth::copy_into(&hvalue), dtype);
self.graph.constants.push_back({var, Tensor::make(hvalue)});
return var;
})
.def("outputs", [](PySubgraphBuilder& self, Subgraph::vars_t outputs){
self.graph.outputs = outputs;
self.output_grad_mask.resize(outputs.size(), true);
})
.def("outputs_has_grad", [](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad){
mgb_assert(self.graph.outputs.size() == self.output_grad_mask.size());
self.output_grad_mask = outputs_has_grad;
})
.def("get", [](PySubgraphBuilder& self){
return (std::shared_ptr<OpDef>)SubgraphOp::make(self.name, self.graph, self.output_grad_mask);
})
.def("compile", [](PySubgraphBuilder& self, int gopt_level){
auto op = SubgraphOp::make(self.name, self.graph, self.output_grad_mask);
return (std::shared_ptr<OpDef>)CompiledOp::make(op, gopt_level);
});
} }

Loading…
Cancel
Save