You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import math
  11. import multiprocessing
  12. import platform
  13. import queue
  14. import random
  15. import threading
  16. import time
  17. from typing import Callable
  18. import numpy as np
  19. from ..logger import get_logger
  20. from ..random.rng import _random_seed_generator
  21. from .collator import Collator
  22. from .dataset import Dataset, StreamDataset
  23. from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
  24. from .transform import PseudoTransform, Transform
  25. try:
  26. import thread
  27. except:
  28. import _thread as thread
  29. logger = get_logger(__name__)
  30. GLOBAL_TIMEOUT = 5
  31. def raise_timeout_error():
  32. raise RuntimeError("dataloader timeout")
  33. class DataLoader:
  34. r"""
  35. Provides a convenient way to iterate on a given dataset.
  36. """
  37. __initialized = False
  38. def __init__(
  39. self,
  40. dataset: Dataset,
  41. sampler: Sampler = None,
  42. transform: Transform = None,
  43. collator: Collator = None,
  44. num_workers: int = 0,
  45. timeout: int = 0,
  46. timeout_event: Callable = raise_timeout_error,
  47. divide: bool = False,
  48. ):
  49. r"""
  50. `DataLoader` combines a dataset with `sampler`, `transform` and `collator`,
  51. make it flexible to get minibatch continually from a dataset.
  52. :type dataset: Dataset
  53. :param dataset: dataset from which to load the minibatch.
  54. :type sampler: Sampler
  55. :param sampler: defines the strategy to sample data from the dataset.
  56. :type transform: Transform
  57. :param transform: defined the transforming strategy for a sampled batch.
  58. Default: None
  59. :type collator: Collator
  60. :param collator: defined the merging strategy for a transformed batch.
  61. Default: None
  62. :type num_workers: int
  63. :param num_workers: the number of sub-process to load, transform and collate
  64. the batch. ``0`` means using single-process. Default: 0
  65. :type timeout: int
  66. :param timeout: if positive, means the timeout value(second) for collecting a
  67. batch from workers. Default: 0
  68. :type timeout_event: Callable
  69. :param timeout_event: callback function triggered by timeout, default to raise
  70. runtime error.
  71. :type divide: bool
  72. :param divide: define the paralleling strategy in multi-processing mode.
  73. ``True`` means one batch is divided into :attr:`num_workers` pieces, and
  74. the workers will process these pieces parallelly. ``False`` means
  75. different sub-process will process different batch. Default: False
  76. """
  77. if num_workers < 0:
  78. raise ValueError("num_workers should not be negative")
  79. if timeout < 0:
  80. raise ValueError("timeout should not be negative")
  81. if divide and num_workers <= 1:
  82. raise ValueError("divide should not be set to True when num_workers <= 1")
  83. self.dataset = dataset
  84. self.num_workers = num_workers
  85. self.timeout = timeout
  86. self.timeout_event = timeout_event
  87. self.divide = divide
  88. if isinstance(dataset, StreamDataset):
  89. self.sampler = sampler if sampler else StreamSampler(batch_size=1)
  90. assert isinstance(
  91. self.sampler, StreamSampler
  92. ), "types of dataset and sampler do not match"
  93. else:
  94. assert isinstance(
  95. dataset, Dataset
  96. ), "Can not recognize this kind of dataset: %s" % type(dataset)
  97. self.sampler = (
  98. sampler
  99. if sampler
  100. else SequentialSampler(dataset, batch_size=1, drop_last=False)
  101. )
  102. assert isinstance(
  103. self.sampler, MapSampler
  104. ), "types of dataset and sampler do not match"
  105. if divide:
  106. if self.sampler.batch_size <= self.num_workers:
  107. raise ValueError(
  108. "batch size must not smaller than num_workers in divide mode."
  109. )
  110. elif self.sampler.batch_size % self.num_workers:
  111. logger.warning(
  112. "batch size is not divisible by num_workers, may lose performance in divide mode."
  113. )
  114. if transform is None:
  115. self.transform = PseudoTransform()
  116. else:
  117. self.transform = transform
  118. if collator is None:
  119. self.collator = Collator()
  120. else:
  121. self.collator = collator
  122. self.__initialized = True
  123. def __iter__(self):
  124. if platform.system() == "Windows" and self.num_workers > 0:
  125. print(
  126. "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
  127. )
  128. self.num_workers = 0
  129. if isinstance(self.dataset, StreamDataset):
  130. if not self.num_workers:
  131. return _SerialStreamDataLoaderIter(self)
  132. else:
  133. return _ParallelStreamDataLoaderIter(self)
  134. else:
  135. assert isinstance(
  136. self.dataset, Dataset
  137. ), "Can not recognize this kind of dataset: %s" % type(self.dataset)
  138. if not self.num_workers:
  139. return _SerialMapDataLoaderIter(self)
  140. else:
  141. return _ParallelMapDataLoaderIter(self)
  142. def __len__(self):
  143. return len(self.sampler)
  144. class _BaseMapDataLoaderIter:
  145. def __init__(self, loader):
  146. self.dataset = loader.dataset
  147. self.sampler = loader.sampler
  148. self.seed = _random_seed_generator().__next__()
  149. self.transform = loader.transform
  150. self.collator = loader.collator
  151. self.num_workers = loader.num_workers
  152. self.timeout = loader.timeout
  153. self.timeout_event = loader.timeout_event
  154. self.divide = loader.divide
  155. self.num_processed = 0
  156. def _get_next_batch(self):
  157. raise NotImplementedError
  158. def __len__(self):
  159. return len(self.sampler)
  160. def __iter__(self):
  161. return self
  162. def __next__(self):
  163. if self.num_processed >= len(self):
  164. raise StopIteration
  165. minibatch = self._get_next_batch()
  166. self.num_processed += 1
  167. return minibatch
  168. class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
  169. def __init__(self, loader):
  170. super(_SerialMapDataLoaderIter, self).__init__(loader)
  171. self.indices_iter = iter(self.sampler)
  172. def _get_next_batch(self):
  173. indices = next(self.indices_iter)
  174. items = [self.dataset[idx] for idx in indices]
  175. trans_items = self.transform.apply_batch(items)
  176. return self.collator.apply(trans_items)
  177. class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
  178. __initialized = False
  179. def __init__(self, loader):
  180. super(_ParallelMapDataLoaderIter, self).__init__(loader)
  181. self.task_queues = [
  182. multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
  183. ]
  184. self.feed_batch_idx = multiprocessing.Value("i", 0)
  185. self.target_batch_idx = multiprocessing.Value("i", 0)
  186. self.shutdown_flag = multiprocessing.Value("i", 0)
  187. self.trans_data_queues = [
  188. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  189. ]
  190. # use shared-memory queue implemented by pyarrow plasma store.
  191. from ._queue import PlasmaShmQueue
  192. self.batch_queue = PlasmaShmQueue(maxsize=2)
  193. self.task_feeding_worker = multiprocessing.Process(
  194. target=_task_feeding_loop,
  195. args=(
  196. iter(self.sampler),
  197. self.task_queues,
  198. self.num_workers,
  199. self.divide,
  200. self.shutdown_flag,
  201. self.feed_batch_idx,
  202. ),
  203. daemon=True,
  204. )
  205. self.task_feeding_worker.start()
  206. self.workers = []
  207. for worker_id in range(self.num_workers):
  208. worker = multiprocessing.Process(
  209. target=_worker_loop,
  210. args=(
  211. self.dataset,
  212. self.task_queues[worker_id],
  213. self.trans_data_queues[worker_id],
  214. self.transform,
  215. self.seed + worker_id + 1,
  216. self.shutdown_flag,
  217. ),
  218. daemon=True,
  219. )
  220. worker.start()
  221. self.workers.append(worker)
  222. if self.divide:
  223. self.data_collecting_worker = multiprocessing.Process(
  224. target=_data_gathering_loop,
  225. args=(
  226. self.trans_data_queues,
  227. self.batch_queue,
  228. self.collator,
  229. len(self),
  230. self.num_workers,
  231. self.shutdown_flag,
  232. self.target_batch_idx,
  233. ),
  234. daemon=True,
  235. )
  236. else:
  237. self.data_collecting_worker = multiprocessing.Process(
  238. target=_data_selecting_loop,
  239. args=(
  240. self.trans_data_queues,
  241. self.batch_queue,
  242. self.collator,
  243. len(self),
  244. self.num_workers,
  245. self.shutdown_flag,
  246. self.target_batch_idx,
  247. ),
  248. daemon=True,
  249. )
  250. self.data_collecting_worker.start()
  251. self.__initialized = True
  252. def _check_workers(self):
  253. # Check the status of each worker.
  254. if not self.data_collecting_worker.is_alive():
  255. exitcode = self.task_feeding_worker.exitcode
  256. if exitcode != 0:
  257. raise RuntimeError("data collecting worker died. {}".format(exitcode))
  258. if not self.task_feeding_worker.is_alive():
  259. exitcode = self.task_feeding_worker.exitcode
  260. if exitcode != 0:
  261. raise RuntimeError("task feeding worker died. {}".format(exitcode))
  262. for worker_id, worker in enumerate(self.workers):
  263. if not worker.is_alive():
  264. exitcode = worker.exitcode
  265. if exitcode != 0:
  266. raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode))
  267. logger.debug("all workers are alive.")
  268. def _get_next_batch(self):
  269. start_time = time.time()
  270. while True:
  271. self._check_workers()
  272. try:
  273. return self.batch_queue.get(timeout=1)
  274. except queue.Empty:
  275. logger.debug("batch queue empty!")
  276. waited_time = time.time() - start_time
  277. if self.timeout > 0:
  278. if waited_time > self.timeout:
  279. raise RuntimeError("get_next_batch timeout!")
  280. def _shutdown(self):
  281. with self.shutdown_flag.get_lock():
  282. self.shutdown_flag.value = 1
  283. if self.task_feeding_worker.is_alive():
  284. self.task_feeding_worker.terminate()
  285. self.task_feeding_worker.join()
  286. if self.data_collecting_worker.is_alive():
  287. self.data_collecting_worker.terminate()
  288. self.data_collecting_worker.join()
  289. for worker in self.workers:
  290. if worker.is_alive():
  291. worker.terminate()
  292. worker.join()
  293. for q in self.trans_data_queues:
  294. q.cancel_join_thread()
  295. q.close()
  296. for q in self.task_queues:
  297. q.cancel_join_thread()
  298. q.close()
  299. self.batch_queue.cancel_join_thread()
  300. self.batch_queue.close()
  301. def __del__(self):
  302. if self.__initialized:
  303. self._shutdown()
  304. class _BaseStreamDataLoaderIter:
  305. def __init__(self, loader):
  306. self.dataset = loader.dataset
  307. self.sampler = loader.sampler
  308. self.transform = loader.transform
  309. self.collator = loader.collator
  310. self.num_workers = loader.num_workers
  311. self.timeout = loader.timeout
  312. self.timeout_event = loader.timeout_event
  313. def _get_next_batch(self):
  314. raise NotImplementedError
  315. def _process_raw_data(self, raw_data):
  316. assert len(raw_data) == 2 and isinstance(
  317. raw_data[0], bool
  318. ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
  319. if not raw_data[0]:
  320. data = list((x,) for x in raw_data[1])
  321. else:
  322. data = raw_data[1]
  323. ret = []
  324. for idx in range(len(data[0])):
  325. ret.append(tuple(e[idx] for e in data))
  326. return ret
  327. def __iter__(self):
  328. return self
  329. def __next__(self):
  330. return self._get_next_batch()
  331. class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  332. def __init__(self, loader):
  333. super().__init__(loader)
  334. self.dataset_iter = iter(self.dataset)
  335. self.idx = 0
  336. self.unused = []
  337. def _try_get_raw_data(self, start_time):
  338. raw_data = None
  339. while not raw_data:
  340. try:
  341. if self.timeout > 0:
  342. timer = threading.Timer(self.timeout, thread.interrupt_main)
  343. timer.start()
  344. raw_data = next(self.dataset_iter)
  345. if self.timeout > 0:
  346. timer.cancel()
  347. except KeyboardInterrupt:
  348. raw_data = self.timeout_event()
  349. except:
  350. if self.timeout > 0:
  351. timer.cancel()
  352. waited_time = time.time() - start_time
  353. if waited_time > self.timeout:
  354. raw_data = self.timeout_event()
  355. return raw_data
  356. def _get_next_batch(self):
  357. ret = []
  358. start_time = time.time()
  359. while len(ret) < self.sampler.batch_size:
  360. if len(self.unused) != 0:
  361. batch_data = self.unused
  362. else:
  363. raw_data = self._try_get_raw_data(start_time)
  364. batch_data = self._process_raw_data(raw_data)
  365. while len(batch_data) != 0 and len(ret) < self.sampler.batch_size:
  366. data = batch_data.pop()
  367. ret.append(self.transform.apply(data))
  368. self.unused = batch_data
  369. return self.collator.apply(ret)
  370. class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  371. __initialized = False
  372. def __init__(self, loader):
  373. super().__init__(loader)
  374. self.shutdown_flag = multiprocessing.Value("i", 0)
  375. self.raw_data_queues = [
  376. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  377. ]
  378. self.trans_data_queues = [
  379. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  380. ]
  381. # shared-memory queue implemented by pyarrow plasma store
  382. from ._queue import PlasmaShmQueue
  383. self.batch_queue = PlasmaShmQueue(maxsize=2)
  384. self.recieve_worker = multiprocessing.Process(
  385. target=self._worker_to_raw_data_queues, daemon=True
  386. )
  387. self.recieve_worker.start()
  388. self.transform_workers = []
  389. for worker_id in range(self.num_workers):
  390. worker = multiprocessing.Process(
  391. target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True
  392. )
  393. worker.start()
  394. self.transform_workers.append(worker)
  395. self.collect_worker = multiprocessing.Process(
  396. target=self._worker_to_batch_queue, daemon=True
  397. )
  398. self.collect_worker.start()
  399. self.__initialized = True
  400. def _put_raw_data_queues(self, raw_data, qidx):
  401. batch_data = self._process_raw_data(raw_data)
  402. for data in batch_data:
  403. while True:
  404. qidx = qidx % self.num_workers
  405. try:
  406. self.raw_data_queues[qidx].put(data)
  407. break
  408. except queue.Full:
  409. if self.shutdown_flag.value == 1:
  410. break
  411. logger.debug("raw data queue %d is full" % qidx)
  412. finally:
  413. qidx += 1
  414. return qidx
  415. def _worker_to_raw_data_queues(self):
  416. dataset_iter = iter(self.dataset)
  417. qidx = 0
  418. while True:
  419. if self.shutdown_flag.value == 1:
  420. break
  421. raw_data = next(dataset_iter)
  422. qidx = self._put_raw_data_queues(raw_data, qidx)
  423. def _worker_to_trans_data_queues(self, worker_id):
  424. while True:
  425. if self.shutdown_flag.value == 1:
  426. break
  427. try:
  428. data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)
  429. except queue.Empty:
  430. continue
  431. trans_data = self.transform.apply(data)
  432. while True:
  433. try:
  434. self.trans_data_queues[worker_id].put(trans_data)
  435. break
  436. except queue.Full:
  437. if self.shutdown_flag.value == 1:
  438. break
  439. logger.debug("batch queue if full")
  440. def _worker_to_batch_queue(self):
  441. cnt = -1
  442. trans_items = []
  443. while True:
  444. if self.shutdown_flag.value == 1:
  445. break
  446. cnt += 1
  447. queue_id = cnt % self.num_workers
  448. try:
  449. trans_item = self.trans_data_queues[queue_id].get(
  450. timeout=GLOBAL_TIMEOUT
  451. )
  452. except queue.Empty:
  453. continue
  454. trans_items.append(trans_item)
  455. if len(trans_items) == self.sampler.batch_size:
  456. batch_data = self.collator.apply(trans_items)
  457. while True:
  458. try:
  459. self.batch_queue.put(batch_data, timeout=1)
  460. break
  461. except queue.Full:
  462. if self.shutdown_flag.value == 1:
  463. break
  464. logger.debug("batch queue is full")
  465. trans_items = []
  466. def _check_workers(self):
  467. if not self.collect_worker.is_alive():
  468. exitcode = self.collect_worker.exitcode
  469. if exitcode != 0:
  470. raise RuntimeError("collator worker died. {}".format(exitcode))
  471. for worker_id, worker in enumerate(self.transform_workers):
  472. if not worker.is_alive():
  473. exitcode = worker.exitcode
  474. if exitcode != 0:
  475. raise RuntimeError(
  476. "worker: {} died. {}".format(worker_id, exitcode)
  477. )
  478. def _get_next_batch(self):
  479. start_time = time.time()
  480. while True:
  481. self._check_workers()
  482. try:
  483. return self.batch_queue.get(timeout=1)
  484. except queue.Empty:
  485. logger.debug("batch queue empty!")
  486. waited_time = time.time() - start_time
  487. if self.timeout > 0 and waited_time > self.timeout:
  488. self._put_raw_data_queues(self.timeout_event(), 0)
  489. def _shutdown(self):
  490. with self.shutdown_flag.get_lock():
  491. self.shutdown_flag.value = 1
  492. if self.recieve_worker.is_alive():
  493. self.recieve_worker.terminate()
  494. self.recieve_worker.join()
  495. if self.collect_worker.is_alive():
  496. self.collect_worker.terminate()
  497. self.collect_worker.join()
  498. for worker in self.transform_workers:
  499. if worker.is_alive():
  500. worker.terminate()
  501. worker.join()
  502. for q in self.raw_data_queues:
  503. q.cancel_join_thread()
  504. q.close()
  505. for q in self.trans_data_queues:
  506. q.cancel_join_thread()
  507. q.close()
  508. self.batch_queue.cancel_join_thread()
  509. self.batch_queue.close()
  510. def __del__(self):
  511. if self.__initialized:
  512. self._shutdown()
  513. def _task_feeding_loop(
  514. indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
  515. ):
  516. # Feed the indices into the task queues
  517. while True:
  518. if shutdown_flag.value == 1:
  519. break
  520. batch_idx = feed_batch_idx.value
  521. try:
  522. indices = next(indices_iter)
  523. except StopIteration:
  524. break
  525. if divide:
  526. # make sure all task_queues is ready for put
  527. while any([q.full() for q in task_queues]):
  528. if shutdown_flag.value == 1:
  529. return
  530. # divide into small pieces, feed to different workers.
  531. sub_num = math.ceil(len(indices) / num_workers)
  532. for worker_id in range(num_workers):
  533. sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num]
  534. task_queues[worker_id].put((batch_idx, sub_indices))
  535. else:
  536. # distribute tasks to different workers uniformly.
  537. target_id = batch_idx % num_workers
  538. while task_queues[target_id].full():
  539. if shutdown_flag.value == 1:
  540. return
  541. task_queues[target_id].put((batch_idx, indices))
  542. with feed_batch_idx.get_lock():
  543. feed_batch_idx.value += 1
  544. def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag):
  545. # Get dataset items and do the transform
  546. random.seed(seed)
  547. np.random.seed(seed)
  548. while True:
  549. if shutdown_flag.value == 1:
  550. break
  551. try:
  552. batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT)
  553. except queue.Empty:
  554. continue
  555. if len(indices) > 0:
  556. items = [dataset[idx] for idx in indices]
  557. trans_items = transform.apply_batch(items)
  558. else:
  559. # in case of incomplete last batch
  560. trans_items = ()
  561. while True:
  562. try:
  563. trans_data_queue.put((batch_idx, trans_items), timeout=1)
  564. break
  565. except queue.Full:
  566. if shutdown_flag.value == 1:
  567. break
  568. logger.debug("batch part queue is full!")
  569. def _data_gathering_loop(
  570. trans_data_queues,
  571. batch_queue,
  572. collator,
  573. length,
  574. num_workers,
  575. shutdown_flag,
  576. target_idx,
  577. ):
  578. # Gathering the small pieces of batch data into full batch data
  579. while True:
  580. if shutdown_flag.value == 1:
  581. break
  582. target_batch_idx = target_idx.value
  583. if target_batch_idx >= length:
  584. break
  585. full_trans_items = []
  586. for worker_id in range(num_workers):
  587. while True:
  588. try:
  589. batch_idx, trans_items = trans_data_queues[worker_id].get(
  590. timeout=GLOBAL_TIMEOUT
  591. )
  592. break
  593. except queue.Empty:
  594. if shutdown_flag.value == 1:
  595. break
  596. logger.debug(
  597. "worker:{} data queue get timeout! target batch idx:{}".format(
  598. worker_id, target_batch_idx
  599. )
  600. )
  601. if batch_idx != target_batch_idx:
  602. raise RuntimeError(
  603. "Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
  604. worker_id
  605. )
  606. )
  607. else:
  608. full_trans_items.extend(trans_items)
  609. # Merge different parts into a batch.
  610. full_batch = collator.apply(full_trans_items)
  611. while True:
  612. try:
  613. batch_queue.put(full_batch, timeout=1)
  614. break
  615. except queue.Full:
  616. if shutdown_flag.value == 1:
  617. break
  618. logger.debug("batch queue is full!")
  619. with target_idx.get_lock():
  620. target_idx.value += 1
  621. batch_queue.disconnect_client()
  622. def _data_selecting_loop(
  623. trans_data_queues,
  624. batch_queue,
  625. collator,
  626. length,
  627. num_workers,
  628. shutdown_flag,
  629. target_idx,
  630. ):
  631. # Make sure that batch is generated exactly with the same order as generated indices
  632. while True:
  633. if shutdown_flag.value == 1:
  634. break
  635. target_batch_idx = target_idx.value
  636. if target_batch_idx >= length:
  637. break
  638. target_worker_id = target_batch_idx % num_workers
  639. while True:
  640. try:
  641. batch_idx, trans_items = trans_data_queues[target_worker_id].get(
  642. timeout=GLOBAL_TIMEOUT
  643. )
  644. batch_data = collator.apply(trans_items)
  645. break
  646. except queue.Empty:
  647. if shutdown_flag.value == 1:
  648. break
  649. logger.debug(
  650. "worker:{} data queue get timeout! target batch idx:{}".format(
  651. target_worker_id, target_batch_idx
  652. )
  653. )
  654. if batch_idx != target_batch_idx:
  655. raise RuntimeError(
  656. "batch_idx {} mismatch the target_batch_idx {}".format(
  657. batch_idx, target_batch_idx
  658. )
  659. )
  660. while True:
  661. try:
  662. batch_queue.put(batch_data, timeout=1)
  663. break
  664. except queue.Full:
  665. if shutdown_flag.value == 1:
  666. break
  667. logger.debug("batch queue is full!")
  668. with target_idx.get_lock():
  669. target_idx.value += 1
  670. batch_queue.disconnect_client()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台