@@ -47,6 +47,16 @@ class _FDataSet: | |||||
def __len__(self) -> int: | def __len__(self) -> int: | ||||
return len(self.dataset) | return len(self.dataset) | ||||
# 这里需要显示地带上这两个方法,因为可能会涉及到 pickle 的 dumps 和 loads;否则会导致 pickle 在 loads 时调用 __setstate__ 方法 | |||||
# 进入到 __getattr__ 内部,引发死循环; | |||||
# https://docs.python.org/3/library/pickle.html#pickling-class-instances | |||||
# https://stackoverflow.com/questions/73662315/when-using-multiprocessing-and-spawn-in-python-use-self-a-in-getattr-cause?noredirect=1 | |||||
def __getstate__(self): | |||||
return self.__dict__ | |||||
def __setstate__(self, state): | |||||
self.__dict__ = state | |||||
class TorchDataLoader(DataLoader): | class TorchDataLoader(DataLoader): | ||||
""" | """ | ||||
@@ -190,6 +190,7 @@ class TorchDriver(Driver): | |||||
:param load_state_dict: 保存的内容是否只是权重 | :param load_state_dict: 保存的内容是否只是权重 | ||||
""" | """ | ||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
# todo torch.load 在加载时会使得卡 0 多出一个(甚至多个)model 的显存;因此在多卡断点重训时可能会出现错误; | |||||
res = torch.load(filepath, map_location='cpu') | res = torch.load(filepath, map_location='cpu') | ||||
if isinstance(res, dict) and only_state_dict is False: | if isinstance(res, dict) and only_state_dict is False: | ||||
logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use " | logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use " | ||||