You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 7.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. from collections import OrderedDict
  4. import torch
  5. import torch.nn as nn
  6. from pytorch import mutables
  7. from pytorch.darts import ops
  8. class AuxiliaryHead(nn.Module):
  9. """ Auxiliary head in 2/3 place of network to let the gradient flow well """
  10. def __init__(self, input_size, C, n_classes):
  11. """ assuming input size 7x7 or 8x8 """
  12. assert input_size in [7, 8]
  13. super().__init__()
  14. self.net = nn.Sequential(
  15. nn.ReLU(inplace=True),
  16. nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out
  17. nn.Conv2d(C, 128, kernel_size=1, bias=False),
  18. nn.BatchNorm2d(128),
  19. nn.ReLU(inplace=True),
  20. nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out
  21. nn.BatchNorm2d(768),
  22. nn.ReLU(inplace=True)
  23. )
  24. self.linear = nn.Linear(768, n_classes)
  25. def forward(self, x):
  26. out = self.net(x)
  27. out = out.view(out.size(0), -1) # flatten
  28. logits = self.linear(out)
  29. return logits
  30. class Node(nn.Module):
  31. def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect, search, dropout_rate):
  32. super().__init__()
  33. self.dropout_rate = dropout_rate
  34. self.ops = nn.ModuleList()
  35. choice_keys = []
  36. for i in range(num_prev_nodes):
  37. stride = 2 if i < num_downsample_connect else 1
  38. choice_keys.append("{}_p{}".format(node_id, i))
  39. skip_op = nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)
  40. # In search, op-level dropout for skip-connect
  41. if search and self.dropout_rate > 0:
  42. skip_op = nn.Sequential(skip_op, nn.Dropout(self.dropout_rate))
  43. self.ops.append(
  44. mutables.LayerChoice(OrderedDict([
  45. ("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)),
  46. ("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)),
  47. ("skipconnect", skip_op),
  48. ("sepconv3x3", ops.SepConv(channels, channels, 3, stride, 1, affine=False)),
  49. ("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)),
  50. ("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)),
  51. ("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False))
  52. ]), key=choice_keys[-1]))
  53. # In retrain, DropPath for non skip-connect, p in DropPath default to 0
  54. self.drop_path = ops.DropPath()
  55. self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
  56. def forward(self, prev_nodes):
  57. assert len(self.ops) == len(prev_nodes)
  58. output = []
  59. for op, node in zip(self.ops, prev_nodes):
  60. out = op(node)
  61. # In retrain
  62. if out is not None:
  63. if not isinstance(op, nn.Identity):
  64. out = self.drop_path(out)
  65. else:
  66. out = None
  67. output.append(out)
  68. # out = [op(node) for op, node in zip(self.ops, prev_nodes)]
  69. # out = [self.drop_path(o) if o is not None else None for o in out]
  70. return self.input_switch(output)
  71. class Cell(nn.Module):
  72. def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction, search, dropout_rate):
  73. super().__init__()
  74. self.reduction = reduction
  75. self.n_nodes = n_nodes
  76. # If previous cell is reduction cell, current input size does not match with
  77. # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
  78. if reduction_p:
  79. self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
  80. else:
  81. self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
  82. self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)
  83. # generate dag
  84. self.mutable_ops = nn.ModuleList()
  85. for depth in range(2, self.n_nodes + 2):
  86. self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth), depth, channels, 2 if reduction else 0, search, dropout_rate))
  87. def forward(self, s0, s1):
  88. # s0, s1 are the outputs of previous previous cell and previous cell, respectively.
  89. tensors = [self.preproc0(s0), self.preproc1(s1)]
  90. for node in self.mutable_ops:
  91. cur_tensor = node(tensors)
  92. tensors.append(cur_tensor)
  93. output = torch.cat(tensors[2:], dim=1)
  94. return output
  95. class CNN(nn.Module):
  96. def __init__(self, input_size, in_channels, channels, n_classes, n_layers, dropout_rate, n_nodes=4, stem_multiplier=3, auxiliary=False, search=True):
  97. super().__init__()
  98. self.in_channels = in_channels
  99. self.channels = channels
  100. self.n_classes = n_classes
  101. self.n_layers = n_layers
  102. self.aux_pos = 2 * n_layers // 3 if auxiliary else -1
  103. c_cur = stem_multiplier * self.channels
  104. self.stem = nn.Sequential(
  105. nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
  106. nn.BatchNorm2d(c_cur)
  107. )
  108. # for the first cell, stem is used for both s0 and s1
  109. # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
  110. channels_pp, channels_p, c_cur = c_cur, c_cur, channels
  111. self.cells = nn.ModuleList()
  112. reduction_p, reduction = False, False
  113. for i in range(n_layers):
  114. reduction_p, reduction = reduction, False
  115. # Reduce featuremap size and double channels in 1/3 and 2/3 layer.
  116. if i in [n_layers // 3, 2 * n_layers // 3]:
  117. c_cur *= 2
  118. reduction = True
  119. cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction, search, dropout_rate)
  120. self.cells.append(cell)
  121. c_cur_out = c_cur * n_nodes
  122. channels_pp, channels_p = channels_p, c_cur_out
  123. if i == self.aux_pos:
  124. self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes)
  125. self.gap = nn.AdaptiveAvgPool2d(1)
  126. self.linear = nn.Linear(channels_p, n_classes)
  127. def forward(self, x):
  128. s0 = s1 = self.stem(x)
  129. aux_logits = None
  130. for i, cell in enumerate(self.cells):
  131. s0, s1 = s1, cell(s0, s1)
  132. if i == self.aux_pos and self.training:
  133. aux_logits = self.aux_head(s1)
  134. out = self.gap(s1)
  135. out = out.view(out.size(0), -1) # flatten
  136. logits = self.linear(out)
  137. if aux_logits is not None:
  138. return logits, aux_logits
  139. return logits
  140. def drop_path_prob(self, p, search=True):
  141. if search:
  142. for module in self.modules():
  143. # In search, update dropout rate
  144. if isinstance(module, nn.Sequential) and isinstance(module[0], nn.Identity):
  145. module[1].dropout_rate = p
  146. else:
  147. # In retrain, update ops.DropPath
  148. for module in self.modules():
  149. if isinstance(module, ops.DropPath):
  150. module.p = p

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能