You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 8.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # -*- coding: utf-8 -*-
  2. import threading
  3. import warnings
  4. import numpy as np
  5. from .base import *
  6. from .struct import *
  7. from .tensor import *
  8. class TensorBatchCollector:
  9. """
  10. A tensor utils is used to collect many single batch tensor to a multi batch
  11. size tensor, when the multi batch size tensor collect finish, the result
  12. tensor can be get and send to the model input for forwarding.
  13. when collect single batch tensor, the single batch tensor is no need in the
  14. same device_type and device_id with the result tensor, however the dtype must
  15. match and the shape must match except the highest dimension.
  16. Args:
  17. shape: the multi batch size tensor shape, After collection, the result
  18. tensor shape.
  19. dtype(LiteDataType): the datatype of the single batch tensor and the
  20. result tensor, default value is LiteDataType.LITE_INT8.
  21. device_type(LiteDeviceType): the target device type the result tensor
  22. will allocate, default value is LiteDeviceType.LITE_CUDA.
  23. device_id: the device id the result tensor will allocate, default 0.
  24. is_pinned_host: Whether the memory is pinned memory, refer to CUDA
  25. pinned memory, default False.
  26. tensor(LiteTensor): the result tensor, user can also create the multi
  27. batch size tensor and then create the TensorBatchColletor, if tensor is
  28. not None, all the member, such as shape, dtype, device_type,
  29. device_id, is_pinned_host will get from the tensor, if the tensor is
  30. None and the result tensor will create by the TensorBatchCollector,
  31. default is None.
  32. Note:
  33. when collect tensor, the single batch tensor or array shape must match the
  34. result tensor shape except the batch size dimension (the highest dimension)
  35. Examples:
  36. .. code-block:: python
  37. import numpy as np
  38. batch_tensor = TensorBatchCollector([4, 8, 8])
  39. arr = np.ones([8, 8], "int8")
  40. for i in range(4):
  41. batch_tensor.collect(arr)
  42. arr += 1
  43. data = batch_tensor.to_numpy()
  44. assert data.shape[0] == 4
  45. assert data.shape[1] == 8
  46. assert data.shape[2] == 8
  47. for i in range(4):
  48. for j in range(64):
  49. assert data[i][j // 8][j % 8] == i + 1
  50. """
  51. def __init__(
  52. self,
  53. shape,
  54. dtype=LiteDataType.LITE_INT8,
  55. device_type=LiteDeviceType.LITE_CUDA,
  56. device_id=0,
  57. is_pinned_host=False,
  58. tensor=None,
  59. ):
  60. self._mutex = threading.Lock()
  61. self.dev_type = device_type
  62. self.is_pinned_host = is_pinned_host
  63. self.dev_id = device_id
  64. self.shape = shape
  65. self.dtype = LiteLayout(dtype=dtype).data_type
  66. self._free_list = list(range(self.shape[0]))
  67. if tensor is not None:
  68. assert (
  69. tensor.layout.shapes[0 : tensor.layout.ndim] == shape
  70. ), "The tensor set to TensorBatchCollector is not right."
  71. self._tensor = tensor
  72. self.dtype = tensor.layout.data_type
  73. self.device_type = tensor.device_type
  74. self.device_id = tensor.device_type
  75. else:
  76. self._tensor = LiteTensor(
  77. LiteLayout(shape, dtype), device_type, device_id, is_pinned_host
  78. )
  79. def collect_id(self, array, batch_id):
  80. """
  81. Collect a single batch through an array and store the array data to the
  82. specific batch_id.
  83. Args:
  84. array: an array maybe LiteTensor or numpy ndarray, the shape of
  85. array must match the result tensor shape except the highest
  86. dimension.
  87. batch_id: the batch id to store the array data to the result tensor,
  88. if the batch_id has already collected, a warning will generate.
  89. """
  90. # get the batch index
  91. with self._mutex:
  92. if batch_id in self._free_list:
  93. self._free_list.remove(batch_id)
  94. else:
  95. warnings.warn(
  96. "batch {} has been collected, please call free before collected it again.".format(
  97. batch_id
  98. )
  99. )
  100. self._collect_with_id(array, batch_id)
  101. def _collect_with_id(self, array, batch_id):
  102. if isinstance(array, np.ndarray):
  103. shape = array.shape
  104. assert list(shape) == self.shape[1:]
  105. in_dtype = ctype_to_lite_dtypes[np.ctypeslib.as_ctypes_type(array.dtype)]
  106. assert in_dtype == self.dtype
  107. # get the subtensor
  108. subtensor = self._tensor.slice([batch_id], [batch_id + 1])
  109. if subtensor.device_type == LiteDeviceType.LITE_CPU:
  110. subtensor.set_data_by_copy(array)
  111. else:
  112. pinned_tensor = LiteTensor(
  113. subtensor.layout, self.dev_type, self.dev_id, True
  114. )
  115. pinned_tensor.set_data_by_share(array)
  116. subtensor.copy_from(pinned_tensor)
  117. else:
  118. assert isinstance(array, LiteTensor)
  119. ndim = array.layout.ndim
  120. shape = list(array.layout.shapes)[0:ndim]
  121. assert list(shape) == self.shape[1:]
  122. in_dtype = array.layout.data_type
  123. assert in_dtype == self.dtype
  124. # get the subtensor
  125. subtensor = self._tensor.slice([batch_id], [batch_id + 1])
  126. subtensor.copy_from(array)
  127. return batch_id
  128. def collect(self, array):
  129. """
  130. Collect a single batch through an array and store the array data to an
  131. empty batch, the empty batch is the front batch id in free list.
  132. Args:
  133. array: an array maybe LiteTensor or numpy ndarray, the shape must
  134. match the result tensor shape except the highest dimension
  135. """
  136. with self._mutex:
  137. if len(self._free_list) == 0:
  138. warnings.warn(
  139. "all batch has been collected, please call free before collect again."
  140. )
  141. return -1
  142. idx = self._free_list.pop(0)
  143. return self._collect_with_id(array, idx)
  144. def collect_by_ctypes(self, data, length):
  145. """
  146. Collect a single batch through an ctypes memory buffer and store the
  147. ctypes memory data to an empty batch, the empty batch is the front
  148. batch id in free list.
  149. Args:
  150. array: an array maybe LiteTensor or numpy ndarray, the shape must
  151. match the result tensor shape except the highest dimension
  152. """
  153. with self._mutex:
  154. if len(self._free_list) == 0:
  155. return -1
  156. idx = self._free_list.pop(0)
  157. # get the subtensor
  158. subtensor = self._tensor.slice([idx], [idx + 1])
  159. if subtensor.device_type == LiteDeviceType.LITE_CPU:
  160. subtensor.set_data_by_copy(data, length)
  161. else:
  162. pinned_tensor = LiteTensor(
  163. subtensor.layout, self.dev_type, self.dev_id, True
  164. )
  165. pinned_tensor.set_data_by_share(data, length)
  166. subtensor.copy_from(pinned_tensor)
  167. def free(self, indexes):
  168. """
  169. free the batch ids in the indexes, after the batch id is freed, it can
  170. be collected again without warning.
  171. Args:
  172. indexes: a list of to be freed batch id
  173. """
  174. with self._mutex:
  175. for i in indexes:
  176. if i in self._free_list:
  177. warnings.warn(
  178. "batch id {} has not collected before free it.".format(i)
  179. )
  180. self._free_list.remove(i)
  181. self._free_list.extend(indexes)
  182. def get(self):
  183. """
  184. After finish collection, get the result tensor
  185. """
  186. return self._tensor
  187. def to_numpy(self):
  188. """
  189. Convert the result tensor to a numpy ndarray
  190. """
  191. return self._tensor.to_numpy()