You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

wrn.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. #Adapted from https://github.com/polo5/ZeroShotKnowledgeTransfer/blob/master/models/wresnet.py
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. __all__ = ['wrn']
  7. class BasicBlock(nn.Module):
  8. def __init__(self, in_planes, out_planes, stride, dropout_rate=0.0):
  9. super(BasicBlock, self).__init__()
  10. self.bn1 = nn.BatchNorm2d(in_planes)
  11. self.relu1 = nn.ReLU(inplace=True)
  12. self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  13. padding=1, bias=False)
  14. self.bn2 = nn.BatchNorm2d(out_planes)
  15. self.relu2 = nn.ReLU(inplace=True)
  16. self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
  17. padding=1, bias=False)
  18. self.dropout = nn.Dropout( dropout_rate )
  19. self.equalInOut = (in_planes == out_planes)
  20. self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
  21. padding=0, bias=False) or None
  22. def forward(self, x):
  23. if not self.equalInOut:
  24. x = self.relu1(self.bn1(x))
  25. else:
  26. out = self.relu1(self.bn1(x))
  27. out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
  28. out = self.dropout(out)
  29. out = self.conv2(out)
  30. return torch.add(x if self.equalInOut else self.convShortcut(x), out)
  31. class NetworkBlock(nn.Module):
  32. def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropout_rate=0.0):
  33. super(NetworkBlock, self).__init__()
  34. self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropout_rate)
  35. def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropout_rate):
  36. layers = []
  37. for i in range(nb_layers):
  38. layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropout_rate))
  39. return nn.Sequential(*layers)
  40. def forward(self, x):
  41. return self.layer(x)
  42. class WideResNet(nn.Module):
  43. def __init__(self, depth, num_classes, widen_factor=1, dropout_rate=0.0):
  44. super(WideResNet, self).__init__()
  45. nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
  46. assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
  47. n = (depth - 4) // 6
  48. block = BasicBlock
  49. # 1st conv before any network block
  50. self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
  51. padding=1, bias=False)
  52. # 1st block
  53. self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropout_rate)
  54. # 2nd block
  55. self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropout_rate)
  56. # 3rd block
  57. self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropout_rate)
  58. # global average pooling and classifier
  59. self.bn1 = nn.BatchNorm2d(nChannels[3])
  60. self.relu = nn.ReLU(inplace=True)
  61. self.fc = nn.Linear(nChannels[3], num_classes)
  62. self.nChannels = nChannels[3]
  63. for m in self.modules():
  64. if isinstance(m, nn.Conv2d):
  65. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  66. m.weight.data.normal_(0, math.sqrt(2. / n))
  67. elif isinstance(m, nn.BatchNorm2d):
  68. m.weight.data.fill_(1)
  69. m.bias.data.zero_()
  70. elif isinstance(m, nn.Linear):
  71. m.bias.data.zero_()
  72. def forward(self, x, return_features=False):
  73. out = self.conv1(x)
  74. out = self.block1(out)
  75. out = self.block2(out)
  76. out = self.block3(out)
  77. out = self.relu(self.bn1(out))
  78. out = F.avg_pool2d(out, 8)
  79. features = out.view(-1, self.nChannels)
  80. out = self.fc(features)
  81. if return_features:
  82. return out, features
  83. else:
  84. return out
  85. def wrn_16_1(num_classes, dropout_rate=0):
  86. return WideResNet(depth=16, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate)
  87. def wrn_16_2(num_classes, dropout_rate=0):
  88. return WideResNet(depth=16, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate)
  89. def wrn_40_1(num_classes, dropout_rate=0):
  90. return WideResNet(depth=40, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate)
  91. def wrn_40_2(num_classes, dropout_rate=0):
  92. return WideResNet(depth=40, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)