You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sampler.py 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections.abc
  10. import math
  11. from abc import ABC, abstractmethod
  12. from typing import Any, Generator, Iterator, List, Union
  13. import numpy as np
  14. import megengine.distributed as dist
  15. class Sampler(ABC):
  16. r"""
  17. An abstract class for all Sampler
  18. """
  19. @abstractmethod
  20. def __init__(self):
  21. pass
  22. class MapSampler(Sampler):
  23. def __init__(
  24. self,
  25. dataset,
  26. batch_size=1,
  27. drop_last=False,
  28. num_samples=None,
  29. world_size=None,
  30. rank=None,
  31. seed=None,
  32. ):
  33. r"""
  34. An abstract class for all sampler.
  35. :type dataset: `dataset`
  36. :param dataset: dataset to sample from.
  37. :type batch_size: positive integer
  38. :param batch_size: batch size for batch method.
  39. :type drop_last: bool
  40. :param drop_last: set ``True`` to drop the last incomplete batch,
  41. if the dataset size is not divisible by the batch size. If ``False`` and
  42. the size of dataset is not divisible by the batch_size, then the last batch will
  43. be smaller. Default: False
  44. :type num_samples: positive integer
  45. :param num_samples: number of samples assigned to one rank.
  46. :type world_size: positive integer
  47. :param world_size: number of ranks.
  48. :type rank: non-negative integer within 0 and world_size
  49. :param rank: rank id, non-negative interger within 0 and ``world_size``.
  50. :type seed: non-negative integer
  51. :param seed: seed for random operators.
  52. """
  53. if (
  54. not isinstance(batch_size, int)
  55. or isinstance(batch_size, bool)
  56. or batch_size <= 0
  57. ):
  58. raise ValueError(
  59. "batch_size should be a positive integer value, "
  60. "but got batch_size={}".format(batch_size)
  61. )
  62. if not isinstance(drop_last, bool):
  63. raise ValueError(
  64. "drop_last should be a boolean value, but got "
  65. "drop_last={}".format(drop_last)
  66. )
  67. if num_samples is not None and (
  68. not isinstance(num_samples, int)
  69. or isinstance(num_samples, bool)
  70. or num_samples <= 0
  71. ):
  72. raise ValueError(
  73. "num_samples should be a positive integer "
  74. "value, but got num_samples={}".format(num_samples)
  75. )
  76. self.batch_size = batch_size
  77. self.dataset = dataset
  78. self.drop_last = drop_last
  79. if world_size is None:
  80. world_size = dist.get_world_size() if dist.is_distributed() else 1
  81. self.world_size = world_size
  82. if rank is None:
  83. rank = dist.get_rank() if dist.is_distributed() else 0
  84. self.rank = rank
  85. if num_samples is None:
  86. num_samples = len(self.dataset)
  87. self.num_samples = int(math.ceil(num_samples / self.world_size))
  88. # Make sure seeds are the same at each rank
  89. if seed is None and self.world_size > 1:
  90. seed = 0
  91. self.rng = np.random.RandomState(seed)
  92. def __iter__(self) -> Union[Generator, Iterator]:
  93. return self.batch()
  94. def __len__(self) -> int:
  95. if self.drop_last:
  96. return self.num_samples // self.batch_size
  97. else:
  98. return int(math.ceil(self.num_samples / self.batch_size))
  99. def sample(self):
  100. """
  101. Return a list contains all sample indices.
  102. """
  103. raise NotImplementedError
  104. def scatter(self, indices) -> List:
  105. r"""
  106. Scatter method is used for splitting indices into subset, each subset
  107. will be assigned to a rank. Indices are evenly splitted by default.
  108. If customized indices assignment method is needed, please rewrite this method.
  109. """
  110. total_size = self.num_samples * self.world_size
  111. # add extra indices to make it evenly divisible
  112. indices += indices[: (total_size - len(indices))]
  113. assert len(indices) == total_size
  114. # subsample
  115. indices = indices[self.rank : total_size : self.world_size]
  116. assert len(indices) == self.num_samples
  117. return indices
  118. def batch(self) -> Iterator[List[Any]]:
  119. r"""
  120. Batch method provides a batch indices generator.
  121. """
  122. indices = list(self.sample())
  123. # user might pass the world_size parameter without dist,
  124. # so dist.is_distributed() should not be used
  125. if self.world_size > 1:
  126. indices = self.scatter(indices)
  127. step, length = self.batch_size, len(indices)
  128. batch_index = [indices[i : i + step] for i in range(0, length, step)]
  129. if self.drop_last and len(batch_index[-1]) < self.batch_size:
  130. batch_index.pop()
  131. return iter(batch_index)
  132. class StreamSampler(Sampler):
  133. """
  134. Sampler for stream dataset.
  135. .. warning::
  136. In the case of multiple machines, sampler should ensure that each worker gets
  137. different data. But this class cannot do it yet, please build your own
  138. dataset and sampler to achieve this goal.
  139. Usually, meth::`~.StreamDataset.__iter__` can return different iterator by
  140. ``rank = dist.get_rank()``. So that they will get different data.
  141. """
  142. def __init__(self, batch_size=1):
  143. self.batch_size = batch_size
  144. def __iter__(self):
  145. return self
  146. def __next__(self):
  147. return iter(range(self.batch_size))
  148. class SequentialSampler(MapSampler):
  149. def __init__(
  150. self,
  151. dataset,
  152. batch_size=1,
  153. drop_last=False,
  154. indices=None,
  155. world_size=None,
  156. rank=None,
  157. ):
  158. r"""
  159. Sample elements sequentially.
  160. """
  161. super().__init__(dataset, batch_size, drop_last, None, world_size, rank)
  162. if indices is not None and not isinstance(indices, collections.abc.Sequence):
  163. raise ValueError(
  164. "indices should be None or a sequence, "
  165. "but got indices={}".format(indices)
  166. )
  167. self.indices = indices
  168. def sample(self) -> Iterator[Any]:
  169. r"""
  170. Return a generator.
  171. """
  172. if self.indices is None:
  173. return iter(range(len(self.dataset)))
  174. else:
  175. return self.indices
  176. class RandomSampler(MapSampler):
  177. def __init__(
  178. self,
  179. dataset,
  180. batch_size=1,
  181. drop_last=False,
  182. indices=None,
  183. world_size=None,
  184. rank=None,
  185. seed=None,
  186. ):
  187. r"""
  188. Sample elements randomly without replacement.
  189. """
  190. super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed)
  191. if indices is not None and not isinstance(indices, collections.abc.Sequence):
  192. raise ValueError(
  193. "indices should be None or a sequence, "
  194. "but got indices={}".format(indices)
  195. )
  196. self.indices = indices
  197. def sample(self) -> List:
  198. if self.indices is None:
  199. return self.rng.permutation(len(self.dataset)).tolist()
  200. else:
  201. return self.rng.permutation(self.indices).tolist()
  202. class ReplacementSampler(MapSampler):
  203. def __init__(
  204. self,
  205. dataset,
  206. batch_size=1,
  207. drop_last=False,
  208. num_samples=None,
  209. weights=None,
  210. world_size=None,
  211. rank=None,
  212. seed=None,
  213. ):
  214. r"""
  215. Sample elements randomly with replacement.
  216. :type weights: List
  217. :param weights: weights for sampling indices, it could be unnormalized weights.
  218. """
  219. super().__init__(
  220. dataset, batch_size, drop_last, num_samples, world_size, rank, seed
  221. )
  222. if weights is not None:
  223. if not isinstance(weights, collections.abc.Sequence):
  224. raise ValueError(
  225. "weights should be None or a sequence, "
  226. "but got weights={}".format(weights)
  227. )
  228. if len(weights) != len(dataset):
  229. raise ValueError(
  230. "len(dataset)={} should be equal to"
  231. "len(weights)={}".format(len(dataset), len(weights))
  232. )
  233. self.weights = weights
  234. if self.weights is not None:
  235. self.weights = np.array(weights) / sum(weights)
  236. def sample(self) -> List:
  237. n = len(self.dataset)
  238. if self.weights is None:
  239. return self.rng.randint(n, size=self.num_samples).tolist()
  240. else:
  241. return self.rng.multinomial(n, self.weights, self.num_samples).tolist()
  242. class Infinite(MapSampler):
  243. r"""Infinite Sampler warper for basic sampler."""
  244. def sample(self):
  245. raise NotImplementedError("sample method not supported in Infinite")
  246. def __init__(self, sampler):
  247. self.sampler = sampler
  248. self.sampler_iter = iter(self.sampler)
  249. def __iter__(self):
  250. return self
  251. def __next__(self):
  252. try:
  253. index = next(self.sampler_iter)
  254. except StopIteration:
  255. self.sampler_iter = iter(self.sampler)
  256. index = next(self.sampler_iter)
  257. return index
  258. def __len__(self):
  259. return np.iinfo(np.int64).max

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台