|
|
@@ -32,7 +32,7 @@ class TensorBatchCollector: |
|
|
|
self._mutex = threading.Lock() |
|
|
|
self.dev_type = device_type |
|
|
|
self.is_pinned_host = is_pinned_host |
|
|
|
self.dev_id = 0 |
|
|
|
self.dev_id = device_id |
|
|
|
self.shape = shape |
|
|
|
self.dtype = LiteLayout(dtype=dtype).data_type |
|
|
|
self._free_list = list(range(self.shape[0])) |
|
|
|