You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sampler.py 8.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections.abc
  10. import math
  11. from abc import ABC
  12. from typing import Any, Generator, Iterator, List, Union
  13. import numpy as np
  14. import megengine.distributed as dist
  15. class Sampler(ABC):
  16. def __init__(
  17. self,
  18. dataset,
  19. batch_size=1,
  20. drop_last=False,
  21. num_samples=None,
  22. world_size=None,
  23. rank=None,
  24. seed=None,
  25. ):
  26. r"""
  27. An abstract class for all sampler
  28. :type dataset: `dataset`
  29. :param dataset: dataset to sample from
  30. :type batch_size: positive integer
  31. :param batch_size: batch size for batch method
  32. :type drop_last: bool
  33. :param drop_last: set ``True`` to drop the last incomplete batch,
  34. if the dataset size is not divisible by the batch size. If ``False`` and
  35. the size of dataset is not divisible by the batch_size, then the last batch will
  36. be smaller. (default: ``False``)
  37. :type num_samples: positive integer
  38. :param num_samples: number of samples assigned to one rank
  39. :type world_size: positive integer
  40. :param world_size: number of ranks
  41. :type rank: non-negative integer within 0 and world_size
  42. :param rank: rank id, non-negative interger within 0 and ``world_size``
  43. :type seed: non-negative integer
  44. :param seed: seed for random operators
  45. """
  46. if (
  47. not isinstance(batch_size, int)
  48. or isinstance(batch_size, bool)
  49. or batch_size <= 0
  50. ):
  51. raise ValueError(
  52. "batch_size should be a positive integer value, "
  53. "but got batch_size={}".format(batch_size)
  54. )
  55. if not isinstance(drop_last, bool):
  56. raise ValueError(
  57. "drop_last should be a boolean value, but got "
  58. "drop_last={}".format(drop_last)
  59. )
  60. if num_samples is not None and (
  61. not isinstance(num_samples, int)
  62. or isinstance(num_samples, bool)
  63. or num_samples <= 0
  64. ):
  65. raise ValueError(
  66. "num_samples should be a positive integer "
  67. "value, but got num_samples={}".format(num_samples)
  68. )
  69. self.batch_size = batch_size
  70. self.dataset = dataset
  71. self.drop_last = drop_last
  72. if world_size is None:
  73. world_size = dist.get_world_size() if dist.is_distributed() else 1
  74. self.world_size = world_size
  75. if rank is None:
  76. rank = dist.get_rank() if dist.is_distributed() else 0
  77. self.rank = rank
  78. if num_samples is None:
  79. num_samples = len(self.dataset)
  80. self.num_samples = int(math.ceil(num_samples / self.world_size))
  81. # Make sure seeds are the same at each rank
  82. if seed is None and self.world_size > 1:
  83. seed = 0
  84. self.rng = np.random.RandomState(seed)
  85. def __iter__(self) -> Union[Generator, Iterator]:
  86. return self.batch()
  87. def __len__(self) -> int:
  88. if self.drop_last:
  89. return self.num_samples // self.batch_size
  90. else:
  91. return int(math.ceil(self.num_samples / self.batch_size))
  92. def sample(self):
  93. """
  94. return a list contains all sample indices
  95. """
  96. raise NotImplementedError
  97. def scatter(self, indices) -> List:
  98. r"""
  99. scatter method is used for splitting indices into subset, each subset
  100. will be assigned to a rank. Indices are evenly splitted by default.
  101. If customized indices assignment method is needed, please rewrite this method
  102. """
  103. total_size = self.num_samples * self.world_size
  104. # add extra indices to make it evenly divisible
  105. indices += indices[: (total_size - len(indices))]
  106. assert len(indices) == total_size
  107. # subsample
  108. indices = indices[self.rank : total_size : self.world_size]
  109. assert len(indices) == self.num_samples
  110. return indices
  111. def batch(self) -> Iterator[List[Any]]:
  112. r"""
  113. batch method provides a batch indices generator
  114. """
  115. indices = list(self.sample())
  116. # user might pass the world_size parameter without dist,
  117. # so dist.is_distributed() should not be used
  118. if self.world_size > 1:
  119. indices = self.scatter(indices)
  120. step, length = self.batch_size, len(indices)
  121. batch_index = [indices[i : i + step] for i in range(0, length, step)]
  122. if self.drop_last and len(batch_index[-1]) < self.batch_size:
  123. batch_index.pop()
  124. return iter(batch_index)
  125. class SequentialSampler(Sampler):
  126. def __init__(
  127. self,
  128. dataset,
  129. batch_size=1,
  130. drop_last=False,
  131. indices=None,
  132. world_size=None,
  133. rank=None,
  134. ):
  135. r"""
  136. Sample elements sequentially
  137. """
  138. super().__init__(dataset, batch_size, drop_last, None, world_size, rank)
  139. if indices is not None and not isinstance(indices, collections.abc.Sequence):
  140. raise ValueError(
  141. "indices should be None or a sequence, "
  142. "but got indices={}".format(indices)
  143. )
  144. self.indices = indices
  145. def sample(self) -> Iterator[Any]:
  146. r"""
  147. return a generator
  148. """
  149. if self.indices is None:
  150. return iter(range(len(self.dataset)))
  151. else:
  152. return self.indices
  153. class RandomSampler(Sampler):
  154. def __init__(
  155. self,
  156. dataset,
  157. batch_size=1,
  158. drop_last=False,
  159. indices=None,
  160. world_size=None,
  161. rank=None,
  162. seed=None,
  163. ):
  164. r"""
  165. Sample elements randomly without replacement
  166. """
  167. super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed)
  168. if indices is not None and not isinstance(indices, collections.abc.Sequence):
  169. raise ValueError(
  170. "indices should be None or a sequence, "
  171. "but got indices={}".format(indices)
  172. )
  173. self.indices = indices
  174. def sample(self) -> List:
  175. if self.indices is None:
  176. return self.rng.permutation(len(self.dataset)).tolist()
  177. else:
  178. return self.rng.permutation(self.indices).tolist()
  179. class ReplacementSampler(Sampler):
  180. def __init__(
  181. self,
  182. dataset,
  183. batch_size=1,
  184. drop_last=False,
  185. num_samples=None,
  186. weights=None,
  187. world_size=None,
  188. rank=None,
  189. seed=None,
  190. ):
  191. r"""
  192. Sample elements randomly with replacement
  193. :type weights: List
  194. :param weights: weights for sampling indices, it could be unnormalized weights
  195. """
  196. super().__init__(
  197. dataset, batch_size, drop_last, num_samples, world_size, rank, seed
  198. )
  199. if weights is not None:
  200. if not isinstance(weights, collections.abc.Sequence):
  201. raise ValueError(
  202. "weights should be None or a sequence, "
  203. "but got weights={}".format(weights)
  204. )
  205. if len(weights) != len(dataset):
  206. raise ValueError(
  207. "len(dataset)={} should be equal to"
  208. "len(weights)={}".format(len(dataset), len(weights))
  209. )
  210. self.weights = weights
  211. if self.weights is not None:
  212. self.weights = np.array(weights) / sum(weights)
  213. def sample(self) -> List:
  214. n = len(self.dataset)
  215. if self.weights is None:
  216. return self.rng.randint(n, size=self.num_samples).tolist()
  217. else:
  218. return self.rng.multinomial(n, self.weights, self.num_samples).tolist()
  219. class Infinite(Sampler):
  220. r"""Infinite Sampler warper for basic sampler"""
  221. def sample(self):
  222. raise NotImplementedError("sample method not supported in Infinite")
  223. def __init__(self, sampler):
  224. self.sampler = sampler
  225. self.sampler_iter = iter(self.sampler)
  226. def __iter__(self):
  227. return self
  228. def __next__(self):
  229. try:
  230. index = next(self.sampler_iter)
  231. except StopIteration:
  232. self.sampler_iter = iter(self.sampler)
  233. index = next(self.sampler_iter)
  234. return index
  235. def __len__(self):
  236. return np.iinfo(np.int64).max

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台