@@ -222,45 +222,40 @@ def _normalize_axis( | |||||
raise | raise | ||||
_opr_map = { | |||||
("-", 1): builtin.Elemwise(mode="negate"), | |||||
("fma3", 3): builtin.Elemwise(mode="FUSE_MUL_ADD3"), | |||||
("fma4", 4): builtin.Elemwise(mode="FUSE_MUL_ADD4"), | |||||
} | |||||
for name, mode in [ | |||||
("+", "add"), | |||||
("-", "sub"), | |||||
("*", "mul"), | |||||
("/", "true_div"), | |||||
("//", "floor_div"), | |||||
("**", "pow"), | |||||
("max", "max"), | |||||
("additive", "add"), | |||||
]: | |||||
_opr_map[(name, 2)] = builtin.Elemwise(mode=mode) | |||||
def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | ||||
if device.physical_name.startswith("cpu"): | if device.physical_name.startswith("cpu"): | ||||
gopt_level = None # disable jit and compile | gopt_level = None # disable jit and compile | ||||
binary_ops = { | |||||
"+": lambda: builtin.Elemwise(mode="add"), | |||||
"-": lambda: builtin.Elemwise(mode="sub"), | |||||
"*": lambda: builtin.Elemwise(mode="mul"), | |||||
"/": lambda: builtin.Elemwise(mode="true_div"), | |||||
"//": lambda: builtin.Elemwise(mode="floor_div"), | |||||
"**": lambda: builtin.Elemwise(mode="pow"), | |||||
"√": lambda: builtin.Elemwise(mode="expm1"), | |||||
"max": lambda: builtin.Elemwise(mode="max"), | |||||
"additive": lambda: builtin.Elemwise(mode="add"), | |||||
} | |||||
unary_ops = { | |||||
"-": lambda: builtin.Elemwise(mode="negate"), | |||||
} | |||||
ternary_ops = { | |||||
"fma3": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD3"), | |||||
} | |||||
quaternary_ops = {"fma4": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD4")} | |||||
def as_op(op, nargs): | |||||
if isinstance(op, str): | |||||
assert (op, nargs) in _opr_map, "unknown operator" | |||||
op = _opr_map[(op, nargs)] | |||||
return op | |||||
def decorator(func): | def decorator(func): | ||||
builder = _SubgraphBuilder(name) | builder = _SubgraphBuilder(name) | ||||
def apply_expr(op, *args, nr_out=None): | def apply_expr(op, *args, nr_out=None): | ||||
if isinstance(op, str): | |||||
if len(args) == 2: | |||||
op = binary_ops[op]() | |||||
elif len(args) == 1: | |||||
op = unary_ops[op]() | |||||
elif len(args) == 3: | |||||
op = ternary_ops[op]() | |||||
elif len(args) == 4: | |||||
op = quaternary_ops[op]() | |||||
op = as_op(op, len(args)) | |||||
results = builder.apply(op, args, 1 if nr_out is None else nr_out) | results = builder.apply(op, args, 1 if nr_out is None else nr_out) | ||||
if nr_out is None: | if nr_out is None: | ||||
assert len(results) == 1 | assert len(results) == 1 | ||||
@@ -282,3 +277,40 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||||
return lambda: builder.compile(gopt_level) | return lambda: builder.compile(gopt_level) | ||||
return decorator | return decorator | ||||
def interpret_subgraph(func, dtype, device): | |||||
def as_op(op, nargs): | |||||
if isinstance(op, str) and (op, nargs) in _opr_map: | |||||
op = _opr_map[(op, nargs)] | |||||
return op | |||||
def decorated_func(*args): | |||||
def apply_expr(op, *args, nr_out=None): | |||||
op = as_op(op, len(args)) | |||||
results = apply(op, *args) | |||||
if nr_out is None: | |||||
assert len(results) == 1 | |||||
return results[0] | |||||
else: | |||||
assert len(results) == nr_out | |||||
return results | |||||
def apply_const(value, dtype=dtype, device=device): | |||||
return Const(value, dtype=dtype, device=device)()[0] | |||||
outputs, outputs_has_grad = func(args, apply_expr, apply_const) | |||||
return outputs | |||||
return decorated_func | |||||
def subgraph_fn(name, dtype, device, nr_inputs, gopt_level=None, interpret=False): | |||||
def decorator(func): | |||||
if not interpret: | |||||
op = subgraph(name, dtype, device, nr_inputs, gopt_level=gopt_level)(func) | |||||
return lambda *args: apply(op(), *args) | |||||
else: | |||||
return interpret_subgraph(func, dtype, device) | |||||
return decorator |
@@ -0,0 +1,108 @@ | |||||
import functools | |||||
import numpy as np | |||||
import pytest | |||||
import megengine | |||||
from megengine.autodiff.grad_manager import GradManager | |||||
from megengine.core.ops.builtin import GetVarShape, Reduce, TypeCvt | |||||
from megengine.core.tensor.utils import subgraph_fn | |||||
from megengine.device import CompNode, get_default_device | |||||
from megengine.jit import trace | |||||
_assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) | |||||
@functools.lru_cache(maxsize=None) | |||||
def _get_batch_norm_fn(dtype, device, channels, ndim, interpret, gopt_level): | |||||
@subgraph_fn( | |||||
"BatchNormNd", | |||||
dtype=dtype, | |||||
device=device, | |||||
nr_inputs=4, | |||||
interpret=interpret, | |||||
gopt_level=gopt_level, | |||||
) | |||||
def batch_norm_nd(inputs, f, c): | |||||
input, eps, weight, bias = inputs[0:4] | |||||
reduce_shape = c( | |||||
(1, channels) + (1,) * (ndim - 2), dtype="int32", device=device | |||||
) | |||||
input_shape = f(GetVarShape(), input) | |||||
input_elems = f(Reduce(mode="product", axis=0), input_shape) | |||||
reduce_elems = f(Reduce(mode="product", axis=0), reduce_shape) | |||||
reduce_size = f("//", input_elems, reduce_elems) | |||||
reduce_size = f(TypeCvt(dtype=dtype), reduce_size) | |||||
channel_x1s = f(Reduce(mode="sum"), input, reduce_shape) | |||||
channel_x2s = f(Reduce(mode="sum_sqr"), input, reduce_shape) | |||||
channel_mean = f("/", channel_x1s, reduce_size) | |||||
channel_var = f( | |||||
"-", f("/", channel_x2s, reduce_size), f("*", channel_mean, channel_mean), | |||||
) | |||||
invsqrt_channel_var = f("**", f("+", channel_var, eps), c(-0.5)) | |||||
inv_var_wt = f("*", invsqrt_channel_var, weight) | |||||
neg_channel_mean = f("-", channel_mean) | |||||
outvar = f( | |||||
"fma3", input, inv_var_wt, f("fma3", neg_channel_mean, inv_var_wt, bias), | |||||
) | |||||
return (outvar,), (True,) | |||||
return batch_norm_nd | |||||
@pytest.mark.parametrize("device", [get_default_device(), "cpux"]) | |||||
@pytest.mark.parametrize("batch_size", [1, 8]) | |||||
@pytest.mark.parametrize("channels", [3]) | |||||
@pytest.mark.parametrize( | |||||
"use_trace, symbolic", [(False, None), (True, False), (True, True)] | |||||
) | |||||
@pytest.mark.parametrize("gopt_level", [None, 1, 2]) | |||||
@pytest.mark.parametrize("dtype", ["float32"]) | |||||
def test_subgraph(device, batch_size, channels, use_trace, symbolic, gopt_level, dtype): | |||||
device = CompNode(device) | |||||
def subgraph_batch_norm(inp, weight, bias, eps, diff): | |||||
inp = inp.detach() | |||||
with GradManager().attach(inp) as gm: | |||||
batch_norm_fn = _get_batch_norm_fn( | |||||
dtype, device, channels, ndim, interpret=False, gopt_level=gopt_level | |||||
) | |||||
out, *_ = batch_norm_fn(inp, eps, weight, bias) | |||||
gm.backward(out * 1e3 + 1e3, diff) | |||||
return out, inp.grad | |||||
def primitive_batch_norm(inp, weight, bias, eps, diff): | |||||
inp = inp.detach() | |||||
with GradManager().attach(inp) as gm: | |||||
batch_norm_fn = _get_batch_norm_fn( | |||||
dtype, device, channels, ndim, interpret=True, gopt_level=gopt_level | |||||
) | |||||
(out,) = batch_norm_fn(inp, eps, weight, bias) | |||||
gm.backward(out * 1e3 + 1e3, diff) | |||||
return out, inp.grad | |||||
if use_trace: | |||||
subgraph_batch_norm = trace(symbolic=symbolic)(subgraph_batch_norm) | |||||
primitive_batch_norm = trace(symbolic=symbolic)(primitive_batch_norm) | |||||
def rand_tensor(shape, dtype=dtype, device=device): | |||||
return megengine.tensor(np.random.random(shape), dtype=dtype, device=device) | |||||
# test shape change | |||||
for image_shape in [(223, 223), (10, 20)]: | |||||
ndim = len(image_shape) + 2 | |||||
input_shape = (batch_size, channels) + image_shape | |||||
param_shape = (1, channels) + (1,) * len(image_shape) | |||||
inp = rand_tensor(input_shape) * 1e3 + 1e3 | |||||
weight = rand_tensor(param_shape) | |||||
bias = rand_tensor(param_shape) | |||||
eps = megengine.tensor(1e-5, dtype=dtype, device=device) | |||||
diff = rand_tensor(input_shape) | |||||
out1, grad1 = subgraph_batch_norm(inp, weight, bias, eps, diff) | |||||
out2, grad2 = primitive_batch_norm(inp, weight, bias, eps, diff) | |||||
_assert_allclose(out1.numpy(), out2.numpy()) | |||||
_assert_allclose(grad1.numpy(), grad2.numpy()) |
@@ -15,6 +15,7 @@ import pytest | |||||
import megengine as mge | import megengine as mge | ||||
import megengine.distributed as dist | import megengine.distributed as dist | ||||
from megengine import Tensor | from megengine import Tensor | ||||
from megengine.autodiff.grad_manager import GradManager | |||||
from megengine.core._trace_option import use_symbolic_shape | from megengine.core._trace_option import use_symbolic_shape | ||||
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm | from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm | ||||
@@ -337,3 +338,33 @@ def test_syncbn2d_no_stats(): | |||||
yv_expect = (xv - mean) / sd | yv_expect = (xv - mean) / sd | ||||
_assert_allclose(yv.numpy(), yv_expect) | _assert_allclose(yv.numpy(), yv_expect) | ||||
def test_syncbn2d_grad(): | |||||
nr_chan = 8 | |||||
data_shape = (3, nr_chan, 16, 16) | |||||
syncbn = SyncBatchNorm(8, track_running_stats=False) | |||||
bn = BatchNorm2d(8, track_running_stats=False) | |||||
for i in range(4): | |||||
if i == 2: | |||||
syncbn.training = False | |||||
bn.training = False | |||||
inp = Tensor(np.random.normal(loc=2.3, size=data_shape).astype(np.float32)) | |||||
diff = Tensor(np.random.normal(size=data_shape).astype(np.float32)) | |||||
with GradManager().attach(inp) as gm: | |||||
oup = syncbn(inp) | |||||
gm.backward(oup, diff) | |||||
grad = inp.grad | |||||
inp.grad = None | |||||
with GradManager().attach(inp) as gm: | |||||
oup_expect = bn(inp) | |||||
gm.backward(oup_expect, diff) | |||||
grad_expect = inp.grad | |||||
inp.grad = None | |||||
_assert_allclose(oup.numpy(), oup_expect.numpy()) | |||||
_assert_allclose(grad.numpy(), grad_expect.numpy()) |