You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections.abc
  10. import functools
  11. import random
  12. import cv2
  13. import numpy as np
  14. def wrap_keepdims(func):
  15. r"""Wraper to keep the dimension of input images unchanged."""
  16. @functools.wraps(func)
  17. def wrapper(image, *args, **kwargs):
  18. if len(image.shape) != 3:
  19. raise ValueError(
  20. "image must have 3 dims, but got {} dims".format(len(image.shape))
  21. )
  22. ret = func(image, *args, **kwargs)
  23. if len(ret.shape) == 2:
  24. ret = ret[:, :, np.newaxis]
  25. return ret
  26. return wrapper
  27. @wrap_keepdims
  28. def to_gray(image):
  29. r"""Change BGR format image's color space to gray.
  30. Args:
  31. image: input BGR format image, with `(H, W, C)` shape.
  32. Returns:
  33. gray format image, with `(H, W, C)` shape.
  34. """
  35. return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  36. @wrap_keepdims
  37. def to_bgr(image):
  38. r"""Change gray format image's color space to BGR.
  39. Args:
  40. image: input Gray format image, with `(H, W, C)` shape.
  41. Returns:
  42. BGR format image, with `(H, W, C)` shape.
  43. """
  44. return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
  45. @wrap_keepdims
  46. def pad(input, size, value):
  47. r"""Pad input data with *value* and given *size*.
  48. Args:
  49. input: input data, with `(H, W, C)` shape.
  50. size: padding size of input data, it could be integer or sequence.
  51. If it is an integer, the input data will be padded in four directions.
  52. If it is a sequence contains two integer, the bottom and right side
  53. of input data will be padded.
  54. If it is a sequence contains four integer, the top, bottom, left, right
  55. side of input data will be padded with given size.
  56. value: padding value of data, could be a sequence of int or float.
  57. If it is float value, the dtype of image will be casted to float32 also.
  58. Returns:
  59. padded image.
  60. """
  61. if isinstance(size, int):
  62. size = (size, size, size, size)
  63. elif isinstance(size, collections.abc.Sequence) and len(size) == 2:
  64. size = (0, size[0], 0, size[1])
  65. if np.array(value).dtype == float:
  66. input = input.astype(np.float32)
  67. return cv2.copyMakeBorder(input, *size, cv2.BORDER_CONSTANT, value=value)
  68. @wrap_keepdims
  69. def flip(image, flipCode):
  70. r"""Accordding to the flipCode (the type of flip), flip the input image.
  71. Args:
  72. image: input image, with `(H, W, C)` shape.
  73. flipCode: code that indicates the type of flip.
  74. * 1 : Flip horizontally
  75. * 0 : Flip vertically
  76. * -1: Flip horizontally and vertically
  77. Returns:
  78. BGR format image, with `(H, W, C)` shape.
  79. """
  80. return cv2.flip(image, flipCode=flipCode)
  81. @wrap_keepdims
  82. def resize(input, size, interpolation=cv2.INTER_LINEAR):
  83. r"""Resize the input data to given size.
  84. Args:
  85. input: input data, could be image or masks, with `(H, W, C)` shape.
  86. size: target size of input data, with (height, width) shape.
  87. interpolation: interpolation method.
  88. Returns:
  89. resized data, with `(H, W, C)` shape.
  90. """
  91. if len(size) != 2:
  92. raise ValueError("resize needs (h, w), but got {}".format(size))
  93. if isinstance(interpolation, collections.abc.Sequence):
  94. interpolation = random.choice(interpolation)
  95. return cv2.resize(input, size[::-1], interpolation=interpolation)