You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.py 36 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections.abc
  10. import math
  11. from typing import Sequence, Tuple
  12. import cv2
  13. import numpy as np
  14. from megengine.data.transform import Transform
  15. from megengine.data.transform.vision import functional as F
  16. __all__ = [
  17. "VisionTransform",
  18. "ToMode",
  19. "Compose",
  20. "TorchTransformCompose",
  21. "Pad",
  22. "Resize",
  23. "ShortestEdgeResize",
  24. "RandomResize",
  25. "RandomCrop",
  26. "RandomResizedCrop",
  27. "CenterCrop",
  28. "RandomHorizontalFlip",
  29. "RandomVerticalFlip",
  30. "Normalize",
  31. "GaussianNoise",
  32. "BrightnessTransform",
  33. "SaturationTransform",
  34. "ContrastTransform",
  35. "HueTransform",
  36. "ColorJitter",
  37. "Lighting",
  38. ]
  39. class VisionTransform(Transform):
  40. r"""
  41. Base class of all transforms used in computer vision.
  42. Calling logic: apply_batch() -> apply() -> _apply_image() and other _apply_*()
  43. method. If you want to implement a self-defined transform method for image,
  44. rewrite _apply_image method in subclass.
  45. :param order: input type order. Input is a tuple containing different structures,
  46. order is used to specify the order of structures. For example, if your input
  47. is (image, boxes) type, then the ``order`` should be ("image", "boxes").
  48. Current available strings and data type are describe below:
  49. * "image": input image, with shape of `(H, W, C)`.
  50. * "coords": coordinates, with shape of `(N, 2)`.
  51. * "boxes": bounding boxes, with shape of `(N, 4)`, "xyxy" format,
  52. the 1st "xy" represents top left point of a box,
  53. the 2nd "xy" represents right bottom point.
  54. * "mask": map used for segmentation, with shape of `(H, W, 1)`.
  55. * "keypoints": keypoints with shape of `(N, K, 3)`, N for number of instances,
  56. and K for number of keypoints in one instance. The first two dimensions
  57. of last axis is coordinate of keypoints and the the 3rd dimension is
  58. the label of keypoints.
  59. * "polygons": a sequence containing numpy arrays, its length is the number of instances.
  60. Each numpy array represents polygon coordinate of one instance.
  61. * "category": categories for some data type. For example, "image_category"
  62. means category of the input image and "boxes_category" means categories of
  63. bounding boxes.
  64. * "info": information for images such as image shapes and image path.
  65. You can also customize your data types only if you implement the corresponding
  66. _apply_*() methods, otherwise ``NotImplementedError`` will be raised.
  67. """
  68. def __init__(self, order=None):
  69. super().__init__()
  70. if order is None:
  71. order = ("image",)
  72. elif not isinstance(order, collections.abc.Sequence):
  73. raise ValueError(
  74. "order should be a sequence, but got order={}".format(order)
  75. )
  76. for k in order:
  77. if k in ("batch",):
  78. raise ValueError("{} is invalid data type".format(k))
  79. elif k.endswith("category") or k.endswith("info"):
  80. # when the key is *category or info, we should do nothing
  81. # if the corresponding apply methods are not implemented.
  82. continue
  83. elif self._get_apply(k) is None:
  84. raise NotImplementedError("{} is unsupported data type".format(k))
  85. self.order = order
  86. def apply_batch(self, inputs: Sequence[Tuple]):
  87. r"""Apply transform on batch input data."""
  88. return tuple(self.apply(input) for input in inputs)
  89. def apply(self, input: Tuple):
  90. r"""Apply transform on single input data."""
  91. if not isinstance(input, tuple):
  92. input = (input,)
  93. output = []
  94. for i in range(min(len(input), len(self.order))):
  95. apply_func = self._get_apply(self.order[i])
  96. if apply_func is None:
  97. output.append(input[i])
  98. else:
  99. output.append(apply_func(input[i]))
  100. if len(input) > len(self.order):
  101. output.extend(input[len(self.order) :])
  102. if len(output) == 1:
  103. output = output[0]
  104. else:
  105. output = tuple(output)
  106. return output
  107. def _get_apply(self, key):
  108. return getattr(self, "_apply_{}".format(key), None)
  109. def _get_image(self, input: Tuple):
  110. if not isinstance(input, tuple):
  111. input = (input,)
  112. return input[self.order.index("image")]
  113. def _apply_image(self, image):
  114. raise NotImplementedError
  115. def _apply_coords(self, coords):
  116. raise NotImplementedError
  117. def _apply_boxes(self, boxes):
  118. idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
  119. coords = np.asarray(boxes).reshape(-1, 4)[:, idxs].reshape(-1, 2)
  120. coords = self._apply_coords(coords).reshape((-1, 4, 2))
  121. minxy = coords.min(axis=1)
  122. maxxy = coords.max(axis=1)
  123. trans_boxes = np.concatenate((minxy, maxxy), axis=1)
  124. return trans_boxes
  125. def _apply_mask(self, mask):
  126. raise NotImplementedError
  127. def _apply_keypoints(self, keypoints):
  128. coords, visibility = keypoints[..., :2], keypoints[..., 2:]
  129. trans_coords = [self._apply_coords(p) for p in coords]
  130. return np.concatenate((trans_coords, visibility), axis=-1)
  131. def _apply_polygons(self, polygons):
  132. return [[self._apply_coords(p) for p in instance] for instance in polygons]
  133. class ToMode(VisionTransform):
  134. r"""
  135. Change input data to a target mode.
  136. For example, most transforms use HWC mode image,
  137. while the neural network might use CHW mode input tensor.
  138. :param mode: output mode of input. Default: "CHW"
  139. :param order: the same with :class:`VisionTransform`
  140. """
  141. def __init__(self, mode="CHW", *, order=None):
  142. super().__init__(order)
  143. assert mode in ["CHW"], "unsupported mode: {}".format(mode)
  144. self.mode = mode
  145. def _apply_image(self, image):
  146. if self.mode == "CHW":
  147. return np.ascontiguousarray(np.rollaxis(image, 2))
  148. return image
  149. def _apply_coords(self, coords):
  150. return coords
  151. def _apply_mask(self, mask):
  152. if self.mode == "CHW":
  153. return np.ascontiguousarray(np.rollaxis(mask, 2))
  154. return mask
  155. class Compose(VisionTransform):
  156. r"""
  157. Composes several transforms together.
  158. :param transforms: list of :class:`VisionTransform` to compose.
  159. :param batch_compose: whether use shuffle_indices for batch data or not.
  160. If True, use original input sequence.
  161. Otherwise, the shuffle_indices will be used for transforms.
  162. :param shuffle_indices: indices used for random shuffle, start at 1.
  163. For example, if shuffle_indices is [(1, 3), (2, 4)], then the 1st and 3rd transform
  164. will be random shuffled, the 2nd and 4th transform will also be shuffled.
  165. :param order: the same with :class:`VisionTransform`
  166. Examples:
  167. .. testcode::
  168. from megengine.data.transform import RandomHorizontalFlip, RandomVerticalFlip, CenterCrop, ToMode, Compose
  169. transform_func = Compose([
  170. RandomHorizontalFlip(),
  171. RandomVerticalFlip(),
  172. CenterCrop(100),
  173. ToMode("CHW"),
  174. ],
  175. shuffle_indices=[(1, 2, 3)]
  176. )
  177. """
  178. def __init__(
  179. self, transforms=[], batch_compose=False, shuffle_indices=None, *, order=None
  180. ):
  181. super().__init__(order)
  182. self.transforms = transforms
  183. self._set_order()
  184. if batch_compose and shuffle_indices is not None:
  185. raise ValueError(
  186. "Do not support shuffle when apply transforms along the whole batch"
  187. )
  188. self.batch_compose = batch_compose
  189. if shuffle_indices is not None:
  190. shuffle_indices = [tuple(x - 1 for x in idx) for idx in shuffle_indices]
  191. self.shuffle_indices = shuffle_indices
  192. def _set_order(self):
  193. for t in self.transforms:
  194. t.order = self.order
  195. if isinstance(t, Compose):
  196. t._set_order()
  197. def apply_batch(self, inputs: Sequence[Tuple]):
  198. if self.batch_compose:
  199. for t in self.transforms:
  200. inputs = t.apply_batch(inputs)
  201. return inputs
  202. else:
  203. return super().apply_batch(inputs)
  204. def apply(self, input: Tuple):
  205. for t in self._shuffle():
  206. input = t.apply(input)
  207. return input
  208. def _shuffle(self):
  209. if self.shuffle_indices is not None:
  210. source_idx = list(range(len(self.transforms)))
  211. for idx in self.shuffle_indices:
  212. shuffled = np.random.permutation(idx).tolist()
  213. for src, dst in zip(idx, shuffled):
  214. source_idx[src] = dst
  215. return [self.transforms[i] for i in source_idx]
  216. else:
  217. return self.transforms
  218. class TorchTransformCompose(VisionTransform):
  219. r"""
  220. Compose class used for transforms in torchvision, only support PIL image,
  221. some transforms with tensor in torchvision are not supported,
  222. such as Normalize and ToTensor in torchvision.
  223. :param transforms: the same with ``Compose``.
  224. :param order: the same with :class:`VisionTransform`.
  225. """
  226. def __init__(self, transforms, *, order=None):
  227. super().__init__(order)
  228. self.transforms = transforms
  229. def _apply_image(self, image):
  230. from PIL import Image
  231. try:
  232. import accimage
  233. except ImportError:
  234. accimage = None
  235. if image.shape[0] == 3: # CHW
  236. image = np.ascontiguousarray(image[[2, 1, 0]])
  237. elif image.shape[2] == 3: # HWC
  238. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  239. image = Image.fromarray(image.astype(np.uint8))
  240. for t in self.transforms:
  241. image = t(image)
  242. if isinstance(image, Image.Image) or (
  243. accimage is not None and isinstance(image, accimage.Image)
  244. ):
  245. image = np.array(image, dtype=np.uint8)
  246. if image.shape[0] == 3: # CHW
  247. image = np.ascontiguousarray(image[[2, 1, 0]])
  248. elif image.shape[2] == 3: # HWC
  249. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  250. return image
  251. class Pad(VisionTransform):
  252. r"""
  253. Pad the input data.
  254. :param size: padding size of input image, it could be integer or sequence.
  255. If it is an integer, the input image will be padded in four directions.
  256. If it is a sequence containing two integers, the bottom and right side
  257. of image will be padded.
  258. If it is a sequence containing four integers, the top, bottom, left, right
  259. side of image will be padded with given size.
  260. :param value: padding value of image, could be a sequence of int or float.
  261. if it is float value, the dtype of image will be casted to float32 also.
  262. :param mask_value: padding value of segmentation map.
  263. :param order: the same with :class:`VisionTransform`.
  264. """
  265. def __init__(self, size=0, value=0, mask_value=0, *, order=None):
  266. super().__init__(order)
  267. if isinstance(size, int):
  268. size = (size, size, size, size)
  269. elif isinstance(size, collections.abc.Sequence) and len(size) == 2:
  270. size = (0, size[0], 0, size[1])
  271. elif not (isinstance(size, collections.abc.Sequence) and len(size) == 4):
  272. raise ValueError(
  273. "size should be a list/tuple which contains "
  274. "(top, down, left, right) four pad sizes."
  275. )
  276. self.size = size
  277. self.value = value
  278. if not isinstance(mask_value, int):
  279. raise ValueError(
  280. "mask_value should be a positive integer, "
  281. "but got mask_value={}".format(mask_value)
  282. )
  283. self.mask_value = mask_value
  284. def _apply_image(self, image):
  285. return F.pad(image, self.size, self.value)
  286. def _apply_coords(self, coords):
  287. coords[:, 0] += self.size[2]
  288. coords[:, 1] += self.size[0]
  289. return coords
  290. def _apply_mask(self, mask):
  291. return F.pad(mask, self.size, self.mask_value)
  292. class Resize(VisionTransform):
  293. r"""
  294. Resize the input data.
  295. :param output_size: target size of image, with (height, width) shape.
  296. :param interpolation: interpolation method. All methods are listed below:
  297. * cv2.INTER_NEAREST – a nearest-neighbor interpolation.
  298. * cv2.INTER_LINEAR – a bilinear interpolation (used by default).
  299. * cv2.INTER_AREA – resampling using pixel area relation.
  300. * cv2.INTER_CUBIC – a bicubic interpolation over 4×4 pixel neighborhood.
  301. * cv2.INTER_LANCZOS4 – a Lanczos interpolation over 8×8 pixel neighborhood.
  302. :param order: the same with :class:`VisionTransform`.
  303. """
  304. def __init__(self, output_size, interpolation=cv2.INTER_LINEAR, *, order=None):
  305. super().__init__(order)
  306. self.output_size = output_size
  307. self.interpolation = interpolation
  308. def apply(self, input: Tuple):
  309. self._shape_info = self._get_shape(self._get_image(input))
  310. return super().apply(input)
  311. def _apply_image(self, image):
  312. h, w, th, tw = self._shape_info
  313. if h == th and w == tw:
  314. return image
  315. return F.resize(image, (th, tw), self.interpolation)
  316. def _apply_coords(self, coords):
  317. h, w, th, tw = self._shape_info
  318. if h == th and w == tw:
  319. return coords
  320. coords[:, 0] = coords[:, 0] * (tw / w)
  321. coords[:, 1] = coords[:, 1] * (th / h)
  322. return coords
  323. def _apply_mask(self, mask):
  324. h, w, th, tw = self._shape_info
  325. if h == th and w == tw:
  326. return mask
  327. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  328. def _get_shape(self, image):
  329. h, w, _ = image.shape
  330. if isinstance(self.output_size, int):
  331. if min(h, w) == self.output_size:
  332. return h, w, h, w
  333. if h < w:
  334. th = self.output_size
  335. tw = int(self.output_size * w / h)
  336. else:
  337. tw = self.output_size
  338. th = int(self.output_size * h / w)
  339. return h, w, th, tw
  340. else:
  341. return (h, w, *self.output_size)
  342. class ShortestEdgeResize(VisionTransform):
  343. r"""
  344. Resize the input data with specified shortset edge.
  345. """
  346. def __init__(
  347. self,
  348. min_size,
  349. max_size,
  350. sample_style="range",
  351. interpolation=cv2.INTER_LINEAR,
  352. *,
  353. order=None
  354. ):
  355. super().__init__(order)
  356. if sample_style not in ("range", "choice"):
  357. raise NotImplementedError(
  358. "{} is unsupported sample style".format(sample_style)
  359. )
  360. self.sample_style = sample_style
  361. if isinstance(min_size, int):
  362. min_size = (min_size, min_size)
  363. self.min_size = min_size
  364. self.max_size = max_size
  365. self.interpolation = interpolation
  366. def apply(self, input: Tuple):
  367. self._shape_info = self._get_shape(self._get_image(input))
  368. return super().apply(input)
  369. def _apply_image(self, image):
  370. h, w, th, tw = self._shape_info
  371. if h == th and w == tw:
  372. return image
  373. return F.resize(image, (th, tw), self.interpolation)
  374. def _apply_coords(self, coords):
  375. h, w, th, tw = self._shape_info
  376. if h == th and w == tw:
  377. return coords
  378. coords[:, 0] = coords[:, 0] * (tw / w)
  379. coords[:, 1] = coords[:, 1] * (th / h)
  380. return coords
  381. def _apply_mask(self, mask):
  382. h, w, th, tw = self._shape_info
  383. if h == th and w == tw:
  384. return mask
  385. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  386. def _get_shape(self, image):
  387. h, w, _ = image.shape
  388. if self.sample_style == "range":
  389. size = np.random.randint(self.min_size[0], self.min_size[1] + 1)
  390. else:
  391. size = np.random.choice(self.min_size)
  392. scale = size / min(h, w)
  393. if h < w:
  394. th, tw = size, scale * w
  395. else:
  396. th, tw = scale * h, size
  397. if max(th, tw) > self.max_size:
  398. scale = self.max_size / max(th, tw)
  399. th = th * scale
  400. tw = tw * scale
  401. th = int(round(th))
  402. tw = int(round(tw))
  403. return h, w, th, tw
  404. class RandomResize(VisionTransform):
  405. r"""
  406. Resize the input data randomly.
  407. :param scale_range: range of scaling.
  408. :param order: the same with :class:`VisionTransform`.
  409. """
  410. def __init__(self, scale_range, interpolation=cv2.INTER_LINEAR, *, order=None):
  411. super().__init__(order)
  412. self.scale_range = scale_range
  413. self.interpolation = interpolation
  414. def apply(self, input: Tuple):
  415. self._shape_info = self._get_shape(self._get_image(input))
  416. return super().apply(input)
  417. def _apply_image(self, image):
  418. h, w, th, tw = self._shape_info
  419. if h == th and w == tw:
  420. return image
  421. return F.resize(image, (th, tw), self.interpolation)
  422. def _apply_coords(self, coords):
  423. h, w, th, tw = self._shape_info
  424. if h == th and w == tw:
  425. return coords
  426. coords[:, 0] = coords[:, 0] * (tw / w)
  427. coords[:, 1] = coords[:, 1] * (th / h)
  428. return coords
  429. def _apply_mask(self, mask):
  430. h, w, th, tw = self._shape_info
  431. if h == th and w == tw:
  432. return mask
  433. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  434. def _get_shape(self, image):
  435. h, w, _ = image.shape
  436. scale = np.random.uniform(*self.scale_range)
  437. th = int(round(h * scale))
  438. tw = int(round(w * scale))
  439. return h, w, th, tw
  440. class RandomCrop(VisionTransform):
  441. r"""
  442. Crop the input data randomly. Before applying the crop transform,
  443. pad the image first. If target size is still bigger than the size of
  444. padded image, pad the image size to target size.
  445. :param output_size: target size of output image, with (height, width) shape.
  446. :param padding_size: the same with `size` in ``Pad``.
  447. :param padding_value: the same with `value` in ``Pad``.
  448. :param order: the same with :class:`VisionTransform`.
  449. """
  450. def __init__(
  451. self,
  452. output_size,
  453. padding_size=0,
  454. padding_value=[0, 0, 0],
  455. padding_maskvalue=0,
  456. *,
  457. order=None
  458. ):
  459. super().__init__(order)
  460. if isinstance(output_size, int):
  461. self.output_size = (output_size, output_size)
  462. else:
  463. self.output_size = output_size
  464. self.pad = Pad(padding_size, padding_value, order=self.order)
  465. self.padding_value = padding_value
  466. self.padding_maskvalue = padding_maskvalue
  467. def apply(self, input):
  468. input = self.pad.apply(input)
  469. self._h, self._w, _ = self._get_image(input).shape
  470. self._th, self._tw = self.output_size
  471. self._x = np.random.randint(0, max(0, self._w - self._tw) + 1)
  472. self._y = np.random.randint(0, max(0, self._h - self._th) + 1)
  473. return super().apply(input)
  474. def _apply_image(self, image):
  475. if self._th > self._h:
  476. image = F.pad(image, (self._th - self._h, 0), self.padding_value)
  477. if self._tw > self._w:
  478. image = F.pad(image, (0, self._tw - self._w), self.padding_value)
  479. return image[self._y : self._y + self._th, self._x : self._x + self._tw]
  480. def _apply_coords(self, coords):
  481. coords[:, 0] -= self._x
  482. coords[:, 1] -= self._y
  483. return coords
  484. def _apply_mask(self, mask):
  485. if self._th > self._h:
  486. mask = F.pad(mask, (self._th - self._h, 0), self.padding_maskvalue)
  487. if self._tw > self._w:
  488. mask = F.pad(mask, (0, self._tw - self._w), self.padding_maskvalue)
  489. return mask[self._y : self._y + self._th, self._x : self._x + self._tw]
  490. class RandomResizedCrop(VisionTransform):
  491. r"""
  492. Crop the input data to random size and aspect ratio.
  493. A crop of random size (default: of 0.08 to 1.0) of the original size and a random
  494. aspect ratio (default: of 3/4 to 1.33) of the original aspect ratio is made.
  495. After applying crop transfrom, the input data will be resized to given size.
  496. :param output_size: target size of output image, with (height, width) shape.
  497. :param scale_range: range of size of the origin size cropped. Default: (0.08, 1.0)
  498. :param ratio_range: range of aspect ratio of the origin aspect ratio cropped. Default: (0.75, 1.33)
  499. :param order: the same with :class:`VisionTransform`.
  500. """
  501. def __init__(
  502. self,
  503. output_size,
  504. scale_range=(0.08, 1.0),
  505. ratio_range=(3.0 / 4, 4.0 / 3),
  506. interpolation=cv2.INTER_LINEAR,
  507. *,
  508. order=None
  509. ):
  510. super().__init__(order)
  511. if isinstance(output_size, int):
  512. self.output_size = (output_size, output_size)
  513. else:
  514. self.output_size = output_size
  515. assert (
  516. scale_range[0] <= scale_range[1]
  517. ), "scale_range should be of kind (min, max)"
  518. assert (
  519. ratio_range[0] <= ratio_range[1]
  520. ), "ratio_range should be of kind (min, max)"
  521. self.scale_range = scale_range
  522. self.ratio_range = ratio_range
  523. self.interpolation = interpolation
  524. def apply(self, input: Tuple):
  525. self._coord_info = self._get_coord(self._get_image(input))
  526. return super().apply(input)
  527. def _apply_image(self, image):
  528. x, y, w, h = self._coord_info
  529. cropped_img = image[y : y + h, x : x + w]
  530. return F.resize(cropped_img, self.output_size, self.interpolation)
  531. def _apply_coords(self, coords):
  532. x, y, w, h = self._coord_info
  533. coords[:, 0] = (coords[:, 0] - x) * self.output_size[1] / w
  534. coords[:, 1] = (coords[:, 1] - y) * self.output_size[0] / h
  535. return coords
  536. def _apply_mask(self, mask):
  537. x, y, w, h = self._coord_info
  538. cropped_mask = mask[y : y + h, x : x + w]
  539. return F.resize(cropped_mask, self.output_size, cv2.INTER_NEAREST)
  540. def _get_coord(self, image, attempts=10):
  541. height, width, _ = image.shape
  542. area = height * width
  543. for _ in range(attempts):
  544. target_area = np.random.uniform(*self.scale_range) * area
  545. log_ratio = tuple(math.log(x) for x in self.ratio_range)
  546. aspect_ratio = math.exp(np.random.uniform(*log_ratio))
  547. w = int(round(math.sqrt(target_area * aspect_ratio)))
  548. h = int(round(math.sqrt(target_area / aspect_ratio)))
  549. if 0 < w <= width and 0 < h <= height:
  550. x = np.random.randint(0, width - w + 1)
  551. y = np.random.randint(0, height - h + 1)
  552. return x, y, w, h
  553. # Fallback to central crop
  554. in_ratio = float(width) / float(height)
  555. if in_ratio < min(self.ratio_range):
  556. w = width
  557. h = int(round(w / min(self.ratio_range)))
  558. elif in_ratio > max(self.ratio_range):
  559. h = height
  560. w = int(round(h * max(self.ratio_range)))
  561. else: # whole image
  562. w = width
  563. h = height
  564. x = (width - w) // 2
  565. y = (height - h) // 2
  566. return x, y, w, h
  567. class CenterCrop(VisionTransform):
  568. r"""
  569. Crops the given the input data at the center.
  570. :param output_size: target size of output image, with (height, width) shape.
  571. :param order: the same with :class:`VisionTransform`.
  572. """
  573. def __init__(self, output_size, *, order=None):
  574. super().__init__(order)
  575. if isinstance(output_size, int):
  576. self.output_size = (output_size, output_size)
  577. else:
  578. self.output_size = output_size
  579. def apply(self, input: Tuple):
  580. self._coord_info = self._get_coord(self._get_image(input))
  581. return super().apply(input)
  582. def _apply_image(self, image):
  583. x, y = self._coord_info
  584. th, tw = self.output_size
  585. return image[y : y + th, x : x + tw]
  586. def _apply_coords(self, coords):
  587. x, y = self._coord_info
  588. coords[:, 0] -= x
  589. coords[:, 1] -= y
  590. return coords
  591. def _apply_mask(self, mask):
  592. x, y = self._coord_info
  593. th, tw = self.output_size
  594. return mask[y : y + th, x : x + tw]
  595. def _get_coord(self, image):
  596. th, tw = self.output_size
  597. h, w, _ = image.shape
  598. assert th <= h and tw <= w, "output size is bigger than image size"
  599. x = int(round((w - tw) / 2.0))
  600. y = int(round((h - th) / 2.0))
  601. return x, y
  602. class RandomHorizontalFlip(VisionTransform):
  603. r"""
  604. Horizontally flip the input data randomly with a given probability.
  605. :param p: probability of the input data being flipped. Default: 0.5
  606. :param order: the same with :class:`VisionTransform`.
  607. """
  608. def __init__(self, prob: float = 0.5, *, order=None):
  609. super().__init__(order)
  610. self.prob = prob
  611. def apply(self, input: Tuple):
  612. self._flipped = np.random.random() < self.prob
  613. self._w = self._get_image(input).shape[1]
  614. return super().apply(input)
  615. def _apply_image(self, image):
  616. if self._flipped:
  617. return F.flip(image, flipCode=1)
  618. return image
  619. def _apply_coords(self, coords):
  620. if self._flipped:
  621. coords[:, 0] = self._w - coords[:, 0]
  622. return coords
  623. def _apply_mask(self, mask):
  624. if self._flipped:
  625. return F.flip(mask, flipCode=1)
  626. return mask
  627. class RandomVerticalFlip(VisionTransform):
  628. r"""
  629. Vertically flip the input data randomly with a given probability.
  630. :param p: probability of the input data being flipped. Default: 0.5
  631. :param order: the same with :class:`VisionTransform`.
  632. """
  633. def __init__(self, prob: float = 0.5, *, order=None):
  634. super().__init__(order)
  635. self.prob = prob
  636. def apply(self, input: Tuple):
  637. self._flipped = np.random.random() < self.prob
  638. self._h = self._get_image(input).shape[0]
  639. return super().apply(input)
  640. def _apply_image(self, image):
  641. if self._flipped:
  642. return F.flip(image, flipCode=0)
  643. return image
  644. def _apply_coords(self, coords):
  645. if self._flipped:
  646. coords[:, 1] = self._h - coords[:, 1]
  647. return coords
  648. def _apply_mask(self, mask):
  649. if self._flipped:
  650. return F.flip(mask, flipCode=0)
  651. return mask
  652. class Normalize(VisionTransform):
  653. r"""
  654. Normalize the input data with mean and standard deviation.
  655. Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
  656. this transform will normalize each channel of the input data.
  657. ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
  658. :param mean: sequence of means for each channel.
  659. :param std: sequence of standard deviations for each channel.
  660. :param order: the same with :class:`VisionTransform`.
  661. """
  662. def __init__(self, mean=0.0, std=1.0, *, order=None):
  663. super().__init__(order)
  664. self.mean = np.array(mean, dtype=np.float32)
  665. self.std = np.array(std, dtype=np.float32)
  666. def _apply_image(self, image):
  667. return (image - self.mean) / self.std
  668. def _apply_coords(self, coords):
  669. return coords
  670. def _apply_mask(self, mask):
  671. return mask
  672. class GaussianNoise(VisionTransform):
  673. r"""
  674. Add random gaussian noise to the input data.
  675. Gaussian noise is generated with given mean and std.
  676. :param mean: Gaussian mean used to generate noise.
  677. :param std: Gaussian standard deviation used to generate noise.
  678. :param order: the same with :class:`VisionTransform`
  679. """
  680. def __init__(self, mean=0.0, std=1.0, *, order=None):
  681. super().__init__(order)
  682. self.mean = np.array(mean, dtype=np.float32)
  683. self.std = np.array(std, dtype=np.float32)
  684. def _apply_image(self, image):
  685. dtype = image.dtype
  686. noise = np.random.normal(self.mean, self.std, image.shape) * 255
  687. image = image + noise.astype(np.float32)
  688. return np.clip(image, 0, 255).astype(dtype)
  689. def _apply_coords(self, coords):
  690. return coords
  691. def _apply_mask(self, mask):
  692. return mask
  693. class BrightnessTransform(VisionTransform):
  694. r"""
  695. Adjust brightness of the input data.
  696. :param value: how much to adjust the brightness. Can be any
  697. non negative number. 0 gives the original image.
  698. :param order: the same with :class:`VisionTransform`.
  699. """
  700. def __init__(self, value, *, order=None):
  701. super().__init__(order)
  702. if value < 0:
  703. raise ValueError("brightness value should be non-negative")
  704. self.value = value
  705. def _apply_image(self, image):
  706. if self.value == 0:
  707. return image
  708. dtype = image.dtype
  709. image = image.astype(np.float32)
  710. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  711. image = image * alpha
  712. return image.clip(0, 255).astype(dtype)
  713. def _apply_coords(self, coords):
  714. return coords
  715. def _apply_mask(self, mask):
  716. return mask
  717. class ContrastTransform(VisionTransform):
  718. r"""
  719. Adjust contrast of the input data.
  720. :param value: how much to adjust the contrast. Can be any
  721. non negative number. 0 gives the original image.
  722. :param order: the same with :class:`VisionTransform`.
  723. """
  724. def __init__(self, value, *, order=None):
  725. super().__init__(order)
  726. if value < 0:
  727. raise ValueError("contrast value should be non-negative")
  728. self.value = value
  729. def _apply_image(self, image):
  730. if self.value == 0:
  731. return image
  732. dtype = image.dtype
  733. image = image.astype(np.float32)
  734. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  735. image = image * alpha + F.to_gray(image).mean() * (1 - alpha)
  736. return image.clip(0, 255).astype(dtype)
  737. def _apply_coords(self, coords):
  738. return coords
  739. def _apply_mask(self, mask):
  740. return mask
  741. class SaturationTransform(VisionTransform):
  742. r"""
  743. Adjust saturation of the input data.
  744. :param value: how much to adjust the saturation. Can be any
  745. non negative number. 0 gives the original image.
  746. :param order: the same with :class:`VisionTransform`.
  747. """
  748. def __init__(self, value, *, order=None):
  749. super().__init__(order)
  750. if value < 0:
  751. raise ValueError("saturation value should be non-negative")
  752. self.value = value
  753. def _apply_image(self, image):
  754. if self.value == 0:
  755. return image
  756. dtype = image.dtype
  757. image = image.astype(np.float32)
  758. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  759. image = image * alpha + F.to_gray(image) * (1 - alpha)
  760. return image.clip(0, 255).astype(dtype)
  761. def _apply_coords(self, coords):
  762. return coords
  763. def _apply_mask(self, mask):
  764. return mask
  765. class HueTransform(VisionTransform):
  766. r"""
  767. Adjust hue of the input data.
  768. :param value: how much to adjust the hue. Can be any number
  769. between 0 and 0.5, 0 gives the original image.
  770. :param order: the same with :class:`VisionTransform`.
  771. """
  772. def __init__(self, value, *, order=None):
  773. super().__init__(order)
  774. if value < 0 or value > 0.5:
  775. raise ValueError("hue value should be in [0.0, 0.5]")
  776. self.value = value
  777. def _apply_image(self, image):
  778. if self.value == 0:
  779. return image
  780. dtype = image.dtype
  781. image = image.astype(np.uint8)
  782. hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV_FULL)
  783. h, s, v = cv2.split(hsv_image)
  784. alpha = np.random.uniform(-self.value, self.value)
  785. h = h.astype(np.uint8)
  786. # uint8 addition take cares of rotation across boundaries
  787. with np.errstate(over="ignore"):
  788. h += np.uint8(alpha * 255)
  789. hsv_image = cv2.merge([h, s, v])
  790. return cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR_FULL).astype(dtype)
  791. def _apply_coords(self, coords):
  792. return coords
  793. def _apply_mask(self, mask):
  794. return mask
  795. class ColorJitter(VisionTransform):
  796. r"""
  797. Randomly change the brightness, contrast, saturation and hue of an image.
  798. :param brightness: how much to jitter brightness.
  799. Chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
  800. or the given [min, max]. Should be non negative numbers.
  801. :param contrast: how much to jitter contrast.
  802. Chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
  803. or the given [min, max]. Should be non negative numbers.
  804. :param saturation: how much to jitter saturation.
  805. Chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
  806. or the given [min, max]. Should be non negative numbers.
  807. :param hue: how much to jitter hue.
  808. Chosen uniformly from [-hue, hue] or the given [min, max].
  809. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
  810. :param order: the same with :class:`VisionTransform`.
  811. """
  812. def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, *, order=None):
  813. super().__init__(order)
  814. transforms = []
  815. if brightness != 0:
  816. transforms.append(BrightnessTransform(brightness))
  817. if contrast != 0:
  818. transforms.append(ContrastTransform(contrast))
  819. if saturation != 0:
  820. transforms.append(SaturationTransform(saturation))
  821. if hue != 0:
  822. transforms.append(HueTransform(hue))
  823. self.transforms = Compose(
  824. transforms,
  825. shuffle_indices=[tuple(range(1, len(transforms) + 1))],
  826. order=order,
  827. )
  828. def apply(self, input):
  829. return self.transforms.apply(input)
  830. class Lighting(VisionTransform):
  831. r"""
  832. Apply AlexNet-Style "lighting" augmentation to input data.
  833. Input images are assumed to have 'RGB' channel order.
  834. The degree of color jittering is randomly sampled via a normal distribution,
  835. with standard deviation given by the scale parameter.
  836. """
  837. def __init__(self, scale, *, order=None):
  838. super().__init__(order)
  839. if scale < 0:
  840. raise ValueError("lighting scale should be non-negative")
  841. self.scale = scale
  842. self.eigvec = np.array(
  843. [
  844. [-0.5836, -0.6948, 0.4203],
  845. [-0.5808, -0.0045, -0.8140],
  846. [-0.5675, 0.7192, 0.4009],
  847. ]
  848. ) # reverse the first dimension for BGR
  849. self.eigval = np.array([0.2175, 0.0188, 0.0045])
  850. def _apply_image(self, image):
  851. if self.scale == 0:
  852. return image
  853. dtype = image.dtype
  854. image = image.astype(np.float32)
  855. alpha = np.random.normal(scale=self.scale * 255, size=3)
  856. image = image + self.eigvec.dot(alpha * self.eigval)
  857. return image.clip(0, 255).astype(dtype)
  858. def _apply_coords(self, coords):
  859. return coords
  860. def _apply_mask(self, mask):
  861. return mask

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台