You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

normal_data.py 2.4 kB

3 years ago
3 years ago
3 years ago
3 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import numpy as np
  2. import random
  3. class NormalSampler:
  4. def __init__(self, num_of_data=1000, shuffle=False):
  5. self._num_of_data = num_of_data
  6. self._data = list(range(num_of_data))
  7. if shuffle:
  8. random.shuffle(self._data)
  9. self.shuffle = shuffle
  10. self._index = 0
  11. self.need_reinitialize = False
  12. def __iter__(self):
  13. if self.need_reinitialize:
  14. self._index = 0
  15. if self.shuffle:
  16. random.shuffle(self._data)
  17. else:
  18. self.need_reinitialize = True
  19. return self
  20. def __next__(self):
  21. if self._index >= self._num_of_data:
  22. raise StopIteration
  23. _data = self._data[self._index]
  24. self._index += 1
  25. return _data
  26. def __len__(self):
  27. return self._num_of_data
  28. class NormalBatchSampler:
  29. def __init__(self, sampler, batch_size: int, drop_last: bool) -> None:
  30. # Since collections.abc.Iterable does not check for `__getitem__`, which
  31. # is one way for an object to be an iterable, we don't do an `isinstance`
  32. # check here.
  33. if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
  34. batch_size <= 0:
  35. raise ValueError("batch_size should be a positive integer value, "
  36. "but got batch_size={}".format(batch_size))
  37. if not isinstance(drop_last, bool):
  38. raise ValueError("drop_last should be a boolean value, but got "
  39. "drop_last={}".format(drop_last))
  40. self.sampler = sampler
  41. self.batch_size = batch_size
  42. self.drop_last = drop_last
  43. def __iter__(self):
  44. batch = []
  45. for idx in self.sampler:
  46. batch.append(idx)
  47. if len(batch) == self.batch_size:
  48. yield batch
  49. batch = []
  50. if len(batch) > 0 and not self.drop_last:
  51. yield batch
  52. def __len__(self) -> int:
  53. if self.drop_last:
  54. return len(self.sampler) // self.batch_size
  55. else:
  56. return (len(self.sampler) + self.batch_size - 1) // self.batch_size
  57. class RandomDataset:
  58. def __init__(self, num_data=10):
  59. self.data = np.random.rand(num_data)
  60. def __len__(self):
  61. return len(self.data)
  62. def __getitem__(self, item):
  63. return self.data[item]