You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

torch_data.py 2.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import torch
  2. from functools import reduce
  3. from torch.utils.data import Dataset, DataLoader, DistributedSampler
  4. from torch.utils.data.sampler import SequentialSampler, BatchSampler
  5. class TorchNormalDataset(Dataset):
  6. def __init__(self, num_of_data=1000):
  7. self.num_of_data = num_of_data
  8. self._data = list(range(num_of_data))
  9. def __len__(self):
  10. return self.num_of_data
  11. def __getitem__(self, item):
  12. return self._data[item]
  13. # 该类专门用于为 tests.helpers.models.torch_model.py/ TorchNormalModel_Classification_1 创建数据;
  14. class TorchNormalDataset_Classification(Dataset):
  15. def __init__(self, num_labels, feature_dimension=2, each_label_data=1000, seed=0):
  16. self.num_labels = num_labels
  17. self.feature_dimension = feature_dimension
  18. self.each_label_data = each_label_data
  19. self.seed = seed
  20. torch.manual_seed(seed)
  21. self.x_center = torch.randint(low=-100, high=100, size=[num_labels, feature_dimension])
  22. random_shuffle = torch.randn([num_labels, each_label_data, feature_dimension]) / 10
  23. self.x = self.x_center.unsqueeze(1).expand(num_labels, each_label_data, feature_dimension) + random_shuffle
  24. self.x = self.x.view(num_labels * each_label_data, feature_dimension)
  25. self.y = reduce(lambda x, y: x+y, [[i] * each_label_data for i in range(num_labels)])
  26. def __len__(self):
  27. return self.num_labels * self.each_label_data
  28. def __getitem__(self, item):
  29. return {"x": self.x[item], "y": self.y[item]}
  30. class TorchArgMaxDatset(Dataset):
  31. def __init__(self, feature_dimension=10, data_num=1000, seed=0):
  32. self.num_labels = feature_dimension
  33. self.feature_dimension = feature_dimension
  34. self.data_num = data_num
  35. self.seed = seed
  36. g = torch.Generator()
  37. g.manual_seed(1000)
  38. self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()
  39. self.y = torch.max(self.x, dim=-1)[1]
  40. def __len__(self):
  41. return self.data_num
  42. def __getitem__(self, item):
  43. return {"x": self.x[item], "y": self.y[item]}
  44. if __name__ == "__main__":
  45. a = TorchNormalDataset_Classification(2, each_label_data=4)
  46. print(a.x)
  47. print(a.y)
  48. print(a[0])