You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

blocks.py 3.3 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import torch
  2. import torch.nn as nn
  3. class ShuffleNetBlock(nn.Module):
  4. """
  5. When stride = 1, the block receives input with 2 * inp channels. Otherwise inp channels.
  6. """
  7. def __init__(self, inp, oup, mid_channels, ksize, stride, sequence="pdp", affine=True):
  8. super().__init__()
  9. assert stride in [1, 2]
  10. assert ksize in [3, 5, 7]
  11. self.channels = inp // 2 if stride == 1 else inp
  12. self.inp = inp
  13. self.oup = oup
  14. self.mid_channels = mid_channels
  15. self.ksize = ksize
  16. self.stride = stride
  17. self.pad = ksize // 2
  18. self.oup_main = oup - self.channels
  19. self._affine = affine
  20. assert self.oup_main > 0
  21. self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence))
  22. if stride == 2:
  23. self.branch_proj = nn.Sequential(
  24. # dw
  25. nn.Conv2d(self.channels, self.channels, ksize, stride, self.pad,
  26. groups=self.channels, bias=False),
  27. nn.BatchNorm2d(self.channels, affine=affine),
  28. # pw-linear
  29. nn.Conv2d(self.channels, self.channels, 1, 1, 0, bias=False),
  30. nn.BatchNorm2d(self.channels, affine=affine),
  31. nn.ReLU(inplace=True)
  32. )
  33. def forward(self, x):
  34. if self.stride == 2:
  35. x_proj, x = self.branch_proj(x), x
  36. else:
  37. x_proj, x = self._channel_shuffle(x)
  38. return torch.cat((x_proj, self.branch_main(x)), 1)
  39. def _decode_point_depth_conv(self, sequence):
  40. result = []
  41. first_depth = first_point = True
  42. pc = c = self.channels
  43. for i, token in enumerate(sequence):
  44. # compute output channels of this conv
  45. if i + 1 == len(sequence):
  46. assert token == "p", "Last conv must be point-wise conv."
  47. c = self.oup_main
  48. elif token == "p" and first_point:
  49. c = self.mid_channels
  50. if token == "d":
  51. # depth-wise conv
  52. assert pc == c, "Depth-wise conv must not change channels."
  53. result.append(nn.Conv2d(pc, c, self.ksize, self.stride if first_depth else 1, self.pad,
  54. groups=c, bias=False))
  55. result.append(nn.BatchNorm2d(c, affine=self._affine))
  56. first_depth = False
  57. elif token == "p":
  58. # point-wise conv
  59. result.append(nn.Conv2d(pc, c, 1, 1, 0, bias=False))
  60. result.append(nn.BatchNorm2d(c, affine=self._affine))
  61. result.append(nn.ReLU(inplace=True))
  62. first_point = False
  63. else:
  64. raise ValueError("Conv sequence must be d and p.")
  65. pc = c
  66. return result
  67. def _channel_shuffle(self, x):
  68. bs, num_channels, height, width = x.data.size()
  69. assert (num_channels % 4 == 0)
  70. x = x.reshape(bs * num_channels // 2, 2, height * width)
  71. x = x.permute(1, 0, 2)
  72. x = x.reshape(2, -1, num_channels // 2, height, width)
  73. return x[0], x[1]
  74. class ShuffleXceptionBlock(ShuffleNetBlock):
  75. def __init__(self, inp, oup, mid_channels, stride, affine=True):
  76. super().__init__(inp, oup, mid_channels, 3, stride, "dpdpdp", affine)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能