You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sampler.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections.abc
  10. import math
  11. from abc import ABC, abstractmethod
  12. from typing import Any, Generator, Iterator, List, Union
  13. import numpy as np
  14. from .. import distributed as dist
  15. class Sampler(ABC):
  16. r"""An abstract base class for all Sampler"""
  17. @abstractmethod
  18. def __init__(self):
  19. pass
  20. class MapSampler(Sampler):
  21. r"""Sampler for map dataset.
  22. Args:
  23. dataset: dataset to sample from.
  24. batch_size: batch size for batch method.
  25. drop_last: set ``True`` to drop the last incomplete batch,
  26. if the dataset size is not divisible by the batch size. If ``False`` and
  27. the size of dataset is not divisible by the batch_size, then the last batch will
  28. be smaller. Default: False
  29. num_samples: number of samples assigned to one rank.
  30. world_size: number of ranks.
  31. rank: rank id, non-negative interger within 0 and ``world_size``.
  32. seed: seed for random operators.
  33. """
  34. def __init__(
  35. self,
  36. dataset,
  37. batch_size=1,
  38. drop_last=False,
  39. num_samples=None,
  40. world_size=None,
  41. rank=None,
  42. seed=None,
  43. ):
  44. if (
  45. not isinstance(batch_size, int)
  46. or isinstance(batch_size, bool)
  47. or batch_size <= 0
  48. ):
  49. raise ValueError(
  50. "batch_size should be a positive integer value, "
  51. "but got batch_size={}".format(batch_size)
  52. )
  53. if not isinstance(drop_last, bool):
  54. raise ValueError(
  55. "drop_last should be a boolean value, but got "
  56. "drop_last={}".format(drop_last)
  57. )
  58. if num_samples is not None and (
  59. not isinstance(num_samples, int)
  60. or isinstance(num_samples, bool)
  61. or num_samples <= 0
  62. ):
  63. raise ValueError(
  64. "num_samples should be a positive integer "
  65. "value, but got num_samples={}".format(num_samples)
  66. )
  67. self.batch_size = batch_size
  68. self.dataset = dataset
  69. self.drop_last = drop_last
  70. if world_size is None:
  71. world_size = dist.get_world_size() if dist.is_distributed() else 1
  72. self.world_size = world_size
  73. if rank is None:
  74. rank = dist.get_rank() if dist.is_distributed() else 0
  75. self.rank = rank
  76. if num_samples is None:
  77. num_samples = len(self.dataset)
  78. self.num_samples = int(math.ceil(num_samples / self.world_size))
  79. # Make sure seeds are the same at each rank
  80. if seed is None and self.world_size > 1:
  81. seed = 0
  82. self.rng = np.random.RandomState(seed)
  83. def __iter__(self) -> Union[Generator, Iterator]:
  84. return self.batch()
  85. def __len__(self) -> int:
  86. if self.drop_last:
  87. return self.num_samples // self.batch_size
  88. else:
  89. return int(math.ceil(self.num_samples / self.batch_size))
  90. def sample(self):
  91. r"""Return a list contains all sample indices."""
  92. raise NotImplementedError
  93. def scatter(self, indices) -> List:
  94. r"""Scatter method is used for splitting indices into subset, each subset
  95. will be assigned to a rank. Indices are evenly splitted by default.
  96. If customized indices assignment method is needed, please rewrite this method.
  97. """
  98. total_size = self.num_samples * self.world_size
  99. # add extra indices to make it evenly divisible
  100. indices += indices[: (total_size - len(indices))]
  101. assert len(indices) == total_size
  102. # subsample
  103. indices = indices[self.rank : total_size : self.world_size]
  104. assert len(indices) == self.num_samples
  105. return indices
  106. def batch(self) -> Iterator[List[Any]]:
  107. r"""Batch method provides a batch indices generator."""
  108. indices = list(self.sample())
  109. # user might pass the world_size parameter without dist,
  110. # so dist.is_distributed() should not be used
  111. if self.world_size > 1:
  112. indices = self.scatter(indices)
  113. step, length = self.batch_size, len(indices)
  114. batch_index = [indices[i : i + step] for i in range(0, length, step)]
  115. if self.drop_last and len(batch_index[-1]) < self.batch_size:
  116. batch_index.pop()
  117. return iter(batch_index)
  118. class StreamSampler(Sampler):
  119. r"""Sampler for stream dataset.
  120. Warning:
  121. In the case of multiple machines, sampler should ensure that each worker gets
  122. different data. But this class cannot do it yet, please build your own
  123. dataset and sampler to achieve this goal.
  124. Usually, :meth:`~.StreamDataset.__iter__` can return different iterator by
  125. ``rank = dist.get_rank()``. So that they will get different data.
  126. """
  127. def __init__(self, batch_size=1):
  128. self.batch_size = batch_size
  129. def __iter__(self):
  130. return self
  131. def __next__(self):
  132. return iter(range(self.batch_size))
  133. class SequentialSampler(MapSampler):
  134. r"""Sample elements sequentially.
  135. Args:
  136. dataset: dataset to sample from.
  137. batch_size: batch size for batch method.
  138. drop_last: set ``True`` to drop the last incomplete batch,
  139. if the dataset size is not divisible by the batch size. If ``False`` and
  140. the size of dataset is not divisible by the batch_size, then the last batch will
  141. be smaller. Default: False
  142. indices: indice of samples.
  143. world_size: number of ranks.
  144. rank: rank id, non-negative interger within 0 and ``world_size``.
  145. """
  146. def __init__(
  147. self,
  148. dataset,
  149. batch_size=1,
  150. drop_last=False,
  151. indices=None,
  152. world_size=None,
  153. rank=None,
  154. ):
  155. super().__init__(dataset, batch_size, drop_last, None, world_size, rank)
  156. if indices is not None and not isinstance(indices, collections.abc.Sequence):
  157. raise ValueError(
  158. "indices should be None or a sequence, "
  159. "but got indices={}".format(indices)
  160. )
  161. self.indices = indices
  162. def sample(self) -> Iterator[Any]:
  163. r"""Return a generator."""
  164. if self.indices is None:
  165. return iter(range(len(self.dataset)))
  166. else:
  167. return self.indices
  168. class RandomSampler(MapSampler):
  169. r"""Sample elements randomly without replacement.
  170. Args:
  171. dataset: dataset to sample from.
  172. batch_size: batch size for batch method.
  173. drop_last: set ``True`` to drop the last incomplete batch,
  174. if the dataset size is not divisible by the batch size. If ``False`` and
  175. the size of dataset is not divisible by the batch_size, then the last batch will
  176. be smaller. Default: False
  177. indices: indice of samples.
  178. world_size: number of ranks.
  179. rank: rank id, non-negative interger within 0 and ``world_size``.
  180. seed: seed for random operators.
  181. """
  182. def __init__(
  183. self,
  184. dataset,
  185. batch_size=1,
  186. drop_last=False,
  187. indices=None,
  188. world_size=None,
  189. rank=None,
  190. seed=None,
  191. ):
  192. super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed)
  193. if indices is not None and not isinstance(indices, collections.abc.Sequence):
  194. raise ValueError(
  195. "indices should be None or a sequence, "
  196. "but got indices={}".format(indices)
  197. )
  198. self.indices = indices
  199. def sample(self) -> List:
  200. if self.indices is None:
  201. return self.rng.permutation(len(self.dataset)).tolist()
  202. else:
  203. return self.rng.permutation(self.indices).tolist()
  204. class ReplacementSampler(MapSampler):
  205. r"""Sample elements randomly with replacement.
  206. Args:
  207. dataset: dataset to sample from.
  208. batch_size: batch size for batch method.
  209. drop_last: set ``True`` to drop the last incomplete batch,
  210. if the dataset size is not divisible by the batch size. If ``False`` and
  211. the size of dataset is not divisible by the batch_size, then the last batch will
  212. be smaller. Default: False
  213. num_samples: number of samples assigned to one rank.
  214. weights: weights for sampling indices, it could be unnormalized weights.
  215. world_size: number of ranks.
  216. rank: rank id, non-negative interger within 0 and ``world_size``.
  217. seed: seed for random operators.
  218. """
  219. def __init__(
  220. self,
  221. dataset,
  222. batch_size=1,
  223. drop_last=False,
  224. num_samples=None,
  225. weights=None,
  226. world_size=None,
  227. rank=None,
  228. seed=None,
  229. ):
  230. super().__init__(
  231. dataset, batch_size, drop_last, num_samples, world_size, rank, seed
  232. )
  233. if weights is not None:
  234. if not isinstance(weights, collections.abc.Sequence):
  235. raise ValueError(
  236. "weights should be None or a sequence, "
  237. "but got weights={}".format(weights)
  238. )
  239. if len(weights) != len(dataset):
  240. raise ValueError(
  241. "len(dataset)={} should be equal to"
  242. "len(weights)={}".format(len(dataset), len(weights))
  243. )
  244. self.weights = weights
  245. if self.weights is not None:
  246. self.weights = np.array(weights) / sum(weights)
  247. def sample(self) -> List:
  248. n = len(self.dataset)
  249. if self.weights is None:
  250. return self.rng.randint(n, size=self.num_samples).tolist()
  251. else:
  252. return self.rng.multinomial(n, self.weights, self.num_samples).tolist()
  253. class Infinite(MapSampler):
  254. r"""Infinite Sampler warper for basic sampler."""
  255. def sample(self):
  256. raise NotImplementedError("sample method not supported in Infinite")
  257. def __init__(self, sampler):
  258. self.sampler = sampler
  259. self.sampler_iter = iter(self.sampler)
  260. def __iter__(self):
  261. return self
  262. def __next__(self):
  263. try:
  264. index = next(self.sampler_iter)
  265. except StopIteration:
  266. self.sampler_iter = iter(self.sampler)
  267. index = next(self.sampler_iter)
  268. return index
  269. def __len__(self):
  270. return np.iinfo(np.int64).max