You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805
  1. # A synchronized version modified from https://github.com/pytorch/vision
  2. from __future__ import division
  3. import torch
  4. import sys
  5. import math
  6. from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
  7. try:
  8. import accimage
  9. except ImportError:
  10. accimage = None
  11. import numpy as np
  12. import numbers
  13. import collections
  14. import warnings
  15. if sys.version_info < (3, 3):
  16. Sequence = collections.Sequence
  17. Iterable = collections.Iterable
  18. else:
  19. Sequence = collections.abc.Sequence
  20. Iterable = collections.abc.Iterable
  21. def _is_pil_image(img):
  22. if accimage is not None:
  23. return isinstance(img, (Image.Image, accimage.Image))
  24. else:
  25. return isinstance(img, Image.Image)
  26. def _is_tensor_image(img):
  27. return torch.is_tensor(img) and img.ndimension() == 3
  28. def _is_numpy_image(img):
  29. return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
  30. def _is_numpy(img):
  31. return isinstance(img, np.ndarray)
  32. def to_tensor(pic, normalize=True, dtype=None):
  33. img = _to_tensor( pic, normalize=normalize )
  34. if dtype is not None:
  35. img = img.to(dtype=dtype)
  36. return img
  37. def _to_tensor(pic, normalize):
  38. """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
  39. See ``ToTensor`` for more details.
  40. Args:
  41. pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
  42. Returns:
  43. Tensor: Converted image.
  44. """
  45. if not(_is_pil_image(pic) or _is_numpy(pic)):
  46. raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
  47. if _is_numpy(pic) and not _is_numpy_image(pic):
  48. raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
  49. if isinstance(pic, np.ndarray):
  50. # handle numpy array
  51. if pic.ndim == 2:
  52. pic = pic[:, :, None]
  53. img = torch.from_numpy(pic.transpose((2, 0, 1)))
  54. # backward compatibility
  55. if normalize==True and isinstance(img, torch.ByteTensor):
  56. return img.float().div(255)
  57. else:
  58. return img
  59. if accimage is not None and isinstance(pic, accimage.Image):
  60. nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
  61. pic.copyto(nppic)
  62. return torch.from_numpy(nppic)
  63. # handle PIL Image
  64. if pic.mode == 'I':
  65. img = torch.from_numpy(np.array(pic, np.int32))
  66. elif pic.mode == 'I;16':
  67. img = torch.from_numpy(np.array(pic, np.int16))
  68. elif pic.mode == 'F':
  69. img = torch.from_numpy(np.array(pic, np.float32))
  70. elif pic.mode == '1':
  71. img = 255 * torch.from_numpy(np.array(pic, np.uint8))
  72. else:
  73. img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
  74. img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
  75. # put it from HWC to CHW format
  76. img = img.permute((2, 0, 1)).contiguous()
  77. if normalize==True and isinstance(img, torch.ByteTensor):
  78. return img.float().div(255)
  79. else:
  80. return img
  81. def to_pil_image(pic, mode=None):
  82. """Convert a tensor or an ndarray to PIL Image.
  83. See :class:`~torchvision.transforms.ToPILImage` for more details.
  84. Args:
  85. pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
  86. mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
  87. .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
  88. Returns:
  89. PIL Image: Image converted to PIL Image.
  90. """
  91. if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
  92. raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
  93. elif isinstance(pic, torch.Tensor):
  94. if pic.ndimension() not in {2, 3}:
  95. raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension()))
  96. elif pic.ndimension() == 2:
  97. # if 2D image, add channel dimension (CHW)
  98. pic = pic.unsqueeze(0)
  99. elif isinstance(pic, np.ndarray):
  100. if pic.ndim not in {2, 3}:
  101. raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
  102. elif pic.ndim == 2:
  103. # if 2D image, add channel dimension (HWC)
  104. pic = np.expand_dims(pic, 2)
  105. npimg = pic
  106. if isinstance(pic, torch.FloatTensor):
  107. pic = pic.mul(255).byte()
  108. if isinstance(pic, torch.Tensor):
  109. npimg = np.transpose(pic.numpy(), (1, 2, 0))
  110. if not isinstance(npimg, np.ndarray):
  111. raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
  112. 'not {}'.format(type(npimg)))
  113. if npimg.shape[2] == 1:
  114. expected_mode = None
  115. npimg = npimg[:, :, 0]
  116. if npimg.dtype == np.uint8:
  117. expected_mode = 'L'
  118. elif npimg.dtype == np.int16:
  119. expected_mode = 'I;16'
  120. elif npimg.dtype == np.int32:
  121. expected_mode = 'I'
  122. elif npimg.dtype == np.float32:
  123. expected_mode = 'F'
  124. if mode is not None and mode != expected_mode:
  125. raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
  126. .format(mode, np.dtype, expected_mode))
  127. mode = expected_mode
  128. elif npimg.shape[2] == 2:
  129. permitted_2_channel_modes = ['LA']
  130. if mode is not None and mode not in permitted_2_channel_modes:
  131. raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes))
  132. if mode is None and npimg.dtype == np.uint8:
  133. mode = 'LA'
  134. elif npimg.shape[2] == 4:
  135. permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
  136. if mode is not None and mode not in permitted_4_channel_modes:
  137. raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))
  138. if mode is None and npimg.dtype == np.uint8:
  139. mode = 'RGBA'
  140. else:
  141. permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
  142. if mode is not None and mode not in permitted_3_channel_modes:
  143. raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
  144. if mode is None and npimg.dtype == np.uint8:
  145. mode = 'RGB'
  146. if mode is None:
  147. raise TypeError('Input type {} is not supported'.format(npimg.dtype))
  148. return Image.fromarray(npimg, mode=mode)
  149. def normalize(tensor, mean, std, inplace=False):
  150. """Normalize a tensor image with mean and standard deviation.
  151. .. note::
  152. This transform acts out of place by default, i.e., it does not mutates the input tensor.
  153. See :class:`~torchvision.transforms.Normalize` for more details.
  154. Args:
  155. tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
  156. mean (sequence): Sequence of means for each channel.
  157. std (sequence): Sequence of standard deviations for each channel.
  158. Returns:
  159. Tensor: Normalized Tensor image.
  160. """
  161. if not _is_tensor_image(tensor):
  162. raise TypeError('tensor is not a torch image.')
  163. if not inplace:
  164. tensor = tensor.clone()
  165. mean = torch.as_tensor(mean, dtype=torch.float32, device=tensor.device)
  166. std = torch.as_tensor(std, dtype=torch.float32, device=tensor.device)
  167. tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
  168. return tensor
  169. def resize(img, size, interpolation=Image.BILINEAR):
  170. r"""Resize the input PIL Image to the given size.
  171. Args:
  172. img (PIL Image): Image to be resized.
  173. size (sequence or int): Desired output size. If size is a sequence like
  174. (h, w), the output size will be matched to this. If size is an int,
  175. the smaller edge of the image will be matched to this number maintaing
  176. the aspect ratio. i.e, if height > width, then image will be rescaled to
  177. :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`
  178. interpolation (int, optional): Desired interpolation. Default is
  179. ``PIL.Image.BILINEAR``
  180. Returns:
  181. PIL Image: Resized image.
  182. """
  183. if not _is_pil_image(img):
  184. raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
  185. if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
  186. raise TypeError('Got inappropriate size arg: {}'.format(size))
  187. if isinstance(size, int):
  188. w, h = img.size
  189. if (w <= h and w == size) or (h <= w and h == size):
  190. return img
  191. if w < h:
  192. ow = size
  193. oh = int(size * h / w)
  194. return img.resize((ow, oh), interpolation)
  195. else:
  196. oh = size
  197. ow = int(size * w / h)
  198. return img.resize((ow, oh), interpolation)
  199. else:
  200. return img.resize(size[::-1], interpolation)
  201. def scale(*args, **kwargs):
  202. warnings.warn("The use of the transforms.Scale transform is deprecated, " +
  203. "please use transforms.Resize instead.")
  204. return resize(*args, **kwargs)
  205. def pad(img, padding, fill=0, padding_mode='constant'):
  206. r"""Pad the given PIL Image on all sides with specified padding mode and fill value.
  207. Args:
  208. img (PIL Image): Image to be padded.
  209. padding (int or tuple): Padding on each border. If a single int is provided this
  210. is used to pad all borders. If tuple of length 2 is provided this is the padding
  211. on left/right and top/bottom respectively. If a tuple of length 4 is provided
  212. this is the padding for the left, top, right and bottom borders
  213. respectively.
  214. fill: Pixel fill value for constant fill. Default is 0. If a tuple of
  215. length 3, it is used to fill R, G, B channels respectively.
  216. This value is only used when the padding_mode is constant
  217. padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
  218. - constant: pads with a constant value, this value is specified with fill
  219. - edge: pads with the last value on the edge of the image
  220. - reflect: pads with reflection of image (without repeating the last value on the edge)
  221. padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
  222. will result in [3, 2, 1, 2, 3, 4, 3, 2]
  223. - symmetric: pads with reflection of image (repeating the last value on the edge)
  224. padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
  225. will result in [2, 1, 1, 2, 3, 4, 4, 3]
  226. Returns:
  227. PIL Image: Padded image.
  228. """
  229. if not _is_pil_image(img):
  230. raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
  231. if not isinstance(padding, (numbers.Number, tuple)):
  232. raise TypeError('Got inappropriate padding arg')
  233. if not isinstance(fill, (numbers.Number, str, tuple)):
  234. raise TypeError('Got inappropriate fill arg')
  235. if not isinstance(padding_mode, str):
  236. raise TypeError('Got inappropriate padding_mode arg')
  237. if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
  238. raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
  239. "{} element tuple".format(len(padding)))
  240. assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \
  241. 'Padding mode should be either constant, edge, reflect or symmetric'
  242. if padding_mode == 'constant':
  243. if img.mode == 'P':
  244. palette = img.getpalette()
  245. image = ImageOps.expand(img, border=padding, fill=fill)
  246. image.putpalette(palette)
  247. return image
  248. return ImageOps.expand(img, border=padding, fill=fill)
  249. else:
  250. if isinstance(padding, int):
  251. pad_left = pad_right = pad_top = pad_bottom = padding
  252. if isinstance(padding, Sequence) and len(padding) == 2:
  253. pad_left = pad_right = padding[0]
  254. pad_top = pad_bottom = padding[1]
  255. if isinstance(padding, Sequence) and len(padding) == 4:
  256. pad_left = padding[0]
  257. pad_top = padding[1]
  258. pad_right = padding[2]
  259. pad_bottom = padding[3]
  260. if img.mode == 'P':
  261. palette = img.getpalette()
  262. img = np.asarray(img)
  263. img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
  264. img = Image.fromarray(img)
  265. img.putpalette(palette)
  266. return img
  267. img = np.asarray(img)
  268. # RGB image
  269. if len(img.shape) == 3:
  270. img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
  271. # Grayscale image
  272. if len(img.shape) == 2:
  273. img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
  274. return Image.fromarray(img)
  275. def crop(img, i, j, h, w):
  276. """Crop the given PIL Image.
  277. Args:
  278. img (PIL Image): Image to be cropped.
  279. i (int): i in (i,j) i.e coordinates of the upper left corner.
  280. j (int): j in (i,j) i.e coordinates of the upper left corner.
  281. h (int): Height of the cropped image.
  282. w (int): Width of the cropped image.
  283. Returns:
  284. PIL Image: Cropped image.
  285. """
  286. if not _is_pil_image(img):
  287. raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
  288. return img.crop((j, i, j + w, i + h))
  289. def center_crop(img, output_size):
  290. if isinstance(output_size, numbers.Number):
  291. output_size = (int(output_size), int(output_size))
  292. w, h = img.size
  293. th, tw = output_size
  294. i = int(round((h - th) / 2.))
  295. j = int(round((w - tw) / 2.))
  296. return crop(img, i, j, th, tw)
  297. def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
  298. """Crop the given PIL Image and resize it to desired size.
  299. Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
  300. Args:
  301. img (PIL Image): Image to be cropped.
  302. i (int): i in (i,j) i.e coordinates of the upper left corner
  303. j (int): j in (i,j) i.e coordinates of the upper left corner
  304. h (int): Height of the cropped image.
  305. w (int): Width of the cropped image.
  306. size (sequence or int): Desired output size. Same semantics as ``resize``.
  307. interpolation (int, optional): Desired interpolation. Default is
  308. ``PIL.Image.BILINEAR``.
  309. Returns:
  310. PIL Image: Cropped image.
  311. """
  312. assert _is_pil_image(img), 'img should be PIL Image'
  313. img = crop(img, i, j, h, w)
  314. img = resize(img, size, interpolation)
  315. return img
  316. def hflip(img):
  317. """Horizontally flip the given PIL Image.
  318. Args:
  319. img (PIL Image): Image to be flipped.
  320. Returns:
  321. PIL Image: Horizontall flipped image.
  322. """
  323. if not _is_pil_image(img):
  324. raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
  325. return img.transpose(Image.FLIP_LEFT_RIGHT)
  326. def _get_perspective_coeffs(startpoints, endpoints):
  327. """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
  328. In Perspective Transform each pixel (x, y) in the orignal image gets transformed as,
  329. (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
  330. Args:
  331. List containing [top-left, top-right, bottom-right, bottom-left] of the orignal image,
  332. List containing [top-left, top-right, bottom-right, bottom-left] of the transformed
  333. image
  334. Returns:
  335. octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
  336. """
  337. matrix = []
  338. for p1, p2 in zip(endpoints, startpoints):
  339. matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
  340. matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
  341. A = torch.tensor(matrix, dtype=torch.float)
  342. B = torch.tensor(startpoints, dtype=torch.float).view(8)
  343. res = torch.gels(B, A)[0]
  344. return res.squeeze_(1).tolist()
  345. def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC):
  346. """Perform perspective transform of the given PIL Image.
  347. Args:
  348. img (PIL Image): Image to be transformed.
  349. coeffs (tuple) : 8-tuple (a, b, c, d, e, f, g, h) which contains the coefficients.
  350. for a perspective transform.
  351. interpolation: Default- Image.BICUBIC
  352. Returns:
  353. PIL Image: Perspectively transformed Image.
  354. """
  355. if not _is_pil_image(img):
  356. raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
  357. coeffs = _get_perspective_coeffs(startpoints, endpoints)
  358. return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation)
  359. def vflip(img):
  360. """Vertically flip the given PIL Image.
  361. Args:
  362. img (PIL Image): Image to be flipped.
  363. Returns:
  364. PIL Image: Vertically flipped image.
  365. """
  366. if not _is_pil_image(img):
  367. raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
  368. return img.transpose(Image.FLIP_TOP_BOTTOM)
  369. def five_crop(img, size):
  370. """Crop the given PIL Image into four corners and the central crop.
  371. .. Note::
  372. This transform returns a tuple of images and there may be a
  373. mismatch in the number of inputs and targets your ``Dataset`` returns.
  374. Args:
  375. size (sequence or int): Desired output size of the crop. If size is an
  376. int instead of sequence like (h, w), a square crop (size, size) is
  377. made.
  378. Returns:
  379. tuple: tuple (tl, tr, bl, br, center)
  380. Corresponding top left, top right, bottom left, bottom right and center crop.
  381. """
  382. if isinstance(size, numbers.Number):
  383. size = (int(size), int(size))
  384. else:
  385. assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
  386. w, h = img.size
  387. crop_h, crop_w = size
  388. if crop_w > w or crop_h > h:
  389. raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
  390. (h, w)))
  391. tl = img.crop((0, 0, crop_w, crop_h))
  392. tr = img.crop((w - crop_w, 0, w, crop_h))
  393. bl = img.crop((0, h - crop_h, crop_w, h))
  394. br = img.crop((w - crop_w, h - crop_h, w, h))
  395. center = center_crop(img, (crop_h, crop_w))
  396. return (tl, tr, bl, br, center)
  397. def ten_crop(img, size, vertical_flip=False):
  398. r"""Crop the given PIL Image into four corners and the central crop plus the
  399. flipped version of these (horizontal flipping is used by default).
  400. .. Note::
  401. This transform returns a tuple of images and there may be a
  402. mismatch in the number of inputs and targets your ``Dataset`` returns.
  403. Args:
  404. size (sequence or int): Desired output size of the crop. If size is an
  405. int instead of sequence like (h, w), a square crop (size, size) is
  406. made.
  407. vertical_flip (bool): Use vertical flipping instead of horizontal
  408. Returns:
  409. tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
  410. Corresponding top left, top right, bottom left, bottom right and center crop
  411. and same for the flipped image.
  412. """
  413. if isinstance(size, numbers.Number):
  414. size = (int(size), int(size))
  415. else:
  416. assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
  417. first_five = five_crop(img, size)
  418. if vertical_flip:
  419. img = vflip(img)
  420. else:
  421. img = hflip(img)
  422. second_five = five_crop(img, size)
  423. return first_five + second_five
  424. def adjust_brightness(img, brightness_factor):
  425. """Adjust brightness of an Image.
  426. Args:
  427. img (PIL Image): PIL Image to be adjusted.
  428. brightness_factor (float): How much to adjust the brightness. Can be
  429. any non negative number. 0 gives a black image, 1 gives the
  430. original image while 2 increases the brightness by a factor of 2.
  431. Returns:
  432. PIL Image: Brightness adjusted image.
  433. """
  434. if not _is_pil_image(img):
  435. raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
  436. enhancer = ImageEnhance.Brightness(img)
  437. img = enhancer.enhance(brightness_factor)
  438. return img
  439. def adjust_contrast(img, contrast_factor):
  440. """Adjust contrast of an Image.
  441. Args:
  442. img (PIL Image): PIL Image to be adjusted.
  443. contrast_factor (float): How much to adjust the contrast. Can be any
  444. non negative number. 0 gives a solid gray image, 1 gives the
  445. original image while 2 increases the contrast by a factor of 2.
  446. Returns:
  447. PIL Image: Contrast adjusted image.
  448. """
  449. if not _is_pil_image(img):
  450. raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
  451. enhancer = ImageEnhance.Contrast(img)
  452. img = enhancer.enhance(contrast_factor)
  453. return img
  454. def adjust_saturation(img, saturation_factor):
  455. """Adjust color saturation of an image.
  456. Args:
  457. img (PIL Image): PIL Image to be adjusted.
  458. saturation_factor (float): How much to adjust the saturation. 0 will
  459. give a black and white image, 1 will give the original image while
  460. 2 will enhance the saturation by a factor of 2.
  461. Returns:
  462. PIL Image: Saturation adjusted image.
  463. """
  464. if not _is_pil_image(img):
  465. raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
  466. enhancer = ImageEnhance.Color(img)
  467. img = enhancer.enhance(saturation_factor)
  468. return img
  469. def adjust_hue(img, hue_factor):
  470. """Adjust hue of an image.
  471. The image hue is adjusted by converting the image to HSV and
  472. cyclically shifting the intensities in the hue channel (H).
  473. The image is then converted back to original image mode.
  474. `hue_factor` is the amount of shift in H channel and must be in the
  475. interval `[-0.5, 0.5]`.
  476. See `Hue`_ for more details.
  477. .. _Hue: https://en.wikipedia.org/wiki/Hue
  478. Args:
  479. img (PIL Image): PIL Image to be adjusted.
  480. hue_factor (float): How much to shift the hue channel. Should be in
  481. [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
  482. HSV space in positive and negative direction respectively.
  483. 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
  484. with complementary colors while 0 gives the original image.
  485. Returns:
  486. PIL Image: Hue adjusted image.
  487. """
  488. if not(-0.5 <= hue_factor <= 0.5):
  489. raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
  490. if not _is_pil_image(img):
  491. raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
  492. input_mode = img.mode
  493. if input_mode in {'L', '1', 'I', 'F'}:
  494. return img
  495. h, s, v = img.convert('HSV').split()
  496. np_h = np.array(h, dtype=np.uint8)
  497. # uint8 addition take cares of rotation across boundaries
  498. with np.errstate(over='ignore'):
  499. np_h += np.uint8(hue_factor * 255)
  500. h = Image.fromarray(np_h, 'L')
  501. img = Image.merge('HSV', (h, s, v)).convert(input_mode)
  502. return img
  503. def adjust_gamma(img, gamma, gain=1):
  504. r"""Perform gamma correction on an image.
  505. Also known as Power Law Transform. Intensities in RGB mode are adjusted
  506. based on the following equation:
  507. .. math::
  508. I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
  509. See `Gamma Correction`_ for more details.
  510. .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
  511. Args:
  512. img (PIL Image): PIL Image to be adjusted.
  513. gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
  514. gamma larger than 1 make the shadows darker,
  515. while gamma smaller than 1 make dark regions lighter.
  516. gain (float): The constant multiplier.
  517. """
  518. if not _is_pil_image(img):
  519. raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
  520. if gamma < 0:
  521. raise ValueError('Gamma should be a non-negative real number')
  522. input_mode = img.mode
  523. img = img.convert('RGB')
  524. gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
  525. img = img.point(gamma_map) # use PIL's point-function to accelerate this part
  526. img = img.convert(input_mode)
  527. return img
  528. def rotate(img, angle, resample=False, expand=False, center=None):
  529. """Rotate the image by angle.
  530. Args:
  531. img (PIL Image): PIL Image to be rotated.
  532. angle (float or int): In degrees degrees counter clockwise order.
  533. resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
  534. An optional resampling filter. See `filters`_ for more information.
  535. If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
  536. expand (bool, optional): Optional expansion flag.
  537. If true, expands the output image to make it large enough to hold the entire rotated image.
  538. If false or omitted, make the output image the same size as the input image.
  539. Note that the expand flag assumes rotation around the center and no translation.
  540. center (2-tuple, optional): Optional center of rotation.
  541. Origin is the upper left corner.
  542. Default is the center of the image.
  543. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
  544. """
  545. if not _is_pil_image(img):
  546. raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
  547. return img.rotate(angle, resample, expand, center)
  548. def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
  549. # Helper method to compute inverse matrix for affine transformation
  550. # As it is explained in PIL.Image.rotate
  551. # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
  552. # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
  553. # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
  554. # RSS is rotation with scale and shear matrix
  555. # RSS(a, scale, shear) = [ cos(a)*scale -sin(a + shear)*scale 0]
  556. # [ sin(a)*scale cos(a + shear)*scale 0]
  557. # [ 0 0 1]
  558. # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
  559. angle = math.radians(angle)
  560. shear = math.radians(shear)
  561. scale = 1.0 / scale
  562. # Inverted rotation matrix with scale and shear
  563. d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle)
  564. matrix = [
  565. math.cos(angle + shear), math.sin(angle + shear), 0,
  566. -math.sin(angle), math.cos(angle), 0
  567. ]
  568. matrix = [scale / d * m for m in matrix]
  569. # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
  570. matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
  571. matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])
  572. # Apply center translation: C * RSS^-1 * C^-1 * T^-1
  573. matrix[2] += center[0]
  574. matrix[5] += center[1]
  575. return matrix
  576. def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None):
  577. """Apply affine transformation on the image keeping image center invariant
  578. Args:
  579. img (PIL Image): PIL Image to be rotated.
  580. angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction.
  581. translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation)
  582. scale (float): overall scale
  583. shear (float): shear angle value in degrees between -180 to 180, clockwise direction.
  584. resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
  585. An optional resampling filter.
  586. See `filters`_ for more information.
  587. If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
  588. fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
  589. """
  590. if not _is_pil_image(img):
  591. raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
  592. assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
  593. "Argument translate should be a list or tuple of length 2"
  594. assert scale > 0.0, "Argument scale should be positive"
  595. output_size = img.size
  596. center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5)
  597. matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
  598. kwargs = {"fillcolor": fillcolor} if PILLOW_VERSION[0] >= '5' else {}
  599. return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs)
  600. def to_grayscale(img, num_output_channels=1):
  601. """Convert image to grayscale version of image.
  602. Args:
  603. img (PIL Image): Image to be converted to grayscale.
  604. Returns:
  605. PIL Image: Grayscale version of the image.
  606. if num_output_channels = 1 : returned image is single channel
  607. if num_output_channels = 3 : returned image is 3 channel with r = g = b
  608. """
  609. if not _is_pil_image(img):
  610. raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
  611. if num_output_channels == 1:
  612. img = img.convert('L')
  613. elif num_output_channels == 3:
  614. img = img.convert('L')
  615. np_img = np.array(img, dtype=np.uint8)
  616. np_img = np.dstack([np_img, np_img, np_img])
  617. img = Image.fromarray(np_img, 'RGB')
  618. else:
  619. raise ValueError('num_output_channels should be either 1 or 3')
  620. return img
  621. def flip_channels(img):
  622. img = np.array(img)[:, :, ::-1]
  623. return Image.fromarray(img.astype(np.uint8))

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能