You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import math
  11. import multiprocessing
  12. import platform
  13. import queue
  14. import random
  15. import time
  16. import numpy as np
  17. from ..logger import get_logger
  18. from ..random.rng import _random_seed_generator
  19. from .collator import Collator
  20. from .dataset import Dataset, MapDataset, StreamDataset
  21. from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
  22. from .transform import PseudoTransform, Transform
  23. logger = get_logger(__name__)
  24. MP_QUEUE_GET_TIMEOUT = 5
  25. class DataLoader:
  26. __initialized = False
  27. def __init__(
  28. self,
  29. dataset: Dataset,
  30. sampler: Sampler = None,
  31. transform: Transform = None,
  32. collator: Collator = None,
  33. num_workers: int = 0,
  34. timeout: int = 0,
  35. divide: bool = False,
  36. ):
  37. r"""
  38. Provides a convenient way to iterate on a given dataset.
  39. `DataLoader` combines a dataset with `sampler`, `transform` and `collator`,
  40. make it flexible to get minibatch continually from a dataset.
  41. :type dataset: Dataset
  42. :param dataset: dataset from which to load the minibatch.
  43. :type sampler: Sampler
  44. :param sampler: defines the strategy to sample data from the dataset.
  45. :type transform: Transform
  46. :param transform: defined the transforming strategy for a sampled batch.
  47. Default: None
  48. :type collator: Collator
  49. :param collator: defined the merging strategy for a transformed batch.
  50. Default: None
  51. :type num_workers: int
  52. :param num_workers: the number of sub-process to load, transform and collate
  53. the batch. ``0`` means using single-process. Default: 0
  54. :type timeout: int
  55. :param timeout: if positive, means the timeout value(second) for collecting a
  56. batch from workers. Default: 0
  57. :type divide: bool
  58. :param divide: define the paralleling strategy in multi-processing mode.
  59. ``True`` means one batch is divided into :attr:`num_workers` pieces, and
  60. the workers will process these pieces parallelly. ``False`` means
  61. different sub-process will process different batch. Default: False
  62. """
  63. if num_workers < 0:
  64. raise ValueError("num_workers should not be negative")
  65. if timeout < 0:
  66. raise ValueError("timeout should not be negative")
  67. if divide and num_workers <= 1:
  68. raise ValueError("divide should not be set to True when num_workers <= 1")
  69. self.dataset = dataset
  70. self.num_workers = num_workers
  71. self.timeout = timeout
  72. self.divide = divide
  73. if isinstance(dataset, MapDataset):
  74. self.sampler = (
  75. sampler
  76. if sampler
  77. else SequentialSampler(dataset, batch_size=1, drop_last=False)
  78. )
  79. assert isinstance(
  80. self.sampler, MapSampler
  81. ), "types of dataset and sampler do not match"
  82. elif isinstance(dataset, StreamDataset):
  83. self.sampler = sampler if sampler else StreamSampler(batch_size=1)
  84. assert isinstance(
  85. self.sampler, StreamSampler
  86. ), "types of dataset and sampler do not match"
  87. else:
  88. raise TypeError(
  89. "can not recognize this kind of dataset: %s" % type(dataset)
  90. )
  91. if divide:
  92. if self.sampler.batch_size <= self.num_workers:
  93. raise ValueError(
  94. "batch size must not smaller than num_workers in divide mode."
  95. )
  96. elif self.sampler.batch_size % self.num_workers:
  97. logger.warning(
  98. "batch size is not divisible by num_workers, may lose performance in divide mode."
  99. )
  100. if transform is None:
  101. self.transform = PseudoTransform()
  102. else:
  103. self.transform = transform
  104. if collator is None:
  105. self.collator = Collator()
  106. else:
  107. self.collator = collator
  108. self.__initialized = True
  109. def __iter__(self):
  110. if platform.system() == "Windows" and self.num_workers > 0:
  111. print(
  112. "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
  113. )
  114. self.num_workers = 0
  115. if isinstance(self.dataset, StreamDataset):
  116. if not self.num_workers:
  117. return _SerialStreamDataLoaderIter(self)
  118. else:
  119. return _ParallelStreamDataLoaderIter(self)
  120. elif isinstance(self.dataset, MapDataset):
  121. if not self.num_workers:
  122. return _SerialMapDataLoaderIter(self)
  123. else:
  124. return _ParallelMapDataLoaderIter(self)
  125. else:
  126. raise TypeError(
  127. "can not recognize this kind of dataset: %s" % type(self.dataset)
  128. )
  129. def __len__(self):
  130. return len(self.sampler)
  131. class _BaseMapDataLoaderIter:
  132. def __init__(self, loader):
  133. self.dataset = loader.dataset
  134. self.sampler = loader.sampler
  135. self.seed = _random_seed_generator().__next__()
  136. self.transform = loader.transform
  137. self.collator = loader.collator
  138. self.num_workers = loader.num_workers
  139. self.timeout = loader.timeout
  140. self.divide = loader.divide
  141. self.num_processed = 0
  142. def _get_next_batch(self):
  143. raise NotImplementedError
  144. def __len__(self):
  145. return len(self.sampler)
  146. def __iter__(self):
  147. return self
  148. def __next__(self):
  149. if self.num_processed >= len(self):
  150. raise StopIteration
  151. minibatch = self._get_next_batch()
  152. self.num_processed += 1
  153. return minibatch
  154. class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
  155. def __init__(self, loader):
  156. super(_SerialMapDataLoaderIter, self).__init__(loader)
  157. self.indices_iter = iter(self.sampler)
  158. def _get_next_batch(self):
  159. indices = next(self.indices_iter)
  160. items = [self.dataset[idx] for idx in indices]
  161. trans_items = self.transform.apply_batch(items)
  162. return self.collator.apply(trans_items)
  163. class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
  164. __initialized = False
  165. def __init__(self, loader):
  166. super(_ParallelMapDataLoaderIter, self).__init__(loader)
  167. self.task_queues = [
  168. multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
  169. ]
  170. self.feed_batch_idx = multiprocessing.Value("i", 0)
  171. self.target_batch_idx = multiprocessing.Value("i", 0)
  172. self.shutdown_flag = multiprocessing.Value("i", 0)
  173. self.trans_data_queues = [
  174. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  175. ]
  176. # use shared-memory queue implemented by pyarrow plasma store.
  177. from ._queue import PlasmaShmQueue
  178. self.batch_queue = PlasmaShmQueue(maxsize=2)
  179. self.task_feeding_worker = multiprocessing.Process(
  180. target=_task_feeding_loop,
  181. args=(
  182. iter(self.sampler),
  183. self.task_queues,
  184. self.num_workers,
  185. self.divide,
  186. self.shutdown_flag,
  187. self.feed_batch_idx,
  188. ),
  189. daemon=True,
  190. )
  191. self.task_feeding_worker.start()
  192. self.workers = []
  193. for worker_id in range(self.num_workers):
  194. worker = multiprocessing.Process(
  195. target=_worker_loop,
  196. args=(
  197. self.dataset,
  198. self.task_queues[worker_id],
  199. self.trans_data_queues[worker_id],
  200. self.transform,
  201. self.seed + worker_id + 1,
  202. self.shutdown_flag,
  203. ),
  204. daemon=True,
  205. )
  206. worker.start()
  207. self.workers.append(worker)
  208. if self.divide:
  209. self.data_collecting_worker = multiprocessing.Process(
  210. target=_data_gathering_loop,
  211. args=(
  212. self.trans_data_queues,
  213. self.batch_queue,
  214. self.collator,
  215. len(self),
  216. self.num_workers,
  217. self.shutdown_flag,
  218. self.target_batch_idx,
  219. ),
  220. daemon=True,
  221. )
  222. else:
  223. self.data_collecting_worker = multiprocessing.Process(
  224. target=_data_selecting_loop,
  225. args=(
  226. self.trans_data_queues,
  227. self.batch_queue,
  228. self.collator,
  229. len(self),
  230. self.num_workers,
  231. self.shutdown_flag,
  232. self.target_batch_idx,
  233. ),
  234. daemon=True,
  235. )
  236. self.data_collecting_worker.start()
  237. self.__initialized = True
  238. def _check_workers(self):
  239. # Check the status of each worker.
  240. if not self.data_collecting_worker.is_alive():
  241. exitcode = self.task_feeding_worker.exitcode
  242. if exitcode != 0:
  243. raise RuntimeError("data collecting worker died. {}".format(exitcode))
  244. if not self.task_feeding_worker.is_alive():
  245. exitcode = self.task_feeding_worker.exitcode
  246. if exitcode != 0:
  247. raise RuntimeError("task feeding worker died. {}".format(exitcode))
  248. for worker_id, worker in enumerate(self.workers):
  249. if not worker.is_alive():
  250. exitcode = worker.exitcode
  251. if exitcode != 0:
  252. raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode))
  253. logger.debug("all workers are alive.")
  254. def _try_get_next_batch(self):
  255. start_time = time.time()
  256. while True:
  257. self._check_workers()
  258. try:
  259. return self.batch_queue.get(timeout=1)
  260. except queue.Empty:
  261. logger.debug("batch queue empty!")
  262. waited_time = time.time() - start_time
  263. if self.timeout > 0:
  264. if waited_time > self.timeout:
  265. raise RuntimeError("get_next_batch timeout!")
  266. def _get_next_batch(self):
  267. batch_data = self._try_get_next_batch()
  268. return batch_data
  269. def _shutdown(self):
  270. with self.shutdown_flag.get_lock():
  271. self.shutdown_flag.value = 1
  272. if self.task_feeding_worker.is_alive():
  273. self.task_feeding_worker.terminate()
  274. self.task_feeding_worker.join()
  275. if self.data_collecting_worker.is_alive():
  276. self.data_collecting_worker.terminate()
  277. self.data_collecting_worker.join()
  278. for worker in self.workers:
  279. if worker.is_alive():
  280. worker.terminate()
  281. worker.join()
  282. for q in self.trans_data_queues:
  283. q.cancel_join_thread()
  284. q.close()
  285. for q in self.task_queues:
  286. q.cancel_join_thread()
  287. q.close()
  288. self.batch_queue.cancel_join_thread()
  289. self.batch_queue.close()
  290. def __del__(self):
  291. if self.__initialized:
  292. self._shutdown()
  293. class _BaseStreamDataLoaderIter:
  294. def __init__(self, loader):
  295. self.dataset = loader.dataset
  296. self.sampler = loader.sampler
  297. self.transform = loader.transform
  298. self.collator = loader.collator
  299. self.num_workers = loader.num_workers
  300. self.timeout = loader.timeout
  301. def _get_next_batch(self):
  302. raise NotImplementedError
  303. def __iter__(self):
  304. return self
  305. def __next__(self):
  306. return self._get_next_batch()
  307. class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  308. def __init__(self, loader):
  309. super().__init__(loader)
  310. self.dataset_iter = iter(self.dataset)
  311. self.idx = 0
  312. self.data = None
  313. def _get_next_batch(self):
  314. ret = []
  315. start_time = time.time()
  316. while len(ret) != self.sampler.batch_size:
  317. waited_time = time.time() - start_time
  318. if self.timeout > 0 and waited_time > self.timeout:
  319. raise RuntimeError("get_next_batch timeout!")
  320. if self.idx != 0:
  321. data = self.data
  322. else:
  323. try:
  324. raw_data = next(self.dataset_iter)
  325. except:
  326. continue
  327. assert len(raw_data) == 2 and isinstance(
  328. raw_data[0], bool
  329. ), "raw_data must be a tuple"
  330. if not raw_data[0]:
  331. data = list((x,) for x in raw_data[1])
  332. else:
  333. data = raw_data[1]
  334. for idx in range(self.idx, len(data[0])):
  335. trans_data = self.transform.apply(tuple(e[idx] for e in data))
  336. ret.append(trans_data)
  337. if len(ret) == self.sampler.batch_size:
  338. if idx + 1 == len(data[0]):
  339. self.idx = 0
  340. self.data = None
  341. else:
  342. self.idx = idx
  343. self.data = data
  344. break
  345. return self.collator.apply(ret)
  346. class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  347. __initialized = False
  348. def __init__(self, loader):
  349. super().__init__(loader)
  350. self.shutdown_flag = multiprocessing.Value("i", 0)
  351. self.raw_data_queues = [
  352. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  353. ]
  354. self.trans_data_queues = [
  355. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  356. ]
  357. # shared-memory queue implemented by pyarrow plasma store
  358. from ._queue import PlasmaShmQueue
  359. self.batch_queue = PlasmaShmQueue(maxsize=2)
  360. self.recieve_worker = multiprocessing.Process(target=self._recieve, daemon=True)
  361. self.recieve_worker.start()
  362. self.transform_workers = []
  363. for worker_id in range(self.num_workers):
  364. worker = multiprocessing.Process(
  365. target=self._transform, args=(worker_id,), daemon=True
  366. )
  367. worker.start()
  368. self.transform_workers.append(worker)
  369. self.collect_worker = multiprocessing.Process(target=self._collect, daemon=True)
  370. self.collect_worker.start()
  371. self.__initialized = True
  372. def _recieve(self):
  373. dataset_iter = iter(self.dataset)
  374. cnt = -1
  375. while True:
  376. if self.shutdown_flag.value == 1:
  377. break
  378. raw_data = next(dataset_iter)
  379. assert len(raw_data) == 2 and isinstance(
  380. raw_data[0], bool
  381. ), "raw_data must be a tuple"
  382. if not raw_data[0]:
  383. data = list((x,) for x in raw_data[1])
  384. else:
  385. data = raw_data[1]
  386. for idx in range(len(data[0])):
  387. while True:
  388. cnt += 1
  389. qid = cnt % self.num_workers
  390. try:
  391. self.raw_data_queues[qid].put(tuple(e[idx] for e in data))
  392. break
  393. except queue.Full:
  394. if self.shutdown_flag.value == 1:
  395. break
  396. logger.debug("raw data queue is full")
  397. def _transform(self, worker_id):
  398. while True:
  399. if self.shutdown_flag.value == 1:
  400. break
  401. try:
  402. data = self.raw_data_queues[worker_id].get(timeout=MP_QUEUE_GET_TIMEOUT)
  403. except queue.Empty:
  404. continue
  405. trans_data = self.transform.apply(data)
  406. while True:
  407. try:
  408. self.trans_data_queues[worker_id].put(trans_data)
  409. break
  410. except queue.Full:
  411. if self.shutdown_flag.value == 1:
  412. break
  413. logger.debug("batch queue if full")
  414. def _collect(self):
  415. cnt = -1
  416. trans_items = []
  417. while True:
  418. if self.shutdown_flag.value == 1:
  419. break
  420. cnt += 1
  421. queue_id = cnt % self.num_workers
  422. try:
  423. trans_item = self.trans_data_queues[queue_id].get(
  424. timeout=MP_QUEUE_GET_TIMEOUT
  425. )
  426. except queue.Empty:
  427. continue
  428. trans_items.append(trans_item)
  429. if len(trans_items) == self.sampler.batch_size:
  430. batch_data = self.collator.apply(trans_items)
  431. while True:
  432. try:
  433. self.batch_queue.put(batch_data, timeout=1)
  434. break
  435. except queue.Full:
  436. if self.shutdown_flag.value == 1:
  437. break
  438. logger.debug("batch queue is full")
  439. trans_items = []
  440. def _check_workers(self):
  441. if not self.collect_worker.is_alive():
  442. exitcode = self.collect_worker.exitcode
  443. if exitcode != 0:
  444. raise RuntimeError("collator worker died. {}".format(exitcode))
  445. for worker_id, worker in enumerate(self.transform_workers):
  446. if not worker.is_alive():
  447. exitcode = worker.exitcode
  448. if exitcode != 0:
  449. raise RuntimeError(
  450. "worker: {} died. {}".format(worker_id, exitcode)
  451. )
  452. def _try_get_next_batch(self):
  453. start_time = time.time()
  454. while True:
  455. self._check_workers()
  456. try:
  457. return self.batch_queue.get(timeout=1)
  458. except queue.Empty:
  459. logger.debug("batch queue empty!")
  460. waited_time = time.time() - start_time
  461. if self.timeout > 0 and waited_time > self.timeout:
  462. raise RuntimeError("get_next_batch timeout!")
  463. def _get_next_batch(self):
  464. batch_data = self._try_get_next_batch()
  465. return batch_data
  466. def _shutdown(self):
  467. with self.shutdown_flag.get_lock():
  468. self.shutdown_flag.value = 1
  469. if self.recieve_worker.is_alive():
  470. self.recieve_worker.terminate()
  471. self.recieve_worker.join()
  472. if self.collect_worker.is_alive():
  473. self.collect_worker.terminate()
  474. self.collect_worker.join()
  475. for worker in self.transform_workers:
  476. if worker.is_alive():
  477. worker.terminate()
  478. worker.join()
  479. for q in self.raw_data_queues:
  480. q.cancel_join_thread()
  481. q.close()
  482. for q in self.trans_data_queues:
  483. q.cancel_join_thread()
  484. q.close()
  485. self.batch_queue.cancel_join_thread()
  486. self.batch_queue.close()
  487. def __del__(self):
  488. if self.__initialized:
  489. self._shutdown()
  490. def _task_feeding_loop(
  491. indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
  492. ):
  493. # Feed the indices into the task queues
  494. while True:
  495. if shutdown_flag.value == 1:
  496. break
  497. batch_idx = feed_batch_idx.value
  498. try:
  499. indices = next(indices_iter)
  500. except StopIteration:
  501. break
  502. if divide:
  503. # make sure all task_queues is ready for put
  504. while any([q.full() for q in task_queues]):
  505. if shutdown_flag.value == 1:
  506. return
  507. # divide into small pieces, feed to different workers.
  508. sub_num = math.ceil(len(indices) / num_workers)
  509. for worker_id in range(num_workers):
  510. sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num]
  511. task_queues[worker_id].put((batch_idx, sub_indices))
  512. else:
  513. # distribute tasks to different workers uniformly.
  514. target_id = batch_idx % num_workers
  515. while task_queues[target_id].full():
  516. if shutdown_flag.value == 1:
  517. return
  518. task_queues[target_id].put((batch_idx, indices))
  519. with feed_batch_idx.get_lock():
  520. feed_batch_idx.value += 1
  521. def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag):
  522. # Get dataset items and do the transform
  523. random.seed(seed)
  524. np.random.seed(seed)
  525. while True:
  526. if shutdown_flag.value == 1:
  527. break
  528. try:
  529. batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT)
  530. except queue.Empty:
  531. continue
  532. if len(indices) > 0:
  533. items = [dataset[idx] for idx in indices]
  534. trans_items = transform.apply_batch(items)
  535. else:
  536. # in case of incomplete last batch
  537. trans_items = ()
  538. while True:
  539. try:
  540. trans_data_queue.put((batch_idx, trans_items), timeout=1)
  541. break
  542. except queue.Full:
  543. if shutdown_flag.value == 1:
  544. break
  545. logger.debug("batch part queue is full!")
  546. def _data_gathering_loop(
  547. trans_data_queues,
  548. batch_queue,
  549. collator,
  550. length,
  551. num_workers,
  552. shutdown_flag,
  553. target_idx,
  554. ):
  555. # Gathering the small pieces of batch data into full batch data
  556. while True:
  557. if shutdown_flag.value == 1:
  558. break
  559. target_batch_idx = target_idx.value
  560. if target_batch_idx >= length:
  561. break
  562. full_trans_items = []
  563. for worker_id in range(num_workers):
  564. while True:
  565. try:
  566. batch_idx, trans_items = trans_data_queues[worker_id].get(
  567. timeout=MP_QUEUE_GET_TIMEOUT
  568. )
  569. break
  570. except queue.Empty:
  571. if shutdown_flag.value == 1:
  572. break
  573. logger.debug(
  574. "worker:{} data queue get timeout! target batch idx:{}".format(
  575. worker_id, target_batch_idx
  576. )
  577. )
  578. if batch_idx != target_batch_idx:
  579. raise RuntimeError(
  580. "Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
  581. worker_id
  582. )
  583. )
  584. else:
  585. full_trans_items.extend(trans_items)
  586. # Merge different parts into a batch.
  587. full_batch = collator.apply(full_trans_items)
  588. while True:
  589. try:
  590. batch_queue.put(full_batch, timeout=1)
  591. break
  592. except queue.Full:
  593. if shutdown_flag.value == 1:
  594. break
  595. logger.debug("batch queue is full!")
  596. with target_idx.get_lock():
  597. target_idx.value += 1
  598. batch_queue.disconnect_client()
  599. def _data_selecting_loop(
  600. trans_data_queues,
  601. batch_queue,
  602. collator,
  603. length,
  604. num_workers,
  605. shutdown_flag,
  606. target_idx,
  607. ):
  608. # Make sure that batch is generated exactly with the same order as generated indices
  609. while True:
  610. if shutdown_flag.value == 1:
  611. break
  612. target_batch_idx = target_idx.value
  613. if target_batch_idx >= length:
  614. break
  615. target_worker_id = target_batch_idx % num_workers
  616. while True:
  617. try:
  618. batch_idx, trans_items = trans_data_queues[target_worker_id].get(
  619. timeout=MP_QUEUE_GET_TIMEOUT
  620. )
  621. batch_data = collator.apply(trans_items)
  622. break
  623. except queue.Empty:
  624. if shutdown_flag.value == 1:
  625. break
  626. logger.debug(
  627. "worker:{} data queue get timeout! target batch idx:{}".format(
  628. target_worker_id, target_batch_idx
  629. )
  630. )
  631. if batch_idx != target_batch_idx:
  632. raise RuntimeError(
  633. "batch_idx {} mismatch the target_batch_idx {}".format(
  634. batch_idx, target_batch_idx
  635. )
  636. )
  637. while True:
  638. try:
  639. batch_queue.put(batch_data, timeout=1)
  640. break
  641. except queue.Full:
  642. if shutdown_flag.value == 1:
  643. break
  644. logger.debug("batch queue is full!")
  645. with target_idx.get_lock():
  646. target_idx.value += 1
  647. batch_queue.disconnect_client()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台