You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

torch_data.py 2.9 kB

3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from functools import reduce
  2. from numpy import dtype
  3. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  4. if _NEED_IMPORT_TORCH:
  5. import torch
  6. from torch.utils.data import Dataset
  7. else:
  8. from fastNLP.core.utils.dummy_class import DummyClass as Dataset
  9. class TorchNormalDataset(Dataset):
  10. def __init__(self, num_of_data=1000):
  11. self.num_of_data = num_of_data
  12. self._data = list(range(num_of_data))
  13. def __len__(self):
  14. return self.num_of_data
  15. def __getitem__(self, item):
  16. return self._data[item]
  17. class TorchNormalXYDataset(Dataset):
  18. """
  19. 可以被输入到分类模型中的普通数据集
  20. """
  21. def __init__(self, num_of_data=1000):
  22. self.num_of_data = num_of_data
  23. self._data = list(range(num_of_data))
  24. def __len__(self):
  25. return self.num_of_data
  26. def __getitem__(self, item):
  27. return {
  28. "x": torch.tensor([self._data[item]], dtype=torch.float),
  29. "y": torch.tensor([self._data[item]], dtype=torch.float)
  30. }
  31. # 该类专门用于为 tests.helpers.models.torch_model.py/ TorchNormalModel_Classification_1 创建数据;
  32. class TorchNormalDataset_Classification(Dataset):
  33. def __init__(self, num_labels, feature_dimension=2, each_label_data=1000, seed=0):
  34. self.num_labels = num_labels
  35. self.feature_dimension = feature_dimension
  36. self.each_label_data = each_label_data
  37. self.seed = seed
  38. torch.manual_seed(seed)
  39. self.x_center = torch.randint(low=-100, high=100, size=[num_labels, feature_dimension])
  40. random_shuffle = torch.randn([num_labels, each_label_data, feature_dimension]) / 10
  41. self.x = self.x_center.unsqueeze(1).expand(num_labels, each_label_data, feature_dimension) + random_shuffle
  42. self.x = self.x.view(num_labels * each_label_data, feature_dimension)
  43. self.y = reduce(lambda x, y: x+y, [[i] * each_label_data for i in range(num_labels)])
  44. def __len__(self):
  45. return self.num_labels * self.each_label_data
  46. def __getitem__(self, item):
  47. return {"x": self.x[item], "y": self.y[item]}
  48. class TorchArgMaxDataset(Dataset):
  49. def __init__(self, feature_dimension=10, data_num=1000, seed=0):
  50. self.num_labels = feature_dimension
  51. self.feature_dimension = feature_dimension
  52. self.data_num = data_num
  53. self.seed = seed
  54. g = torch.Generator()
  55. g.manual_seed(1000)
  56. self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()
  57. self.y = torch.max(self.x, dim=-1)[1]
  58. def __len__(self):
  59. return self.data_num
  60. def __getitem__(self, item):
  61. return {"x": self.x[item], "y": self.y[item]}
  62. if __name__ == "__main__":
  63. a = TorchNormalDataset_Classification(2, each_label_data=4)
  64. print(a.x)
  65. print(a.y)
  66. print(a[0])