You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

jittor_data.py 1.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
  2. if _NEED_IMPORT_JITTOR:
  3. import jittor as jt
  4. from jittor.dataset import Dataset
  5. else:
  6. from fastNLP.core.utils.dummy_class import DummyClass as Dataset
  7. class JittorNormalDataset(Dataset):
  8. def __init__(self, num_of_data=100, **kwargs):
  9. super(JittorNormalDataset, self).__init__(**kwargs)
  10. self._data = list(range(num_of_data))
  11. self.set_attrs(total_len=num_of_data)
  12. def __getitem__(self, item):
  13. return self._data[item]
  14. class JittorNormalXYDataset(Dataset):
  15. """
  16. 可以被输入到分类模型中的普通数据集
  17. """
  18. def __init__(self, num_of_data=1000, **kwargs):
  19. super(JittorNormalXYDataset, self).__init__(**kwargs)
  20. self.num_of_data = num_of_data
  21. self._data = list(range(num_of_data))
  22. self.set_attrs(total_len=num_of_data)
  23. def __getitem__(self, item):
  24. return {
  25. "x": jt.Var([self._data[item]]),
  26. "y": jt.Var([self._data[item]])
  27. }
  28. class JittorArgMaxDataset(Dataset):
  29. def __init__(self, num_samples, num_features, **kwargs):
  30. super(JittorArgMaxDataset, self).__init__(**kwargs)
  31. self.x = jt.randn(num_samples, num_features)
  32. self.y = self.x.argmax(dim=-1)
  33. self.set_attrs(total_len=num_samples)
  34. def __getitem__(self, item):
  35. return {"x": self.x[item], "y": self.y[item]}
  36. if __name__ == "__main__":
  37. dataset = JittorNormalDataset()
  38. print(len(dataset))