You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 1.7 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import numpy as np
  4. import torch
  5. from torchvision import transforms
  6. from torchvision.datasets import CIFAR10
  7. class Cutout(object):
  8. def __init__(self, length):
  9. self.length = length
  10. def __call__(self, img):
  11. h, w = img.size(1), img.size(2)
  12. mask = np.ones((h, w), np.float32)
  13. y = np.random.randint(h)
  14. x = np.random.randint(w)
  15. y1 = np.clip(y - self.length // 2, 0, h)
  16. y2 = np.clip(y + self.length // 2, 0, h)
  17. x1 = np.clip(x - self.length // 2, 0, w)
  18. x2 = np.clip(x + self.length // 2, 0, w)
  19. mask[y1: y2, x1: x2] = 0.
  20. mask = torch.from_numpy(mask)
  21. mask = mask.expand_as(img)
  22. img *= mask
  23. return img
  24. def get_dataset(cls, cutout_length=0, root=None):
  25. MEAN = [0.49139968, 0.48215827, 0.44653124]
  26. STD = [0.24703233, 0.24348505, 0.26158768]
  27. transf = [
  28. transforms.RandomCrop(32, padding=4),
  29. transforms.RandomHorizontalFlip()
  30. ]
  31. normalize = [
  32. transforms.ToTensor(),
  33. transforms.Normalize(MEAN, STD)
  34. ]
  35. cutout = []
  36. if cutout_length > 0:
  37. cutout.append(Cutout(cutout_length))
  38. train_transform = transforms.Compose(transf + normalize + cutout)
  39. valid_transform = transforms.Compose(normalize)
  40. if cls == "cifar10":
  41. dataset_train = CIFAR10(root=root, train=True, download=True, transform=train_transform)
  42. dataset_valid = CIFAR10(root=root, train=False, download=True, transform=valid_transform)
  43. else:
  44. raise NotImplementedError
  45. return dataset_train, dataset_valid

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能