You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bn.py 1.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from typing import Tuple, Union
  2. from ..functional import relu
  3. from .batchnorm import BatchNorm2d
  4. from .conv import Conv2d
  5. from .module import Module
  6. class _ConvBnActivation2d(Module):
  7. def __init__(
  8. self,
  9. in_channels: int,
  10. out_channels: int,
  11. kernel_size: Union[int, Tuple[int, int]],
  12. stride: Union[int, Tuple[int, int]] = 1,
  13. padding: Union[int, Tuple[int, int]] = 0,
  14. dilation: Union[int, Tuple[int, int]] = 1,
  15. groups: int = 1,
  16. bias: bool = True,
  17. conv_mode: str = "cross_correlation",
  18. compute_mode: str = "default",
  19. eps=1e-5,
  20. momentum=0.9,
  21. affine=True,
  22. track_running_stats=True,
  23. padding_mode: str = "zeros",
  24. **kwargs
  25. ):
  26. super().__init__(**kwargs)
  27. self.conv = Conv2d(
  28. in_channels,
  29. out_channels,
  30. kernel_size,
  31. stride,
  32. padding,
  33. dilation,
  34. groups,
  35. bias,
  36. conv_mode,
  37. compute_mode,
  38. padding_mode,
  39. **kwargs,
  40. )
  41. self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
  42. class ConvBn2d(_ConvBnActivation2d):
  43. r"""A fused :class:`~.Module` including :class:`~.module.Conv2d` and :class:`~.module.BatchNorm2d`.
  44. Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvBn2d` using
  45. :func:`~.quantize.quantize_qat`.
  46. """
  47. def forward(self, inp):
  48. return self.bn(self.conv(inp))
  49. class ConvBnRelu2d(_ConvBnActivation2d):
  50. r"""A fused :class:`~.Module` including :class:`~.module.Conv2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu`.
  51. Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvBnRelu2d` using :func:`~.quantize.quantize_qat`.
  52. """
  53. def forward(self, inp):
  54. return relu(self.bn(self.conv(inp)))