@@ -93,6 +93,7 @@ __all__ = [ | |||||
"f_rich_progress", | "f_rich_progress", | ||||
"auto_param_call", | "auto_param_call", | ||||
"seq_len_to_mask", | "seq_len_to_mask", | ||||
"f_tqdm_progress", | |||||
# vocabulary.py | # vocabulary.py | ||||
'Vocabulary' | 'Vocabulary' | ||||
@@ -8,7 +8,10 @@ r""" | |||||
``Trainer`` 通过来自己内部内置一个 ``Evaluator`` 实例来支持在训练过程中进行验证的功能; | ``Trainer`` 通过来自己内部内置一个 ``Evaluator`` 实例来支持在训练过程中进行验证的功能; | ||||
""" | """ | ||||
from typing import Union, List, Optional, Dict, Callable | |||||
from typing import Union, List, Optional, Dict, Callable, BinaryIO | |||||
import os | |||||
from pathlib import Path | |||||
import io | |||||
from dataclasses import is_dataclass | from dataclasses import is_dataclass | ||||
__all__ = [ | __all__ = [ | ||||
@@ -25,6 +28,8 @@ from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metri | |||||
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader | from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader | ||||
from fastNLP.core.utils.utils import _check_valid_parameters_number | from fastNLP.core.utils.utils import _check_valid_parameters_number | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME | |||||
class Evaluator: | class Evaluator: | ||||
@@ -174,6 +179,44 @@ class Evaluator: | |||||
self.driver.barrier() | self.driver.barrier() | ||||
def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = True, | |||||
model_load_fn: Optional[Callable] = None, **kwargs): | |||||
""" | |||||
用于帮助您加载模型的辅助函数; | |||||
:param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时, | |||||
直接将该 folder 传递到 model_load_fn 中; | |||||
:param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn 不为 None`` 时,该参数无意义; | |||||
:param model_load_fn: ``callable`` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容; | |||||
:param kwargs: 理论上您不需要使用到该参数; | |||||
.. note:: | |||||
注意您需要在初始化 ``Evaluator`` 后再通过 ``evaluator`` 实例来调用该函数;这意味着您需要保证在保存和加载时使用的 ``driver`` 是属于同一个 | |||||
训练框架的,例如都是 ``pytorch`` 或者 ``paddle``; | |||||
""" | |||||
self.driver.barrier() | |||||
if not isinstance(folder, (io.BytesIO, BinaryIO)): | |||||
try: | |||||
if model_load_fn is not None: | |||||
if not callable(model_load_fn): | |||||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | |||||
model_load_fn(folder) | |||||
else: | |||||
if isinstance(folder, str): | |||||
folder = Path(folder) | |||||
self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) | |||||
except FileNotFoundError as e: | |||||
if FASTNLP_MODEL_FILENAME not in os.listdir(folder): | |||||
logger.error(f"fastNLP model checkpoint file:{FASTNLP_MODEL_FILENAME} is not found in {folder}.") | |||||
raise e | |||||
else: | |||||
if model_load_fn is not None: | |||||
raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being " | |||||
"`io.BytesIO` type.") | |||||
self.driver.load_model(folder, only_state_dict, **kwargs) | |||||
self.driver.barrier() | |||||
def run(self, num_eval_batch_per_dl: int = -1) -> Dict: | def run(self, num_eval_batch_per_dl: int = -1) -> Dict: | ||||
""" | """ | ||||
该函数是在 ``Evaluator`` 初始化后用于真正开始评测的函数; | 该函数是在 ``Evaluator`` 初始化后用于真正开始评测的函数; | ||||
@@ -45,7 +45,11 @@ class EvaluateBatchLoop(Loop): | |||||
except BaseException as e: | except BaseException as e: | ||||
if callable(getattr(dataloader, 'get_batch_indices', None)): | if callable(getattr(dataloader, 'get_batch_indices', None)): | ||||
indices = dataloader.get_batch_indices() | indices = dataloader.get_batch_indices() | ||||
logger.error(f"Exception happens when evaluating on samples: {indices}") | |||||
if evaluator.cur_dataloader_name is not None: | |||||
logger.error(f"Exception happens when evaluating on samples in dataloader:" | |||||
f"{evaluator.cur_dataloader_name}: {indices}") | |||||
else: | |||||
logger.error(f"Exception happens when evaluating on samples: {indices}") | |||||
raise e | raise e | ||||
# 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...} | # 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...} | ||||
results = evaluator.get_metric() | results = evaluator.get_metric() | ||||
@@ -220,7 +220,8 @@ def prepare_torch_dataloader(ds_or_db, | |||||
pin_memory: bool = False, drop_last: bool = False, | pin_memory: bool = False, drop_last: bool = False, | ||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | timeout: float = 0, worker_init_fn: Optional[Callable] = None, | ||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | ||||
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, | |||||
persistent_workers: bool = False, | |||||
non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||||
non_train_batch_size: int = 16) \ | non_train_batch_size: int = 16) \ | ||||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | ||||
""" | """ | ||||
@@ -119,10 +119,10 @@ class DataSet: | |||||
self._collator = Collator() | self._collator = Collator() | ||||
if data is not None: | if data is not None: | ||||
if isinstance(data, Dict): | if isinstance(data, Dict): | ||||
length_set = set() | |||||
length_set = {} | |||||
for key, value in data.items(): | for key, value in data.items(): | ||||
length_set.add(len(value)) | |||||
assert len(length_set) == 1, "Arrays must all be same length." | |||||
length_set[key] = len(value) | |||||
assert len(set(length_set.values())) == 1, f"Fields must all be of same length, instead of {length_set}." | |||||
for key, value in data.items(): | for key, value in data.items(): | ||||
self.add_field(field_name=key, fields=value) | self.add_field(field_name=key, fields=value) | ||||
elif isinstance(data, List): | elif isinstance(data, List): | ||||
@@ -1497,7 +1497,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin): | |||||
else: | else: | ||||
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") | logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") | ||||
if len(missing_keys) > 0: | if len(missing_keys) > 0: | ||||
logger.warning( | |||||
logger.rank_zero_warning( | |||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " | f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " | ||||
f"and are newly initialized: {missing_keys}\n" | f"and are newly initialized: {missing_keys}\n" | ||||
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." | f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." | ||||
@@ -1515,7 +1515,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin): | |||||
for key, shape1, shape2 in mismatched_keys | for key, shape1, shape2 in mismatched_keys | ||||
] | ] | ||||
) | ) | ||||
logger.warning( | |||||
logger.rank_zero_warning( | |||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " | f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " | ||||
f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n" | f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n" | ||||
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." | f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." | ||||
@@ -1,3 +1,4 @@ | |||||
from .auto import * | |||||
from .bart import * | from .bart import * | ||||
from .bert import * | from .bert import * | ||||
from .cpt import * | from .cpt import * | ||||
@@ -8,6 +8,7 @@ import time | |||||
from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback | from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback | ||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP import Evaluator | |||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
@@ -286,6 +287,13 @@ def test_model_checkpoint_callback_2( | |||||
trainer.load_model(folder, only_state_dict=only_state_dict) | trainer.load_model(folder, only_state_dict=only_state_dict) | ||||
trainer.run() | trainer.run() | ||||
evaluator = Evaluator(model=model_and_optimizers.model, driver='torch', device=0, | |||||
dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics) | |||||
evaluator.load_model(folder, only_state_dict=only_state_dict) | |||||
evaluator.run() | |||||
trainer.driver.barrier() | trainer.driver.barrier() | ||||
finally: | finally: | ||||