You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sampler.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # -*- coding: utf-8 -*-
  2. import collections.abc
  3. import math
  4. from abc import ABC, abstractmethod
  5. from itertools import count
  6. from typing import Any, Generator, Iterator, List, Union
  7. import numpy as np
  8. from .. import distributed as dist
  9. class Sampler(ABC):
  10. r"""An abstract base class for all Sampler"""
  11. @abstractmethod
  12. def __init__(self):
  13. pass
  14. class MapSampler(Sampler):
  15. r"""Sampler for map dataset.
  16. Args:
  17. dataset: dataset to sample from.
  18. batch_size: batch size for batch method.
  19. drop_last: set ``True`` to drop the last incomplete batch,
  20. if the dataset size is not divisible by the batch size. If ``False`` and
  21. the size of dataset is not divisible by the batch_size, then the last batch will
  22. be smaller. Default: False
  23. num_samples: number of samples assigned to one rank.
  24. world_size: number of ranks.
  25. rank: rank id, non-negative interger within 0 and ``world_size``.
  26. seed: seed for random operators.
  27. """
  28. def __init__(
  29. self,
  30. dataset,
  31. batch_size=1,
  32. drop_last=False,
  33. num_samples=None,
  34. world_size=None,
  35. rank=None,
  36. seed=None,
  37. ):
  38. if (
  39. not isinstance(batch_size, int)
  40. or isinstance(batch_size, bool)
  41. or batch_size <= 0
  42. ):
  43. raise ValueError(
  44. "batch_size should be a positive integer value, "
  45. "but got batch_size={}".format(batch_size)
  46. )
  47. if not isinstance(drop_last, bool):
  48. raise ValueError(
  49. "drop_last should be a boolean value, but got "
  50. "drop_last={}".format(drop_last)
  51. )
  52. if num_samples is not None and (
  53. not isinstance(num_samples, int)
  54. or isinstance(num_samples, bool)
  55. or num_samples <= 0
  56. ):
  57. raise ValueError(
  58. "num_samples should be a positive integer "
  59. "value, but got num_samples={}".format(num_samples)
  60. )
  61. self.batch_size = batch_size
  62. self.dataset = dataset
  63. self.drop_last = drop_last
  64. if world_size is None:
  65. world_size = dist.get_world_size() if dist.is_distributed() else 1
  66. self.world_size = world_size
  67. if rank is None:
  68. rank = dist.get_rank() if dist.is_distributed() else 0
  69. self.rank = rank
  70. if num_samples is None:
  71. num_samples = len(self.dataset)
  72. self.num_samples = int(math.ceil(num_samples / self.world_size))
  73. # Make sure seeds are the same at each rank
  74. if seed is None and self.world_size > 1:
  75. seed = 0
  76. self.rng = np.random.RandomState(seed)
  77. def __iter__(self) -> Union[Generator, Iterator]:
  78. return self.batch()
  79. def __len__(self) -> int:
  80. if self.drop_last:
  81. return self.num_samples // self.batch_size
  82. else:
  83. return int(math.ceil(self.num_samples / self.batch_size))
  84. def sample(self):
  85. r"""Return a list contains all sample indices."""
  86. raise NotImplementedError
  87. def scatter(self, indices) -> List:
  88. r"""Scatter method is used for splitting indices into subset, each subset
  89. will be assigned to a rank. Indices are evenly splitted by default.
  90. If customized indices assignment method is needed, please rewrite this method.
  91. """
  92. total_size = self.num_samples * self.world_size
  93. # add extra indices to make it evenly divisible
  94. indices += indices[: (total_size - len(indices))]
  95. assert len(indices) == total_size
  96. # subsample
  97. indices = indices[self.rank : total_size : self.world_size]
  98. assert len(indices) == self.num_samples
  99. return indices
  100. def batch(self) -> Iterator[List[Any]]:
  101. r"""Batch method provides a batch indices generator."""
  102. indices = list(self.sample())
  103. # user might pass the world_size parameter without dist,
  104. # so dist.is_distributed() should not be used
  105. if self.world_size > 1:
  106. indices = self.scatter(indices)
  107. batch = []
  108. for idx in indices:
  109. batch.append(idx)
  110. if len(batch) == self.batch_size:
  111. yield batch
  112. batch = []
  113. if len(batch) > 0 and not self.drop_last:
  114. yield batch
  115. class StreamSampler(Sampler):
  116. r"""Sampler for stream dataset.
  117. Warning:
  118. In the case of multiple machines, sampler should ensure that each worker gets
  119. different data. But this class cannot do it yet, please build your own
  120. dataset and sampler to achieve this goal.
  121. Usually, :meth:`~.StreamDataset.__iter__` can return different iterator by
  122. ``rank = dist.get_rank()``. So that they will get different data.
  123. """
  124. def __init__(self, batch_size=1):
  125. self.batch_size = batch_size
  126. def __iter__(self):
  127. return self.batch()
  128. def batch(self):
  129. batch = []
  130. for idx in self.sample():
  131. batch.append(idx)
  132. if len(batch) == self.batch_size:
  133. yield batch
  134. batch = []
  135. def sample(self):
  136. return count(start=0)
  137. class SequentialSampler(MapSampler):
  138. r"""Sample elements sequentially.
  139. Args:
  140. dataset: dataset to sample from.
  141. batch_size: batch size for batch method.
  142. drop_last: set ``True`` to drop the last incomplete batch,
  143. if the dataset size is not divisible by the batch size. If ``False`` and
  144. the size of dataset is not divisible by the batch_size, then the last batch will
  145. be smaller. Default: False
  146. indices: indice of samples.
  147. world_size: number of ranks.
  148. rank: rank id, non-negative interger within 0 and ``world_size``.
  149. """
  150. def __init__(
  151. self,
  152. dataset,
  153. batch_size=1,
  154. drop_last=False,
  155. indices=None,
  156. world_size=None,
  157. rank=None,
  158. ):
  159. super().__init__(dataset, batch_size, drop_last, None, world_size, rank)
  160. if indices is not None and not isinstance(indices, collections.abc.Sequence):
  161. raise ValueError(
  162. "indices should be None or a sequence, "
  163. "but got indices={}".format(indices)
  164. )
  165. self.indices = indices
  166. def sample(self) -> Iterator[Any]:
  167. r"""Return a generator."""
  168. if self.indices is None:
  169. return iter(range(len(self.dataset)))
  170. else:
  171. return self.indices
  172. class RandomSampler(MapSampler):
  173. r"""Sample elements randomly without replacement.
  174. Args:
  175. dataset: dataset to sample from.
  176. batch_size: batch size for batch method.
  177. drop_last: set ``True`` to drop the last incomplete batch,
  178. if the dataset size is not divisible by the batch size. If ``False`` and
  179. the size of dataset is not divisible by the batch_size, then the last batch will
  180. be smaller. Default: False
  181. indices: indice of samples.
  182. world_size: number of ranks.
  183. rank: rank id, non-negative interger within 0 and ``world_size``.
  184. seed: seed for random operators.
  185. """
  186. def __init__(
  187. self,
  188. dataset,
  189. batch_size=1,
  190. drop_last=False,
  191. indices=None,
  192. world_size=None,
  193. rank=None,
  194. seed=None,
  195. ):
  196. super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed)
  197. if indices is not None and not isinstance(indices, collections.abc.Sequence):
  198. raise ValueError(
  199. "indices should be None or a sequence, "
  200. "but got indices={}".format(indices)
  201. )
  202. self.indices = indices
  203. def sample(self) -> List:
  204. if self.indices is None:
  205. return self.rng.permutation(len(self.dataset)).tolist()
  206. else:
  207. return self.rng.permutation(self.indices).tolist()
  208. class ReplacementSampler(MapSampler):
  209. r"""Sample elements randomly with replacement.
  210. Args:
  211. dataset: dataset to sample from.
  212. batch_size: batch size for batch method.
  213. drop_last: set ``True`` to drop the last incomplete batch,
  214. if the dataset size is not divisible by the batch size. If ``False`` and
  215. the size of dataset is not divisible by the batch_size, then the last batch will
  216. be smaller. Default: False
  217. num_samples: number of samples assigned to one rank.
  218. weights: weights for sampling indices, it could be unnormalized weights.
  219. world_size: number of ranks.
  220. rank: rank id, non-negative interger within 0 and ``world_size``.
  221. seed: seed for random operators.
  222. """
  223. def __init__(
  224. self,
  225. dataset,
  226. batch_size=1,
  227. drop_last=False,
  228. num_samples=None,
  229. weights=None,
  230. world_size=None,
  231. rank=None,
  232. seed=None,
  233. ):
  234. super().__init__(
  235. dataset, batch_size, drop_last, num_samples, world_size, rank, seed
  236. )
  237. if weights is not None:
  238. if not isinstance(weights, collections.abc.Sequence):
  239. raise ValueError(
  240. "weights should be None or a sequence, "
  241. "but got weights={}".format(weights)
  242. )
  243. if len(weights) != len(dataset):
  244. raise ValueError(
  245. "len(dataset)={} should be equal to"
  246. "len(weights)={}".format(len(dataset), len(weights))
  247. )
  248. self.weights = weights
  249. if self.weights is not None:
  250. self.weights = np.array(weights) / sum(weights)
  251. def sample(self) -> List:
  252. n = len(self.dataset)
  253. if self.weights is None:
  254. return self.rng.randint(n, size=self.num_samples).tolist()
  255. else:
  256. return self.rng.multinomial(n, self.weights, self.num_samples).tolist()
  257. class Infinite(MapSampler):
  258. r"""Infinite Sampler warper for basic sampler."""
  259. def sample(self):
  260. raise NotImplementedError("sample method not supported in Infinite")
  261. def __init__(self, sampler):
  262. self.sampler = sampler
  263. self.sampler_iter = iter(self.sampler)
  264. def __iter__(self):
  265. return self
  266. def __next__(self):
  267. try:
  268. index = next(self.sampler_iter)
  269. except StopIteration:
  270. self.sampler_iter = iter(self.sampler)
  271. index = next(self.sampler_iter)
  272. return index
  273. def __len__(self):
  274. return np.iinfo(np.int64).max