import numpy as np import torch import torchvision.transforms as transforms from utils import Constant class Cutout: """Randomly mask out one or more patches from an image. Args: n_holes (int): Number of patches to cut out of each image. length (int): The length (in pixels) of each square patch. """ def __init__(self, n_holes, length): self.n_holes = n_holes self.length = length def __call__(self, img): """ Args: img (Tensor): Tensor image of size (C, H, W). Returns: Tensor: Image with n_holes of dimension length x length cut out of it. """ h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) for _ in range(self.n_holes): y = np.random.randint(h) x = np.random.randint(w) y1 = np.clip(y - self.length // 2, 0, h) y2 = np.clip(y + self.length // 2, 0, h) x1 = np.clip(x - self.length // 2, 0, w) x2 = np.clip(x + self.length // 2, 0, w) mask[y1:y2, x1:x2] = 0.0 mask = torch.from_numpy(mask) mask = mask.expand_as(img) img *= mask return img def data_transforms_cifar10(): """ data_transforms for cifar10 dataset """ cifar_mean = [0.49139968, 0.48215827, 0.44653124] cifar_std = [0.24703233, 0.24348505, 0.26158768] train_transform = transforms.Compose( [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(cifar_mean, cifar_std), Cutout(n_holes=Constant.CUTOUT_HOLES, length=int(32 * Constant.CUTOUT_RATIO)) ] ) valid_transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(cifar_mean, cifar_std)] ) return train_transform, valid_transform