@@ -37,6 +37,12 @@ class _JittorDataset(Dataset): | |||||
item = item.tolist() | item = item.tolist() | ||||
return (item, self.dataset[item]) | return (item, self.dataset[item]) | ||||
def __getstate__(self): | |||||
return self.__dict__ | |||||
def __setstate__(self, state): | |||||
self.__dict__ = state | |||||
class JittorDataLoader: | class JittorDataLoader: | ||||
""" | """ | ||||
@@ -47,6 +47,12 @@ class _FDataSet: | |||||
def __len__(self) -> int: | def __len__(self) -> int: | ||||
return len(self.dataset) | return len(self.dataset) | ||||
def __getstate__(self): | |||||
return self.__dict__ | |||||
def __setstate__(self, state): | |||||
self.__dict__ = state | |||||
class OneflowDataLoader(DataLoader): | class OneflowDataLoader(DataLoader): | ||||
""" | """ | ||||
@@ -43,6 +43,15 @@ class _PaddleDataset(Dataset): | |||||
except Exception as e: | except Exception as e: | ||||
raise e | raise e | ||||
def __len__(self) -> int: | |||||
return len(self.dataset) | |||||
def __getstate__(self): | |||||
return self.__dict__ | |||||
def __setstate__(self, state): | |||||
self.__dict__ = state | |||||
class PaddleDataLoader(DataLoader): | class PaddleDataLoader(DataLoader): | ||||
""" | """ | ||||