You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dtr.py 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import numpy as np
  2. import pytest
  3. import megengine as mge
  4. import megengine.functional as F
  5. import megengine.module as M
  6. import megengine.optimizer as optim
  7. import megengine.tensor as tensor
  8. from megengine.autodiff import GradManager
  9. from megengine.data import DataLoader, RandomSampler, transform
  10. from megengine.data.dataset import CIFAR10
  11. def _weights_init(m):
  12. classname = m.__class__.__name__
  13. if isinstance(m, M.Linear) or isinstance(m, M.Conv2d):
  14. M.init.msra_normal_(m.weight)
  15. mean = [125.3, 123.0, 113.9]
  16. std = [63.0, 62.1, 66.7]
  17. class BasicBlock(M.Module):
  18. expansion = 1
  19. def __init__(self, in_planes, planes, stride=1):
  20. super(BasicBlock, self).__init__()
  21. self.conv1 = M.Conv2d(
  22. in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
  23. )
  24. self.bn1 = M.BatchNorm2d(planes)
  25. self.conv2 = M.Conv2d(
  26. planes, planes, kernel_size=3, stride=1, padding=1, bias=False
  27. )
  28. self.bn2 = M.BatchNorm2d(planes)
  29. self.shortcut = M.Sequential()
  30. if stride != 1 or in_planes != planes:
  31. self.shortcut = M.Sequential(
  32. M.Conv2d(
  33. in_planes,
  34. self.expansion * planes,
  35. kernel_size=1,
  36. stride=stride,
  37. bias=False,
  38. ),
  39. M.BatchNorm2d(self.expansion * planes),
  40. )
  41. def forward(self, x):
  42. out = F.relu(self.bn1(self.conv1(x)))
  43. out = self.bn2(self.conv2(out))
  44. out += self.shortcut(x)
  45. out = F.relu(out)
  46. return out
  47. class ResNet(M.Module):
  48. def __init__(self, block, num_blocks, num_classes=10):
  49. super(ResNet, self).__init__()
  50. self.in_planes = 16
  51. self.conv1 = M.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
  52. self.bn1 = M.BatchNorm2d(16)
  53. self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
  54. self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
  55. self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
  56. self.linear = M.Linear(64, num_classes)
  57. self.apply(_weights_init)
  58. def _make_layer(self, block, planes, num_blocks, stride):
  59. strides = [stride] + [1] * (num_blocks - 1)
  60. layers = []
  61. for stride in strides:
  62. layers.append(block(self.in_planes, planes, stride))
  63. self.in_planes = planes * block.expansion
  64. return M.Sequential(*layers)
  65. def forward(self, x):
  66. out = F.relu(self.bn1(self.conv1(x)))
  67. out = self.layer1(out)
  68. out = self.layer2(out)
  69. out = self.layer3(out)
  70. out = out.mean(3).mean(2)
  71. out = self.linear(out)
  72. return out
  73. @pytest.mark.require_ngpu(1)
  74. def test_dtr_resnet1202():
  75. batch_size = 8
  76. resnet1202 = ResNet(BasicBlock, [200, 200, 200])
  77. opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4)
  78. gm = GradManager().attach(resnet1202.parameters())
  79. def train_func(data, label, *, net, gm):
  80. net.train()
  81. with gm:
  82. pred = net(data)
  83. loss = F.loss.cross_entropy(pred, label)
  84. gm.backward(loss)
  85. return pred, loss
  86. _, free_mem = mge.device.get_mem_status_bytes()
  87. tensor_mem = free_mem - (2 ** 30)
  88. if tensor_mem > 0:
  89. x = np.ones((1, int(tensor_mem / 4)), dtype=np.float32)
  90. else:
  91. x = np.ones((1,), dtype=np.float32)
  92. t = mge.tensor(x)
  93. mge.dtr.enable()
  94. mge.dtr.enable_sqrt_sampling = True
  95. data = np.random.randn(batch_size, 3, 32, 32).astype("float32")
  96. label = np.random.randint(0, 10, size=(batch_size,)).astype("int32")
  97. for _ in range(2):
  98. opt.clear_grad()
  99. _, loss = train_func(mge.tensor(data), mge.tensor(label), net=resnet1202, gm=gm)
  100. opt.step()
  101. loss.item()
  102. t.numpy()
  103. mge.dtr.disable()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台