You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.py 35 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections.abc
  10. import math
  11. from typing import Sequence, Tuple
  12. import cv2
  13. import numpy as np
  14. from megengine.data.transform import Transform
  15. from megengine.data.transform.vision import functional as F
  16. __all__ = [
  17. "VisionTransform",
  18. "ToMode",
  19. "Compose",
  20. "TorchTransformCompose",
  21. "Pad",
  22. "Resize",
  23. "ShortestEdgeResize",
  24. "RandomResize",
  25. "RandomCrop",
  26. "RandomResizedCrop",
  27. "CenterCrop",
  28. "RandomHorizontalFlip",
  29. "RandomVerticalFlip",
  30. "Normalize",
  31. "GaussianNoise",
  32. "BrightnessTransform",
  33. "SaturationTransform",
  34. "ContrastTransform",
  35. "HueTransform",
  36. "ColorJitter",
  37. "Lighting",
  38. ]
  39. class VisionTransform(Transform):
  40. r"""
  41. Base class of all transforms used in computer vision.
  42. Calling logic: apply_batch() -> apply() -> _apply_image() and other _apply_*()
  43. method. If you want to implement a self-defined transform method for image,
  44. rewrite _apply_image method in subclass.
  45. :param order: input type order. Input is a tuple containing different structures,
  46. order is used to specify the order of structures. For example, if your input
  47. is (image, boxes) type, then the ``order`` should be ("image", "boxes").
  48. Current available strings and data type are describe below:
  49. * "image": input image, with shape of `(H, W, C)`.
  50. * "coords": coordinates, with shape of `(N, 2)`.
  51. * "boxes": bounding boxes, with shape of `(N, 4)`, "xyxy" format,
  52. the 1st "xy" represents top left point of a box,
  53. the 2nd "xy" represents right bottom point.
  54. * "mask": map used for segmentation, with shape of `(H, W, 1)`.
  55. * "keypoints": keypoints with shape of `(N, K, 3)`, N for number of instances,
  56. and K for number of keypoints in one instance. The first two dimensions
  57. of last axis is coordinate of keypoints and the the 3rd dimension is
  58. the label of keypoints.
  59. * "polygons": a sequence containing numpy arrays, its length is the number of instances.
  60. Each numpy array represents polygon coordinate of one instance.
  61. * "category": categories for some data type. For example, "image_category"
  62. means category of the input image and "boxes_category" means categories of
  63. bounding boxes.
  64. * "info": information for images such as image shapes and image path.
  65. You can also customize your data types only if you implement the corresponding
  66. _apply_*() methods, otherwise ``NotImplementedError`` will be raised.
  67. """
  68. def __init__(self, order=None):
  69. super().__init__()
  70. if order is None:
  71. order = ("image",)
  72. elif not isinstance(order, collections.abc.Sequence):
  73. raise ValueError(
  74. "order should be a sequence, but got order={}".format(order)
  75. )
  76. for k in order:
  77. if k in ("batch",):
  78. raise ValueError("{} is invalid data type".format(k))
  79. elif k.endswith("category") or k.endswith("info"):
  80. # when the key is *category or info, we should do nothing
  81. # if the corresponding apply methods are not implemented.
  82. continue
  83. elif self._get_apply(k) is None:
  84. raise NotImplementedError("{} is unsupported data type".format(k))
  85. self.order = order
  86. def apply_batch(self, inputs: Sequence[Tuple]):
  87. r"""Apply transform on batch input data."""
  88. return tuple(self.apply(input) for input in inputs)
  89. def apply(self, input: Tuple):
  90. r"""Apply transform on single input data."""
  91. if not isinstance(input, tuple):
  92. input = (input,)
  93. output = []
  94. for i in range(min(len(input), len(self.order))):
  95. apply_func = self._get_apply(self.order[i])
  96. if apply_func is None:
  97. output.append(input[i])
  98. else:
  99. output.append(apply_func(input[i]))
  100. if len(input) > len(self.order):
  101. output.extend(input[len(self.order) :])
  102. if len(output) == 1:
  103. output = output[0]
  104. else:
  105. output = tuple(output)
  106. return output
  107. def _get_apply(self, key):
  108. return getattr(self, "_apply_{}".format(key), None)
  109. def _get_image(self, input: Tuple):
  110. if not isinstance(input, tuple):
  111. input = (input,)
  112. return input[self.order.index("image")]
  113. def _apply_image(self, image):
  114. raise NotImplementedError
  115. def _apply_coords(self, coords):
  116. raise NotImplementedError
  117. def _apply_boxes(self, boxes):
  118. idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
  119. coords = np.asarray(boxes).reshape(-1, 4)[:, idxs].reshape(-1, 2)
  120. coords = self._apply_coords(coords).reshape((-1, 4, 2))
  121. minxy = coords.min(axis=1)
  122. maxxy = coords.max(axis=1)
  123. trans_boxes = np.concatenate((minxy, maxxy), axis=1)
  124. return trans_boxes
  125. def _apply_mask(self, mask):
  126. raise NotImplementedError
  127. def _apply_keypoints(self, keypoints):
  128. coords, visibility = keypoints[..., :2], keypoints[..., 2:]
  129. trans_coords = [self._apply_coords(p) for p in coords]
  130. return np.concatenate((trans_coords, visibility), axis=-1)
  131. def _apply_polygons(self, polygons):
  132. return [[self._apply_coords(p) for p in instance] for instance in polygons]
  133. class ToMode(VisionTransform):
  134. r"""
  135. Change input data to a target mode.
  136. For example, most transforms use HWC mode image,
  137. while the neural network might use CHW mode input tensor.
  138. :param mode: output mode of input. Default: "CHW"
  139. :param order: the same with :class:`VisionTransform`
  140. """
  141. def __init__(self, mode="CHW", *, order=None):
  142. super().__init__(order)
  143. assert mode in ["CHW"], "unsupported mode: {}".format(mode)
  144. self.mode = mode
  145. def _apply_image(self, image):
  146. if self.mode == "CHW":
  147. return np.ascontiguousarray(np.rollaxis(image, 2))
  148. return image
  149. def _apply_coords(self, coords):
  150. return coords
  151. def _apply_mask(self, mask):
  152. if self.mode == "CHW":
  153. return np.ascontiguousarray(np.rollaxis(mask, 2))
  154. return mask
  155. class Compose(VisionTransform):
  156. r"""
  157. Composes several transforms together.
  158. :param transforms: list of :class:`VisionTransform` to compose.
  159. :param batch_compose: whether use shuffle_indices for batch data or not.
  160. If True, use original input sequence.
  161. Otherwise, the shuffle_indices will be used for transforms.
  162. :param shuffle_indices: indices used for random shuffle, start at 1.
  163. For example, if shuffle_indices is [(1, 3), (2, 4)], then the 1st and 3rd transform
  164. will be random shuffled, the 2nd and 4th transform will also be shuffled.
  165. :param order: the same with :class:`VisionTransform`
  166. Examples:
  167. .. testcode::
  168. from megengine.data.transform import RandomHorizontalFlip, RandomVerticalFlip, CenterCrop, ToMode, Compose
  169. transform_func = Compose([
  170. RandomHorizontalFlip(),
  171. RandomVerticalFlip(),
  172. CenterCrop(100),
  173. ToMode("CHW"),
  174. ],
  175. shuffle_indices=[(1, 2, 3)]
  176. )
  177. """
  178. def __init__(
  179. self, transforms=[], batch_compose=False, shuffle_indices=None, *, order=None
  180. ):
  181. super().__init__(order)
  182. self.transforms = transforms
  183. self._set_order()
  184. if batch_compose and shuffle_indices is not None:
  185. raise ValueError(
  186. "Do not support shuffle when apply transforms along the whole batch"
  187. )
  188. self.batch_compose = batch_compose
  189. if shuffle_indices is not None:
  190. shuffle_indices = [tuple(x - 1 for x in idx) for idx in shuffle_indices]
  191. self.shuffle_indices = shuffle_indices
  192. def _set_order(self):
  193. for t in self.transforms:
  194. t.order = self.order
  195. if isinstance(t, Compose):
  196. t._set_order()
  197. def apply_batch(self, inputs: Sequence[Tuple]):
  198. if self.batch_compose:
  199. for t in self.transforms:
  200. inputs = t.apply_batch(inputs)
  201. return inputs
  202. else:
  203. return super().apply_batch(inputs)
  204. def apply(self, input: Tuple):
  205. for t in self._shuffle():
  206. input = t.apply(input)
  207. return input
  208. def _shuffle(self):
  209. if self.shuffle_indices is not None:
  210. source_idx = list(range(len(self.transforms)))
  211. for idx in self.shuffle_indices:
  212. shuffled = np.random.permutation(idx).tolist()
  213. for src, dst in zip(idx, shuffled):
  214. source_idx[src] = dst
  215. return [self.transforms[i] for i in source_idx]
  216. else:
  217. return self.transforms
  218. class TorchTransformCompose(VisionTransform):
  219. r"""
  220. Compose class used for transforms in torchvision, only support PIL image,
  221. some transforms with tensor in torchvision are not supported,
  222. such as Normalize and ToTensor in torchvision.
  223. :param transforms: the same with ``Compose``.
  224. :param order: the same with :class:`VisionTransform`.
  225. """
  226. def __init__(self, transforms, *, order=None):
  227. super().__init__(order)
  228. self.transforms = transforms
  229. def _apply_image(self, image):
  230. from PIL import Image
  231. try:
  232. import accimage
  233. except ImportError:
  234. accimage = None
  235. if image.shape[0] == 3: # CHW
  236. image = np.ascontiguousarray(image[[2, 1, 0]])
  237. elif image.shape[2] == 3: # HWC
  238. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  239. image = Image.fromarray(image.astype(np.uint8))
  240. for t in self.transforms:
  241. image = t(image)
  242. if isinstance(image, Image.Image) or (
  243. accimage is not None and isinstance(image, accimage.Image)
  244. ):
  245. image = np.array(image, dtype=np.uint8)
  246. if image.shape[0] == 3: # CHW
  247. image = np.ascontiguousarray(image[[2, 1, 0]])
  248. elif image.shape[2] == 3: # HWC
  249. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  250. return image
  251. class Pad(VisionTransform):
  252. r"""
  253. Pad the input data.
  254. :param size: padding size of input image, it could be integer or sequence.
  255. If it is an integer, the input image will be padded in four directions.
  256. If it is a sequence containing two integers, the bottom and right side
  257. of image will be padded.
  258. If it is a sequence containing four integers, the top, bottom, left, right
  259. side of image will be padded with given size.
  260. :param value: padding value of image, could be a sequence of int or float.
  261. if it is float value, the dtype of image will be casted to float32 also.
  262. :param mask_value: padding value of segmentation map.
  263. :param order: the same with :class:`VisionTransform`.
  264. """
  265. def __init__(self, size=0, value=0, mask_value=0, *, order=None):
  266. super().__init__(order)
  267. if isinstance(size, int):
  268. size = (size, size, size, size)
  269. elif isinstance(size, collections.abc.Sequence) and len(size) == 2:
  270. size = (0, size[0], 0, size[1])
  271. elif not (isinstance(size, collections.abc.Sequence) and len(size) == 4):
  272. raise ValueError(
  273. "size should be a list/tuple which contains "
  274. "(top, down, left, right) four pad sizes."
  275. )
  276. self.size = size
  277. self.value = value
  278. if not isinstance(mask_value, int):
  279. raise ValueError(
  280. "mask_value should be a positive integer, "
  281. "but got mask_value={}".format(mask_value)
  282. )
  283. self.mask_value = mask_value
  284. def _apply_image(self, image):
  285. return F.pad(image, self.size, self.value)
  286. def _apply_coords(self, coords):
  287. coords[:, 0] += self.size[2]
  288. coords[:, 1] += self.size[0]
  289. return coords
  290. def _apply_mask(self, mask):
  291. return F.pad(mask, self.size, self.mask_value)
  292. class Resize(VisionTransform):
  293. r"""
  294. Resize the input data.
  295. :param output_size: target size of image, with (height, width) shape.
  296. :param interpolation: interpolation method. All methods are listed below:
  297. * cv2.INTER_NEAREST – a nearest-neighbor interpolation.
  298. * cv2.INTER_LINEAR – a bilinear interpolation (used by default).
  299. * cv2.INTER_AREA – resampling using pixel area relation.
  300. * cv2.INTER_CUBIC – a bicubic interpolation over 4×4 pixel neighborhood.
  301. * cv2.INTER_LANCZOS4 – a Lanczos interpolation over 8×8 pixel neighborhood.
  302. :param order: the same with :class:`VisionTransform`.
  303. """
  304. def __init__(self, output_size, interpolation=cv2.INTER_LINEAR, *, order=None):
  305. super().__init__(order)
  306. self.output_size = output_size
  307. self.interpolation = interpolation
  308. def apply(self, input: Tuple):
  309. self._shape_info = self._get_shape(self._get_image(input))
  310. return super().apply(input)
  311. def _apply_image(self, image):
  312. h, w, th, tw = self._shape_info
  313. if h == th and w == tw:
  314. return image
  315. return F.resize(image, (th, tw), self.interpolation)
  316. def _apply_coords(self, coords):
  317. h, w, th, tw = self._shape_info
  318. if h == th and w == tw:
  319. return coords
  320. coords[:, 0] = coords[:, 0] * (tw / w)
  321. coords[:, 1] = coords[:, 1] * (th / h)
  322. return coords
  323. def _apply_mask(self, mask):
  324. h, w, th, tw = self._shape_info
  325. if h == th and w == tw:
  326. return mask
  327. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  328. def _get_shape(self, image):
  329. h, w, _ = image.shape
  330. if isinstance(self.output_size, int):
  331. if min(h, w) == self.output_size:
  332. return h, w, h, w
  333. if h < w:
  334. th = self.output_size
  335. tw = int(self.output_size * w / h)
  336. else:
  337. tw = self.output_size
  338. th = int(self.output_size * h / w)
  339. return h, w, th, tw
  340. else:
  341. return (h, w, *self.output_size)
  342. class ShortestEdgeResize(VisionTransform):
  343. def __init__(
  344. self,
  345. min_size,
  346. max_size,
  347. sample_style="range",
  348. interpolation=cv2.INTER_LINEAR,
  349. *,
  350. order=None
  351. ):
  352. super().__init__(order)
  353. if sample_style not in ("range", "choice"):
  354. raise NotImplementedError(
  355. "{} is unsupported sample style".format(sample_style)
  356. )
  357. self.sample_style = sample_style
  358. if isinstance(min_size, int):
  359. min_size = (min_size, min_size)
  360. self.min_size = min_size
  361. self.max_size = max_size
  362. self.interpolation = interpolation
  363. def apply(self, input: Tuple):
  364. self._shape_info = self._get_shape(self._get_image(input))
  365. return super().apply(input)
  366. def _apply_image(self, image):
  367. h, w, th, tw = self._shape_info
  368. if h == th and w == tw:
  369. return image
  370. return F.resize(image, (th, tw), self.interpolation)
  371. def _apply_coords(self, coords):
  372. h, w, th, tw = self._shape_info
  373. if h == th and w == tw:
  374. return coords
  375. coords[:, 0] = coords[:, 0] * (tw / w)
  376. coords[:, 1] = coords[:, 1] * (th / h)
  377. return coords
  378. def _apply_mask(self, mask):
  379. h, w, th, tw = self._shape_info
  380. if h == th and w == tw:
  381. return mask
  382. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  383. def _get_shape(self, image):
  384. h, w, _ = image.shape
  385. if self.sample_style == "range":
  386. size = np.random.randint(self.min_size[0], self.min_size[1] + 1)
  387. else:
  388. size = np.random.choice(self.min_size)
  389. scale = size / min(h, w)
  390. if h < w:
  391. th, tw = size, scale * w
  392. else:
  393. th, tw = scale * h, size
  394. if max(th, tw) > self.max_size:
  395. scale = self.max_size / max(th, tw)
  396. th = th * scale
  397. tw = tw * scale
  398. th = int(round(th))
  399. tw = int(round(tw))
  400. return h, w, th, tw
  401. class RandomResize(VisionTransform):
  402. r"""
  403. Resize the input data randomly.
  404. :param scale_range: range of scaling.
  405. :param order: the same with :class:`VisionTransform`.
  406. """
  407. def __init__(self, scale_range, interpolation=cv2.INTER_LINEAR, *, order=None):
  408. super().__init__(order)
  409. self.scale_range = scale_range
  410. self.interpolation = interpolation
  411. def apply(self, input: Tuple):
  412. self._shape_info = self._get_shape(self._get_image(input))
  413. return super().apply(input)
  414. def _apply_image(self, image):
  415. h, w, th, tw = self._shape_info
  416. if h == th and w == tw:
  417. return image
  418. return F.resize(image, (th, tw), self.interpolation)
  419. def _apply_coords(self, coords):
  420. h, w, th, tw = self._shape_info
  421. if h == th and w == tw:
  422. return coords
  423. coords[:, 0] = coords[:, 0] * (tw / w)
  424. coords[:, 1] = coords[:, 1] * (th / h)
  425. return coords
  426. def _apply_mask(self, mask):
  427. h, w, th, tw = self._shape_info
  428. if h == th and w == tw:
  429. return mask
  430. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  431. def _get_shape(self, image):
  432. h, w, _ = image.shape
  433. scale = np.random.uniform(*self.scale_range)
  434. th = int(round(h * scale))
  435. tw = int(round(w * scale))
  436. return h, w, th, tw
  437. class RandomCrop(VisionTransform):
  438. r"""
  439. Crop the input data randomly. Before applying the crop transform,
  440. pad the image first. If target size is still bigger than the size of
  441. padded image, pad the image size to target size.
  442. :param output_size: target size of output image, with (height, width) shape.
  443. :param padding_size: the same with `size` in ``Pad``.
  444. :param padding_value: the same with `value` in ``Pad``.
  445. :param order: the same with :class:`VisionTransform`.
  446. """
  447. def __init__(
  448. self,
  449. output_size,
  450. padding_size=0,
  451. padding_value=[0, 0, 0],
  452. padding_maskvalue=0,
  453. *,
  454. order=None
  455. ):
  456. super().__init__(order)
  457. if isinstance(output_size, int):
  458. self.output_size = (output_size, output_size)
  459. else:
  460. self.output_size = output_size
  461. self.pad = Pad(padding_size, padding_value, order=self.order)
  462. self.padding_value = padding_value
  463. self.padding_maskvalue = padding_maskvalue
  464. def apply(self, input):
  465. input = self.pad.apply(input)
  466. self._h, self._w, _ = self._get_image(input).shape
  467. self._th, self._tw = self.output_size
  468. self._x = np.random.randint(0, max(0, self._w - self._tw) + 1)
  469. self._y = np.random.randint(0, max(0, self._h - self._th) + 1)
  470. return super().apply(input)
  471. def _apply_image(self, image):
  472. if self._th > self._h:
  473. image = F.pad(image, (self._th - self._h, 0), self.padding_value)
  474. if self._tw > self._w:
  475. image = F.pad(image, (0, self._tw - self._w), self.padding_value)
  476. return image[self._y : self._y + self._th, self._x : self._x + self._tw]
  477. def _apply_coords(self, coords):
  478. coords[:, 0] -= self._x
  479. coords[:, 1] -= self._y
  480. return coords
  481. def _apply_mask(self, mask):
  482. if self._th > self._h:
  483. mask = F.pad(mask, (self._th - self._h, 0), self.padding_maskvalue)
  484. if self._tw > self._w:
  485. mask = F.pad(mask, (0, self._tw - self._w), self.padding_maskvalue)
  486. return mask[self._y : self._y + self._th, self._x : self._x + self._tw]
  487. class RandomResizedCrop(VisionTransform):
  488. r"""
  489. Crop the input data to random size and aspect ratio.
  490. A crop of random size (default: of 0.08 to 1.0) of the original size and a random
  491. aspect ratio (default: of 3/4 to 1.33) of the original aspect ratio is made.
  492. After applying crop transfrom, the input data will be resized to given size.
  493. :param output_size: target size of output image, with (height, width) shape.
  494. :param scale_range: range of size of the origin size cropped. Default: (0.08, 1.0)
  495. :param ratio_range: range of aspect ratio of the origin aspect ratio cropped. Default: (0.75, 1.33)
  496. :param order: the same with :class:`VisionTransform`.
  497. """
  498. def __init__(
  499. self,
  500. output_size,
  501. scale_range=(0.08, 1.0),
  502. ratio_range=(3.0 / 4, 4.0 / 3),
  503. interpolation=cv2.INTER_LINEAR,
  504. *,
  505. order=None
  506. ):
  507. super().__init__(order)
  508. if isinstance(output_size, int):
  509. self.output_size = (output_size, output_size)
  510. else:
  511. self.output_size = output_size
  512. assert (
  513. scale_range[0] <= scale_range[1]
  514. ), "scale_range should be of kind (min, max)"
  515. assert (
  516. ratio_range[0] <= ratio_range[1]
  517. ), "ratio_range should be of kind (min, max)"
  518. self.scale_range = scale_range
  519. self.ratio_range = ratio_range
  520. self.interpolation = interpolation
  521. def apply(self, input: Tuple):
  522. self._coord_info = self._get_coord(self._get_image(input))
  523. return super().apply(input)
  524. def _apply_image(self, image):
  525. x, y, w, h = self._coord_info
  526. cropped_img = image[y : y + h, x : x + w]
  527. return F.resize(cropped_img, self.output_size, self.interpolation)
  528. def _apply_coords(self, coords):
  529. x, y, w, h = self._coord_info
  530. coords[:, 0] = (coords[:, 0] - x) * self.output_size[1] / w
  531. coords[:, 1] = (coords[:, 1] - y) * self.output_size[0] / h
  532. return coords
  533. def _apply_mask(self, mask):
  534. x, y, w, h = self._coord_info
  535. cropped_mask = mask[y : y + h, x : x + w]
  536. return F.resize(cropped_mask, self.output_size, cv2.INTER_NEAREST)
  537. def _get_coord(self, image, attempts=10):
  538. height, width, _ = image.shape
  539. area = height * width
  540. for _ in range(attempts):
  541. target_area = np.random.uniform(*self.scale_range) * area
  542. log_ratio = tuple(math.log(x) for x in self.ratio_range)
  543. aspect_ratio = math.exp(np.random.uniform(*log_ratio))
  544. w = int(round(math.sqrt(target_area * aspect_ratio)))
  545. h = int(round(math.sqrt(target_area / aspect_ratio)))
  546. if 0 < w <= width and 0 < h <= height:
  547. x = np.random.randint(0, width - w + 1)
  548. y = np.random.randint(0, height - h + 1)
  549. return x, y, w, h
  550. # Fallback to central crop
  551. in_ratio = float(width) / float(height)
  552. if in_ratio < min(self.ratio_range):
  553. w = width
  554. h = int(round(w / min(self.ratio_range)))
  555. elif in_ratio > max(self.ratio_range):
  556. h = height
  557. w = int(round(h * max(self.ratio_range)))
  558. else: # whole image
  559. w = width
  560. h = height
  561. x = (width - w) // 2
  562. y = (height - h) // 2
  563. return x, y, w, h
  564. class CenterCrop(VisionTransform):
  565. r"""
  566. Crops the given the input data at the center.
  567. :param output_size: target size of output image, with (height, width) shape.
  568. :param order: the same with :class:`VisionTransform`.
  569. """
  570. def __init__(self, output_size, *, order=None):
  571. super().__init__(order)
  572. if isinstance(output_size, int):
  573. self.output_size = (output_size, output_size)
  574. else:
  575. self.output_size = output_size
  576. def apply(self, input: Tuple):
  577. self._coord_info = self._get_coord(self._get_image(input))
  578. return super().apply(input)
  579. def _apply_image(self, image):
  580. x, y = self._coord_info
  581. th, tw = self.output_size
  582. return image[y : y + th, x : x + tw]
  583. def _apply_coords(self, coords):
  584. x, y = self._coord_info
  585. coords[:, 0] -= x
  586. coords[:, 1] -= y
  587. return coords
  588. def _apply_mask(self, mask):
  589. x, y = self._coord_info
  590. th, tw = self.output_size
  591. return mask[y : y + th, x : x + tw]
  592. def _get_coord(self, image):
  593. th, tw = self.output_size
  594. h, w, _ = image.shape
  595. assert th <= h and tw <= w, "output size is bigger than image size"
  596. x = int(round((w - tw) / 2.0))
  597. y = int(round((h - th) / 2.0))
  598. return x, y
  599. class RandomHorizontalFlip(VisionTransform):
  600. r"""
  601. Horizontally flip the input data randomly with a given probability.
  602. :param p: probability of the input data being flipped. Default: 0.5
  603. :param order: the same with :class:`VisionTransform`.
  604. """
  605. def __init__(self, prob: float = 0.5, *, order=None):
  606. super().__init__(order)
  607. self.prob = prob
  608. def apply(self, input: Tuple):
  609. self._flipped = np.random.random() < self.prob
  610. self._w = self._get_image(input).shape[1]
  611. return super().apply(input)
  612. def _apply_image(self, image):
  613. if self._flipped:
  614. return F.flip(image, flipCode=1)
  615. return image
  616. def _apply_coords(self, coords):
  617. if self._flipped:
  618. coords[:, 0] = self._w - coords[:, 0]
  619. return coords
  620. def _apply_mask(self, mask):
  621. if self._flipped:
  622. return F.flip(mask, flipCode=1)
  623. return mask
  624. class RandomVerticalFlip(VisionTransform):
  625. r"""
  626. Vertically flip the input data randomly with a given probability.
  627. :param p: probability of the input data being flipped. Default: 0.5
  628. :param order: the same with :class:`VisionTransform`.
  629. """
  630. def __init__(self, prob: float = 0.5, *, order=None):
  631. super().__init__(order)
  632. self.prob = prob
  633. def apply(self, input: Tuple):
  634. self._flipped = np.random.random() < self.prob
  635. self._h = self._get_image(input).shape[0]
  636. return super().apply(input)
  637. def _apply_image(self, image):
  638. if self._flipped:
  639. return F.flip(image, flipCode=0)
  640. return image
  641. def _apply_coords(self, coords):
  642. if self._flipped:
  643. coords[:, 1] = self._h - coords[:, 1]
  644. return coords
  645. def _apply_mask(self, mask):
  646. if self._flipped:
  647. return F.flip(mask, flipCode=0)
  648. return mask
  649. class Normalize(VisionTransform):
  650. r"""
  651. Normalize the input data with mean and standard deviation.
  652. Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
  653. this transform will normalize each channel of the input data.
  654. ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
  655. :param mean: sequence of means for each channel.
  656. :param std: sequence of standard deviations for each channel.
  657. :param order: the same with :class:`VisionTransform`.
  658. """
  659. def __init__(self, mean=0.0, std=1.0, *, order=None):
  660. super().__init__(order)
  661. self.mean = np.array(mean, dtype=np.float32)
  662. self.std = np.array(std, dtype=np.float32)
  663. def _apply_image(self, image):
  664. return (image - self.mean) / self.std
  665. def _apply_coords(self, coords):
  666. return coords
  667. def _apply_mask(self, mask):
  668. return mask
  669. class GaussianNoise(VisionTransform):
  670. r"""
  671. Add random gaussian noise to the input data.
  672. Gaussian noise is generated with given mean and std.
  673. :param mean: Gaussian mean used to generate noise.
  674. :param std: Gaussian standard deviation used to generate noise.
  675. :param order: the same with :class:`VisionTransform`
  676. """
  677. def __init__(self, mean=0.0, std=1.0, *, order=None):
  678. super().__init__(order)
  679. self.mean = np.array(mean, dtype=np.float32)
  680. self.std = np.array(std, dtype=np.float32)
  681. def _apply_image(self, image):
  682. dtype = image.dtype
  683. noise = np.random.normal(self.mean, self.std, image.shape) * 255
  684. image = image + noise.astype(np.float32)
  685. return np.clip(image, 0, 255).astype(dtype)
  686. def _apply_coords(self, coords):
  687. return coords
  688. def _apply_mask(self, mask):
  689. return mask
  690. class BrightnessTransform(VisionTransform):
  691. r"""
  692. Adjust brightness of the input data.
  693. :param value: how much to adjust the brightness. Can be any
  694. non negative number. 0 gives the original image.
  695. :param order: the same with :class:`VisionTransform`.
  696. """
  697. def __init__(self, value, *, order=None):
  698. super().__init__(order)
  699. if value < 0:
  700. raise ValueError("brightness value should be non-negative")
  701. self.value = value
  702. def _apply_image(self, image):
  703. if self.value == 0:
  704. return image
  705. dtype = image.dtype
  706. image = image.astype(np.float32)
  707. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  708. image = image * alpha
  709. return image.clip(0, 255).astype(dtype)
  710. def _apply_coords(self, coords):
  711. return coords
  712. def _apply_mask(self, mask):
  713. return mask
  714. class ContrastTransform(VisionTransform):
  715. r"""
  716. Adjust contrast of the input data.
  717. :param value: how much to adjust the contrast. Can be any
  718. non negative number. 0 gives the original image.
  719. :param order: the same with :class:`VisionTransform`.
  720. """
  721. def __init__(self, value, *, order=None):
  722. super().__init__(order)
  723. if value < 0:
  724. raise ValueError("contrast value should be non-negative")
  725. self.value = value
  726. def _apply_image(self, image):
  727. if self.value == 0:
  728. return image
  729. dtype = image.dtype
  730. image = image.astype(np.float32)
  731. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  732. image = image * alpha + F.to_gray(image).mean() * (1 - alpha)
  733. return image.clip(0, 255).astype(dtype)
  734. def _apply_coords(self, coords):
  735. return coords
  736. def _apply_mask(self, mask):
  737. return mask
  738. class SaturationTransform(VisionTransform):
  739. r"""
  740. Adjust saturation of the input data.
  741. :param value: how much to adjust the saturation. Can be any
  742. non negative number. 0 gives the original image.
  743. :param order: the same with :class:`VisionTransform`.
  744. """
  745. def __init__(self, value, *, order=None):
  746. super().__init__(order)
  747. if value < 0:
  748. raise ValueError("saturation value should be non-negative")
  749. self.value = value
  750. def _apply_image(self, image):
  751. if self.value == 0:
  752. return image
  753. dtype = image.dtype
  754. image = image.astype(np.float32)
  755. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  756. image = image * alpha + F.to_gray(image) * (1 - alpha)
  757. return image.clip(0, 255).astype(dtype)
  758. def _apply_coords(self, coords):
  759. return coords
  760. def _apply_mask(self, mask):
  761. return mask
  762. class HueTransform(VisionTransform):
  763. r"""
  764. Adjust hue of the input data.
  765. :param value: how much to adjust the hue. Can be any number
  766. between 0 and 0.5, 0 gives the original image.
  767. :param order: the same with :class:`VisionTransform`.
  768. """
  769. def __init__(self, value, *, order=None):
  770. super().__init__(order)
  771. if value < 0 or value > 0.5:
  772. raise ValueError("hue value should be in [0.0, 0.5]")
  773. self.value = value
  774. def _apply_image(self, image):
  775. if self.value == 0:
  776. return image
  777. dtype = image.dtype
  778. image = image.astype(np.uint8)
  779. hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV_FULL)
  780. h, s, v = cv2.split(hsv_image)
  781. alpha = np.random.uniform(-self.value, self.value)
  782. h = h.astype(np.uint8)
  783. # uint8 addition take cares of rotation across boundaries
  784. with np.errstate(over="ignore"):
  785. h += np.uint8(alpha * 255)
  786. hsv_image = cv2.merge([h, s, v])
  787. return cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR_FULL).astype(dtype)
  788. def _apply_coords(self, coords):
  789. return coords
  790. def _apply_mask(self, mask):
  791. return mask
  792. class ColorJitter(VisionTransform):
  793. r"""
  794. Randomly change the brightness, contrast, saturation and hue of an image.
  795. :param brightness: how much to jitter brightness.
  796. Chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
  797. or the given [min, max]. Should be non negative numbers.
  798. :param contrast: how much to jitter contrast.
  799. Chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
  800. or the given [min, max]. Should be non negative numbers.
  801. :param saturation: how much to jitter saturation.
  802. Chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
  803. or the given [min, max]. Should be non negative numbers.
  804. :param hue: how much to jitter hue.
  805. Chosen uniformly from [-hue, hue] or the given [min, max].
  806. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
  807. :param order: the same with :class:`VisionTransform`.
  808. """
  809. def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, *, order=None):
  810. super().__init__(order)
  811. transforms = []
  812. if brightness != 0:
  813. transforms.append(BrightnessTransform(brightness))
  814. if contrast != 0:
  815. transforms.append(ContrastTransform(contrast))
  816. if saturation != 0:
  817. transforms.append(SaturationTransform(saturation))
  818. if hue != 0:
  819. transforms.append(HueTransform(hue))
  820. self.transforms = Compose(
  821. transforms,
  822. shuffle_indices=[tuple(range(1, len(transforms) + 1))],
  823. order=order,
  824. )
  825. def apply(self, input):
  826. return self.transforms.apply(input)
  827. class Lighting(VisionTransform):
  828. def __init__(self, scale, *, order=None):
  829. super().__init__(order)
  830. if scale < 0:
  831. raise ValueError("lighting scale should be non-negative")
  832. self.scale = scale
  833. self.eigvec = np.array(
  834. [
  835. [-0.5836, -0.6948, 0.4203],
  836. [-0.5808, -0.0045, -0.8140],
  837. [-0.5675, 0.7192, 0.4009],
  838. ]
  839. ) # reverse the first dimension for BGR
  840. self.eigval = np.array([0.2175, 0.0188, 0.0045])
  841. def _apply_image(self, image):
  842. if self.scale == 0:
  843. return image
  844. dtype = image.dtype
  845. image = image.astype(np.float32)
  846. alpha = np.random.normal(scale=self.scale, size=3)
  847. image = image + self.eigvec.dot(alpha * self.eigval)
  848. return image.clip(0, 255).astype(dtype)
  849. def _apply_coords(self, coords):
  850. return coords
  851. def _apply_mask(self, mask):
  852. return mask

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台