You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

micro.py 6.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from pytorch import mutables
  7. from ops import FactorizedReduce, StdConv, SepConvBN, Pool
  8. class AuxiliaryHead(nn.Module):
  9. def __init__(self, in_channels, num_classes):
  10. super().__init__()
  11. self.in_channels = in_channels
  12. self.num_classes = num_classes
  13. self.pooling = nn.Sequential(
  14. nn.ReLU(),
  15. nn.AvgPool2d(5, 3, 2)
  16. )
  17. self.proj = nn.Sequential(
  18. StdConv(in_channels, 128),
  19. StdConv(128, 768)
  20. )
  21. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  22. self.fc = nn.Linear(768, 10, bias=False)
  23. def forward(self, x):
  24. bs = x.size(0)
  25. x = self.pooling(x)
  26. x = self.proj(x)
  27. x = self.avg_pool(x).view(bs, -1)
  28. x = self.fc(x)
  29. return x
  30. class Cell(nn.Module):
  31. def __init__(self, cell_name, prev_labels, channels):
  32. super().__init__()
  33. self.input_choice = mutables.InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True,
  34. key=cell_name + "_input")
  35. self.op_choice = mutables.LayerChoice([
  36. SepConvBN(channels, channels, 3, 1),
  37. SepConvBN(channels, channels, 5, 2),
  38. Pool("avg", 3, 1, 1),
  39. Pool("max", 3, 1, 1),
  40. nn.Identity()
  41. ], key=cell_name + "_op")
  42. def forward(self, prev_layers):
  43. chosen_input, chosen_mask = self.input_choice(prev_layers)
  44. cell_out = self.op_choice(chosen_input)
  45. return cell_out, chosen_mask
  46. class Node(mutables.MutableScope):
  47. def __init__(self, node_name, prev_node_names, channels):
  48. super().__init__(node_name)
  49. self.cell_x = Cell(node_name + "_x", prev_node_names, channels)
  50. self.cell_y = Cell(node_name + "_y", prev_node_names, channels)
  51. def forward(self, prev_layers):
  52. out_x, mask_x = self.cell_x(prev_layers)
  53. out_y, mask_y = self.cell_y(prev_layers)
  54. return out_x + out_y, mask_x | mask_y
  55. class Calibration(nn.Module):
  56. def __init__(self, in_channels, out_channels):
  57. super().__init__()
  58. self.process = None
  59. if in_channels != out_channels:
  60. self.process = StdConv(in_channels, out_channels)
  61. def forward(self, x):
  62. if self.process is None:
  63. return x
  64. return self.process(x)
  65. class ReductionLayer(nn.Module):
  66. def __init__(self, in_channels_pp, in_channels_p, out_channels):
  67. super().__init__()
  68. self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False)
  69. self.reduce1 = FactorizedReduce(in_channels_p, out_channels, affine=False)
  70. def forward(self, pprev, prev):
  71. return self.reduce0(pprev), self.reduce1(prev)
  72. class ENASLayer(nn.Module):
  73. def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction):
  74. super().__init__()
  75. self.preproc0 = Calibration(in_channels_pp, out_channels)
  76. self.preproc1 = Calibration(in_channels_p, out_channels)
  77. self.num_nodes = num_nodes
  78. name_prefix = "reduce" if reduction else "normal"
  79. self.nodes = nn.ModuleList()
  80. node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY]
  81. for i in range(num_nodes):
  82. node_labels.append("{}_node_{}".format(name_prefix, i))
  83. self.nodes.append(Node(node_labels[-1], node_labels[:-1], out_channels))
  84. self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1), requires_grad=True)
  85. self.bn = nn.BatchNorm2d(out_channels, affine=False)
  86. self.reset_parameters()
  87. def reset_parameters(self):
  88. nn.init.kaiming_normal_(self.final_conv_w)
  89. def forward(self, pprev, prev):
  90. pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev)
  91. prev_nodes_out = [pprev_, prev_]
  92. nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device)
  93. for i in range(self.num_nodes):
  94. node_out, mask = self.nodes[i](prev_nodes_out)
  95. nodes_used_mask[:mask.size(0)] |= mask.to(node_out.device)
  96. prev_nodes_out.append(node_out)
  97. unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1)
  98. unused_nodes = F.relu(unused_nodes)
  99. conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :]
  100. conv_weight = conv_weight.view(conv_weight.size(0), -1, 1, 1)
  101. out = F.conv2d(unused_nodes, conv_weight)
  102. return prev, self.bn(out)
  103. class MicroNetwork(nn.Module):
  104. def __init__(self, num_layers=2, num_nodes=5, out_channels=24, in_channels=3, num_classes=10,
  105. dropout_rate=0.0, use_aux_heads=False):
  106. super().__init__()
  107. self.num_layers = num_layers
  108. self.use_aux_heads = use_aux_heads
  109. self.stem = nn.Sequential(
  110. nn.Conv2d(in_channels, out_channels * 3, 3, 1, 1, bias=False),
  111. nn.BatchNorm2d(out_channels * 3)
  112. )
  113. pool_distance = self.num_layers // 3
  114. pool_layers = [pool_distance, 2 * pool_distance + 1]
  115. self.dropout = nn.Dropout(dropout_rate)
  116. self.layers = nn.ModuleList()
  117. c_pp = c_p = out_channels * 3
  118. c_cur = out_channels
  119. for layer_id in range(self.num_layers + 2):
  120. reduction = False
  121. if layer_id in pool_layers:
  122. c_cur, reduction = c_p * 2, True
  123. self.layers.append(ReductionLayer(c_pp, c_p, c_cur))
  124. c_pp = c_p = c_cur
  125. self.layers.append(ENASLayer(num_nodes, c_pp, c_p, c_cur, reduction))
  126. if self.use_aux_heads and layer_id == pool_layers[-1] + 1:
  127. self.layers.append(AuxiliaryHead(c_cur, num_classes))
  128. c_pp, c_p = c_p, c_cur
  129. self.gap = nn.AdaptiveAvgPool2d(1)
  130. self.dense = nn.Linear(c_cur, num_classes)
  131. self.reset_parameters()
  132. def reset_parameters(self):
  133. for m in self.modules():
  134. if isinstance(m, nn.Conv2d):
  135. nn.init.kaiming_normal_(m.weight)
  136. def forward(self, x):
  137. bs = x.size(0)
  138. prev = cur = self.stem(x)
  139. aux_logits = None
  140. for layer in self.layers:
  141. if isinstance(layer, AuxiliaryHead):
  142. if self.training:
  143. aux_logits = layer(cur)
  144. else:
  145. prev, cur = layer(prev, cur)
  146. cur = self.gap(F.relu(cur)).view(bs, -1)
  147. cur = self.dropout(cur)
  148. logits = self.dense(cur)
  149. if aux_logits is not None:
  150. return logits, aux_logits
  151. return logits

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能