You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

oneflow_data.py 1.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW
  2. if _NEED_IMPORT_ONEFLOW:
  3. import oneflow
  4. from oneflow.utils.data import Dataset
  5. else:
  6. from fastNLP.core.utils.dummy_class import DummyClass as Dataset
  7. class OneflowNormalDataset(Dataset):
  8. def __init__(self, num_of_data=1000):
  9. self.num_of_data = num_of_data
  10. self._data = list(range(num_of_data))
  11. def __len__(self):
  12. return self.num_of_data
  13. def __getitem__(self, item):
  14. return self._data[item]
  15. class OneflowNormalXYDataset(Dataset):
  16. """
  17. 可以被输入到分类模型中的普通数据集
  18. """
  19. def __init__(self, num_of_data=1000):
  20. self.num_of_data = num_of_data
  21. self._data = list(range(num_of_data))
  22. def __len__(self):
  23. return self.num_of_data
  24. def __getitem__(self, item):
  25. return {
  26. "x": oneflow.tensor([self._data[item]], dtype=oneflow.float),
  27. "y": oneflow.tensor([self._data[item]], dtype=oneflow.float)
  28. }
  29. class OneflowArgMaxDataset(Dataset):
  30. def __init__(self, data_num=1000, feature_dimension=10, seed=0):
  31. self.num_labels = feature_dimension
  32. self.feature_dimension = feature_dimension
  33. self.data_num = data_num
  34. self.seed = seed
  35. g = oneflow.Generator()
  36. g.manual_seed(1000)
  37. self.x = oneflow.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()
  38. self.y = oneflow.max(self.x, dim=-1)[1]
  39. def __len__(self):
  40. return self.data_num
  41. def __getitem__(self, item):
  42. return {"x": self.x[item], "y": self.y[item]}