GitOrigin-RevId: efc6377197
tags/v1.3.0
@@ -784,10 +784,10 @@ def sync_batch_norm( | |||
if is_distributed(): | |||
# reduce all nodes' data to calculate mean and variance | |||
reduce_size = broadcast_to(Tensor(reduce_size, dtype=_dtype), [1] * _ndim) | |||
stat = concat( | |||
[reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1 | |||
reduce_size = broadcast_to( | |||
Tensor(reduce_size).astype(dtype=_dtype), [1] * _ndim | |||
) | |||
stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1) | |||
stat = all_reduce_sum(stat, group) | |||
reduce_size = stat[:, :1].reshape(1) | |||
channel_x1s = stat[:, 1 : 1 + _channels] | |||
@@ -18,6 +18,7 @@ from .core._wrap import device as as_device | |||
from .core.ops.builtin import Copy, GetVarShape | |||
from .core.tensor.array_method import ArrayMethodMixin | |||
from .device import _valid_device, get_default_device | |||
from .logger import get_logger | |||
from .utils.deprecation import deprecated | |||
@@ -41,6 +42,10 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
cn = device._cn | |||
if isinstance(data, _Tensor): | |||
if dtype is not None: | |||
get_logger().warning( | |||
"dtype does not work when creating a new Tensor with another Tensor" | |||
) | |||
obj = _Tensor.__new__(cls, data) | |||
else: | |||
if isinstance(data, np.ndarray): | |||
@@ -17,7 +17,7 @@ import megengine.optimizer as optimizer | |||
from megengine import Parameter, tensor | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.jit import trace | |||
from megengine.module import BatchNorm2d, Module, SyncBatchNorm | |||
from megengine.module import BatchNorm2d, Conv2d, Module, Sequential, SyncBatchNorm | |||
def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False): | |||
@@ -68,7 +68,7 @@ def test_frozen_bn(): | |||
run_frozen_bn(BatchNorm2d, True, True) | |||
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||
@pytest.mark.require_ngpu(2) | |||
@pytest.mark.isolated_distributed | |||
def test_frozen_synced_bn(): | |||
@dist.launcher(n_gpus=2) | |||
@@ -151,6 +151,45 @@ def test_trace_bn_forward_twice(): | |||
np.testing.assert_equal(y.numpy(), 0) | |||
def run_syncbn(trace_mode): | |||
x = F.ones([2, 16, 4, 4], dtype="float32") | |||
net = Sequential( | |||
Conv2d(16, 16, 1), SyncBatchNorm(16), Conv2d(16, 16, 1), SyncBatchNorm(16), | |||
) | |||
gm = ad.GradManager().attach( | |||
net.parameters(), callbacks=dist.make_allreduce_cb("MEAN") | |||
) | |||
opt = optimizer.SGD(net.parameters(), 1e-3) | |||
def train_func(x): | |||
with gm: | |||
y = net(x) | |||
loss = y.mean() | |||
gm.backward(loss) | |||
opt.step().clear_grad() | |||
return loss | |||
if trace_mode is not None: | |||
train_func = trace(train_func, symbolic=trace_mode) | |||
for _ in range(3): | |||
loss = train_func(x) | |||
loss.numpy() | |||
@pytest.mark.require_ngpu(2) | |||
@pytest.mark.isolated_distributed | |||
@pytest.mark.parametrize("trace_mode", [None, True, False]) | |||
def test_trace_several_syncbn(trace_mode): | |||
@dist.launcher(n_gpus=2) | |||
def worker(): | |||
run_syncbn(trace_mode) | |||
worker() | |||
# https://github.com/MegEngine/MegEngine/issues/145 | |||
def test_frozen_bn_no_affine(): | |||
nchannel = 3 | |||
@@ -226,8 +226,14 @@ void DelayBroadcastPass::apply(OptState& opt) const { | |||
if (!prev) | |||
prev = rewriter.get_var(opr->input(inp_idx)); | |||
if (!opr->same_type<opr::Broadcast>()) { | |||
VarNodeArray new_inp = opr->input(); | |||
new_inp.at(inp_idx) = prev; | |||
VarNodeArray new_inp(opr->input().size()); | |||
for (size_t i = 0; i < opr->input().size(); i++) { | |||
if (i == inp_idx) { | |||
new_inp[i] = prev; | |||
} else { | |||
new_inp[i] = rewriter.get_var(opr->input(i)); | |||
} | |||
} | |||
opt.call_with_opr(opr, [&] { | |||
// create new opr with the original opr's properties | |||
auto new_opr = serialization::copy_opr_shallow( | |||
@@ -177,6 +177,32 @@ TEST_PASS(DelayBroadcastPass, LongChain) { | |||
ASSERT_EQ(bcast(bcast(relu(relu(x)), y), z), out); | |||
} | |||
TEST_PASS(DelayBroadcastPass, ElemwiseChain) { | |||
auto typecvt = [](SymbolVar x) { | |||
return opr::TypeCvt::make(x, dtype::Int32()); | |||
}; | |||
auto reduce = [](SymbolVar x) { | |||
SymbolVar tshp = x.make_scalar(1); | |||
opr::Reduce::Param param_default{opr::Reduce::Mode::SUM, INT_MAX, | |||
opr::Reduce::Param::DataType::DEFAULT}; | |||
return opr::Reduce::make(x, param_default, tshp); | |||
}; | |||
auto shp = TensorShape{2, 2}; | |||
auto x = mkvar("x", {1, 1}); | |||
auto val = x.make_scalar(3); | |||
auto out = reduce(typecvt(x.broadcast(shp))) + val.broadcast(shp); | |||
out = gopt::GraphOptimizer{}. | |||
add_pass<gopt::DelayBroadcastPass>(). | |||
apply({{out}}).endpoint_vars()[0]; | |||
auto expected = (reduce(typecvt(x).broadcast(shp)) + val).broadcast(shp); | |||
ASSERT_EQ(out, expected); | |||
} | |||
TEST_PASS(ExpandVirtualGradPass, Simple) { | |||
auto x = mkvar("x"); | |||
check(x * 2, | |||