You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import math
  11. import multiprocessing
  12. import queue
  13. import random
  14. import time
  15. import numpy as np
  16. import megengine as mge
  17. from ..logger import get_logger
  18. from .collator import Collator
  19. from .dataset import Dataset
  20. from .sampler import Sampler, SequentialSampler
  21. from .transform import PseudoTransform, Transform
  22. logger = get_logger(__name__)
  23. MP_QUEUE_GET_TIMEOUT = 5
  24. class DataLoader:
  25. __initialized = False
  26. def __init__(
  27. self,
  28. dataset: Dataset,
  29. sampler: Sampler = None,
  30. transform: Transform = None,
  31. collator: Collator = None,
  32. num_workers: int = 0,
  33. timeout: int = 0,
  34. divide: bool = False,
  35. ):
  36. r"""Provides a convenient way to iterate on a given dataset.
  37. `DataLoader` combines a dataset with sampler, transform and collator,
  38. make it flexible to get minibatch continually from a dataset.
  39. :type dataset: Dataset
  40. :param dataset: dataset from which to load the minibatch.
  41. :type sampler: Sampler
  42. :param sampler: defines the strategy to sample data from the dataset.
  43. If specified, :attr:`shuffle` must be ``False``.
  44. :type transform: Transform
  45. :param transform: defined the transforming strategy for a sampled batch.
  46. (default: ``None``)
  47. :type collator: Collator
  48. :param collator: defined the merging strategy for a transformed batch.
  49. (default: ``None``)
  50. :type num_workers: int
  51. :param num_workers: the number of sub-process to load, transform and collate
  52. the batch. ``0`` means using single-process. (default: ``0``)
  53. :type timeout: int
  54. :param timeout: if positive, means the timeout value(second) for collecting a
  55. batch from workers. (default: 0)
  56. :type divide: bool
  57. :param divide: define the paralleling strategy in multi-processing mode.
  58. ``True`` means one batch is divided into :attr:`num_workers` pieces, and
  59. the workers will process these pieces parallelly. ``False`` means
  60. different sub-process will process different batch. (default: ``False``)
  61. """
  62. if num_workers < 0:
  63. raise ValueError("num_workers should not be negative")
  64. if timeout < 0:
  65. raise ValueError("timeout should not be negative")
  66. if divide and num_workers <= 1:
  67. raise ValueError("divide should not be set to True when num_workers <= 1")
  68. self.dataset = dataset
  69. self.num_workers = num_workers
  70. self.timeout = timeout
  71. self.divide = divide
  72. self.rng = np.random.RandomState()
  73. if sampler is None:
  74. self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False)
  75. else:
  76. self.sampler = sampler
  77. if divide:
  78. if self.sampler.batch_size <= self.num_workers:
  79. raise ValueError(
  80. "batch size must not smaller than num_workers in divide mode."
  81. )
  82. elif self.sampler.batch_size % self.num_workers:
  83. logger.warning(
  84. "batch size is not divisible by num_workers, may lose performance in divide mode."
  85. )
  86. if transform is None:
  87. self.transform = PseudoTransform()
  88. else:
  89. self.transform = transform
  90. if collator is None:
  91. self.collator = Collator()
  92. else:
  93. self.collator = collator
  94. self.__initialized = True
  95. def __iter__(self):
  96. if self.num_workers == 0:
  97. return _SerialDataLoaderIter(self)
  98. else:
  99. return _ParallelDataLoaderIter(self)
  100. def __len__(self):
  101. return len(self.sampler)
  102. class _BaseDataLoaderIter:
  103. def __init__(self, loader):
  104. self.dataset = loader.dataset
  105. self.sampler = loader.sampler
  106. self.seed = loader.rng.randint(1e9)
  107. self.transform = loader.transform
  108. self.collator = loader.collator
  109. self.num_workers = loader.num_workers
  110. self.timeout = loader.timeout
  111. self.divide = loader.divide
  112. self.num_processed = 0
  113. def _get_next_batch(self):
  114. raise NotImplementedError
  115. def __len__(self):
  116. return len(self.sampler)
  117. def __iter__(self):
  118. return self
  119. def __next__(self):
  120. if self.num_processed >= len(self):
  121. raise StopIteration
  122. minibatch = self._get_next_batch()
  123. self.num_processed += 1
  124. return minibatch
  125. class _SerialDataLoaderIter(_BaseDataLoaderIter):
  126. def __init__(self, loader):
  127. super(_SerialDataLoaderIter, self).__init__(loader)
  128. self.indices_iter = iter(self.sampler)
  129. def _get_next_batch(self):
  130. indices = next(self.indices_iter)
  131. items = [self.dataset[idx] for idx in indices]
  132. trans_items = self.transform.apply_batch(items)
  133. return self.collator.apply(trans_items)
  134. class _ParallelDataLoaderIter(_BaseDataLoaderIter):
  135. __initialized = False
  136. def __init__(self, loader):
  137. super(_ParallelDataLoaderIter, self).__init__(loader)
  138. # if any worker died, all workers will be shutdown.
  139. self.strict = True
  140. # TODO: put `strict` into DataLoader args or not?
  141. self.task_queues = [
  142. multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
  143. ]
  144. self.feed_batch_idx = multiprocessing.Value("i", 0)
  145. self.target_batch_idx = multiprocessing.Value("i", 0)
  146. self.shutdown_flag = multiprocessing.Value("i", 0)
  147. self.batch_part_queues = [
  148. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  149. ]
  150. # use shared-memory queue implemented by pyarrow plasma store.
  151. from ._queue import PlasmaShmQueue
  152. self.batch_queue = PlasmaShmQueue(maxsize=2)
  153. self.task_feeding_worker = multiprocessing.Process(
  154. target=self._task_feeding_loop,
  155. args=(iter(self.sampler), self.divide),
  156. daemon=True,
  157. )
  158. self.task_feeding_worker.start()
  159. self.workers = []
  160. for worker_id in range(self.num_workers):
  161. worker = multiprocessing.Process(
  162. target=self._worker_loop,
  163. args=(
  164. self.task_queues[worker_id],
  165. self.batch_part_queues[worker_id],
  166. self.transform,
  167. self.collator,
  168. self.seed + worker_id + 1,
  169. ),
  170. daemon=True,
  171. )
  172. worker.start()
  173. self.workers.append(worker)
  174. if self.divide:
  175. self.data_collecting_worker = multiprocessing.Process(
  176. target=self._data_gathering_loop,
  177. args=(self.batch_part_queues, self.batch_queue,),
  178. daemon=True,
  179. )
  180. else:
  181. self.data_collecting_worker = multiprocessing.Process(
  182. target=self._data_selecting_loop,
  183. args=(self.batch_part_queues, self.batch_queue,),
  184. daemon=True,
  185. )
  186. self.data_collecting_worker.start()
  187. self.__initialized = True
  188. def _task_feeding_loop(self, indices_iter, divide):
  189. while True:
  190. if self.shutdown_flag.value == 1:
  191. break
  192. batch_idx = self.feed_batch_idx.value
  193. try:
  194. indices = next(indices_iter)
  195. except StopIteration:
  196. break
  197. if divide:
  198. # make sure all task_queues is ready for put
  199. while any([q.full() for q in self.task_queues]):
  200. if self.shutdown_flag.value == 1:
  201. return
  202. # divide into small pieces, feed to different workers.
  203. sub_num = math.ceil(len(indices) / self.num_workers)
  204. for worker_id in range(self.num_workers):
  205. sub_indices = indices[
  206. worker_id * sub_num : (worker_id + 1) * sub_num
  207. ]
  208. self.task_queues[worker_id].put((batch_idx, sub_indices))
  209. else:
  210. # distribute tasks to different workers uniformly.
  211. target_id = batch_idx % self.num_workers
  212. while self.task_queues[target_id].full():
  213. if self.shutdown_flag.value == 1:
  214. return
  215. self.task_queues[target_id].put((batch_idx, indices))
  216. with self.feed_batch_idx.get_lock():
  217. self.feed_batch_idx.value += 1
  218. def _worker_loop(self, task_queue, data_queue, transform, collator, seed):
  219. random.seed(seed)
  220. np.random.seed(seed)
  221. while True:
  222. if self.shutdown_flag.value == 1:
  223. break
  224. try:
  225. batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT)
  226. except queue.Empty:
  227. continue
  228. if len(indices) > 0:
  229. items = [self.dataset[idx] for idx in indices]
  230. trans_items = transform.apply_batch(items)
  231. batch_data = collator.apply(trans_items)
  232. else:
  233. # in case of incomplete last batch
  234. batch_data = ()
  235. while True:
  236. try:
  237. data_queue.put((np.array([batch_idx]), batch_data), timeout=1)
  238. break
  239. except queue.Full:
  240. if self.shutdown_flag.value == 1:
  241. break
  242. logger.debug("batch part queue is full!")
  243. continue
  244. def _data_gathering_loop(self, batch_part_queues, batch_queue):
  245. r"""Gathering the small pieces of batch data into full batch data."""
  246. gathered_data = collections.defaultdict(dict)
  247. while True:
  248. if self.shutdown_flag.value == 1:
  249. break
  250. target_batch_idx = self.target_batch_idx.value
  251. if target_batch_idx >= len(self):
  252. break
  253. for worker_id in range(self.num_workers):
  254. if worker_id in gathered_data[target_batch_idx]:
  255. continue
  256. while True:
  257. try:
  258. (batch_idx,), batch_part = batch_part_queues[worker_id].get(
  259. timeout=MP_QUEUE_GET_TIMEOUT
  260. )
  261. break
  262. except queue.Empty:
  263. if self.shutdown_flag.value == 1:
  264. break
  265. logger.debug(
  266. "worker:{} data queue get timeout! target batch idx:{}".format(
  267. worker_id, target_batch_idx
  268. )
  269. )
  270. if batch_idx < target_batch_idx:
  271. raise RuntimeError(
  272. "Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
  273. worker_id
  274. )
  275. )
  276. else:
  277. gathered_data[batch_idx][worker_id] = batch_part
  278. if len(gathered_data[target_batch_idx]) < self.num_workers:
  279. length = len(gathered_data[target_batch_idx])
  280. if self.strict:
  281. raise RuntimeError("Parts missing in data gathering loop.")
  282. logger.warning(
  283. "target_batch_idx:{}, {} part(s) missing.".format(
  284. target_batch_idx, self.num_workers - length
  285. )
  286. )
  287. del gathered_data[target_batch_idx]
  288. with self.target_batch_idx.get_lock():
  289. self.target_batch_idx.value += 1
  290. continue
  291. # Merge different parts.
  292. full_batch = [[] for _ in range(len(gathered_data[target_batch_idx][0]))]
  293. for idx in range(self.num_workers):
  294. for i, field in enumerate(gathered_data[target_batch_idx][idx]):
  295. full_batch[i].append(field)
  296. full_batch = tuple([np.concatenate(field, axis=0) for field in full_batch])
  297. while True:
  298. try:
  299. batch_queue.put(full_batch, timeout=1)
  300. break
  301. except queue.Full:
  302. if self.shutdown_flag.value == 1:
  303. break
  304. logger.debug("batch queue is full!")
  305. continue
  306. del gathered_data[target_batch_idx]
  307. with self.target_batch_idx.get_lock():
  308. self.target_batch_idx.value += 1
  309. batch_queue.disconnect_client()
  310. def _data_selecting_loop(self, batch_part_queues, batch_queue):
  311. r"""Make sure that batch is generated exactly with the same order as generated indices."""
  312. buffer_batches = {}
  313. while True:
  314. if self.shutdown_flag.value == 1:
  315. break
  316. target_batch_idx = self.target_batch_idx.value
  317. if target_batch_idx >= len(self):
  318. break
  319. if target_batch_idx in buffer_batches:
  320. while True:
  321. try:
  322. batch_queue.put(
  323. buffer_batches[target_batch_idx], timeout=1,
  324. )
  325. break
  326. except queue.Full:
  327. if self.shutdown_flag.value == 1:
  328. break
  329. logger.debug("batch queue is full!")
  330. with self.target_batch_idx.get_lock():
  331. self.target_batch_idx.value += 1
  332. del buffer_batches[target_batch_idx]
  333. continue
  334. target_worker_id = target_batch_idx % self.num_workers
  335. while True:
  336. try:
  337. (batch_idx,), batch_data = batch_part_queues[target_worker_id].get(
  338. timeout=MP_QUEUE_GET_TIMEOUT
  339. )
  340. break
  341. except queue.Empty:
  342. if self.shutdown_flag.value == 1:
  343. break
  344. logger.debug(
  345. "worker:{} data queue get timeout! target batch idx:{}".format(
  346. target_worker_id, target_batch_idx
  347. )
  348. )
  349. if batch_idx < target_batch_idx:
  350. raise RuntimeError("batch_idx smaller than target_batch_idx")
  351. elif batch_idx > target_batch_idx:
  352. if self.strict:
  353. raise RuntimeError("batch_idx larger than target_batch_idx")
  354. logger.warning(
  355. "missing target batch idx:{}, batch idx:{}".format(
  356. target_batch_idx, batch_idx
  357. )
  358. )
  359. buffer_batches[batch_idx] = batch_data
  360. else:
  361. try:
  362. batch_queue.put(batch_data, timeout=1)
  363. except queue.Full:
  364. buffer_batches[batch_idx] = batch_data
  365. continue
  366. with self.target_batch_idx.get_lock():
  367. self.target_batch_idx.value += 1
  368. batch_queue.disconnect_client()
  369. def _check_workers(self):
  370. """Check the status of each worker and restart if necessary."""
  371. if not self.data_collecting_worker.is_alive():
  372. exitcode = self.task_feeding_worker.exitcode
  373. if exitcode != 0:
  374. raise RuntimeError("data collecting worker died. {}".format(exitcode))
  375. if self.strict:
  376. if not self.task_feeding_worker.is_alive():
  377. exitcode = self.task_feeding_worker.exitcode
  378. if exitcode != 0:
  379. raise RuntimeError("task feeding worker died. {}".format(exitcode))
  380. for worker_id, worker in enumerate(self.workers):
  381. if not worker.is_alive():
  382. exitcode = worker.exitcode
  383. if exitcode != 0:
  384. raise RuntimeError(
  385. "worker:{} died. {}".format(worker_id, exitcode)
  386. )
  387. else:
  388. if not self.task_feeding_worker.is_alive():
  389. exitcode = self.task_feeding_worker.exitcode
  390. if exitcode != 0:
  391. logger.error(
  392. "task feeding worker died {}. Restarting".format(exitcode)
  393. )
  394. self.task_feeding_worker.join()
  395. self.task_feeding_worker = multiprocessing.Process(
  396. target=self._task_feeding_loop,
  397. args=(iter(self.sampler), self.divide),
  398. daemon=True,
  399. )
  400. self.task_feeding_worker.start()
  401. failed_num = 0
  402. for worker_id in range(self.num_workers):
  403. if self.workers[worker_id].is_alive():
  404. continue
  405. exitcode = worker.exitcode
  406. if exitcode == 0:
  407. continue
  408. logger.error("worker {} died. Restarting".format(worker_id))
  409. failed_num += 1
  410. self.workers[worker_id].join()
  411. worker = multiprocessing.Process(
  412. target=self._worker_loop,
  413. args=(
  414. self.task_queues[worker_id],
  415. self.batch_part_queues[worker_id],
  416. self.transform,
  417. self.collator,
  418. self.seed + worker_id + 1,
  419. ),
  420. daemon=True,
  421. )
  422. worker.start()
  423. self.workers[worker_id] = worker
  424. if failed_num > 0:
  425. logger.error("{} worker had exited".format(failed_num))
  426. else:
  427. logger.debug("all workers are alive.")
  428. def _try_get_next_batch(self):
  429. start_time = time.time()
  430. while True:
  431. self._check_workers()
  432. try:
  433. return self.batch_queue.get(timeout=1)
  434. except queue.Empty:
  435. logger.debug("batch queue empty!")
  436. waited_time = time.time() - start_time
  437. if self.timeout > 0:
  438. if waited_time > self.timeout:
  439. raise RuntimeError("get_next_batch timeout!")
  440. def _get_next_batch(self):
  441. batch_data = self._try_get_next_batch()
  442. return batch_data
  443. def _shutdown(self):
  444. with self.shutdown_flag.get_lock():
  445. self.shutdown_flag.value = 1
  446. if self.task_feeding_worker.is_alive():
  447. self.task_feeding_worker.terminate()
  448. self.task_feeding_worker.join()
  449. if self.data_collecting_worker.is_alive():
  450. self.data_collecting_worker.terminate()
  451. self.data_collecting_worker.join()
  452. for worker in self.workers:
  453. if worker.is_alive():
  454. worker.terminate()
  455. worker.join()
  456. for q in self.batch_part_queues:
  457. q.cancel_join_thread()
  458. q.close()
  459. for q in self.task_queues:
  460. q.cancel_join_thread()
  461. q.close()
  462. self.batch_queue.cancel_join_thread()
  463. self.batch_queue.close()
  464. def __del__(self):
  465. if self.__initialized:
  466. self._shutdown()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)