# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import numpy as np import torch from torchvision import transforms from torchvision.datasets import CIFAR10 class Cutout(object): def __init__(self, length): self.length = length def __call__(self, img): h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) x = np.random.randint(w) y1 = np.clip(y - self.length // 2, 0, h) y2 = np.clip(y + self.length // 2, 0, h) x1 = np.clip(x - self.length // 2, 0, w) x2 = np.clip(x + self.length // 2, 0, w) mask[y1: y2, x1: x2] = 0. mask = torch.from_numpy(mask) mask = mask.expand_as(img) img *= mask return img def get_dataset(cls, cutout_length=0, root=None): MEAN = [0.49139968, 0.48215827, 0.44653124] STD = [0.24703233, 0.24348505, 0.26158768] transf = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip() ] normalize = [ transforms.ToTensor(), transforms.Normalize(MEAN, STD) ] cutout = [] if cutout_length > 0: cutout.append(Cutout(cutout_length)) train_transform = transforms.Compose(transf + normalize + cutout) valid_transform = transforms.Compose(normalize) if cls == "cifar10": dataset_train = CIFAR10(root=root, train=True, download=True, transform=train_transform) dataset_valid = CIFAR10(root=root, train=False, download=True, transform=valid_transform) else: raise NotImplementedError return dataset_train, dataset_valid