You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import gc
  11. import math
  12. import multiprocessing
  13. import os
  14. import platform
  15. import queue
  16. import random
  17. import threading
  18. import time
  19. from typing import Callable, Union
  20. import numpy as np
  21. from ..device import _sh, get_default_device
  22. from ..functional.tensor import copy
  23. from ..logger import get_logger
  24. from ..random.rng import _random_seed_generator
  25. from ..tensor import Tensor
  26. from .collator import Collator
  27. from .dataset import Dataset, StreamDataset
  28. from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
  29. from .transform import PseudoTransform, Transform
  30. try:
  31. import thread
  32. except:
  33. import _thread as thread
  34. logger = get_logger(__name__)
  35. GLOBAL_TIMEOUT = 5
  36. def raise_timeout_error():
  37. raise RuntimeError("dataloader timeout")
  38. class DataLoader:
  39. r"""Provides a convenient way to iterate on a given dataset.
  40. DataLoader combines a dataset with
  41. :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`,
  42. make it flexible to get minibatch continually from a dataset.
  43. Args:
  44. dataset: dataset from which to load the minibatch.
  45. sampler: defines the strategy to sample data from the dataset.
  46. transform: defined the transforming strategy for a sampled batch.
  47. Default: None
  48. collator: defined the merging strategy for a transformed batch.
  49. Default: None
  50. num_workers: the number of sub-process to load, transform and collate
  51. the batch. ``0`` means using single-process. Default: 0
  52. timeout: if positive, means the timeout value(second) for collecting a
  53. batch from workers. Default: 0
  54. timeout_event: callback function triggered by timeout, default to raise
  55. runtime error.
  56. divide: define the paralleling strategy in multi-processing mode.
  57. ``True`` means one batch is divided into :attr:`num_workers` pieces, and
  58. the workers will process these pieces parallelly. ``False`` means
  59. different sub-process will process different batch. Default: False
  60. preload: whether to enable the preloading strategy of the dataloader. When enabling, the dataloader will preload one batch to the device memory to speed up the whole training process.
  61. All values in the map, list, and tuple will be converted to :class:`~.Tensor` by preloading, and you will get :class:`~.Tensor` instead of the original Numpy array or Python number.
  62. .. note::
  63. By enabling preload, tensors' host2device copy and device kernel execution will be overlapped, which will improve the training speed at the cost of higher device memory usage (due to one more batch data on device memory).
  64. This feature saves more time when your NN training time is short or your machine's host PCIe bandwidth for each device is low.
  65. """
  66. __initialized = False
  67. def __init__(
  68. self,
  69. dataset: Dataset,
  70. sampler: Sampler = None,
  71. transform: Transform = None,
  72. collator: Collator = None,
  73. num_workers: int = 0,
  74. timeout: int = 0,
  75. timeout_event: Callable = raise_timeout_error,
  76. divide: bool = False,
  77. preload: bool = False,
  78. ):
  79. if num_workers < 0:
  80. raise ValueError("num_workers should not be negative")
  81. if timeout < 0:
  82. raise ValueError("timeout should not be negative")
  83. if divide and num_workers <= 1:
  84. raise ValueError("divide should not be set to True when num_workers <= 1")
  85. self.dataset = dataset
  86. self.num_workers = num_workers
  87. self.timeout = timeout
  88. self.timeout_event = timeout_event
  89. self.divide = divide
  90. self.preload = preload
  91. if isinstance(dataset, StreamDataset):
  92. self.sampler = sampler if sampler else StreamSampler(batch_size=1)
  93. assert isinstance(
  94. self.sampler, StreamSampler
  95. ), "types of dataset and sampler do not match"
  96. else:
  97. assert isinstance(
  98. dataset, Dataset
  99. ), "Can not recognize this kind of dataset: %s" % type(dataset)
  100. self.sampler = (
  101. sampler
  102. if sampler
  103. else SequentialSampler(dataset, batch_size=1, drop_last=False)
  104. )
  105. assert isinstance(
  106. self.sampler, MapSampler
  107. ), "types of dataset and sampler do not match"
  108. if divide:
  109. if self.sampler.batch_size <= self.num_workers:
  110. raise ValueError(
  111. "batch size must not smaller than num_workers in divide mode."
  112. )
  113. elif self.sampler.batch_size % self.num_workers:
  114. logger.warning(
  115. "batch size is not divisible by num_workers, may lose performance in divide mode."
  116. )
  117. if transform is None:
  118. self.transform = PseudoTransform()
  119. else:
  120. self.transform = transform
  121. if collator is None:
  122. self.collator = Collator()
  123. else:
  124. self.collator = collator
  125. self.__initialized = True
  126. def __iter__(self):
  127. if platform.system() == "Windows" and self.num_workers > 0:
  128. print(
  129. "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
  130. )
  131. self.num_workers = 0
  132. if os.getenv("TERMUX_VERSION"):
  133. # FIXME: termux install pyarrow will build error now
  134. # remove this logic after pyarrow fix this issue
  135. print(
  136. "pyarrow do not support on termux env now, changing num_workers to be zero"
  137. )
  138. self.num_workers = 0
  139. if isinstance(self.dataset, StreamDataset):
  140. if not self.num_workers:
  141. return _SerialStreamDataLoaderIter(self, self.preload)
  142. else:
  143. return _ParallelStreamDataLoaderIter(self, self.preload)
  144. else:
  145. assert isinstance(
  146. self.dataset, Dataset
  147. ), "Can not recognize this kind of dataset: %s" % type(self.dataset)
  148. if not self.num_workers:
  149. return _SerialMapDataLoaderIter(self, self.preload)
  150. else:
  151. return _ParallelMapDataLoaderIter(self, self.preload)
  152. def __len__(self):
  153. return len(self.sampler)
  154. class PreLoader:
  155. def __init__(self, preload):
  156. if preload:
  157. self.default_device = get_default_device()
  158. self.pre_load_device = self.default_device + ":" + str(_sh.get_next())
  159. self.pre_load_device_cache = None
  160. self.preload = preload
  161. """
  162. strategy one: load from numpy data, and generate dtype tensor
  163. """
  164. def _load_tensor(self, batch, cached=True):
  165. if isinstance(batch, np.ndarray):
  166. device = self.pre_load_device if cached else self.default_device
  167. return Tensor(batch, device=device)
  168. elif isinstance(batch, collections.abc.Mapping):
  169. return {k: self._load_tensor(v, cached) for k, v in batch.items()}
  170. elif isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple
  171. return type(batch)(*(self._load_tensor(value, cached) for value in batch))
  172. elif isinstance(batch, collections.abc.Sequence):
  173. return [self._load_tensor(value, cached) for value in batch]
  174. else:
  175. return batch
  176. """
  177. strategy two: load from cache that is already tensor just do d2d copy
  178. """
  179. def _load_cache(self, data):
  180. if isinstance(data, Tensor):
  181. if data.device == self.default_device:
  182. return data
  183. return copy(data, device=self.default_device)
  184. elif isinstance(data, collections.abc.Mapping):
  185. return {k: self._load_cache(v) for k, v in data.items()}
  186. elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
  187. return type(data)(*(self._load_cache(value) for value in data))
  188. elif isinstance(data, collections.abc.Sequence):
  189. return [self._load_cache(value) for value in data]
  190. else:
  191. return data
  192. def _swap_out_cache(self):
  193. out = self._load_cache(self.pre_load_device_cache)
  194. self.pre_load_device_cache = None # clean cache
  195. return out
  196. class _BaseMapDataLoaderIter(PreLoader):
  197. def __init__(self, loader, preload):
  198. super().__init__(preload)
  199. self.dataset = loader.dataset
  200. self.sampler = loader.sampler
  201. self.seed = _random_seed_generator().__next__()
  202. self.transform = loader.transform
  203. self.collator = loader.collator
  204. self.num_workers = loader.num_workers
  205. self.timeout = loader.timeout
  206. self.timeout_event = loader.timeout_event
  207. self.divide = loader.divide
  208. self.num_processed = 0
  209. def _get_next_batch(self):
  210. raise NotImplementedError
  211. def __len__(self):
  212. return len(self.sampler)
  213. def __iter__(self):
  214. return self
  215. def __next__(self):
  216. if self.preload:
  217. cached = self.pre_load_device_cache
  218. if cached is None: # first and last
  219. if self.num_processed >= len(self): # last
  220. raise StopIteration
  221. elif self.num_processed == 0: # first
  222. self._try_load_tensor(cached=False) # first do the h2d
  223. out = self._swap_out_cache()
  224. self._try_load_tensor()
  225. return out
  226. else:
  227. if self.num_processed >= len(self):
  228. raise StopIteration
  229. minibatch = self._get_next_batch()
  230. self.num_processed += 1
  231. return minibatch
  232. def _try_load_tensor(self, cached=True):
  233. if self.num_processed >= len(self):
  234. return
  235. else:
  236. self.num_processed += 1
  237. batch = self._get_next_batch()
  238. self.pre_load_device_cache = self._load_tensor(batch, cached)
  239. class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
  240. def __init__(self, loader, preload):
  241. super(_SerialMapDataLoaderIter, self).__init__(loader, preload)
  242. self.indices_iter = iter(self.sampler)
  243. def _get_next_batch(self):
  244. indices = next(self.indices_iter)
  245. items = [self.dataset[idx] for idx in indices]
  246. trans_items = self.transform.apply_batch(items)
  247. return self.collator.apply(trans_items)
  248. class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
  249. __initialized = False
  250. def __init__(self, loader, preload):
  251. super(_ParallelMapDataLoaderIter, self).__init__(loader, preload)
  252. self.task_queues = [
  253. multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
  254. ]
  255. self.feed_batch_idx = multiprocessing.Value("i", 0)
  256. self.target_batch_idx = multiprocessing.Value("i", 0)
  257. self.shutdown_flag = multiprocessing.Value("i", 0)
  258. self.trans_data_queues = [
  259. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  260. ]
  261. # use shared-memory queue implemented by pyarrow plasma store.
  262. from .tools._queue import PlasmaShmQueue
  263. self.batch_queue = PlasmaShmQueue(maxsize=2)
  264. self.task_feeding_worker = multiprocessing.Process(
  265. target=_task_feeding_loop,
  266. args=(
  267. iter(self.sampler),
  268. self.task_queues,
  269. self.num_workers,
  270. self.divide,
  271. self.shutdown_flag,
  272. self.feed_batch_idx,
  273. ),
  274. daemon=True,
  275. )
  276. gc.collect()
  277. self.task_feeding_worker.start()
  278. self.workers = []
  279. for worker_id in range(self.num_workers):
  280. worker = multiprocessing.Process(
  281. target=_worker_loop,
  282. args=(
  283. self.dataset,
  284. self.task_queues[worker_id],
  285. self.trans_data_queues[worker_id],
  286. self.transform,
  287. self.seed + worker_id + 1,
  288. self.shutdown_flag,
  289. ),
  290. daemon=True,
  291. )
  292. gc.collect()
  293. worker.start()
  294. self.workers.append(worker)
  295. if self.divide:
  296. self.data_collecting_worker = multiprocessing.Process(
  297. target=_data_gathering_loop,
  298. args=(
  299. self.trans_data_queues,
  300. self.batch_queue,
  301. self.collator,
  302. len(self),
  303. self.num_workers,
  304. self.shutdown_flag,
  305. self.target_batch_idx,
  306. ),
  307. daemon=True,
  308. )
  309. else:
  310. self.data_collecting_worker = multiprocessing.Process(
  311. target=_data_selecting_loop,
  312. args=(
  313. self.trans_data_queues,
  314. self.batch_queue,
  315. self.collator,
  316. len(self),
  317. self.num_workers,
  318. self.shutdown_flag,
  319. self.target_batch_idx,
  320. ),
  321. daemon=True,
  322. )
  323. gc.collect()
  324. self.data_collecting_worker.start()
  325. self.__initialized = True
  326. def _check_workers(self):
  327. # Check the status of each worker.
  328. if not self.data_collecting_worker.is_alive():
  329. exitcode = self.data_collecting_worker.exitcode
  330. if exitcode != 0:
  331. raise RuntimeError("data collecting worker died. {}".format(exitcode))
  332. if not self.task_feeding_worker.is_alive():
  333. exitcode = self.task_feeding_worker.exitcode
  334. if exitcode != 0:
  335. raise RuntimeError("task feeding worker died. {}".format(exitcode))
  336. for worker_id, worker in enumerate(self.workers):
  337. if not worker.is_alive():
  338. exitcode = worker.exitcode
  339. if exitcode != 0:
  340. raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode))
  341. logger.debug("all workers are alive.")
  342. def _get_next_batch(self):
  343. start_time = time.time()
  344. while True:
  345. self._check_workers()
  346. try:
  347. return self.batch_queue.get(timeout=1)
  348. except queue.Empty:
  349. logger.debug("batch queue empty!")
  350. waited_time = time.time() - start_time
  351. if self.timeout > 0:
  352. if waited_time > self.timeout:
  353. raise RuntimeError("get_next_batch timeout!")
  354. def _shutdown(self):
  355. with self.shutdown_flag.get_lock():
  356. self.shutdown_flag.value = 1
  357. if self.task_feeding_worker.is_alive():
  358. self.task_feeding_worker.terminate()
  359. self.task_feeding_worker.join()
  360. if self.data_collecting_worker.is_alive():
  361. self.data_collecting_worker.terminate()
  362. self.data_collecting_worker.join()
  363. for worker in self.workers:
  364. if worker.is_alive():
  365. worker.terminate()
  366. worker.join()
  367. for q in self.trans_data_queues:
  368. q.cancel_join_thread()
  369. q.close()
  370. for q in self.task_queues:
  371. q.cancel_join_thread()
  372. q.close()
  373. self.batch_queue.cancel_join_thread()
  374. self.batch_queue.close()
  375. def __del__(self):
  376. if self.__initialized:
  377. self._shutdown()
  378. class _BaseStreamDataLoaderIter(PreLoader):
  379. def __init__(self, loader, preload):
  380. super().__init__(preload)
  381. self.dataset = loader.dataset
  382. self.sampler = loader.sampler
  383. self.transform = loader.transform
  384. self.collator = loader.collator
  385. self.num_workers = loader.num_workers
  386. self.timeout = loader.timeout
  387. self.timeout_event = loader.timeout_event
  388. def _get_next_batch(self):
  389. raise NotImplementedError
  390. def _process_raw_data(self, raw_data):
  391. assert len(raw_data) == 2 and isinstance(
  392. raw_data[0], bool
  393. ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
  394. if not raw_data[0]:
  395. data = list((x,) for x in raw_data[1])
  396. else:
  397. data = raw_data[1]
  398. ret = []
  399. for idx in range(len(data[0])):
  400. ret.append(tuple(e[idx] for e in data))
  401. return ret
  402. def __iter__(self):
  403. return self
  404. def __next__(self):
  405. if self.preload:
  406. if self.pre_load_device_cache is None:
  407. self._try_load_tensor(cached=False) # load in current
  408. out = self._swap_out_cache()
  409. self._try_load_tensor() # load in cached
  410. return out
  411. else:
  412. return self._get_next_batch()
  413. def _try_load_tensor(self, cached=True):
  414. batch = self._get_next_batch()
  415. self.pre_load_device_cache = self._load_tensor(batch, cached)
  416. class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  417. def __init__(self, loader, preload):
  418. super().__init__(loader, preload)
  419. self.dataset_iter = iter(self.dataset)
  420. self.idx = 0
  421. self.unused = []
  422. def _try_get_raw_data(self, start_time):
  423. raw_data = None
  424. while not raw_data:
  425. try:
  426. if self.timeout > 0:
  427. timer = threading.Timer(self.timeout, thread.interrupt_main)
  428. timer.start()
  429. raw_data = next(self.dataset_iter)
  430. if self.timeout > 0:
  431. timer.cancel()
  432. except KeyboardInterrupt:
  433. raw_data = self.timeout_event()
  434. except:
  435. if self.timeout > 0:
  436. timer.cancel()
  437. waited_time = time.time() - start_time
  438. if waited_time > self.timeout:
  439. raw_data = self.timeout_event()
  440. return raw_data
  441. def _get_next_batch(self):
  442. ret = []
  443. start_time = time.time()
  444. while len(ret) < self.sampler.batch_size:
  445. if len(self.unused) != 0:
  446. batch_data = self.unused
  447. else:
  448. raw_data = self._try_get_raw_data(start_time)
  449. batch_data = self._process_raw_data(raw_data)
  450. while len(batch_data) != 0 and len(ret) < self.sampler.batch_size:
  451. data = batch_data.pop()
  452. ret.append(self.transform.apply(data))
  453. self.unused = batch_data
  454. return self.collator.apply(ret)
  455. class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  456. __initialized = False
  457. def __init__(self, loader, preload):
  458. super().__init__(loader, preload)
  459. self.shutdown_flag = multiprocessing.Value("i", 0)
  460. self.raw_data_queues = [
  461. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  462. ]
  463. self.trans_data_queues = [
  464. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  465. ]
  466. # shared-memory queue implemented by pyarrow plasma store
  467. from .tools._queue import PlasmaShmQueue
  468. self.batch_queue = PlasmaShmQueue(maxsize=2)
  469. self.recieve_worker = multiprocessing.Process(
  470. target=self._worker_to_raw_data_queues, daemon=True
  471. )
  472. gc.collect()
  473. self.recieve_worker.start()
  474. self.transform_workers = []
  475. for worker_id in range(self.num_workers):
  476. worker = multiprocessing.Process(
  477. target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True
  478. )
  479. gc.collect()
  480. worker.start()
  481. self.transform_workers.append(worker)
  482. self.collect_worker = multiprocessing.Process(
  483. target=self._worker_to_batch_queue, daemon=True
  484. )
  485. gc.collect()
  486. self.collect_worker.start()
  487. self.__initialized = True
  488. def _put_raw_data_queues(self, raw_data, qidx):
  489. batch_data = self._process_raw_data(raw_data)
  490. for data in batch_data:
  491. while True:
  492. qidx = qidx % self.num_workers
  493. try:
  494. self.raw_data_queues[qidx].put(data)
  495. break
  496. except queue.Full:
  497. if self.shutdown_flag.value == 1:
  498. break
  499. logger.debug("raw data queue %d is full" % qidx)
  500. finally:
  501. qidx += 1
  502. return qidx
  503. def _worker_to_raw_data_queues(self):
  504. dataset_iter = iter(self.dataset)
  505. qidx = 0
  506. while True:
  507. if self.shutdown_flag.value == 1:
  508. break
  509. raw_data = next(dataset_iter)
  510. qidx = self._put_raw_data_queues(raw_data, qidx)
  511. def _worker_to_trans_data_queues(self, worker_id):
  512. while True:
  513. if self.shutdown_flag.value == 1:
  514. break
  515. try:
  516. data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)
  517. except queue.Empty:
  518. continue
  519. trans_data = self.transform.apply(data)
  520. while True:
  521. try:
  522. self.trans_data_queues[worker_id].put(trans_data)
  523. break
  524. except queue.Full:
  525. if self.shutdown_flag.value == 1:
  526. break
  527. logger.debug("batch queue if full")
  528. def _worker_to_batch_queue(self):
  529. cnt = -1
  530. trans_items = []
  531. while True:
  532. if self.shutdown_flag.value == 1:
  533. break
  534. cnt += 1
  535. queue_id = cnt % self.num_workers
  536. try:
  537. trans_item = self.trans_data_queues[queue_id].get(
  538. timeout=GLOBAL_TIMEOUT
  539. )
  540. except queue.Empty:
  541. continue
  542. trans_items.append(trans_item)
  543. if len(trans_items) == self.sampler.batch_size:
  544. batch_data = self.collator.apply(trans_items)
  545. while True:
  546. try:
  547. self.batch_queue.put(batch_data, timeout=1)
  548. break
  549. except queue.Full:
  550. if self.shutdown_flag.value == 1:
  551. break
  552. logger.debug("batch queue is full")
  553. trans_items = []
  554. def _check_workers(self):
  555. if not self.collect_worker.is_alive():
  556. exitcode = self.collect_worker.exitcode
  557. if exitcode != 0:
  558. raise RuntimeError("collator worker died. {}".format(exitcode))
  559. for worker_id, worker in enumerate(self.transform_workers):
  560. if not worker.is_alive():
  561. exitcode = worker.exitcode
  562. if exitcode != 0:
  563. raise RuntimeError(
  564. "worker: {} died. {}".format(worker_id, exitcode)
  565. )
  566. def _get_next_batch(self):
  567. start_time = time.time()
  568. while True:
  569. self._check_workers()
  570. try:
  571. return self.batch_queue.get(timeout=1)
  572. except queue.Empty:
  573. logger.debug("batch queue empty!")
  574. waited_time = time.time() - start_time
  575. if self.timeout > 0 and waited_time > self.timeout:
  576. self._put_raw_data_queues(self.timeout_event(), 0)
  577. def _shutdown(self):
  578. with self.shutdown_flag.get_lock():
  579. self.shutdown_flag.value = 1
  580. if self.recieve_worker.is_alive():
  581. self.recieve_worker.terminate()
  582. self.recieve_worker.join()
  583. if self.collect_worker.is_alive():
  584. self.collect_worker.terminate()
  585. self.collect_worker.join()
  586. for worker in self.transform_workers:
  587. if worker.is_alive():
  588. worker.terminate()
  589. worker.join()
  590. for q in self.raw_data_queues:
  591. q.cancel_join_thread()
  592. q.close()
  593. for q in self.trans_data_queues:
  594. q.cancel_join_thread()
  595. q.close()
  596. self.batch_queue.cancel_join_thread()
  597. self.batch_queue.close()
  598. def __del__(self):
  599. if self.__initialized:
  600. self._shutdown()
  601. def _task_feeding_loop(
  602. indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
  603. ):
  604. # Feed the indices into the task queues
  605. while True:
  606. if shutdown_flag.value == 1:
  607. break
  608. batch_idx = feed_batch_idx.value
  609. try:
  610. indices = next(indices_iter)
  611. except StopIteration:
  612. break
  613. if divide:
  614. # make sure all task_queues is ready for put
  615. while any([q.full() for q in task_queues]):
  616. if shutdown_flag.value == 1:
  617. return
  618. # divide into small pieces, feed to different workers.
  619. sub_num = math.ceil(len(indices) / num_workers)
  620. for worker_id in range(num_workers):
  621. sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num]
  622. task_queues[worker_id].put((batch_idx, sub_indices))
  623. else:
  624. # distribute tasks to different workers uniformly.
  625. target_id = batch_idx % num_workers
  626. while task_queues[target_id].full():
  627. if shutdown_flag.value == 1:
  628. return
  629. task_queues[target_id].put((batch_idx, indices))
  630. with feed_batch_idx.get_lock():
  631. feed_batch_idx.value += 1
  632. def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag):
  633. # Get dataset items and do the transform
  634. random.seed(seed)
  635. np.random.seed(seed)
  636. while True:
  637. if shutdown_flag.value == 1:
  638. break
  639. try:
  640. batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT)
  641. except queue.Empty:
  642. continue
  643. if len(indices) > 0:
  644. items = [dataset[idx] for idx in indices]
  645. trans_items = transform.apply_batch(items)
  646. else:
  647. # in case of incomplete last batch
  648. trans_items = ()
  649. while True:
  650. try:
  651. trans_data_queue.put((batch_idx, trans_items), timeout=1)
  652. break
  653. except queue.Full:
  654. if shutdown_flag.value == 1:
  655. break
  656. logger.debug("batch part queue is full!")
  657. def _data_gathering_loop(
  658. trans_data_queues,
  659. batch_queue,
  660. collator,
  661. length,
  662. num_workers,
  663. shutdown_flag,
  664. target_idx,
  665. ):
  666. # Gathering the small pieces of batch data into full batch data
  667. while True:
  668. if shutdown_flag.value == 1:
  669. break
  670. target_batch_idx = target_idx.value
  671. if target_batch_idx >= length:
  672. break
  673. full_trans_items = []
  674. for worker_id in range(num_workers):
  675. while True:
  676. try:
  677. batch_idx, trans_items = trans_data_queues[worker_id].get(
  678. timeout=GLOBAL_TIMEOUT
  679. )
  680. break
  681. except queue.Empty:
  682. if shutdown_flag.value == 1:
  683. break
  684. logger.debug(
  685. "worker:{} data queue get timeout! target batch idx:{}".format(
  686. worker_id, target_batch_idx
  687. )
  688. )
  689. if batch_idx != target_batch_idx:
  690. raise RuntimeError(
  691. "Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
  692. worker_id
  693. )
  694. )
  695. else:
  696. full_trans_items.extend(trans_items)
  697. # Merge different parts into a batch.
  698. full_batch = collator.apply(full_trans_items)
  699. while True:
  700. try:
  701. batch_queue.put(full_batch, timeout=1)
  702. break
  703. except queue.Full:
  704. if shutdown_flag.value == 1:
  705. break
  706. logger.debug("batch queue is full!")
  707. with target_idx.get_lock():
  708. target_idx.value += 1
  709. batch_queue.disconnect_client()
  710. def _data_selecting_loop(
  711. trans_data_queues,
  712. batch_queue,
  713. collator,
  714. length,
  715. num_workers,
  716. shutdown_flag,
  717. target_idx,
  718. ):
  719. # Make sure that batch is generated exactly with the same order as generated indices
  720. while True:
  721. if shutdown_flag.value == 1:
  722. break
  723. target_batch_idx = target_idx.value
  724. if target_batch_idx >= length:
  725. break
  726. target_worker_id = target_batch_idx % num_workers
  727. while True:
  728. try:
  729. batch_idx, trans_items = trans_data_queues[target_worker_id].get(
  730. timeout=GLOBAL_TIMEOUT
  731. )
  732. batch_data = collator.apply(trans_items)
  733. break
  734. except queue.Empty:
  735. if shutdown_flag.value == 1:
  736. break
  737. logger.debug(
  738. "worker:{} data queue get timeout! target batch idx:{}".format(
  739. target_worker_id, target_batch_idx
  740. )
  741. )
  742. if batch_idx != target_batch_idx:
  743. raise RuntimeError(
  744. "batch_idx {} mismatch the target_batch_idx {}".format(
  745. batch_idx, target_batch_idx
  746. )
  747. )
  748. while True:
  749. try:
  750. batch_queue.put(batch_data, timeout=1)
  751. break
  752. except queue.Full:
  753. if shutdown_flag.value == 1:
  754. break
  755. logger.debug("batch queue is full!")
  756. with target_idx.get_lock():
  757. target_idx.value += 1
  758. batch_queue.disconnect_client()