You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 1.9 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import numpy as np
  2. import torch
  3. import torchvision.transforms as transforms
  4. from utils import Constant
  5. class Cutout:
  6. """Randomly mask out one or more patches from an image.
  7. Args:
  8. n_holes (int): Number of patches to cut out of each image.
  9. length (int): The length (in pixels) of each square patch.
  10. """
  11. def __init__(self, n_holes, length):
  12. self.n_holes = n_holes
  13. self.length = length
  14. def __call__(self, img):
  15. """
  16. Args:
  17. img (Tensor): Tensor image of size (C, H, W).
  18. Returns:
  19. Tensor: Image with n_holes of dimension length x length cut out of it.
  20. """
  21. h, w = img.size(1), img.size(2)
  22. mask = np.ones((h, w), np.float32)
  23. for _ in range(self.n_holes):
  24. y = np.random.randint(h)
  25. x = np.random.randint(w)
  26. y1 = np.clip(y - self.length // 2, 0, h)
  27. y2 = np.clip(y + self.length // 2, 0, h)
  28. x1 = np.clip(x - self.length // 2, 0, w)
  29. x2 = np.clip(x + self.length // 2, 0, w)
  30. mask[y1:y2, x1:x2] = 0.0
  31. mask = torch.from_numpy(mask)
  32. mask = mask.expand_as(img)
  33. img *= mask
  34. return img
  35. def data_transforms_cifar10():
  36. """ data_transforms for cifar10 dataset
  37. """
  38. cifar_mean = [0.49139968, 0.48215827, 0.44653124]
  39. cifar_std = [0.24703233, 0.24348505, 0.26158768]
  40. train_transform = transforms.Compose(
  41. [
  42. transforms.RandomCrop(32, padding=4),
  43. transforms.RandomHorizontalFlip(),
  44. transforms.ToTensor(),
  45. transforms.Normalize(cifar_mean, cifar_std),
  46. Cutout(n_holes=Constant.CUTOUT_HOLES,
  47. length=int(32 * Constant.CUTOUT_RATIO))
  48. ]
  49. )
  50. valid_transform = transforms.Compose(
  51. [transforms.ToTensor(), transforms.Normalize(cifar_mean, cifar_std)]
  52. )
  53. return train_transform, valid_transform

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能