diff --git a/imperative/python/test/integration/test_dtr.py b/imperative/python/test/integration/test_dtr.py index 25ad054f..7e4b98a4 100644 --- a/imperative/python/test/integration/test_dtr.py +++ b/imperative/python/test/integration/test_dtr.py @@ -90,7 +90,7 @@ class ResNet(M.Module): @pytest.mark.require_ngpu(1) def test_dtr_resnet1202(): - batch_size = 64 + batch_size = 8 resnet1202 = ResNet(BasicBlock, [200, 200, 200]) opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4) gm = GradManager().attach(resnet1202.parameters()) @@ -103,12 +103,24 @@ def test_dtr_resnet1202(): gm.backward(loss) return pred, loss + _, free_mem = mge.device.get_mem_status_bytes() + tensor_mem = free_mem - (2 ** 30) + if tensor_mem > 0: + x = np.ones((1, int(tensor_mem / 4)), dtype=np.float32) + else: + x = np.ones((1,), dtype=np.float32) + t = mge.tensor(x) + mge.dtr.enable() + mge.dtr.enable_sqrt_sampling = True data = np.random.randn(batch_size, 3, 32, 32).astype("float32") label = np.random.randint(0, 10, size=(batch_size,)).astype("int32") - for step in range(10): + for _ in range(2): opt.clear_grad() _, loss = train_func(mge.tensor(data), mge.tensor(label), net=resnet1202, gm=gm) opt.step() loss.item() + + t.numpy() + mge.dtr.disable()