diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 5eb64a62..0e0cb443 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -37,6 +37,12 @@ class _JittorDataset(Dataset): item = item.tolist() return (item, self.dataset[item]) + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state + class JittorDataLoader: """ diff --git a/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py b/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py index 15fea5d4..e2882d75 100644 --- a/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py @@ -47,6 +47,12 @@ class _FDataSet: def __len__(self) -> int: return len(self.dataset) + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state + class OneflowDataLoader(DataLoader): """ diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index 529f23aa..12f00534 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -43,6 +43,15 @@ class _PaddleDataset(Dataset): except Exception as e: raise e + def __len__(self) -> int: + return len(self.dataset) + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state + class PaddleDataLoader(DataLoader): """