You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884
  1. # -*- coding: utf-8 -*-
  2. import collections
  3. import gc
  4. import math
  5. import multiprocessing
  6. import os
  7. import platform
  8. import queue
  9. import random
  10. import threading
  11. import time
  12. from typing import Callable, Union
  13. import numpy as np
  14. from ..device import _sh, get_default_device
  15. from ..functional.tensor import copy
  16. from ..logger import get_logger
  17. from ..random.rng import _random_seed_generator
  18. from ..tensor import Tensor
  19. from .collator import Collator
  20. from .dataset import Dataset, StreamDataset
  21. from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
  22. from .transform import PseudoTransform, Transform
  23. try:
  24. import thread
  25. except:
  26. import _thread as thread
  27. logger = get_logger(__name__)
  28. GLOBAL_TIMEOUT = 5
  29. def raise_timeout_error():
  30. raise RuntimeError("dataloader timeout")
  31. class DataLoader:
  32. r"""Provides a convenient way to iterate on a given dataset.
  33. DataLoader combines a dataset with
  34. :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`,
  35. make it flexible to get minibatch continually from a dataset.
  36. Args:
  37. dataset: dataset from which to load the minibatch.
  38. sampler: defines the strategy to sample data from the dataset.
  39. transform: defined the transforming strategy for a sampled batch.
  40. Default: None
  41. collator: defined the merging strategy for a transformed batch.
  42. Default: None
  43. num_workers: the number of sub-process to load, transform and collate
  44. the batch. ``0`` means using single-process. Default: 0
  45. timeout: if positive, means the timeout value(second) for collecting a
  46. batch from workers. Default: 0
  47. timeout_event: callback function triggered by timeout, default to raise
  48. runtime error.
  49. divide: define the paralleling strategy in multi-processing mode.
  50. ``True`` means one batch is divided into :attr:`num_workers` pieces, and
  51. the workers will process these pieces parallelly. ``False`` means
  52. different sub-process will process different batch. Default: False
  53. preload: whether to enable the preloading strategy of the dataloader. When enabling, the dataloader will preload one batch to the device memory to speed up the whole training process.
  54. All values in the map, list, and tuple will be converted to :class:`~.Tensor` by preloading, and you will get :class:`~.Tensor` instead of the original Numpy array or Python number.
  55. .. note::
  56. By enabling preload, tensors' host2device copy and device kernel execution will be overlapped, which will improve the training speed at the cost of higher device memory usage (due to one more batch data on device memory).
  57. This feature saves more time when your NN training time is short or your machine's host PCIe bandwidth for each device is low.
  58. """
  59. __initialized = False
  60. def __init__(
  61. self,
  62. dataset: Dataset,
  63. sampler: Sampler = None,
  64. transform: Transform = None,
  65. collator: Collator = None,
  66. num_workers: int = 0,
  67. timeout: int = 0,
  68. timeout_event: Callable = raise_timeout_error,
  69. divide: bool = False,
  70. preload: bool = False,
  71. ):
  72. if num_workers < 0:
  73. raise ValueError("num_workers should not be negative")
  74. if timeout < 0:
  75. raise ValueError("timeout should not be negative")
  76. if divide and num_workers <= 1:
  77. raise ValueError("divide should not be set to True when num_workers <= 1")
  78. self.dataset = dataset
  79. self.num_workers = num_workers
  80. self.timeout = timeout
  81. self.timeout_event = timeout_event
  82. self.divide = divide
  83. self.preload = preload
  84. if isinstance(dataset, StreamDataset):
  85. self.sampler = sampler if sampler else StreamSampler(batch_size=1)
  86. assert isinstance(
  87. self.sampler, StreamSampler
  88. ), "types of dataset and sampler do not match"
  89. else:
  90. assert isinstance(
  91. dataset, Dataset
  92. ), "Can not recognize this kind of dataset: %s" % type(dataset)
  93. self.sampler = (
  94. sampler
  95. if sampler
  96. else SequentialSampler(dataset, batch_size=1, drop_last=False)
  97. )
  98. assert isinstance(
  99. self.sampler, MapSampler
  100. ), "types of dataset and sampler do not match"
  101. if divide:
  102. if self.sampler.batch_size <= self.num_workers:
  103. raise ValueError(
  104. "batch size must not smaller than num_workers in divide mode."
  105. )
  106. elif self.sampler.batch_size % self.num_workers:
  107. logger.warning(
  108. "batch size is not divisible by num_workers, may lose performance in divide mode."
  109. )
  110. if transform is None:
  111. self.transform = PseudoTransform()
  112. else:
  113. self.transform = transform
  114. if collator is None:
  115. self.collator = Collator()
  116. else:
  117. self.collator = collator
  118. self.__initialized = True
  119. def __iter__(self):
  120. if platform.system() == "Windows" and self.num_workers > 0:
  121. print(
  122. "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
  123. )
  124. self.num_workers = 0
  125. if os.getenv("TERMUX_VERSION"):
  126. # FIXME: termux install pyarrow will build error now
  127. # remove this logic after pyarrow fix this issue
  128. print(
  129. "pyarrow do not support on termux env now, changing num_workers to be zero"
  130. )
  131. self.num_workers = 0
  132. if isinstance(self.dataset, StreamDataset):
  133. if not self.num_workers:
  134. return _SerialStreamDataLoaderIter(self, self.preload)
  135. else:
  136. return _ParallelStreamDataLoaderIter(self, self.preload)
  137. else:
  138. assert isinstance(
  139. self.dataset, Dataset
  140. ), "Can not recognize this kind of dataset: %s" % type(self.dataset)
  141. if not self.num_workers:
  142. return _SerialMapDataLoaderIter(self, self.preload)
  143. else:
  144. return _ParallelMapDataLoaderIter(self, self.preload)
  145. def __len__(self):
  146. return len(self.sampler)
  147. class PreLoader:
  148. def __init__(self, preload):
  149. if preload:
  150. self.default_device = get_default_device()
  151. self.pre_load_device = self.default_device + ":" + str(_sh.get_next())
  152. self.pre_load_device_cache = None
  153. self.preload = preload
  154. """
  155. strategy one: load from numpy data, and generate dtype tensor
  156. """
  157. def _load_tensor(self, batch, cached=True):
  158. if isinstance(batch, np.ndarray):
  159. device = self.pre_load_device if cached else self.default_device
  160. return Tensor(batch, device=device)
  161. elif isinstance(batch, collections.abc.Mapping):
  162. return {k: self._load_tensor(v, cached) for k, v in batch.items()}
  163. elif isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple
  164. return type(batch)(*(self._load_tensor(value, cached) for value in batch))
  165. elif isinstance(batch, collections.abc.Sequence):
  166. return [self._load_tensor(value, cached) for value in batch]
  167. else:
  168. return batch
  169. """
  170. strategy two: load from cache that is already tensor just do d2d copy
  171. """
  172. def _load_cache(self, data):
  173. if isinstance(data, Tensor):
  174. if data.device == self.default_device:
  175. return data
  176. return copy(data, device=self.default_device)
  177. elif isinstance(data, collections.abc.Mapping):
  178. return {k: self._load_cache(v) for k, v in data.items()}
  179. elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
  180. return type(data)(*(self._load_cache(value) for value in data))
  181. elif isinstance(data, collections.abc.Sequence):
  182. return [self._load_cache(value) for value in data]
  183. else:
  184. return data
  185. def _swap_out_cache(self):
  186. out = self._load_cache(self.pre_load_device_cache)
  187. self.pre_load_device_cache = None # clean cache
  188. return out
  189. class _BaseMapDataLoaderIter(PreLoader):
  190. def __init__(self, loader, preload):
  191. super().__init__(preload)
  192. self.dataset = loader.dataset
  193. self.sampler = loader.sampler
  194. self.seed = _random_seed_generator().__next__()
  195. self.transform = loader.transform
  196. self.collator = loader.collator
  197. self.num_workers = loader.num_workers
  198. self.timeout = loader.timeout
  199. self.timeout_event = loader.timeout_event
  200. self.divide = loader.divide
  201. self.num_processed = 0
  202. def _get_next_batch(self):
  203. raise NotImplementedError
  204. def __len__(self):
  205. return len(self.sampler)
  206. def __iter__(self):
  207. return self
  208. def __next__(self):
  209. if self.preload:
  210. cached = self.pre_load_device_cache
  211. if cached is None: # first and last
  212. if self.num_processed >= len(self): # last
  213. raise StopIteration
  214. elif self.num_processed == 0: # first
  215. self._try_load_tensor(cached=False) # first do the h2d
  216. out = self._swap_out_cache()
  217. self._try_load_tensor()
  218. return out
  219. else:
  220. if self.num_processed >= len(self):
  221. raise StopIteration
  222. minibatch = self._get_next_batch()
  223. self.num_processed += 1
  224. return minibatch
  225. def _try_load_tensor(self, cached=True):
  226. if self.num_processed >= len(self):
  227. return
  228. else:
  229. self.num_processed += 1
  230. batch = self._get_next_batch()
  231. self.pre_load_device_cache = self._load_tensor(batch, cached)
  232. class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
  233. def __init__(self, loader, preload):
  234. super(_SerialMapDataLoaderIter, self).__init__(loader, preload)
  235. self.indices_iter = iter(self.sampler)
  236. def _get_next_batch(self):
  237. indices = next(self.indices_iter)
  238. items = [self.dataset[idx] for idx in indices]
  239. trans_items = self.transform.apply_batch(items)
  240. return self.collator.apply(trans_items)
  241. class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
  242. __initialized = False
  243. def __init__(self, loader, preload):
  244. super(_ParallelMapDataLoaderIter, self).__init__(loader, preload)
  245. self.task_queues = [
  246. multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
  247. ]
  248. self.feed_batch_idx = multiprocessing.Value("i", 0)
  249. self.target_batch_idx = multiprocessing.Value("i", 0)
  250. self.shutdown_flag = multiprocessing.Value("i", 0)
  251. self.trans_data_queues = [
  252. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  253. ]
  254. # use shared-memory queue implemented by pyarrow plasma store.
  255. from .tools._queue import PlasmaShmQueue
  256. self.batch_queue = PlasmaShmQueue(maxsize=2)
  257. self.task_feeding_worker = multiprocessing.Process(
  258. target=_task_feeding_loop,
  259. args=(
  260. iter(self.sampler),
  261. self.task_queues,
  262. self.num_workers,
  263. self.divide,
  264. self.shutdown_flag,
  265. self.feed_batch_idx,
  266. ),
  267. daemon=True,
  268. )
  269. gc.collect()
  270. self.task_feeding_worker.start()
  271. self.workers = []
  272. for worker_id in range(self.num_workers):
  273. worker = multiprocessing.Process(
  274. target=_worker_loop,
  275. args=(
  276. self.dataset,
  277. self.task_queues[worker_id],
  278. self.trans_data_queues[worker_id],
  279. self.transform,
  280. self.seed + worker_id + 1,
  281. self.shutdown_flag,
  282. ),
  283. daemon=True,
  284. )
  285. gc.collect()
  286. worker.start()
  287. self.workers.append(worker)
  288. if self.divide:
  289. self.data_collecting_worker = multiprocessing.Process(
  290. target=_data_gathering_loop,
  291. args=(
  292. self.trans_data_queues,
  293. self.batch_queue,
  294. self.collator,
  295. len(self),
  296. self.num_workers,
  297. self.shutdown_flag,
  298. self.target_batch_idx,
  299. ),
  300. daemon=True,
  301. )
  302. else:
  303. self.data_collecting_worker = multiprocessing.Process(
  304. target=_data_selecting_loop,
  305. args=(
  306. self.trans_data_queues,
  307. self.batch_queue,
  308. self.collator,
  309. len(self),
  310. self.num_workers,
  311. self.shutdown_flag,
  312. self.target_batch_idx,
  313. ),
  314. daemon=True,
  315. )
  316. gc.collect()
  317. self.data_collecting_worker.start()
  318. self.__initialized = True
  319. def _check_workers(self):
  320. # Check the status of each worker.
  321. if not self.data_collecting_worker.is_alive():
  322. exitcode = self.data_collecting_worker.exitcode
  323. if exitcode != 0:
  324. raise RuntimeError("data collecting worker died. {}".format(exitcode))
  325. if not self.task_feeding_worker.is_alive():
  326. exitcode = self.task_feeding_worker.exitcode
  327. if exitcode != 0:
  328. raise RuntimeError("task feeding worker died. {}".format(exitcode))
  329. for worker_id, worker in enumerate(self.workers):
  330. if not worker.is_alive():
  331. exitcode = worker.exitcode
  332. if exitcode != 0:
  333. raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode))
  334. logger.debug("all workers are alive.")
  335. def _get_next_batch(self):
  336. start_time = time.time()
  337. while True:
  338. self._check_workers()
  339. try:
  340. return self.batch_queue.get(timeout=1)
  341. except queue.Empty:
  342. logger.debug("batch queue empty!")
  343. waited_time = time.time() - start_time
  344. if self.timeout > 0:
  345. if waited_time > self.timeout:
  346. raise RuntimeError("get_next_batch timeout!")
  347. def _shutdown(self):
  348. with self.shutdown_flag.get_lock():
  349. self.shutdown_flag.value = 1
  350. if self.task_feeding_worker.is_alive():
  351. self.task_feeding_worker.terminate()
  352. self.task_feeding_worker.join()
  353. if self.data_collecting_worker.is_alive():
  354. self.data_collecting_worker.terminate()
  355. self.data_collecting_worker.join()
  356. for worker in self.workers:
  357. if worker.is_alive():
  358. worker.terminate()
  359. worker.join()
  360. for q in self.trans_data_queues:
  361. q.cancel_join_thread()
  362. q.close()
  363. for q in self.task_queues:
  364. q.cancel_join_thread()
  365. q.close()
  366. self.batch_queue.cancel_join_thread()
  367. self.batch_queue.close()
  368. def __del__(self):
  369. if self.__initialized:
  370. self._shutdown()
  371. class _BaseStreamDataLoaderIter(PreLoader):
  372. def __init__(self, loader, preload):
  373. super().__init__(preload)
  374. self.dataset = loader.dataset
  375. self.sampler = loader.sampler
  376. self.transform = loader.transform
  377. self.collator = loader.collator
  378. self.num_workers = loader.num_workers
  379. self.timeout = loader.timeout
  380. self.timeout_event = loader.timeout_event
  381. def _get_next_batch(self):
  382. raise NotImplementedError
  383. def _process_raw_data(self, raw_data):
  384. assert len(raw_data) == 2 and isinstance(
  385. raw_data[0], bool
  386. ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
  387. if not raw_data[0]:
  388. data = list((x,) for x in raw_data[1])
  389. else:
  390. data = raw_data[1]
  391. ret = []
  392. for idx in range(len(data[0])):
  393. ret.append(tuple(e[idx] for e in data))
  394. return ret
  395. def __iter__(self):
  396. return self
  397. def __next__(self):
  398. if self.preload:
  399. if self.pre_load_device_cache is None:
  400. self._try_load_tensor(cached=False) # load in current
  401. out = self._swap_out_cache()
  402. self._try_load_tensor() # load in cached
  403. return out
  404. else:
  405. return self._get_next_batch()
  406. def _try_load_tensor(self, cached=True):
  407. batch = self._get_next_batch()
  408. self.pre_load_device_cache = self._load_tensor(batch, cached)
  409. class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  410. def __init__(self, loader, preload):
  411. super().__init__(loader, preload)
  412. self.dataset_iter = iter(self.dataset)
  413. self.idx = 0
  414. self.unused = []
  415. def _try_get_raw_data(self, start_time):
  416. raw_data = None
  417. while not raw_data:
  418. try:
  419. if self.timeout > 0:
  420. timer = threading.Timer(self.timeout, thread.interrupt_main)
  421. timer.start()
  422. raw_data = next(self.dataset_iter)
  423. if self.timeout > 0:
  424. timer.cancel()
  425. except KeyboardInterrupt:
  426. raw_data = self.timeout_event()
  427. except:
  428. if self.timeout > 0:
  429. timer.cancel()
  430. waited_time = time.time() - start_time
  431. if waited_time > self.timeout:
  432. raw_data = self.timeout_event()
  433. return raw_data
  434. def _get_next_batch(self):
  435. ret = []
  436. start_time = time.time()
  437. while len(ret) < self.sampler.batch_size:
  438. if len(self.unused) != 0:
  439. batch_data = self.unused
  440. else:
  441. raw_data = self._try_get_raw_data(start_time)
  442. batch_data = self._process_raw_data(raw_data)
  443. while len(batch_data) != 0 and len(ret) < self.sampler.batch_size:
  444. data = batch_data.pop()
  445. ret.append(self.transform.apply(data))
  446. self.unused = batch_data
  447. return self.collator.apply(ret)
  448. class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  449. __initialized = False
  450. def __init__(self, loader, preload):
  451. super().__init__(loader, preload)
  452. self.shutdown_flag = multiprocessing.Value("i", 0)
  453. self.raw_data_queues = [
  454. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  455. ]
  456. self.trans_data_queues = [
  457. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  458. ]
  459. # shared-memory queue implemented by pyarrow plasma store
  460. from .tools._queue import PlasmaShmQueue
  461. self.batch_queue = PlasmaShmQueue(maxsize=2)
  462. self.recieve_worker = multiprocessing.Process(
  463. target=self._worker_to_raw_data_queues, daemon=True
  464. )
  465. gc.collect()
  466. self.recieve_worker.start()
  467. self.transform_workers = []
  468. for worker_id in range(self.num_workers):
  469. worker = multiprocessing.Process(
  470. target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True
  471. )
  472. gc.collect()
  473. worker.start()
  474. self.transform_workers.append(worker)
  475. self.collect_worker = multiprocessing.Process(
  476. target=self._worker_to_batch_queue, daemon=True
  477. )
  478. gc.collect()
  479. self.collect_worker.start()
  480. self.__initialized = True
  481. def _put_raw_data_queues(self, raw_data, qidx):
  482. batch_data = self._process_raw_data(raw_data)
  483. for data in batch_data:
  484. while True:
  485. qidx = qidx % self.num_workers
  486. try:
  487. self.raw_data_queues[qidx].put(data)
  488. break
  489. except queue.Full:
  490. if self.shutdown_flag.value == 1:
  491. break
  492. logger.debug("raw data queue %d is full" % qidx)
  493. finally:
  494. qidx += 1
  495. return qidx
  496. def _worker_to_raw_data_queues(self):
  497. dataset_iter = iter(self.dataset)
  498. qidx = 0
  499. while True:
  500. if self.shutdown_flag.value == 1:
  501. break
  502. raw_data = next(dataset_iter)
  503. qidx = self._put_raw_data_queues(raw_data, qidx)
  504. def _worker_to_trans_data_queues(self, worker_id):
  505. while True:
  506. if self.shutdown_flag.value == 1:
  507. break
  508. try:
  509. data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)
  510. except queue.Empty:
  511. continue
  512. trans_data = self.transform.apply(data)
  513. while True:
  514. try:
  515. self.trans_data_queues[worker_id].put(trans_data)
  516. break
  517. except queue.Full:
  518. if self.shutdown_flag.value == 1:
  519. break
  520. logger.debug("batch queue if full")
  521. def _worker_to_batch_queue(self):
  522. cnt = -1
  523. trans_items = []
  524. while True:
  525. if self.shutdown_flag.value == 1:
  526. break
  527. cnt += 1
  528. queue_id = cnt % self.num_workers
  529. try:
  530. trans_item = self.trans_data_queues[queue_id].get(
  531. timeout=GLOBAL_TIMEOUT
  532. )
  533. except queue.Empty:
  534. continue
  535. trans_items.append(trans_item)
  536. if len(trans_items) == self.sampler.batch_size:
  537. batch_data = self.collator.apply(trans_items)
  538. while True:
  539. try:
  540. self.batch_queue.put(batch_data, timeout=1)
  541. break
  542. except queue.Full:
  543. if self.shutdown_flag.value == 1:
  544. break
  545. logger.debug("batch queue is full")
  546. trans_items = []
  547. def _check_workers(self):
  548. if not self.collect_worker.is_alive():
  549. exitcode = self.collect_worker.exitcode
  550. if exitcode != 0:
  551. raise RuntimeError("collator worker died. {}".format(exitcode))
  552. for worker_id, worker in enumerate(self.transform_workers):
  553. if not worker.is_alive():
  554. exitcode = worker.exitcode
  555. if exitcode != 0:
  556. raise RuntimeError(
  557. "worker: {} died. {}".format(worker_id, exitcode)
  558. )
  559. def _get_next_batch(self):
  560. start_time = time.time()
  561. while True:
  562. self._check_workers()
  563. try:
  564. return self.batch_queue.get(timeout=1)
  565. except queue.Empty:
  566. logger.debug("batch queue empty!")
  567. waited_time = time.time() - start_time
  568. if self.timeout > 0 and waited_time > self.timeout:
  569. self._put_raw_data_queues(self.timeout_event(), 0)
  570. def _shutdown(self):
  571. with self.shutdown_flag.get_lock():
  572. self.shutdown_flag.value = 1
  573. if self.recieve_worker.is_alive():
  574. self.recieve_worker.terminate()
  575. self.recieve_worker.join()
  576. if self.collect_worker.is_alive():
  577. self.collect_worker.terminate()
  578. self.collect_worker.join()
  579. for worker in self.transform_workers:
  580. if worker.is_alive():
  581. worker.terminate()
  582. worker.join()
  583. for q in self.raw_data_queues:
  584. q.cancel_join_thread()
  585. q.close()
  586. for q in self.trans_data_queues:
  587. q.cancel_join_thread()
  588. q.close()
  589. self.batch_queue.cancel_join_thread()
  590. self.batch_queue.close()
  591. def __del__(self):
  592. if self.__initialized:
  593. self._shutdown()
  594. def _task_feeding_loop(
  595. indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
  596. ):
  597. # Feed the indices into the task queues
  598. while True:
  599. if shutdown_flag.value == 1:
  600. break
  601. batch_idx = feed_batch_idx.value
  602. try:
  603. indices = next(indices_iter)
  604. except StopIteration:
  605. break
  606. if divide:
  607. # make sure all task_queues is ready for put
  608. while any([q.full() for q in task_queues]):
  609. if shutdown_flag.value == 1:
  610. return
  611. # divide into small pieces, feed to different workers.
  612. sub_num = math.ceil(len(indices) / num_workers)
  613. for worker_id in range(num_workers):
  614. sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num]
  615. task_queues[worker_id].put((batch_idx, sub_indices))
  616. else:
  617. # distribute tasks to different workers uniformly.
  618. target_id = batch_idx % num_workers
  619. while task_queues[target_id].full():
  620. if shutdown_flag.value == 1:
  621. return
  622. task_queues[target_id].put((batch_idx, indices))
  623. with feed_batch_idx.get_lock():
  624. feed_batch_idx.value += 1
  625. def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag):
  626. # Get dataset items and do the transform
  627. random.seed(seed)
  628. np.random.seed(seed)
  629. while True:
  630. if shutdown_flag.value == 1:
  631. break
  632. try:
  633. batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT)
  634. except queue.Empty:
  635. continue
  636. if len(indices) > 0:
  637. items = [dataset[idx] for idx in indices]
  638. trans_items = transform.apply_batch(items)
  639. else:
  640. # in case of incomplete last batch
  641. trans_items = ()
  642. while True:
  643. try:
  644. trans_data_queue.put((batch_idx, trans_items), timeout=1)
  645. break
  646. except queue.Full:
  647. if shutdown_flag.value == 1:
  648. break
  649. logger.debug("batch part queue is full!")
  650. def _data_gathering_loop(
  651. trans_data_queues,
  652. batch_queue,
  653. collator,
  654. length,
  655. num_workers,
  656. shutdown_flag,
  657. target_idx,
  658. ):
  659. # Gathering the small pieces of batch data into full batch data
  660. while True:
  661. if shutdown_flag.value == 1:
  662. break
  663. target_batch_idx = target_idx.value
  664. if target_batch_idx >= length:
  665. break
  666. full_trans_items = []
  667. for worker_id in range(num_workers):
  668. while True:
  669. try:
  670. batch_idx, trans_items = trans_data_queues[worker_id].get(
  671. timeout=GLOBAL_TIMEOUT
  672. )
  673. break
  674. except queue.Empty:
  675. if shutdown_flag.value == 1:
  676. break
  677. logger.debug(
  678. "worker:{} data queue get timeout! target batch idx:{}".format(
  679. worker_id, target_batch_idx
  680. )
  681. )
  682. if batch_idx != target_batch_idx:
  683. raise RuntimeError(
  684. "Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
  685. worker_id
  686. )
  687. )
  688. else:
  689. full_trans_items.extend(trans_items)
  690. # Merge different parts into a batch.
  691. full_batch = collator.apply(full_trans_items)
  692. while True:
  693. try:
  694. batch_queue.put(full_batch, timeout=1)
  695. break
  696. except queue.Full:
  697. if shutdown_flag.value == 1:
  698. break
  699. logger.debug("batch queue is full!")
  700. with target_idx.get_lock():
  701. target_idx.value += 1
  702. batch_queue.disconnect_client()
  703. def _data_selecting_loop(
  704. trans_data_queues,
  705. batch_queue,
  706. collator,
  707. length,
  708. num_workers,
  709. shutdown_flag,
  710. target_idx,
  711. ):
  712. # Make sure that batch is generated exactly with the same order as generated indices
  713. while True:
  714. if shutdown_flag.value == 1:
  715. break
  716. target_batch_idx = target_idx.value
  717. if target_batch_idx >= length:
  718. break
  719. target_worker_id = target_batch_idx % num_workers
  720. while True:
  721. try:
  722. batch_idx, trans_items = trans_data_queues[target_worker_id].get(
  723. timeout=GLOBAL_TIMEOUT
  724. )
  725. batch_data = collator.apply(trans_items)
  726. break
  727. except queue.Empty:
  728. if shutdown_flag.value == 1:
  729. break
  730. logger.debug(
  731. "worker:{} data queue get timeout! target batch idx:{}".format(
  732. target_worker_id, target_batch_idx
  733. )
  734. )
  735. if batch_idx != target_batch_idx:
  736. raise RuntimeError(
  737. "batch_idx {} mismatch the target_batch_idx {}".format(
  738. batch_idx, target_batch_idx
  739. )
  740. )
  741. while True:
  742. try:
  743. batch_queue.put(batch_data, timeout=1)
  744. break
  745. except queue.Full:
  746. if shutdown_flag.value == 1:
  747. break
  748. logger.debug("batch queue is full!")
  749. with target_idx.get_lock():
  750. target_idx.value += 1
  751. batch_queue.disconnect_client()