From 74d8d66bef900cbf7ddb5c82aac697fad61b72bc Mon Sep 17 00:00:00 2001 From: YWMditto Date: Fri, 8 Jul 2022 12:53:28 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=20torch=5Fdriver.lo?= =?UTF-8?q?ad=5Fmodel=EF=BC=8C=E6=B7=BB=E5=8A=A0=E4=BA=86=20strict=20?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E7=9A=84=E8=AE=BE=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/torch_driver/torch_driver.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 3307e3c9..db011403 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -199,7 +199,8 @@ class TorchDriver(Driver): f"`only_state_dict=False`") if not isinstance(res, dict): res = res.state_dict() - model.load_state_dict(res) + _strict = kwargs.get("strict", True) + model.load_state_dict(res, _strict) @rank_zero_call def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):