|
@@ -1174,7 +1174,7 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): |
|
|
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) |
|
|
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) |
|
|
return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True) |
|
|
return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True) |
|
|
|
|
|
|
|
|
@subgraph("SyncBnStage1", dtype, device, 7, gopt_level=3) |
|
|
|
|
|
|
|
|
@subgraph("SyncBnStage1", dtype, device, 7) |
|
|
def syncbn_stage1(inputs, f, c): |
|
|
def syncbn_stage1(inputs, f, c): |
|
|
input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5] |
|
|
input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5] |
|
|
weight, bias = inputs[5:7] |
|
|
weight, bias = inputs[5:7] |
|
@@ -1187,12 +1187,12 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): |
|
|
inv_var_wt = f("*", invsqrt_channel_var, weight) |
|
|
inv_var_wt = f("*", invsqrt_channel_var, weight) |
|
|
neg_channel_mean = f("-", channel_mean) |
|
|
neg_channel_mean = f("-", channel_mean) |
|
|
outvar =\ |
|
|
outvar =\ |
|
|
f("+", f("*", input, inv_var_wt), |
|
|
|
|
|
|
|
|
f("fma3", input, inv_var_wt, |
|
|
f("+", f("*", neg_channel_mean, inv_var_wt), |
|
|
f("+", f("*", neg_channel_mean, inv_var_wt), |
|
|
bias)) |
|
|
bias)) |
|
|
return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False) |
|
|
return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False) |
|
|
|
|
|
|
|
|
@subgraph("SyncBnStage1Inference", dtype, device, 6, gopt_level=3) |
|
|
|
|
|
|
|
|
@subgraph("SyncBnStage1Inference", dtype, device, 6) |
|
|
def syncbn_stage1_inference(inputs, f, c): |
|
|
def syncbn_stage1_inference(inputs, f, c): |
|
|
input, channel_mean, channel_var, eps = inputs[0:4] |
|
|
input, channel_mean, channel_var, eps = inputs[0:4] |
|
|
weight, bias = inputs[4:6] |
|
|
weight, bias = inputs[4:6] |
|
@@ -1205,36 +1205,36 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): |
|
|
bias)) |
|
|
bias)) |
|
|
return (outvar,), (True,) |
|
|
return (outvar,), (True,) |
|
|
|
|
|
|
|
|
@subgraph("SyncBnStage2", dtype, device, 7, gopt_level=3) |
|
|
|
|
|
|
|
|
@subgraph("SyncBnStage2", dtype, device, 7) |
|
|
def syncbn_stage2(inputs, f, c): |
|
|
def syncbn_stage2(inputs, f, c): |
|
|
running_mean, running_var, momentum = inputs[0:3] |
|
|
running_mean, running_var, momentum = inputs[0:3] |
|
|
reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7] |
|
|
reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7] |
|
|
running_mean = f("*", running_mean, momentum) |
|
|
|
|
|
running_mean =\ |
|
|
|
|
|
f("+", running_mean, |
|
|
|
|
|
f("*", f("-", c(1), momentum), |
|
|
|
|
|
channel_mean)) |
|
|
|
|
|
|
|
|
c1_minus_momentum = f("-", c(1), momentum) |
|
|
|
|
|
reduce_size_minus_c1 = f("-", reduce_size, c(1)) |
|
|
|
|
|
running_mean = f("fma4", |
|
|
|
|
|
running_mean, momentum, |
|
|
|
|
|
c1_minus_momentum, channel_mean, |
|
|
|
|
|
) |
|
|
channel_variance_unbiased =\ |
|
|
channel_variance_unbiased =\ |
|
|
f("+", f("/", f("**", channel_x1s, c(2)), |
|
|
f("+", f("/", f("**", channel_x1s, c(2)), |
|
|
f("*", f("-", reduce_size), |
|
|
f("*", f("-", reduce_size), |
|
|
f("-", reduce_size, c(1)))), |
|
|
|
|
|
|
|
|
reduce_size_minus_c1)), |
|
|
f("/", channel_x2s, |
|
|
f("/", channel_x2s, |
|
|
f("-", reduce_size, c(1)))) |
|
|
|
|
|
running_var = f("*", running_var, momentum) |
|
|
|
|
|
running_var =\ |
|
|
|
|
|
f("+", running_var, |
|
|
|
|
|
f("*", f("-", c(1), momentum), |
|
|
|
|
|
channel_variance_unbiased)) |
|
|
|
|
|
|
|
|
reduce_size_minus_c1)) |
|
|
|
|
|
running_var = f("fma4", |
|
|
|
|
|
running_var, momentum, |
|
|
|
|
|
c1_minus_momentum, channel_variance_unbiased |
|
|
|
|
|
) |
|
|
return (running_mean, running_var), (True, True) |
|
|
return (running_mean, running_var), (True, True) |
|
|
|
|
|
|
|
|
@subgraph("SyncBnConcatStats", dtype, device, 3, gopt_level=3) |
|
|
|
|
|
|
|
|
@subgraph("SyncBnConcatStats", dtype, device, 3) |
|
|
def syncbn_concat_stats(inputs, f, c): |
|
|
def syncbn_concat_stats(inputs, f, c): |
|
|
reduce_size, channel_x1s, channel_x2s = inputs[0:3] |
|
|
reduce_size, channel_x1s, channel_x2s = inputs[0:3] |
|
|
reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32")) |
|
|
reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32")) |
|
|
stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s) |
|
|
stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s) |
|
|
return (stats,), (True,) |
|
|
return (stats,), (True,) |
|
|
|
|
|
|
|
|
@subgraph("SyncBnSplitStats", dtype, device, 1, gopt_level=3) |
|
|
|
|
|
|
|
|
@subgraph("SyncBnSplitStats", dtype, device, 1) |
|
|
def syncbn_split_stats(inputs, f, c): |
|
|
def syncbn_split_stats(inputs, f, c): |
|
|
stats = inputs[0] |
|
|
stats = inputs[0] |
|
|
c_1 = c(1, dtype="int32") |
|
|
c_1 = c(1, dtype="int32") |
|
|