You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.py 35 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections.abc
  10. import math
  11. from typing import Sequence, Tuple
  12. import cv2
  13. import numpy as np
  14. from megengine.data.transform import Transform
  15. from megengine.data.transform.vision import functional as F
  16. __all__ = [
  17. "VisionTransform",
  18. "ToMode",
  19. "Compose",
  20. "TorchTransformCompose",
  21. "Pad",
  22. "Resize",
  23. "ShortestEdgeResize",
  24. "RandomResize",
  25. "RandomCrop",
  26. "RandomResizedCrop",
  27. "CenterCrop",
  28. "RandomHorizontalFlip",
  29. "RandomVerticalFlip",
  30. "Normalize",
  31. "GaussianNoise",
  32. "BrightnessTransform",
  33. "SaturationTransform",
  34. "ContrastTransform",
  35. "HueTransform",
  36. "ColorJitter",
  37. "Lighting",
  38. ]
  39. class VisionTransform(Transform):
  40. r"""
  41. Base class of all transforms used in computer vision.
  42. calling logic: apply_batch() -> apply() -> _apply_image() and other _apply_*()
  43. method. If you want to implement a self-defined transform method for image,
  44. rewrite _apply_image method in subclass.
  45. :param order: Input type order. Input is a tuple contains different structures,
  46. order is used to specify the order of structures. For example, if your input
  47. is (image, boxes) type, then the order should be ("image", "boxes").
  48. Current available strings & data type are describe below:
  49. * "image": input image, with shape of (H, W, C)
  50. * "coords": coordinates, with shape of (N, 2)
  51. * "boxes": bounding boxes, with shape of (N, 4), "xyxy" format,
  52. the 1st "xy" represents top left point of a box,
  53. the 2nd "xy" represents right bottom point.
  54. * "mask": map used for segmentation, with shape of (H, W, 1)
  55. * "keypoints": keypoints with shape of (N, K, 3), N for number of instances,
  56. and K for number of keypoints in one instance. The first two dimensions
  57. of last axis is coordinate of keypoints and the the 3rd dimension is
  58. the label of keypoints.
  59. * "polygons": A sequence contains numpy array, its length is number of instances.
  60. Each numpy array represents polygon coordinate of one instance.
  61. * "category": categories for some data type. For example, "image_category"
  62. means category of the input image and "boxes_category" means categories of
  63. bounding boxes.
  64. * "info": information for images such as image shapes and image path.
  65. You can also customize your data types only if you implement the corresponding
  66. _apply_*() methods, otherwise ``NotImplementedError`` will be raised.
  67. """
  68. def __init__(self, order=None):
  69. super().__init__()
  70. if order is None:
  71. order = ("image",)
  72. elif not isinstance(order, collections.abc.Sequence):
  73. raise ValueError(
  74. "order should be a sequence, but got order={}".format(order)
  75. )
  76. for k in order:
  77. if k in ("batch",):
  78. raise ValueError("{} is invalid data type".format(k))
  79. elif k.endswith("category") or k.endswith("info"):
  80. # when the key is *category or info, we should do nothing
  81. # if the corresponding apply methods are not implemented.
  82. continue
  83. elif self._get_apply(k) is None:
  84. raise NotImplementedError("{} is unsupported data type".format(k))
  85. self.order = order
  86. def apply_batch(self, inputs: Sequence[Tuple]):
  87. r"""Apply transform on batch input data"""
  88. return tuple(self.apply(input) for input in inputs)
  89. def apply(self, input: Tuple):
  90. r"""Apply transform on single input data"""
  91. if not isinstance(input, tuple):
  92. input = (input,)
  93. output = []
  94. for i in range(min(len(input), len(self.order))):
  95. apply_func = self._get_apply(self.order[i])
  96. if apply_func is None:
  97. output.append(input[i])
  98. else:
  99. output.append(apply_func(input[i]))
  100. if len(input) > len(self.order):
  101. output.extend(input[len(self.order) :])
  102. if len(output) == 1:
  103. output = output[0]
  104. else:
  105. output = tuple(output)
  106. return output
  107. def _get_apply(self, key):
  108. return getattr(self, "_apply_{}".format(key), None)
  109. def _get_image(self, input: Tuple):
  110. if not isinstance(input, tuple):
  111. input = (input,)
  112. return input[self.order.index("image")]
  113. def _apply_image(self, image):
  114. raise NotImplementedError
  115. def _apply_coords(self, coords):
  116. raise NotImplementedError
  117. def _apply_boxes(self, boxes):
  118. idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
  119. coords = np.asarray(boxes).reshape(-1, 4)[:, idxs].reshape(-1, 2)
  120. coords = self._apply_coords(coords).reshape((-1, 4, 2))
  121. minxy = coords.min(axis=1)
  122. maxxy = coords.max(axis=1)
  123. trans_boxes = np.concatenate((minxy, maxxy), axis=1)
  124. return trans_boxes
  125. def _apply_mask(self, mask):
  126. raise NotImplementedError
  127. def _apply_keypoints(self, keypoints):
  128. coords, visibility = keypoints[..., :2], keypoints[..., 2:]
  129. trans_coords = [self._apply_coords(p) for p in coords]
  130. return np.concatenate((trans_coords, visibility), axis=-1)
  131. def _apply_polygons(self, polygons):
  132. return [[self._apply_coords(p) for p in instance] for instance in polygons]
  133. class ToMode(VisionTransform):
  134. r"""Change input data to a target mode.
  135. For example, most transforms use HWC mode image,
  136. while the Neural Network might use CHW mode input tensor
  137. :param mode: Output mode of input. Use "CHW" mode by default.
  138. :param order: The same with :class:`VisionTransform`
  139. """
  140. def __init__(self, mode="CHW", *, order=None):
  141. super().__init__(order)
  142. assert mode in ["CHW"], "unsupported mode: {}".format(mode)
  143. self.mode = mode
  144. def _apply_image(self, image):
  145. if self.mode == "CHW":
  146. return np.ascontiguousarray(np.rollaxis(image, 2))
  147. return image
  148. def _apply_coords(self, coords):
  149. return coords
  150. def _apply_mask(self, mask):
  151. if self.mode == "CHW":
  152. return np.ascontiguousarray(np.rollaxis(mask, 2))
  153. return mask
  154. class Compose(VisionTransform):
  155. r"""
  156. Composes several transforms together.
  157. :param transforms: List of :class:`VisionTransform` to compose.
  158. :param batch_compose: Whether use shuffle_indices for batch data or not.
  159. If True, use original input sequence.
  160. Otherwise, the shuffle_indices will be used for transforms.
  161. :param shuffle_indices: Indices used for random shuffle, start at 1.
  162. For example, if shuffle_indices is [(1, 3), (2, 4)], then the 1st and 3rd transform
  163. will be random shuffled, the 2nd and 4th transform will also be shuffled.
  164. :param order: The same with :class:`VisionTransform`
  165. Example:
  166. ..testcode::
  167. from megengine.data.transform import RandomHorizontalFlip, RandomVerticalFlip, CenterCrop, ToMode, Compose
  168. transform_func = Compose([
  169. RandomHorizontalFlip(),
  170. RandomVerticalFlip(),
  171. CenterCrop(100),
  172. ToMode("CHW"),
  173. ],
  174. shuffle_indices=[(1, 2, 3)]
  175. )
  176. """
  177. def __init__(
  178. self, transforms=[], batch_compose=False, shuffle_indices=None, *, order=None
  179. ):
  180. super().__init__(order)
  181. self.transforms = transforms
  182. self._set_order()
  183. if batch_compose and shuffle_indices is not None:
  184. raise ValueError(
  185. "Do not support shuffle when apply transforms along the whole batch"
  186. )
  187. self.batch_compose = batch_compose
  188. if shuffle_indices is not None:
  189. shuffle_indices = [tuple(x - 1 for x in idx) for idx in shuffle_indices]
  190. self.shuffle_indices = shuffle_indices
  191. def _set_order(self):
  192. for t in self.transforms:
  193. t.order = self.order
  194. if isinstance(t, Compose):
  195. t._set_order()
  196. def apply_batch(self, inputs: Sequence[Tuple]):
  197. if self.batch_compose:
  198. for t in self.transforms:
  199. inputs = t.apply_batch(inputs)
  200. return inputs
  201. else:
  202. return super().apply_batch(inputs)
  203. def apply(self, input: Tuple):
  204. for t in self._shuffle():
  205. input = t.apply(input)
  206. return input
  207. def _shuffle(self):
  208. if self.shuffle_indices is not None:
  209. source_idx = list(range(len(self.transforms)))
  210. for idx in self.shuffle_indices:
  211. shuffled = np.random.permutation(idx).tolist()
  212. for src, dst in zip(idx, shuffled):
  213. source_idx[src] = dst
  214. return [self.transforms[i] for i in source_idx]
  215. else:
  216. return self.transforms
  217. class TorchTransformCompose(VisionTransform):
  218. r"""
  219. Compose class used for transforms in torchvision, only support PIL image,
  220. some transforms with tensor in torchvision are not supported,
  221. such as Normalize and ToTensor in torchvision.
  222. :param transforms: The same with ``Compose``
  223. :param order: The same with :class:`VisionTransform`
  224. """
  225. def __init__(self, transforms, *, order=None):
  226. super().__init__(order)
  227. self.transforms = transforms
  228. def _apply_image(self, image):
  229. from PIL import Image
  230. try:
  231. import accimage
  232. except ImportError:
  233. accimage = None
  234. if image.shape[0] == 3: # CHW
  235. image = np.ascontiguousarray(image[[2, 1, 0]])
  236. elif image.shape[2] == 3: # HWC
  237. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  238. image = Image.fromarray(image.astype(np.uint8))
  239. for t in self.transforms:
  240. image = t(image)
  241. if isinstance(image, Image.Image) or (
  242. accimage is not None and isinstance(image, accimage.Image)
  243. ):
  244. image = np.array(image, dtype=np.uint8)
  245. if image.shape[0] == 3: # CHW
  246. image = np.ascontiguousarray(image[[2, 1, 0]])
  247. elif image.shape[2] == 3: # HWC
  248. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  249. return image
  250. class Pad(VisionTransform):
  251. r"""Pad the input data.
  252. :param size: Padding size of input image, it could be integer or sequence.
  253. If it's an integer, the input image will be padded in four directions.
  254. If it's a sequence contains two integer, the bottom and right side
  255. of image will be padded.
  256. If it's a sequence contains four integer, the top, bottom, left, right
  257. side of image will be padded with given size.
  258. :param value: Padding value of image, could be a sequence of int or float.
  259. if it's float value, the dtype of image will be casted to float32 also.
  260. :param mask_value: Padding value of segmentation map.
  261. :param order: The same with :class:`VisionTransform`
  262. """
  263. def __init__(self, size=0, value=0, mask_value=0, *, order=None):
  264. super().__init__(order)
  265. if isinstance(size, int):
  266. size = (size, size, size, size)
  267. elif isinstance(size, collections.abc.Sequence) and len(size) == 2:
  268. size = (0, size[0], 0, size[1])
  269. elif not (isinstance(size, collections.abc.Sequence) and len(size) == 4):
  270. raise ValueError(
  271. "size should be a list/tuple which contains "
  272. "(top, down, left, right) four pad sizes."
  273. )
  274. self.size = size
  275. self.value = value
  276. if not isinstance(mask_value, int):
  277. raise ValueError(
  278. "mask_value should be a positive integer, "
  279. "but got mask_value={}".format(mask_value)
  280. )
  281. self.mask_value = mask_value
  282. def _apply_image(self, image):
  283. return F.pad(image, self.size, self.value)
  284. def _apply_coords(self, coords):
  285. coords[:, 0] += self.size[2]
  286. coords[:, 1] += self.size[0]
  287. return coords
  288. def _apply_mask(self, mask):
  289. return F.pad(mask, self.size, self.mask_value)
  290. class Resize(VisionTransform):
  291. r"""Resize the input data.
  292. :param output_size: Target size of image, with (height, width) shape.
  293. :param interpolation: Interpolation method. All methods are listed below:
  294. * cv2.INTER_NEAREST – a nearest-neighbor interpolation.
  295. * cv2.INTER_LINEAR – a bilinear interpolation (used by default).
  296. * cv2.INTER_AREA – resampling using pixel area relation.
  297. * cv2.INTER_CUBIC – a bicubic interpolation over 4×4 pixel neighborhood.
  298. * cv2.INTER_LANCZOS4 – a Lanczos interpolation over 8×8 pixel neighborhood.
  299. :param order: The same with :class:`VisionTransform`
  300. """
  301. def __init__(self, output_size, interpolation=cv2.INTER_LINEAR, *, order=None):
  302. super().__init__(order)
  303. self.output_size = output_size
  304. self.interpolation = interpolation
  305. def apply(self, input: Tuple):
  306. self._shape_info = self._get_shape(self._get_image(input))
  307. return super().apply(input)
  308. def _apply_image(self, image):
  309. h, w, th, tw = self._shape_info
  310. if h == th and w == tw:
  311. return image
  312. return F.resize(image, (th, tw), self.interpolation)
  313. def _apply_coords(self, coords):
  314. h, w, th, tw = self._shape_info
  315. if h == th and w == tw:
  316. return coords
  317. coords[:, 0] = coords[:, 0] * (tw / w)
  318. coords[:, 1] = coords[:, 1] * (th / h)
  319. return coords
  320. def _apply_mask(self, mask):
  321. h, w, th, tw = self._shape_info
  322. if h == th and w == tw:
  323. return mask
  324. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  325. def _get_shape(self, image):
  326. h, w, _ = image.shape
  327. if isinstance(self.output_size, int):
  328. if min(h, w) == self.output_size:
  329. return h, w, h, w
  330. if h < w:
  331. th = self.output_size
  332. tw = int(self.output_size * w / h)
  333. else:
  334. tw = self.output_size
  335. th = int(self.output_size * h / w)
  336. return h, w, th, tw
  337. else:
  338. return (h, w, *self.output_size)
  339. class ShortestEdgeResize(VisionTransform):
  340. def __init__(
  341. self,
  342. min_size,
  343. max_size,
  344. sample_style="range",
  345. interpolation=cv2.INTER_LINEAR,
  346. *,
  347. order=None
  348. ):
  349. super().__init__(order)
  350. if sample_style not in ("range", "choice"):
  351. raise NotImplementedError(
  352. "{} is unsupported sample style".format(sample_style)
  353. )
  354. self.sample_style = sample_style
  355. if isinstance(min_size, int):
  356. min_size = (min_size, min_size)
  357. self.min_size = min_size
  358. self.max_size = max_size
  359. self.interpolation = interpolation
  360. def apply(self, input: Tuple):
  361. self._shape_info = self._get_shape(self._get_image(input))
  362. return super().apply(input)
  363. def _apply_image(self, image):
  364. h, w, th, tw = self._shape_info
  365. if h == th and w == tw:
  366. return image
  367. return F.resize(image, (th, tw), self.interpolation)
  368. def _apply_coords(self, coords):
  369. h, w, th, tw = self._shape_info
  370. if h == th and w == tw:
  371. return coords
  372. coords[:, 0] = coords[:, 0] * (tw / w)
  373. coords[:, 1] = coords[:, 1] * (th / h)
  374. return coords
  375. def _apply_mask(self, mask):
  376. h, w, th, tw = self._shape_info
  377. if h == th and w == tw:
  378. return mask
  379. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  380. def _get_shape(self, image):
  381. h, w, _ = image.shape
  382. if self.sample_style == "range":
  383. size = np.random.randint(self.min_size[0], self.min_size[1] + 1)
  384. else:
  385. size = np.random.choice(self.min_size)
  386. scale = size / min(h, w)
  387. if h < w:
  388. th, tw = size, scale * w
  389. else:
  390. th, tw = scale * h, size
  391. if max(th, tw) > self.max_size:
  392. scale = self.max_size / max(th, tw)
  393. th = th * scale
  394. tw = tw * scale
  395. th = int(round(th))
  396. tw = int(round(tw))
  397. return h, w, th, tw
  398. class RandomResize(VisionTransform):
  399. r"""Resize the input data randomly.
  400. :param scale_range: .
  401. :param order: The same with :class:`VisionTransform`
  402. """
  403. def __init__(self, scale_range, interpolation=cv2.INTER_LINEAR, *, order=None):
  404. super().__init__(order)
  405. self.scale_range = scale_range
  406. self.interpolation = interpolation
  407. def apply(self, input: Tuple):
  408. self._shape_info = self._get_shape(self._get_image(input))
  409. return super().apply(input)
  410. def _apply_image(self, image):
  411. h, w, th, tw = self._shape_info
  412. if h == th and w == tw:
  413. return image
  414. return F.resize(image, (th, tw), self.interpolation)
  415. def _apply_coords(self, coords):
  416. h, w, th, tw = self._shape_info
  417. if h == th and w == tw:
  418. return coords
  419. coords[:, 0] = coords[:, 0] * (tw / w)
  420. coords[:, 1] = coords[:, 1] * (th / h)
  421. return coords
  422. def _apply_mask(self, mask):
  423. h, w, th, tw = self._shape_info
  424. if h == th and w == tw:
  425. return mask
  426. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  427. def _get_shape(self, image):
  428. h, w, _ = image.shape
  429. scale = np.random.uniform(*self.scale_range)
  430. th = int(round(h * scale))
  431. tw = int(round(w * scale))
  432. return h, w, th, tw
  433. class RandomCrop(VisionTransform):
  434. r"""Crop the input data randomly. Before applying the crop transform,
  435. pad the image first. And if target size is still bigger than the size of
  436. padded image, pad the image size to target size.
  437. :param output_size: Target size of output image, with (height, width) shape.
  438. :param padding_size: The same with `size` in ``Pad``
  439. :param padding_value: The same with `value` in ``Pad``
  440. :param order: The same with :class:`VisionTransform`
  441. """
  442. def __init__(
  443. self,
  444. output_size,
  445. padding_size=0,
  446. padding_value=[0, 0, 0],
  447. padding_maskvalue=0,
  448. *,
  449. order=None
  450. ):
  451. super().__init__(order)
  452. if isinstance(output_size, int):
  453. self.output_size = (output_size, output_size)
  454. else:
  455. self.output_size = output_size
  456. self.pad = Pad(padding_size, padding_value, order=self.order)
  457. self.padding_value = padding_value
  458. self.padding_maskvalue = padding_maskvalue
  459. def apply(self, input):
  460. input = self.pad.apply(input)
  461. self._h, self._w, _ = self._get_image(input).shape
  462. self._th, self._tw = self.output_size
  463. self._x = np.random.randint(0, max(0, self._w - self._tw) + 1)
  464. self._y = np.random.randint(0, max(0, self._h - self._th) + 1)
  465. return super().apply(input)
  466. def _apply_image(self, image):
  467. if self._th > self._h:
  468. image = F.pad(image, (self._th - self._h, 0), self.padding_value)
  469. if self._tw > self._w:
  470. image = F.pad(image, (0, self._tw - self._w), self.padding_value)
  471. return image[self._y : self._y + self._th, self._x : self._x + self._tw]
  472. def _apply_coords(self, coords):
  473. coords[:, 0] -= self._x
  474. coords[:, 1] -= self._y
  475. return coords
  476. def _apply_mask(self, mask):
  477. if self._th > self._h:
  478. mask = F.pad(mask, (self._th - self._h, 0), self.padding_maskvalue)
  479. if self._tw > self._w:
  480. mask = F.pad(mask, (0, self._tw - self._w), self.padding_maskvalue)
  481. return mask[self._y : self._y + self._th, self._x : self._x + self._tw]
  482. class RandomResizedCrop(VisionTransform):
  483. r"""Crop the input data to random size and aspect ratio.
  484. A crop of random size (default: of 0.08 to 1.0) of the original size and a random
  485. aspect ratio (default: of 3/4 to 1.33) of the original aspect ratio is made.
  486. After applying crop transfrom, the input data will be resized to given size.
  487. :param output_size: Target size of output image, with (height, width) shape.
  488. :param scale_range: Range of size of the origin size cropped. Default: (0.08, 1.0)
  489. :param ratio_range: Range of aspect ratio of the origin aspect ratio cropped. Default: (0.75, 1.33)
  490. :param order: The same with :class:`VisionTransform`
  491. """
  492. def __init__(
  493. self,
  494. output_size,
  495. scale_range=(0.08, 1.0),
  496. ratio_range=(3.0 / 4, 4.0 / 3),
  497. interpolation=cv2.INTER_LINEAR,
  498. *,
  499. order=None
  500. ):
  501. super().__init__(order)
  502. if isinstance(output_size, int):
  503. self.output_size = (output_size, output_size)
  504. else:
  505. self.output_size = output_size
  506. assert (
  507. scale_range[0] <= scale_range[1]
  508. ), "scale_range should be of kind (min, max)"
  509. assert (
  510. ratio_range[0] <= ratio_range[1]
  511. ), "ratio_range should be of kind (min, max)"
  512. self.scale_range = scale_range
  513. self.ratio_range = ratio_range
  514. self.interpolation = interpolation
  515. def apply(self, input: Tuple):
  516. self._coord_info = self._get_coord(self._get_image(input))
  517. return super().apply(input)
  518. def _apply_image(self, image):
  519. x, y, w, h = self._coord_info
  520. cropped_img = image[y : y + h, x : x + w]
  521. return F.resize(cropped_img, self.output_size, self.interpolation)
  522. def _apply_coords(self, coords):
  523. x, y, w, h = self._coord_info
  524. coords[:, 0] = (coords[:, 0] - x) * self.output_size[1] / w
  525. coords[:, 1] = (coords[:, 1] - y) * self.output_size[0] / h
  526. return coords
  527. def _apply_mask(self, mask):
  528. x, y, w, h = self._coord_info
  529. cropped_mask = mask[y : y + h, x : x + w]
  530. return F.resize(cropped_mask, self.output_size, cv2.INTER_NEAREST)
  531. def _get_coord(self, image, attempts=10):
  532. height, width, _ = image.shape
  533. area = height * width
  534. for _ in range(attempts):
  535. target_area = np.random.uniform(*self.scale_range) * area
  536. log_ratio = tuple(math.log(x) for x in self.ratio_range)
  537. aspect_ratio = math.exp(np.random.uniform(*log_ratio))
  538. w = int(round(math.sqrt(target_area * aspect_ratio)))
  539. h = int(round(math.sqrt(target_area / aspect_ratio)))
  540. if 0 < w <= width and 0 < h <= height:
  541. x = np.random.randint(0, width - w + 1)
  542. y = np.random.randint(0, height - h + 1)
  543. return x, y, w, h
  544. # Fallback to central crop
  545. in_ratio = float(width) / float(height)
  546. if in_ratio < min(self.ratio_range):
  547. w = width
  548. h = int(round(w / min(self.ratio_range)))
  549. elif in_ratio > max(self.ratio_range):
  550. h = height
  551. w = int(round(h * max(self.ratio_range)))
  552. else: # whole image
  553. w = width
  554. h = height
  555. x = (width - w) // 2
  556. y = (height - h) // 2
  557. return x, y, w, h
  558. class CenterCrop(VisionTransform):
  559. r"""Crops the given the input data at the center.
  560. :param output_size: Target size of output image, with (height, width) shape.
  561. :param order: The same with :class:`VisionTransform`
  562. """
  563. def __init__(self, output_size, *, order=None):
  564. super().__init__(order)
  565. if isinstance(output_size, int):
  566. self.output_size = (output_size, output_size)
  567. else:
  568. self.output_size = output_size
  569. def apply(self, input: Tuple):
  570. self._coord_info = self._get_coord(self._get_image(input))
  571. return super().apply(input)
  572. def _apply_image(self, image):
  573. x, y = self._coord_info
  574. th, tw = self.output_size
  575. return image[y : y + th, x : x + tw]
  576. def _apply_coords(self, coords):
  577. x, y = self._coord_info
  578. coords[:, 0] -= x
  579. coords[:, 1] -= y
  580. return coords
  581. def _apply_mask(self, mask):
  582. x, y = self._coord_info
  583. th, tw = self.output_size
  584. return mask[y : y + th, x : x + tw]
  585. def _get_coord(self, image):
  586. th, tw = self.output_size
  587. h, w, _ = image.shape
  588. assert th <= h and tw <= w, "output size is bigger than image size"
  589. x = int(round((w - tw) / 2.0))
  590. y = int(round((h - th) / 2.0))
  591. return x, y
  592. class RandomHorizontalFlip(VisionTransform):
  593. r"""Horizontally flip the input data randomly with a given probability.
  594. :param p: probability of the input data being flipped. Default: 0.5
  595. :param order: The same with :class:`VisionTransform`
  596. """
  597. def __init__(self, prob: float = 0.5, *, order=None):
  598. super().__init__(order)
  599. self.prob = prob
  600. def apply(self, input: Tuple):
  601. self._flipped = np.random.random() < self.prob
  602. self._w = self._get_image(input).shape[1]
  603. return super().apply(input)
  604. def _apply_image(self, image):
  605. if self._flipped:
  606. return F.flip(image, flipCode=1)
  607. return image
  608. def _apply_coords(self, coords):
  609. if self._flipped:
  610. coords[:, 0] = self._w - coords[:, 0]
  611. return coords
  612. def _apply_mask(self, mask):
  613. if self._flipped:
  614. return F.flip(mask, flipCode=1)
  615. return mask
  616. class RandomVerticalFlip(VisionTransform):
  617. r"""Vertically flip the input data randomly with a given probability.
  618. :param p: probability of the input data being flipped. Default: 0.5
  619. :param order: The same with :class:`VisionTransform`
  620. """
  621. def __init__(self, prob: float = 0.5, *, order=None):
  622. super().__init__(order)
  623. self.prob = prob
  624. def apply(self, input: Tuple):
  625. self._flipped = np.random.random() < self.prob
  626. self._h = self._get_image(input).shape[0]
  627. return super().apply(input)
  628. def _apply_image(self, image):
  629. if self._flipped:
  630. return F.flip(image, flipCode=0)
  631. return image
  632. def _apply_coords(self, coords):
  633. if self._flipped:
  634. coords[:, 1] = self._h - coords[:, 1]
  635. return coords
  636. def _apply_mask(self, mask):
  637. if self._flipped:
  638. return F.flip(mask, flipCode=0)
  639. return mask
  640. class Normalize(VisionTransform):
  641. r"""Normalize the input data with mean and standard deviation.
  642. Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
  643. this transform will normalize each channel of the input data.
  644. ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
  645. :param mean: Sequence of means for each channel.
  646. :param std: Sequence of standard deviations for each channel.
  647. :param order: The same with :class:`VisionTransform`
  648. """
  649. def __init__(self, mean=0.0, std=1.0, *, order=None):
  650. super().__init__(order)
  651. self.mean = np.array(mean, dtype=np.float32)
  652. self.std = np.array(std, dtype=np.float32)
  653. def _apply_image(self, image):
  654. return (image - self.mean) / self.std
  655. def _apply_coords(self, coords):
  656. return coords
  657. def _apply_mask(self, mask):
  658. return mask
  659. class GaussianNoise(VisionTransform):
  660. r"""Add random gaussian noise to the input data.
  661. Gaussian noise is generated with given mean and std.
  662. :param mean: Gaussian mean used to generate noise.
  663. :param std: Gaussian standard deviation used to generate noise.
  664. :param order: The same with :class:`VisionTransform`
  665. """
  666. def __init__(self, mean=0.0, std=1.0, *, order=None):
  667. super().__init__(order)
  668. self.mean = np.array(mean, dtype=np.float32)
  669. self.std = np.array(std, dtype=np.float32)
  670. def _apply_image(self, image):
  671. dtype = image.dtype
  672. noise = np.random.normal(self.mean, self.std, image.shape) * 255
  673. image = image + noise.astype(np.float32)
  674. return np.clip(image, 0, 255).astype(dtype)
  675. def _apply_coords(self, coords):
  676. return coords
  677. def _apply_mask(self, mask):
  678. return mask
  679. class BrightnessTransform(VisionTransform):
  680. r"""Adjust brightness of the input data.
  681. :param value: How much to adjust the brightness. Can be any
  682. non negative number. 0 gives the original image
  683. :param order: The same with :class:`VisionTransform`
  684. """
  685. def __init__(self, value, *, order=None):
  686. super().__init__(order)
  687. if value < 0:
  688. raise ValueError("brightness value should be non-negative")
  689. self.value = value
  690. def _apply_image(self, image):
  691. if self.value == 0:
  692. return image
  693. dtype = image.dtype
  694. image = image.astype(np.float32)
  695. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  696. image = image * alpha
  697. return image.clip(0, 255).astype(dtype)
  698. def _apply_coords(self, coords):
  699. return coords
  700. def _apply_mask(self, mask):
  701. return mask
  702. class ContrastTransform(VisionTransform):
  703. r"""Adjust contrast of the input data.
  704. :param value: How much to adjust the contrast. Can be any
  705. non negative number. 0 gives the original image
  706. :param order: The same with :class:`VisionTransform`
  707. """
  708. def __init__(self, value, *, order=None):
  709. super().__init__(order)
  710. if value < 0:
  711. raise ValueError("contrast value should be non-negative")
  712. self.value = value
  713. def _apply_image(self, image):
  714. if self.value == 0:
  715. return image
  716. dtype = image.dtype
  717. image = image.astype(np.float32)
  718. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  719. image = image * alpha + F.to_gray(image).mean() * (1 - alpha)
  720. return image.clip(0, 255).astype(dtype)
  721. def _apply_coords(self, coords):
  722. return coords
  723. def _apply_mask(self, mask):
  724. return mask
  725. class SaturationTransform(VisionTransform):
  726. r"""Adjust saturation of the input data.
  727. :param value: How much to adjust the saturation. Can be any
  728. non negative number. 0 gives the original image
  729. :param order: The same with :class:`VisionTransform`
  730. """
  731. def __init__(self, value, *, order=None):
  732. super().__init__(order)
  733. if value < 0:
  734. raise ValueError("saturation value should be non-negative")
  735. self.value = value
  736. def _apply_image(self, image):
  737. if self.value == 0:
  738. return image
  739. dtype = image.dtype
  740. image = image.astype(np.float32)
  741. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  742. image = image * alpha + F.to_gray(image) * (1 - alpha)
  743. return image.clip(0, 255).astype(dtype)
  744. def _apply_coords(self, coords):
  745. return coords
  746. def _apply_mask(self, mask):
  747. return mask
  748. class HueTransform(VisionTransform):
  749. r"""Adjust hue of the input data.
  750. :param value: How much to adjust the hue. Can be any number
  751. between 0 and 0.5, 0 gives the original image
  752. :param order: The same with :class:`VisionTransform`
  753. """
  754. def __init__(self, value, *, order=None):
  755. super().__init__(order)
  756. if value < 0 or value > 0.5:
  757. raise ValueError("hue value should be in [0.0, 0.5]")
  758. self.value = value
  759. def _apply_image(self, image):
  760. if self.value == 0:
  761. return image
  762. dtype = image.dtype
  763. image = image.astype(np.uint8)
  764. hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV_FULL)
  765. h, s, v = cv2.split(hsv_image)
  766. alpha = np.random.uniform(-self.value, self.value)
  767. h = h.astype(np.uint8)
  768. # uint8 addition take cares of rotation across boundaries
  769. with np.errstate(over="ignore"):
  770. h += np.uint8(alpha * 255)
  771. hsv_image = cv2.merge([h, s, v])
  772. return cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR_FULL).astype(dtype)
  773. def _apply_coords(self, coords):
  774. return coords
  775. def _apply_mask(self, mask):
  776. return mask
  777. class ColorJitter(VisionTransform):
  778. r"""Randomly change the brightness, contrast, saturation and hue of an image.
  779. :param brightness: How much to jitter brightness.
  780. Chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
  781. or the given [min, max]. Should be non negative numbers.
  782. :param contrast: How much to jitter contrast.
  783. Chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
  784. or the given [min, max]. Should be non negative numbers.
  785. :param saturation: How much to jitter saturation.
  786. Chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
  787. or the given [min, max]. Should be non negative numbers.
  788. :param hue: How much to jitter hue.
  789. Chosen uniformly from [-hue, hue] or the given [min, max].
  790. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
  791. :param order: The same with :class:`VisionTransform`
  792. """
  793. def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, *, order=None):
  794. super().__init__(order)
  795. transforms = []
  796. if brightness != 0:
  797. transforms.append(BrightnessTransform(brightness))
  798. if contrast != 0:
  799. transforms.append(ContrastTransform(contrast))
  800. if saturation != 0:
  801. transforms.append(SaturationTransform(saturation))
  802. if hue != 0:
  803. transforms.append(HueTransform(hue))
  804. self.transforms = Compose(
  805. transforms,
  806. shuffle_indices=[tuple(range(1, len(transforms) + 1))],
  807. order=order,
  808. )
  809. def apply(self, input):
  810. return self.transforms.apply(input)
  811. class Lighting(VisionTransform):
  812. def __init__(self, scale, *, order=None):
  813. super().__init__(order)
  814. if scale < 0:
  815. raise ValueError("lighting scale should be non-negative")
  816. self.scale = scale
  817. self.eigvec = np.array(
  818. [
  819. [-0.5836, -0.6948, 0.4203],
  820. [-0.5808, -0.0045, -0.8140],
  821. [-0.5675, 0.7192, 0.4009],
  822. ]
  823. ) # reverse the first dimension for BGR
  824. self.eigval = np.array([0.2175, 0.0188, 0.0045])
  825. def _apply_image(self, image):
  826. if self.scale == 0:
  827. return image
  828. dtype = image.dtype
  829. image = image.astype(np.float32)
  830. alpha = np.random.normal(scale=self.scale, size=3)
  831. image = image + self.eigvec.dot(alpha * self.eigval)
  832. return image.clip(0, 255).astype(dtype)
  833. def _apply_coords(self, coords):
  834. return coords
  835. def _apply_mask(self, mask):
  836. return mask

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台