You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import gc
  11. import math
  12. import multiprocessing
  13. import platform
  14. import queue
  15. import random
  16. import threading
  17. import time
  18. from typing import Callable, Union
  19. import numpy as np
  20. from ..device import _sh, get_default_device
  21. from ..functional.tensor import copy
  22. from ..logger import get_logger
  23. from ..random.rng import _random_seed_generator
  24. from ..tensor import Tensor
  25. from .collator import Collator
  26. from .dataset import Dataset, StreamDataset
  27. from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
  28. from .transform import PseudoTransform, Transform
  29. try:
  30. import thread
  31. except:
  32. import _thread as thread
  33. logger = get_logger(__name__)
  34. GLOBAL_TIMEOUT = 5
  35. def raise_timeout_error():
  36. raise RuntimeError("dataloader timeout")
  37. class DataLoader:
  38. r"""Provides a convenient way to iterate on a given dataset.
  39. DataLoader combines a dataset with
  40. :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`,
  41. make it flexible to get minibatch continually from a dataset.
  42. Args:
  43. dataset: dataset from which to load the minibatch.
  44. sampler: defines the strategy to sample data from the dataset.
  45. transform: defined the transforming strategy for a sampled batch.
  46. Default: None
  47. collator: defined the merging strategy for a transformed batch.
  48. Default: None
  49. num_workers: the number of sub-process to load, transform and collate
  50. the batch. ``0`` means using single-process. Default: 0
  51. timeout: if positive, means the timeout value(second) for collecting a
  52. batch from workers. Default: 0
  53. timeout_event: callback function triggered by timeout, default to raise
  54. runtime error.
  55. divide: define the paralleling strategy in multi-processing mode.
  56. ``True`` means one batch is divided into :attr:`num_workers` pieces, and
  57. the workers will process these pieces parallelly. ``False`` means
  58. different sub-process will process different batch. Default: False
  59. preload: Defines whether to apply the preloading strategy of dataloader, and parallelize the copy of host2device while kernal is executed to improve the loading speed. default is seted False
  60. the output will change from np.ndarry to dtype tensor. the support dtypes for preload are int,float,list[int,float],tuple[int,float],and another type is not supported.
  61. """
  62. __initialized = False
  63. def __init__(
  64. self,
  65. dataset: Dataset,
  66. sampler: Sampler = None,
  67. transform: Transform = None,
  68. collator: Collator = None,
  69. num_workers: int = 0,
  70. timeout: int = 0,
  71. timeout_event: Callable = raise_timeout_error,
  72. divide: bool = False,
  73. preload: bool = False,
  74. ):
  75. if num_workers < 0:
  76. raise ValueError("num_workers should not be negative")
  77. if timeout < 0:
  78. raise ValueError("timeout should not be negative")
  79. if divide and num_workers <= 1:
  80. raise ValueError("divide should not be set to True when num_workers <= 1")
  81. self.dataset = dataset
  82. self.num_workers = num_workers
  83. self.timeout = timeout
  84. self.timeout_event = timeout_event
  85. self.divide = divide
  86. self.preload = preload
  87. if isinstance(dataset, StreamDataset):
  88. self.sampler = sampler if sampler else StreamSampler(batch_size=1)
  89. assert isinstance(
  90. self.sampler, StreamSampler
  91. ), "types of dataset and sampler do not match"
  92. else:
  93. assert isinstance(
  94. dataset, Dataset
  95. ), "Can not recognize this kind of dataset: %s" % type(dataset)
  96. self.sampler = (
  97. sampler
  98. if sampler
  99. else SequentialSampler(dataset, batch_size=1, drop_last=False)
  100. )
  101. assert isinstance(
  102. self.sampler, MapSampler
  103. ), "types of dataset and sampler do not match"
  104. if divide:
  105. if self.sampler.batch_size <= self.num_workers:
  106. raise ValueError(
  107. "batch size must not smaller than num_workers in divide mode."
  108. )
  109. elif self.sampler.batch_size % self.num_workers:
  110. logger.warning(
  111. "batch size is not divisible by num_workers, may lose performance in divide mode."
  112. )
  113. if transform is None:
  114. self.transform = PseudoTransform()
  115. else:
  116. self.transform = transform
  117. if collator is None:
  118. self.collator = Collator()
  119. else:
  120. self.collator = collator
  121. self.__initialized = True
  122. def __iter__(self):
  123. if platform.system() == "Windows" and self.num_workers > 0:
  124. print(
  125. "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
  126. )
  127. self.num_workers = 0
  128. if isinstance(self.dataset, StreamDataset):
  129. if not self.num_workers:
  130. return _SerialStreamDataLoaderIter(self, self.preload)
  131. else:
  132. return _ParallelStreamDataLoaderIter(self, self.preload)
  133. else:
  134. assert isinstance(
  135. self.dataset, Dataset
  136. ), "Can not recognize this kind of dataset: %s" % type(self.dataset)
  137. if not self.num_workers:
  138. return _SerialMapDataLoaderIter(self, self.preload)
  139. else:
  140. return _ParallelMapDataLoaderIter(self, self.preload)
  141. def __len__(self):
  142. return len(self.sampler)
  143. class PreLoader:
  144. def __init__(self, preload):
  145. if preload:
  146. self.default_device = get_default_device()
  147. self.pre_load_device = self.default_device + ":" + str(_sh.get_next())
  148. self.pre_load_device_cache = None
  149. self.preload = preload
  150. """
  151. strategy one: load from numpy data, and generate dtype tensor
  152. """
  153. def _load_tensor(self, batch, cached=True):
  154. if isinstance(batch, np.ndarray):
  155. device = self.pre_load_device if cached else self.default_device
  156. return Tensor(batch, device=device)
  157. elif isinstance(batch, collections.abc.Mapping):
  158. return {k: self._load_tensor(v, cached) for k, v in batch.items()}
  159. elif isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple
  160. return type(batch)(*(self._load_tensor(value, cached) for value in batch))
  161. elif isinstance(batch, collections.abc.Sequence):
  162. return [self._load_tensor(value, cached) for value in batch]
  163. else:
  164. return batch
  165. """
  166. strategy two: load from cache that is already tensor just do d2d copy
  167. """
  168. def _load_cache(self, data):
  169. if isinstance(data, Tensor):
  170. if data.device == self.default_device:
  171. return data
  172. return copy(data, device=self.default_device)
  173. elif isinstance(data, collections.abc.Mapping):
  174. return {k: self._load_cache(v) for k, v in data.items()}
  175. elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
  176. return type(data)(*(self._load_cache(value) for value in data))
  177. elif isinstance(data, collections.abc.Sequence):
  178. return [self._load_cache(value) for value in data]
  179. else:
  180. return data
  181. def _swap_out_cache(self):
  182. out = self._load_cache(self.pre_load_device_cache)
  183. self.pre_load_device_cache = None # clean cache
  184. return out
  185. class _BaseMapDataLoaderIter(PreLoader):
  186. def __init__(self, loader, preload):
  187. super().__init__(preload)
  188. self.dataset = loader.dataset
  189. self.sampler = loader.sampler
  190. self.seed = _random_seed_generator().__next__()
  191. self.transform = loader.transform
  192. self.collator = loader.collator
  193. self.num_workers = loader.num_workers
  194. self.timeout = loader.timeout
  195. self.timeout_event = loader.timeout_event
  196. self.divide = loader.divide
  197. self.num_processed = 0
  198. def _get_next_batch(self):
  199. raise NotImplementedError
  200. def __len__(self):
  201. return len(self.sampler)
  202. def __iter__(self):
  203. return self
  204. def __next__(self):
  205. if self.preload:
  206. cached = self.pre_load_device_cache
  207. if cached is None: # first and last
  208. if self.num_processed >= len(self): # last
  209. raise StopIteration
  210. elif self.num_processed == 0: # first
  211. self._try_load_tensor(cached=False) # first do the h2d
  212. out = self._swap_out_cache()
  213. self._try_load_tensor()
  214. return out
  215. else:
  216. if self.num_processed >= len(self):
  217. raise StopIteration
  218. minibatch = self._get_next_batch()
  219. self.num_processed += 1
  220. return minibatch
  221. def _try_load_tensor(self, cached=True):
  222. if self.num_processed >= len(self):
  223. return
  224. else:
  225. self.num_processed += 1
  226. batch = self._get_next_batch()
  227. self.pre_load_device_cache = self._load_tensor(batch, cached)
  228. class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
  229. def __init__(self, loader, preload):
  230. super(_SerialMapDataLoaderIter, self).__init__(loader, preload)
  231. self.indices_iter = iter(self.sampler)
  232. def _get_next_batch(self):
  233. indices = next(self.indices_iter)
  234. items = [self.dataset[idx] for idx in indices]
  235. trans_items = self.transform.apply_batch(items)
  236. return self.collator.apply(trans_items)
  237. class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
  238. __initialized = False
  239. def __init__(self, loader, preload):
  240. super(_ParallelMapDataLoaderIter, self).__init__(loader, preload)
  241. self.task_queues = [
  242. multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
  243. ]
  244. self.feed_batch_idx = multiprocessing.Value("i", 0)
  245. self.target_batch_idx = multiprocessing.Value("i", 0)
  246. self.shutdown_flag = multiprocessing.Value("i", 0)
  247. self.trans_data_queues = [
  248. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  249. ]
  250. # use shared-memory queue implemented by pyarrow plasma store.
  251. from .tools._queue import PlasmaShmQueue
  252. self.batch_queue = PlasmaShmQueue(maxsize=2)
  253. self.task_feeding_worker = multiprocessing.Process(
  254. target=_task_feeding_loop,
  255. args=(
  256. iter(self.sampler),
  257. self.task_queues,
  258. self.num_workers,
  259. self.divide,
  260. self.shutdown_flag,
  261. self.feed_batch_idx,
  262. ),
  263. daemon=True,
  264. )
  265. gc.collect()
  266. self.task_feeding_worker.start()
  267. self.workers = []
  268. for worker_id in range(self.num_workers):
  269. worker = multiprocessing.Process(
  270. target=_worker_loop,
  271. args=(
  272. self.dataset,
  273. self.task_queues[worker_id],
  274. self.trans_data_queues[worker_id],
  275. self.transform,
  276. self.seed + worker_id + 1,
  277. self.shutdown_flag,
  278. ),
  279. daemon=True,
  280. )
  281. gc.collect()
  282. worker.start()
  283. self.workers.append(worker)
  284. if self.divide:
  285. self.data_collecting_worker = multiprocessing.Process(
  286. target=_data_gathering_loop,
  287. args=(
  288. self.trans_data_queues,
  289. self.batch_queue,
  290. self.collator,
  291. len(self),
  292. self.num_workers,
  293. self.shutdown_flag,
  294. self.target_batch_idx,
  295. ),
  296. daemon=True,
  297. )
  298. else:
  299. self.data_collecting_worker = multiprocessing.Process(
  300. target=_data_selecting_loop,
  301. args=(
  302. self.trans_data_queues,
  303. self.batch_queue,
  304. self.collator,
  305. len(self),
  306. self.num_workers,
  307. self.shutdown_flag,
  308. self.target_batch_idx,
  309. ),
  310. daemon=True,
  311. )
  312. gc.collect()
  313. self.data_collecting_worker.start()
  314. self.__initialized = True
  315. def _check_workers(self):
  316. # Check the status of each worker.
  317. if not self.data_collecting_worker.is_alive():
  318. exitcode = self.data_collecting_worker.exitcode
  319. if exitcode != 0:
  320. raise RuntimeError("data collecting worker died. {}".format(exitcode))
  321. if not self.task_feeding_worker.is_alive():
  322. exitcode = self.task_feeding_worker.exitcode
  323. if exitcode != 0:
  324. raise RuntimeError("task feeding worker died. {}".format(exitcode))
  325. for worker_id, worker in enumerate(self.workers):
  326. if not worker.is_alive():
  327. exitcode = worker.exitcode
  328. if exitcode != 0:
  329. raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode))
  330. logger.debug("all workers are alive.")
  331. def _get_next_batch(self):
  332. start_time = time.time()
  333. while True:
  334. self._check_workers()
  335. try:
  336. return self.batch_queue.get(timeout=1)
  337. except queue.Empty:
  338. logger.debug("batch queue empty!")
  339. waited_time = time.time() - start_time
  340. if self.timeout > 0:
  341. if waited_time > self.timeout:
  342. raise RuntimeError("get_next_batch timeout!")
  343. def _shutdown(self):
  344. with self.shutdown_flag.get_lock():
  345. self.shutdown_flag.value = 1
  346. if self.task_feeding_worker.is_alive():
  347. self.task_feeding_worker.terminate()
  348. self.task_feeding_worker.join()
  349. if self.data_collecting_worker.is_alive():
  350. self.data_collecting_worker.terminate()
  351. self.data_collecting_worker.join()
  352. for worker in self.workers:
  353. if worker.is_alive():
  354. worker.terminate()
  355. worker.join()
  356. for q in self.trans_data_queues:
  357. q.cancel_join_thread()
  358. q.close()
  359. for q in self.task_queues:
  360. q.cancel_join_thread()
  361. q.close()
  362. self.batch_queue.cancel_join_thread()
  363. self.batch_queue.close()
  364. def __del__(self):
  365. if self.__initialized:
  366. self._shutdown()
  367. class _BaseStreamDataLoaderIter(PreLoader):
  368. def __init__(self, loader, preload):
  369. super().__init__(preload)
  370. self.dataset = loader.dataset
  371. self.sampler = loader.sampler
  372. self.transform = loader.transform
  373. self.collator = loader.collator
  374. self.num_workers = loader.num_workers
  375. self.timeout = loader.timeout
  376. self.timeout_event = loader.timeout_event
  377. def _get_next_batch(self):
  378. raise NotImplementedError
  379. def _process_raw_data(self, raw_data):
  380. assert len(raw_data) == 2 and isinstance(
  381. raw_data[0], bool
  382. ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
  383. if not raw_data[0]:
  384. data = list((x,) for x in raw_data[1])
  385. else:
  386. data = raw_data[1]
  387. ret = []
  388. for idx in range(len(data[0])):
  389. ret.append(tuple(e[idx] for e in data))
  390. return ret
  391. def __iter__(self):
  392. return self
  393. def __next__(self):
  394. if self.preload:
  395. if self.pre_load_device_cache is None:
  396. self._try_load_tensor(cached=False) # load in current
  397. out = self._swap_out_cache()
  398. self._try_load_tensor() # load in cached
  399. return out
  400. else:
  401. return self._get_next_batch()
  402. def _try_load_tensor(self, cached=True):
  403. batch = self._get_next_batch()
  404. self.pre_load_device_cache = self._load_tensor(batch, cached)
  405. class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  406. def __init__(self, loader, preload):
  407. super().__init__(loader, preload)
  408. self.dataset_iter = iter(self.dataset)
  409. self.idx = 0
  410. self.unused = []
  411. def _try_get_raw_data(self, start_time):
  412. raw_data = None
  413. while not raw_data:
  414. try:
  415. if self.timeout > 0:
  416. timer = threading.Timer(self.timeout, thread.interrupt_main)
  417. timer.start()
  418. raw_data = next(self.dataset_iter)
  419. if self.timeout > 0:
  420. timer.cancel()
  421. except KeyboardInterrupt:
  422. raw_data = self.timeout_event()
  423. except:
  424. if self.timeout > 0:
  425. timer.cancel()
  426. waited_time = time.time() - start_time
  427. if waited_time > self.timeout:
  428. raw_data = self.timeout_event()
  429. return raw_data
  430. def _get_next_batch(self):
  431. ret = []
  432. start_time = time.time()
  433. while len(ret) < self.sampler.batch_size:
  434. if len(self.unused) != 0:
  435. batch_data = self.unused
  436. else:
  437. raw_data = self._try_get_raw_data(start_time)
  438. batch_data = self._process_raw_data(raw_data)
  439. while len(batch_data) != 0 and len(ret) < self.sampler.batch_size:
  440. data = batch_data.pop()
  441. ret.append(self.transform.apply(data))
  442. self.unused = batch_data
  443. return self.collator.apply(ret)
  444. class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  445. __initialized = False
  446. def __init__(self, loader, preload):
  447. super().__init__(loader, preload)
  448. self.shutdown_flag = multiprocessing.Value("i", 0)
  449. self.raw_data_queues = [
  450. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  451. ]
  452. self.trans_data_queues = [
  453. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  454. ]
  455. # shared-memory queue implemented by pyarrow plasma store
  456. from .tools._queue import PlasmaShmQueue
  457. self.batch_queue = PlasmaShmQueue(maxsize=2)
  458. self.recieve_worker = multiprocessing.Process(
  459. target=self._worker_to_raw_data_queues, daemon=True
  460. )
  461. gc.collect()
  462. self.recieve_worker.start()
  463. self.transform_workers = []
  464. for worker_id in range(self.num_workers):
  465. worker = multiprocessing.Process(
  466. target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True
  467. )
  468. gc.collect()
  469. worker.start()
  470. self.transform_workers.append(worker)
  471. self.collect_worker = multiprocessing.Process(
  472. target=self._worker_to_batch_queue, daemon=True
  473. )
  474. gc.collect()
  475. self.collect_worker.start()
  476. self.__initialized = True
  477. def _put_raw_data_queues(self, raw_data, qidx):
  478. batch_data = self._process_raw_data(raw_data)
  479. for data in batch_data:
  480. while True:
  481. qidx = qidx % self.num_workers
  482. try:
  483. self.raw_data_queues[qidx].put(data)
  484. break
  485. except queue.Full:
  486. if self.shutdown_flag.value == 1:
  487. break
  488. logger.debug("raw data queue %d is full" % qidx)
  489. finally:
  490. qidx += 1
  491. return qidx
  492. def _worker_to_raw_data_queues(self):
  493. dataset_iter = iter(self.dataset)
  494. qidx = 0
  495. while True:
  496. if self.shutdown_flag.value == 1:
  497. break
  498. raw_data = next(dataset_iter)
  499. qidx = self._put_raw_data_queues(raw_data, qidx)
  500. def _worker_to_trans_data_queues(self, worker_id):
  501. while True:
  502. if self.shutdown_flag.value == 1:
  503. break
  504. try:
  505. data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)
  506. except queue.Empty:
  507. continue
  508. trans_data = self.transform.apply(data)
  509. while True:
  510. try:
  511. self.trans_data_queues[worker_id].put(trans_data)
  512. break
  513. except queue.Full:
  514. if self.shutdown_flag.value == 1:
  515. break
  516. logger.debug("batch queue if full")
  517. def _worker_to_batch_queue(self):
  518. cnt = -1
  519. trans_items = []
  520. while True:
  521. if self.shutdown_flag.value == 1:
  522. break
  523. cnt += 1
  524. queue_id = cnt % self.num_workers
  525. try:
  526. trans_item = self.trans_data_queues[queue_id].get(
  527. timeout=GLOBAL_TIMEOUT
  528. )
  529. except queue.Empty:
  530. continue
  531. trans_items.append(trans_item)
  532. if len(trans_items) == self.sampler.batch_size:
  533. batch_data = self.collator.apply(trans_items)
  534. while True:
  535. try:
  536. self.batch_queue.put(batch_data, timeout=1)
  537. break
  538. except queue.Full:
  539. if self.shutdown_flag.value == 1:
  540. break
  541. logger.debug("batch queue is full")
  542. trans_items = []
  543. def _check_workers(self):
  544. if not self.collect_worker.is_alive():
  545. exitcode = self.collect_worker.exitcode
  546. if exitcode != 0:
  547. raise RuntimeError("collator worker died. {}".format(exitcode))
  548. for worker_id, worker in enumerate(self.transform_workers):
  549. if not worker.is_alive():
  550. exitcode = worker.exitcode
  551. if exitcode != 0:
  552. raise RuntimeError(
  553. "worker: {} died. {}".format(worker_id, exitcode)
  554. )
  555. def _get_next_batch(self):
  556. start_time = time.time()
  557. while True:
  558. self._check_workers()
  559. try:
  560. return self.batch_queue.get(timeout=1)
  561. except queue.Empty:
  562. logger.debug("batch queue empty!")
  563. waited_time = time.time() - start_time
  564. if self.timeout > 0 and waited_time > self.timeout:
  565. self._put_raw_data_queues(self.timeout_event(), 0)
  566. def _shutdown(self):
  567. with self.shutdown_flag.get_lock():
  568. self.shutdown_flag.value = 1
  569. if self.recieve_worker.is_alive():
  570. self.recieve_worker.terminate()
  571. self.recieve_worker.join()
  572. if self.collect_worker.is_alive():
  573. self.collect_worker.terminate()
  574. self.collect_worker.join()
  575. for worker in self.transform_workers:
  576. if worker.is_alive():
  577. worker.terminate()
  578. worker.join()
  579. for q in self.raw_data_queues:
  580. q.cancel_join_thread()
  581. q.close()
  582. for q in self.trans_data_queues:
  583. q.cancel_join_thread()
  584. q.close()
  585. self.batch_queue.cancel_join_thread()
  586. self.batch_queue.close()
  587. def __del__(self):
  588. if self.__initialized:
  589. self._shutdown()
  590. def _task_feeding_loop(
  591. indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
  592. ):
  593. # Feed the indices into the task queues
  594. while True:
  595. if shutdown_flag.value == 1:
  596. break
  597. batch_idx = feed_batch_idx.value
  598. try:
  599. indices = next(indices_iter)
  600. except StopIteration:
  601. break
  602. if divide:
  603. # make sure all task_queues is ready for put
  604. while any([q.full() for q in task_queues]):
  605. if shutdown_flag.value == 1:
  606. return
  607. # divide into small pieces, feed to different workers.
  608. sub_num = math.ceil(len(indices) / num_workers)
  609. for worker_id in range(num_workers):
  610. sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num]
  611. task_queues[worker_id].put((batch_idx, sub_indices))
  612. else:
  613. # distribute tasks to different workers uniformly.
  614. target_id = batch_idx % num_workers
  615. while task_queues[target_id].full():
  616. if shutdown_flag.value == 1:
  617. return
  618. task_queues[target_id].put((batch_idx, indices))
  619. with feed_batch_idx.get_lock():
  620. feed_batch_idx.value += 1
  621. def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag):
  622. # Get dataset items and do the transform
  623. random.seed(seed)
  624. np.random.seed(seed)
  625. while True:
  626. if shutdown_flag.value == 1:
  627. break
  628. try:
  629. batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT)
  630. except queue.Empty:
  631. continue
  632. if len(indices) > 0:
  633. items = [dataset[idx] for idx in indices]
  634. trans_items = transform.apply_batch(items)
  635. else:
  636. # in case of incomplete last batch
  637. trans_items = ()
  638. while True:
  639. try:
  640. trans_data_queue.put((batch_idx, trans_items), timeout=1)
  641. break
  642. except queue.Full:
  643. if shutdown_flag.value == 1:
  644. break
  645. logger.debug("batch part queue is full!")
  646. def _data_gathering_loop(
  647. trans_data_queues,
  648. batch_queue,
  649. collator,
  650. length,
  651. num_workers,
  652. shutdown_flag,
  653. target_idx,
  654. ):
  655. # Gathering the small pieces of batch data into full batch data
  656. while True:
  657. if shutdown_flag.value == 1:
  658. break
  659. target_batch_idx = target_idx.value
  660. if target_batch_idx >= length:
  661. break
  662. full_trans_items = []
  663. for worker_id in range(num_workers):
  664. while True:
  665. try:
  666. batch_idx, trans_items = trans_data_queues[worker_id].get(
  667. timeout=GLOBAL_TIMEOUT
  668. )
  669. break
  670. except queue.Empty:
  671. if shutdown_flag.value == 1:
  672. break
  673. logger.debug(
  674. "worker:{} data queue get timeout! target batch idx:{}".format(
  675. worker_id, target_batch_idx
  676. )
  677. )
  678. if batch_idx != target_batch_idx:
  679. raise RuntimeError(
  680. "Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
  681. worker_id
  682. )
  683. )
  684. else:
  685. full_trans_items.extend(trans_items)
  686. # Merge different parts into a batch.
  687. full_batch = collator.apply(full_trans_items)
  688. while True:
  689. try:
  690. batch_queue.put(full_batch, timeout=1)
  691. break
  692. except queue.Full:
  693. if shutdown_flag.value == 1:
  694. break
  695. logger.debug("batch queue is full!")
  696. with target_idx.get_lock():
  697. target_idx.value += 1
  698. batch_queue.disconnect_client()
  699. def _data_selecting_loop(
  700. trans_data_queues,
  701. batch_queue,
  702. collator,
  703. length,
  704. num_workers,
  705. shutdown_flag,
  706. target_idx,
  707. ):
  708. # Make sure that batch is generated exactly with the same order as generated indices
  709. while True:
  710. if shutdown_flag.value == 1:
  711. break
  712. target_batch_idx = target_idx.value
  713. if target_batch_idx >= length:
  714. break
  715. target_worker_id = target_batch_idx % num_workers
  716. while True:
  717. try:
  718. batch_idx, trans_items = trans_data_queues[target_worker_id].get(
  719. timeout=GLOBAL_TIMEOUT
  720. )
  721. batch_data = collator.apply(trans_items)
  722. break
  723. except queue.Empty:
  724. if shutdown_flag.value == 1:
  725. break
  726. logger.debug(
  727. "worker:{} data queue get timeout! target batch idx:{}".format(
  728. target_worker_id, target_batch_idx
  729. )
  730. )
  731. if batch_idx != target_batch_idx:
  732. raise RuntimeError(
  733. "batch_idx {} mismatch the target_batch_idx {}".format(
  734. batch_idx, target_batch_idx
  735. )
  736. )
  737. while True:
  738. try:
  739. batch_queue.put(batch_data, timeout=1)
  740. break
  741. except queue.Full:
  742. if shutdown_flag.value == 1:
  743. break
  744. logger.debug("batch queue is full!")
  745. with target_idx.get_lock():
  746. target_idx.value += 1
  747. batch_queue.disconnect_client()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台