You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # -*- coding: utf-8 -*-
  2. # This file is part of MegEngine, a deep learning framework developed by
  3. # Megvii.
  4. #
  5. # Copyright (c) Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
  6. import threading
  7. import numpy as np
  8. from .base import *
  9. from .struct import *
  10. from .tensor import *
  11. class TensorBatchCollector:
  12. """
  13. this is a tensor utils to collect subtensor in batch continuous
  14. """
  15. def __init__(
  16. self,
  17. shape,
  18. dtype=LiteDataType.LITE_INT8,
  19. device_type=LiteDeviceType.LITE_CUDA,
  20. device_id=0,
  21. is_pinned_host=False,
  22. tensor=None,
  23. ):
  24. self._mutex = threading.Lock()
  25. self.dev_type = device_type
  26. self.is_pinned_host = is_pinned_host
  27. self.dev_id = 0
  28. self.shape = shape
  29. self.dtype = LiteLayout(dtype=dtype).data_type
  30. self._free_list = list(range(self.shape[0]))
  31. if tensor is not None:
  32. assert (
  33. tensor.layout.shapes[0 : tensor.layout.ndim] == shape
  34. ), "The tensor set to TensorBatchCollector is not right."
  35. self._tensor = tensor
  36. self.dtype = tensor.layout.data_type
  37. self.device_type = tensor.device_type
  38. self.device_id = tensor.device_type
  39. else:
  40. self._tensor = LiteTensor(
  41. LiteLayout(shape, dtype), device_type, device_id, is_pinned_host
  42. )
  43. def collect_id(self, array, batch_id):
  44. if isinstance(array, np.ndarray):
  45. shape = array.shape
  46. assert list(shape) == self.shape[1:]
  47. in_dtype = ctype_to_lite_dtypes[np.ctypeslib.as_ctypes_type(array.dtype)]
  48. assert in_dtype == self.dtype
  49. # get the batch index
  50. with self._mutex:
  51. if batch_id in self._free_list:
  52. self._free_list.remove(batch_id)
  53. # get the subtensor
  54. subtensor = self._tensor.slice([batch_id], [batch_id + 1])
  55. if subtensor.device_type == LiteDeviceType.LITE_CPU:
  56. subtensor.set_data_by_copy(array)
  57. else:
  58. pinned_tensor = LiteTensor(
  59. subtensor.layout, self.dev_type, self.dev_id, True
  60. )
  61. pinned_tensor.set_data_by_share(array)
  62. subtensor.copy_from(pinned_tensor)
  63. else:
  64. assert isinstance(array, LiteTensor)
  65. ndim = array.layout.ndim
  66. shape = list(array.layout.shapes)[0:ndim]
  67. assert list(shape) == self.shape[1:]
  68. in_dtype = array.layout.data_type
  69. assert in_dtype == self.dtype
  70. # get the batch index
  71. with self._mutex:
  72. if batch_id in self._free_list:
  73. self._free_list.remove(batch_id)
  74. # get the subtensor
  75. subtensor = self._tensor.slice([batch_id], [batch_id + 1])
  76. subtensor.copy_from(array)
  77. return batch_id
  78. def collect(self, array):
  79. with self._mutex:
  80. if len(self._free_list) == 0:
  81. return -1
  82. idx = self._free_list.pop(0)
  83. return self.collect_id(array, idx)
  84. def collect_by_ctypes(self, data, length):
  85. """
  86. collect with ctypes data input
  87. """
  88. with self._mutex:
  89. if len(self._free_list) == 0:
  90. return -1
  91. idx = self._free_list.pop(0)
  92. # get the subtensor
  93. subtensor = self._tensor.slice([idx], [idx + 1])
  94. if subtensor.device_type == LiteDeviceType.LITE_CPU:
  95. subtensor.set_data_by_copy(data, length)
  96. else:
  97. pinned_tensor = LiteTensor(
  98. subtensor.layout, self.dev_type, self.dev_id, True
  99. )
  100. pinned_tensor.set_data_by_share(data, length)
  101. subtensor.copy_from(pinned_tensor)
  102. def free(self, indexes):
  103. with self._mutex:
  104. self._free_list.extend(indexes)
  105. def get(self):
  106. return self._tensor
  107. def to_numpy(self):
  108. return self._tensor.to_numpy()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台