|
|
@@ -7,6 +7,7 @@ |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import threading |
|
|
|
import warnings |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
@@ -51,15 +52,24 @@ class TensorBatchCollector: |
|
|
|
) |
|
|
|
|
|
|
|
def collect_id(self, array, batch_id): |
|
|
|
# get the batch index |
|
|
|
with self._mutex: |
|
|
|
if batch_id in self._free_list: |
|
|
|
self._free_list.remove(batch_id) |
|
|
|
else: |
|
|
|
warnings.warn( |
|
|
|
"batch {} has been collected, please call free before collected it again.".format( |
|
|
|
batch_id |
|
|
|
) |
|
|
|
) |
|
|
|
self._collect_with_id(array, batch_id) |
|
|
|
|
|
|
|
def _collect_with_id(self, array, batch_id): |
|
|
|
if isinstance(array, np.ndarray): |
|
|
|
shape = array.shape |
|
|
|
assert list(shape) == self.shape[1:] |
|
|
|
in_dtype = ctype_to_lite_dtypes[np.ctypeslib.as_ctypes_type(array.dtype)] |
|
|
|
assert in_dtype == self.dtype |
|
|
|
# get the batch index |
|
|
|
with self._mutex: |
|
|
|
if batch_id in self._free_list: |
|
|
|
self._free_list.remove(batch_id) |
|
|
|
# get the subtensor |
|
|
|
subtensor = self._tensor.slice([batch_id], [batch_id + 1]) |
|
|
|
if subtensor.device_type == LiteDeviceType.LITE_CPU: |
|
|
@@ -77,10 +87,6 @@ class TensorBatchCollector: |
|
|
|
assert list(shape) == self.shape[1:] |
|
|
|
in_dtype = array.layout.data_type |
|
|
|
assert in_dtype == self.dtype |
|
|
|
# get the batch index |
|
|
|
with self._mutex: |
|
|
|
if batch_id in self._free_list: |
|
|
|
self._free_list.remove(batch_id) |
|
|
|
# get the subtensor |
|
|
|
subtensor = self._tensor.slice([batch_id], [batch_id + 1]) |
|
|
|
subtensor.copy_from(array) |
|
|
@@ -90,9 +96,12 @@ class TensorBatchCollector: |
|
|
|
def collect(self, array): |
|
|
|
with self._mutex: |
|
|
|
if len(self._free_list) == 0: |
|
|
|
warnings.warn( |
|
|
|
"all batch has been collected, please call free before collect again." |
|
|
|
) |
|
|
|
return -1 |
|
|
|
idx = self._free_list.pop(0) |
|
|
|
return self.collect_id(array, idx) |
|
|
|
return self._collect_with_id(array, idx) |
|
|
|
|
|
|
|
def collect_by_ctypes(self, data, length): |
|
|
|
""" |
|
|
@@ -115,6 +124,12 @@ class TensorBatchCollector: |
|
|
|
|
|
|
|
def free(self, indexes): |
|
|
|
with self._mutex: |
|
|
|
for i in indexes: |
|
|
|
if i in self._free_list: |
|
|
|
warnings.warn( |
|
|
|
"batch id {} has not collected before free it.".format(i) |
|
|
|
) |
|
|
|
self._free_list.remove(i) |
|
|
|
self._free_list.extend(indexes) |
|
|
|
|
|
|
|
def get(self): |
|
|
|