You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.py 4.1 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import torch
  4. import torch.nn as nn
  5. class DropPath(nn.Module):
  6. def __init__(self, p=0.):
  7. """
  8. Drop path with probability.
  9. Parameters
  10. ----------
  11. p : float
  12. Probability of an path to be zeroed.
  13. """
  14. super().__init__()
  15. self.p = p
  16. def forward(self, x):
  17. if self.training and self.p > 0.:
  18. keep_prob = 1. - self.p
  19. # per data point mask
  20. mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob)
  21. return x / keep_prob * mask
  22. return x
  23. class PoolBN(nn.Module):
  24. """
  25. AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
  26. """
  27. def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
  28. super().__init__()
  29. if pool_type.lower() == 'max':
  30. self.pool = nn.MaxPool2d(kernel_size, stride, padding)
  31. elif pool_type.lower() == 'avg':
  32. self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
  33. else:
  34. raise ValueError()
  35. self.bn = nn.BatchNorm2d(C, affine=affine)
  36. def forward(self, x):
  37. out = self.pool(x)
  38. out = self.bn(out)
  39. return out
  40. class StdConv(nn.Module):
  41. """
  42. Standard conv: ReLU - Conv - BN
  43. """
  44. def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
  45. super().__init__()
  46. self.net = nn.Sequential(
  47. nn.ReLU(),
  48. nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
  49. nn.BatchNorm2d(C_out, affine=affine)
  50. )
  51. def forward(self, x):
  52. return self.net(x)
  53. class FacConv(nn.Module):
  54. """
  55. Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
  56. """
  57. def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
  58. super().__init__()
  59. self.net = nn.Sequential(
  60. nn.ReLU(),
  61. nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
  62. nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
  63. nn.BatchNorm2d(C_out, affine=affine)
  64. )
  65. def forward(self, x):
  66. return self.net(x)
  67. class DilConv(nn.Module):
  68. """
  69. (Dilated) depthwise separable conv.
  70. ReLU - (Dilated) depthwise separable - Pointwise - BN.
  71. If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
  72. """
  73. def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
  74. super().__init__()
  75. self.net = nn.Sequential(
  76. nn.ReLU(),
  77. nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
  78. bias=False),
  79. nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
  80. nn.BatchNorm2d(C_out, affine=affine)
  81. )
  82. def forward(self, x):
  83. return self.net(x)
  84. class SepConv(nn.Module):
  85. """
  86. Depthwise separable conv.
  87. DilConv(dilation=1) * 2.
  88. """
  89. def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
  90. super().__init__()
  91. self.net = nn.Sequential(
  92. DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
  93. DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
  94. )
  95. def forward(self, x):
  96. return self.net(x)
  97. class FactorizedReduce(nn.Module):
  98. """
  99. Reduce feature map size by factorized pointwise (stride=2).
  100. """
  101. def __init__(self, C_in, C_out, affine=True):
  102. super().__init__()
  103. self.relu = nn.ReLU()
  104. self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
  105. self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
  106. self.bn = nn.BatchNorm2d(C_out, affine=affine)
  107. def forward(self, x):
  108. x = self.relu(x)
  109. out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
  110. out = self.bn(out)
  111. return out

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能