You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 6.3 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. from collections import OrderedDict
  4. import torch
  5. import torch.nn as nn
  6. import ops
  7. import pytorch.mutables as mutables
  8. class AuxiliaryHead(nn.Module):
  9. """ Auxiliary head in 2/3 place of network to let the gradient flow well """
  10. def __init__(self, input_size, C, n_classes):
  11. """ assuming input size 7x7 or 8x8 """
  12. assert input_size in [7, 8]
  13. super().__init__()
  14. self.net = nn.Sequential(
  15. nn.ReLU(inplace=True),
  16. nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out
  17. nn.Conv2d(C, 128, kernel_size=1, bias=False),
  18. nn.BatchNorm2d(128),
  19. nn.ReLU(inplace=True),
  20. nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out
  21. nn.BatchNorm2d(768),
  22. nn.ReLU(inplace=True)
  23. )
  24. self.linear = nn.Linear(768, n_classes)
  25. def forward(self, x):
  26. out = self.net(x)
  27. out = out.view(out.size(0), -1) # flatten
  28. logits = self.linear(out)
  29. return logits
  30. class Node(nn.Module):
  31. def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
  32. super().__init__()
  33. self.ops = nn.ModuleList()
  34. choice_keys = []
  35. for i in range(num_prev_nodes):
  36. stride = 2 if i < num_downsample_connect else 1
  37. choice_keys.append("{}_p{}".format(node_id, i))
  38. self.ops.append(
  39. mutables.LayerChoice(OrderedDict([
  40. ("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)),
  41. ("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)),
  42. ("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)),
  43. ("sepconv3x3", ops.SepConv(channels, channels, 3, stride, 1, affine=False)),
  44. ("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)),
  45. ("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)),
  46. ("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False))
  47. ]), key=choice_keys[-1]))
  48. self.drop_path = ops.DropPath()
  49. self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
  50. def forward(self, prev_nodes):
  51. assert len(self.ops) == len(prev_nodes)
  52. out = [op(node) for op, node in zip(self.ops, prev_nodes)]
  53. out = [self.drop_path(o) if o is not None else None for o in out]
  54. return self.input_switch(out)
  55. class Cell(nn.Module):
  56. def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
  57. super().__init__()
  58. self.reduction = reduction
  59. self.n_nodes = n_nodes
  60. # If previous cell is reduction cell, current input size does not match with
  61. # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
  62. if reduction_p:
  63. self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
  64. else:
  65. self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
  66. self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)
  67. # generate dag
  68. self.mutable_ops = nn.ModuleList()
  69. for depth in range(2, self.n_nodes + 2):
  70. self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth),
  71. depth, channels, 2 if reduction else 0))
  72. def forward(self, s0, s1):
  73. # s0, s1 are the outputs of previous previous cell and previous cell, respectively.
  74. tensors = [self.preproc0(s0), self.preproc1(s1)]
  75. for node in self.mutable_ops:
  76. cur_tensor = node(tensors)
  77. tensors.append(cur_tensor)
  78. output = torch.cat(tensors[2:], dim=1)
  79. return output
  80. class CNN(nn.Module):
  81. def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
  82. stem_multiplier=3, auxiliary=False):
  83. super().__init__()
  84. self.in_channels = in_channels
  85. self.channels = channels
  86. self.n_classes = n_classes
  87. self.n_layers = n_layers
  88. self.aux_pos = 2 * n_layers // 3 if auxiliary else -1
  89. c_cur = stem_multiplier * self.channels
  90. self.stem = nn.Sequential(
  91. nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
  92. nn.BatchNorm2d(c_cur)
  93. )
  94. # for the first cell, stem is used for both s0 and s1
  95. # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
  96. channels_pp, channels_p, c_cur = c_cur, c_cur, channels
  97. self.cells = nn.ModuleList()
  98. reduction_p, reduction = False, False
  99. for i in range(n_layers):
  100. reduction_p, reduction = reduction, False
  101. # Reduce featuremap size and double channels in 1/3 and 2/3 layer.
  102. if i in [n_layers // 3, 2 * n_layers // 3]:
  103. c_cur *= 2
  104. reduction = True
  105. cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
  106. self.cells.append(cell)
  107. c_cur_out = c_cur * n_nodes
  108. channels_pp, channels_p = channels_p, c_cur_out
  109. if i == self.aux_pos:
  110. self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes)
  111. self.gap = nn.AdaptiveAvgPool2d(1)
  112. self.linear = nn.Linear(channels_p, n_classes)
  113. def forward(self, x):
  114. s0 = s1 = self.stem(x)
  115. aux_logits = None
  116. for i, cell in enumerate(self.cells):
  117. s0, s1 = s1, cell(s0, s1)
  118. if i == self.aux_pos and self.training:
  119. aux_logits = self.aux_head(s1)
  120. out = self.gap(s1)
  121. out = out.view(out.size(0), -1) # flatten
  122. logits = self.linear(out)
  123. if aux_logits is not None:
  124. return logits, aux_logits
  125. return logits
  126. def drop_path_prob(self, p):
  127. for module in self.modules():
  128. if isinstance(module, ops.DropPath):
  129. module.p = p

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能