|
|
@@ -90,7 +90,7 @@ class ResNet(M.Module): |
|
|
|
|
|
|
|
@pytest.mark.require_ngpu(1) |
|
|
|
def test_dtr_resnet1202(): |
|
|
|
batch_size = 64 |
|
|
|
batch_size = 8 |
|
|
|
resnet1202 = ResNet(BasicBlock, [200, 200, 200]) |
|
|
|
opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4) |
|
|
|
gm = GradManager().attach(resnet1202.parameters()) |
|
|
@@ -103,12 +103,24 @@ def test_dtr_resnet1202(): |
|
|
|
gm.backward(loss) |
|
|
|
return pred, loss |
|
|
|
|
|
|
|
_, free_mem = mge.device.get_mem_status_bytes() |
|
|
|
tensor_mem = free_mem - (2 ** 30) |
|
|
|
if tensor_mem > 0: |
|
|
|
x = np.ones((1, int(tensor_mem / 4)), dtype=np.float32) |
|
|
|
else: |
|
|
|
x = np.ones((1,), dtype=np.float32) |
|
|
|
t = mge.tensor(x) |
|
|
|
|
|
|
|
mge.dtr.enable() |
|
|
|
mge.dtr.enable_sqrt_sampling = True |
|
|
|
|
|
|
|
data = np.random.randn(batch_size, 3, 32, 32).astype("float32") |
|
|
|
label = np.random.randint(0, 10, size=(batch_size,)).astype("int32") |
|
|
|
for step in range(10): |
|
|
|
for _ in range(2): |
|
|
|
opt.clear_grad() |
|
|
|
_, loss = train_func(mge.tensor(data), mge.tensor(label), net=resnet1202, gm=gm) |
|
|
|
opt.step() |
|
|
|
loss.item() |
|
|
|
|
|
|
|
t.numpy() |
|
|
|
mge.dtr.disable() |