Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0beta
x54-729 2 years ago
parent
commit
189050d25d
3 changed files with 6 additions and 3 deletions
  1. +2
    -2
      README.md
  2. +2
    -0
      fastNLP/core/callbacks/topk_saver.py
  3. +2
    -1
      fastNLP/core/drivers/torch_driver/torch_driver.py

+ 2
- 2
README.md View File

@@ -51,7 +51,7 @@ from fastNLP.transformers.torch import BertTokenizer
# 该文件还存在,将自动读取缓存文件,而不再次运行预处理代码。 # 该文件还存在,将自动读取缓存文件,而不再次运行预处理代码。
@cache_results('caches/cache.pkl') @cache_results('caches/cache.pkl')
def prepare_data(): def prepare_data():
# 会自动下载 SST2 数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的
# 会自动下载数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的
data_bundle = ChnSentiCorpLoader().load() data_bundle = ChnSentiCorpLoader().load()
# 使用tokenizer对数据进行tokenize # 使用tokenizer对数据进行tokenize
tokenizer = BertTokenizer.from_pretrained('hfl/chinese-bert-wwm') tokenizer = BertTokenizer.from_pretrained('hfl/chinese-bert-wwm')
@@ -130,7 +130,7 @@ evaluator.run()
from fastNLP.io import ChnSentiCorpLoader from fastNLP.io import ChnSentiCorpLoader
from functools import partial from functools import partial


# 会自动下载 SST2 数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的
# 会自动下载数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的
data_bundle = ChnSentiCorpLoader().load() data_bundle = ChnSentiCorpLoader().load()


# 使用tokenizer对数据进行tokenize # 使用tokenizer对数据进行tokenize


+ 2
- 0
fastNLP/core/callbacks/topk_saver.py View File

@@ -50,6 +50,8 @@ class Saver:
self.save_fn_name = 'save_checkpoint' if save_object == 'trainer' else 'save_model' self.save_fn_name = 'save_checkpoint' if save_object == 'trainer' else 'save_model'


self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
# 打印这次运行时 checkpoint 所保存在的文件夹,因为这个文件夹是根据时间实时生成的,因此需要打印出来防止用户混淆;
logger.info(f"The checkpoint will be saved in this folder for this time: {self.timestamp_path}.")


def save(self, trainer, folder_name): def save(self, trainer, folder_name):
""" """


+ 2
- 1
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -199,7 +199,8 @@ class TorchDriver(Driver):
f"`only_state_dict=False`") f"`only_state_dict=False`")
if not isinstance(res, dict): if not isinstance(res, dict):
res = res.state_dict() res = res.state_dict()
model.load_state_dict(res)
_strict = kwargs.get("strict", True)
model.load_state_dict(res, _strict)


@rank_zero_call @rank_zero_call
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):


Loading…
Cancel
Save