You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.py 37 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections.abc
  10. import math
  11. from typing import List, Sequence, Tuple
  12. import cv2
  13. import numpy as np
  14. from megengine.data.transform import Transform
  15. from megengine.data.transform.vision import functional as F
  16. __all__ = [
  17. "VisionTransform",
  18. "ToMode",
  19. "Compose",
  20. "TorchTransformCompose",
  21. "Pad",
  22. "Resize",
  23. "ShortestEdgeResize",
  24. "RandomResize",
  25. "RandomCrop",
  26. "RandomResizedCrop",
  27. "CenterCrop",
  28. "RandomHorizontalFlip",
  29. "RandomVerticalFlip",
  30. "Normalize",
  31. "GaussianNoise",
  32. "BrightnessTransform",
  33. "SaturationTransform",
  34. "ContrastTransform",
  35. "HueTransform",
  36. "ColorJitter",
  37. "Lighting",
  38. ]
  39. class VisionTransform(Transform):
  40. r"""Base class of all transforms used in computer vision.
  41. Calling logic: apply_batch() -> apply() -> _apply_image() and other _apply_*()
  42. method. If you want to implement a self-defined transform method for image,
  43. rewrite _apply_image method in subclass.
  44. Args:
  45. order: input type order. Input is a tuple containing different structures,
  46. order is used to specify the order of structures. For example, if your input
  47. is (image, boxes) type, then the ``order`` should be ("image", "boxes").
  48. Current available strings and data type are describe below:
  49. * "image": input image, with shape of `(H, W, C)`.
  50. * "coords": coordinates, with shape of `(N, 2)`.
  51. * "boxes": bounding boxes, with shape of `(N, 4)`, "xyxy" format,
  52. the 1st "xy" represents top left point of a box,
  53. the 2nd "xy" represents right bottom point.
  54. * "mask": map used for segmentation, with shape of `(H, W, 1)`.
  55. * "keypoints": keypoints with shape of `(N, K, 3)`, N for number of instances,
  56. and K for number of keypoints in one instance. The first two dimensions
  57. of last axis is coordinate of keypoints and the the 3rd dimension is
  58. the label of keypoints.
  59. * "polygons": a sequence containing numpy arrays, its length is the number of instances.
  60. Each numpy array represents polygon coordinate of one instance.
  61. * "category": categories for some data type. For example, "image_category"
  62. means category of the input image and "boxes_category" means categories of
  63. bounding boxes.
  64. * "info": information for images such as image shapes and image path.
  65. You can also customize your data types only if you implement the corresponding
  66. _apply_*() methods, otherwise ``NotImplementedError`` will be raised.
  67. """
  68. def __init__(self, order=None):
  69. super().__init__()
  70. if order is None:
  71. order = ("image",)
  72. elif not isinstance(order, collections.abc.Sequence):
  73. raise ValueError(
  74. "order should be a sequence, but got order={}".format(order)
  75. )
  76. for k in order:
  77. if k in ("batch",):
  78. raise ValueError("{} is invalid data type".format(k))
  79. elif k.endswith("category") or k.endswith("info"):
  80. # when the key is *category or info, we should do nothing
  81. # if the corresponding apply methods are not implemented.
  82. continue
  83. elif self._get_apply(k) is None:
  84. raise NotImplementedError("{} is unsupported data type".format(k))
  85. self.order = order
  86. def apply_batch(self, inputs: Sequence[Tuple]):
  87. r"""Apply transform on batch input data."""
  88. return tuple(self.apply(input) for input in inputs)
  89. def apply(self, input: Tuple):
  90. r"""Apply transform on single input data."""
  91. if not isinstance(input, tuple):
  92. input = (input,)
  93. output = []
  94. for i in range(min(len(input), len(self.order))):
  95. apply_func = self._get_apply(self.order[i])
  96. if apply_func is None:
  97. output.append(input[i])
  98. else:
  99. output.append(apply_func(input[i]))
  100. if len(input) > len(self.order):
  101. output.extend(input[len(self.order) :])
  102. if len(output) == 1:
  103. output = output[0]
  104. else:
  105. output = tuple(output)
  106. return output
  107. def _get_apply(self, key):
  108. return getattr(self, "_apply_{}".format(key), None)
  109. def _get_image(self, input: Tuple):
  110. if not isinstance(input, tuple):
  111. input = (input,)
  112. return input[self.order.index("image")]
  113. def _apply_image(self, image):
  114. raise NotImplementedError
  115. def _apply_coords(self, coords):
  116. raise NotImplementedError
  117. def _apply_boxes(self, boxes):
  118. idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
  119. coords = np.asarray(boxes).reshape(-1, 4)[:, idxs].reshape(-1, 2)
  120. coords = self._apply_coords(coords).reshape((-1, 4, 2))
  121. minxy = coords.min(axis=1)
  122. maxxy = coords.max(axis=1)
  123. trans_boxes = np.concatenate((minxy, maxxy), axis=1)
  124. return trans_boxes
  125. def _apply_mask(self, mask):
  126. raise NotImplementedError
  127. def _apply_keypoints(self, keypoints):
  128. coords, visibility = keypoints[..., :2], keypoints[..., 2:]
  129. trans_coords = [self._apply_coords(p) for p in coords]
  130. return np.concatenate((trans_coords, visibility), axis=-1)
  131. def _apply_polygons(self, polygons):
  132. return [[self._apply_coords(p) for p in instance] for instance in polygons]
  133. class ToMode(VisionTransform):
  134. r"""Change input data to a target mode.
  135. For example, most transforms use HWC mode image,
  136. while the neural network might use CHW mode input tensor.
  137. Args:
  138. mode: output mode of input. Default: "CHW"
  139. order: the same with :class:`VisionTransform`
  140. """
  141. def __init__(self, mode="CHW", *, order=None):
  142. super().__init__(order)
  143. assert mode in ["CHW"], "unsupported mode: {}".format(mode)
  144. self.mode = mode
  145. def _apply_image(self, image):
  146. if self.mode == "CHW":
  147. return np.ascontiguousarray(np.rollaxis(image, 2))
  148. return image
  149. def _apply_coords(self, coords):
  150. return coords
  151. def _apply_mask(self, mask):
  152. if self.mode == "CHW":
  153. return np.ascontiguousarray(np.rollaxis(mask, 2))
  154. return mask
  155. class Compose(VisionTransform):
  156. r"""Composes several transfomations together.
  157. Args:
  158. transforms: list of :class:`VisionTransform` to compose.
  159. batch_compose: whether keep the same transform order in batch data when shuffle.
  160. shuffle_indices: indices used for random shuffle, start at 1.
  161. order: the same with :class:`VisionTransform`
  162. .. seealso:: Refer to :mod:`~.data.transform` module for vision transform APIs.
  163. Examples:
  164. >>> import megengine.data.transform as T
  165. >>> T.Compose([ # doctest: +SKIP
  166. ... T.RandomHorizontalFlip(), # 1st
  167. ... T.RandomVerticalFlip(), # 2nd
  168. ... T.CenterCrop(100), # 3rd
  169. ... T.ToMode("CHW"), # 4th
  170. ... ],
  171. ... shuffle_indices=[(1, 2, 3)]
  172. ... )
  173. In this case, ``shuffle_indices`` is given so each input data will be transformed
  174. out of order:
  175. .. math::
  176. \begin{array}{cc}
  177. [{\color{red}1 \quad 2 \quad 3} \quad 4] & [{\color{red}1 \quad 3 \quad 2} \quad 4] \\
  178. [{\color{red}2 \quad 1 \quad 3} \quad 4] & [{\color{red}2 \quad 3 \quad 1} \quad 4] \\
  179. [{\color{red}3 \quad 1 \quad 2} \quad 4] & [{\color{red}3 \quad 2 \quad 1} \quad 4]
  180. \end{array}
  181. In another case, if ``[(1, 3), (2, 4)]`` is given, then the 1st and 3rd transfomation
  182. will be random shuffled, the 2nd and 4th transfomation will also be shuffled:
  183. .. math::
  184. \begin{array}{cc}
  185. [{\color{red}1} \quad {\color{blue}2} \quad {\color{red}3} \quad {\color{blue}4}] &
  186. [{\color{red}1} \quad {\color{blue}4} \quad {\color{red}3} \quad {\color{blue}2}] \\
  187. [{\color{red}3} \quad {\color{blue}2} \quad {\color{red}1} \quad {\color{blue}4}] &
  188. [{\color{red}3} \quad {\color{blue}4} \quad {\color{red}1} \quad {\color{blue}2}]
  189. \end{array}
  190. Different colors represent different groups that need to be internally shuffled.
  191. .. warning::
  192. Different samples within each batch will also use random transfomation orders,
  193. unless ``batch_compose`` is set to ``True``.
  194. """
  195. def __init__(
  196. self,
  197. transforms: List[VisionTransform] = [],
  198. batch_compose: bool = False,
  199. shuffle_indices: List[Tuple] = None,
  200. *,
  201. order=None
  202. ):
  203. super().__init__(order)
  204. self.transforms = transforms
  205. self._set_order()
  206. if batch_compose and shuffle_indices is not None:
  207. raise ValueError(
  208. "Do not support shuffle when apply transforms along the whole batch"
  209. )
  210. self.batch_compose = batch_compose
  211. if shuffle_indices is not None:
  212. shuffle_indices = [tuple(x - 1 for x in idx) for idx in shuffle_indices]
  213. self.shuffle_indices = shuffle_indices
  214. def _set_order(self):
  215. for t in self.transforms:
  216. t.order = self.order
  217. if isinstance(t, Compose):
  218. t._set_order()
  219. def apply_batch(self, inputs: Sequence[Tuple]):
  220. if self.batch_compose:
  221. for t in self.transforms:
  222. inputs = t.apply_batch(inputs)
  223. return inputs
  224. else:
  225. return super().apply_batch(inputs)
  226. def apply(self, input: Tuple):
  227. for t in self._shuffle():
  228. input = t.apply(input)
  229. return input
  230. def _shuffle(self):
  231. if self.shuffle_indices is not None:
  232. source_idx = list(range(len(self.transforms)))
  233. for idx in self.shuffle_indices:
  234. shuffled = np.random.permutation(idx).tolist()
  235. for src, dst in zip(idx, shuffled):
  236. source_idx[src] = dst
  237. return [self.transforms[i] for i in source_idx]
  238. else:
  239. return self.transforms
  240. class TorchTransformCompose(VisionTransform):
  241. r"""Compose class used for transforms in torchvision, only support PIL image,
  242. some transforms with tensor in torchvision are not supported,
  243. such as Normalize and ToTensor in torchvision.
  244. Args:
  245. transforms: the same with ``Compose``.
  246. order: the same with :class:`VisionTransform`.
  247. """
  248. def __init__(self, transforms, *, order=None):
  249. super().__init__(order)
  250. self.transforms = transforms
  251. def _apply_image(self, image):
  252. from PIL import Image
  253. try:
  254. import accimage
  255. except ImportError:
  256. accimage = None
  257. if image.shape[0] == 3: # CHW
  258. image = np.ascontiguousarray(image[[2, 1, 0]])
  259. elif image.shape[2] == 3: # HWC
  260. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  261. image = Image.fromarray(image.astype(np.uint8))
  262. for t in self.transforms:
  263. image = t(image)
  264. if isinstance(image, Image.Image) or (
  265. accimage is not None and isinstance(image, accimage.Image)
  266. ):
  267. image = np.array(image, dtype=np.uint8)
  268. if image.shape[0] == 3: # CHW
  269. image = np.ascontiguousarray(image[[2, 1, 0]])
  270. elif image.shape[2] == 3: # HWC
  271. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  272. return image
  273. class Pad(VisionTransform):
  274. r"""Pad the input data.
  275. Args:
  276. size: padding size of input image, it could be integer or sequence.
  277. If it is an integer, the input image will be padded in four directions.
  278. If it is a sequence containing two integers, the bottom and right side
  279. of image will be padded.
  280. If it is a sequence containing four integers, the top, bottom, left, right
  281. side of image will be padded with given size.
  282. value: padding value of image, could be a sequence of int or float.
  283. if it is float value, the dtype of image will be casted to float32 also.
  284. mask_value: padding value of segmentation map.
  285. order: the same with :class:`VisionTransform`.
  286. """
  287. def __init__(self, size=0, value=0, mask_value=0, *, order=None):
  288. super().__init__(order)
  289. if isinstance(size, int):
  290. size = (size, size, size, size)
  291. elif isinstance(size, collections.abc.Sequence) and len(size) == 2:
  292. size = (0, size[0], 0, size[1])
  293. elif not (isinstance(size, collections.abc.Sequence) and len(size) == 4):
  294. raise ValueError(
  295. "size should be a list/tuple which contains "
  296. "(top, down, left, right) four pad sizes."
  297. )
  298. self.size = size
  299. self.value = value
  300. if not isinstance(mask_value, int):
  301. raise ValueError(
  302. "mask_value should be a positive integer, "
  303. "but got mask_value={}".format(mask_value)
  304. )
  305. self.mask_value = mask_value
  306. def _apply_image(self, image):
  307. return F.pad(image, self.size, self.value)
  308. def _apply_coords(self, coords):
  309. coords[:, 0] += self.size[2]
  310. coords[:, 1] += self.size[0]
  311. return coords
  312. def _apply_mask(self, mask):
  313. return F.pad(mask, self.size, self.mask_value)
  314. class Resize(VisionTransform):
  315. r"""Resize the input data.
  316. Args:
  317. output_size: target size of image, with (height, width) shape.
  318. interpolation: interpolation method. All methods are listed below:
  319. * cv2.INTER_NEAREST – a nearest-neighbor interpolation.
  320. * cv2.INTER_LINEAR – a bilinear interpolation (used by default).
  321. * cv2.INTER_AREA – resampling using pixel area relation.
  322. * cv2.INTER_CUBIC – a bicubic interpolation over 4×4 pixel neighborhood.
  323. * cv2.INTER_LANCZOS4 – a Lanczos interpolation over 8×8 pixel neighborhood.
  324. order: the same with :class:`VisionTransform`.
  325. """
  326. def __init__(self, output_size, interpolation=cv2.INTER_LINEAR, *, order=None):
  327. super().__init__(order)
  328. self.output_size = output_size
  329. self.interpolation = interpolation
  330. def apply(self, input: Tuple):
  331. self._shape_info = self._get_shape(self._get_image(input))
  332. return super().apply(input)
  333. def _apply_image(self, image):
  334. h, w, th, tw = self._shape_info
  335. if h == th and w == tw:
  336. return image
  337. return F.resize(image, (th, tw), self.interpolation)
  338. def _apply_coords(self, coords):
  339. h, w, th, tw = self._shape_info
  340. if h == th and w == tw:
  341. return coords
  342. coords[:, 0] = coords[:, 0] * (tw / w)
  343. coords[:, 1] = coords[:, 1] * (th / h)
  344. return coords
  345. def _apply_mask(self, mask):
  346. h, w, th, tw = self._shape_info
  347. if h == th and w == tw:
  348. return mask
  349. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  350. def _get_shape(self, image):
  351. h, w, _ = image.shape
  352. if isinstance(self.output_size, int):
  353. if min(h, w) == self.output_size:
  354. return h, w, h, w
  355. if h < w:
  356. th = self.output_size
  357. tw = int(self.output_size * w / h)
  358. else:
  359. tw = self.output_size
  360. th = int(self.output_size * h / w)
  361. return h, w, th, tw
  362. else:
  363. return (h, w, *self.output_size)
  364. class ShortestEdgeResize(VisionTransform):
  365. r"""Resize the input data with specified shortset edge."""
  366. def __init__(
  367. self,
  368. min_size,
  369. max_size,
  370. sample_style="range",
  371. interpolation=cv2.INTER_LINEAR,
  372. *,
  373. order=None
  374. ):
  375. super().__init__(order)
  376. if sample_style not in ("range", "choice"):
  377. raise NotImplementedError(
  378. "{} is unsupported sample style".format(sample_style)
  379. )
  380. self.sample_style = sample_style
  381. if isinstance(min_size, int):
  382. min_size = (min_size, min_size)
  383. self.min_size = min_size
  384. self.max_size = max_size
  385. self.interpolation = interpolation
  386. def apply(self, input: Tuple):
  387. self._shape_info = self._get_shape(self._get_image(input))
  388. return super().apply(input)
  389. def _apply_image(self, image):
  390. h, w, th, tw = self._shape_info
  391. if h == th and w == tw:
  392. return image
  393. return F.resize(image, (th, tw), self.interpolation)
  394. def _apply_coords(self, coords):
  395. h, w, th, tw = self._shape_info
  396. if h == th and w == tw:
  397. return coords
  398. coords[:, 0] = coords[:, 0] * (tw / w)
  399. coords[:, 1] = coords[:, 1] * (th / h)
  400. return coords
  401. def _apply_mask(self, mask):
  402. h, w, th, tw = self._shape_info
  403. if h == th and w == tw:
  404. return mask
  405. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  406. def _get_shape(self, image):
  407. h, w, _ = image.shape
  408. if self.sample_style == "range":
  409. size = np.random.randint(self.min_size[0], self.min_size[1] + 1)
  410. else:
  411. size = np.random.choice(self.min_size)
  412. scale = size / min(h, w)
  413. if h < w:
  414. th, tw = size, scale * w
  415. else:
  416. th, tw = scale * h, size
  417. if max(th, tw) > self.max_size:
  418. scale = self.max_size / max(th, tw)
  419. th = th * scale
  420. tw = tw * scale
  421. th = int(round(th))
  422. tw = int(round(tw))
  423. return h, w, th, tw
  424. class RandomResize(VisionTransform):
  425. r"""Resize the input data randomly.
  426. Args:
  427. scale_range: range of scaling.
  428. order: the same with :class:`VisionTransform`.
  429. """
  430. def __init__(self, scale_range, interpolation=cv2.INTER_LINEAR, *, order=None):
  431. super().__init__(order)
  432. self.scale_range = scale_range
  433. self.interpolation = interpolation
  434. def apply(self, input: Tuple):
  435. self._shape_info = self._get_shape(self._get_image(input))
  436. return super().apply(input)
  437. def _apply_image(self, image):
  438. h, w, th, tw = self._shape_info
  439. if h == th and w == tw:
  440. return image
  441. return F.resize(image, (th, tw), self.interpolation)
  442. def _apply_coords(self, coords):
  443. h, w, th, tw = self._shape_info
  444. if h == th and w == tw:
  445. return coords
  446. coords[:, 0] = coords[:, 0] * (tw / w)
  447. coords[:, 1] = coords[:, 1] * (th / h)
  448. return coords
  449. def _apply_mask(self, mask):
  450. h, w, th, tw = self._shape_info
  451. if h == th and w == tw:
  452. return mask
  453. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  454. def _get_shape(self, image):
  455. h, w, _ = image.shape
  456. scale = np.random.uniform(*self.scale_range)
  457. th = int(round(h * scale))
  458. tw = int(round(w * scale))
  459. return h, w, th, tw
  460. class RandomCrop(VisionTransform):
  461. r"""Crop the input data randomly. Before applying the crop transform,
  462. pad the image first. If target size is still bigger than the size of
  463. padded image, pad the image size to target size.
  464. Args:
  465. output_size: target size of output image, with (height, width) shape.
  466. padding_size: the same with `size` in ``Pad``.
  467. padding_value: the same with `value` in ``Pad``.
  468. order: the same with :class:`VisionTransform`.
  469. """
  470. def __init__(
  471. self,
  472. output_size,
  473. padding_size=0,
  474. padding_value=[0, 0, 0],
  475. padding_maskvalue=0,
  476. *,
  477. order=None
  478. ):
  479. super().__init__(order)
  480. if isinstance(output_size, int):
  481. self.output_size = (output_size, output_size)
  482. else:
  483. self.output_size = output_size
  484. self.pad = Pad(padding_size, padding_value, order=self.order)
  485. self.padding_value = padding_value
  486. self.padding_maskvalue = padding_maskvalue
  487. def apply(self, input):
  488. input = self.pad.apply(input)
  489. self._h, self._w, _ = self._get_image(input).shape
  490. self._th, self._tw = self.output_size
  491. self._x = np.random.randint(0, max(0, self._w - self._tw) + 1)
  492. self._y = np.random.randint(0, max(0, self._h - self._th) + 1)
  493. return super().apply(input)
  494. def _apply_image(self, image):
  495. if self._th > self._h:
  496. image = F.pad(image, (self._th - self._h, 0), self.padding_value)
  497. if self._tw > self._w:
  498. image = F.pad(image, (0, self._tw - self._w), self.padding_value)
  499. return image[self._y : self._y + self._th, self._x : self._x + self._tw]
  500. def _apply_coords(self, coords):
  501. coords[:, 0] -= self._x
  502. coords[:, 1] -= self._y
  503. return coords
  504. def _apply_mask(self, mask):
  505. if self._th > self._h:
  506. mask = F.pad(mask, (self._th - self._h, 0), self.padding_maskvalue)
  507. if self._tw > self._w:
  508. mask = F.pad(mask, (0, self._tw - self._w), self.padding_maskvalue)
  509. return mask[self._y : self._y + self._th, self._x : self._x + self._tw]
  510. class RandomResizedCrop(VisionTransform):
  511. r"""Crop the input data to random size and aspect ratio.
  512. A crop of random size (default: of 0.08 to 1.0) of the original size and a random
  513. aspect ratio (default: of 3/4 to 1.33) of the original aspect ratio is made.
  514. After applying crop transfrom, the input data will be resized to given size.
  515. Args:
  516. output_size: target size of output image, with (height, width) shape.
  517. scale_range: range of size of the origin size cropped. Default: (0.08, 1.0)
  518. ratio_range: range of aspect ratio of the origin aspect ratio cropped. Default: (0.75, 1.33)
  519. order: the same with :class:`VisionTransform`.
  520. """
  521. def __init__(
  522. self,
  523. output_size,
  524. scale_range=(0.08, 1.0),
  525. ratio_range=(3.0 / 4, 4.0 / 3),
  526. interpolation=cv2.INTER_LINEAR,
  527. *,
  528. order=None
  529. ):
  530. super().__init__(order)
  531. if isinstance(output_size, int):
  532. self.output_size = (output_size, output_size)
  533. else:
  534. self.output_size = output_size
  535. assert (
  536. scale_range[0] <= scale_range[1]
  537. ), "scale_range should be of kind (min, max)"
  538. assert (
  539. ratio_range[0] <= ratio_range[1]
  540. ), "ratio_range should be of kind (min, max)"
  541. self.scale_range = scale_range
  542. self.ratio_range = ratio_range
  543. self.interpolation = interpolation
  544. def apply(self, input: Tuple):
  545. self._coord_info = self._get_coord(self._get_image(input))
  546. return super().apply(input)
  547. def _apply_image(self, image):
  548. x, y, w, h = self._coord_info
  549. cropped_img = image[y : y + h, x : x + w]
  550. return F.resize(cropped_img, self.output_size, self.interpolation)
  551. def _apply_coords(self, coords):
  552. x, y, w, h = self._coord_info
  553. coords[:, 0] = (coords[:, 0] - x) * self.output_size[1] / w
  554. coords[:, 1] = (coords[:, 1] - y) * self.output_size[0] / h
  555. return coords
  556. def _apply_mask(self, mask):
  557. x, y, w, h = self._coord_info
  558. cropped_mask = mask[y : y + h, x : x + w]
  559. return F.resize(cropped_mask, self.output_size, cv2.INTER_NEAREST)
  560. def _get_coord(self, image, attempts=10):
  561. height, width, _ = image.shape
  562. area = height * width
  563. for _ in range(attempts):
  564. target_area = np.random.uniform(*self.scale_range) * area
  565. log_ratio = tuple(math.log(x) for x in self.ratio_range)
  566. aspect_ratio = math.exp(np.random.uniform(*log_ratio))
  567. w = int(round(math.sqrt(target_area * aspect_ratio)))
  568. h = int(round(math.sqrt(target_area / aspect_ratio)))
  569. if 0 < w <= width and 0 < h <= height:
  570. x = np.random.randint(0, width - w + 1)
  571. y = np.random.randint(0, height - h + 1)
  572. return x, y, w, h
  573. # Fallback to central crop
  574. in_ratio = float(width) / float(height)
  575. if in_ratio < min(self.ratio_range):
  576. w = width
  577. h = int(round(w / min(self.ratio_range)))
  578. elif in_ratio > max(self.ratio_range):
  579. h = height
  580. w = int(round(h * max(self.ratio_range)))
  581. else: # whole image
  582. w = width
  583. h = height
  584. x = (width - w) // 2
  585. y = (height - h) // 2
  586. return x, y, w, h
  587. class CenterCrop(VisionTransform):
  588. r"""Crops the given the input data at the center.
  589. Args:
  590. output_size: target size of output image, with (height, width) shape.
  591. order: the same with :class:`VisionTransform`.
  592. """
  593. def __init__(self, output_size, *, order=None):
  594. super().__init__(order)
  595. if isinstance(output_size, int):
  596. self.output_size = (output_size, output_size)
  597. else:
  598. self.output_size = output_size
  599. def apply(self, input: Tuple):
  600. self._coord_info = self._get_coord(self._get_image(input))
  601. return super().apply(input)
  602. def _apply_image(self, image):
  603. x, y = self._coord_info
  604. th, tw = self.output_size
  605. return image[y : y + th, x : x + tw]
  606. def _apply_coords(self, coords):
  607. x, y = self._coord_info
  608. coords[:, 0] -= x
  609. coords[:, 1] -= y
  610. return coords
  611. def _apply_mask(self, mask):
  612. x, y = self._coord_info
  613. th, tw = self.output_size
  614. return mask[y : y + th, x : x + tw]
  615. def _get_coord(self, image):
  616. th, tw = self.output_size
  617. h, w, _ = image.shape
  618. assert th <= h and tw <= w, "output size is bigger than image size"
  619. x = int(round((w - tw) / 2.0))
  620. y = int(round((h - th) / 2.0))
  621. return x, y
  622. class RandomHorizontalFlip(VisionTransform):
  623. r"""Horizontally flip the input data randomly with a given probability.
  624. Args:
  625. p: probability of the input data being flipped. Default: 0.5
  626. order: the same with :class:`VisionTransform`.
  627. """
  628. def __init__(self, prob: float = 0.5, *, order=None):
  629. super().__init__(order)
  630. self.prob = prob
  631. def apply(self, input: Tuple):
  632. self._flipped = np.random.random() < self.prob
  633. self._w = self._get_image(input).shape[1]
  634. return super().apply(input)
  635. def _apply_image(self, image):
  636. if self._flipped:
  637. return F.flip(image, flipCode=1)
  638. return image
  639. def _apply_coords(self, coords):
  640. if self._flipped:
  641. coords[:, 0] = self._w - coords[:, 0]
  642. return coords
  643. def _apply_mask(self, mask):
  644. if self._flipped:
  645. return F.flip(mask, flipCode=1)
  646. return mask
  647. class RandomVerticalFlip(VisionTransform):
  648. r"""Vertically flip the input data randomly with a given probability.
  649. Args:
  650. p: probability of the input data being flipped. Default: 0.5
  651. order: the same with :class:`VisionTransform`.
  652. """
  653. def __init__(self, prob: float = 0.5, *, order=None):
  654. super().__init__(order)
  655. self.prob = prob
  656. def apply(self, input: Tuple):
  657. self._flipped = np.random.random() < self.prob
  658. self._h = self._get_image(input).shape[0]
  659. return super().apply(input)
  660. def _apply_image(self, image):
  661. if self._flipped:
  662. return F.flip(image, flipCode=0)
  663. return image
  664. def _apply_coords(self, coords):
  665. if self._flipped:
  666. coords[:, 1] = self._h - coords[:, 1]
  667. return coords
  668. def _apply_mask(self, mask):
  669. if self._flipped:
  670. return F.flip(mask, flipCode=0)
  671. return mask
  672. class Normalize(VisionTransform):
  673. r"""Normalize the input data with mean and standard deviation.
  674. Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
  675. this transform will normalize each channel of the input data.
  676. ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
  677. Args:
  678. mean: sequence of means for each channel.
  679. std: sequence of standard deviations for each channel.
  680. order: the same with :class:`VisionTransform`.
  681. """
  682. def __init__(self, mean=0.0, std=1.0, *, order=None):
  683. super().__init__(order)
  684. self.mean = np.array(mean, dtype=np.float32)
  685. self.std = np.array(std, dtype=np.float32)
  686. def _apply_image(self, image):
  687. return (image - self.mean) / self.std
  688. def _apply_coords(self, coords):
  689. return coords
  690. def _apply_mask(self, mask):
  691. return mask
  692. class GaussianNoise(VisionTransform):
  693. r"""Add random gaussian noise to the input data.
  694. Gaussian noise is generated with given mean and std.
  695. Args:
  696. mean: Gaussian mean used to generate noise.
  697. std: Gaussian standard deviation used to generate noise.
  698. order: the same with :class:`VisionTransform`
  699. """
  700. def __init__(self, mean=0.0, std=1.0, *, order=None):
  701. super().__init__(order)
  702. self.mean = np.array(mean, dtype=np.float32)
  703. self.std = np.array(std, dtype=np.float32)
  704. def _apply_image(self, image):
  705. dtype = image.dtype
  706. noise = np.random.normal(self.mean, self.std, image.shape) * 255
  707. image = image + noise.astype(np.float32)
  708. return np.clip(image, 0, 255).astype(dtype)
  709. def _apply_coords(self, coords):
  710. return coords
  711. def _apply_mask(self, mask):
  712. return mask
  713. class BrightnessTransform(VisionTransform):
  714. r"""Adjust brightness of the input data.
  715. Args:
  716. value: how much to adjust the brightness. Can be any
  717. non negative number. 0 gives the original image.
  718. order: the same with :class:`VisionTransform`.
  719. """
  720. def __init__(self, value, *, order=None):
  721. super().__init__(order)
  722. if value < 0:
  723. raise ValueError("brightness value should be non-negative")
  724. self.value = value
  725. def _apply_image(self, image):
  726. if self.value == 0:
  727. return image
  728. dtype = image.dtype
  729. image = image.astype(np.float32)
  730. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  731. image = image * alpha
  732. return image.clip(0, 255).astype(dtype)
  733. def _apply_coords(self, coords):
  734. return coords
  735. def _apply_mask(self, mask):
  736. return mask
  737. class ContrastTransform(VisionTransform):
  738. r"""Adjust contrast of the input data.
  739. Args:
  740. value: how much to adjust the contrast. Can be any
  741. non negative number. 0 gives the original image.
  742. order: the same with :class:`VisionTransform`.
  743. """
  744. def __init__(self, value, *, order=None):
  745. super().__init__(order)
  746. if value < 0:
  747. raise ValueError("contrast value should be non-negative")
  748. self.value = value
  749. def _apply_image(self, image):
  750. if self.value == 0:
  751. return image
  752. dtype = image.dtype
  753. image = image.astype(np.float32)
  754. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  755. image = image * alpha + F.to_gray(image).mean() * (1 - alpha)
  756. return image.clip(0, 255).astype(dtype)
  757. def _apply_coords(self, coords):
  758. return coords
  759. def _apply_mask(self, mask):
  760. return mask
  761. class SaturationTransform(VisionTransform):
  762. r"""Adjust saturation of the input data.
  763. Args:
  764. value: how much to adjust the saturation. Can be any
  765. non negative number. 0 gives the original image.
  766. order: the same with :class:`VisionTransform`.
  767. """
  768. def __init__(self, value, *, order=None):
  769. super().__init__(order)
  770. if value < 0:
  771. raise ValueError("saturation value should be non-negative")
  772. self.value = value
  773. def _apply_image(self, image):
  774. if self.value == 0:
  775. return image
  776. dtype = image.dtype
  777. image = image.astype(np.float32)
  778. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  779. image = image * alpha + F.to_gray(image) * (1 - alpha)
  780. return image.clip(0, 255).astype(dtype)
  781. def _apply_coords(self, coords):
  782. return coords
  783. def _apply_mask(self, mask):
  784. return mask
  785. class HueTransform(VisionTransform):
  786. r"""Adjust hue of the input data.
  787. Args:
  788. value: how much to adjust the hue. Can be any number
  789. between 0 and 0.5, 0 gives the original image.
  790. order: the same with :class:`VisionTransform`.
  791. """
  792. def __init__(self, value, *, order=None):
  793. super().__init__(order)
  794. if value < 0 or value > 0.5:
  795. raise ValueError("hue value should be in [0.0, 0.5]")
  796. self.value = value
  797. def _apply_image(self, image):
  798. if self.value == 0:
  799. return image
  800. dtype = image.dtype
  801. image = image.astype(np.uint8)
  802. hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV_FULL)
  803. h, s, v = cv2.split(hsv_image)
  804. alpha = np.random.uniform(-self.value, self.value)
  805. h = h.astype(np.uint8)
  806. # uint8 addition take cares of rotation across boundaries
  807. with np.errstate(over="ignore"):
  808. h += np.uint8(alpha * 255)
  809. hsv_image = cv2.merge([h, s, v])
  810. return cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR_FULL).astype(dtype)
  811. def _apply_coords(self, coords):
  812. return coords
  813. def _apply_mask(self, mask):
  814. return mask
  815. class ColorJitter(VisionTransform):
  816. r"""Randomly change the brightness, contrast, saturation and hue of an image.
  817. Args:
  818. brightness: how much to jitter brightness.
  819. Chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
  820. or the given [min, max]. Should be non negative numbers.
  821. contrast: how much to jitter contrast.
  822. Chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
  823. or the given [min, max]. Should be non negative numbers.
  824. saturation: how much to jitter saturation.
  825. Chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
  826. or the given [min, max]. Should be non negative numbers.
  827. hue: how much to jitter hue.
  828. Chosen uniformly from [-hue, hue] or the given [min, max].
  829. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
  830. order: the same with :class:`VisionTransform`.
  831. """
  832. def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, *, order=None):
  833. super().__init__(order)
  834. transforms = []
  835. if brightness != 0:
  836. transforms.append(BrightnessTransform(brightness))
  837. if contrast != 0:
  838. transforms.append(ContrastTransform(contrast))
  839. if saturation != 0:
  840. transforms.append(SaturationTransform(saturation))
  841. if hue != 0:
  842. transforms.append(HueTransform(hue))
  843. self.transforms = Compose(
  844. transforms,
  845. shuffle_indices=[tuple(range(1, len(transforms) + 1))],
  846. order=order,
  847. )
  848. def apply(self, input):
  849. return self.transforms.apply(input)
  850. class Lighting(VisionTransform):
  851. r"""Apply AlexNet-Style "lighting" augmentation to input data.
  852. Input images are assumed to have 'RGB' channel order.
  853. The degree of color jittering is randomly sampled via a normal distribution,
  854. with standard deviation given by the scale parameter.
  855. """
  856. def __init__(self, scale, *, order=None):
  857. super().__init__(order)
  858. if scale < 0:
  859. raise ValueError("lighting scale should be non-negative")
  860. self.scale = scale
  861. self.eigvec = np.array(
  862. [
  863. [-0.5836, -0.6948, 0.4203],
  864. [-0.5808, -0.0045, -0.8140],
  865. [-0.5675, 0.7192, 0.4009],
  866. ]
  867. ) # reverse the first dimension for BGR
  868. self.eigval = np.array([0.2175, 0.0188, 0.0045])
  869. def _apply_image(self, image):
  870. if self.scale == 0:
  871. return image
  872. dtype = image.dtype
  873. image = image.astype(np.float32)
  874. alpha = np.random.normal(scale=self.scale * 255, size=3)
  875. image = image + self.eigvec.dot(alpha * self.eigval)
  876. return image.clip(0, 255).astype(dtype)
  877. def _apply_coords(self, coords):
  878. return coords
  879. def _apply_mask(self, mask):
  880. return mask