You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sampler.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections.abc
  10. import math
  11. from abc import ABC, abstractmethod
  12. from typing import Any, Generator, Iterator, List, Union
  13. import numpy as np
  14. import megengine.distributed as dist
  15. class Sampler(ABC):
  16. r"""
  17. An abstract base class for all Sampler
  18. """
  19. @abstractmethod
  20. def __init__(self):
  21. pass
  22. class MapSampler(Sampler):
  23. r"""
  24. Sampler for map dataset.
  25. :param dataset: dataset to sample from.
  26. :param batch_size: batch size for batch method.
  27. :param drop_last: set ``True`` to drop the last incomplete batch,
  28. if the dataset size is not divisible by the batch size. If ``False`` and
  29. the size of dataset is not divisible by the batch_size, then the last batch will
  30. be smaller. Default: False
  31. :param num_samples: number of samples assigned to one rank.
  32. :param world_size: number of ranks.
  33. :param rank: rank id, non-negative interger within 0 and ``world_size``.
  34. :param seed: seed for random operators.
  35. """
  36. def __init__(
  37. self,
  38. dataset,
  39. batch_size=1,
  40. drop_last=False,
  41. num_samples=None,
  42. world_size=None,
  43. rank=None,
  44. seed=None,
  45. ):
  46. if (
  47. not isinstance(batch_size, int)
  48. or isinstance(batch_size, bool)
  49. or batch_size <= 0
  50. ):
  51. raise ValueError(
  52. "batch_size should be a positive integer value, "
  53. "but got batch_size={}".format(batch_size)
  54. )
  55. if not isinstance(drop_last, bool):
  56. raise ValueError(
  57. "drop_last should be a boolean value, but got "
  58. "drop_last={}".format(drop_last)
  59. )
  60. if num_samples is not None and (
  61. not isinstance(num_samples, int)
  62. or isinstance(num_samples, bool)
  63. or num_samples <= 0
  64. ):
  65. raise ValueError(
  66. "num_samples should be a positive integer "
  67. "value, but got num_samples={}".format(num_samples)
  68. )
  69. self.batch_size = batch_size
  70. self.dataset = dataset
  71. self.drop_last = drop_last
  72. if world_size is None:
  73. world_size = dist.get_world_size() if dist.is_distributed() else 1
  74. self.world_size = world_size
  75. if rank is None:
  76. rank = dist.get_rank() if dist.is_distributed() else 0
  77. self.rank = rank
  78. if num_samples is None:
  79. num_samples = len(self.dataset)
  80. self.num_samples = int(math.ceil(num_samples / self.world_size))
  81. # Make sure seeds are the same at each rank
  82. if seed is None and self.world_size > 1:
  83. seed = 0
  84. self.rng = np.random.RandomState(seed)
  85. def __iter__(self) -> Union[Generator, Iterator]:
  86. return self.batch()
  87. def __len__(self) -> int:
  88. if self.drop_last:
  89. return self.num_samples // self.batch_size
  90. else:
  91. return int(math.ceil(self.num_samples / self.batch_size))
  92. def sample(self):
  93. """
  94. Return a list contains all sample indices.
  95. """
  96. raise NotImplementedError
  97. def scatter(self, indices) -> List:
  98. r"""
  99. Scatter method is used for splitting indices into subset, each subset
  100. will be assigned to a rank. Indices are evenly splitted by default.
  101. If customized indices assignment method is needed, please rewrite this method.
  102. """
  103. total_size = self.num_samples * self.world_size
  104. # add extra indices to make it evenly divisible
  105. indices += indices[: (total_size - len(indices))]
  106. assert len(indices) == total_size
  107. # subsample
  108. indices = indices[self.rank : total_size : self.world_size]
  109. assert len(indices) == self.num_samples
  110. return indices
  111. def batch(self) -> Iterator[List[Any]]:
  112. r"""
  113. Batch method provides a batch indices generator.
  114. """
  115. indices = list(self.sample())
  116. # user might pass the world_size parameter without dist,
  117. # so dist.is_distributed() should not be used
  118. if self.world_size > 1:
  119. indices = self.scatter(indices)
  120. step, length = self.batch_size, len(indices)
  121. batch_index = [indices[i : i + step] for i in range(0, length, step)]
  122. if self.drop_last and len(batch_index[-1]) < self.batch_size:
  123. batch_index.pop()
  124. return iter(batch_index)
  125. class StreamSampler(Sampler):
  126. r"""
  127. Sampler for stream dataset.
  128. .. warning::
  129. In the case of multiple machines, sampler should ensure that each worker gets
  130. different data. But this class cannot do it yet, please build your own
  131. dataset and sampler to achieve this goal.
  132. Usually, :meth:`~.StreamDataset.__iter__` can return different iterator by
  133. ``rank = dist.get_rank()``. So that they will get different data.
  134. """
  135. def __init__(self, batch_size=1):
  136. self.batch_size = batch_size
  137. def __iter__(self):
  138. return self
  139. def __next__(self):
  140. return iter(range(self.batch_size))
  141. class SequentialSampler(MapSampler):
  142. r"""
  143. Sample elements sequentially.
  144. :param dataset: dataset to sample from.
  145. :param batch_size: batch size for batch method.
  146. :param drop_last: set ``True`` to drop the last incomplete batch,
  147. if the dataset size is not divisible by the batch size. If ``False`` and
  148. the size of dataset is not divisible by the batch_size, then the last batch will
  149. be smaller. Default: False
  150. :param indices: indice of samples.
  151. :param world_size: number of ranks.
  152. :param rank: rank id, non-negative interger within 0 and ``world_size``.
  153. """
  154. def __init__(
  155. self,
  156. dataset,
  157. batch_size=1,
  158. drop_last=False,
  159. indices=None,
  160. world_size=None,
  161. rank=None,
  162. ):
  163. super().__init__(dataset, batch_size, drop_last, None, world_size, rank)
  164. if indices is not None and not isinstance(indices, collections.abc.Sequence):
  165. raise ValueError(
  166. "indices should be None or a sequence, "
  167. "but got indices={}".format(indices)
  168. )
  169. self.indices = indices
  170. def sample(self) -> Iterator[Any]:
  171. r"""
  172. Return a generator.
  173. """
  174. if self.indices is None:
  175. return iter(range(len(self.dataset)))
  176. else:
  177. return self.indices
  178. class RandomSampler(MapSampler):
  179. r"""
  180. Sample elements randomly without replacement.
  181. :param dataset: dataset to sample from.
  182. :param batch_size: batch size for batch method.
  183. :param drop_last: set ``True`` to drop the last incomplete batch,
  184. if the dataset size is not divisible by the batch size. If ``False`` and
  185. the size of dataset is not divisible by the batch_size, then the last batch will
  186. be smaller. Default: False
  187. :param indices: indice of samples.
  188. :param world_size: number of ranks.
  189. :param rank: rank id, non-negative interger within 0 and ``world_size``.
  190. :param seed: seed for random operators.
  191. """
  192. def __init__(
  193. self,
  194. dataset,
  195. batch_size=1,
  196. drop_last=False,
  197. indices=None,
  198. world_size=None,
  199. rank=None,
  200. seed=None,
  201. ):
  202. super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed)
  203. if indices is not None and not isinstance(indices, collections.abc.Sequence):
  204. raise ValueError(
  205. "indices should be None or a sequence, "
  206. "but got indices={}".format(indices)
  207. )
  208. self.indices = indices
  209. def sample(self) -> List:
  210. if self.indices is None:
  211. return self.rng.permutation(len(self.dataset)).tolist()
  212. else:
  213. return self.rng.permutation(self.indices).tolist()
  214. class ReplacementSampler(MapSampler):
  215. r"""
  216. Sample elements randomly with replacement.
  217. :param dataset: dataset to sample from.
  218. :param batch_size: batch size for batch method.
  219. :param drop_last: set ``True`` to drop the last incomplete batch,
  220. if the dataset size is not divisible by the batch size. If ``False`` and
  221. the size of dataset is not divisible by the batch_size, then the last batch will
  222. be smaller. Default: False
  223. :param num_samples: number of samples assigned to one rank.
  224. :param weights: weights for sampling indices, it could be unnormalized weights.
  225. :param world_size: number of ranks.
  226. :param rank: rank id, non-negative interger within 0 and ``world_size``.
  227. :param seed: seed for random operators.
  228. """
  229. def __init__(
  230. self,
  231. dataset,
  232. batch_size=1,
  233. drop_last=False,
  234. num_samples=None,
  235. weights=None,
  236. world_size=None,
  237. rank=None,
  238. seed=None,
  239. ):
  240. super().__init__(
  241. dataset, batch_size, drop_last, num_samples, world_size, rank, seed
  242. )
  243. if weights is not None:
  244. if not isinstance(weights, collections.abc.Sequence):
  245. raise ValueError(
  246. "weights should be None or a sequence, "
  247. "but got weights={}".format(weights)
  248. )
  249. if len(weights) != len(dataset):
  250. raise ValueError(
  251. "len(dataset)={} should be equal to"
  252. "len(weights)={}".format(len(dataset), len(weights))
  253. )
  254. self.weights = weights
  255. if self.weights is not None:
  256. self.weights = np.array(weights) / sum(weights)
  257. def sample(self) -> List:
  258. n = len(self.dataset)
  259. if self.weights is None:
  260. return self.rng.randint(n, size=self.num_samples).tolist()
  261. else:
  262. return self.rng.multinomial(n, self.weights, self.num_samples).tolist()
  263. class Infinite(MapSampler):
  264. r"""Infinite Sampler warper for basic sampler."""
  265. def sample(self):
  266. raise NotImplementedError("sample method not supported in Infinite")
  267. def __init__(self, sampler):
  268. self.sampler = sampler
  269. self.sampler_iter = iter(self.sampler)
  270. def __iter__(self):
  271. return self
  272. def __next__(self):
  273. try:
  274. index = next(self.sampler_iter)
  275. except StopIteration:
  276. self.sampler_iter = iter(self.sampler)
  277. index = next(self.sampler_iter)
  278. return index
  279. def __len__(self):
  280. return np.iinfo(np.int64).max

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台