You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dtr.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import multiprocessing as mp
  2. import numpy as np
  3. import pytest
  4. import megengine as mge
  5. import megengine.functional as F
  6. import megengine.module as M
  7. import megengine.optimizer as optim
  8. import megengine.tensor as tensor
  9. from megengine.autodiff import GradManager
  10. from megengine.data import DataLoader, RandomSampler, transform
  11. from megengine.data.dataset import CIFAR10
  12. def _weights_init(m):
  13. classname = m.__class__.__name__
  14. if isinstance(m, M.Linear) or isinstance(m, M.Conv2d):
  15. M.init.msra_normal_(m.weight)
  16. mean = [125.3, 123.0, 113.9]
  17. std = [63.0, 62.1, 66.7]
  18. class BasicBlock(M.Module):
  19. expansion = 1
  20. def __init__(self, in_planes, planes, stride=1):
  21. super(BasicBlock, self).__init__()
  22. self.conv1 = M.Conv2d(
  23. in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
  24. )
  25. self.bn1 = M.BatchNorm2d(planes)
  26. self.conv2 = M.Conv2d(
  27. planes, planes, kernel_size=3, stride=1, padding=1, bias=False
  28. )
  29. self.bn2 = M.BatchNorm2d(planes)
  30. self.shortcut = M.Sequential()
  31. if stride != 1 or in_planes != planes:
  32. self.shortcut = M.Sequential(
  33. M.Conv2d(
  34. in_planes,
  35. self.expansion * planes,
  36. kernel_size=1,
  37. stride=stride,
  38. bias=False,
  39. ),
  40. M.BatchNorm2d(self.expansion * planes),
  41. )
  42. def forward(self, x):
  43. out = F.relu(self.bn1(self.conv1(x)))
  44. out = self.bn2(self.conv2(out))
  45. out += self.shortcut(x)
  46. out = F.relu(out)
  47. return out
  48. class ResNet(M.Module):
  49. def __init__(self, block, num_blocks, num_classes=10):
  50. super(ResNet, self).__init__()
  51. self.in_planes = 16
  52. self.conv1 = M.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
  53. self.bn1 = M.BatchNorm2d(16)
  54. self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
  55. self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
  56. self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
  57. self.linear = M.Linear(64, num_classes)
  58. self.apply(_weights_init)
  59. def _make_layer(self, block, planes, num_blocks, stride):
  60. strides = [stride] + [1] * (num_blocks - 1)
  61. layers = []
  62. for stride in strides:
  63. layers.append(block(self.in_planes, planes, stride))
  64. self.in_planes = planes * block.expansion
  65. return M.Sequential(*layers)
  66. def forward(self, x):
  67. out = F.relu(self.bn1(self.conv1(x)))
  68. out = self.layer1(out)
  69. out = self.layer2(out)
  70. out = self.layer3(out)
  71. out = out.mean(3).mean(2)
  72. out = self.linear(out)
  73. return out
  74. def run_dtr_resnet1202():
  75. batch_size = 8
  76. resnet1202 = ResNet(BasicBlock, [200, 200, 200])
  77. opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4)
  78. gm = GradManager().attach(resnet1202.parameters())
  79. def train_func(data, label, *, net, gm):
  80. net.train()
  81. with gm:
  82. pred = net(data)
  83. loss = F.loss.cross_entropy(pred, label)
  84. gm.backward(loss)
  85. return pred, loss
  86. _, free_mem = mge.device.get_mem_status_bytes()
  87. tensor_mem = free_mem - (2 ** 30)
  88. if tensor_mem > 0:
  89. x = np.ones((1, int(tensor_mem / 4)), dtype=np.float32)
  90. else:
  91. x = np.ones((1,), dtype=np.float32)
  92. t = mge.tensor(x)
  93. mge.dtr.enable()
  94. mge.dtr.enable_sqrt_sampling = True
  95. data = np.random.randn(batch_size, 3, 32, 32).astype("float32")
  96. label = np.random.randint(0, 10, size=(batch_size,)).astype("int32")
  97. for _ in range(2):
  98. opt.clear_grad()
  99. _, loss = train_func(mge.tensor(data), mge.tensor(label), net=resnet1202, gm=gm)
  100. opt.step()
  101. loss.item()
  102. t.numpy()
  103. mge.dtr.disable()
  104. mge._exit(0)
  105. @pytest.mark.require_ngpu(1)
  106. @pytest.mark.isolated_distributed
  107. def test_dtr_resnet1202():
  108. p = mp.Process(target=run_dtr_resnet1202)
  109. p.start()
  110. p.join()
  111. assert p.exitcode == 0