You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

residual_block.py 3.0 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT License.
  3. # Written by Hao Du and Houwen Peng
  4. # email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. def conv3x3(in_planes, out_planes, stride=1):
  9. "3x3 convolution with padding"
  10. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  11. padding=1, bias=True)
  12. class BasicBlock(nn.Module):
  13. expansion = 1
  14. def __init__(self, inplanes, planes, stride=1, downsample=None):
  15. super(BasicBlock, self).__init__()
  16. self.conv1 = conv3x3(inplanes, planes, stride)
  17. self.bn1 = nn.BatchNorm2d(planes)
  18. self.relu = nn.ReLU(inplace=True)
  19. self.conv2 = conv3x3(planes, planes)
  20. self.bn2 = nn.BatchNorm2d(planes)
  21. self.downsample = downsample
  22. self.stride = stride
  23. def forward(self, x):
  24. residual = x
  25. out = self.conv1(x)
  26. out = self.bn1(out)
  27. out = self.relu(out)
  28. out = self.conv2(out)
  29. out = self.bn2(out)
  30. if self.downsample is not None:
  31. residual = self.downsample(x)
  32. out += residual
  33. out = self.relu(out)
  34. return out
  35. class Bottleneck(nn.Module):
  36. def __init__(self, inplanes, planes, stride=1, expansion=4):
  37. super(Bottleneck, self).__init__()
  38. planes = int(planes / expansion)
  39. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True)
  40. self.bn1 = nn.BatchNorm2d(planes)
  41. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
  42. padding=1, bias=True)
  43. self.bn2 = nn.BatchNorm2d(planes)
  44. self.conv3 = nn.Conv2d(
  45. planes,
  46. planes * expansion,
  47. kernel_size=1,
  48. bias=True)
  49. self.bn3 = nn.BatchNorm2d(planes * expansion)
  50. self.relu = nn.ReLU(inplace=True)
  51. self.stride = stride
  52. self.expansion = expansion
  53. if inplanes != planes * self.expansion:
  54. self.downsample = nn.Sequential(
  55. nn.Conv2d(inplanes, planes * self.expansion,
  56. kernel_size=1, stride=stride, bias=True),
  57. nn.BatchNorm2d(planes * self.expansion),
  58. )
  59. else:
  60. self.downsample = None
  61. def forward(self, x):
  62. residual = x
  63. out = self.conv1(x)
  64. out = self.bn1(out)
  65. out = self.relu(out)
  66. out = self.conv2(out)
  67. out = self.bn2(out)
  68. out = self.relu(out)
  69. out = self.conv3(out)
  70. out = self.bn3(out)
  71. if self.downsample is not None:
  72. residual = self.downsample(x)
  73. out += residual
  74. out = self.relu(out)
  75. return out
  76. def get_Bottleneck(in_c, out_c, stride):
  77. return Bottleneck(in_c, out_c, stride=stride)
  78. def get_BasicBlock(in_c, out_c, stride):
  79. return BasicBlock(in_c, out_c, stride=stride)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能