You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import math
  11. import multiprocessing
  12. import queue
  13. import random
  14. import time
  15. import numpy as np
  16. from ..logger import get_logger
  17. from ..random.rng import _random_seed_generator
  18. from .collator import Collator
  19. from .dataset import Dataset
  20. from .sampler import Sampler, SequentialSampler
  21. from .transform import PseudoTransform, Transform
  22. logger = get_logger(__name__)
  23. MP_QUEUE_GET_TIMEOUT = 5
  24. class DataLoader:
  25. __initialized = False
  26. def __init__(
  27. self,
  28. dataset: Dataset,
  29. sampler: Sampler = None,
  30. transform: Transform = None,
  31. collator: Collator = None,
  32. num_workers: int = 0,
  33. timeout: int = 0,
  34. divide: bool = False,
  35. ):
  36. r"""Provides a convenient way to iterate on a given dataset.
  37. `DataLoader` combines a dataset with `sampler`, `transform` and `collator`,
  38. make it flexible to get minibatch continually from a dataset.
  39. :type dataset: Dataset
  40. :param dataset: dataset from which to load the minibatch.
  41. :type sampler: Sampler
  42. :param sampler: defines the strategy to sample data from the dataset.
  43. :type transform: Transform
  44. :param transform: defined the transforming strategy for a sampled batch.
  45. Default: None
  46. :type collator: Collator
  47. :param collator: defined the merging strategy for a transformed batch.
  48. Default: None
  49. :type num_workers: int
  50. :param num_workers: the number of sub-process to load, transform and collate
  51. the batch. ``0`` means using single-process. Default: 0
  52. :type timeout: int
  53. :param timeout: if positive, means the timeout value(second) for collecting a
  54. batch from workers. Default: 0
  55. :type divide: bool
  56. :param divide: define the paralleling strategy in multi-processing mode.
  57. ``True`` means one batch is divided into :attr:`num_workers` pieces, and
  58. the workers will process these pieces parallelly. ``False`` means
  59. different sub-process will process different batch. Default: False
  60. """
  61. if num_workers < 0:
  62. raise ValueError("num_workers should not be negative")
  63. if timeout < 0:
  64. raise ValueError("timeout should not be negative")
  65. if divide and num_workers <= 1:
  66. raise ValueError("divide should not be set to True when num_workers <= 1")
  67. self.dataset = dataset
  68. self.num_workers = num_workers
  69. self.timeout = timeout
  70. self.divide = divide
  71. if sampler is None:
  72. self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False)
  73. else:
  74. self.sampler = sampler
  75. if divide:
  76. if self.sampler.batch_size <= self.num_workers:
  77. raise ValueError(
  78. "batch size must not smaller than num_workers in divide mode."
  79. )
  80. elif self.sampler.batch_size % self.num_workers:
  81. logger.warning(
  82. "batch size is not divisible by num_workers, may lose performance in divide mode."
  83. )
  84. if transform is None:
  85. self.transform = PseudoTransform()
  86. else:
  87. self.transform = transform
  88. if collator is None:
  89. self.collator = Collator()
  90. else:
  91. self.collator = collator
  92. self.__initialized = True
  93. def __iter__(self):
  94. if self.num_workers == 0:
  95. return _SerialDataLoaderIter(self)
  96. else:
  97. return _ParallelDataLoaderIter(self)
  98. def __len__(self):
  99. return len(self.sampler)
  100. class _BaseDataLoaderIter:
  101. def __init__(self, loader):
  102. self.dataset = loader.dataset
  103. self.sampler = loader.sampler
  104. self.seed = _random_seed_generator().__next__()
  105. self.transform = loader.transform
  106. self.collator = loader.collator
  107. self.num_workers = loader.num_workers
  108. self.timeout = loader.timeout
  109. self.divide = loader.divide
  110. self.num_processed = 0
  111. def _get_next_batch(self):
  112. raise NotImplementedError
  113. def __len__(self):
  114. return len(self.sampler)
  115. def __iter__(self):
  116. return self
  117. def __next__(self):
  118. if self.num_processed >= len(self):
  119. raise StopIteration
  120. minibatch = self._get_next_batch()
  121. self.num_processed += 1
  122. return minibatch
  123. class _SerialDataLoaderIter(_BaseDataLoaderIter):
  124. def __init__(self, loader):
  125. super(_SerialDataLoaderIter, self).__init__(loader)
  126. self.indices_iter = iter(self.sampler)
  127. def _get_next_batch(self):
  128. indices = next(self.indices_iter)
  129. items = [self.dataset[idx] for idx in indices]
  130. trans_items = self.transform.apply_batch(items)
  131. return self.collator.apply(trans_items)
  132. class _ParallelDataLoaderIter(_BaseDataLoaderIter):
  133. __initialized = False
  134. def __init__(self, loader):
  135. super(_ParallelDataLoaderIter, self).__init__(loader)
  136. self.task_queues = [
  137. multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
  138. ]
  139. self.feed_batch_idx = multiprocessing.Value("i", 0)
  140. self.target_batch_idx = multiprocessing.Value("i", 0)
  141. self.shutdown_flag = multiprocessing.Value("i", 0)
  142. self.trans_data_queues = [
  143. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  144. ]
  145. # use shared-memory queue implemented by pyarrow plasma store.
  146. from ._queue import PlasmaShmQueue
  147. self.batch_queue = PlasmaShmQueue(maxsize=2)
  148. self.task_feeding_worker = multiprocessing.Process(
  149. target=_task_feeding_loop,
  150. args=(
  151. iter(self.sampler),
  152. self.task_queues,
  153. self.num_workers,
  154. self.divide,
  155. self.shutdown_flag,
  156. self.feed_batch_idx,
  157. ),
  158. daemon=True,
  159. )
  160. self.task_feeding_worker.start()
  161. self.workers = []
  162. for worker_id in range(self.num_workers):
  163. worker = multiprocessing.Process(
  164. target=_worker_loop,
  165. args=(
  166. self.dataset,
  167. self.task_queues[worker_id],
  168. self.trans_data_queues[worker_id],
  169. self.transform,
  170. self.seed + worker_id + 1,
  171. self.shutdown_flag,
  172. ),
  173. daemon=True,
  174. )
  175. worker.start()
  176. self.workers.append(worker)
  177. if self.divide:
  178. self.data_collecting_worker = multiprocessing.Process(
  179. target=_data_gathering_loop,
  180. args=(
  181. self.trans_data_queues,
  182. self.batch_queue,
  183. self.collator,
  184. len(self),
  185. self.num_workers,
  186. self.shutdown_flag,
  187. self.target_batch_idx,
  188. ),
  189. daemon=True,
  190. )
  191. else:
  192. self.data_collecting_worker = multiprocessing.Process(
  193. target=_data_selecting_loop,
  194. args=(
  195. self.trans_data_queues,
  196. self.batch_queue,
  197. self.collator,
  198. len(self),
  199. self.num_workers,
  200. self.shutdown_flag,
  201. self.target_batch_idx,
  202. ),
  203. daemon=True,
  204. )
  205. self.data_collecting_worker.start()
  206. self.__initialized = True
  207. def _check_workers(self):
  208. # Check the status of each worker.
  209. if not self.data_collecting_worker.is_alive():
  210. exitcode = self.task_feeding_worker.exitcode
  211. if exitcode != 0:
  212. raise RuntimeError("data collecting worker died. {}".format(exitcode))
  213. if not self.task_feeding_worker.is_alive():
  214. exitcode = self.task_feeding_worker.exitcode
  215. if exitcode != 0:
  216. raise RuntimeError("task feeding worker died. {}".format(exitcode))
  217. for worker_id, worker in enumerate(self.workers):
  218. if not worker.is_alive():
  219. exitcode = worker.exitcode
  220. if exitcode != 0:
  221. raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode))
  222. logger.debug("all workers are alive.")
  223. def _try_get_next_batch(self):
  224. start_time = time.time()
  225. while True:
  226. self._check_workers()
  227. try:
  228. return self.batch_queue.get(timeout=1)
  229. except queue.Empty:
  230. logger.debug("batch queue empty!")
  231. waited_time = time.time() - start_time
  232. if self.timeout > 0:
  233. if waited_time > self.timeout:
  234. raise RuntimeError("get_next_batch timeout!")
  235. def _get_next_batch(self):
  236. batch_data = self._try_get_next_batch()
  237. return batch_data
  238. def _shutdown(self):
  239. with self.shutdown_flag.get_lock():
  240. self.shutdown_flag.value = 1
  241. if self.task_feeding_worker.is_alive():
  242. self.task_feeding_worker.terminate()
  243. self.task_feeding_worker.join()
  244. if self.data_collecting_worker.is_alive():
  245. self.data_collecting_worker.terminate()
  246. self.data_collecting_worker.join()
  247. for worker in self.workers:
  248. if worker.is_alive():
  249. worker.terminate()
  250. worker.join()
  251. for q in self.trans_data_queues:
  252. q.cancel_join_thread()
  253. q.close()
  254. for q in self.task_queues:
  255. q.cancel_join_thread()
  256. q.close()
  257. self.batch_queue.cancel_join_thread()
  258. self.batch_queue.close()
  259. def __del__(self):
  260. if self.__initialized:
  261. self._shutdown()
  262. def _task_feeding_loop(
  263. indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
  264. ):
  265. # Feed the indices into the task queues
  266. while True:
  267. if shutdown_flag.value == 1:
  268. break
  269. batch_idx = feed_batch_idx.value
  270. try:
  271. indices = next(indices_iter)
  272. except StopIteration:
  273. break
  274. if divide:
  275. # make sure all task_queues is ready for put
  276. while any([q.full() for q in task_queues]):
  277. if shutdown_flag.value == 1:
  278. return
  279. # divide into small pieces, feed to different workers.
  280. sub_num = math.ceil(len(indices) / num_workers)
  281. for worker_id in range(num_workers):
  282. sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num]
  283. task_queues[worker_id].put((batch_idx, sub_indices))
  284. else:
  285. # distribute tasks to different workers uniformly.
  286. target_id = batch_idx % num_workers
  287. while task_queues[target_id].full():
  288. if shutdown_flag.value == 1:
  289. return
  290. task_queues[target_id].put((batch_idx, indices))
  291. with feed_batch_idx.get_lock():
  292. feed_batch_idx.value += 1
  293. def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag):
  294. # Get dataset items and do the transform
  295. random.seed(seed)
  296. np.random.seed(seed)
  297. while True:
  298. if shutdown_flag.value == 1:
  299. break
  300. try:
  301. batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT)
  302. except queue.Empty:
  303. continue
  304. if len(indices) > 0:
  305. items = [dataset[idx] for idx in indices]
  306. trans_items = transform.apply_batch(items)
  307. else:
  308. # in case of incomplete last batch
  309. trans_items = ()
  310. while True:
  311. try:
  312. trans_data_queue.put((batch_idx, trans_items), timeout=1)
  313. break
  314. except queue.Full:
  315. if shutdown_flag.value == 1:
  316. break
  317. logger.debug("batch part queue is full!")
  318. def _data_gathering_loop(
  319. trans_data_queues,
  320. batch_queue,
  321. collator,
  322. length,
  323. num_workers,
  324. shutdown_flag,
  325. target_idx,
  326. ):
  327. # Gathering the small pieces of batch data into full batch data
  328. while True:
  329. if shutdown_flag.value == 1:
  330. break
  331. target_batch_idx = target_idx.value
  332. if target_batch_idx >= length:
  333. break
  334. full_trans_items = []
  335. for worker_id in range(num_workers):
  336. while True:
  337. try:
  338. batch_idx, trans_items = trans_data_queues[worker_id].get(
  339. timeout=MP_QUEUE_GET_TIMEOUT
  340. )
  341. break
  342. except queue.Empty:
  343. if shutdown_flag.value == 1:
  344. break
  345. logger.debug(
  346. "worker:{} data queue get timeout! target batch idx:{}".format(
  347. worker_id, target_batch_idx
  348. )
  349. )
  350. if batch_idx != target_batch_idx:
  351. raise RuntimeError(
  352. "Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
  353. worker_id
  354. )
  355. )
  356. else:
  357. full_trans_items.extend(trans_items)
  358. # Merge different parts into a batch.
  359. full_batch = collator.apply(full_trans_items)
  360. while True:
  361. try:
  362. batch_queue.put(full_batch, timeout=1)
  363. break
  364. except queue.Full:
  365. if shutdown_flag.value == 1:
  366. break
  367. logger.debug("batch queue is full!")
  368. with target_idx.get_lock():
  369. target_idx.value += 1
  370. batch_queue.disconnect_client()
  371. def _data_selecting_loop(
  372. trans_data_queues,
  373. batch_queue,
  374. collator,
  375. length,
  376. num_workers,
  377. shutdown_flag,
  378. target_idx,
  379. ):
  380. # Make sure that batch is generated exactly with the same order as generated indices
  381. while True:
  382. if shutdown_flag.value == 1:
  383. break
  384. target_batch_idx = target_idx.value
  385. if target_batch_idx >= length:
  386. break
  387. target_worker_id = target_batch_idx % num_workers
  388. while True:
  389. try:
  390. batch_idx, trans_items = trans_data_queues[target_worker_id].get(
  391. timeout=MP_QUEUE_GET_TIMEOUT
  392. )
  393. batch_data = collator.apply(trans_items)
  394. break
  395. except queue.Empty:
  396. if shutdown_flag.value == 1:
  397. break
  398. logger.debug(
  399. "worker:{} data queue get timeout! target batch idx:{}".format(
  400. target_worker_id, target_batch_idx
  401. )
  402. )
  403. if batch_idx != target_batch_idx:
  404. raise RuntimeError(
  405. "batch_idx {} mismatch the target_batch_idx {}".format(
  406. batch_idx, target_batch_idx
  407. )
  408. )
  409. while True:
  410. try:
  411. batch_queue.put(batch_data, timeout=1)
  412. break
  413. except queue.Full:
  414. if shutdown_flag.value == 1:
  415. break
  416. logger.debug("batch queue is full!")
  417. with target_idx.get_lock():
  418. target_idx.value += 1
  419. batch_queue.disconnect_client()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台