diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 3307e3c9..db011403 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -199,7 +199,8 @@ class TorchDriver(Driver): f"`only_state_dict=False`") if not isinstance(res, dict): res = res.state_dict() - model.load_state_dict(res) + _strict = kwargs.get("strict", True) + model.load_state_dict(res, _strict) @rank_zero_call def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):