You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import threading
  10. import warnings
  11. import numpy as np
  12. from .base import *
  13. from .struct import *
  14. from .tensor import *
  15. class TensorBatchCollector:
  16. """
  17. this is a tensor utils to collect subtensor in batch continuous
  18. """
  19. def __init__(
  20. self,
  21. shape,
  22. dtype=LiteDataType.LITE_INT8,
  23. device_type=LiteDeviceType.LITE_CUDA,
  24. device_id=0,
  25. is_pinned_host=False,
  26. tensor=None,
  27. ):
  28. self._mutex = threading.Lock()
  29. self.dev_type = device_type
  30. self.is_pinned_host = is_pinned_host
  31. self.dev_id = device_id
  32. self.shape = shape
  33. self.dtype = LiteLayout(dtype=dtype).data_type
  34. self._free_list = list(range(self.shape[0]))
  35. if tensor is not None:
  36. assert (
  37. tensor.layout.shapes[0 : tensor.layout.ndim] == shape
  38. ), "The tensor set to TensorBatchCollector is not right."
  39. self._tensor = tensor
  40. self.dtype = tensor.layout.data_type
  41. self.device_type = tensor.device_type
  42. self.device_id = tensor.device_type
  43. else:
  44. self._tensor = LiteTensor(
  45. LiteLayout(shape, dtype), device_type, device_id, is_pinned_host
  46. )
  47. def collect_id(self, array, batch_id):
  48. # get the batch index
  49. with self._mutex:
  50. if batch_id in self._free_list:
  51. self._free_list.remove(batch_id)
  52. else:
  53. warnings.warn(
  54. "batch {} has been collected, please call free before collected it again.".format(
  55. batch_id
  56. )
  57. )
  58. self._collect_with_id(array, batch_id)
  59. def _collect_with_id(self, array, batch_id):
  60. if isinstance(array, np.ndarray):
  61. shape = array.shape
  62. assert list(shape) == self.shape[1:]
  63. in_dtype = ctype_to_lite_dtypes[np.ctypeslib.as_ctypes_type(array.dtype)]
  64. assert in_dtype == self.dtype
  65. # get the subtensor
  66. subtensor = self._tensor.slice([batch_id], [batch_id + 1])
  67. if subtensor.device_type == LiteDeviceType.LITE_CPU:
  68. subtensor.set_data_by_copy(array)
  69. else:
  70. pinned_tensor = LiteTensor(
  71. subtensor.layout, self.dev_type, self.dev_id, True
  72. )
  73. pinned_tensor.set_data_by_share(array)
  74. subtensor.copy_from(pinned_tensor)
  75. else:
  76. assert isinstance(array, LiteTensor)
  77. ndim = array.layout.ndim
  78. shape = list(array.layout.shapes)[0:ndim]
  79. assert list(shape) == self.shape[1:]
  80. in_dtype = array.layout.data_type
  81. assert in_dtype == self.dtype
  82. # get the subtensor
  83. subtensor = self._tensor.slice([batch_id], [batch_id + 1])
  84. subtensor.copy_from(array)
  85. return batch_id
  86. def collect(self, array):
  87. with self._mutex:
  88. if len(self._free_list) == 0:
  89. warnings.warn(
  90. "all batch has been collected, please call free before collect again."
  91. )
  92. return -1
  93. idx = self._free_list.pop(0)
  94. return self._collect_with_id(array, idx)
  95. def collect_by_ctypes(self, data, length):
  96. """
  97. collect with ctypes data input
  98. """
  99. with self._mutex:
  100. if len(self._free_list) == 0:
  101. return -1
  102. idx = self._free_list.pop(0)
  103. # get the subtensor
  104. subtensor = self._tensor.slice([idx], [idx + 1])
  105. if subtensor.device_type == LiteDeviceType.LITE_CPU:
  106. subtensor.set_data_by_copy(data, length)
  107. else:
  108. pinned_tensor = LiteTensor(
  109. subtensor.layout, self.dev_type, self.dev_id, True
  110. )
  111. pinned_tensor.set_data_by_share(data, length)
  112. subtensor.copy_from(pinned_tensor)
  113. def free(self, indexes):
  114. with self._mutex:
  115. for i in indexes:
  116. if i in self._free_list:
  117. warnings.warn(
  118. "batch id {} has not collected before free it.".format(i)
  119. )
  120. self._free_list.remove(i)
  121. self._free_list.extend(indexes)
  122. def get(self):
  123. return self._tensor
  124. def to_numpy(self):
  125. return self._tensor.to_numpy()