You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # -*- coding: utf-8 -*-
  2. import collections.abc
  3. import functools
  4. import random
  5. import cv2
  6. import numpy as np
  7. def wrap_keepdims(func):
  8. r"""Wraper to keep the dimension of input images unchanged."""
  9. @functools.wraps(func)
  10. def wrapper(image, *args, **kwargs):
  11. if len(image.shape) != 3:
  12. raise ValueError(
  13. "image must have 3 dims, but got {} dims".format(len(image.shape))
  14. )
  15. ret = func(image, *args, **kwargs)
  16. if len(ret.shape) == 2:
  17. ret = ret[:, :, np.newaxis]
  18. return ret
  19. return wrapper
  20. @wrap_keepdims
  21. def to_gray(image):
  22. r"""Change BGR format image's color space to gray.
  23. Args:
  24. image: input BGR format image, with `(H, W, C)` shape.
  25. Returns:
  26. gray format image, with `(H, W, C)` shape.
  27. """
  28. return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  29. @wrap_keepdims
  30. def to_bgr(image):
  31. r"""Change gray format image's color space to BGR.
  32. Args:
  33. image: input Gray format image, with `(H, W, C)` shape.
  34. Returns:
  35. BGR format image, with `(H, W, C)` shape.
  36. """
  37. return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
  38. @wrap_keepdims
  39. def pad(input, size, value):
  40. r"""Pad input data with *value* and given *size*.
  41. Args:
  42. input: input data, with `(H, W, C)` shape.
  43. size: padding size of input data, it could be integer or sequence.
  44. If it is an integer, the input data will be padded in four directions.
  45. If it is a sequence contains two integer, the bottom and right side
  46. of input data will be padded.
  47. If it is a sequence contains four integer, the top, bottom, left, right
  48. side of input data will be padded with given size.
  49. value: padding value of data, could be a sequence of int or float.
  50. If it is float value, the dtype of image will be casted to float32 also.
  51. Returns:
  52. padded image.
  53. """
  54. if isinstance(size, int):
  55. size = (size, size, size, size)
  56. elif isinstance(size, collections.abc.Sequence) and len(size) == 2:
  57. size = (0, size[0], 0, size[1])
  58. if np.array(value).dtype == float:
  59. input = input.astype(np.float32)
  60. return cv2.copyMakeBorder(input, *size, cv2.BORDER_CONSTANT, value=value)
  61. @wrap_keepdims
  62. def flip(image, flipCode):
  63. r"""Accordding to the flipCode (the type of flip), flip the input image.
  64. Args:
  65. image: input image, with `(H, W, C)` shape.
  66. flipCode: code that indicates the type of flip.
  67. * 1 : Flip horizontally
  68. * 0 : Flip vertically
  69. * -1: Flip horizontally and vertically
  70. Returns:
  71. BGR format image, with `(H, W, C)` shape.
  72. """
  73. return cv2.flip(image, flipCode=flipCode)
  74. @wrap_keepdims
  75. def resize(input, size, interpolation=cv2.INTER_LINEAR):
  76. r"""Resize the input data to given size.
  77. Args:
  78. input: input data, could be image or masks, with `(H, W, C)` shape.
  79. size: target size of input data, with (height, width) shape.
  80. interpolation: interpolation method.
  81. Returns:
  82. resized data, with `(H, W, C)` shape.
  83. """
  84. if len(size) != 2:
  85. raise ValueError("resize needs (h, w), but got {}".format(size))
  86. if isinstance(interpolation, collections.abc.Sequence):
  87. interpolation = random.choice(interpolation)
  88. return cv2.resize(input, size[::-1], interpolation=interpolation)