Browse Source

test(param_pack): more cases for param pack concat

GitOrigin-RevId: 0700b548ab
release-1.4
Megvii Engine Team 4 years ago
parent
commit
0a1bd90f6b
1 changed files with 27 additions and 41 deletions
  1. +27
    -41
      imperative/python/test/integration/test_param_pack.py

+ 27
- 41
imperative/python/test/integration/test_param_pack.py View File

@@ -22,9 +22,11 @@ from megengine.optimizer import SGD


class Simple(Module):
def __init__(self):
def __init__(self, param_shape):
super().__init__()
self.params = [Parameter(1.0, dtype=np.float32) for i in range(10)]
self.params = [
Parameter(np.ones(param_shape), dtype=np.float32) for i in range(10)
]

def forward(self, x):
for p in self.params:
@@ -34,51 +36,35 @@ class Simple(Module):

@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_param_pack():
data = np.ones([1], dtype="float32")

@dist.launcher
@pytest.mark.parametrize(
"threshold", [0, 128, None], ids=["no_pack", "small_pack", "large_pack"]
)
@pytest.mark.parametrize("param_shape", [(16,), (128, 256), (2, 1024, 1024)])
def test_param_pack(param_shape, threshold, n_iters=100):
data = np.ones(param_shape, dtype="float32")

@dist.launcher(n_gpus=2)
def worker():
net = Simple()
opt = SGD(net.parameters(), lr=0.1)

gm = ad.GradManager().attach(
net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)]
)

opt.clear_grad()
with gm:
x = tensor(data)
loss = net(x)
loss = loss.sum()
gm.backward(loss)
for p in net.params:
np.testing.assert_equal(p.grad.numpy(), 1)

worker()


@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_param_pack_with_no_param():
data = np.ones([1], dtype="float32")

@dist.launcher
def worker():
net = Simple()
net = Simple(param_shape)
opt = SGD(net.parameters(), lr=0.1)

allreduce_cb = dist.make_allreduce_cb("MEAN", dist.WORLD)
allreduce_cb._param_pack_thd = 0
if threshold is not None:
allreduce_cb._param_pack_thd = threshold
gm = ad.GradManager().attach(net.parameters(), callbacks=[allreduce_cb])

opt.clear_grad()
with gm:
x = tensor(data)
loss = net(x)
loss = loss.sum()
gm.backward(loss)
def run():
opt.clear_grad()
with gm:
x = tensor(data)
loss = net(x)
loss = loss.sum()
gm.backward(loss)

for i in range(n_iters):
run()

for p in net.params:
np.testing.assert_equal(p.grad.numpy(), 1)
np.testing.assert_equal(p.grad.numpy(), np.ones_like(p.grad.numpy()))

worker()

Loading…
Cancel
Save