You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

alexnet.py 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # Modified from https://github.com/pytorch/vision
  2. import torch.nn as nn
  3. import torch.utils.model_zoo as model_zoo
  4. __all__ = ['AlexNet', 'alexnet']
  5. model_urls = {
  6. 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
  7. }
  8. class AlexNet(nn.Module):
  9. def __init__(self, num_classes=1000):
  10. super(AlexNet, self).__init__()
  11. self.features = nn.Sequential(
  12. nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
  13. nn.ReLU(inplace=True),
  14. nn.MaxPool2d(kernel_size=3, stride=2),
  15. nn.Conv2d(64, 192, kernel_size=5, padding=2),
  16. nn.ReLU(inplace=True),
  17. nn.MaxPool2d(kernel_size=3, stride=2),
  18. nn.Conv2d(192, 384, kernel_size=3, padding=1),
  19. nn.ReLU(inplace=True),
  20. nn.Conv2d(384, 256, kernel_size=3, padding=1),
  21. nn.ReLU(inplace=True),
  22. nn.Conv2d(256, 256, kernel_size=3, padding=1),
  23. nn.ReLU(inplace=True),
  24. nn.MaxPool2d(kernel_size=3, stride=2),
  25. )
  26. self.classifier = nn.Sequential(
  27. nn.Dropout(),
  28. nn.Linear(256 * 6 * 6, 4096),
  29. nn.ReLU(inplace=True),
  30. nn.Dropout(),
  31. nn.Linear(4096, 4096),
  32. nn.ReLU(inplace=True),
  33. nn.Linear(4096, num_classes),
  34. )
  35. def forward(self, x):
  36. x = self.features(x)
  37. x = x.view(x.size(0), 256 * 6 * 6)
  38. x = self.classifier(x)
  39. return x
  40. def alexnet(pretrained=False, **kwargs):
  41. r"""AlexNet model architecture from the
  42. `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
  43. Args:
  44. pretrained (bool): If True, returns a model pre-trained on ImageNet
  45. """
  46. num_classes = kwargs.pop('num_classes', None)
  47. model = AlexNet(**kwargs)
  48. if pretrained:
  49. model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
  50. if num_classes is not None and num_classes!=1000:
  51. model.classifier[-1] = nn.Linear(4096, num_classes)
  52. return model

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)