You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699
  1. # -*- coding: utf-8 -*-
  2. import collections
  3. import gc
  4. import itertools
  5. import multiprocessing
  6. import os
  7. import platform
  8. import queue
  9. import random
  10. import threading
  11. import time
  12. import numpy as np
  13. from ..device import _sh, get_default_device
  14. from ..functional.tensor import copy
  15. from ..logger import get_logger
  16. from ..random.rng import _random_seed_generator
  17. from ..tensor import Tensor
  18. from .collator import Collator
  19. from .dataset import Dataset, StreamDataset
  20. from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
  21. from .transform import PseudoTransform, Transform
  22. try:
  23. import thread
  24. except:
  25. import _thread as thread
  26. logger = get_logger(__name__)
  27. GLOBAL_TIMEOUT = 5
  28. def _raise_timeout_error():
  29. raise RuntimeError("dataloader timeout")
  30. class DataLoader:
  31. r"""Provides a convenient way to iterate on a given dataset.
  32. The process is as follows:
  33. .. mermaid::
  34. :align: center
  35. flowchart LR
  36. Dataset.__len__ -- Sampler --> Indices
  37. batch_size -- Sampler --> Indices
  38. Indices -- Dataset.__getitem__ --> Samples
  39. Samples -- Transform + Collator --> mini-batch
  40. DataLoader combines a :class:`~.Dataset` with
  41. :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`,
  42. make it flexible to get minibatch continually from a dataset.
  43. See :ref:`data-guide` for more details.
  44. Args:
  45. dataset: dataset from which to load the minibatch.
  46. sampler: defines the strategy to sample data from the dataset.
  47. If ``None``, it will sequentially sample from the dataset one by one.
  48. transform: defined the transforming strategy for a sampled batch.
  49. collator: defined the merging strategy for a transformed batch.
  50. num_workers: the number of sub-process to load, transform and collate
  51. the batch. ``0`` means using single-process. Default: 0
  52. timeout: if positive, means the timeout value(second) for collecting a
  53. batch from workers. Default: 0
  54. preload: whether to enable the preloading strategy of the dataloader.
  55. When enabling, the dataloader will preload one batch to the device memory to speed up the whole training process.
  56. .. admonition:: The effect of enabling preload
  57. :class: warning
  58. * All elements in :class:`map`, :class:`list`, and :class:`tuple` will be converted to :class:`~.Tensor` by preloading,
  59. and you will get :class:`~.Tensor` instead of the original Numpy array or Python built-in data structrure.
  60. * Tensors' host2device copy and device kernel execution will be overlapped,
  61. which will improve the training speed at the cost of **higher device memory usage** (due to one more batch data on device memory).
  62. This feature saves more time when your NN training time is short or your machine's host PCIe bandwidth for each device is low.
  63. """
  64. def __init__(
  65. self,
  66. dataset: Dataset,
  67. sampler: Sampler = None,
  68. transform: Transform = None,
  69. collator: Collator = None,
  70. num_workers: int = 0,
  71. timeout: int = 0,
  72. preload: bool = False,
  73. parallel_stream: bool = False,
  74. ):
  75. if num_workers < 0:
  76. raise ValueError("num_workers should not be negative")
  77. if timeout < 0:
  78. raise ValueError("timeout should not be negative")
  79. self.dataset = dataset
  80. self.num_workers = num_workers
  81. self.timeout = timeout
  82. self.preload = preload
  83. self.parallel_stream = parallel_stream
  84. if isinstance(dataset, StreamDataset):
  85. self.sampler = sampler if sampler else StreamSampler(batch_size=1)
  86. assert isinstance(
  87. self.sampler, StreamSampler
  88. ), "types of dataset and sampler do not match"
  89. if parallel_stream is False and self.num_workers > 1:
  90. logger.warning(
  91. "Data time will be affected by getting origin-data, please set parallel_stream in order to speed up dataloader!"
  92. )
  93. self.datakind = "stream"
  94. else:
  95. assert isinstance(
  96. dataset, Dataset
  97. ), "Can not recognize this kind of dataset: %s" % type(dataset)
  98. self.sampler = (
  99. sampler
  100. if sampler
  101. else SequentialSampler(dataset, batch_size=1, drop_last=False)
  102. )
  103. assert isinstance(
  104. self.sampler, MapSampler
  105. ), "types of dataset and sampler do not match"
  106. self.datakind = "map"
  107. if transform is None:
  108. self.transform = PseudoTransform()
  109. else:
  110. self.transform = transform
  111. if collator is None:
  112. self.collator = Collator()
  113. else:
  114. self.collator = collator
  115. if platform.system() == "Linux" and self.num_workers > 0:
  116. self.check_memory_rationality()
  117. def __iter__(self):
  118. if platform.system() == "Windows" and self.num_workers > 0:
  119. print(
  120. "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
  121. )
  122. self.num_workers = 0
  123. if os.getenv("TERMUX_VERSION"):
  124. # FIXME: termux install pyarrow will build error now
  125. # remove this logic after pyarrow fix this issue
  126. print(
  127. "pyarrow do not support on termux env now, changing num_workers to be zero"
  128. )
  129. self.num_workers = 0
  130. if isinstance(self.dataset, StreamDataset):
  131. if not self.num_workers:
  132. return _SerialStreamDataLoaderIter(self, self.preload)
  133. else:
  134. return _ParallelStreamDataLoaderIter(self, self.preload)
  135. else:
  136. assert isinstance(
  137. self.dataset, Dataset
  138. ), "Can not recognize this kind of dataset: %s" % type(self.dataset)
  139. if not self.num_workers:
  140. return _SerialMapDataLoaderIter(self, self.preload)
  141. else:
  142. return _ParallelMapDataLoaderIter(self, self.preload)
  143. def __len__(self):
  144. return len(self.sampler)
  145. def check_memory_rationality(self):
  146. import psutil
  147. main_memory = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024
  148. total_memory = (self.num_workers + 1) * main_memory
  149. current_memory = (
  150. int(os.popen("cat /sys/fs/cgroup/memory/memory.limit_in_bytes").read())
  151. / 1024
  152. / 1024
  153. / 1024
  154. )
  155. if current_memory < total_memory:
  156. logger.warning(
  157. "Each worker need to read the shared meta-data, which will be increasing the reference count."
  158. "Copy-On-Write propety will lead to 'memory leak', the memory usage will end up being "
  159. + str(total_memory)
  160. + " GB, "
  161. "However the current requested memory is "
  162. + str(current_memory)
  163. + " GB, "
  164. "Maybe you can request more memory or uesd np-array to save meta-data rather than List or Tuple"
  165. )
  166. class _PreLoader:
  167. def __init__(self, loader, preload):
  168. self.dataset = loader.dataset
  169. self.sampler = loader.sampler
  170. self.seed = _random_seed_generator().__next__()
  171. self.transform = loader.transform
  172. self.collator = loader.collator
  173. self.num_workers = loader.num_workers
  174. self.timeout = loader.timeout
  175. self.num_processed = 0
  176. self.datakind = loader.datakind
  177. self.parallel_stream = loader.parallel_stream
  178. if preload:
  179. self.default_device = get_default_device()
  180. self.pre_load_device = self.default_device + ":" + str(_sh.get_next())
  181. self.pre_load_device_cache = None
  182. self.preload = preload
  183. def __iter__(self):
  184. return self
  185. """
  186. strategy one: load from numpy data, and generate dtype tensor
  187. """
  188. def _load_tensor(self, batch, cached=True):
  189. if isinstance(batch, np.ndarray):
  190. device = self.pre_load_device if cached else self.default_device
  191. return Tensor(batch, device=device)
  192. elif isinstance(batch, collections.abc.Mapping):
  193. return {k: self._load_tensor(v, cached) for k, v in batch.items()}
  194. elif isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple
  195. return type(batch)(*(self._load_tensor(value, cached) for value in batch))
  196. elif isinstance(batch, collections.abc.Sequence):
  197. return [self._load_tensor(value, cached) for value in batch]
  198. else:
  199. return batch
  200. """
  201. strategy two: load from cache that is already tensor just do d2d copy
  202. """
  203. def _load_cache(self, data):
  204. if isinstance(data, Tensor):
  205. if data.device == self.default_device:
  206. return data
  207. return copy(data, device=self.default_device)
  208. elif isinstance(data, collections.abc.Mapping):
  209. return {k: self._load_cache(v) for k, v in data.items()}
  210. elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
  211. return type(data)(*(self._load_cache(value) for value in data))
  212. elif isinstance(data, collections.abc.Sequence):
  213. return [self._load_cache(value) for value in data]
  214. else:
  215. return data
  216. def _swap_out_cache(self):
  217. out = self._load_cache(self.pre_load_device_cache)
  218. self.pre_load_device_cache = None # clean cache
  219. return out
  220. class _ParallelDataLoaderIter:
  221. def __init__(self):
  222. self._worker_queue_idx_cycle = itertools.cycle(range(self.num_workers))
  223. from .tools._queue import PlasmaShmQueue
  224. self._worker_result_queue = PlasmaShmQueue()
  225. self._shutdown = False
  226. self._workers_done_event = multiprocessing.Event()
  227. self._index_queues = []
  228. self._workers = []
  229. for i in range(self.num_workers):
  230. index_queue = multiprocessing.Queue()
  231. index_queue.cancel_join_thread()
  232. w = multiprocessing.Process(
  233. target=_worker_loop,
  234. args=(
  235. self.dataset,
  236. index_queue,
  237. self._worker_result_queue,
  238. self._workers_done_event,
  239. self.transform,
  240. self.collator,
  241. self.sampler.batch_size,
  242. self.seed + i,
  243. i,
  244. self.num_workers,
  245. self.datakind,
  246. self.parallel_stream,
  247. ),
  248. daemon=True,
  249. )
  250. gc.collect()
  251. w.start()
  252. self._index_queues.append(index_queue)
  253. self._workers.append(w)
  254. self._data_queue = self._worker_result_queue
  255. self._reset()
  256. def _try_put_index(self):
  257. raise NotImplementedError
  258. def _reset(self):
  259. self._sampler_iter = iter(self.sampler)
  260. self._send_idx = 0
  261. self._rcvd_idx = 0
  262. self._task_info = {}
  263. self._workers_status = [True for _ in range(self.num_workers)]
  264. for _ in range(2 * self.num_workers):
  265. self._try_put_index()
  266. def _process_data(self, data):
  267. self._rcvd_idx += 1
  268. self._try_put_index()
  269. return data
  270. def _get_data(self):
  271. if self.timeout > 0:
  272. success, data = self._try_get_data(self.timeout)
  273. if success:
  274. return data
  275. else:
  276. _raise_timeout_error()
  277. else:
  278. while True:
  279. success, data = self._try_get_data()
  280. if success:
  281. return data
  282. def _get_next_batch(self):
  283. while True:
  284. while self._rcvd_idx < self._send_idx:
  285. info = self._task_info[self._rcvd_idx]
  286. worker_id = info[0]
  287. if (
  288. len(info) == 2 or self._workers_status[worker_id]
  289. ): # has data or work is still active
  290. break
  291. del self._task_info[self._rcvd_idx]
  292. self._rcvd_idx += 1
  293. else:
  294. self._shutdown_workers()
  295. raise StopIteration
  296. if len(self._task_info[self._rcvd_idx]) == 2:
  297. data = self._task_info.pop(self._rcvd_idx)[1]
  298. return self._process_data(data)
  299. idx, data = self._get_data()
  300. if isinstance(data, int): # Check if StopIteration in StreamDataset
  301. self._mark_worker_as_unavailable(data)
  302. self._try_put_index()
  303. continue
  304. if idx != self._rcvd_idx:
  305. self._task_info[idx] += (data,)
  306. else:
  307. del self._task_info[idx]
  308. return self._process_data(data)
  309. def _try_get_data(self, timeout=GLOBAL_TIMEOUT):
  310. try:
  311. data = self._data_queue.get(timeout=timeout)
  312. return (True, data)
  313. except Exception as e:
  314. failed_workers = []
  315. for worker_id, w in enumerate(self._workers):
  316. if self._workers_status[worker_id] and not w.is_alive():
  317. failed_workers.append((worker_id, w))
  318. self._mark_worker_as_unavailable(worker_id)
  319. if w.exitcode == -9:
  320. logger.debug(
  321. "Maybe memory is not enough, please request for more memory!"
  322. )
  323. if len(failed_workers) > 0:
  324. pids_str = ", ".join(str(w_info[1].pid) for w_info in failed_workers)
  325. w_ids_str = ", ".join(str(w_info[0]) for w_info in failed_workers)
  326. exitcode_str = ", ".join(
  327. str(w_info[1].exitcode) for w_info in failed_workers
  328. )
  329. raise RuntimeError(
  330. "DataLoader worker (worker(s): {} , pid(s): {}) exited unexpectedly, exitcode(s): {}".format(
  331. w_ids_str, pids_str, exitcode_str
  332. )
  333. )
  334. if isinstance(e, queue.Empty):
  335. return (False, None)
  336. def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
  337. q = self._index_queues[worker_id]
  338. q.put(None)
  339. self._workers_status[worker_id] = False
  340. assert self._workers_done_event.is_set() == shutdown
  341. def _shutdown_workers(self):
  342. if not self._shutdown:
  343. self._shutdown = True
  344. try:
  345. self._workers_done_event.set()
  346. for worker_id in range(len(self._workers)):
  347. if self._workers_status[worker_id]:
  348. self._mark_worker_as_unavailable(worker_id, shutdown=True)
  349. for w in self._workers:
  350. w.join(timeout=GLOBAL_TIMEOUT)
  351. for q in self._index_queues:
  352. q.cancel_join_thread()
  353. q.close()
  354. self._data_queue.cancel_join_thread()
  355. self._data_queue.close()
  356. finally:
  357. for w in self._workers:
  358. if w.is_alive():
  359. w.terminate()
  360. def __del__(self):
  361. self._shutdown_workers()
  362. class _BaseMapDataLoaderIter(_PreLoader):
  363. def __init__(self, loader, preload):
  364. super().__init__(loader, preload)
  365. def __len__(self):
  366. return len(self.sampler)
  367. def __next__(self):
  368. if self.preload:
  369. cached = self.pre_load_device_cache
  370. if cached is None: # first and last
  371. if self.num_processed >= len(self): # last
  372. raise StopIteration
  373. elif self.num_processed == 0: # first
  374. self._try_load_tensor(cached=False) # first do the h2d
  375. out = self._swap_out_cache()
  376. self._try_load_tensor()
  377. return out
  378. else:
  379. data = self._get_next_batch()
  380. return data
  381. def _try_load_tensor(self, cached=True):
  382. if self.num_processed >= len(self):
  383. return
  384. else:
  385. self.num_processed += 1
  386. batch = self._get_next_batch()
  387. self.pre_load_device_cache = self._load_tensor(batch, cached)
  388. class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
  389. def __init__(self, loader, preload):
  390. super(_SerialMapDataLoaderIter, self).__init__(loader, preload)
  391. self._sampler_iter = iter(self.sampler)
  392. def _get_next_batch(self):
  393. indices = next(self._sampler_iter)
  394. items = [self.dataset[idx] for idx in indices]
  395. trans_items = self.transform.apply_batch(items)
  396. return self.collator.apply(trans_items)
  397. class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter, _ParallelDataLoaderIter):
  398. def __init__(self, loader, preload):
  399. _BaseMapDataLoaderIter.__init__(self, loader, preload)
  400. _ParallelDataLoaderIter.__init__(self)
  401. def _try_put_index(self):
  402. try:
  403. index = next(self._sampler_iter)
  404. except StopIteration:
  405. return
  406. for _ in range(self.num_workers): # find the next active worker, if any
  407. worker_queue_idx = next(self._worker_queue_idx_cycle)
  408. if self._workers_status[worker_queue_idx]:
  409. break
  410. self._index_queues[worker_queue_idx].put((self._send_idx, index))
  411. self._task_info[self._send_idx] = (worker_queue_idx,)
  412. self._send_idx += 1
  413. _worker_info = None
  414. class WorkerInfo(object):
  415. __initialized = False
  416. def __init__(self, **kwargs):
  417. for k, v in kwargs.items():
  418. setattr(self, k, v)
  419. self.__keys = tuple(kwargs.keys())
  420. self.__initialized = True
  421. def __setattr__(self, key, val):
  422. if self.__initialized:
  423. raise RuntimeError(
  424. "Cannot assign attributes to {} objects".format(self.__class__.__name__)
  425. )
  426. return super(WorkerInfo, self).__setattr__(key, val)
  427. def __repr__(self):
  428. items = []
  429. for k in self.__keys:
  430. items.append("{}={}".format(k, getattr(self, k)))
  431. return "{}({})".format(self.__class__.__name__, ", ".join(items))
  432. def get_worker_info():
  433. return _worker_info
  434. class _BaseStreamDataLoaderIter(_PreLoader):
  435. def __init__(self, loader, preload):
  436. super().__init__(loader, preload)
  437. self.dataset_iter = iter(self.dataset)
  438. def __next__(self):
  439. if self.preload:
  440. if self.pre_load_device_cache is None:
  441. self._try_load_tensor(cached=False) # load in current
  442. out = self._swap_out_cache()
  443. self._try_load_tensor() # load in cached
  444. return out
  445. else:
  446. return self._get_next_batch()
  447. def _try_load_tensor(self, cached=True):
  448. batch = self._get_next_batch()
  449. self.pre_load_device_cache = self._load_tensor(batch, cached)
  450. class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  451. def __init__(self, loader, preload):
  452. super().__init__(loader, preload)
  453. self.dataset_iter = iter(self.dataset)
  454. def _try_get_raw_data(self, start_time):
  455. raw_data = None
  456. while not raw_data:
  457. try:
  458. if self.timeout > 0:
  459. timer = threading.Timer(self.timeout, thread.interrupt_main)
  460. timer.start()
  461. raw_data = next(self.dataset_iter)
  462. if self.timeout > 0:
  463. timer.cancel()
  464. except AttributeError as error:
  465. raise error
  466. except:
  467. if self.timeout > 0:
  468. timer.cancel()
  469. waited_time = time.time() - start_time
  470. if waited_time > self.timeout:
  471. _raise_timeout_error()
  472. return raw_data
  473. def _get_next_batch(self):
  474. ret = []
  475. start_time = time.time()
  476. while len(ret) < self.sampler.batch_size:
  477. raw_data = self._try_get_raw_data(start_time)
  478. ret.append(self.transform.apply(raw_data))
  479. return self.collator.apply(ret)
  480. class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter, _ParallelDataLoaderIter):
  481. def __init__(self, loader, preload):
  482. _BaseStreamDataLoaderIter.__init__(self, loader, preload)
  483. _ParallelDataLoaderIter.__init__(self)
  484. def _get_remaind_data(self, place_holder):
  485. num = self.sampler.batch_size
  486. for _ in range(num - 1):
  487. place_holder.append(next(self.dataset_iter))
  488. return place_holder
  489. def _try_put_index(self):
  490. try:
  491. if self.parallel_stream is False:
  492. start_time = time.time()
  493. place_holder = [next(self.dataset_iter)]
  494. waited_time = time.time() - start_time
  495. if self.timeout > 0 and waited_time > self.timeout:
  496. _raise_timeout_error()
  497. place_holder = self._get_remaind_data(place_holder)
  498. else:
  499. place_holder = next(self._sampler_iter)
  500. except StopIteration:
  501. return
  502. for _ in range(self.num_workers):
  503. worker_queue_idx = next(self._worker_queue_idx_cycle)
  504. if self._workers_status[worker_queue_idx]:
  505. break
  506. else:
  507. return
  508. self._index_queues[worker_queue_idx].put((self._send_idx, place_holder))
  509. self._task_info[self._send_idx] = (worker_queue_idx,)
  510. self._send_idx += 1
  511. class ManagerWatchdog(object):
  512. def __init__(self):
  513. self.manager_pid = os.getppid()
  514. self.manager_dead = False
  515. def is_alive(self):
  516. if not self.manager_dead:
  517. self.manager_dead = os.getppid() != self.manager_pid
  518. return not self.manager_dead
  519. def stream_fetcher(
  520. dataset_iter, place_holder, transform, collate, parallel_stream, batch_size
  521. ):
  522. data = []
  523. for idx in place_holder:
  524. try:
  525. if parallel_stream is False:
  526. raw_data = idx
  527. else:
  528. raw_data = next(dataset_iter)
  529. trans_items = transform.apply(raw_data)
  530. data.append(trans_items)
  531. except StopIteration:
  532. break
  533. if len(data) == 0:
  534. raise StopIteration
  535. data = collate.apply(data)
  536. return data
  537. def map_fetcher(dataset, place_holder, transform, collate, parallel_stream, batch_size):
  538. items = [dataset[idx] for idx in place_holder]
  539. trans_items = transform.apply_batch(items)
  540. data = collate.apply(trans_items)
  541. return data
  542. def _worker_loop(
  543. dataset,
  544. index_queue,
  545. data_queue,
  546. done_event,
  547. transform,
  548. collate,
  549. batch_size,
  550. seed,
  551. worker_id,
  552. num_workers,
  553. datakind,
  554. parallel_stream,
  555. ):
  556. random.seed(seed)
  557. np.random.seed(seed)
  558. watchdog = ManagerWatchdog()
  559. iteration_end = False
  560. fetcher = map_fetcher
  561. if datakind == "stream":
  562. global _worker_info
  563. _worker_info = WorkerInfo(idx=worker_id, worker=num_workers, seed=seed)
  564. dataset = iter(dataset)
  565. fetcher = stream_fetcher
  566. while watchdog.is_alive():
  567. try:
  568. r = index_queue.get(timeout=GLOBAL_TIMEOUT)
  569. except queue.Empty:
  570. continue
  571. if r is None:
  572. assert done_event.is_set() or iteration_end
  573. break
  574. elif done_event.is_set() or iteration_end:
  575. continue
  576. idx, place_holder = r
  577. try:
  578. data = fetcher(
  579. dataset, place_holder, transform, collate, parallel_stream, batch_size
  580. )
  581. except Exception as e:
  582. if isinstance(e, StopIteration) and datakind == "stream":
  583. data = worker_id
  584. iteration_end = True
  585. else:
  586. raise e
  587. data_queue.put((idx, data))
  588. del data, idx, place_holder, r
  589. if done_event.is_set():
  590. data_queue.disconnect_client()
  591. data_queue.close()