You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

paddle_data.py 1.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import numpy as np
  2. from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
  3. if _NEED_IMPORT_PADDLE:
  4. import paddle
  5. from paddle.io import Dataset
  6. else:
  7. from fastNLP.core.utils.dummy_class import DummyClass as Dataset
  8. class PaddleNormalDataset(Dataset):
  9. def __init__(self, num_of_data=1000):
  10. self.num_of_data = num_of_data
  11. self._data = list(range(num_of_data))
  12. def __len__(self):
  13. return self.num_of_data
  14. def __getitem__(self, item):
  15. return self._data[item]
  16. class PaddleNormalXYDataset(Dataset):
  17. """
  18. 可以被输入到分类模型中的普通数据集
  19. """
  20. def __init__(self, num_of_data=1000):
  21. self.num_of_data = num_of_data
  22. self._data = list(range(num_of_data))
  23. def __len__(self):
  24. return self.num_of_data
  25. def __getitem__(self, item):
  26. return {
  27. "x": paddle.to_tensor([self._data[item]], dtype="float32"),
  28. "y": paddle.to_tensor([self._data[item]], dtype="float32")
  29. }
  30. class PaddleArgMaxDataset(Dataset):
  31. def __init__(self, num_samples, num_features):
  32. self.x = paddle.randn((num_samples, num_features))
  33. self.y = self.x.argmax(axis=-1)
  34. def __len__(self):
  35. return len(self.x)
  36. def __getitem__(self, item):
  37. return {"x": self.x[item], "y": self.y[item]}