|
|
@@ -190,6 +190,7 @@ class TorchDriver(Driver): |
|
|
|
:param load_state_dict: 保存的内容是否只是权重 |
|
|
|
""" |
|
|
|
model = self.unwrap_model() |
|
|
|
# todo torch.load 在加载时会使得卡 0 多出一个(甚至多个)model 的显存;因此在多卡断点重训时可能会出现错误; |
|
|
|
res = torch.load(filepath, map_location='cpu') |
|
|
|
if isinstance(res, dict) and only_state_dict is False: |
|
|
|
logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use " |
|
|
|