|
- import numpy as np
- import torch
- import torchvision.transforms as transforms
- from utils import Constant
-
- class Cutout:
- """Randomly mask out one or more patches from an image.
- Args:
- n_holes (int): Number of patches to cut out of each image.
- length (int): The length (in pixels) of each square patch.
- """
-
- def __init__(self, n_holes, length):
- self.n_holes = n_holes
- self.length = length
-
- def __call__(self, img):
- """
- Args:
- img (Tensor): Tensor image of size (C, H, W).
- Returns:
- Tensor: Image with n_holes of dimension length x length cut out of it.
- """
- h, w = img.size(1), img.size(2)
- mask = np.ones((h, w), np.float32)
-
- for _ in range(self.n_holes):
- y = np.random.randint(h)
- x = np.random.randint(w)
-
- y1 = np.clip(y - self.length // 2, 0, h)
- y2 = np.clip(y + self.length // 2, 0, h)
- x1 = np.clip(x - self.length // 2, 0, w)
- x2 = np.clip(x + self.length // 2, 0, w)
-
- mask[y1:y2, x1:x2] = 0.0
-
- mask = torch.from_numpy(mask)
- mask = mask.expand_as(img)
- img *= mask
- return img
-
- def data_transforms_cifar10():
- """ data_transforms for cifar10 dataset
- """
-
- cifar_mean = [0.49139968, 0.48215827, 0.44653124]
- cifar_std = [0.24703233, 0.24348505, 0.26158768]
-
- train_transform = transforms.Compose(
- [
- transforms.RandomCrop(32, padding=4),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize(cifar_mean, cifar_std),
- Cutout(n_holes=Constant.CUTOUT_HOLES,
- length=int(32 * Constant.CUTOUT_RATIO))
- ]
- )
-
- valid_transform = transforms.Compose(
- [transforms.ToTensor(), transforms.Normalize(cifar_mean, cifar_std)]
- )
- return train_transform, valid_transform
|