You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

macro.py 3.3 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import torch.nn as nn
  4. import sys
  5. sys.path.append('..'+ '/' + '..')
  6. from pytorch import mutables # LayerChoice, InputChoice, MutableScope
  7. from ops import FactorizedReduce, ConvBranch, PoolBranch
  8. class ENASLayer(mutables.MutableScope):
  9. def __init__(self, key, prev_labels, in_filters, out_filters):
  10. super().__init__(key)
  11. self.in_filters = in_filters
  12. self.out_filters = out_filters
  13. self.mutable = mutables.LayerChoice([
  14. ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False),
  15. ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True),
  16. ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False),
  17. ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True),
  18. PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
  19. PoolBranch('max', in_filters, out_filters, 3, 1, 1)
  20. ])
  21. if len(prev_labels) > 0:
  22. self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None)
  23. else:
  24. self.skipconnect = None
  25. self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
  26. def forward(self, prev_layers):
  27. out = self.mutable(prev_layers[-1])
  28. if self.skipconnect is not None:
  29. connection = self.skipconnect(prev_layers[:-1])
  30. if connection is not None:
  31. out += connection
  32. return self.batch_norm(out)
  33. class GeneralNetwork(nn.Module):
  34. def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10,
  35. dropout_rate=0.0):
  36. super().__init__()
  37. self.num_layers = num_layers
  38. self.num_classes = num_classes
  39. self.out_filters = out_filters
  40. self.stem = nn.Sequential(
  41. nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False),
  42. nn.BatchNorm2d(out_filters)
  43. )
  44. pool_distance = self.num_layers // 3
  45. self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1]
  46. self.dropout_rate = dropout_rate
  47. self.dropout = nn.Dropout(self.dropout_rate)
  48. self.layers = nn.ModuleList()
  49. self.pool_layers = nn.ModuleList()
  50. labels = []
  51. for layer_id in range(self.num_layers):
  52. labels.append("layer_{}".format(layer_id))
  53. if layer_id in self.pool_layers_idx:
  54. self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters))
  55. self.layers.append(ENASLayer(labels[-1], labels[:-1], self.out_filters, self.out_filters))
  56. self.gap = nn.AdaptiveAvgPool2d(1)
  57. self.dense = nn.Linear(self.out_filters, self.num_classes)
  58. def forward(self, x):
  59. bs = x.size(0)
  60. cur = self.stem(x)
  61. layers = [cur]
  62. for layer_id in range(self.num_layers):
  63. cur = self.layers[layer_id](layers)
  64. layers.append(cur)
  65. if layer_id in self.pool_layers_idx:
  66. for i, layer in enumerate(layers):
  67. layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer)
  68. cur = layers[-1]
  69. cur = self.gap(cur).view(bs, -1)
  70. cur = self.dropout(cur)
  71. logits = self.dense(cur)
  72. return logits

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能