From 024fecfbf3e6837ff19ba9e128e62df0f881c8aa Mon Sep 17 00:00:00 2001 From: YWMditto Date: Thu, 16 Jun 2022 22:30:02 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=20overfit=20?= =?UTF-8?q?=E7=9A=84=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/trainer.py | 52 +++++++++++++--------- fastNLP/core/dataloaders/__init__.py | 7 ++- fastNLP/core/dataloaders/utils.py | 41 ++++++++++++++++- fastNLP/core/drivers/torch_driver/torch_driver.py | 3 +- .../controllers/test_trainer_w_evaluator_torch.py | 42 +++++++++++++++++ .../controllers/test_trainer_wo_evaluator_torch.py | 27 +++++++++++ 6 files changed, 146 insertions(+), 26 deletions(-) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index f92611dd..00a18f1d 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -35,6 +35,7 @@ from fastNLP.envs import rank_zero_call from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.utils.exceptions import EarlyStopException +from fastNLP.core.dataloaders import OverfitDataLoader class Trainer(TrainerEventTrigger): @@ -356,6 +357,7 @@ class Trainer(TrainerEventTrigger): optimizers, device: Optional[Union[int, List[int], str]] = "cpu", n_epochs: int = 20, + overfit_batches: int = 0, evaluate_dataloaders=None, batch_step_fn: Optional[Callable] = None, evaluate_batch_step_fn: Optional[Callable] = None, @@ -469,9 +471,6 @@ class Trainer(TrainerEventTrigger): n_batches=n_batches ) - if metrics is None and evaluate_dataloaders is not None: - raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") - if metrics is not None and evaluate_dataloaders is None: raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.") @@ -495,33 +494,44 @@ class Trainer(TrainerEventTrigger): else: _dist_sampler = None + self.dataloader = self.train_dataloader + self.driver.set_deterministic_dataloader(self.dataloader) + + self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, + reproducible=self.callback_manager._need_reproducible_sampler) + # 进行 overfit 相关的设置; + if overfit_batches != 0: + self.dataloader = OverfitDataLoader(self.dataloader, overfit_batches) + self.overfit_batches = overfit_batches + self.evaluator = None self.monitor = monitor self.larger_better = larger_better - if metrics is not None and evaluate_dataloaders is not None: - check_evaluate_every(evaluate_every) - progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 - if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 - progress_bar = progress_bar.name - self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, - driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, - evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, - output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, - use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), - progress_bar=progress_bar, - check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) + if metrics is not None: + if overfit_batches != 0: + logger.warning("Notice you are trying to 'overfit' the model and also using 'metrics', it may cause error " + "because 'metrics' are prepared for 'evaluate_dataloaders', but now 'train_dataloader'.") + evaluate_dataloaders = self.dataloader + if evaluate_dataloaders is not None: + check_evaluate_every(evaluate_every) + progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 + if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 + progress_bar = progress_bar.name + self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, + driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, + evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, + output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, + use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), + progress_bar=progress_bar, + check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) + else: + raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") if train_fn is not None and not isinstance(train_fn, str): raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) self.train_fn = train_fn - self.dataloader = self.train_dataloader - self.driver.set_deterministic_dataloader(self.dataloader) - - self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, - reproducible=self.callback_manager._need_reproducible_sampler) - self.evaluate_batch_step_fn = evaluate_batch_step_fn self.kwargs = kwargs diff --git a/fastNLP/core/dataloaders/__init__.py b/fastNLP/core/dataloaders/__init__.py index 84f8b288..b18e371c 100644 --- a/fastNLP/core/dataloaders/__init__.py +++ b/fastNLP/core/dataloaders/__init__.py @@ -7,10 +7,13 @@ __all__ = [ 'prepare_paddle_dataloader', 'prepare_torch_dataloader', - "prepare_dataloader" + "prepare_dataloader", + + "OverfitDataLoader" ] from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader -from .prepare_dataloader import prepare_dataloader \ No newline at end of file +from .prepare_dataloader import prepare_dataloader +from .utils import OverfitDataLoader \ No newline at end of file diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index d905101f..9f8b608c 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -1,4 +1,4 @@ -from typing import Callable, Any, Union +from typing import Callable, Any, Union, Sequence from abc import ABC import inspect import ast @@ -6,7 +6,8 @@ import ast from ..log import logger from ..utils.cache_results import get_func_calls, truncate_start_blanks __all__ = [ - "indice_collate_wrapper" + "indice_collate_wrapper", + "OverfitDataLoader" ] @@ -111,6 +112,42 @@ class HasLenGetitemType(ABC): return NotImplemented +class OverfitDataLoader: + """ + 实现一个简单的迭代器来模拟实际的 dataloader,从给定的 dataloader 中取出部分数据,来让 Trainer 实现 overfit 的功能; + """ + + def __init__(self, dataloader, overfit_batches: int): + self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作; + self.batches = [] + + if isinstance(overfit_batches, int): + if overfit_batches < 0 and overfit_batches != -1: + raise ValueError("Parameter 'overfit_batches' can only be '-1' when it is smaller than 0, and it means" + "that you use all the data to check whether it could be overfitted.") + else: + raise TypeError("Parameter 'overfit_batches' can only be 'int' type, check the parameter you input into 'Trainer'.") + + if overfit_batches > len(dataloader): + logger.warning("Parameter 'overfit_batches' is bigger than the real length of 'train dataloader'.") + + for idx, batch in enumerate(dataloader): + + if idx < overfit_batches or overfit_batches == -1: + self.batches.append(batch) + + def __len__(self): + return len(self.batches) + + def __iter__(self): + for batch in self.batches: + yield batch + + def __getattr__(self, item): + return getattr(self.dataloader, item) + + + if __name__ == '__main__': def demo(*args, **kwargs): pass diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index fb01b6c3..84e4aa70 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -31,6 +31,7 @@ from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler +from fastNLP.core.dataloaders import OverfitDataLoader class TorchDriver(Driver): @@ -92,7 +93,7 @@ class TorchDriver(Driver): self.grad_scaler.update() def check_dataloader_legality(self, dataloader): - if not isinstance(dataloader, DataLoader): + if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader): raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") if len(dataloader) == 0: logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 752e06d8..a70766f5 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -286,6 +286,9 @@ def test_trainer_specific_params_1( assert trainer.driver.non_blocking is False assert trainer.driver.wo_auto_param_call is True + if dist.is_initialized(): + dist.destroy_process_group() + @pytest.mark.torch @pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) @@ -332,5 +335,44 @@ def test_trainer_specific_params_2( assert _ddp_kwargs.get("broadcast_buffers") is True assert _ddp_kwargs.get("find_unused_parameters") is True + if dist.is_initialized(): + dist.destroy_process_group() +@pytest.mark.torch +@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) +@magic_argv_env_context +def test_trainer_w_evaluator_overfit_torch( + model_and_optimizers: TrainerParameters, + overfit_batches, + num_train_batch_per_epoch +): + """ + 测试一些特殊的参数是否能够正确地传递; + """ + trainer = Trainer( + model=model_and_optimizers.model, + driver="torch", + device=0, + overfit_batches=overfit_batches, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders={"dl": model_and_optimizers.evaluate_dataloaders}, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + n_epochs=2, + output_from_new_proc="all", + evaluate_every=-1, + + torch_kwargs={ + "non_blocking": False, + "set_grad_to_none": True + } + + ) + + trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch) + + if dist.is_initialized(): + dist.destroy_process_group() \ No newline at end of file diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index be04bcd3..a7eeeda6 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -361,5 +361,32 @@ def test_torch_wo_auto_param_call( dist.destroy_process_group() +# 测试 accumulation_steps; +@pytest.mark.torch +@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) +@magic_argv_env_context +def test_trainer_overfit_torch( + model_and_optimizers: TrainerParameters, + overfit_batches, + num_train_batch_per_epoch +): + trainer = Trainer( + model=model_and_optimizers.model, + driver="torch", + device=0, + overfit_batches=overfit_batches, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + output_from_new_proc="all", + n_epochs=2, + ) + + trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch) + if dist.is_initialized(): + dist.destroy_process_group()