You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.py 36 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections.abc
  10. import math
  11. from typing import Sequence, Tuple
  12. import cv2
  13. import numpy as np
  14. from megengine.data.transform import Transform
  15. from megengine.data.transform.vision import functional as F
  16. __all__ = [
  17. "VisionTransform",
  18. "ToMode",
  19. "Compose",
  20. "TorchTransformCompose",
  21. "Pad",
  22. "Resize",
  23. "ShortestEdgeResize",
  24. "RandomResize",
  25. "RandomCrop",
  26. "RandomResizedCrop",
  27. "CenterCrop",
  28. "RandomHorizontalFlip",
  29. "RandomVerticalFlip",
  30. "Normalize",
  31. "GaussianNoise",
  32. "BrightnessTransform",
  33. "SaturationTransform",
  34. "ContrastTransform",
  35. "HueTransform",
  36. "ColorJitter",
  37. "Lighting",
  38. ]
  39. class VisionTransform(Transform):
  40. r"""Base class of all transforms used in computer vision.
  41. Calling logic: apply_batch() -> apply() -> _apply_image() and other _apply_*()
  42. method. If you want to implement a self-defined transform method for image,
  43. rewrite _apply_image method in subclass.
  44. Args:
  45. order: input type order. Input is a tuple containing different structures,
  46. order is used to specify the order of structures. For example, if your input
  47. is (image, boxes) type, then the ``order`` should be ("image", "boxes").
  48. Current available strings and data type are describe below:
  49. * "image": input image, with shape of `(H, W, C)`.
  50. * "coords": coordinates, with shape of `(N, 2)`.
  51. * "boxes": bounding boxes, with shape of `(N, 4)`, "xyxy" format,
  52. the 1st "xy" represents top left point of a box,
  53. the 2nd "xy" represents right bottom point.
  54. * "mask": map used for segmentation, with shape of `(H, W, 1)`.
  55. * "keypoints": keypoints with shape of `(N, K, 3)`, N for number of instances,
  56. and K for number of keypoints in one instance. The first two dimensions
  57. of last axis is coordinate of keypoints and the the 3rd dimension is
  58. the label of keypoints.
  59. * "polygons": a sequence containing numpy arrays, its length is the number of instances.
  60. Each numpy array represents polygon coordinate of one instance.
  61. * "category": categories for some data type. For example, "image_category"
  62. means category of the input image and "boxes_category" means categories of
  63. bounding boxes.
  64. * "info": information for images such as image shapes and image path.
  65. You can also customize your data types only if you implement the corresponding
  66. _apply_*() methods, otherwise ``NotImplementedError`` will be raised.
  67. """
  68. def __init__(self, order=None):
  69. super().__init__()
  70. if order is None:
  71. order = ("image",)
  72. elif not isinstance(order, collections.abc.Sequence):
  73. raise ValueError(
  74. "order should be a sequence, but got order={}".format(order)
  75. )
  76. for k in order:
  77. if k in ("batch",):
  78. raise ValueError("{} is invalid data type".format(k))
  79. elif k.endswith("category") or k.endswith("info"):
  80. # when the key is *category or info, we should do nothing
  81. # if the corresponding apply methods are not implemented.
  82. continue
  83. elif self._get_apply(k) is None:
  84. raise NotImplementedError("{} is unsupported data type".format(k))
  85. self.order = order
  86. def apply_batch(self, inputs: Sequence[Tuple]):
  87. r"""Apply transform on batch input data."""
  88. return tuple(self.apply(input) for input in inputs)
  89. def apply(self, input: Tuple):
  90. r"""Apply transform on single input data."""
  91. if not isinstance(input, tuple):
  92. input = (input,)
  93. output = []
  94. for i in range(min(len(input), len(self.order))):
  95. apply_func = self._get_apply(self.order[i])
  96. if apply_func is None:
  97. output.append(input[i])
  98. else:
  99. output.append(apply_func(input[i]))
  100. if len(input) > len(self.order):
  101. output.extend(input[len(self.order) :])
  102. if len(output) == 1:
  103. output = output[0]
  104. else:
  105. output = tuple(output)
  106. return output
  107. def _get_apply(self, key):
  108. return getattr(self, "_apply_{}".format(key), None)
  109. def _get_image(self, input: Tuple):
  110. if not isinstance(input, tuple):
  111. input = (input,)
  112. return input[self.order.index("image")]
  113. def _apply_image(self, image):
  114. raise NotImplementedError
  115. def _apply_coords(self, coords):
  116. raise NotImplementedError
  117. def _apply_boxes(self, boxes):
  118. idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
  119. coords = np.asarray(boxes).reshape(-1, 4)[:, idxs].reshape(-1, 2)
  120. coords = self._apply_coords(coords).reshape((-1, 4, 2))
  121. minxy = coords.min(axis=1)
  122. maxxy = coords.max(axis=1)
  123. trans_boxes = np.concatenate((minxy, maxxy), axis=1)
  124. return trans_boxes
  125. def _apply_mask(self, mask):
  126. raise NotImplementedError
  127. def _apply_keypoints(self, keypoints):
  128. coords, visibility = keypoints[..., :2], keypoints[..., 2:]
  129. trans_coords = [self._apply_coords(p) for p in coords]
  130. return np.concatenate((trans_coords, visibility), axis=-1)
  131. def _apply_polygons(self, polygons):
  132. return [[self._apply_coords(p) for p in instance] for instance in polygons]
  133. class ToMode(VisionTransform):
  134. r"""Change input data to a target mode.
  135. For example, most transforms use HWC mode image,
  136. while the neural network might use CHW mode input tensor.
  137. Args:
  138. mode: output mode of input. Default: "CHW"
  139. order: the same with :class:`VisionTransform`
  140. """
  141. def __init__(self, mode="CHW", *, order=None):
  142. super().__init__(order)
  143. assert mode in ["CHW"], "unsupported mode: {}".format(mode)
  144. self.mode = mode
  145. def _apply_image(self, image):
  146. if self.mode == "CHW":
  147. return np.ascontiguousarray(np.rollaxis(image, 2))
  148. return image
  149. def _apply_coords(self, coords):
  150. return coords
  151. def _apply_mask(self, mask):
  152. if self.mode == "CHW":
  153. return np.ascontiguousarray(np.rollaxis(mask, 2))
  154. return mask
  155. class Compose(VisionTransform):
  156. r"""Composes several transforms together.
  157. Args:
  158. transforms: list of :class:`VisionTransform` to compose.
  159. batch_compose: whether use shuffle_indices for batch data or not.
  160. If True, use original input sequence.
  161. Otherwise, the shuffle_indices will be used for transforms.
  162. shuffle_indices: indices used for random shuffle, start at 1.
  163. For example, if shuffle_indices is [(1, 3), (2, 4)], then the 1st and 3rd transform
  164. will be random shuffled, the 2nd and 4th transform will also be shuffled.
  165. order: the same with :class:`VisionTransform`
  166. Examples:
  167. .. testcode::
  168. from megengine.data.transform import RandomHorizontalFlip, RandomVerticalFlip, CenterCrop, ToMode, Compose
  169. transform_func = Compose([
  170. RandomHorizontalFlip(),
  171. RandomVerticalFlip(),
  172. CenterCrop(100),
  173. ToMode("CHW"),
  174. ],
  175. shuffle_indices=[(1, 2, 3)]
  176. )
  177. """
  178. def __init__(
  179. self, transforms=[], batch_compose=False, shuffle_indices=None, *, order=None
  180. ):
  181. super().__init__(order)
  182. self.transforms = transforms
  183. self._set_order()
  184. if batch_compose and shuffle_indices is not None:
  185. raise ValueError(
  186. "Do not support shuffle when apply transforms along the whole batch"
  187. )
  188. self.batch_compose = batch_compose
  189. if shuffle_indices is not None:
  190. shuffle_indices = [tuple(x - 1 for x in idx) for idx in shuffle_indices]
  191. self.shuffle_indices = shuffle_indices
  192. def _set_order(self):
  193. for t in self.transforms:
  194. t.order = self.order
  195. if isinstance(t, Compose):
  196. t._set_order()
  197. def apply_batch(self, inputs: Sequence[Tuple]):
  198. if self.batch_compose:
  199. for t in self.transforms:
  200. inputs = t.apply_batch(inputs)
  201. return inputs
  202. else:
  203. return super().apply_batch(inputs)
  204. def apply(self, input: Tuple):
  205. for t in self._shuffle():
  206. input = t.apply(input)
  207. return input
  208. def _shuffle(self):
  209. if self.shuffle_indices is not None:
  210. source_idx = list(range(len(self.transforms)))
  211. for idx in self.shuffle_indices:
  212. shuffled = np.random.permutation(idx).tolist()
  213. for src, dst in zip(idx, shuffled):
  214. source_idx[src] = dst
  215. return [self.transforms[i] for i in source_idx]
  216. else:
  217. return self.transforms
  218. class TorchTransformCompose(VisionTransform):
  219. r"""Compose class used for transforms in torchvision, only support PIL image,
  220. some transforms with tensor in torchvision are not supported,
  221. such as Normalize and ToTensor in torchvision.
  222. Args:
  223. transforms: the same with ``Compose``.
  224. order: the same with :class:`VisionTransform`.
  225. """
  226. def __init__(self, transforms, *, order=None):
  227. super().__init__(order)
  228. self.transforms = transforms
  229. def _apply_image(self, image):
  230. from PIL import Image
  231. try:
  232. import accimage
  233. except ImportError:
  234. accimage = None
  235. if image.shape[0] == 3: # CHW
  236. image = np.ascontiguousarray(image[[2, 1, 0]])
  237. elif image.shape[2] == 3: # HWC
  238. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  239. image = Image.fromarray(image.astype(np.uint8))
  240. for t in self.transforms:
  241. image = t(image)
  242. if isinstance(image, Image.Image) or (
  243. accimage is not None and isinstance(image, accimage.Image)
  244. ):
  245. image = np.array(image, dtype=np.uint8)
  246. if image.shape[0] == 3: # CHW
  247. image = np.ascontiguousarray(image[[2, 1, 0]])
  248. elif image.shape[2] == 3: # HWC
  249. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  250. return image
  251. class Pad(VisionTransform):
  252. r"""Pad the input data.
  253. Args:
  254. size: padding size of input image, it could be integer or sequence.
  255. If it is an integer, the input image will be padded in four directions.
  256. If it is a sequence containing two integers, the bottom and right side
  257. of image will be padded.
  258. If it is a sequence containing four integers, the top, bottom, left, right
  259. side of image will be padded with given size.
  260. value: padding value of image, could be a sequence of int or float.
  261. if it is float value, the dtype of image will be casted to float32 also.
  262. mask_value: padding value of segmentation map.
  263. order: the same with :class:`VisionTransform`.
  264. """
  265. def __init__(self, size=0, value=0, mask_value=0, *, order=None):
  266. super().__init__(order)
  267. if isinstance(size, int):
  268. size = (size, size, size, size)
  269. elif isinstance(size, collections.abc.Sequence) and len(size) == 2:
  270. size = (0, size[0], 0, size[1])
  271. elif not (isinstance(size, collections.abc.Sequence) and len(size) == 4):
  272. raise ValueError(
  273. "size should be a list/tuple which contains "
  274. "(top, down, left, right) four pad sizes."
  275. )
  276. self.size = size
  277. self.value = value
  278. if not isinstance(mask_value, int):
  279. raise ValueError(
  280. "mask_value should be a positive integer, "
  281. "but got mask_value={}".format(mask_value)
  282. )
  283. self.mask_value = mask_value
  284. def _apply_image(self, image):
  285. return F.pad(image, self.size, self.value)
  286. def _apply_coords(self, coords):
  287. coords[:, 0] += self.size[2]
  288. coords[:, 1] += self.size[0]
  289. return coords
  290. def _apply_mask(self, mask):
  291. return F.pad(mask, self.size, self.mask_value)
  292. class Resize(VisionTransform):
  293. r"""Resize the input data.
  294. Args:
  295. output_size: target size of image, with (height, width) shape.
  296. interpolation: interpolation method. All methods are listed below:
  297. * cv2.INTER_NEAREST – a nearest-neighbor interpolation.
  298. * cv2.INTER_LINEAR – a bilinear interpolation (used by default).
  299. * cv2.INTER_AREA – resampling using pixel area relation.
  300. * cv2.INTER_CUBIC – a bicubic interpolation over 4×4 pixel neighborhood.
  301. * cv2.INTER_LANCZOS4 – a Lanczos interpolation over 8×8 pixel neighborhood.
  302. order: the same with :class:`VisionTransform`.
  303. """
  304. def __init__(self, output_size, interpolation=cv2.INTER_LINEAR, *, order=None):
  305. super().__init__(order)
  306. self.output_size = output_size
  307. self.interpolation = interpolation
  308. def apply(self, input: Tuple):
  309. self._shape_info = self._get_shape(self._get_image(input))
  310. return super().apply(input)
  311. def _apply_image(self, image):
  312. h, w, th, tw = self._shape_info
  313. if h == th and w == tw:
  314. return image
  315. return F.resize(image, (th, tw), self.interpolation)
  316. def _apply_coords(self, coords):
  317. h, w, th, tw = self._shape_info
  318. if h == th and w == tw:
  319. return coords
  320. coords[:, 0] = coords[:, 0] * (tw / w)
  321. coords[:, 1] = coords[:, 1] * (th / h)
  322. return coords
  323. def _apply_mask(self, mask):
  324. h, w, th, tw = self._shape_info
  325. if h == th and w == tw:
  326. return mask
  327. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  328. def _get_shape(self, image):
  329. h, w, _ = image.shape
  330. if isinstance(self.output_size, int):
  331. if min(h, w) == self.output_size:
  332. return h, w, h, w
  333. if h < w:
  334. th = self.output_size
  335. tw = int(self.output_size * w / h)
  336. else:
  337. tw = self.output_size
  338. th = int(self.output_size * h / w)
  339. return h, w, th, tw
  340. else:
  341. return (h, w, *self.output_size)
  342. class ShortestEdgeResize(VisionTransform):
  343. r"""Resize the input data with specified shortset edge."""
  344. def __init__(
  345. self,
  346. min_size,
  347. max_size,
  348. sample_style="range",
  349. interpolation=cv2.INTER_LINEAR,
  350. *,
  351. order=None
  352. ):
  353. super().__init__(order)
  354. if sample_style not in ("range", "choice"):
  355. raise NotImplementedError(
  356. "{} is unsupported sample style".format(sample_style)
  357. )
  358. self.sample_style = sample_style
  359. if isinstance(min_size, int):
  360. min_size = (min_size, min_size)
  361. self.min_size = min_size
  362. self.max_size = max_size
  363. self.interpolation = interpolation
  364. def apply(self, input: Tuple):
  365. self._shape_info = self._get_shape(self._get_image(input))
  366. return super().apply(input)
  367. def _apply_image(self, image):
  368. h, w, th, tw = self._shape_info
  369. if h == th and w == tw:
  370. return image
  371. return F.resize(image, (th, tw), self.interpolation)
  372. def _apply_coords(self, coords):
  373. h, w, th, tw = self._shape_info
  374. if h == th and w == tw:
  375. return coords
  376. coords[:, 0] = coords[:, 0] * (tw / w)
  377. coords[:, 1] = coords[:, 1] * (th / h)
  378. return coords
  379. def _apply_mask(self, mask):
  380. h, w, th, tw = self._shape_info
  381. if h == th and w == tw:
  382. return mask
  383. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  384. def _get_shape(self, image):
  385. h, w, _ = image.shape
  386. if self.sample_style == "range":
  387. size = np.random.randint(self.min_size[0], self.min_size[1] + 1)
  388. else:
  389. size = np.random.choice(self.min_size)
  390. scale = size / min(h, w)
  391. if h < w:
  392. th, tw = size, scale * w
  393. else:
  394. th, tw = scale * h, size
  395. if max(th, tw) > self.max_size:
  396. scale = self.max_size / max(th, tw)
  397. th = th * scale
  398. tw = tw * scale
  399. th = int(round(th))
  400. tw = int(round(tw))
  401. return h, w, th, tw
  402. class RandomResize(VisionTransform):
  403. r"""Resize the input data randomly.
  404. Args:
  405. scale_range: range of scaling.
  406. order: the same with :class:`VisionTransform`.
  407. """
  408. def __init__(self, scale_range, interpolation=cv2.INTER_LINEAR, *, order=None):
  409. super().__init__(order)
  410. self.scale_range = scale_range
  411. self.interpolation = interpolation
  412. def apply(self, input: Tuple):
  413. self._shape_info = self._get_shape(self._get_image(input))
  414. return super().apply(input)
  415. def _apply_image(self, image):
  416. h, w, th, tw = self._shape_info
  417. if h == th and w == tw:
  418. return image
  419. return F.resize(image, (th, tw), self.interpolation)
  420. def _apply_coords(self, coords):
  421. h, w, th, tw = self._shape_info
  422. if h == th and w == tw:
  423. return coords
  424. coords[:, 0] = coords[:, 0] * (tw / w)
  425. coords[:, 1] = coords[:, 1] * (th / h)
  426. return coords
  427. def _apply_mask(self, mask):
  428. h, w, th, tw = self._shape_info
  429. if h == th and w == tw:
  430. return mask
  431. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  432. def _get_shape(self, image):
  433. h, w, _ = image.shape
  434. scale = np.random.uniform(*self.scale_range)
  435. th = int(round(h * scale))
  436. tw = int(round(w * scale))
  437. return h, w, th, tw
  438. class RandomCrop(VisionTransform):
  439. r"""Crop the input data randomly. Before applying the crop transform,
  440. pad the image first. If target size is still bigger than the size of
  441. padded image, pad the image size to target size.
  442. Args:
  443. output_size: target size of output image, with (height, width) shape.
  444. padding_size: the same with `size` in ``Pad``.
  445. padding_value: the same with `value` in ``Pad``.
  446. order: the same with :class:`VisionTransform`.
  447. """
  448. def __init__(
  449. self,
  450. output_size,
  451. padding_size=0,
  452. padding_value=[0, 0, 0],
  453. padding_maskvalue=0,
  454. *,
  455. order=None
  456. ):
  457. super().__init__(order)
  458. if isinstance(output_size, int):
  459. self.output_size = (output_size, output_size)
  460. else:
  461. self.output_size = output_size
  462. self.pad = Pad(padding_size, padding_value, order=self.order)
  463. self.padding_value = padding_value
  464. self.padding_maskvalue = padding_maskvalue
  465. def apply(self, input):
  466. input = self.pad.apply(input)
  467. self._h, self._w, _ = self._get_image(input).shape
  468. self._th, self._tw = self.output_size
  469. self._x = np.random.randint(0, max(0, self._w - self._tw) + 1)
  470. self._y = np.random.randint(0, max(0, self._h - self._th) + 1)
  471. return super().apply(input)
  472. def _apply_image(self, image):
  473. if self._th > self._h:
  474. image = F.pad(image, (self._th - self._h, 0), self.padding_value)
  475. if self._tw > self._w:
  476. image = F.pad(image, (0, self._tw - self._w), self.padding_value)
  477. return image[self._y : self._y + self._th, self._x : self._x + self._tw]
  478. def _apply_coords(self, coords):
  479. coords[:, 0] -= self._x
  480. coords[:, 1] -= self._y
  481. return coords
  482. def _apply_mask(self, mask):
  483. if self._th > self._h:
  484. mask = F.pad(mask, (self._th - self._h, 0), self.padding_maskvalue)
  485. if self._tw > self._w:
  486. mask = F.pad(mask, (0, self._tw - self._w), self.padding_maskvalue)
  487. return mask[self._y : self._y + self._th, self._x : self._x + self._tw]
  488. class RandomResizedCrop(VisionTransform):
  489. r"""Crop the input data to random size and aspect ratio.
  490. A crop of random size (default: of 0.08 to 1.0) of the original size and a random
  491. aspect ratio (default: of 3/4 to 1.33) of the original aspect ratio is made.
  492. After applying crop transfrom, the input data will be resized to given size.
  493. Args:
  494. output_size: target size of output image, with (height, width) shape.
  495. scale_range: range of size of the origin size cropped. Default: (0.08, 1.0)
  496. ratio_range: range of aspect ratio of the origin aspect ratio cropped. Default: (0.75, 1.33)
  497. order: the same with :class:`VisionTransform`.
  498. """
  499. def __init__(
  500. self,
  501. output_size,
  502. scale_range=(0.08, 1.0),
  503. ratio_range=(3.0 / 4, 4.0 / 3),
  504. interpolation=cv2.INTER_LINEAR,
  505. *,
  506. order=None
  507. ):
  508. super().__init__(order)
  509. if isinstance(output_size, int):
  510. self.output_size = (output_size, output_size)
  511. else:
  512. self.output_size = output_size
  513. assert (
  514. scale_range[0] <= scale_range[1]
  515. ), "scale_range should be of kind (min, max)"
  516. assert (
  517. ratio_range[0] <= ratio_range[1]
  518. ), "ratio_range should be of kind (min, max)"
  519. self.scale_range = scale_range
  520. self.ratio_range = ratio_range
  521. self.interpolation = interpolation
  522. def apply(self, input: Tuple):
  523. self._coord_info = self._get_coord(self._get_image(input))
  524. return super().apply(input)
  525. def _apply_image(self, image):
  526. x, y, w, h = self._coord_info
  527. cropped_img = image[y : y + h, x : x + w]
  528. return F.resize(cropped_img, self.output_size, self.interpolation)
  529. def _apply_coords(self, coords):
  530. x, y, w, h = self._coord_info
  531. coords[:, 0] = (coords[:, 0] - x) * self.output_size[1] / w
  532. coords[:, 1] = (coords[:, 1] - y) * self.output_size[0] / h
  533. return coords
  534. def _apply_mask(self, mask):
  535. x, y, w, h = self._coord_info
  536. cropped_mask = mask[y : y + h, x : x + w]
  537. return F.resize(cropped_mask, self.output_size, cv2.INTER_NEAREST)
  538. def _get_coord(self, image, attempts=10):
  539. height, width, _ = image.shape
  540. area = height * width
  541. for _ in range(attempts):
  542. target_area = np.random.uniform(*self.scale_range) * area
  543. log_ratio = tuple(math.log(x) for x in self.ratio_range)
  544. aspect_ratio = math.exp(np.random.uniform(*log_ratio))
  545. w = int(round(math.sqrt(target_area * aspect_ratio)))
  546. h = int(round(math.sqrt(target_area / aspect_ratio)))
  547. if 0 < w <= width and 0 < h <= height:
  548. x = np.random.randint(0, width - w + 1)
  549. y = np.random.randint(0, height - h + 1)
  550. return x, y, w, h
  551. # Fallback to central crop
  552. in_ratio = float(width) / float(height)
  553. if in_ratio < min(self.ratio_range):
  554. w = width
  555. h = int(round(w / min(self.ratio_range)))
  556. elif in_ratio > max(self.ratio_range):
  557. h = height
  558. w = int(round(h * max(self.ratio_range)))
  559. else: # whole image
  560. w = width
  561. h = height
  562. x = (width - w) // 2
  563. y = (height - h) // 2
  564. return x, y, w, h
  565. class CenterCrop(VisionTransform):
  566. r"""Crops the given the input data at the center.
  567. Args:
  568. output_size: target size of output image, with (height, width) shape.
  569. order: the same with :class:`VisionTransform`.
  570. """
  571. def __init__(self, output_size, *, order=None):
  572. super().__init__(order)
  573. if isinstance(output_size, int):
  574. self.output_size = (output_size, output_size)
  575. else:
  576. self.output_size = output_size
  577. def apply(self, input: Tuple):
  578. self._coord_info = self._get_coord(self._get_image(input))
  579. return super().apply(input)
  580. def _apply_image(self, image):
  581. x, y = self._coord_info
  582. th, tw = self.output_size
  583. return image[y : y + th, x : x + tw]
  584. def _apply_coords(self, coords):
  585. x, y = self._coord_info
  586. coords[:, 0] -= x
  587. coords[:, 1] -= y
  588. return coords
  589. def _apply_mask(self, mask):
  590. x, y = self._coord_info
  591. th, tw = self.output_size
  592. return mask[y : y + th, x : x + tw]
  593. def _get_coord(self, image):
  594. th, tw = self.output_size
  595. h, w, _ = image.shape
  596. assert th <= h and tw <= w, "output size is bigger than image size"
  597. x = int(round((w - tw) / 2.0))
  598. y = int(round((h - th) / 2.0))
  599. return x, y
  600. class RandomHorizontalFlip(VisionTransform):
  601. r"""Horizontally flip the input data randomly with a given probability.
  602. Args:
  603. p: probability of the input data being flipped. Default: 0.5
  604. order: the same with :class:`VisionTransform`.
  605. """
  606. def __init__(self, prob: float = 0.5, *, order=None):
  607. super().__init__(order)
  608. self.prob = prob
  609. def apply(self, input: Tuple):
  610. self._flipped = np.random.random() < self.prob
  611. self._w = self._get_image(input).shape[1]
  612. return super().apply(input)
  613. def _apply_image(self, image):
  614. if self._flipped:
  615. return F.flip(image, flipCode=1)
  616. return image
  617. def _apply_coords(self, coords):
  618. if self._flipped:
  619. coords[:, 0] = self._w - coords[:, 0]
  620. return coords
  621. def _apply_mask(self, mask):
  622. if self._flipped:
  623. return F.flip(mask, flipCode=1)
  624. return mask
  625. class RandomVerticalFlip(VisionTransform):
  626. r"""Vertically flip the input data randomly with a given probability.
  627. Args:
  628. p: probability of the input data being flipped. Default: 0.5
  629. order: the same with :class:`VisionTransform`.
  630. """
  631. def __init__(self, prob: float = 0.5, *, order=None):
  632. super().__init__(order)
  633. self.prob = prob
  634. def apply(self, input: Tuple):
  635. self._flipped = np.random.random() < self.prob
  636. self._h = self._get_image(input).shape[0]
  637. return super().apply(input)
  638. def _apply_image(self, image):
  639. if self._flipped:
  640. return F.flip(image, flipCode=0)
  641. return image
  642. def _apply_coords(self, coords):
  643. if self._flipped:
  644. coords[:, 1] = self._h - coords[:, 1]
  645. return coords
  646. def _apply_mask(self, mask):
  647. if self._flipped:
  648. return F.flip(mask, flipCode=0)
  649. return mask
  650. class Normalize(VisionTransform):
  651. r"""Normalize the input data with mean and standard deviation.
  652. Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
  653. this transform will normalize each channel of the input data.
  654. ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
  655. Args:
  656. mean: sequence of means for each channel.
  657. std: sequence of standard deviations for each channel.
  658. order: the same with :class:`VisionTransform`.
  659. """
  660. def __init__(self, mean=0.0, std=1.0, *, order=None):
  661. super().__init__(order)
  662. self.mean = np.array(mean, dtype=np.float32)
  663. self.std = np.array(std, dtype=np.float32)
  664. def _apply_image(self, image):
  665. return (image - self.mean) / self.std
  666. def _apply_coords(self, coords):
  667. return coords
  668. def _apply_mask(self, mask):
  669. return mask
  670. class GaussianNoise(VisionTransform):
  671. r"""Add random gaussian noise to the input data.
  672. Gaussian noise is generated with given mean and std.
  673. Args:
  674. mean: Gaussian mean used to generate noise.
  675. std: Gaussian standard deviation used to generate noise.
  676. order: the same with :class:`VisionTransform`
  677. """
  678. def __init__(self, mean=0.0, std=1.0, *, order=None):
  679. super().__init__(order)
  680. self.mean = np.array(mean, dtype=np.float32)
  681. self.std = np.array(std, dtype=np.float32)
  682. def _apply_image(self, image):
  683. dtype = image.dtype
  684. noise = np.random.normal(self.mean, self.std, image.shape) * 255
  685. image = image + noise.astype(np.float32)
  686. return np.clip(image, 0, 255).astype(dtype)
  687. def _apply_coords(self, coords):
  688. return coords
  689. def _apply_mask(self, mask):
  690. return mask
  691. class BrightnessTransform(VisionTransform):
  692. r"""Adjust brightness of the input data.
  693. Args:
  694. value: how much to adjust the brightness. Can be any
  695. non negative number. 0 gives the original image.
  696. order: the same with :class:`VisionTransform`.
  697. """
  698. def __init__(self, value, *, order=None):
  699. super().__init__(order)
  700. if value < 0:
  701. raise ValueError("brightness value should be non-negative")
  702. self.value = value
  703. def _apply_image(self, image):
  704. if self.value == 0:
  705. return image
  706. dtype = image.dtype
  707. image = image.astype(np.float32)
  708. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  709. image = image * alpha
  710. return image.clip(0, 255).astype(dtype)
  711. def _apply_coords(self, coords):
  712. return coords
  713. def _apply_mask(self, mask):
  714. return mask
  715. class ContrastTransform(VisionTransform):
  716. r"""Adjust contrast of the input data.
  717. Args:
  718. value: how much to adjust the contrast. Can be any
  719. non negative number. 0 gives the original image.
  720. order: the same with :class:`VisionTransform`.
  721. """
  722. def __init__(self, value, *, order=None):
  723. super().__init__(order)
  724. if value < 0:
  725. raise ValueError("contrast value should be non-negative")
  726. self.value = value
  727. def _apply_image(self, image):
  728. if self.value == 0:
  729. return image
  730. dtype = image.dtype
  731. image = image.astype(np.float32)
  732. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  733. image = image * alpha + F.to_gray(image).mean() * (1 - alpha)
  734. return image.clip(0, 255).astype(dtype)
  735. def _apply_coords(self, coords):
  736. return coords
  737. def _apply_mask(self, mask):
  738. return mask
  739. class SaturationTransform(VisionTransform):
  740. r"""Adjust saturation of the input data.
  741. Args:
  742. value: how much to adjust the saturation. Can be any
  743. non negative number. 0 gives the original image.
  744. order: the same with :class:`VisionTransform`.
  745. """
  746. def __init__(self, value, *, order=None):
  747. super().__init__(order)
  748. if value < 0:
  749. raise ValueError("saturation value should be non-negative")
  750. self.value = value
  751. def _apply_image(self, image):
  752. if self.value == 0:
  753. return image
  754. dtype = image.dtype
  755. image = image.astype(np.float32)
  756. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  757. image = image * alpha + F.to_gray(image) * (1 - alpha)
  758. return image.clip(0, 255).astype(dtype)
  759. def _apply_coords(self, coords):
  760. return coords
  761. def _apply_mask(self, mask):
  762. return mask
  763. class HueTransform(VisionTransform):
  764. r"""Adjust hue of the input data.
  765. Args:
  766. value: how much to adjust the hue. Can be any number
  767. between 0 and 0.5, 0 gives the original image.
  768. order: the same with :class:`VisionTransform`.
  769. """
  770. def __init__(self, value, *, order=None):
  771. super().__init__(order)
  772. if value < 0 or value > 0.5:
  773. raise ValueError("hue value should be in [0.0, 0.5]")
  774. self.value = value
  775. def _apply_image(self, image):
  776. if self.value == 0:
  777. return image
  778. dtype = image.dtype
  779. image = image.astype(np.uint8)
  780. hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV_FULL)
  781. h, s, v = cv2.split(hsv_image)
  782. alpha = np.random.uniform(-self.value, self.value)
  783. h = h.astype(np.uint8)
  784. # uint8 addition take cares of rotation across boundaries
  785. with np.errstate(over="ignore"):
  786. h += np.uint8(alpha * 255)
  787. hsv_image = cv2.merge([h, s, v])
  788. return cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR_FULL).astype(dtype)
  789. def _apply_coords(self, coords):
  790. return coords
  791. def _apply_mask(self, mask):
  792. return mask
  793. class ColorJitter(VisionTransform):
  794. r"""Randomly change the brightness, contrast, saturation and hue of an image.
  795. Args:
  796. brightness: how much to jitter brightness.
  797. Chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
  798. or the given [min, max]. Should be non negative numbers.
  799. contrast: how much to jitter contrast.
  800. Chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
  801. or the given [min, max]. Should be non negative numbers.
  802. saturation: how much to jitter saturation.
  803. Chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
  804. or the given [min, max]. Should be non negative numbers.
  805. hue: how much to jitter hue.
  806. Chosen uniformly from [-hue, hue] or the given [min, max].
  807. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
  808. order: the same with :class:`VisionTransform`.
  809. """
  810. def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, *, order=None):
  811. super().__init__(order)
  812. transforms = []
  813. if brightness != 0:
  814. transforms.append(BrightnessTransform(brightness))
  815. if contrast != 0:
  816. transforms.append(ContrastTransform(contrast))
  817. if saturation != 0:
  818. transforms.append(SaturationTransform(saturation))
  819. if hue != 0:
  820. transforms.append(HueTransform(hue))
  821. self.transforms = Compose(
  822. transforms,
  823. shuffle_indices=[tuple(range(1, len(transforms) + 1))],
  824. order=order,
  825. )
  826. def apply(self, input):
  827. return self.transforms.apply(input)
  828. class Lighting(VisionTransform):
  829. r"""Apply AlexNet-Style "lighting" augmentation to input data.
  830. Input images are assumed to have 'RGB' channel order.
  831. The degree of color jittering is randomly sampled via a normal distribution,
  832. with standard deviation given by the scale parameter.
  833. """
  834. def __init__(self, scale, *, order=None):
  835. super().__init__(order)
  836. if scale < 0:
  837. raise ValueError("lighting scale should be non-negative")
  838. self.scale = scale
  839. self.eigvec = np.array(
  840. [
  841. [-0.5836, -0.6948, 0.4203],
  842. [-0.5808, -0.0045, -0.8140],
  843. [-0.5675, 0.7192, 0.4009],
  844. ]
  845. ) # reverse the first dimension for BGR
  846. self.eigval = np.array([0.2175, 0.0188, 0.0045])
  847. def _apply_image(self, image):
  848. if self.scale == 0:
  849. return image
  850. dtype = image.dtype
  851. image = image.astype(np.float32)
  852. alpha = np.random.normal(scale=self.scale * 255, size=3)
  853. image = image + self.eigvec.dot(alpha * self.eigval)
  854. return image.clip(0, 255).astype(dtype)
  855. def _apply_coords(self, coords):
  856. return coords
  857. def _apply_mask(self, mask):
  858. return mask