You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import torch
  4. import torch.nn as nn
  5. class StdConv(nn.Module):
  6. def __init__(self, C_in, C_out):
  7. super(StdConv, self).__init__()
  8. self.conv = nn.Sequential(
  9. nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
  10. nn.BatchNorm2d(C_out, affine=False),
  11. nn.ReLU()
  12. )
  13. def forward(self, x):
  14. return self.conv(x)
  15. def __str__(self):
  16. return 'StdConv'
  17. class PoolBranch(nn.Module):
  18. def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False):
  19. super().__init__()
  20. self.kernel_size = kernel_size
  21. self.pool_type = pool_type
  22. self.preproc = StdConv(C_in, C_out)
  23. self.pool = Pool(pool_type, kernel_size, stride, padding)
  24. self.bn = nn.BatchNorm2d(C_out, affine=affine)
  25. def forward(self, x):
  26. out = self.preproc(x)
  27. out = self.pool(out)
  28. out = self.bn(out)
  29. return out
  30. def __str__(self):
  31. return '{}PoolBranch_{}'.format(self.pool_type, self.kernel_size)
  32. class SeparableConv(nn.Module):
  33. def __init__(self, C_in, C_out, kernel_size, stride, padding):
  34. self.kernel_size = kernel_size
  35. super(SeparableConv, self).__init__()
  36. self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride,
  37. groups=C_in, bias=False)
  38. self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False)
  39. def forward(self, x):
  40. out = self.depthwise(x)
  41. out = self.pointwise(out)
  42. return out
  43. def __str__(self):
  44. return 'SeparableConv_{}'.format(self.kernel_size)
  45. class ConvBranch(nn.Module):
  46. def __init__(self, C_in, C_out, kernel_size, stride, padding, separable):
  47. super(ConvBranch, self).__init__()
  48. self.kernel_size = kernel_size
  49. self.preproc = StdConv(C_in, C_out)
  50. if separable:
  51. self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding)
  52. else:
  53. self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding)
  54. self.postproc = nn.Sequential(
  55. nn.BatchNorm2d(C_out, affine=False),
  56. nn.ReLU()
  57. )
  58. def forward(self, x):
  59. out = self.preproc(x)
  60. out = self.conv(out)
  61. out = self.postproc(out)
  62. return out
  63. def __str__(self):
  64. return 'ConvBranch_{}'.format(self.kernel_size)
  65. class FactorizedReduce(nn.Module):
  66. def __init__(self, C_in, C_out, affine=False):
  67. super().__init__()
  68. self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
  69. self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
  70. self.bn = nn.BatchNorm2d(C_out, affine=affine)
  71. def forward(self, x):
  72. out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
  73. out = self.bn(out)
  74. return out
  75. def __str__(self):
  76. return 'FactorizedReduce'
  77. class Pool(nn.Module):
  78. def __init__(self, pool_type, kernel_size, stride, padding):
  79. super().__init__()
  80. self.kernel_size = kernel_size
  81. self.pool_type = pool_type
  82. if pool_type.lower() == 'max':
  83. self.pool = nn.MaxPool2d(kernel_size, stride, padding)
  84. elif pool_type.lower() == 'avg':
  85. self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
  86. else:
  87. raise ValueError()
  88. def forward(self, x):
  89. return self.pool(x)
  90. def __str__(self):
  91. return '{}Pool_{}'.format(self.pool_type, self.kernel_size)
  92. class SepConvBN(nn.Module):
  93. def __init__(self, C_in, C_out, kernel_size, padding):
  94. super().__init__()
  95. self.kernel_size = kernel_size
  96. self.relu = nn.ReLU()
  97. self.conv = SeparableConv(C_in, C_out, kernel_size, 1, padding)
  98. self.bn = nn.BatchNorm2d(C_out, affine=True)
  99. def forward(self, x):
  100. x = self.relu(x)
  101. x = self.conv(x)
  102. x = self.bn(x)
  103. return x
  104. def __str__(self):
  105. return 'SepConvBN_{}'.format(self.kernel_size)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能