You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import gc
  11. import math
  12. import multiprocessing
  13. import platform
  14. import queue
  15. import random
  16. import threading
  17. import time
  18. from typing import Callable
  19. import numpy as np
  20. from ..logger import get_logger
  21. from ..random.rng import _random_seed_generator
  22. from .collator import Collator
  23. from .dataset import Dataset, StreamDataset
  24. from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
  25. from .transform import PseudoTransform, Transform
  26. try:
  27. import thread
  28. except:
  29. import _thread as thread
  30. logger = get_logger(__name__)
  31. GLOBAL_TIMEOUT = 5
  32. def raise_timeout_error():
  33. raise RuntimeError("dataloader timeout")
  34. class DataLoader:
  35. r"""Provides a convenient way to iterate on a given dataset.
  36. DataLoader combines a dataset with
  37. :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`,
  38. make it flexible to get minibatch continually from a dataset.
  39. Args:
  40. dataset: dataset from which to load the minibatch.
  41. sampler: defines the strategy to sample data from the dataset.
  42. transform: defined the transforming strategy for a sampled batch.
  43. Default: None
  44. collator: defined the merging strategy for a transformed batch.
  45. Default: None
  46. num_workers: the number of sub-process to load, transform and collate
  47. the batch. ``0`` means using single-process. Default: 0
  48. timeout: if positive, means the timeout value(second) for collecting a
  49. batch from workers. Default: 0
  50. timeout_event: callback function triggered by timeout, default to raise
  51. runtime error.
  52. divide: define the paralleling strategy in multi-processing mode.
  53. ``True`` means one batch is divided into :attr:`num_workers` pieces, and
  54. the workers will process these pieces parallelly. ``False`` means
  55. different sub-process will process different batch. Default: False
  56. """
  57. __initialized = False
  58. def __init__(
  59. self,
  60. dataset: Dataset,
  61. sampler: Sampler = None,
  62. transform: Transform = None,
  63. collator: Collator = None,
  64. num_workers: int = 0,
  65. timeout: int = 0,
  66. timeout_event: Callable = raise_timeout_error,
  67. divide: bool = False,
  68. ):
  69. if num_workers < 0:
  70. raise ValueError("num_workers should not be negative")
  71. if timeout < 0:
  72. raise ValueError("timeout should not be negative")
  73. if divide and num_workers <= 1:
  74. raise ValueError("divide should not be set to True when num_workers <= 1")
  75. self.dataset = dataset
  76. self.num_workers = num_workers
  77. self.timeout = timeout
  78. self.timeout_event = timeout_event
  79. self.divide = divide
  80. if isinstance(dataset, StreamDataset):
  81. self.sampler = sampler if sampler else StreamSampler(batch_size=1)
  82. assert isinstance(
  83. self.sampler, StreamSampler
  84. ), "types of dataset and sampler do not match"
  85. else:
  86. assert isinstance(
  87. dataset, Dataset
  88. ), "Can not recognize this kind of dataset: %s" % type(dataset)
  89. self.sampler = (
  90. sampler
  91. if sampler
  92. else SequentialSampler(dataset, batch_size=1, drop_last=False)
  93. )
  94. assert isinstance(
  95. self.sampler, MapSampler
  96. ), "types of dataset and sampler do not match"
  97. if divide:
  98. if self.sampler.batch_size <= self.num_workers:
  99. raise ValueError(
  100. "batch size must not smaller than num_workers in divide mode."
  101. )
  102. elif self.sampler.batch_size % self.num_workers:
  103. logger.warning(
  104. "batch size is not divisible by num_workers, may lose performance in divide mode."
  105. )
  106. if transform is None:
  107. self.transform = PseudoTransform()
  108. else:
  109. self.transform = transform
  110. if collator is None:
  111. self.collator = Collator()
  112. else:
  113. self.collator = collator
  114. self.__initialized = True
  115. def __iter__(self):
  116. if platform.system() == "Windows" and self.num_workers > 0:
  117. print(
  118. "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
  119. )
  120. self.num_workers = 0
  121. if isinstance(self.dataset, StreamDataset):
  122. if not self.num_workers:
  123. return _SerialStreamDataLoaderIter(self)
  124. else:
  125. return _ParallelStreamDataLoaderIter(self)
  126. else:
  127. assert isinstance(
  128. self.dataset, Dataset
  129. ), "Can not recognize this kind of dataset: %s" % type(self.dataset)
  130. if not self.num_workers:
  131. return _SerialMapDataLoaderIter(self)
  132. else:
  133. return _ParallelMapDataLoaderIter(self)
  134. def __len__(self):
  135. return len(self.sampler)
  136. class _BaseMapDataLoaderIter:
  137. def __init__(self, loader):
  138. self.dataset = loader.dataset
  139. self.sampler = loader.sampler
  140. self.seed = _random_seed_generator().__next__()
  141. self.transform = loader.transform
  142. self.collator = loader.collator
  143. self.num_workers = loader.num_workers
  144. self.timeout = loader.timeout
  145. self.timeout_event = loader.timeout_event
  146. self.divide = loader.divide
  147. self.num_processed = 0
  148. def _get_next_batch(self):
  149. raise NotImplementedError
  150. def __len__(self):
  151. return len(self.sampler)
  152. def __iter__(self):
  153. return self
  154. def __next__(self):
  155. if self.num_processed >= len(self):
  156. raise StopIteration
  157. minibatch = self._get_next_batch()
  158. self.num_processed += 1
  159. return minibatch
  160. class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
  161. def __init__(self, loader):
  162. super(_SerialMapDataLoaderIter, self).__init__(loader)
  163. self.indices_iter = iter(self.sampler)
  164. def _get_next_batch(self):
  165. indices = next(self.indices_iter)
  166. items = [self.dataset[idx] for idx in indices]
  167. trans_items = self.transform.apply_batch(items)
  168. return self.collator.apply(trans_items)
  169. class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
  170. __initialized = False
  171. def __init__(self, loader):
  172. super(_ParallelMapDataLoaderIter, self).__init__(loader)
  173. self.task_queues = [
  174. multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
  175. ]
  176. self.feed_batch_idx = multiprocessing.Value("i", 0)
  177. self.target_batch_idx = multiprocessing.Value("i", 0)
  178. self.shutdown_flag = multiprocessing.Value("i", 0)
  179. self.trans_data_queues = [
  180. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  181. ]
  182. # use shared-memory queue implemented by pyarrow plasma store.
  183. from .tools._queue import PlasmaShmQueue
  184. self.batch_queue = PlasmaShmQueue(maxsize=2)
  185. self.task_feeding_worker = multiprocessing.Process(
  186. target=_task_feeding_loop,
  187. args=(
  188. iter(self.sampler),
  189. self.task_queues,
  190. self.num_workers,
  191. self.divide,
  192. self.shutdown_flag,
  193. self.feed_batch_idx,
  194. ),
  195. daemon=True,
  196. )
  197. gc.collect()
  198. self.task_feeding_worker.start()
  199. self.workers = []
  200. for worker_id in range(self.num_workers):
  201. worker = multiprocessing.Process(
  202. target=_worker_loop,
  203. args=(
  204. self.dataset,
  205. self.task_queues[worker_id],
  206. self.trans_data_queues[worker_id],
  207. self.transform,
  208. self.seed + worker_id + 1,
  209. self.shutdown_flag,
  210. ),
  211. daemon=True,
  212. )
  213. gc.collect()
  214. worker.start()
  215. self.workers.append(worker)
  216. if self.divide:
  217. self.data_collecting_worker = multiprocessing.Process(
  218. target=_data_gathering_loop,
  219. args=(
  220. self.trans_data_queues,
  221. self.batch_queue,
  222. self.collator,
  223. len(self),
  224. self.num_workers,
  225. self.shutdown_flag,
  226. self.target_batch_idx,
  227. ),
  228. daemon=True,
  229. )
  230. else:
  231. self.data_collecting_worker = multiprocessing.Process(
  232. target=_data_selecting_loop,
  233. args=(
  234. self.trans_data_queues,
  235. self.batch_queue,
  236. self.collator,
  237. len(self),
  238. self.num_workers,
  239. self.shutdown_flag,
  240. self.target_batch_idx,
  241. ),
  242. daemon=True,
  243. )
  244. gc.collect()
  245. self.data_collecting_worker.start()
  246. self.__initialized = True
  247. def _check_workers(self):
  248. # Check the status of each worker.
  249. if not self.data_collecting_worker.is_alive():
  250. exitcode = self.data_collecting_worker.exitcode
  251. if exitcode != 0:
  252. raise RuntimeError("data collecting worker died. {}".format(exitcode))
  253. if not self.task_feeding_worker.is_alive():
  254. exitcode = self.task_feeding_worker.exitcode
  255. if exitcode != 0:
  256. raise RuntimeError("task feeding worker died. {}".format(exitcode))
  257. for worker_id, worker in enumerate(self.workers):
  258. if not worker.is_alive():
  259. exitcode = worker.exitcode
  260. if exitcode != 0:
  261. raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode))
  262. logger.debug("all workers are alive.")
  263. def _get_next_batch(self):
  264. start_time = time.time()
  265. while True:
  266. self._check_workers()
  267. try:
  268. return self.batch_queue.get(timeout=1)
  269. except queue.Empty:
  270. logger.debug("batch queue empty!")
  271. waited_time = time.time() - start_time
  272. if self.timeout > 0:
  273. if waited_time > self.timeout:
  274. raise RuntimeError("get_next_batch timeout!")
  275. def _shutdown(self):
  276. with self.shutdown_flag.get_lock():
  277. self.shutdown_flag.value = 1
  278. if self.task_feeding_worker.is_alive():
  279. self.task_feeding_worker.terminate()
  280. self.task_feeding_worker.join()
  281. if self.data_collecting_worker.is_alive():
  282. self.data_collecting_worker.terminate()
  283. self.data_collecting_worker.join()
  284. for worker in self.workers:
  285. if worker.is_alive():
  286. worker.terminate()
  287. worker.join()
  288. for q in self.trans_data_queues:
  289. q.cancel_join_thread()
  290. q.close()
  291. for q in self.task_queues:
  292. q.cancel_join_thread()
  293. q.close()
  294. self.batch_queue.cancel_join_thread()
  295. self.batch_queue.close()
  296. def __del__(self):
  297. if self.__initialized:
  298. self._shutdown()
  299. class _BaseStreamDataLoaderIter:
  300. def __init__(self, loader):
  301. self.dataset = loader.dataset
  302. self.sampler = loader.sampler
  303. self.transform = loader.transform
  304. self.collator = loader.collator
  305. self.num_workers = loader.num_workers
  306. self.timeout = loader.timeout
  307. self.timeout_event = loader.timeout_event
  308. def _get_next_batch(self):
  309. raise NotImplementedError
  310. def _process_raw_data(self, raw_data):
  311. assert len(raw_data) == 2 and isinstance(
  312. raw_data[0], bool
  313. ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
  314. if not raw_data[0]:
  315. data = list((x,) for x in raw_data[1])
  316. else:
  317. data = raw_data[1]
  318. ret = []
  319. for idx in range(len(data[0])):
  320. ret.append(tuple(e[idx] for e in data))
  321. return ret
  322. def __iter__(self):
  323. return self
  324. def __next__(self):
  325. return self._get_next_batch()
  326. class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  327. def __init__(self, loader):
  328. super().__init__(loader)
  329. self.dataset_iter = iter(self.dataset)
  330. self.idx = 0
  331. self.unused = []
  332. def _try_get_raw_data(self, start_time):
  333. raw_data = None
  334. while not raw_data:
  335. try:
  336. if self.timeout > 0:
  337. timer = threading.Timer(self.timeout, thread.interrupt_main)
  338. timer.start()
  339. raw_data = next(self.dataset_iter)
  340. if self.timeout > 0:
  341. timer.cancel()
  342. except KeyboardInterrupt:
  343. raw_data = self.timeout_event()
  344. except:
  345. if self.timeout > 0:
  346. timer.cancel()
  347. waited_time = time.time() - start_time
  348. if waited_time > self.timeout:
  349. raw_data = self.timeout_event()
  350. return raw_data
  351. def _get_next_batch(self):
  352. ret = []
  353. start_time = time.time()
  354. while len(ret) < self.sampler.batch_size:
  355. if len(self.unused) != 0:
  356. batch_data = self.unused
  357. else:
  358. raw_data = self._try_get_raw_data(start_time)
  359. batch_data = self._process_raw_data(raw_data)
  360. while len(batch_data) != 0 and len(ret) < self.sampler.batch_size:
  361. data = batch_data.pop()
  362. ret.append(self.transform.apply(data))
  363. self.unused = batch_data
  364. return self.collator.apply(ret)
  365. class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  366. __initialized = False
  367. def __init__(self, loader):
  368. super().__init__(loader)
  369. self.shutdown_flag = multiprocessing.Value("i", 0)
  370. self.raw_data_queues = [
  371. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  372. ]
  373. self.trans_data_queues = [
  374. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  375. ]
  376. # shared-memory queue implemented by pyarrow plasma store
  377. from .tools._queue import PlasmaShmQueue
  378. self.batch_queue = PlasmaShmQueue(maxsize=2)
  379. self.recieve_worker = multiprocessing.Process(
  380. target=self._worker_to_raw_data_queues, daemon=True
  381. )
  382. gc.collect()
  383. self.recieve_worker.start()
  384. self.transform_workers = []
  385. for worker_id in range(self.num_workers):
  386. worker = multiprocessing.Process(
  387. target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True
  388. )
  389. gc.collect()
  390. worker.start()
  391. self.transform_workers.append(worker)
  392. self.collect_worker = multiprocessing.Process(
  393. target=self._worker_to_batch_queue, daemon=True
  394. )
  395. gc.collect()
  396. self.collect_worker.start()
  397. self.__initialized = True
  398. def _put_raw_data_queues(self, raw_data, qidx):
  399. batch_data = self._process_raw_data(raw_data)
  400. for data in batch_data:
  401. while True:
  402. qidx = qidx % self.num_workers
  403. try:
  404. self.raw_data_queues[qidx].put(data)
  405. break
  406. except queue.Full:
  407. if self.shutdown_flag.value == 1:
  408. break
  409. logger.debug("raw data queue %d is full" % qidx)
  410. finally:
  411. qidx += 1
  412. return qidx
  413. def _worker_to_raw_data_queues(self):
  414. dataset_iter = iter(self.dataset)
  415. qidx = 0
  416. while True:
  417. if self.shutdown_flag.value == 1:
  418. break
  419. raw_data = next(dataset_iter)
  420. qidx = self._put_raw_data_queues(raw_data, qidx)
  421. def _worker_to_trans_data_queues(self, worker_id):
  422. while True:
  423. if self.shutdown_flag.value == 1:
  424. break
  425. try:
  426. data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)
  427. except queue.Empty:
  428. continue
  429. trans_data = self.transform.apply(data)
  430. while True:
  431. try:
  432. self.trans_data_queues[worker_id].put(trans_data)
  433. break
  434. except queue.Full:
  435. if self.shutdown_flag.value == 1:
  436. break
  437. logger.debug("batch queue if full")
  438. def _worker_to_batch_queue(self):
  439. cnt = -1
  440. trans_items = []
  441. while True:
  442. if self.shutdown_flag.value == 1:
  443. break
  444. cnt += 1
  445. queue_id = cnt % self.num_workers
  446. try:
  447. trans_item = self.trans_data_queues[queue_id].get(
  448. timeout=GLOBAL_TIMEOUT
  449. )
  450. except queue.Empty:
  451. continue
  452. trans_items.append(trans_item)
  453. if len(trans_items) == self.sampler.batch_size:
  454. batch_data = self.collator.apply(trans_items)
  455. while True:
  456. try:
  457. self.batch_queue.put(batch_data, timeout=1)
  458. break
  459. except queue.Full:
  460. if self.shutdown_flag.value == 1:
  461. break
  462. logger.debug("batch queue is full")
  463. trans_items = []
  464. def _check_workers(self):
  465. if not self.collect_worker.is_alive():
  466. exitcode = self.collect_worker.exitcode
  467. if exitcode != 0:
  468. raise RuntimeError("collator worker died. {}".format(exitcode))
  469. for worker_id, worker in enumerate(self.transform_workers):
  470. if not worker.is_alive():
  471. exitcode = worker.exitcode
  472. if exitcode != 0:
  473. raise RuntimeError(
  474. "worker: {} died. {}".format(worker_id, exitcode)
  475. )
  476. def _get_next_batch(self):
  477. start_time = time.time()
  478. while True:
  479. self._check_workers()
  480. try:
  481. return self.batch_queue.get(timeout=1)
  482. except queue.Empty:
  483. logger.debug("batch queue empty!")
  484. waited_time = time.time() - start_time
  485. if self.timeout > 0 and waited_time > self.timeout:
  486. self._put_raw_data_queues(self.timeout_event(), 0)
  487. def _shutdown(self):
  488. with self.shutdown_flag.get_lock():
  489. self.shutdown_flag.value = 1
  490. if self.recieve_worker.is_alive():
  491. self.recieve_worker.terminate()
  492. self.recieve_worker.join()
  493. if self.collect_worker.is_alive():
  494. self.collect_worker.terminate()
  495. self.collect_worker.join()
  496. for worker in self.transform_workers:
  497. if worker.is_alive():
  498. worker.terminate()
  499. worker.join()
  500. for q in self.raw_data_queues:
  501. q.cancel_join_thread()
  502. q.close()
  503. for q in self.trans_data_queues:
  504. q.cancel_join_thread()
  505. q.close()
  506. self.batch_queue.cancel_join_thread()
  507. self.batch_queue.close()
  508. def __del__(self):
  509. if self.__initialized:
  510. self._shutdown()
  511. def _task_feeding_loop(
  512. indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
  513. ):
  514. # Feed the indices into the task queues
  515. while True:
  516. if shutdown_flag.value == 1:
  517. break
  518. batch_idx = feed_batch_idx.value
  519. try:
  520. indices = next(indices_iter)
  521. except StopIteration:
  522. break
  523. if divide:
  524. # make sure all task_queues is ready for put
  525. while any([q.full() for q in task_queues]):
  526. if shutdown_flag.value == 1:
  527. return
  528. # divide into small pieces, feed to different workers.
  529. sub_num = math.ceil(len(indices) / num_workers)
  530. for worker_id in range(num_workers):
  531. sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num]
  532. task_queues[worker_id].put((batch_idx, sub_indices))
  533. else:
  534. # distribute tasks to different workers uniformly.
  535. target_id = batch_idx % num_workers
  536. while task_queues[target_id].full():
  537. if shutdown_flag.value == 1:
  538. return
  539. task_queues[target_id].put((batch_idx, indices))
  540. with feed_batch_idx.get_lock():
  541. feed_batch_idx.value += 1
  542. def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag):
  543. # Get dataset items and do the transform
  544. random.seed(seed)
  545. np.random.seed(seed)
  546. while True:
  547. if shutdown_flag.value == 1:
  548. break
  549. try:
  550. batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT)
  551. except queue.Empty:
  552. continue
  553. if len(indices) > 0:
  554. items = [dataset[idx] for idx in indices]
  555. trans_items = transform.apply_batch(items)
  556. else:
  557. # in case of incomplete last batch
  558. trans_items = ()
  559. while True:
  560. try:
  561. trans_data_queue.put((batch_idx, trans_items), timeout=1)
  562. break
  563. except queue.Full:
  564. if shutdown_flag.value == 1:
  565. break
  566. logger.debug("batch part queue is full!")
  567. def _data_gathering_loop(
  568. trans_data_queues,
  569. batch_queue,
  570. collator,
  571. length,
  572. num_workers,
  573. shutdown_flag,
  574. target_idx,
  575. ):
  576. # Gathering the small pieces of batch data into full batch data
  577. while True:
  578. if shutdown_flag.value == 1:
  579. break
  580. target_batch_idx = target_idx.value
  581. if target_batch_idx >= length:
  582. break
  583. full_trans_items = []
  584. for worker_id in range(num_workers):
  585. while True:
  586. try:
  587. batch_idx, trans_items = trans_data_queues[worker_id].get(
  588. timeout=GLOBAL_TIMEOUT
  589. )
  590. break
  591. except queue.Empty:
  592. if shutdown_flag.value == 1:
  593. break
  594. logger.debug(
  595. "worker:{} data queue get timeout! target batch idx:{}".format(
  596. worker_id, target_batch_idx
  597. )
  598. )
  599. if batch_idx != target_batch_idx:
  600. raise RuntimeError(
  601. "Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
  602. worker_id
  603. )
  604. )
  605. else:
  606. full_trans_items.extend(trans_items)
  607. # Merge different parts into a batch.
  608. full_batch = collator.apply(full_trans_items)
  609. while True:
  610. try:
  611. batch_queue.put(full_batch, timeout=1)
  612. break
  613. except queue.Full:
  614. if shutdown_flag.value == 1:
  615. break
  616. logger.debug("batch queue is full!")
  617. with target_idx.get_lock():
  618. target_idx.value += 1
  619. batch_queue.disconnect_client()
  620. def _data_selecting_loop(
  621. trans_data_queues,
  622. batch_queue,
  623. collator,
  624. length,
  625. num_workers,
  626. shutdown_flag,
  627. target_idx,
  628. ):
  629. # Make sure that batch is generated exactly with the same order as generated indices
  630. while True:
  631. if shutdown_flag.value == 1:
  632. break
  633. target_batch_idx = target_idx.value
  634. if target_batch_idx >= length:
  635. break
  636. target_worker_id = target_batch_idx % num_workers
  637. while True:
  638. try:
  639. batch_idx, trans_items = trans_data_queues[target_worker_id].get(
  640. timeout=GLOBAL_TIMEOUT
  641. )
  642. batch_data = collator.apply(trans_items)
  643. break
  644. except queue.Empty:
  645. if shutdown_flag.value == 1:
  646. break
  647. logger.debug(
  648. "worker:{} data queue get timeout! target batch idx:{}".format(
  649. target_worker_id, target_batch_idx
  650. )
  651. )
  652. if batch_idx != target_batch_idx:
  653. raise RuntimeError(
  654. "batch_idx {} mismatch the target_batch_idx {}".format(
  655. batch_idx, target_batch_idx
  656. )
  657. )
  658. while True:
  659. try:
  660. batch_queue.put(batch_data, timeout=1)
  661. break
  662. except queue.Full:
  663. if shutdown_flag.value == 1:
  664. break
  665. logger.debug("batch queue is full!")
  666. with target_idx.get_lock():
  667. target_idx.value += 1
  668. batch_queue.disconnect_client()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台