You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

meta_dataset.py 1.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # -*- coding: utf-8 -*-
  2. from abc import ABC, abstractmethod
  3. from typing import Tuple
  4. class Dataset(ABC):
  5. r"""An abstract base class for all datasets.
  6. __getitem__ and __len__ method are aditionally needed.
  7. """
  8. @abstractmethod
  9. def __init__(self):
  10. pass
  11. @abstractmethod
  12. def __getitem__(self, index):
  13. pass
  14. @abstractmethod
  15. def __len__(self):
  16. pass
  17. class StreamDataset(Dataset):
  18. r"""An abstract class for stream data.
  19. __iter__ method is aditionally needed.
  20. """
  21. @abstractmethod
  22. def __init__(self):
  23. pass
  24. @abstractmethod
  25. def __iter__(self):
  26. pass
  27. def __getitem__(self, idx):
  28. raise AssertionError("can not get item from StreamDataset by index")
  29. def __len__(self):
  30. raise AssertionError("StreamDataset does not have length")
  31. class ArrayDataset(Dataset):
  32. r"""ArrayDataset is a dataset for numpy array data.
  33. One or more numpy arrays are needed to initiate the dataset.
  34. And the dimensions represented sample number are expected to be the same.
  35. """
  36. def __init__(self, *arrays):
  37. super().__init__()
  38. if not all(len(arrays[0]) == len(array) for array in arrays):
  39. raise ValueError("lengths of input arrays are inconsistent")
  40. self.arrays = arrays
  41. def __getitem__(self, index: int) -> Tuple:
  42. return tuple(array[index] for array in self.arrays)
  43. def __len__(self) -> int:
  44. return len(self.arrays[0])