# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. from torchvision import transforms from torchvision.datasets import CIFAR10 def get_dataset(cls,datadir): MEAN = [0.49139968, 0.48215827, 0.44653124] STD = [0.24703233, 0.24348505, 0.26158768] transf = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip() ] normalize = [ transforms.ToTensor(), transforms.Normalize(MEAN, STD) ] train_transform = transforms.Compose(transf + normalize) valid_transform = transforms.Compose(normalize) if cls == "cifar10": dataset_train = CIFAR10(root=datadir, train=True, download=True, transform=train_transform) dataset_valid = CIFAR10(root=datadir, train=False, download=True, transform=valid_transform) else: raise NotImplementedError return dataset_train, dataset_valid