You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.py 37 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075
  1. # -*- coding: utf-8 -*-
  2. import collections.abc
  3. import math
  4. from typing import List, Sequence, Tuple
  5. import cv2
  6. import numpy as np
  7. from megengine.data.transform import Transform
  8. from megengine.data.transform.vision import functional as F
  9. __all__ = [
  10. "VisionTransform",
  11. "ToMode",
  12. "Compose",
  13. "TorchTransformCompose",
  14. "Pad",
  15. "Resize",
  16. "ShortestEdgeResize",
  17. "RandomResize",
  18. "RandomCrop",
  19. "RandomResizedCrop",
  20. "CenterCrop",
  21. "RandomHorizontalFlip",
  22. "RandomVerticalFlip",
  23. "Normalize",
  24. "GaussianNoise",
  25. "BrightnessTransform",
  26. "SaturationTransform",
  27. "ContrastTransform",
  28. "HueTransform",
  29. "ColorJitter",
  30. "Lighting",
  31. ]
  32. class VisionTransform(Transform):
  33. r"""Base class of all transforms used in computer vision.
  34. Calling logic: apply_batch() -> apply() -> _apply_image() and other _apply_*()
  35. method. If you want to implement a self-defined transform method for image,
  36. rewrite _apply_image method in subclass.
  37. Args:
  38. order: input type order. Input is a tuple containing different structures,
  39. order is used to specify the order of structures. For example, if your input
  40. is (image, boxes) type, then the ``order`` should be ("image", "boxes").
  41. Current available strings and data type are describe below:
  42. * "image": input image, with shape of `(H, W, C)`.
  43. * "coords": coordinates, with shape of `(N, 2)`.
  44. * "boxes": bounding boxes, with shape of `(N, 4)`, "xyxy" format,
  45. the 1st "xy" represents top left point of a box,
  46. the 2nd "xy" represents right bottom point.
  47. * "mask": map used for segmentation, with shape of `(H, W, 1)`.
  48. * "keypoints": keypoints with shape of `(N, K, 3)`, N for number of instances,
  49. and K for number of keypoints in one instance. The first two dimensions
  50. of last axis is coordinate of keypoints and the the 3rd dimension is
  51. the label of keypoints.
  52. * "polygons": a sequence containing numpy arrays, its length is the number of instances.
  53. Each numpy array represents polygon coordinate of one instance.
  54. * "category": categories for some data type. For example, "image_category"
  55. means category of the input image and "boxes_category" means categories of
  56. bounding boxes.
  57. * "info": information for images such as image shapes and image path.
  58. You can also customize your data types only if you implement the corresponding
  59. _apply_*() methods, otherwise ``NotImplementedError`` will be raised.
  60. """
  61. def __init__(self, order=None):
  62. super().__init__()
  63. if order is None:
  64. order = ("image",)
  65. elif not isinstance(order, collections.abc.Sequence):
  66. raise ValueError(
  67. "order should be a sequence, but got order={}".format(order)
  68. )
  69. for k in order:
  70. if k in ("batch",):
  71. raise ValueError("{} is invalid data type".format(k))
  72. elif k.endswith("category") or k.endswith("info"):
  73. # when the key is *category or info, we should do nothing
  74. # if the corresponding apply methods are not implemented.
  75. continue
  76. elif self._get_apply(k) is None:
  77. raise NotImplementedError("{} is unsupported data type".format(k))
  78. self.order = order
  79. def apply_batch(self, inputs: Sequence[Tuple]):
  80. r"""Apply transform on batch input data."""
  81. return tuple(self.apply(input) for input in inputs)
  82. def apply(self, input: Tuple):
  83. r"""Apply transform on single input data."""
  84. if not isinstance(input, tuple):
  85. input = (input,)
  86. output = []
  87. for i in range(min(len(input), len(self.order))):
  88. apply_func = self._get_apply(self.order[i])
  89. if apply_func is None:
  90. output.append(input[i])
  91. else:
  92. output.append(apply_func(input[i]))
  93. if len(input) > len(self.order):
  94. output.extend(input[len(self.order) :])
  95. if len(output) == 1:
  96. output = output[0]
  97. else:
  98. output = tuple(output)
  99. return output
  100. def _get_apply(self, key):
  101. return getattr(self, "_apply_{}".format(key), None)
  102. def _get_image(self, input: Tuple):
  103. if not isinstance(input, tuple):
  104. input = (input,)
  105. return input[self.order.index("image")]
  106. def _apply_image(self, image):
  107. raise NotImplementedError
  108. def _apply_coords(self, coords):
  109. raise NotImplementedError
  110. def _apply_boxes(self, boxes):
  111. idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
  112. coords = np.asarray(boxes).reshape(-1, 4)[:, idxs].reshape(-1, 2)
  113. coords = self._apply_coords(coords).reshape((-1, 4, 2))
  114. minxy = coords.min(axis=1)
  115. maxxy = coords.max(axis=1)
  116. trans_boxes = np.concatenate((minxy, maxxy), axis=1)
  117. return trans_boxes
  118. def _apply_mask(self, mask):
  119. raise NotImplementedError
  120. def _apply_keypoints(self, keypoints):
  121. coords, visibility = keypoints[..., :2], keypoints[..., 2:]
  122. trans_coords = [self._apply_coords(p) for p in coords]
  123. return np.concatenate((trans_coords, visibility), axis=-1)
  124. def _apply_polygons(self, polygons):
  125. return [[self._apply_coords(p) for p in instance] for instance in polygons]
  126. class ToMode(VisionTransform):
  127. r"""Change input data to a target mode.
  128. For example, most transforms use HWC mode image,
  129. while the neural network might use CHW mode input tensor.
  130. Args:
  131. mode: output mode of input. Default: "CHW"
  132. order: the same with :class:`VisionTransform`
  133. """
  134. def __init__(self, mode="CHW", *, order=None):
  135. super().__init__(order)
  136. assert mode in ["CHW"], "unsupported mode: {}".format(mode)
  137. self.mode = mode
  138. def _apply_image(self, image):
  139. if self.mode == "CHW":
  140. return np.ascontiguousarray(np.rollaxis(image, 2))
  141. return image
  142. def _apply_coords(self, coords):
  143. return coords
  144. def _apply_mask(self, mask):
  145. if self.mode == "CHW":
  146. return np.ascontiguousarray(np.rollaxis(mask, 2))
  147. return mask
  148. class Compose(VisionTransform):
  149. r"""Composes several transfomations together.
  150. Args:
  151. transforms: list of :class:`VisionTransform` to compose.
  152. batch_compose: whether keep the same transform order in batch data when shuffle.
  153. shuffle_indices: indices used for random shuffle, start at 1.
  154. order: the same with :class:`VisionTransform`
  155. .. seealso:: Refer to :mod:`~.data.transform` module for vision transform APIs.
  156. Examples:
  157. >>> import megengine.data.transform as T
  158. >>> T.Compose([ # doctest: +SKIP
  159. ... T.RandomHorizontalFlip(), # 1st
  160. ... T.RandomVerticalFlip(), # 2nd
  161. ... T.CenterCrop(100), # 3rd
  162. ... T.ToMode("CHW"), # 4th
  163. ... ],
  164. ... shuffle_indices=[(1, 2, 3)]
  165. ... )
  166. In this case, ``shuffle_indices`` is given so each input data will be transformed
  167. out of order:
  168. .. math::
  169. \begin{array}{cc}
  170. [{\color{red}1 \quad 2 \quad 3} \quad 4] & [{\color{red}1 \quad 3 \quad 2} \quad 4] \\
  171. [{\color{red}2 \quad 1 \quad 3} \quad 4] & [{\color{red}2 \quad 3 \quad 1} \quad 4] \\
  172. [{\color{red}3 \quad 1 \quad 2} \quad 4] & [{\color{red}3 \quad 2 \quad 1} \quad 4]
  173. \end{array}
  174. In another case, if ``[(1, 3), (2, 4)]`` is given, then the 1st and 3rd transfomation
  175. will be random shuffled, the 2nd and 4th transfomation will also be shuffled:
  176. .. math::
  177. \begin{array}{cc}
  178. [{\color{red}1} \quad {\color{blue}2} \quad {\color{red}3} \quad {\color{blue}4}] &
  179. [{\color{red}1} \quad {\color{blue}4} \quad {\color{red}3} \quad {\color{blue}2}] \\
  180. [{\color{red}3} \quad {\color{blue}2} \quad {\color{red}1} \quad {\color{blue}4}] &
  181. [{\color{red}3} \quad {\color{blue}4} \quad {\color{red}1} \quad {\color{blue}2}]
  182. \end{array}
  183. Different colors represent different groups that need to be internally shuffled.
  184. .. warning::
  185. Different samples within each batch will also use random transfomation orders,
  186. unless ``batch_compose`` is set to ``True``.
  187. """
  188. def __init__(
  189. self,
  190. transforms: List[VisionTransform] = [],
  191. batch_compose: bool = False,
  192. shuffle_indices: List[Tuple] = None,
  193. *,
  194. order=None
  195. ):
  196. super().__init__(order)
  197. self.transforms = transforms
  198. self._set_order()
  199. if batch_compose and shuffle_indices is not None:
  200. raise ValueError(
  201. "Do not support shuffle when apply transforms along the whole batch"
  202. )
  203. self.batch_compose = batch_compose
  204. if shuffle_indices is not None:
  205. shuffle_indices = [tuple(x - 1 for x in idx) for idx in shuffle_indices]
  206. self.shuffle_indices = shuffle_indices
  207. def _set_order(self):
  208. for t in self.transforms:
  209. t.order = self.order
  210. if isinstance(t, Compose):
  211. t._set_order()
  212. def apply_batch(self, inputs: Sequence[Tuple]):
  213. if self.batch_compose:
  214. for t in self.transforms:
  215. inputs = t.apply_batch(inputs)
  216. return inputs
  217. else:
  218. return super().apply_batch(inputs)
  219. def apply(self, input: Tuple):
  220. for t in self._shuffle():
  221. input = t.apply(input)
  222. return input
  223. def _shuffle(self):
  224. if self.shuffle_indices is not None:
  225. source_idx = list(range(len(self.transforms)))
  226. for idx in self.shuffle_indices:
  227. shuffled = np.random.permutation(idx).tolist()
  228. for src, dst in zip(idx, shuffled):
  229. source_idx[src] = dst
  230. return [self.transforms[i] for i in source_idx]
  231. else:
  232. return self.transforms
  233. class TorchTransformCompose(VisionTransform):
  234. r"""Compose class used for transforms in torchvision, only support PIL image,
  235. some transforms with tensor in torchvision are not supported,
  236. such as Normalize and ToTensor in torchvision.
  237. Args:
  238. transforms: the same with ``Compose``.
  239. order: the same with :class:`VisionTransform`.
  240. """
  241. def __init__(self, transforms, *, order=None):
  242. super().__init__(order)
  243. self.transforms = transforms
  244. def _apply_image(self, image):
  245. from PIL import Image
  246. try:
  247. import accimage
  248. except ImportError:
  249. accimage = None
  250. if image.shape[0] == 3: # CHW
  251. image = np.ascontiguousarray(image[[2, 1, 0]])
  252. elif image.shape[2] == 3: # HWC
  253. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  254. image = Image.fromarray(image.astype(np.uint8))
  255. for t in self.transforms:
  256. image = t(image)
  257. if isinstance(image, Image.Image) or (
  258. accimage is not None and isinstance(image, accimage.Image)
  259. ):
  260. image = np.array(image, dtype=np.uint8)
  261. if image.shape[0] == 3: # CHW
  262. image = np.ascontiguousarray(image[[2, 1, 0]])
  263. elif image.shape[2] == 3: # HWC
  264. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  265. return image
  266. class Pad(VisionTransform):
  267. r"""Pad the input data.
  268. Args:
  269. size: padding size of input image, it could be integer or sequence.
  270. If it is an integer, the input image will be padded in four directions.
  271. If it is a sequence containing two integers, the bottom and right side
  272. of image will be padded.
  273. If it is a sequence containing four integers, the top, bottom, left, right
  274. side of image will be padded with given size.
  275. value: padding value of image, could be a sequence of int or float.
  276. if it is float value, the dtype of image will be casted to float32 also.
  277. mask_value: padding value of segmentation map.
  278. order: the same with :class:`VisionTransform`.
  279. """
  280. def __init__(self, size=0, value=0, mask_value=0, *, order=None):
  281. super().__init__(order)
  282. if isinstance(size, int):
  283. size = (size, size, size, size)
  284. elif isinstance(size, collections.abc.Sequence) and len(size) == 2:
  285. size = (0, size[0], 0, size[1])
  286. elif not (isinstance(size, collections.abc.Sequence) and len(size) == 4):
  287. raise ValueError(
  288. "size should be a list/tuple which contains "
  289. "(top, down, left, right) four pad sizes."
  290. )
  291. self.size = size
  292. self.value = value
  293. if not isinstance(mask_value, int):
  294. raise ValueError(
  295. "mask_value should be a positive integer, "
  296. "but got mask_value={}".format(mask_value)
  297. )
  298. self.mask_value = mask_value
  299. def _apply_image(self, image):
  300. return F.pad(image, self.size, self.value)
  301. def _apply_coords(self, coords):
  302. coords[:, 0] += self.size[2]
  303. coords[:, 1] += self.size[0]
  304. return coords
  305. def _apply_mask(self, mask):
  306. return F.pad(mask, self.size, self.mask_value)
  307. class Resize(VisionTransform):
  308. r"""Resize the input data.
  309. Args:
  310. output_size: target size of image, with (height, width) shape.
  311. interpolation: interpolation method. All methods are listed below:
  312. * cv2.INTER_NEAREST – a nearest-neighbor interpolation.
  313. * cv2.INTER_LINEAR – a bilinear interpolation (used by default).
  314. * cv2.INTER_AREA – resampling using pixel area relation.
  315. * cv2.INTER_CUBIC – a bicubic interpolation over 4×4 pixel neighborhood.
  316. * cv2.INTER_LANCZOS4 – a Lanczos interpolation over 8×8 pixel neighborhood.
  317. order: the same with :class:`VisionTransform`.
  318. """
  319. def __init__(self, output_size, interpolation=cv2.INTER_LINEAR, *, order=None):
  320. super().__init__(order)
  321. self.output_size = output_size
  322. self.interpolation = interpolation
  323. def apply(self, input: Tuple):
  324. self._shape_info = self._get_shape(self._get_image(input))
  325. return super().apply(input)
  326. def _apply_image(self, image):
  327. h, w, th, tw = self._shape_info
  328. if h == th and w == tw:
  329. return image
  330. return F.resize(image, (th, tw), self.interpolation)
  331. def _apply_coords(self, coords):
  332. h, w, th, tw = self._shape_info
  333. if h == th and w == tw:
  334. return coords
  335. coords[:, 0] = coords[:, 0] * (tw / w)
  336. coords[:, 1] = coords[:, 1] * (th / h)
  337. return coords
  338. def _apply_mask(self, mask):
  339. h, w, th, tw = self._shape_info
  340. if h == th and w == tw:
  341. return mask
  342. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  343. def _get_shape(self, image):
  344. h, w, _ = image.shape
  345. if isinstance(self.output_size, int):
  346. if min(h, w) == self.output_size:
  347. return h, w, h, w
  348. if h < w:
  349. th = self.output_size
  350. tw = int(self.output_size * w / h)
  351. else:
  352. tw = self.output_size
  353. th = int(self.output_size * h / w)
  354. return h, w, th, tw
  355. else:
  356. return (h, w, *self.output_size)
  357. class ShortestEdgeResize(VisionTransform):
  358. r"""Resize the input data with specified shortset edge."""
  359. def __init__(
  360. self,
  361. min_size,
  362. max_size,
  363. sample_style="range",
  364. interpolation=cv2.INTER_LINEAR,
  365. *,
  366. order=None
  367. ):
  368. super().__init__(order)
  369. if sample_style not in ("range", "choice"):
  370. raise NotImplementedError(
  371. "{} is unsupported sample style".format(sample_style)
  372. )
  373. self.sample_style = sample_style
  374. if isinstance(min_size, int):
  375. min_size = (min_size, min_size)
  376. self.min_size = min_size
  377. self.max_size = max_size
  378. self.interpolation = interpolation
  379. def apply(self, input: Tuple):
  380. self._shape_info = self._get_shape(self._get_image(input))
  381. return super().apply(input)
  382. def _apply_image(self, image):
  383. h, w, th, tw = self._shape_info
  384. if h == th and w == tw:
  385. return image
  386. return F.resize(image, (th, tw), self.interpolation)
  387. def _apply_coords(self, coords):
  388. h, w, th, tw = self._shape_info
  389. if h == th and w == tw:
  390. return coords
  391. coords[:, 0] = coords[:, 0] * (tw / w)
  392. coords[:, 1] = coords[:, 1] * (th / h)
  393. return coords
  394. def _apply_mask(self, mask):
  395. h, w, th, tw = self._shape_info
  396. if h == th and w == tw:
  397. return mask
  398. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  399. def _get_shape(self, image):
  400. h, w, _ = image.shape
  401. if self.sample_style == "range":
  402. size = np.random.randint(self.min_size[0], self.min_size[1] + 1)
  403. else:
  404. size = np.random.choice(self.min_size)
  405. scale = size / min(h, w)
  406. if h < w:
  407. th, tw = size, scale * w
  408. else:
  409. th, tw = scale * h, size
  410. if max(th, tw) > self.max_size:
  411. scale = self.max_size / max(th, tw)
  412. th = th * scale
  413. tw = tw * scale
  414. th = int(round(th))
  415. tw = int(round(tw))
  416. return h, w, th, tw
  417. class RandomResize(VisionTransform):
  418. r"""Resize the input data randomly.
  419. Args:
  420. scale_range: range of scaling.
  421. order: the same with :class:`VisionTransform`.
  422. """
  423. def __init__(self, scale_range, interpolation=cv2.INTER_LINEAR, *, order=None):
  424. super().__init__(order)
  425. self.scale_range = scale_range
  426. self.interpolation = interpolation
  427. def apply(self, input: Tuple):
  428. self._shape_info = self._get_shape(self._get_image(input))
  429. return super().apply(input)
  430. def _apply_image(self, image):
  431. h, w, th, tw = self._shape_info
  432. if h == th and w == tw:
  433. return image
  434. return F.resize(image, (th, tw), self.interpolation)
  435. def _apply_coords(self, coords):
  436. h, w, th, tw = self._shape_info
  437. if h == th and w == tw:
  438. return coords
  439. coords[:, 0] = coords[:, 0] * (tw / w)
  440. coords[:, 1] = coords[:, 1] * (th / h)
  441. return coords
  442. def _apply_mask(self, mask):
  443. h, w, th, tw = self._shape_info
  444. if h == th and w == tw:
  445. return mask
  446. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  447. def _get_shape(self, image):
  448. h, w, _ = image.shape
  449. scale = np.random.uniform(*self.scale_range)
  450. th = int(round(h * scale))
  451. tw = int(round(w * scale))
  452. return h, w, th, tw
  453. class RandomCrop(VisionTransform):
  454. r"""Crop the input data randomly. Before applying the crop transform,
  455. pad the image first. If target size is still bigger than the size of
  456. padded image, pad the image size to target size.
  457. Args:
  458. output_size: target size of output image, with (height, width) shape.
  459. padding_size: the same with `size` in ``Pad``.
  460. padding_value: the same with `value` in ``Pad``.
  461. order: the same with :class:`VisionTransform`.
  462. """
  463. def __init__(
  464. self,
  465. output_size,
  466. padding_size=0,
  467. padding_value=[0, 0, 0],
  468. padding_maskvalue=0,
  469. *,
  470. order=None
  471. ):
  472. super().__init__(order)
  473. if isinstance(output_size, int):
  474. self.output_size = (output_size, output_size)
  475. else:
  476. self.output_size = output_size
  477. self.pad = Pad(padding_size, padding_value, order=self.order)
  478. self.padding_value = padding_value
  479. self.padding_maskvalue = padding_maskvalue
  480. def apply(self, input):
  481. input = self.pad.apply(input)
  482. self._h, self._w, _ = self._get_image(input).shape
  483. self._th, self._tw = self.output_size
  484. self._x = np.random.randint(0, max(0, self._w - self._tw) + 1)
  485. self._y = np.random.randint(0, max(0, self._h - self._th) + 1)
  486. return super().apply(input)
  487. def _apply_image(self, image):
  488. if self._th > self._h:
  489. image = F.pad(image, (self._th - self._h, 0), self.padding_value)
  490. if self._tw > self._w:
  491. image = F.pad(image, (0, self._tw - self._w), self.padding_value)
  492. return image[self._y : self._y + self._th, self._x : self._x + self._tw]
  493. def _apply_coords(self, coords):
  494. coords[:, 0] -= self._x
  495. coords[:, 1] -= self._y
  496. return coords
  497. def _apply_mask(self, mask):
  498. if self._th > self._h:
  499. mask = F.pad(mask, (self._th - self._h, 0), self.padding_maskvalue)
  500. if self._tw > self._w:
  501. mask = F.pad(mask, (0, self._tw - self._w), self.padding_maskvalue)
  502. return mask[self._y : self._y + self._th, self._x : self._x + self._tw]
  503. class RandomResizedCrop(VisionTransform):
  504. r"""Crop the input data to random size and aspect ratio.
  505. A crop of random size (default: of 0.08 to 1.0) of the original size and a random
  506. aspect ratio (default: of 3/4 to 1.33) of the original aspect ratio is made.
  507. After applying crop transfrom, the input data will be resized to given size.
  508. Args:
  509. output_size: target size of output image, with (height, width) shape.
  510. scale_range: range of size of the origin size cropped. Default: (0.08, 1.0)
  511. ratio_range: range of aspect ratio of the origin aspect ratio cropped. Default: (0.75, 1.33)
  512. order: the same with :class:`VisionTransform`.
  513. """
  514. def __init__(
  515. self,
  516. output_size,
  517. scale_range=(0.08, 1.0),
  518. ratio_range=(3.0 / 4, 4.0 / 3),
  519. interpolation=cv2.INTER_LINEAR,
  520. *,
  521. order=None
  522. ):
  523. super().__init__(order)
  524. if isinstance(output_size, int):
  525. self.output_size = (output_size, output_size)
  526. else:
  527. self.output_size = output_size
  528. assert (
  529. scale_range[0] <= scale_range[1]
  530. ), "scale_range should be of kind (min, max)"
  531. assert (
  532. ratio_range[0] <= ratio_range[1]
  533. ), "ratio_range should be of kind (min, max)"
  534. self.scale_range = scale_range
  535. self.ratio_range = ratio_range
  536. self.interpolation = interpolation
  537. def apply(self, input: Tuple):
  538. self._coord_info = self._get_coord(self._get_image(input))
  539. return super().apply(input)
  540. def _apply_image(self, image):
  541. x, y, w, h = self._coord_info
  542. cropped_img = image[y : y + h, x : x + w]
  543. return F.resize(cropped_img, self.output_size, self.interpolation)
  544. def _apply_coords(self, coords):
  545. x, y, w, h = self._coord_info
  546. coords[:, 0] = (coords[:, 0] - x) * self.output_size[1] / w
  547. coords[:, 1] = (coords[:, 1] - y) * self.output_size[0] / h
  548. return coords
  549. def _apply_mask(self, mask):
  550. x, y, w, h = self._coord_info
  551. cropped_mask = mask[y : y + h, x : x + w]
  552. return F.resize(cropped_mask, self.output_size, cv2.INTER_NEAREST)
  553. def _get_coord(self, image, attempts=10):
  554. height, width, _ = image.shape
  555. area = height * width
  556. for _ in range(attempts):
  557. target_area = np.random.uniform(*self.scale_range) * area
  558. log_ratio = tuple(math.log(x) for x in self.ratio_range)
  559. aspect_ratio = math.exp(np.random.uniform(*log_ratio))
  560. w = int(round(math.sqrt(target_area * aspect_ratio)))
  561. h = int(round(math.sqrt(target_area / aspect_ratio)))
  562. if 0 < w <= width and 0 < h <= height:
  563. x = np.random.randint(0, width - w + 1)
  564. y = np.random.randint(0, height - h + 1)
  565. return x, y, w, h
  566. # Fallback to central crop
  567. in_ratio = float(width) / float(height)
  568. if in_ratio < min(self.ratio_range):
  569. w = width
  570. h = int(round(w / min(self.ratio_range)))
  571. elif in_ratio > max(self.ratio_range):
  572. h = height
  573. w = int(round(h * max(self.ratio_range)))
  574. else: # whole image
  575. w = width
  576. h = height
  577. x = (width - w) // 2
  578. y = (height - h) // 2
  579. return x, y, w, h
  580. class CenterCrop(VisionTransform):
  581. r"""Crops the given the input data at the center.
  582. Args:
  583. output_size: target size of output image, with (height, width) shape.
  584. order: the same with :class:`VisionTransform`.
  585. """
  586. def __init__(self, output_size, *, order=None):
  587. super().__init__(order)
  588. if isinstance(output_size, int):
  589. self.output_size = (output_size, output_size)
  590. else:
  591. self.output_size = output_size
  592. def apply(self, input: Tuple):
  593. self._coord_info = self._get_coord(self._get_image(input))
  594. return super().apply(input)
  595. def _apply_image(self, image):
  596. x, y = self._coord_info
  597. th, tw = self.output_size
  598. return image[y : y + th, x : x + tw]
  599. def _apply_coords(self, coords):
  600. x, y = self._coord_info
  601. coords[:, 0] -= x
  602. coords[:, 1] -= y
  603. return coords
  604. def _apply_mask(self, mask):
  605. x, y = self._coord_info
  606. th, tw = self.output_size
  607. return mask[y : y + th, x : x + tw]
  608. def _get_coord(self, image):
  609. th, tw = self.output_size
  610. h, w, _ = image.shape
  611. assert th <= h and tw <= w, "output size is bigger than image size"
  612. x = int(round((w - tw) / 2.0))
  613. y = int(round((h - th) / 2.0))
  614. return x, y
  615. class RandomHorizontalFlip(VisionTransform):
  616. r"""Horizontally flip the input data randomly with a given probability.
  617. Args:
  618. p: probability of the input data being flipped. Default: 0.5
  619. order: the same with :class:`VisionTransform`.
  620. """
  621. def __init__(self, prob: float = 0.5, *, order=None):
  622. super().__init__(order)
  623. self.prob = prob
  624. def apply(self, input: Tuple):
  625. self._flipped = np.random.random() < self.prob
  626. self._w = self._get_image(input).shape[1]
  627. return super().apply(input)
  628. def _apply_image(self, image):
  629. if self._flipped:
  630. return F.flip(image, flipCode=1)
  631. return image
  632. def _apply_coords(self, coords):
  633. if self._flipped:
  634. coords[:, 0] = self._w - coords[:, 0]
  635. return coords
  636. def _apply_mask(self, mask):
  637. if self._flipped:
  638. return F.flip(mask, flipCode=1)
  639. return mask
  640. class RandomVerticalFlip(VisionTransform):
  641. r"""Vertically flip the input data randomly with a given probability.
  642. Args:
  643. p: probability of the input data being flipped. Default: 0.5
  644. order: the same with :class:`VisionTransform`.
  645. """
  646. def __init__(self, prob: float = 0.5, *, order=None):
  647. super().__init__(order)
  648. self.prob = prob
  649. def apply(self, input: Tuple):
  650. self._flipped = np.random.random() < self.prob
  651. self._h = self._get_image(input).shape[0]
  652. return super().apply(input)
  653. def _apply_image(self, image):
  654. if self._flipped:
  655. return F.flip(image, flipCode=0)
  656. return image
  657. def _apply_coords(self, coords):
  658. if self._flipped:
  659. coords[:, 1] = self._h - coords[:, 1]
  660. return coords
  661. def _apply_mask(self, mask):
  662. if self._flipped:
  663. return F.flip(mask, flipCode=0)
  664. return mask
  665. class Normalize(VisionTransform):
  666. r"""Normalize the input data with mean and standard deviation.
  667. Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
  668. this transform will normalize each channel of the input data.
  669. ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
  670. Args:
  671. mean: sequence of means for each channel.
  672. std: sequence of standard deviations for each channel.
  673. order: the same with :class:`VisionTransform`.
  674. """
  675. def __init__(self, mean=0.0, std=1.0, *, order=None):
  676. super().__init__(order)
  677. self.mean = np.array(mean, dtype=np.float32)
  678. self.std = np.array(std, dtype=np.float32)
  679. def _apply_image(self, image):
  680. return (image - self.mean) / self.std
  681. def _apply_coords(self, coords):
  682. return coords
  683. def _apply_mask(self, mask):
  684. return mask
  685. class GaussianNoise(VisionTransform):
  686. r"""Add random gaussian noise to the input data.
  687. Gaussian noise is generated with given mean and std.
  688. Args:
  689. mean: Gaussian mean used to generate noise.
  690. std: Gaussian standard deviation used to generate noise.
  691. order: the same with :class:`VisionTransform`
  692. """
  693. def __init__(self, mean=0.0, std=1.0, *, order=None):
  694. super().__init__(order)
  695. self.mean = np.array(mean, dtype=np.float32)
  696. self.std = np.array(std, dtype=np.float32)
  697. def _apply_image(self, image):
  698. dtype = image.dtype
  699. noise = np.random.normal(self.mean, self.std, image.shape) * 255
  700. image = image + noise.astype(np.float32)
  701. return np.clip(image, 0, 255).astype(dtype)
  702. def _apply_coords(self, coords):
  703. return coords
  704. def _apply_mask(self, mask):
  705. return mask
  706. class BrightnessTransform(VisionTransform):
  707. r"""Adjust brightness of the input data.
  708. Args:
  709. value: how much to adjust the brightness. Can be any
  710. non negative number. 0 gives the original image.
  711. order: the same with :class:`VisionTransform`.
  712. """
  713. def __init__(self, value, *, order=None):
  714. super().__init__(order)
  715. if value < 0:
  716. raise ValueError("brightness value should be non-negative")
  717. self.value = value
  718. def _apply_image(self, image):
  719. if self.value == 0:
  720. return image
  721. dtype = image.dtype
  722. image = image.astype(np.float32)
  723. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  724. image = image * alpha
  725. return image.clip(0, 255).astype(dtype)
  726. def _apply_coords(self, coords):
  727. return coords
  728. def _apply_mask(self, mask):
  729. return mask
  730. class ContrastTransform(VisionTransform):
  731. r"""Adjust contrast of the input data.
  732. Args:
  733. value: how much to adjust the contrast. Can be any
  734. non negative number. 0 gives the original image.
  735. order: the same with :class:`VisionTransform`.
  736. """
  737. def __init__(self, value, *, order=None):
  738. super().__init__(order)
  739. if value < 0:
  740. raise ValueError("contrast value should be non-negative")
  741. self.value = value
  742. def _apply_image(self, image):
  743. if self.value == 0:
  744. return image
  745. dtype = image.dtype
  746. image = image.astype(np.float32)
  747. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  748. image = image * alpha + F.to_gray(image).mean() * (1 - alpha)
  749. return image.clip(0, 255).astype(dtype)
  750. def _apply_coords(self, coords):
  751. return coords
  752. def _apply_mask(self, mask):
  753. return mask
  754. class SaturationTransform(VisionTransform):
  755. r"""Adjust saturation of the input data.
  756. Args:
  757. value: how much to adjust the saturation. Can be any
  758. non negative number. 0 gives the original image.
  759. order: the same with :class:`VisionTransform`.
  760. """
  761. def __init__(self, value, *, order=None):
  762. super().__init__(order)
  763. if value < 0:
  764. raise ValueError("saturation value should be non-negative")
  765. self.value = value
  766. def _apply_image(self, image):
  767. if self.value == 0:
  768. return image
  769. dtype = image.dtype
  770. image = image.astype(np.float32)
  771. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  772. image = image * alpha + F.to_gray(image) * (1 - alpha)
  773. return image.clip(0, 255).astype(dtype)
  774. def _apply_coords(self, coords):
  775. return coords
  776. def _apply_mask(self, mask):
  777. return mask
  778. class HueTransform(VisionTransform):
  779. r"""Adjust hue of the input data.
  780. Args:
  781. value: how much to adjust the hue. Can be any number
  782. between 0 and 0.5, 0 gives the original image.
  783. order: the same with :class:`VisionTransform`.
  784. """
  785. def __init__(self, value, *, order=None):
  786. super().__init__(order)
  787. if value < 0 or value > 0.5:
  788. raise ValueError("hue value should be in [0.0, 0.5]")
  789. self.value = value
  790. def _apply_image(self, image):
  791. if self.value == 0:
  792. return image
  793. dtype = image.dtype
  794. image = image.astype(np.uint8)
  795. hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV_FULL)
  796. h, s, v = cv2.split(hsv_image)
  797. alpha = np.random.uniform(-self.value, self.value)
  798. h = h.astype(np.uint8)
  799. # uint8 addition take cares of rotation across boundaries
  800. with np.errstate(over="ignore"):
  801. h += np.uint8(alpha * 255)
  802. hsv_image = cv2.merge([h, s, v])
  803. return cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR_FULL).astype(dtype)
  804. def _apply_coords(self, coords):
  805. return coords
  806. def _apply_mask(self, mask):
  807. return mask
  808. class ColorJitter(VisionTransform):
  809. r"""Randomly change the brightness, contrast, saturation and hue of an image.
  810. Args:
  811. brightness: how much to jitter brightness.
  812. Chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
  813. or the given [min, max]. Should be non negative numbers.
  814. contrast: how much to jitter contrast.
  815. Chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
  816. or the given [min, max]. Should be non negative numbers.
  817. saturation: how much to jitter saturation.
  818. Chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
  819. or the given [min, max]. Should be non negative numbers.
  820. hue: how much to jitter hue.
  821. Chosen uniformly from [-hue, hue] or the given [min, max].
  822. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
  823. order: the same with :class:`VisionTransform`.
  824. """
  825. def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, *, order=None):
  826. super().__init__(order)
  827. transforms = []
  828. if brightness != 0:
  829. transforms.append(BrightnessTransform(brightness))
  830. if contrast != 0:
  831. transforms.append(ContrastTransform(contrast))
  832. if saturation != 0:
  833. transforms.append(SaturationTransform(saturation))
  834. if hue != 0:
  835. transforms.append(HueTransform(hue))
  836. self.transforms = Compose(
  837. transforms,
  838. shuffle_indices=[tuple(range(1, len(transforms) + 1))],
  839. order=order,
  840. )
  841. def apply(self, input):
  842. return self.transforms.apply(input)
  843. class Lighting(VisionTransform):
  844. r"""Apply AlexNet-Style "lighting" augmentation to input data.
  845. Input images are assumed to have 'RGB' channel order.
  846. The degree of color jittering is randomly sampled via a normal distribution,
  847. with standard deviation given by the scale parameter.
  848. """
  849. def __init__(self, scale, *, order=None):
  850. super().__init__(order)
  851. if scale < 0:
  852. raise ValueError("lighting scale should be non-negative")
  853. self.scale = scale
  854. self.eigvec = np.array(
  855. [
  856. [-0.5836, -0.6948, 0.4203],
  857. [-0.5808, -0.0045, -0.8140],
  858. [-0.5675, 0.7192, 0.4009],
  859. ]
  860. ) # reverse the first dimension for BGR
  861. self.eigval = np.array([0.2175, 0.0188, 0.0045])
  862. def _apply_image(self, image):
  863. if self.scale == 0:
  864. return image
  865. dtype = image.dtype
  866. image = image.astype(np.float32)
  867. alpha = np.random.normal(scale=self.scale * 255, size=3)
  868. image = image + self.eigvec.dot(alpha * self.eigval)
  869. return image.clip(0, 255).astype(dtype)
  870. def _apply_coords(self, coords):
  871. return coords
  872. def _apply_mask(self, mask):
  873. return mask