You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 906 B

2 years ago
12345678910111213141516171819202122232425262728
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. from torchvision import transforms
  4. from torchvision.datasets import CIFAR10
  5. def get_dataset(cls,datadir):
  6. MEAN = [0.49139968, 0.48215827, 0.44653124]
  7. STD = [0.24703233, 0.24348505, 0.26158768]
  8. transf = [
  9. transforms.RandomCrop(32, padding=4),
  10. transforms.RandomHorizontalFlip()
  11. ]
  12. normalize = [
  13. transforms.ToTensor(),
  14. transforms.Normalize(MEAN, STD)
  15. ]
  16. train_transform = transforms.Compose(transf + normalize)
  17. valid_transform = transforms.Compose(normalize)
  18. if cls == "cifar10":
  19. dataset_train = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
  20. dataset_valid = CIFAR10(root=datadir, train=False, download=True, transform=valid_transform)
  21. else:
  22. raise NotImplementedError
  23. return dataset_train, dataset_valid

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能