You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

padding.py 2.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from typing import Tuple
  2. from ..functional import nn
  3. from .module import Module
  4. class Pad(Module):
  5. r"""Pads the input tensor.
  6. Args:
  7. pad_width: A tuple. Each element in the tuple is the tuple of 2-elements,
  8. the 2 elements represent the padding size on both sides of the current dimension, ``(front_offset, back_offset)``
  9. mode: One of the following string values. Default: ``'constant'``
  10. * ``'constant'``: Pads with a constant value.
  11. * ``'reflect'``: Pads with the edge values of tensor.
  12. * ``'replicate'``: Pads with the reflection of the tensor mirrored on the first and last values of the tensor along each axis.
  13. constant_val: Fill value for ``'constant'`` padding. Default: 0
  14. Examples:
  15. >>> import numpy as np
  16. >>> inp = Tensor([[1., 2., 3.],[4., 5., 6.]])
  17. >>> inp
  18. Tensor([[1. 2. 3.]
  19. [4. 5. 6.]], device=xpux:0)
  20. >>> m = M.Pad(pad_width=((1, 1),), mode="constant")
  21. >>> m(inp)
  22. Tensor([[0. 0. 0.]
  23. [1. 2. 3.]
  24. [4. 5. 6.]
  25. [0. 0. 0.]], device=xpux:0)
  26. >>> m = M.Pad(pad_width=((1, 1),), mode="constant", constant_val=9)
  27. >>> m(inp)
  28. Tensor([[9. 9. 9.]
  29. [1. 2. 3.]
  30. [4. 5. 6.]
  31. [9. 9. 9.]], device=xpux:0)
  32. >>> m = M.Pad(pad_width=((1, 1), (1, 2)), mode="reflect")
  33. >>> m(inp)
  34. Tensor([[5. 4. 5. 6. 5. 4.]
  35. [2. 1. 2. 3. 2. 1.]
  36. [5. 4. 5. 6. 5. 4.]
  37. [2. 1. 2. 3. 2. 1.]], device=xpux:0)
  38. >>> m = M.Pad(pad_width=((1, 1), (1, 2)), mode="replicate")
  39. >>> m(inp)
  40. Tensor([[1. 1. 2. 3. 3. 3.]
  41. [1. 1. 2. 3. 3. 3.]
  42. [4. 4. 5. 6. 6. 6.]
  43. [4. 4. 5. 6. 6. 6.]], device=xpux:0)
  44. """
  45. def __init__(
  46. self,
  47. pad_width: Tuple[Tuple[int, int], ...],
  48. mode: str = "constant",
  49. constant_val: float = 0.0,
  50. ):
  51. super().__init__()
  52. self.pad_width = pad_width
  53. self.mode = mode
  54. self.pad_val = constant_val
  55. def forward(self, src):
  56. return nn.pad(
  57. src, pad_width=self.pad_width, mode=self.mode, constant_value=self.pad_val
  58. )