Browse Source

添加了 torch_driver.load_model 在ddp 时出现显存爆炸的todo事项

dev0.8.0
YWMditto 2 years ago
parent
commit
6db70df2a7
1 changed files with 1 additions and 0 deletions
  1. +1
    -0
      fastNLP/core/drivers/torch_driver/torch_driver.py

+ 1
- 0
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -190,6 +190,7 @@ class TorchDriver(Driver):
:param load_state_dict: 保存的内容是否只是权重
"""
model = self.unwrap_model()
# todo torch.load 在加载时会使得卡 0 多出一个(甚至多个)model 的显存;因此在多卡断点重训时可能会出现错误;
res = torch.load(filepath, map_location='cpu')
if isinstance(res, dict) and only_state_dict is False:
logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use "


Loading…
Cancel
Save