diff --git a/imperative/python/test/integration/test_dtr.py b/imperative/python/test/integration/test_dtr.py index 7e4b98a4..68b91594 100644 --- a/imperative/python/test/integration/test_dtr.py +++ b/imperative/python/test/integration/test_dtr.py @@ -1,3 +1,5 @@ +import multiprocessing as mp + import numpy as np import pytest @@ -88,8 +90,7 @@ class ResNet(M.Module): return out -@pytest.mark.require_ngpu(1) -def test_dtr_resnet1202(): +def run_dtr_resnet1202(): batch_size = 8 resnet1202 = ResNet(BasicBlock, [200, 200, 200]) opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4) @@ -124,3 +125,13 @@ def test_dtr_resnet1202(): t.numpy() mge.dtr.disable() + mge._exit(0) + + +@pytest.mark.require_ngpu(1) +@pytest.mark.isolated_distributed +def test_dtr_resnet1202(): + p = mp.Process(target=run_dtr_resnet1202) + p.start() + p.join() + assert p.exitcode == 0