You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sampler.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # -*- coding: utf-8 -*-
  2. import collections.abc
  3. import math
  4. from abc import ABC, abstractmethod
  5. from typing import Any, Generator, Iterator, List, Union
  6. import numpy as np
  7. from .. import distributed as dist
  8. class Sampler(ABC):
  9. r"""An abstract base class for all Sampler"""
  10. @abstractmethod
  11. def __init__(self):
  12. pass
  13. class MapSampler(Sampler):
  14. r"""Sampler for map dataset.
  15. Args:
  16. dataset: dataset to sample from.
  17. batch_size: batch size for batch method.
  18. drop_last: set ``True`` to drop the last incomplete batch,
  19. if the dataset size is not divisible by the batch size. If ``False`` and
  20. the size of dataset is not divisible by the batch_size, then the last batch will
  21. be smaller. Default: False
  22. num_samples: number of samples assigned to one rank.
  23. world_size: number of ranks.
  24. rank: rank id, non-negative interger within 0 and ``world_size``.
  25. seed: seed for random operators.
  26. """
  27. def __init__(
  28. self,
  29. dataset,
  30. batch_size=1,
  31. drop_last=False,
  32. num_samples=None,
  33. world_size=None,
  34. rank=None,
  35. seed=None,
  36. ):
  37. if (
  38. not isinstance(batch_size, int)
  39. or isinstance(batch_size, bool)
  40. or batch_size <= 0
  41. ):
  42. raise ValueError(
  43. "batch_size should be a positive integer value, "
  44. "but got batch_size={}".format(batch_size)
  45. )
  46. if not isinstance(drop_last, bool):
  47. raise ValueError(
  48. "drop_last should be a boolean value, but got "
  49. "drop_last={}".format(drop_last)
  50. )
  51. if num_samples is not None and (
  52. not isinstance(num_samples, int)
  53. or isinstance(num_samples, bool)
  54. or num_samples <= 0
  55. ):
  56. raise ValueError(
  57. "num_samples should be a positive integer "
  58. "value, but got num_samples={}".format(num_samples)
  59. )
  60. self.batch_size = batch_size
  61. self.dataset = dataset
  62. self.drop_last = drop_last
  63. if world_size is None:
  64. world_size = dist.get_world_size() if dist.is_distributed() else 1
  65. self.world_size = world_size
  66. if rank is None:
  67. rank = dist.get_rank() if dist.is_distributed() else 0
  68. self.rank = rank
  69. if num_samples is None:
  70. num_samples = len(self.dataset)
  71. self.num_samples = int(math.ceil(num_samples / self.world_size))
  72. # Make sure seeds are the same at each rank
  73. if seed is None and self.world_size > 1:
  74. seed = 0
  75. self.rng = np.random.RandomState(seed)
  76. def __iter__(self) -> Union[Generator, Iterator]:
  77. return self.batch()
  78. def __len__(self) -> int:
  79. if self.drop_last:
  80. return self.num_samples // self.batch_size
  81. else:
  82. return int(math.ceil(self.num_samples / self.batch_size))
  83. def sample(self):
  84. r"""Return a list contains all sample indices."""
  85. raise NotImplementedError
  86. def scatter(self, indices) -> List:
  87. r"""Scatter method is used for splitting indices into subset, each subset
  88. will be assigned to a rank. Indices are evenly splitted by default.
  89. If customized indices assignment method is needed, please rewrite this method.
  90. """
  91. total_size = self.num_samples * self.world_size
  92. # add extra indices to make it evenly divisible
  93. indices += indices[: (total_size - len(indices))]
  94. assert len(indices) == total_size
  95. # subsample
  96. indices = indices[self.rank : total_size : self.world_size]
  97. assert len(indices) == self.num_samples
  98. return indices
  99. def batch(self) -> Iterator[List[Any]]:
  100. r"""Batch method provides a batch indices generator."""
  101. indices = list(self.sample())
  102. # user might pass the world_size parameter without dist,
  103. # so dist.is_distributed() should not be used
  104. if self.world_size > 1:
  105. indices = self.scatter(indices)
  106. step, length = self.batch_size, len(indices)
  107. batch_index = [indices[i : i + step] for i in range(0, length, step)]
  108. if self.drop_last and len(batch_index[-1]) < self.batch_size:
  109. batch_index.pop()
  110. return iter(batch_index)
  111. class StreamSampler(Sampler):
  112. r"""Sampler for stream dataset.
  113. Warning:
  114. In the case of multiple machines, sampler should ensure that each worker gets
  115. different data. But this class cannot do it yet, please build your own
  116. dataset and sampler to achieve this goal.
  117. Usually, :meth:`~.StreamDataset.__iter__` can return different iterator by
  118. ``rank = dist.get_rank()``. So that they will get different data.
  119. """
  120. def __init__(self, batch_size=1):
  121. self.batch_size = batch_size
  122. def __iter__(self):
  123. return self
  124. def __next__(self):
  125. return iter(range(self.batch_size))
  126. class SequentialSampler(MapSampler):
  127. r"""Sample elements sequentially.
  128. Args:
  129. dataset: dataset to sample from.
  130. batch_size: batch size for batch method.
  131. drop_last: set ``True`` to drop the last incomplete batch,
  132. if the dataset size is not divisible by the batch size. If ``False`` and
  133. the size of dataset is not divisible by the batch_size, then the last batch will
  134. be smaller. Default: False
  135. indices: indice of samples.
  136. world_size: number of ranks.
  137. rank: rank id, non-negative interger within 0 and ``world_size``.
  138. """
  139. def __init__(
  140. self,
  141. dataset,
  142. batch_size=1,
  143. drop_last=False,
  144. indices=None,
  145. world_size=None,
  146. rank=None,
  147. ):
  148. super().__init__(dataset, batch_size, drop_last, None, world_size, rank)
  149. if indices is not None and not isinstance(indices, collections.abc.Sequence):
  150. raise ValueError(
  151. "indices should be None or a sequence, "
  152. "but got indices={}".format(indices)
  153. )
  154. self.indices = indices
  155. def sample(self) -> Iterator[Any]:
  156. r"""Return a generator."""
  157. if self.indices is None:
  158. return iter(range(len(self.dataset)))
  159. else:
  160. return self.indices
  161. class RandomSampler(MapSampler):
  162. r"""Sample elements randomly without replacement.
  163. Args:
  164. dataset: dataset to sample from.
  165. batch_size: batch size for batch method.
  166. drop_last: set ``True`` to drop the last incomplete batch,
  167. if the dataset size is not divisible by the batch size. If ``False`` and
  168. the size of dataset is not divisible by the batch_size, then the last batch will
  169. be smaller. Default: False
  170. indices: indice of samples.
  171. world_size: number of ranks.
  172. rank: rank id, non-negative interger within 0 and ``world_size``.
  173. seed: seed for random operators.
  174. """
  175. def __init__(
  176. self,
  177. dataset,
  178. batch_size=1,
  179. drop_last=False,
  180. indices=None,
  181. world_size=None,
  182. rank=None,
  183. seed=None,
  184. ):
  185. super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed)
  186. if indices is not None and not isinstance(indices, collections.abc.Sequence):
  187. raise ValueError(
  188. "indices should be None or a sequence, "
  189. "but got indices={}".format(indices)
  190. )
  191. self.indices = indices
  192. def sample(self) -> List:
  193. if self.indices is None:
  194. return self.rng.permutation(len(self.dataset)).tolist()
  195. else:
  196. return self.rng.permutation(self.indices).tolist()
  197. class ReplacementSampler(MapSampler):
  198. r"""Sample elements randomly with replacement.
  199. Args:
  200. dataset: dataset to sample from.
  201. batch_size: batch size for batch method.
  202. drop_last: set ``True`` to drop the last incomplete batch,
  203. if the dataset size is not divisible by the batch size. If ``False`` and
  204. the size of dataset is not divisible by the batch_size, then the last batch will
  205. be smaller. Default: False
  206. num_samples: number of samples assigned to one rank.
  207. weights: weights for sampling indices, it could be unnormalized weights.
  208. world_size: number of ranks.
  209. rank: rank id, non-negative interger within 0 and ``world_size``.
  210. seed: seed for random operators.
  211. """
  212. def __init__(
  213. self,
  214. dataset,
  215. batch_size=1,
  216. drop_last=False,
  217. num_samples=None,
  218. weights=None,
  219. world_size=None,
  220. rank=None,
  221. seed=None,
  222. ):
  223. super().__init__(
  224. dataset, batch_size, drop_last, num_samples, world_size, rank, seed
  225. )
  226. if weights is not None:
  227. if not isinstance(weights, collections.abc.Sequence):
  228. raise ValueError(
  229. "weights should be None or a sequence, "
  230. "but got weights={}".format(weights)
  231. )
  232. if len(weights) != len(dataset):
  233. raise ValueError(
  234. "len(dataset)={} should be equal to"
  235. "len(weights)={}".format(len(dataset), len(weights))
  236. )
  237. self.weights = weights
  238. if self.weights is not None:
  239. self.weights = np.array(weights) / sum(weights)
  240. def sample(self) -> List:
  241. n = len(self.dataset)
  242. if self.weights is None:
  243. return self.rng.randint(n, size=self.num_samples).tolist()
  244. else:
  245. return self.rng.multinomial(n, self.weights, self.num_samples).tolist()
  246. class Infinite(MapSampler):
  247. r"""Infinite Sampler warper for basic sampler."""
  248. def sample(self):
  249. raise NotImplementedError("sample method not supported in Infinite")
  250. def __init__(self, sampler):
  251. self.sampler = sampler
  252. self.sampler_iter = iter(self.sampler)
  253. def __iter__(self):
  254. return self
  255. def __next__(self):
  256. try:
  257. index = next(self.sampler_iter)
  258. except StopIteration:
  259. self.sampler_iter = iter(self.sampler)
  260. index = next(self.sampler_iter)
  261. return index
  262. def __len__(self):
  263. return np.iinfo(np.int64).max