diff --git a/lite/pylite/megenginelite/utils.py b/lite/pylite/megenginelite/utils.py index 7b611059..670aa910 100644 --- a/lite/pylite/megenginelite/utils.py +++ b/lite/pylite/megenginelite/utils.py @@ -32,7 +32,7 @@ class TensorBatchCollector: self._mutex = threading.Lock() self.dev_type = device_type self.is_pinned_host = is_pinned_host - self.dev_id = 0 + self.dev_id = device_id self.shape = shape self.dtype = LiteLayout(dtype=dtype).data_type self._free_list = list(range(self.shape[0]))