Browse Source

test(imperative): speed up dtr test

GitOrigin-RevId: 57f092d729
release-1.7
Megvii Engine Team 3 years ago
parent
commit
89ed7ab2ff
1 changed files with 14 additions and 2 deletions
  1. +14
    -2
      imperative/python/test/integration/test_dtr.py

+ 14
- 2
imperative/python/test/integration/test_dtr.py View File

@@ -90,7 +90,7 @@ class ResNet(M.Module):

@pytest.mark.require_ngpu(1)
def test_dtr_resnet1202():
batch_size = 64
batch_size = 8
resnet1202 = ResNet(BasicBlock, [200, 200, 200])
opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4)
gm = GradManager().attach(resnet1202.parameters())
@@ -103,12 +103,24 @@ def test_dtr_resnet1202():
gm.backward(loss)
return pred, loss

_, free_mem = mge.device.get_mem_status_bytes()
tensor_mem = free_mem - (2 ** 30)
if tensor_mem > 0:
x = np.ones((1, int(tensor_mem / 4)), dtype=np.float32)
else:
x = np.ones((1,), dtype=np.float32)
t = mge.tensor(x)

mge.dtr.enable()
mge.dtr.enable_sqrt_sampling = True

data = np.random.randn(batch_size, 3, 32, 32).astype("float32")
label = np.random.randint(0, 10, size=(batch_size,)).astype("int32")
for step in range(10):
for _ in range(2):
opt.clear_grad()
_, loss = train_func(mge.tensor(data), mge.tensor(label), net=resnet1202, gm=gm)
opt.step()
loss.item()

t.numpy()
mge.dtr.disable()

Loading…
Cancel
Save