You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697
  1. # -*- coding: utf-8 -*-
  2. import collections
  3. import gc
  4. import itertools
  5. import multiprocessing
  6. import os
  7. import platform
  8. import queue
  9. import random
  10. import threading
  11. import time
  12. import numpy as np
  13. from ..device import _sh, get_default_device
  14. from ..functional.tensor import copy
  15. from ..logger import get_logger
  16. from ..random.rng import _random_seed_generator
  17. from ..tensor import Tensor
  18. from .collator import Collator
  19. from .dataset import Dataset, StreamDataset
  20. from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
  21. from .transform import PseudoTransform, Transform
  22. try:
  23. import thread
  24. except:
  25. import _thread as thread
  26. logger = get_logger(__name__)
  27. GLOBAL_TIMEOUT = 5
  28. def raise_timeout_error():
  29. raise RuntimeError("dataloader timeout")
  30. class DataLoader:
  31. r"""Provides a convenient way to iterate on a given dataset.
  32. The process is as follows:
  33. .. mermaid::
  34. :align: center
  35. flowchart LR
  36. Dataset.__len__ -- Sampler --> Indices
  37. batch_size -- Sampler --> Indices
  38. Indices -- Dataset.__getitem__ --> Samples
  39. Samples -- Transform + Collator --> mini-batch
  40. DataLoader combines a :class:`~.Dataset` with
  41. :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`,
  42. make it flexible to get minibatch continually from a dataset.
  43. See :ref:`data-guide` for more details.
  44. Args:
  45. dataset: dataset from which to load the minibatch.
  46. sampler: defines the strategy to sample data from the dataset.
  47. If ``None``, it will sequentially sample from the dataset one by one.
  48. transform: defined the transforming strategy for a sampled batch.
  49. collator: defined the merging strategy for a transformed batch.
  50. num_workers: the number of sub-process to load, transform and collate
  51. the batch. ``0`` means using single-process. Default: 0
  52. timeout: if positive, means the timeout value(second) for collecting a
  53. batch from workers. Default: 0
  54. preload: whether to enable the preloading strategy of the dataloader.
  55. When enabling, the dataloader will preload one batch to the device memory to speed up the whole training process.
  56. .. admonition:: The effect of enabling preload
  57. :class: warning
  58. * All elements in :class:`map`, :class:`list`, and :class:`tuple` will be converted to :class:`~.Tensor` by preloading,
  59. and you will get :class:`~.Tensor` instead of the original Numpy array or Python built-in data structrure.
  60. * Tensors' host2device copy and device kernel execution will be overlapped,
  61. which will improve the training speed at the cost of **higher device memory usage** (due to one more batch data on device memory).
  62. This feature saves more time when your NN training time is short or your machine's host PCIe bandwidth for each device is low.
  63. """
  64. def __init__(
  65. self,
  66. dataset: Dataset,
  67. sampler: Sampler = None,
  68. transform: Transform = None,
  69. collator: Collator = None,
  70. num_workers: int = 0,
  71. timeout: int = 0,
  72. preload: bool = False,
  73. parallel_stream: bool = False,
  74. ):
  75. if num_workers < 0:
  76. raise ValueError("num_workers should not be negative")
  77. if timeout < 0:
  78. raise ValueError("timeout should not be negative")
  79. self.dataset = dataset
  80. self.num_workers = num_workers
  81. self.timeout = timeout
  82. self.preload = preload
  83. self.parallel_stream = parallel_stream
  84. if isinstance(dataset, StreamDataset):
  85. self.sampler = sampler if sampler else StreamSampler(batch_size=1)
  86. assert isinstance(
  87. self.sampler, StreamSampler
  88. ), "types of dataset and sampler do not match"
  89. if parallel_stream is False and self.num_workers > 1:
  90. logger.warning(
  91. "Data time will be affected by getting origin-data, please set parallel_stream in order to speed up dataloader!"
  92. )
  93. self.datakind = "stream"
  94. else:
  95. assert isinstance(
  96. dataset, Dataset
  97. ), "Can not recognize this kind of dataset: %s" % type(dataset)
  98. self.sampler = (
  99. sampler
  100. if sampler
  101. else SequentialSampler(dataset, batch_size=1, drop_last=False)
  102. )
  103. assert isinstance(
  104. self.sampler, MapSampler
  105. ), "types of dataset and sampler do not match"
  106. self.datakind = "map"
  107. if transform is None:
  108. self.transform = PseudoTransform()
  109. else:
  110. self.transform = transform
  111. if collator is None:
  112. self.collator = Collator()
  113. else:
  114. self.collator = collator
  115. if platform.system() == "Linux" and self.num_workers > 0:
  116. self.check_memory_rationality()
  117. def __iter__(self):
  118. if platform.system() == "Windows" and self.num_workers > 0:
  119. print(
  120. "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
  121. )
  122. self.num_workers = 0
  123. if os.getenv("TERMUX_VERSION"):
  124. # FIXME: termux install pyarrow will build error now
  125. # remove this logic after pyarrow fix this issue
  126. print(
  127. "pyarrow do not support on termux env now, changing num_workers to be zero"
  128. )
  129. self.num_workers = 0
  130. if isinstance(self.dataset, StreamDataset):
  131. if not self.num_workers:
  132. return _SerialStreamDataLoaderIter(self, self.preload)
  133. else:
  134. return _ParallelStreamDataLoaderIter(self, self.preload)
  135. else:
  136. assert isinstance(
  137. self.dataset, Dataset
  138. ), "Can not recognize this kind of dataset: %s" % type(self.dataset)
  139. if not self.num_workers:
  140. return _SerialMapDataLoaderIter(self, self.preload)
  141. else:
  142. return _ParallelMapDataLoaderIter(self, self.preload)
  143. def __len__(self):
  144. return len(self.sampler)
  145. def check_memory_rationality(self):
  146. import psutil
  147. main_memory = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024
  148. total_memory = (self.num_workers + 1) * main_memory
  149. current_memory = (
  150. int(os.popen("cat /sys/fs/cgroup/memory/memory.limit_in_bytes").read())
  151. / 1024
  152. / 1024
  153. / 1024
  154. )
  155. if current_memory < total_memory:
  156. logger.warning(
  157. "Each worker need to read the shared meta-data, which will be increasing the reference count."
  158. "Copy-On-Write propety will lead to 'memory leak', the memory usage will end up being "
  159. + total_memory
  160. + " GB"
  161. "However the current requested memory is " + current_memory + " GB"
  162. "Maybe you can request more memory or uesd np-array to save meta-data rather than List or Tuple"
  163. )
  164. class PreLoader:
  165. def __init__(self, loader, preload):
  166. self.dataset = loader.dataset
  167. self.sampler = loader.sampler
  168. self.seed = _random_seed_generator().__next__()
  169. self.transform = loader.transform
  170. self.collator = loader.collator
  171. self.num_workers = loader.num_workers
  172. self.timeout = loader.timeout
  173. self.num_processed = 0
  174. self.datakind = loader.datakind
  175. self.parallel_stream = loader.parallel_stream
  176. if preload:
  177. self.default_device = get_default_device()
  178. self.pre_load_device = self.default_device + ":" + str(_sh.get_next())
  179. self.pre_load_device_cache = None
  180. self.preload = preload
  181. def __iter__(self):
  182. return self
  183. """
  184. strategy one: load from numpy data, and generate dtype tensor
  185. """
  186. def _load_tensor(self, batch, cached=True):
  187. if isinstance(batch, np.ndarray):
  188. device = self.pre_load_device if cached else self.default_device
  189. return Tensor(batch, device=device)
  190. elif isinstance(batch, collections.abc.Mapping):
  191. return {k: self._load_tensor(v, cached) for k, v in batch.items()}
  192. elif isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple
  193. return type(batch)(*(self._load_tensor(value, cached) for value in batch))
  194. elif isinstance(batch, collections.abc.Sequence):
  195. return [self._load_tensor(value, cached) for value in batch]
  196. else:
  197. return batch
  198. """
  199. strategy two: load from cache that is already tensor just do d2d copy
  200. """
  201. def _load_cache(self, data):
  202. if isinstance(data, Tensor):
  203. if data.device == self.default_device:
  204. return data
  205. return copy(data, device=self.default_device)
  206. elif isinstance(data, collections.abc.Mapping):
  207. return {k: self._load_cache(v) for k, v in data.items()}
  208. elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
  209. return type(data)(*(self._load_cache(value) for value in data))
  210. elif isinstance(data, collections.abc.Sequence):
  211. return [self._load_cache(value) for value in data]
  212. else:
  213. return data
  214. def _swap_out_cache(self):
  215. out = self._load_cache(self.pre_load_device_cache)
  216. self.pre_load_device_cache = None # clean cache
  217. return out
  218. class _ParallelDataLoaderIter:
  219. def __init__(self):
  220. self._worker_queue_idx_cycle = itertools.cycle(range(self.num_workers))
  221. from .tools._queue import PlasmaShmQueue
  222. self._worker_result_queue = PlasmaShmQueue()
  223. self._shutdown = False
  224. self._workers_done_event = multiprocessing.Event()
  225. self._index_queues = []
  226. self._workers = []
  227. for i in range(self.num_workers):
  228. index_queue = multiprocessing.Queue()
  229. index_queue.cancel_join_thread()
  230. w = multiprocessing.Process(
  231. target=_worker_loop,
  232. args=(
  233. self.dataset,
  234. index_queue,
  235. self._worker_result_queue,
  236. self._workers_done_event,
  237. self.transform,
  238. self.collator,
  239. self.sampler.batch_size,
  240. self.seed + i,
  241. i,
  242. self.num_workers,
  243. self.datakind,
  244. self.parallel_stream,
  245. ),
  246. daemon=True,
  247. )
  248. gc.collect()
  249. w.start()
  250. self._index_queues.append(index_queue)
  251. self._workers.append(w)
  252. self._data_queue = self._worker_result_queue
  253. self._reset()
  254. def _try_put_index(self):
  255. raise NotImplementedError
  256. def _reset(self):
  257. self._sampler_iter = iter(self.sampler)
  258. self._send_idx = 0
  259. self._rcvd_idx = 0
  260. self._task_info = {}
  261. self._workers_status = [True for _ in range(self.num_workers)]
  262. for _ in range(2 * self.num_workers):
  263. self._try_put_index()
  264. def _process_data(self, data):
  265. self._rcvd_idx += 1
  266. self._try_put_index()
  267. return data
  268. def _get_data(self):
  269. if self.timeout > 0:
  270. success, data = self._try_get_data(self.timeout)
  271. if success:
  272. return data
  273. else:
  274. raise_timeout_error()
  275. else:
  276. while True:
  277. success, data = self._try_get_data()
  278. if success:
  279. return data
  280. def _get_next_batch(self):
  281. while True:
  282. while self._rcvd_idx < self._send_idx:
  283. info = self._task_info[self._rcvd_idx]
  284. worker_id = info[0]
  285. if (
  286. len(info) == 2 or self._workers_status[worker_id]
  287. ): # has data or work is still active
  288. break
  289. del self._task_info[self._rcvd_idx]
  290. self._rcvd_idx += 1
  291. else:
  292. self._shutdown_workers()
  293. raise StopIteration
  294. if len(self._task_info[self._rcvd_idx]) == 2:
  295. data = self._task_info.pop(self._rcvd_idx)[1]
  296. return self._process_data(data)
  297. idx, data = self._get_data()
  298. if isinstance(data, int): # Check if StopIteration in StreamDataset
  299. self._mark_worker_as_unavailable(data)
  300. self._try_put_index()
  301. continue
  302. if idx != self._rcvd_idx:
  303. self._task_info[idx] += (data,)
  304. else:
  305. del self._task_info[idx]
  306. return self._process_data(data)
  307. def _try_get_data(self, timeout=GLOBAL_TIMEOUT):
  308. try:
  309. data = self._data_queue.get(timeout=timeout)
  310. return (True, data)
  311. except Exception as e:
  312. failed_workers = []
  313. for worker_id, w in enumerate(self._workers):
  314. if self._workers_status[worker_id] and not w.is_alive():
  315. failed_workers.append((worker_id, w))
  316. self._mark_worker_as_unavailable(worker_id)
  317. if w.exitcode == -9:
  318. logger.debug(
  319. "Maybe memory is not enough, please request for more memory!"
  320. )
  321. if len(failed_workers) > 0:
  322. pids_str = ", ".join(str(w_info[1].pid) for w_info in failed_workers)
  323. w_ids_str = ", ".join(str(w_info[0]) for w_info in failed_workers)
  324. exitcode_str = ", ".join(
  325. str(w_info[1].exitcode) for w_info in failed_workers
  326. )
  327. raise RuntimeError(
  328. "DataLoader worker (worker(s): {} , pid(s): {}) exited unexpectedly, exitcode(s): {}".format(
  329. w_ids_str, pids_str, exitcode_str
  330. )
  331. )
  332. if isinstance(e, queue.Empty):
  333. return (False, None)
  334. def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
  335. q = self._index_queues[worker_id]
  336. q.put(None)
  337. self._workers_status[worker_id] = False
  338. assert self._workers_done_event.is_set() == shutdown
  339. def _shutdown_workers(self):
  340. if not self._shutdown:
  341. self._shutdown = True
  342. try:
  343. self._workers_done_event.set()
  344. for worker_id in range(len(self._workers)):
  345. if self._workers_status[worker_id]:
  346. self._mark_worker_as_unavailable(worker_id, shutdown=True)
  347. for w in self._workers:
  348. w.join(timeout=GLOBAL_TIMEOUT)
  349. for q in self._index_queues:
  350. q.cancel_join_thread()
  351. q.close()
  352. self._data_queue.cancel_join_thread()
  353. self._data_queue.close()
  354. finally:
  355. for w in self._workers:
  356. if w.is_alive():
  357. w.terminate()
  358. def __del__(self):
  359. self._shutdown_workers()
  360. class _BaseMapDataLoaderIter(PreLoader):
  361. def __init__(self, loader, preload):
  362. super().__init__(loader, preload)
  363. def __len__(self):
  364. return len(self.sampler)
  365. def __next__(self):
  366. if self.preload:
  367. cached = self.pre_load_device_cache
  368. if cached is None: # first and last
  369. if self.num_processed >= len(self): # last
  370. raise StopIteration
  371. elif self.num_processed == 0: # first
  372. self._try_load_tensor(cached=False) # first do the h2d
  373. out = self._swap_out_cache()
  374. self._try_load_tensor()
  375. return out
  376. else:
  377. data = self._get_next_batch()
  378. return data
  379. def _try_load_tensor(self, cached=True):
  380. if self.num_processed >= len(self):
  381. return
  382. else:
  383. self.num_processed += 1
  384. batch = self._get_next_batch()
  385. self.pre_load_device_cache = self._load_tensor(batch, cached)
  386. class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
  387. def __init__(self, loader, preload):
  388. super(_SerialMapDataLoaderIter, self).__init__(loader, preload)
  389. self._sampler_iter = iter(self.sampler)
  390. def _get_next_batch(self):
  391. indices = next(self._sampler_iter)
  392. items = [self.dataset[idx] for idx in indices]
  393. trans_items = self.transform.apply_batch(items)
  394. return self.collator.apply(trans_items)
  395. class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter, _ParallelDataLoaderIter):
  396. def __init__(self, loader, preload):
  397. _BaseMapDataLoaderIter.__init__(self, loader, preload)
  398. _ParallelDataLoaderIter.__init__(self)
  399. def _try_put_index(self):
  400. try:
  401. index = next(self._sampler_iter)
  402. except StopIteration:
  403. return
  404. for _ in range(self.num_workers): # find the next active worker, if any
  405. worker_queue_idx = next(self._worker_queue_idx_cycle)
  406. if self._workers_status[worker_queue_idx]:
  407. break
  408. self._index_queues[worker_queue_idx].put((self._send_idx, index))
  409. self._task_info[self._send_idx] = (worker_queue_idx,)
  410. self._send_idx += 1
  411. _worker_info = None
  412. class WorkerInfo(object):
  413. __initialized = False
  414. def __init__(self, **kwargs):
  415. for k, v in kwargs.items():
  416. setattr(self, k, v)
  417. self.__keys = tuple(kwargs.keys())
  418. self.__initialized = True
  419. def __setattr__(self, key, val):
  420. if self.__initialized:
  421. raise RuntimeError(
  422. "Cannot assign attributes to {} objects".format(self.__class__.__name__)
  423. )
  424. return super(WorkerInfo, self).__setattr__(key, val)
  425. def __repr__(self):
  426. items = []
  427. for k in self.__keys:
  428. items.append("{}={}".format(k, getattr(self, k)))
  429. return "{}({})".format(self.__class__.__name__, ", ".join(items))
  430. def get_worker_info():
  431. return _worker_info
  432. class _BaseStreamDataLoaderIter(PreLoader):
  433. def __init__(self, loader, preload):
  434. super().__init__(loader, preload)
  435. self.dataset_iter = iter(self.dataset)
  436. def __next__(self):
  437. if self.preload:
  438. if self.pre_load_device_cache is None:
  439. self._try_load_tensor(cached=False) # load in current
  440. out = self._swap_out_cache()
  441. self._try_load_tensor() # load in cached
  442. return out
  443. else:
  444. return self._get_next_batch()
  445. def _try_load_tensor(self, cached=True):
  446. batch = self._get_next_batch()
  447. self.pre_load_device_cache = self._load_tensor(batch, cached)
  448. class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  449. def __init__(self, loader, preload):
  450. super().__init__(loader, preload)
  451. self.dataset_iter = iter(self.dataset)
  452. def _try_get_raw_data(self, start_time):
  453. raw_data = None
  454. while not raw_data:
  455. try:
  456. if self.timeout > 0:
  457. timer = threading.Timer(self.timeout, thread.interrupt_main)
  458. timer.start()
  459. raw_data = next(self.dataset_iter)
  460. if self.timeout > 0:
  461. timer.cancel()
  462. except AttributeError as error:
  463. raise error
  464. except:
  465. if self.timeout > 0:
  466. timer.cancel()
  467. waited_time = time.time() - start_time
  468. if waited_time > self.timeout:
  469. raise_timeout_error()
  470. return raw_data
  471. def _get_next_batch(self):
  472. ret = []
  473. start_time = time.time()
  474. while len(ret) < self.sampler.batch_size:
  475. raw_data = self._try_get_raw_data(start_time)
  476. ret.append(self.transform.apply(raw_data))
  477. return self.collator.apply(ret)
  478. class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter, _ParallelDataLoaderIter):
  479. def __init__(self, loader, preload):
  480. _BaseStreamDataLoaderIter.__init__(self, loader, preload)
  481. _ParallelDataLoaderIter.__init__(self)
  482. def _get_remaind_data(self, place_holder):
  483. num = self.sampler.batch_size
  484. for _ in range(num - 1):
  485. place_holder.append(next(self.dataset_iter))
  486. return place_holder
  487. def _try_put_index(self):
  488. try:
  489. if self.parallel_stream is False:
  490. start_time = time.time()
  491. place_holder = [next(self.dataset_iter)]
  492. waited_time = time.time() - start_time
  493. if self.timeout > 0 and waited_time > self.timeout:
  494. raise_timeout_error()
  495. place_holder = self._get_remaind_data(place_holder)
  496. else:
  497. place_holder = next(self._sampler_iter)
  498. except StopIteration:
  499. return
  500. for _ in range(self.num_workers):
  501. worker_queue_idx = next(self._worker_queue_idx_cycle)
  502. if self._workers_status[worker_queue_idx]:
  503. break
  504. else:
  505. return
  506. self._index_queues[worker_queue_idx].put((self._send_idx, place_holder))
  507. self._task_info[self._send_idx] = (worker_queue_idx,)
  508. self._send_idx += 1
  509. class ManagerWatchdog(object):
  510. def __init__(self):
  511. self.manager_pid = os.getppid()
  512. self.manager_dead = False
  513. def is_alive(self):
  514. if not self.manager_dead:
  515. self.manager_dead = os.getppid() != self.manager_pid
  516. return not self.manager_dead
  517. def stream_fetcher(
  518. dataset_iter, place_holder, transform, collate, parallel_stream, batch_size
  519. ):
  520. data = []
  521. for idx in place_holder:
  522. try:
  523. if parallel_stream is False:
  524. raw_data = idx
  525. else:
  526. raw_data = next(dataset_iter)
  527. trans_items = transform.apply(raw_data)
  528. data.append(trans_items)
  529. except StopIteration:
  530. break
  531. if len(data) == 0:
  532. raise StopIteration
  533. data = collate.apply(data)
  534. return data
  535. def map_fetcher(dataset, place_holder, transform, collate, parallel_stream, batch_size):
  536. items = [dataset[idx] for idx in place_holder]
  537. trans_items = transform.apply_batch(items)
  538. data = collate.apply(trans_items)
  539. return data
  540. def _worker_loop(
  541. dataset,
  542. index_queue,
  543. data_queue,
  544. done_event,
  545. transform,
  546. collate,
  547. batch_size,
  548. seed,
  549. worker_id,
  550. num_workers,
  551. datakind,
  552. parallel_stream,
  553. ):
  554. random.seed(seed)
  555. np.random.seed(seed)
  556. watchdog = ManagerWatchdog()
  557. iteration_end = False
  558. fetcher = map_fetcher
  559. if datakind == "stream":
  560. global _worker_info
  561. _worker_info = WorkerInfo(idx=worker_id, worker=num_workers, seed=seed)
  562. dataset = iter(dataset)
  563. fetcher = stream_fetcher
  564. while watchdog.is_alive():
  565. try:
  566. r = index_queue.get(timeout=GLOBAL_TIMEOUT)
  567. except queue.Empty:
  568. continue
  569. if r is None:
  570. assert done_event.is_set() or iteration_end
  571. break
  572. elif done_event.is_set() or iteration_end:
  573. continue
  574. idx, place_holder = r
  575. try:
  576. data = fetcher(
  577. dataset, place_holder, transform, collate, parallel_stream, batch_size
  578. )
  579. except Exception as e:
  580. if isinstance(e, StopIteration) and datakind == "stream":
  581. data = worker_id
  582. iteration_end = True
  583. else:
  584. raise e
  585. data_queue.put((idx, data))
  586. del data, idx, place_holder, r
  587. if done_event.is_set():
  588. data_queue.disconnect_client()
  589. data_queue.close()