You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import gc
  11. import math
  12. import multiprocessing
  13. import os
  14. import platform
  15. import queue
  16. import random
  17. import threading
  18. import time
  19. from typing import Callable, Union
  20. import numpy as np
  21. from ..device import _sh, get_default_device
  22. from ..functional.tensor import copy
  23. from ..logger import get_logger
  24. from ..random.rng import _random_seed_generator
  25. from ..tensor import Tensor
  26. from .collator import Collator
  27. from .dataset import Dataset, StreamDataset
  28. from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
  29. from .transform import PseudoTransform, Transform
  30. try:
  31. import thread
  32. except:
  33. import _thread as thread
  34. logger = get_logger(__name__)
  35. GLOBAL_TIMEOUT = 5
  36. def raise_timeout_error():
  37. raise RuntimeError("dataloader timeout")
  38. class DataLoader:
  39. r"""Provides a convenient way to iterate on a given dataset.
  40. DataLoader combines a dataset with
  41. :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`,
  42. make it flexible to get minibatch continually from a dataset.
  43. Args:
  44. dataset: dataset from which to load the minibatch.
  45. sampler: defines the strategy to sample data from the dataset.
  46. transform: defined the transforming strategy for a sampled batch.
  47. Default: None
  48. collator: defined the merging strategy for a transformed batch.
  49. Default: None
  50. num_workers: the number of sub-process to load, transform and collate
  51. the batch. ``0`` means using single-process. Default: 0
  52. timeout: if positive, means the timeout value(second) for collecting a
  53. batch from workers. Default: 0
  54. timeout_event: callback function triggered by timeout, default to raise
  55. runtime error.
  56. divide: define the paralleling strategy in multi-processing mode.
  57. ``True`` means one batch is divided into :attr:`num_workers` pieces, and
  58. the workers will process these pieces parallelly. ``False`` means
  59. different sub-process will process different batch. Default: False
  60. preload: Defines whether to apply the preloading strategy of dataloader, and parallelize the copy of host2device while kernal is executed to improve the loading speed. default is seted False
  61. the output will change from np.ndarry to dtype tensor. the support dtypes for preload are int,float,list[int,float],tuple[int,float],and another type is not supported.
  62. """
  63. __initialized = False
  64. def __init__(
  65. self,
  66. dataset: Dataset,
  67. sampler: Sampler = None,
  68. transform: Transform = None,
  69. collator: Collator = None,
  70. num_workers: int = 0,
  71. timeout: int = 0,
  72. timeout_event: Callable = raise_timeout_error,
  73. divide: bool = False,
  74. preload: bool = False,
  75. ):
  76. if num_workers < 0:
  77. raise ValueError("num_workers should not be negative")
  78. if timeout < 0:
  79. raise ValueError("timeout should not be negative")
  80. if divide and num_workers <= 1:
  81. raise ValueError("divide should not be set to True when num_workers <= 1")
  82. self.dataset = dataset
  83. self.num_workers = num_workers
  84. self.timeout = timeout
  85. self.timeout_event = timeout_event
  86. self.divide = divide
  87. self.preload = preload
  88. if isinstance(dataset, StreamDataset):
  89. self.sampler = sampler if sampler else StreamSampler(batch_size=1)
  90. assert isinstance(
  91. self.sampler, StreamSampler
  92. ), "types of dataset and sampler do not match"
  93. else:
  94. assert isinstance(
  95. dataset, Dataset
  96. ), "Can not recognize this kind of dataset: %s" % type(dataset)
  97. self.sampler = (
  98. sampler
  99. if sampler
  100. else SequentialSampler(dataset, batch_size=1, drop_last=False)
  101. )
  102. assert isinstance(
  103. self.sampler, MapSampler
  104. ), "types of dataset and sampler do not match"
  105. if divide:
  106. if self.sampler.batch_size <= self.num_workers:
  107. raise ValueError(
  108. "batch size must not smaller than num_workers in divide mode."
  109. )
  110. elif self.sampler.batch_size % self.num_workers:
  111. logger.warning(
  112. "batch size is not divisible by num_workers, may lose performance in divide mode."
  113. )
  114. if transform is None:
  115. self.transform = PseudoTransform()
  116. else:
  117. self.transform = transform
  118. if collator is None:
  119. self.collator = Collator()
  120. else:
  121. self.collator = collator
  122. self.__initialized = True
  123. def __iter__(self):
  124. if platform.system() == "Windows" and self.num_workers > 0:
  125. print(
  126. "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
  127. )
  128. self.num_workers = 0
  129. if os.getenv("TERMUX_VERSION"):
  130. # FIXME: termux install pyarrow will build error now
  131. # remove this logic after pyarrow fix this issue
  132. print(
  133. "pyarrow do not support on termux env now, changing num_workers to be zero"
  134. )
  135. self.num_workers = 0
  136. if isinstance(self.dataset, StreamDataset):
  137. if not self.num_workers:
  138. return _SerialStreamDataLoaderIter(self, self.preload)
  139. else:
  140. return _ParallelStreamDataLoaderIter(self, self.preload)
  141. else:
  142. assert isinstance(
  143. self.dataset, Dataset
  144. ), "Can not recognize this kind of dataset: %s" % type(self.dataset)
  145. if not self.num_workers:
  146. return _SerialMapDataLoaderIter(self, self.preload)
  147. else:
  148. return _ParallelMapDataLoaderIter(self, self.preload)
  149. def __len__(self):
  150. return len(self.sampler)
  151. class PreLoader:
  152. def __init__(self, preload):
  153. if preload:
  154. self.default_device = get_default_device()
  155. self.pre_load_device = self.default_device + ":" + str(_sh.get_next())
  156. self.pre_load_device_cache = None
  157. self.preload = preload
  158. """
  159. strategy one: load from numpy data, and generate dtype tensor
  160. """
  161. def _load_tensor(self, batch, cached=True):
  162. if isinstance(batch, np.ndarray):
  163. device = self.pre_load_device if cached else self.default_device
  164. return Tensor(batch, device=device)
  165. elif isinstance(batch, collections.abc.Mapping):
  166. return {k: self._load_tensor(v, cached) for k, v in batch.items()}
  167. elif isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple
  168. return type(batch)(*(self._load_tensor(value, cached) for value in batch))
  169. elif isinstance(batch, collections.abc.Sequence):
  170. return [self._load_tensor(value, cached) for value in batch]
  171. else:
  172. return batch
  173. """
  174. strategy two: load from cache that is already tensor just do d2d copy
  175. """
  176. def _load_cache(self, data):
  177. if isinstance(data, Tensor):
  178. if data.device == self.default_device:
  179. return data
  180. return copy(data, device=self.default_device)
  181. elif isinstance(data, collections.abc.Mapping):
  182. return {k: self._load_cache(v) for k, v in data.items()}
  183. elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
  184. return type(data)(*(self._load_cache(value) for value in data))
  185. elif isinstance(data, collections.abc.Sequence):
  186. return [self._load_cache(value) for value in data]
  187. else:
  188. return data
  189. def _swap_out_cache(self):
  190. out = self._load_cache(self.pre_load_device_cache)
  191. self.pre_load_device_cache = None # clean cache
  192. return out
  193. class _BaseMapDataLoaderIter(PreLoader):
  194. def __init__(self, loader, preload):
  195. super().__init__(preload)
  196. self.dataset = loader.dataset
  197. self.sampler = loader.sampler
  198. self.seed = _random_seed_generator().__next__()
  199. self.transform = loader.transform
  200. self.collator = loader.collator
  201. self.num_workers = loader.num_workers
  202. self.timeout = loader.timeout
  203. self.timeout_event = loader.timeout_event
  204. self.divide = loader.divide
  205. self.num_processed = 0
  206. def _get_next_batch(self):
  207. raise NotImplementedError
  208. def __len__(self):
  209. return len(self.sampler)
  210. def __iter__(self):
  211. return self
  212. def __next__(self):
  213. if self.preload:
  214. cached = self.pre_load_device_cache
  215. if cached is None: # first and last
  216. if self.num_processed >= len(self): # last
  217. raise StopIteration
  218. elif self.num_processed == 0: # first
  219. self._try_load_tensor(cached=False) # first do the h2d
  220. out = self._swap_out_cache()
  221. self._try_load_tensor()
  222. return out
  223. else:
  224. if self.num_processed >= len(self):
  225. raise StopIteration
  226. minibatch = self._get_next_batch()
  227. self.num_processed += 1
  228. return minibatch
  229. def _try_load_tensor(self, cached=True):
  230. if self.num_processed >= len(self):
  231. return
  232. else:
  233. self.num_processed += 1
  234. batch = self._get_next_batch()
  235. self.pre_load_device_cache = self._load_tensor(batch, cached)
  236. class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
  237. def __init__(self, loader, preload):
  238. super(_SerialMapDataLoaderIter, self).__init__(loader, preload)
  239. self.indices_iter = iter(self.sampler)
  240. def _get_next_batch(self):
  241. indices = next(self.indices_iter)
  242. items = [self.dataset[idx] for idx in indices]
  243. trans_items = self.transform.apply_batch(items)
  244. return self.collator.apply(trans_items)
  245. class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
  246. __initialized = False
  247. def __init__(self, loader, preload):
  248. super(_ParallelMapDataLoaderIter, self).__init__(loader, preload)
  249. self.task_queues = [
  250. multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
  251. ]
  252. self.feed_batch_idx = multiprocessing.Value("i", 0)
  253. self.target_batch_idx = multiprocessing.Value("i", 0)
  254. self.shutdown_flag = multiprocessing.Value("i", 0)
  255. self.trans_data_queues = [
  256. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  257. ]
  258. # use shared-memory queue implemented by pyarrow plasma store.
  259. from .tools._queue import PlasmaShmQueue
  260. self.batch_queue = PlasmaShmQueue(maxsize=2)
  261. self.task_feeding_worker = multiprocessing.Process(
  262. target=_task_feeding_loop,
  263. args=(
  264. iter(self.sampler),
  265. self.task_queues,
  266. self.num_workers,
  267. self.divide,
  268. self.shutdown_flag,
  269. self.feed_batch_idx,
  270. ),
  271. daemon=True,
  272. )
  273. gc.collect()
  274. self.task_feeding_worker.start()
  275. self.workers = []
  276. for worker_id in range(self.num_workers):
  277. worker = multiprocessing.Process(
  278. target=_worker_loop,
  279. args=(
  280. self.dataset,
  281. self.task_queues[worker_id],
  282. self.trans_data_queues[worker_id],
  283. self.transform,
  284. self.seed + worker_id + 1,
  285. self.shutdown_flag,
  286. ),
  287. daemon=True,
  288. )
  289. gc.collect()
  290. worker.start()
  291. self.workers.append(worker)
  292. if self.divide:
  293. self.data_collecting_worker = multiprocessing.Process(
  294. target=_data_gathering_loop,
  295. args=(
  296. self.trans_data_queues,
  297. self.batch_queue,
  298. self.collator,
  299. len(self),
  300. self.num_workers,
  301. self.shutdown_flag,
  302. self.target_batch_idx,
  303. ),
  304. daemon=True,
  305. )
  306. else:
  307. self.data_collecting_worker = multiprocessing.Process(
  308. target=_data_selecting_loop,
  309. args=(
  310. self.trans_data_queues,
  311. self.batch_queue,
  312. self.collator,
  313. len(self),
  314. self.num_workers,
  315. self.shutdown_flag,
  316. self.target_batch_idx,
  317. ),
  318. daemon=True,
  319. )
  320. gc.collect()
  321. self.data_collecting_worker.start()
  322. self.__initialized = True
  323. def _check_workers(self):
  324. # Check the status of each worker.
  325. if not self.data_collecting_worker.is_alive():
  326. exitcode = self.data_collecting_worker.exitcode
  327. if exitcode != 0:
  328. raise RuntimeError("data collecting worker died. {}".format(exitcode))
  329. if not self.task_feeding_worker.is_alive():
  330. exitcode = self.task_feeding_worker.exitcode
  331. if exitcode != 0:
  332. raise RuntimeError("task feeding worker died. {}".format(exitcode))
  333. for worker_id, worker in enumerate(self.workers):
  334. if not worker.is_alive():
  335. exitcode = worker.exitcode
  336. if exitcode != 0:
  337. raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode))
  338. logger.debug("all workers are alive.")
  339. def _get_next_batch(self):
  340. start_time = time.time()
  341. while True:
  342. self._check_workers()
  343. try:
  344. return self.batch_queue.get(timeout=1)
  345. except queue.Empty:
  346. logger.debug("batch queue empty!")
  347. waited_time = time.time() - start_time
  348. if self.timeout > 0:
  349. if waited_time > self.timeout:
  350. raise RuntimeError("get_next_batch timeout!")
  351. def _shutdown(self):
  352. with self.shutdown_flag.get_lock():
  353. self.shutdown_flag.value = 1
  354. if self.task_feeding_worker.is_alive():
  355. self.task_feeding_worker.terminate()
  356. self.task_feeding_worker.join()
  357. if self.data_collecting_worker.is_alive():
  358. self.data_collecting_worker.terminate()
  359. self.data_collecting_worker.join()
  360. for worker in self.workers:
  361. if worker.is_alive():
  362. worker.terminate()
  363. worker.join()
  364. for q in self.trans_data_queues:
  365. q.cancel_join_thread()
  366. q.close()
  367. for q in self.task_queues:
  368. q.cancel_join_thread()
  369. q.close()
  370. self.batch_queue.cancel_join_thread()
  371. self.batch_queue.close()
  372. def __del__(self):
  373. if self.__initialized:
  374. self._shutdown()
  375. class _BaseStreamDataLoaderIter(PreLoader):
  376. def __init__(self, loader, preload):
  377. super().__init__(preload)
  378. self.dataset = loader.dataset
  379. self.sampler = loader.sampler
  380. self.transform = loader.transform
  381. self.collator = loader.collator
  382. self.num_workers = loader.num_workers
  383. self.timeout = loader.timeout
  384. self.timeout_event = loader.timeout_event
  385. def _get_next_batch(self):
  386. raise NotImplementedError
  387. def _process_raw_data(self, raw_data):
  388. assert len(raw_data) == 2 and isinstance(
  389. raw_data[0], bool
  390. ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
  391. if not raw_data[0]:
  392. data = list((x,) for x in raw_data[1])
  393. else:
  394. data = raw_data[1]
  395. ret = []
  396. for idx in range(len(data[0])):
  397. ret.append(tuple(e[idx] for e in data))
  398. return ret
  399. def __iter__(self):
  400. return self
  401. def __next__(self):
  402. if self.preload:
  403. if self.pre_load_device_cache is None:
  404. self._try_load_tensor(cached=False) # load in current
  405. out = self._swap_out_cache()
  406. self._try_load_tensor() # load in cached
  407. return out
  408. else:
  409. return self._get_next_batch()
  410. def _try_load_tensor(self, cached=True):
  411. batch = self._get_next_batch()
  412. self.pre_load_device_cache = self._load_tensor(batch, cached)
  413. class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  414. def __init__(self, loader, preload):
  415. super().__init__(loader, preload)
  416. self.dataset_iter = iter(self.dataset)
  417. self.idx = 0
  418. self.unused = []
  419. def _try_get_raw_data(self, start_time):
  420. raw_data = None
  421. while not raw_data:
  422. try:
  423. if self.timeout > 0:
  424. timer = threading.Timer(self.timeout, thread.interrupt_main)
  425. timer.start()
  426. raw_data = next(self.dataset_iter)
  427. if self.timeout > 0:
  428. timer.cancel()
  429. except KeyboardInterrupt:
  430. raw_data = self.timeout_event()
  431. except:
  432. if self.timeout > 0:
  433. timer.cancel()
  434. waited_time = time.time() - start_time
  435. if waited_time > self.timeout:
  436. raw_data = self.timeout_event()
  437. return raw_data
  438. def _get_next_batch(self):
  439. ret = []
  440. start_time = time.time()
  441. while len(ret) < self.sampler.batch_size:
  442. if len(self.unused) != 0:
  443. batch_data = self.unused
  444. else:
  445. raw_data = self._try_get_raw_data(start_time)
  446. batch_data = self._process_raw_data(raw_data)
  447. while len(batch_data) != 0 and len(ret) < self.sampler.batch_size:
  448. data = batch_data.pop()
  449. ret.append(self.transform.apply(data))
  450. self.unused = batch_data
  451. return self.collator.apply(ret)
  452. class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  453. __initialized = False
  454. def __init__(self, loader, preload):
  455. super().__init__(loader, preload)
  456. self.shutdown_flag = multiprocessing.Value("i", 0)
  457. self.raw_data_queues = [
  458. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  459. ]
  460. self.trans_data_queues = [
  461. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  462. ]
  463. # shared-memory queue implemented by pyarrow plasma store
  464. from .tools._queue import PlasmaShmQueue
  465. self.batch_queue = PlasmaShmQueue(maxsize=2)
  466. self.recieve_worker = multiprocessing.Process(
  467. target=self._worker_to_raw_data_queues, daemon=True
  468. )
  469. gc.collect()
  470. self.recieve_worker.start()
  471. self.transform_workers = []
  472. for worker_id in range(self.num_workers):
  473. worker = multiprocessing.Process(
  474. target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True
  475. )
  476. gc.collect()
  477. worker.start()
  478. self.transform_workers.append(worker)
  479. self.collect_worker = multiprocessing.Process(
  480. target=self._worker_to_batch_queue, daemon=True
  481. )
  482. gc.collect()
  483. self.collect_worker.start()
  484. self.__initialized = True
  485. def _put_raw_data_queues(self, raw_data, qidx):
  486. batch_data = self._process_raw_data(raw_data)
  487. for data in batch_data:
  488. while True:
  489. qidx = qidx % self.num_workers
  490. try:
  491. self.raw_data_queues[qidx].put(data)
  492. break
  493. except queue.Full:
  494. if self.shutdown_flag.value == 1:
  495. break
  496. logger.debug("raw data queue %d is full" % qidx)
  497. finally:
  498. qidx += 1
  499. return qidx
  500. def _worker_to_raw_data_queues(self):
  501. dataset_iter = iter(self.dataset)
  502. qidx = 0
  503. while True:
  504. if self.shutdown_flag.value == 1:
  505. break
  506. raw_data = next(dataset_iter)
  507. qidx = self._put_raw_data_queues(raw_data, qidx)
  508. def _worker_to_trans_data_queues(self, worker_id):
  509. while True:
  510. if self.shutdown_flag.value == 1:
  511. break
  512. try:
  513. data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)
  514. except queue.Empty:
  515. continue
  516. trans_data = self.transform.apply(data)
  517. while True:
  518. try:
  519. self.trans_data_queues[worker_id].put(trans_data)
  520. break
  521. except queue.Full:
  522. if self.shutdown_flag.value == 1:
  523. break
  524. logger.debug("batch queue if full")
  525. def _worker_to_batch_queue(self):
  526. cnt = -1
  527. trans_items = []
  528. while True:
  529. if self.shutdown_flag.value == 1:
  530. break
  531. cnt += 1
  532. queue_id = cnt % self.num_workers
  533. try:
  534. trans_item = self.trans_data_queues[queue_id].get(
  535. timeout=GLOBAL_TIMEOUT
  536. )
  537. except queue.Empty:
  538. continue
  539. trans_items.append(trans_item)
  540. if len(trans_items) == self.sampler.batch_size:
  541. batch_data = self.collator.apply(trans_items)
  542. while True:
  543. try:
  544. self.batch_queue.put(batch_data, timeout=1)
  545. break
  546. except queue.Full:
  547. if self.shutdown_flag.value == 1:
  548. break
  549. logger.debug("batch queue is full")
  550. trans_items = []
  551. def _check_workers(self):
  552. if not self.collect_worker.is_alive():
  553. exitcode = self.collect_worker.exitcode
  554. if exitcode != 0:
  555. raise RuntimeError("collator worker died. {}".format(exitcode))
  556. for worker_id, worker in enumerate(self.transform_workers):
  557. if not worker.is_alive():
  558. exitcode = worker.exitcode
  559. if exitcode != 0:
  560. raise RuntimeError(
  561. "worker: {} died. {}".format(worker_id, exitcode)
  562. )
  563. def _get_next_batch(self):
  564. start_time = time.time()
  565. while True:
  566. self._check_workers()
  567. try:
  568. return self.batch_queue.get(timeout=1)
  569. except queue.Empty:
  570. logger.debug("batch queue empty!")
  571. waited_time = time.time() - start_time
  572. if self.timeout > 0 and waited_time > self.timeout:
  573. self._put_raw_data_queues(self.timeout_event(), 0)
  574. def _shutdown(self):
  575. with self.shutdown_flag.get_lock():
  576. self.shutdown_flag.value = 1
  577. if self.recieve_worker.is_alive():
  578. self.recieve_worker.terminate()
  579. self.recieve_worker.join()
  580. if self.collect_worker.is_alive():
  581. self.collect_worker.terminate()
  582. self.collect_worker.join()
  583. for worker in self.transform_workers:
  584. if worker.is_alive():
  585. worker.terminate()
  586. worker.join()
  587. for q in self.raw_data_queues:
  588. q.cancel_join_thread()
  589. q.close()
  590. for q in self.trans_data_queues:
  591. q.cancel_join_thread()
  592. q.close()
  593. self.batch_queue.cancel_join_thread()
  594. self.batch_queue.close()
  595. def __del__(self):
  596. if self.__initialized:
  597. self._shutdown()
  598. def _task_feeding_loop(
  599. indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
  600. ):
  601. # Feed the indices into the task queues
  602. while True:
  603. if shutdown_flag.value == 1:
  604. break
  605. batch_idx = feed_batch_idx.value
  606. try:
  607. indices = next(indices_iter)
  608. except StopIteration:
  609. break
  610. if divide:
  611. # make sure all task_queues is ready for put
  612. while any([q.full() for q in task_queues]):
  613. if shutdown_flag.value == 1:
  614. return
  615. # divide into small pieces, feed to different workers.
  616. sub_num = math.ceil(len(indices) / num_workers)
  617. for worker_id in range(num_workers):
  618. sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num]
  619. task_queues[worker_id].put((batch_idx, sub_indices))
  620. else:
  621. # distribute tasks to different workers uniformly.
  622. target_id = batch_idx % num_workers
  623. while task_queues[target_id].full():
  624. if shutdown_flag.value == 1:
  625. return
  626. task_queues[target_id].put((batch_idx, indices))
  627. with feed_batch_idx.get_lock():
  628. feed_batch_idx.value += 1
  629. def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag):
  630. # Get dataset items and do the transform
  631. random.seed(seed)
  632. np.random.seed(seed)
  633. while True:
  634. if shutdown_flag.value == 1:
  635. break
  636. try:
  637. batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT)
  638. except queue.Empty:
  639. continue
  640. if len(indices) > 0:
  641. items = [dataset[idx] for idx in indices]
  642. trans_items = transform.apply_batch(items)
  643. else:
  644. # in case of incomplete last batch
  645. trans_items = ()
  646. while True:
  647. try:
  648. trans_data_queue.put((batch_idx, trans_items), timeout=1)
  649. break
  650. except queue.Full:
  651. if shutdown_flag.value == 1:
  652. break
  653. logger.debug("batch part queue is full!")
  654. def _data_gathering_loop(
  655. trans_data_queues,
  656. batch_queue,
  657. collator,
  658. length,
  659. num_workers,
  660. shutdown_flag,
  661. target_idx,
  662. ):
  663. # Gathering the small pieces of batch data into full batch data
  664. while True:
  665. if shutdown_flag.value == 1:
  666. break
  667. target_batch_idx = target_idx.value
  668. if target_batch_idx >= length:
  669. break
  670. full_trans_items = []
  671. for worker_id in range(num_workers):
  672. while True:
  673. try:
  674. batch_idx, trans_items = trans_data_queues[worker_id].get(
  675. timeout=GLOBAL_TIMEOUT
  676. )
  677. break
  678. except queue.Empty:
  679. if shutdown_flag.value == 1:
  680. break
  681. logger.debug(
  682. "worker:{} data queue get timeout! target batch idx:{}".format(
  683. worker_id, target_batch_idx
  684. )
  685. )
  686. if batch_idx != target_batch_idx:
  687. raise RuntimeError(
  688. "Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
  689. worker_id
  690. )
  691. )
  692. else:
  693. full_trans_items.extend(trans_items)
  694. # Merge different parts into a batch.
  695. full_batch = collator.apply(full_trans_items)
  696. while True:
  697. try:
  698. batch_queue.put(full_batch, timeout=1)
  699. break
  700. except queue.Full:
  701. if shutdown_flag.value == 1:
  702. break
  703. logger.debug("batch queue is full!")
  704. with target_idx.get_lock():
  705. target_idx.value += 1
  706. batch_queue.disconnect_client()
  707. def _data_selecting_loop(
  708. trans_data_queues,
  709. batch_queue,
  710. collator,
  711. length,
  712. num_workers,
  713. shutdown_flag,
  714. target_idx,
  715. ):
  716. # Make sure that batch is generated exactly with the same order as generated indices
  717. while True:
  718. if shutdown_flag.value == 1:
  719. break
  720. target_batch_idx = target_idx.value
  721. if target_batch_idx >= length:
  722. break
  723. target_worker_id = target_batch_idx % num_workers
  724. while True:
  725. try:
  726. batch_idx, trans_items = trans_data_queues[target_worker_id].get(
  727. timeout=GLOBAL_TIMEOUT
  728. )
  729. batch_data = collator.apply(trans_items)
  730. break
  731. except queue.Empty:
  732. if shutdown_flag.value == 1:
  733. break
  734. logger.debug(
  735. "worker:{} data queue get timeout! target batch idx:{}".format(
  736. target_worker_id, target_batch_idx
  737. )
  738. )
  739. if batch_idx != target_batch_idx:
  740. raise RuntimeError(
  741. "batch_idx {} mismatch the target_batch_idx {}".format(
  742. batch_idx, target_batch_idx
  743. )
  744. )
  745. while True:
  746. try:
  747. batch_queue.put(batch_data, timeout=1)
  748. break
  749. except queue.Full:
  750. if shutdown_flag.value == 1:
  751. break
  752. logger.debug("batch queue is full!")
  753. with target_idx.get_lock():
  754. target_idx.value += 1
  755. batch_queue.disconnect_client()