You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897
  1. # -*- coding: utf-8 -*-
  2. import collections
  3. import gc
  4. import math
  5. import multiprocessing
  6. import os
  7. import platform
  8. import queue
  9. import random
  10. import threading
  11. import time
  12. from typing import Callable, Union
  13. import numpy as np
  14. from ..device import _sh, get_default_device
  15. from ..functional.tensor import copy
  16. from ..logger import get_logger
  17. from ..random.rng import _random_seed_generator
  18. from ..tensor import Tensor
  19. from .collator import Collator
  20. from .dataset import Dataset, StreamDataset
  21. from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
  22. from .transform import PseudoTransform, Transform
  23. try:
  24. import thread
  25. except:
  26. import _thread as thread
  27. logger = get_logger(__name__)
  28. GLOBAL_TIMEOUT = 5
  29. def raise_timeout_error():
  30. raise RuntimeError("dataloader timeout")
  31. class DataLoader:
  32. r"""Provides a convenient way to iterate on a given dataset.
  33. The process is as follows:
  34. .. mermaid::
  35. :align: center
  36. flowchart LR
  37. Dataset.__len__ -- Sampler --> Indices
  38. batch_size -- Sampler --> Indices
  39. Indices -- Dataset.__getitem__ --> Samples
  40. Samples -- Transform + Collator --> mini-batch
  41. DataLoader combines a :class:`~.Dataset` with
  42. :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`,
  43. make it flexible to get minibatch continually from a dataset.
  44. See :ref:`data-guide` for more details.
  45. Args:
  46. dataset: dataset from which to load the minibatch.
  47. sampler: defines the strategy to sample data from the dataset.
  48. If ``None``, it will sequentially sample from the dataset one by one.
  49. transform: defined the transforming strategy for a sampled batch.
  50. collator: defined the merging strategy for a transformed batch.
  51. num_workers: the number of sub-process to load, transform and collate
  52. the batch. ``0`` means using single-process. Default: 0
  53. timeout: if positive, means the timeout value(second) for collecting a
  54. batch from workers. Default: 0
  55. timeout_event: callback function triggered by timeout, default to raise
  56. runtime error.
  57. divide: define the paralleling strategy in multi-processing mode.
  58. ``True`` means one batch is divided into :attr:`num_workers` pieces, and
  59. the workers will process these pieces parallelly. ``False`` means
  60. different sub-process will process different batch. Default: False
  61. preload: whether to enable the preloading strategy of the dataloader.
  62. When enabling, the dataloader will preload one batch to the device memory to speed up the whole training process.
  63. .. admonition:: The effect of enabling preload
  64. :class: warning
  65. * All elements in :class:`map`, :class:`list`, and :class:`tuple` will be converted to :class:`~.Tensor` by preloading,
  66. and you will get :class:`~.Tensor` instead of the original Numpy array or Python built-in data structrure.
  67. * Tensors' host2device copy and device kernel execution will be overlapped,
  68. which will improve the training speed at the cost of **higher device memory usage** (due to one more batch data on device memory).
  69. This feature saves more time when your NN training time is short or your machine's host PCIe bandwidth for each device is low.
  70. """
  71. __initialized = False
  72. def __init__(
  73. self,
  74. dataset: Dataset,
  75. sampler: Sampler = None,
  76. transform: Transform = None,
  77. collator: Collator = None,
  78. num_workers: int = 0,
  79. timeout: int = 0,
  80. timeout_event: Callable = raise_timeout_error,
  81. divide: bool = False,
  82. preload: bool = False,
  83. ):
  84. if num_workers < 0:
  85. raise ValueError("num_workers should not be negative")
  86. if timeout < 0:
  87. raise ValueError("timeout should not be negative")
  88. if divide and num_workers <= 1:
  89. raise ValueError("divide should not be set to True when num_workers <= 1")
  90. self.dataset = dataset
  91. self.num_workers = num_workers
  92. self.timeout = timeout
  93. self.timeout_event = timeout_event
  94. self.divide = divide
  95. self.preload = preload
  96. if isinstance(dataset, StreamDataset):
  97. self.sampler = sampler if sampler else StreamSampler(batch_size=1)
  98. assert isinstance(
  99. self.sampler, StreamSampler
  100. ), "types of dataset and sampler do not match"
  101. else:
  102. assert isinstance(
  103. dataset, Dataset
  104. ), "Can not recognize this kind of dataset: %s" % type(dataset)
  105. self.sampler = (
  106. sampler
  107. if sampler
  108. else SequentialSampler(dataset, batch_size=1, drop_last=False)
  109. )
  110. assert isinstance(
  111. self.sampler, MapSampler
  112. ), "types of dataset and sampler do not match"
  113. if divide:
  114. if self.sampler.batch_size <= self.num_workers:
  115. raise ValueError(
  116. "batch size must not smaller than num_workers in divide mode."
  117. )
  118. elif self.sampler.batch_size % self.num_workers:
  119. logger.warning(
  120. "batch size is not divisible by num_workers, may lose performance in divide mode."
  121. )
  122. if transform is None:
  123. self.transform = PseudoTransform()
  124. else:
  125. self.transform = transform
  126. if collator is None:
  127. self.collator = Collator()
  128. else:
  129. self.collator = collator
  130. self.__initialized = True
  131. def __iter__(self):
  132. if platform.system() == "Windows" and self.num_workers > 0:
  133. print(
  134. "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
  135. )
  136. self.num_workers = 0
  137. if os.getenv("TERMUX_VERSION"):
  138. # FIXME: termux install pyarrow will build error now
  139. # remove this logic after pyarrow fix this issue
  140. print(
  141. "pyarrow do not support on termux env now, changing num_workers to be zero"
  142. )
  143. self.num_workers = 0
  144. if isinstance(self.dataset, StreamDataset):
  145. if not self.num_workers:
  146. return _SerialStreamDataLoaderIter(self, self.preload)
  147. else:
  148. return _ParallelStreamDataLoaderIter(self, self.preload)
  149. else:
  150. assert isinstance(
  151. self.dataset, Dataset
  152. ), "Can not recognize this kind of dataset: %s" % type(self.dataset)
  153. if not self.num_workers:
  154. return _SerialMapDataLoaderIter(self, self.preload)
  155. else:
  156. return _ParallelMapDataLoaderIter(self, self.preload)
  157. def __len__(self):
  158. return len(self.sampler)
  159. class PreLoader:
  160. def __init__(self, preload):
  161. if preload:
  162. self.default_device = get_default_device()
  163. self.pre_load_device = self.default_device + ":" + str(_sh.get_next())
  164. self.pre_load_device_cache = None
  165. self.preload = preload
  166. """
  167. strategy one: load from numpy data, and generate dtype tensor
  168. """
  169. def _load_tensor(self, batch, cached=True):
  170. if isinstance(batch, np.ndarray):
  171. device = self.pre_load_device if cached else self.default_device
  172. return Tensor(batch, device=device)
  173. elif isinstance(batch, collections.abc.Mapping):
  174. return {k: self._load_tensor(v, cached) for k, v in batch.items()}
  175. elif isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple
  176. return type(batch)(*(self._load_tensor(value, cached) for value in batch))
  177. elif isinstance(batch, collections.abc.Sequence):
  178. return [self._load_tensor(value, cached) for value in batch]
  179. else:
  180. return batch
  181. """
  182. strategy two: load from cache that is already tensor just do d2d copy
  183. """
  184. def _load_cache(self, data):
  185. if isinstance(data, Tensor):
  186. if data.device == self.default_device:
  187. return data
  188. return copy(data, device=self.default_device)
  189. elif isinstance(data, collections.abc.Mapping):
  190. return {k: self._load_cache(v) for k, v in data.items()}
  191. elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
  192. return type(data)(*(self._load_cache(value) for value in data))
  193. elif isinstance(data, collections.abc.Sequence):
  194. return [self._load_cache(value) for value in data]
  195. else:
  196. return data
  197. def _swap_out_cache(self):
  198. out = self._load_cache(self.pre_load_device_cache)
  199. self.pre_load_device_cache = None # clean cache
  200. return out
  201. class _BaseMapDataLoaderIter(PreLoader):
  202. def __init__(self, loader, preload):
  203. super().__init__(preload)
  204. self.dataset = loader.dataset
  205. self.sampler = loader.sampler
  206. self.seed = _random_seed_generator().__next__()
  207. self.transform = loader.transform
  208. self.collator = loader.collator
  209. self.num_workers = loader.num_workers
  210. self.timeout = loader.timeout
  211. self.timeout_event = loader.timeout_event
  212. self.divide = loader.divide
  213. self.num_processed = 0
  214. def _get_next_batch(self):
  215. raise NotImplementedError
  216. def __len__(self):
  217. return len(self.sampler)
  218. def __iter__(self):
  219. return self
  220. def __next__(self):
  221. if self.preload:
  222. cached = self.pre_load_device_cache
  223. if cached is None: # first and last
  224. if self.num_processed >= len(self): # last
  225. raise StopIteration
  226. elif self.num_processed == 0: # first
  227. self._try_load_tensor(cached=False) # first do the h2d
  228. out = self._swap_out_cache()
  229. self._try_load_tensor()
  230. return out
  231. else:
  232. if self.num_processed >= len(self):
  233. raise StopIteration
  234. minibatch = self._get_next_batch()
  235. self.num_processed += 1
  236. return minibatch
  237. def _try_load_tensor(self, cached=True):
  238. if self.num_processed >= len(self):
  239. return
  240. else:
  241. self.num_processed += 1
  242. batch = self._get_next_batch()
  243. self.pre_load_device_cache = self._load_tensor(batch, cached)
  244. class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
  245. def __init__(self, loader, preload):
  246. super(_SerialMapDataLoaderIter, self).__init__(loader, preload)
  247. self.indices_iter = iter(self.sampler)
  248. def _get_next_batch(self):
  249. indices = next(self.indices_iter)
  250. items = [self.dataset[idx] for idx in indices]
  251. trans_items = self.transform.apply_batch(items)
  252. return self.collator.apply(trans_items)
  253. class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
  254. __initialized = False
  255. def __init__(self, loader, preload):
  256. super(_ParallelMapDataLoaderIter, self).__init__(loader, preload)
  257. self.task_queues = [
  258. multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
  259. ]
  260. self.feed_batch_idx = multiprocessing.Value("i", 0)
  261. self.target_batch_idx = multiprocessing.Value("i", 0)
  262. self.shutdown_flag = multiprocessing.Value("i", 0)
  263. self.trans_data_queues = [
  264. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  265. ]
  266. # use shared-memory queue implemented by pyarrow plasma store.
  267. from .tools._queue import PlasmaShmQueue
  268. self.batch_queue = PlasmaShmQueue(maxsize=2)
  269. self.task_feeding_worker = multiprocessing.Process(
  270. target=_task_feeding_loop,
  271. args=(
  272. iter(self.sampler),
  273. self.task_queues,
  274. self.num_workers,
  275. self.divide,
  276. self.shutdown_flag,
  277. self.feed_batch_idx,
  278. ),
  279. daemon=True,
  280. )
  281. gc.collect()
  282. self.task_feeding_worker.start()
  283. self.workers = []
  284. for worker_id in range(self.num_workers):
  285. worker = multiprocessing.Process(
  286. target=_worker_loop,
  287. args=(
  288. self.dataset,
  289. self.task_queues[worker_id],
  290. self.trans_data_queues[worker_id],
  291. self.transform,
  292. self.seed + worker_id + 1,
  293. self.shutdown_flag,
  294. ),
  295. daemon=True,
  296. )
  297. gc.collect()
  298. worker.start()
  299. self.workers.append(worker)
  300. if self.divide:
  301. self.data_collecting_worker = multiprocessing.Process(
  302. target=_data_gathering_loop,
  303. args=(
  304. self.trans_data_queues,
  305. self.batch_queue,
  306. self.collator,
  307. len(self),
  308. self.num_workers,
  309. self.shutdown_flag,
  310. self.target_batch_idx,
  311. ),
  312. daemon=True,
  313. )
  314. else:
  315. self.data_collecting_worker = multiprocessing.Process(
  316. target=_data_selecting_loop,
  317. args=(
  318. self.trans_data_queues,
  319. self.batch_queue,
  320. self.collator,
  321. len(self),
  322. self.num_workers,
  323. self.shutdown_flag,
  324. self.target_batch_idx,
  325. ),
  326. daemon=True,
  327. )
  328. gc.collect()
  329. self.data_collecting_worker.start()
  330. self.__initialized = True
  331. def _check_workers(self):
  332. # Check the status of each worker.
  333. if not self.data_collecting_worker.is_alive():
  334. exitcode = self.data_collecting_worker.exitcode
  335. if exitcode != 0:
  336. raise RuntimeError("data collecting worker died. {}".format(exitcode))
  337. if not self.task_feeding_worker.is_alive():
  338. exitcode = self.task_feeding_worker.exitcode
  339. if exitcode != 0:
  340. raise RuntimeError("task feeding worker died. {}".format(exitcode))
  341. for worker_id, worker in enumerate(self.workers):
  342. if not worker.is_alive():
  343. exitcode = worker.exitcode
  344. if exitcode != 0:
  345. raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode))
  346. logger.debug("all workers are alive.")
  347. def _get_next_batch(self):
  348. start_time = time.time()
  349. while True:
  350. self._check_workers()
  351. try:
  352. return self.batch_queue.get(timeout=1)
  353. except queue.Empty:
  354. logger.debug("batch queue empty!")
  355. waited_time = time.time() - start_time
  356. if self.timeout > 0:
  357. if waited_time > self.timeout:
  358. raise RuntimeError("get_next_batch timeout!")
  359. def _shutdown(self):
  360. with self.shutdown_flag.get_lock():
  361. self.shutdown_flag.value = 1
  362. if self.task_feeding_worker.is_alive():
  363. self.task_feeding_worker.terminate()
  364. self.task_feeding_worker.join()
  365. if self.data_collecting_worker.is_alive():
  366. self.data_collecting_worker.terminate()
  367. self.data_collecting_worker.join()
  368. for worker in self.workers:
  369. if worker.is_alive():
  370. worker.terminate()
  371. worker.join()
  372. for q in self.trans_data_queues:
  373. q.cancel_join_thread()
  374. q.close()
  375. for q in self.task_queues:
  376. q.cancel_join_thread()
  377. q.close()
  378. self.batch_queue.cancel_join_thread()
  379. self.batch_queue.close()
  380. def __del__(self):
  381. if self.__initialized:
  382. self._shutdown()
  383. class _BaseStreamDataLoaderIter(PreLoader):
  384. def __init__(self, loader, preload):
  385. super().__init__(preload)
  386. self.dataset = loader.dataset
  387. self.sampler = loader.sampler
  388. self.transform = loader.transform
  389. self.collator = loader.collator
  390. self.num_workers = loader.num_workers
  391. self.timeout = loader.timeout
  392. self.timeout_event = loader.timeout_event
  393. def _get_next_batch(self):
  394. raise NotImplementedError
  395. def _process_raw_data(self, raw_data):
  396. assert len(raw_data) == 2 and isinstance(
  397. raw_data[0], bool
  398. ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
  399. if not raw_data[0]:
  400. data = list((x,) for x in raw_data[1])
  401. else:
  402. data = raw_data[1]
  403. ret = []
  404. for idx in range(len(data[0])):
  405. ret.append(tuple(e[idx] for e in data))
  406. return ret
  407. def __iter__(self):
  408. return self
  409. def __next__(self):
  410. if self.preload:
  411. if self.pre_load_device_cache is None:
  412. self._try_load_tensor(cached=False) # load in current
  413. out = self._swap_out_cache()
  414. self._try_load_tensor() # load in cached
  415. return out
  416. else:
  417. return self._get_next_batch()
  418. def _try_load_tensor(self, cached=True):
  419. batch = self._get_next_batch()
  420. self.pre_load_device_cache = self._load_tensor(batch, cached)
  421. class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  422. def __init__(self, loader, preload):
  423. super().__init__(loader, preload)
  424. self.dataset_iter = iter(self.dataset)
  425. self.idx = 0
  426. self.unused = []
  427. def _try_get_raw_data(self, start_time):
  428. raw_data = None
  429. while not raw_data:
  430. try:
  431. if self.timeout > 0:
  432. timer = threading.Timer(self.timeout, thread.interrupt_main)
  433. timer.start()
  434. raw_data = next(self.dataset_iter)
  435. if self.timeout > 0:
  436. timer.cancel()
  437. except KeyboardInterrupt:
  438. raw_data = self.timeout_event()
  439. except:
  440. if self.timeout > 0:
  441. timer.cancel()
  442. waited_time = time.time() - start_time
  443. if waited_time > self.timeout:
  444. raw_data = self.timeout_event()
  445. return raw_data
  446. def _get_next_batch(self):
  447. ret = []
  448. start_time = time.time()
  449. while len(ret) < self.sampler.batch_size:
  450. if len(self.unused) != 0:
  451. batch_data = self.unused
  452. else:
  453. raw_data = self._try_get_raw_data(start_time)
  454. batch_data = self._process_raw_data(raw_data)
  455. while len(batch_data) != 0 and len(ret) < self.sampler.batch_size:
  456. data = batch_data.pop()
  457. ret.append(self.transform.apply(data))
  458. self.unused = batch_data
  459. return self.collator.apply(ret)
  460. class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  461. __initialized = False
  462. def __init__(self, loader, preload):
  463. super().__init__(loader, preload)
  464. self.shutdown_flag = multiprocessing.Value("i", 0)
  465. self.raw_data_queues = [
  466. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  467. ]
  468. self.trans_data_queues = [
  469. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  470. ]
  471. # shared-memory queue implemented by pyarrow plasma store
  472. from .tools._queue import PlasmaShmQueue
  473. self.batch_queue = PlasmaShmQueue(maxsize=2)
  474. self.recieve_worker = multiprocessing.Process(
  475. target=self._worker_to_raw_data_queues, daemon=True
  476. )
  477. gc.collect()
  478. self.recieve_worker.start()
  479. self.transform_workers = []
  480. for worker_id in range(self.num_workers):
  481. worker = multiprocessing.Process(
  482. target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True
  483. )
  484. gc.collect()
  485. worker.start()
  486. self.transform_workers.append(worker)
  487. self.collect_worker = multiprocessing.Process(
  488. target=self._worker_to_batch_queue, daemon=True
  489. )
  490. gc.collect()
  491. self.collect_worker.start()
  492. self.__initialized = True
  493. def _put_raw_data_queues(self, raw_data, qidx):
  494. batch_data = self._process_raw_data(raw_data)
  495. for data in batch_data:
  496. while True:
  497. qidx = qidx % self.num_workers
  498. try:
  499. self.raw_data_queues[qidx].put(data)
  500. break
  501. except queue.Full:
  502. if self.shutdown_flag.value == 1:
  503. break
  504. logger.debug("raw data queue %d is full" % qidx)
  505. finally:
  506. qidx += 1
  507. return qidx
  508. def _worker_to_raw_data_queues(self):
  509. dataset_iter = iter(self.dataset)
  510. qidx = 0
  511. while True:
  512. if self.shutdown_flag.value == 1:
  513. break
  514. raw_data = next(dataset_iter)
  515. qidx = self._put_raw_data_queues(raw_data, qidx)
  516. def _worker_to_trans_data_queues(self, worker_id):
  517. while True:
  518. if self.shutdown_flag.value == 1:
  519. break
  520. try:
  521. data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)
  522. except queue.Empty:
  523. continue
  524. trans_data = self.transform.apply(data)
  525. while True:
  526. try:
  527. self.trans_data_queues[worker_id].put(trans_data)
  528. break
  529. except queue.Full:
  530. if self.shutdown_flag.value == 1:
  531. break
  532. logger.debug("batch queue if full")
  533. def _worker_to_batch_queue(self):
  534. cnt = -1
  535. trans_items = []
  536. while True:
  537. if self.shutdown_flag.value == 1:
  538. break
  539. cnt += 1
  540. queue_id = cnt % self.num_workers
  541. try:
  542. trans_item = self.trans_data_queues[queue_id].get(
  543. timeout=GLOBAL_TIMEOUT
  544. )
  545. except queue.Empty:
  546. continue
  547. trans_items.append(trans_item)
  548. if len(trans_items) == self.sampler.batch_size:
  549. batch_data = self.collator.apply(trans_items)
  550. while True:
  551. try:
  552. self.batch_queue.put(batch_data, timeout=1)
  553. break
  554. except queue.Full:
  555. if self.shutdown_flag.value == 1:
  556. break
  557. logger.debug("batch queue is full")
  558. trans_items = []
  559. def _check_workers(self):
  560. if not self.collect_worker.is_alive():
  561. exitcode = self.collect_worker.exitcode
  562. if exitcode != 0:
  563. raise RuntimeError("collator worker died. {}".format(exitcode))
  564. for worker_id, worker in enumerate(self.transform_workers):
  565. if not worker.is_alive():
  566. exitcode = worker.exitcode
  567. if exitcode != 0:
  568. raise RuntimeError(
  569. "worker: {} died. {}".format(worker_id, exitcode)
  570. )
  571. def _get_next_batch(self):
  572. start_time = time.time()
  573. while True:
  574. self._check_workers()
  575. try:
  576. return self.batch_queue.get(timeout=1)
  577. except queue.Empty:
  578. logger.debug("batch queue empty!")
  579. waited_time = time.time() - start_time
  580. if self.timeout > 0 and waited_time > self.timeout:
  581. self._put_raw_data_queues(self.timeout_event(), 0)
  582. def _shutdown(self):
  583. with self.shutdown_flag.get_lock():
  584. self.shutdown_flag.value = 1
  585. if self.recieve_worker.is_alive():
  586. self.recieve_worker.terminate()
  587. self.recieve_worker.join()
  588. if self.collect_worker.is_alive():
  589. self.collect_worker.terminate()
  590. self.collect_worker.join()
  591. for worker in self.transform_workers:
  592. if worker.is_alive():
  593. worker.terminate()
  594. worker.join()
  595. for q in self.raw_data_queues:
  596. q.cancel_join_thread()
  597. q.close()
  598. for q in self.trans_data_queues:
  599. q.cancel_join_thread()
  600. q.close()
  601. self.batch_queue.cancel_join_thread()
  602. self.batch_queue.close()
  603. def __del__(self):
  604. if self.__initialized:
  605. self._shutdown()
  606. def _task_feeding_loop(
  607. indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
  608. ):
  609. # Feed the indices into the task queues
  610. while True:
  611. if shutdown_flag.value == 1:
  612. break
  613. batch_idx = feed_batch_idx.value
  614. try:
  615. indices = next(indices_iter)
  616. except StopIteration:
  617. break
  618. if divide:
  619. # make sure all task_queues is ready for put
  620. while any([q.full() for q in task_queues]):
  621. if shutdown_flag.value == 1:
  622. return
  623. # divide into small pieces, feed to different workers.
  624. sub_num = math.ceil(len(indices) / num_workers)
  625. for worker_id in range(num_workers):
  626. sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num]
  627. task_queues[worker_id].put((batch_idx, sub_indices))
  628. else:
  629. # distribute tasks to different workers uniformly.
  630. target_id = batch_idx % num_workers
  631. while task_queues[target_id].full():
  632. if shutdown_flag.value == 1:
  633. return
  634. task_queues[target_id].put((batch_idx, indices))
  635. with feed_batch_idx.get_lock():
  636. feed_batch_idx.value += 1
  637. def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag):
  638. # Get dataset items and do the transform
  639. random.seed(seed)
  640. np.random.seed(seed)
  641. while True:
  642. if shutdown_flag.value == 1:
  643. break
  644. try:
  645. batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT)
  646. except queue.Empty:
  647. continue
  648. if len(indices) > 0:
  649. items = [dataset[idx] for idx in indices]
  650. trans_items = transform.apply_batch(items)
  651. else:
  652. # in case of incomplete last batch
  653. trans_items = ()
  654. while True:
  655. try:
  656. trans_data_queue.put((batch_idx, trans_items), timeout=1)
  657. break
  658. except queue.Full:
  659. if shutdown_flag.value == 1:
  660. break
  661. logger.debug("batch part queue is full!")
  662. def _data_gathering_loop(
  663. trans_data_queues,
  664. batch_queue,
  665. collator,
  666. length,
  667. num_workers,
  668. shutdown_flag,
  669. target_idx,
  670. ):
  671. # Gathering the small pieces of batch data into full batch data
  672. while True:
  673. if shutdown_flag.value == 1:
  674. break
  675. target_batch_idx = target_idx.value
  676. if target_batch_idx >= length:
  677. break
  678. full_trans_items = []
  679. for worker_id in range(num_workers):
  680. while True:
  681. try:
  682. batch_idx, trans_items = trans_data_queues[worker_id].get(
  683. timeout=GLOBAL_TIMEOUT
  684. )
  685. break
  686. except queue.Empty:
  687. if shutdown_flag.value == 1:
  688. break
  689. logger.debug(
  690. "worker:{} data queue get timeout! target batch idx:{}".format(
  691. worker_id, target_batch_idx
  692. )
  693. )
  694. if batch_idx != target_batch_idx:
  695. raise RuntimeError(
  696. "Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
  697. worker_id
  698. )
  699. )
  700. else:
  701. full_trans_items.extend(trans_items)
  702. # Merge different parts into a batch.
  703. full_batch = collator.apply(full_trans_items)
  704. while True:
  705. try:
  706. batch_queue.put(full_batch, timeout=1)
  707. break
  708. except queue.Full:
  709. if shutdown_flag.value == 1:
  710. break
  711. logger.debug("batch queue is full!")
  712. with target_idx.get_lock():
  713. target_idx.value += 1
  714. batch_queue.disconnect_client()
  715. def _data_selecting_loop(
  716. trans_data_queues,
  717. batch_queue,
  718. collator,
  719. length,
  720. num_workers,
  721. shutdown_flag,
  722. target_idx,
  723. ):
  724. # Make sure that batch is generated exactly with the same order as generated indices
  725. while True:
  726. if shutdown_flag.value == 1:
  727. break
  728. target_batch_idx = target_idx.value
  729. if target_batch_idx >= length:
  730. break
  731. target_worker_id = target_batch_idx % num_workers
  732. while True:
  733. try:
  734. batch_idx, trans_items = trans_data_queues[target_worker_id].get(
  735. timeout=GLOBAL_TIMEOUT
  736. )
  737. batch_data = collator.apply(trans_items)
  738. break
  739. except queue.Empty:
  740. if shutdown_flag.value == 1:
  741. break
  742. logger.debug(
  743. "worker:{} data queue get timeout! target batch idx:{}".format(
  744. target_worker_id, target_batch_idx
  745. )
  746. )
  747. if batch_idx != target_batch_idx:
  748. raise RuntimeError(
  749. "batch_idx {} mismatch the target_batch_idx {}".format(
  750. batch_idx, target_batch_idx
  751. )
  752. )
  753. while True:
  754. try:
  755. batch_queue.put(batch_data, timeout=1)
  756. break
  757. except queue.Full:
  758. if shutdown_flag.value == 1:
  759. break
  760. logger.debug("batch queue is full!")
  761. with target_idx.get_lock():
  762. target_idx.value += 1
  763. batch_queue.disconnect_client()