|
|
@@ -1,3 +1,5 @@ |
|
|
|
import multiprocessing as mp |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import pytest |
|
|
|
|
|
|
@@ -88,8 +90,7 @@ class ResNet(M.Module): |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.require_ngpu(1) |
|
|
|
def test_dtr_resnet1202(): |
|
|
|
def run_dtr_resnet1202(): |
|
|
|
batch_size = 8 |
|
|
|
resnet1202 = ResNet(BasicBlock, [200, 200, 200]) |
|
|
|
opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4) |
|
|
@@ -124,3 +125,13 @@ def test_dtr_resnet1202(): |
|
|
|
|
|
|
|
t.numpy() |
|
|
|
mge.dtr.disable() |
|
|
|
mge._exit(0) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.require_ngpu(1) |
|
|
|
@pytest.mark.isolated_distributed |
|
|
|
def test_dtr_resnet1202(): |
|
|
|
p = mp.Process(target=run_dtr_resnet1202) |
|
|
|
p.start() |
|
|
|
p.join() |
|
|
|
assert p.exitcode == 0 |