From 1f4208b1c941bebc45766a720b757662e3c933cf Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 19 Sep 2022 15:03:46 +0800 Subject: [PATCH] fix bugs of _FDataset of paddle, jittor and oneflow --- fastNLP/core/dataloaders/jittor_dataloader/fdl.py | 6 ++++++ fastNLP/core/dataloaders/oneflow_dataloader/fdl.py | 6 ++++++ fastNLP/core/dataloaders/paddle_dataloader/fdl.py | 9 +++++++++ 3 files changed, 21 insertions(+) diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 5eb64a62..0e0cb443 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -37,6 +37,12 @@ class _JittorDataset(Dataset): item = item.tolist() return (item, self.dataset[item]) + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state + class JittorDataLoader: """ diff --git a/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py b/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py index 15fea5d4..e2882d75 100644 --- a/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py @@ -47,6 +47,12 @@ class _FDataSet: def __len__(self) -> int: return len(self.dataset) + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state + class OneflowDataLoader(DataLoader): """ diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index 529f23aa..12f00534 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -43,6 +43,15 @@ class _PaddleDataset(Dataset): except Exception as e: raise e + def __len__(self) -> int: + return len(self.dataset) + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state + class PaddleDataLoader(DataLoader): """