You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 8.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. from collections import OrderedDict
  4. import torch
  5. import torch.nn as nn
  6. from pytorch import mutables
  7. from pytorch.darts import ops
  8. def random_channel_shuffle(x):
  9. num_channels = x.data.size()[1]
  10. indices = torch.randperm(num_channels)
  11. x = x[:, indices]
  12. return x
  13. def channel_shuffle(x, groups):
  14. batchsize, num_channels, height, width = x.data.size()
  15. channels_per_group = num_channels // groups
  16. # reshape
  17. x = x.view(batchsize, groups,
  18. channels_per_group, height, width)
  19. x = torch.transpose(x, 1, 2).contiguous()
  20. # flatten
  21. x = x.view(batchsize, -1, height, width)
  22. return x
  23. class AuxiliaryHead(nn.Module):
  24. """ Auxiliary head in 2/3 place of network to let the gradient flow well """
  25. def __init__(self, input_size, C, n_classes):
  26. """ assuming input size 7x7 or 8x8 """
  27. assert input_size in [7, 8]
  28. super().__init__()
  29. self.net = nn.Sequential(
  30. nn.ReLU(inplace=True),
  31. nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out
  32. nn.Conv2d(C, 128, kernel_size=1, bias=False),
  33. nn.BatchNorm2d(128),
  34. nn.ReLU(inplace=True),
  35. nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out
  36. nn.BatchNorm2d(768),
  37. nn.ReLU(inplace=True)
  38. )
  39. self.linear = nn.Linear(768, n_classes)
  40. def forward(self, x):
  41. out = self.net(x)
  42. out = out.view(out.size(0), -1) # flatten
  43. logits = self.linear(out)
  44. return logits
  45. class Node(nn.Module):
  46. def __init__(self, node_id, num_prev_nodes, channels, k, num_downsample_connect, search):
  47. super().__init__()
  48. if search:
  49. self.k = k
  50. partial_channles = channels // k
  51. else:
  52. partial_channles = channels
  53. self.search = search
  54. self.ops = nn.ModuleList()
  55. choice_keys = []
  56. for i in range(num_prev_nodes):
  57. stride = 2 if i < num_downsample_connect else 1
  58. choice_keys.append("{}_p{}".format(node_id, i))
  59. self.ops.append(
  60. mutables.LayerChoice(OrderedDict([
  61. ("maxpool", ops.PoolBN('max', partial_channles, 3, stride, 1, affine=False)),
  62. ("avgpool", ops.PoolBN('avg', partial_channles, 3, stride, 1, affine=False)),
  63. ("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(partial_channles, partial_channles, affine=False)),
  64. ("sepconv3x3", ops.SepConv(partial_channles, partial_channles, 3, stride, 1, affine=False)),
  65. ("sepconv5x5", ops.SepConv(partial_channles, partial_channles, 5, stride, 2, affine=False)),
  66. ("dilconv3x3", ops.DilConv(partial_channles, partial_channles, 3, stride, 2, 2, affine=False)),
  67. ("dilconv5x5", ops.DilConv(partial_channles, partial_channles, 5, stride, 4, 2, affine=False))
  68. ]), key=choice_keys[-1]))
  69. self.drop_path = ops.DropPath()
  70. self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
  71. self.pool = nn.MaxPool2d(2,2)
  72. def forward(self, prev_nodes):
  73. assert len(self.ops) == len(prev_nodes), "len(self.ops) != len(prev_nodes) in Node"
  74. # for each candicate predecessor of each intermediate node
  75. if self.search:
  76. # in search
  77. results = []
  78. for op, x in zip(self.ops, prev_nodes):
  79. # channel shuffle
  80. channels = x.shape[1]
  81. # channel proportion k=4
  82. temp0 = x[ : , : channels//self.k, : , :]
  83. temp1 = x[ : ,channels//self.k : , : , :]
  84. out = op(temp0)
  85. # normal
  86. if out.shape[2] == x.shape[2]:
  87. result = torch.cat([out, temp1], dim=1)
  88. # reduction
  89. else:
  90. result = torch.cat([out, self.pool(temp1)], dim=1)
  91. results.append(channel_shuffle(result, self.k))
  92. # # channel random shuffule
  93. # channels = random_channel_shuffle(x).shape[1]
  94. # # channel proportion k=4
  95. # temp0 = x[ : , : channels//self.k, : , :]
  96. # temp1 = x[ : ,channels//self.k : , : , :]
  97. # out = op(temp0)
  98. # # normal
  99. # if out.shape[2] == x.shape[2]:
  100. # result = torch.cat([out, temp1], dim=1)
  101. # # reduction
  102. # else:
  103. # result = torch.cat([out, self.pool(temp1)], dim=1)
  104. # results.append(result)
  105. else:
  106. # in retrain, no channel shuffle
  107. results = [op(node) for op, node in zip(self.ops, prev_nodes)]
  108. output = [self.drop_path(re) if re is not None else None for re in results]
  109. return self.input_switch(output)
  110. class Cell(nn.Module):
  111. def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction, k, search):
  112. super().__init__()
  113. self.reduction = reduction
  114. self.n_nodes = n_nodes
  115. # If previous cell is reduction cell, current input size does not match with
  116. # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
  117. if reduction_p:
  118. self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
  119. else:
  120. self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
  121. self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)
  122. # generate dag
  123. self.mutable_ops = nn.ModuleList()
  124. for depth in range(2, self.n_nodes + 2):
  125. self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth), depth, channels, k, 2 if reduction else 0, search))
  126. def forward(self, s0, s1):
  127. # s0, s1 are the outputs of previous previous cell and previous cell, respectively.
  128. tensors = [self.preproc0(s0), self.preproc1(s1)]
  129. for node in self.mutable_ops:
  130. cur_tensor = node(tensors)
  131. tensors.append(cur_tensor)
  132. output = torch.cat(tensors[2:], dim=1)
  133. return output
  134. class CNN(nn.Module):
  135. def __init__(self, input_size, in_channels, channels, n_classes, n_layers, k=4, n_nodes=4, stem_multiplier=3, auxiliary=False, search=True):
  136. super().__init__()
  137. self.in_channels = in_channels
  138. self.channels = channels
  139. self.n_classes = n_classes
  140. self.n_layers = n_layers
  141. self.n_nodes = n_nodes
  142. self.aux_pos = 2 * n_layers // 3 if auxiliary else -1
  143. c_cur = stem_multiplier * self.channels
  144. self.stem = nn.Sequential(
  145. nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
  146. nn.BatchNorm2d(c_cur)
  147. )
  148. # for the first cell, stem is used for both s0 and s1
  149. # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
  150. channels_pp, channels_p, c_cur = c_cur, c_cur, channels
  151. self.cells = nn.ModuleList()
  152. reduction_p, reduction = False, False
  153. for i in range(n_layers):
  154. reduction_p, reduction = reduction, False
  155. # Reduce featuremap size and double channels in 1/3 and 2/3 layer.
  156. if i in [n_layers // 3, 2 * n_layers // 3]:
  157. c_cur *= 2
  158. reduction = True
  159. cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction, k, search)
  160. self.cells.append(cell)
  161. c_cur_out = c_cur * n_nodes
  162. channels_pp, channels_p = channels_p, c_cur_out
  163. if i == self.aux_pos:
  164. self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes)
  165. self.gap = nn.AdaptiveAvgPool2d(1)
  166. self.linear = nn.Linear(channels_p, n_classes)
  167. def forward(self, x):
  168. s0 = s1 = self.stem(x)
  169. aux_logits = None
  170. for i, cell in enumerate(self.cells):
  171. s0, s1 = s1, cell(s0, s1)
  172. if i == self.aux_pos and self.training:
  173. aux_logits = self.aux_head(s1)
  174. out = self.gap(s1)
  175. out = out.view(out.size(0), -1) # flatten
  176. logits = self.linear(out)
  177. if aux_logits is not None:
  178. return logits, aux_logits
  179. return logits
  180. def drop_path_prob(self, p):
  181. for module in self.modules():
  182. if isinstance(module, ops.DropPath):
  183. module.p = p
  184. def _loss(self, input, target):
  185. logits = self(input)
  186. return self._criterion(logits, target)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能