You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_transpose_bn.py 2.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from typing import Tuple, Union
  2. from ..functional import relu
  3. from .batchnorm import BatchNorm2d
  4. from .conv import ConvTranspose2d
  5. from .module import Module
  6. class _ConvTransposeBnActivation2d(Module):
  7. def __init__(
  8. self,
  9. in_channels: int,
  10. out_channels: int,
  11. kernel_size: Union[int, Tuple[int, int]],
  12. stride: Union[int, Tuple[int, int]] = 1,
  13. padding: Union[int, Tuple[int, int]] = 0,
  14. output_padding: Union[int, Tuple[int, int]] = 0,
  15. dilation: Union[int, Tuple[int, int]] = 1,
  16. groups: int = 1,
  17. bias: bool = True,
  18. conv_mode: str = "cross_correlation",
  19. compute_mode: str = "default",
  20. eps=1e-5,
  21. momentum=0.9,
  22. affine=True,
  23. track_running_stats=True,
  24. **kwargs
  25. ):
  26. super().__init__(**kwargs)
  27. self.conv_transpose2d = ConvTranspose2d(
  28. in_channels,
  29. out_channels,
  30. kernel_size,
  31. stride,
  32. padding,
  33. output_padding,
  34. dilation,
  35. groups,
  36. bias,
  37. conv_mode,
  38. compute_mode,
  39. **kwargs,
  40. )
  41. self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
  42. class ConvTransposeBn2d(_ConvTransposeBnActivation2d):
  43. r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d`.
  44. Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBn2d` using:func:`~.quantize.quantize_qat`.
  45. """
  46. def forward(self, inp):
  47. return self.bn(self.conv_transpose2d(inp))
  48. class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d):
  49. r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu`.
  50. Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBnRelu2d` using :func:`~.quantize.quantize_qat`.
  51. """
  52. def forward(self, inp):
  53. return relu(self.bn(self.conv_transpose2d(inp)))