Browse Source

修改了 torch_driver.load_model,添加了 strict 参数的设定

tags/v1.0.0beta
YWMditto 2 years ago
parent
commit
74d8d66bef
1 changed files with 2 additions and 1 deletions
  1. +2
    -1
      fastNLP/core/drivers/torch_driver/torch_driver.py

+ 2
- 1
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -199,7 +199,8 @@ class TorchDriver(Driver):
f"`only_state_dict=False`") f"`only_state_dict=False`")
if not isinstance(res, dict): if not isinstance(res, dict):
res = res.state_dict() res = res.state_dict()
model.load_state_dict(res)
_strict = kwargs.get("strict", True)
model.load_state_dict(res, _strict)


@rank_zero_call @rank_zero_call
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):


Loading…
Cancel
Save