From ddce6e960c52901da9534773e9ca6e84ed94fad4 Mon Sep 17 00:00:00 2001 From: yhcc Date: Fri, 8 Jul 2022 08:17:23 +0800 Subject: [PATCH 1/3] fix typo --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a0f973e0..8a72ee2a 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ from fastNLP.transformers.torch import BertTokenizer # 该文件还存在,将自动读取缓存文件,而不再次运行预处理代码。 @cache_results('caches/cache.pkl') def prepare_data(): - # 会自动下载 SST2 数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的 + # 会自动下载数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的 data_bundle = ChnSentiCorpLoader().load() # 使用tokenizer对数据进行tokenize tokenizer = BertTokenizer.from_pretrained('hfl/chinese-bert-wwm') @@ -130,7 +130,7 @@ evaluator.run() from fastNLP.io import ChnSentiCorpLoader from functools import partial -# 会自动下载 SST2 数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的 +# 会自动下载数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的 data_bundle = ChnSentiCorpLoader().load() # 使用tokenizer对数据进行tokenize From 74d8d66bef900cbf7ddb5c82aac697fad61b72bc Mon Sep 17 00:00:00 2001 From: YWMditto Date: Fri, 8 Jul 2022 12:53:28 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=20torch=5Fdrive?= =?UTF-8?q?r.load=5Fmodel=EF=BC=8C=E6=B7=BB=E5=8A=A0=E4=BA=86=20strict=20?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E7=9A=84=E8=AE=BE=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/torch_driver/torch_driver.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 3307e3c9..db011403 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -199,7 +199,8 @@ class TorchDriver(Driver): f"`only_state_dict=False`") if not isinstance(res, dict): res = res.state_dict() - model.load_state_dict(res) + _strict = kwargs.get("strict", True) + model.load_state_dict(res, _strict) @rank_zero_call def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): From 4724a7206af60c91145b056d09c67edd01556ae3 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Sun, 10 Jul 2022 17:51:17 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=20saver=20?= =?UTF-8?q?=E4=B8=AD=E5=AF=B9=E4=BA=8E=E4=BF=9D=E5=AD=98=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E5=A4=B9=E7=9A=84=E6=89=93=E5=8D=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/topk_saver.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fastNLP/core/callbacks/topk_saver.py b/fastNLP/core/callbacks/topk_saver.py index 1a4de8d5..21a8961f 100644 --- a/fastNLP/core/callbacks/topk_saver.py +++ b/fastNLP/core/callbacks/topk_saver.py @@ -50,6 +50,8 @@ class Saver: self.save_fn_name = 'save_checkpoint' if save_object == 'trainer' else 'save_model' self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) + # 打印这次运行时 checkpoint 所保存在的文件夹,因为这个文件夹是根据时间实时生成的,因此需要打印出来防止用户混淆; + logger.info(f"The checkpoint will be saved in this folder for this time: {self.timestamp_path}.") def save(self, trainer, folder_name): """