@@ -222,45 +222,40 @@ def _normalize_axis( | |||
raise | |||
_opr_map = { | |||
("-", 1): builtin.Elemwise(mode="negate"), | |||
("fma3", 3): builtin.Elemwise(mode="FUSE_MUL_ADD3"), | |||
("fma4", 4): builtin.Elemwise(mode="FUSE_MUL_ADD4"), | |||
} | |||
for name, mode in [ | |||
("+", "add"), | |||
("-", "sub"), | |||
("*", "mul"), | |||
("/", "true_div"), | |||
("//", "floor_div"), | |||
("**", "pow"), | |||
("max", "max"), | |||
("additive", "add"), | |||
]: | |||
_opr_map[(name, 2)] = builtin.Elemwise(mode=mode) | |||
def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||
if device.physical_name.startswith("cpu"): | |||
gopt_level = None # disable jit and compile | |||
binary_ops = { | |||
"+": lambda: builtin.Elemwise(mode="add"), | |||
"-": lambda: builtin.Elemwise(mode="sub"), | |||
"*": lambda: builtin.Elemwise(mode="mul"), | |||
"/": lambda: builtin.Elemwise(mode="true_div"), | |||
"//": lambda: builtin.Elemwise(mode="floor_div"), | |||
"**": lambda: builtin.Elemwise(mode="pow"), | |||
"√": lambda: builtin.Elemwise(mode="expm1"), | |||
"max": lambda: builtin.Elemwise(mode="max"), | |||
"additive": lambda: builtin.Elemwise(mode="add"), | |||
} | |||
unary_ops = { | |||
"-": lambda: builtin.Elemwise(mode="negate"), | |||
} | |||
ternary_ops = { | |||
"fma3": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD3"), | |||
} | |||
quaternary_ops = {"fma4": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD4")} | |||
def as_op(op, nargs): | |||
if isinstance(op, str): | |||
assert (op, nargs) in _opr_map, "unknown operator" | |||
op = _opr_map[(op, nargs)] | |||
return op | |||
def decorator(func): | |||
builder = _SubgraphBuilder(name) | |||
def apply_expr(op, *args, nr_out=None): | |||
if isinstance(op, str): | |||
if len(args) == 2: | |||
op = binary_ops[op]() | |||
elif len(args) == 1: | |||
op = unary_ops[op]() | |||
elif len(args) == 3: | |||
op = ternary_ops[op]() | |||
elif len(args) == 4: | |||
op = quaternary_ops[op]() | |||
op = as_op(op, len(args)) | |||
results = builder.apply(op, args, 1 if nr_out is None else nr_out) | |||
if nr_out is None: | |||
assert len(results) == 1 | |||
@@ -282,3 +277,40 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||
return lambda: builder.compile(gopt_level) | |||
return decorator | |||
def interpret_subgraph(func, dtype, device): | |||
def as_op(op, nargs): | |||
if isinstance(op, str) and (op, nargs) in _opr_map: | |||
op = _opr_map[(op, nargs)] | |||
return op | |||
def decorated_func(*args): | |||
def apply_expr(op, *args, nr_out=None): | |||
op = as_op(op, len(args)) | |||
results = apply(op, *args) | |||
if nr_out is None: | |||
assert len(results) == 1 | |||
return results[0] | |||
else: | |||
assert len(results) == nr_out | |||
return results | |||
def apply_const(value, dtype=dtype, device=device): | |||
return Const(value, dtype=dtype, device=device)()[0] | |||
outputs, outputs_has_grad = func(args, apply_expr, apply_const) | |||
return outputs | |||
return decorated_func | |||
def subgraph_fn(name, dtype, device, nr_inputs, gopt_level=None, interpret=False): | |||
def decorator(func): | |||
if not interpret: | |||
op = subgraph(name, dtype, device, nr_inputs, gopt_level=gopt_level)(func) | |||
return lambda *args: apply(op(), *args) | |||
else: | |||
return interpret_subgraph(func, dtype, device) | |||
return decorator |
@@ -0,0 +1,108 @@ | |||
import functools | |||
import numpy as np | |||
import pytest | |||
import megengine | |||
from megengine.autodiff.grad_manager import GradManager | |||
from megengine.core.ops.builtin import GetVarShape, Reduce, TypeCvt | |||
from megengine.core.tensor.utils import subgraph_fn | |||
from megengine.device import CompNode, get_default_device | |||
from megengine.jit import trace | |||
_assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) | |||
@functools.lru_cache(maxsize=None) | |||
def _get_batch_norm_fn(dtype, device, channels, ndim, interpret, gopt_level): | |||
@subgraph_fn( | |||
"BatchNormNd", | |||
dtype=dtype, | |||
device=device, | |||
nr_inputs=4, | |||
interpret=interpret, | |||
gopt_level=gopt_level, | |||
) | |||
def batch_norm_nd(inputs, f, c): | |||
input, eps, weight, bias = inputs[0:4] | |||
reduce_shape = c( | |||
(1, channels) + (1,) * (ndim - 2), dtype="int32", device=device | |||
) | |||
input_shape = f(GetVarShape(), input) | |||
input_elems = f(Reduce(mode="product", axis=0), input_shape) | |||
reduce_elems = f(Reduce(mode="product", axis=0), reduce_shape) | |||
reduce_size = f("//", input_elems, reduce_elems) | |||
reduce_size = f(TypeCvt(dtype=dtype), reduce_size) | |||
channel_x1s = f(Reduce(mode="sum"), input, reduce_shape) | |||
channel_x2s = f(Reduce(mode="sum_sqr"), input, reduce_shape) | |||
channel_mean = f("/", channel_x1s, reduce_size) | |||
channel_var = f( | |||
"-", f("/", channel_x2s, reduce_size), f("*", channel_mean, channel_mean), | |||
) | |||
invsqrt_channel_var = f("**", f("+", channel_var, eps), c(-0.5)) | |||
inv_var_wt = f("*", invsqrt_channel_var, weight) | |||
neg_channel_mean = f("-", channel_mean) | |||
outvar = f( | |||
"fma3", input, inv_var_wt, f("fma3", neg_channel_mean, inv_var_wt, bias), | |||
) | |||
return (outvar,), (True,) | |||
return batch_norm_nd | |||
@pytest.mark.parametrize("device", [get_default_device(), "cpux"]) | |||
@pytest.mark.parametrize("batch_size", [1, 8]) | |||
@pytest.mark.parametrize("channels", [3]) | |||
@pytest.mark.parametrize( | |||
"use_trace, symbolic", [(False, None), (True, False), (True, True)] | |||
) | |||
@pytest.mark.parametrize("gopt_level", [None, 1, 2]) | |||
@pytest.mark.parametrize("dtype", ["float32"]) | |||
def test_subgraph(device, batch_size, channels, use_trace, symbolic, gopt_level, dtype): | |||
device = CompNode(device) | |||
def subgraph_batch_norm(inp, weight, bias, eps, diff): | |||
inp = inp.detach() | |||
with GradManager().attach(inp) as gm: | |||
batch_norm_fn = _get_batch_norm_fn( | |||
dtype, device, channels, ndim, interpret=False, gopt_level=gopt_level | |||
) | |||
out, *_ = batch_norm_fn(inp, eps, weight, bias) | |||
gm.backward(out * 1e3 + 1e3, diff) | |||
return out, inp.grad | |||
def primitive_batch_norm(inp, weight, bias, eps, diff): | |||
inp = inp.detach() | |||
with GradManager().attach(inp) as gm: | |||
batch_norm_fn = _get_batch_norm_fn( | |||
dtype, device, channels, ndim, interpret=True, gopt_level=gopt_level | |||
) | |||
(out,) = batch_norm_fn(inp, eps, weight, bias) | |||
gm.backward(out * 1e3 + 1e3, diff) | |||
return out, inp.grad | |||
if use_trace: | |||
subgraph_batch_norm = trace(symbolic=symbolic)(subgraph_batch_norm) | |||
primitive_batch_norm = trace(symbolic=symbolic)(primitive_batch_norm) | |||
def rand_tensor(shape, dtype=dtype, device=device): | |||
return megengine.tensor(np.random.random(shape), dtype=dtype, device=device) | |||
# test shape change | |||
for image_shape in [(223, 223), (10, 20)]: | |||
ndim = len(image_shape) + 2 | |||
input_shape = (batch_size, channels) + image_shape | |||
param_shape = (1, channels) + (1,) * len(image_shape) | |||
inp = rand_tensor(input_shape) * 1e3 + 1e3 | |||
weight = rand_tensor(param_shape) | |||
bias = rand_tensor(param_shape) | |||
eps = megengine.tensor(1e-5, dtype=dtype, device=device) | |||
diff = rand_tensor(input_shape) | |||
out1, grad1 = subgraph_batch_norm(inp, weight, bias, eps, diff) | |||
out2, grad2 = primitive_batch_norm(inp, weight, bias, eps, diff) | |||
_assert_allclose(out1.numpy(), out2.numpy()) | |||
_assert_allclose(grad1.numpy(), grad2.numpy()) |
@@ -15,6 +15,7 @@ import pytest | |||
import megengine as mge | |||
import megengine.distributed as dist | |||
from megengine import Tensor | |||
from megengine.autodiff.grad_manager import GradManager | |||
from megengine.core._trace_option import use_symbolic_shape | |||
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
@@ -337,3 +338,33 @@ def test_syncbn2d_no_stats(): | |||
yv_expect = (xv - mean) / sd | |||
_assert_allclose(yv.numpy(), yv_expect) | |||
def test_syncbn2d_grad(): | |||
nr_chan = 8 | |||
data_shape = (3, nr_chan, 16, 16) | |||
syncbn = SyncBatchNorm(8, track_running_stats=False) | |||
bn = BatchNorm2d(8, track_running_stats=False) | |||
for i in range(4): | |||
if i == 2: | |||
syncbn.training = False | |||
bn.training = False | |||
inp = Tensor(np.random.normal(loc=2.3, size=data_shape).astype(np.float32)) | |||
diff = Tensor(np.random.normal(size=data_shape).astype(np.float32)) | |||
with GradManager().attach(inp) as gm: | |||
oup = syncbn(inp) | |||
gm.backward(oup, diff) | |||
grad = inp.grad | |||
inp.grad = None | |||
with GradManager().attach(inp) as gm: | |||
oup_expect = bn(inp) | |||
gm.backward(oup_expect, diff) | |||
grad_expect = inp.grad | |||
inp.grad = None | |||
_assert_allclose(oup.numpy(), oup_expect.numpy()) | |||
_assert_allclose(grad.numpy(), grad_expect.numpy()) |