From 024fecfbf3e6837ff19ba9e128e62df0f881c8aa Mon Sep 17 00:00:00 2001 From: YWMditto Date: Thu, 16 Jun 2022 22:30:02 +0800 Subject: [PATCH 1/7] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=20overfit=20?= =?UTF-8?q?=E7=9A=84=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/trainer.py | 52 +++++++++++++--------- fastNLP/core/dataloaders/__init__.py | 7 ++- fastNLP/core/dataloaders/utils.py | 41 ++++++++++++++++- fastNLP/core/drivers/torch_driver/torch_driver.py | 3 +- .../controllers/test_trainer_w_evaluator_torch.py | 42 +++++++++++++++++ .../controllers/test_trainer_wo_evaluator_torch.py | 27 +++++++++++ 6 files changed, 146 insertions(+), 26 deletions(-) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index f92611dd..00a18f1d 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -35,6 +35,7 @@ from fastNLP.envs import rank_zero_call from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.utils.exceptions import EarlyStopException +from fastNLP.core.dataloaders import OverfitDataLoader class Trainer(TrainerEventTrigger): @@ -356,6 +357,7 @@ class Trainer(TrainerEventTrigger): optimizers, device: Optional[Union[int, List[int], str]] = "cpu", n_epochs: int = 20, + overfit_batches: int = 0, evaluate_dataloaders=None, batch_step_fn: Optional[Callable] = None, evaluate_batch_step_fn: Optional[Callable] = None, @@ -469,9 +471,6 @@ class Trainer(TrainerEventTrigger): n_batches=n_batches ) - if metrics is None and evaluate_dataloaders is not None: - raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") - if metrics is not None and evaluate_dataloaders is None: raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.") @@ -495,33 +494,44 @@ class Trainer(TrainerEventTrigger): else: _dist_sampler = None + self.dataloader = self.train_dataloader + self.driver.set_deterministic_dataloader(self.dataloader) + + self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, + reproducible=self.callback_manager._need_reproducible_sampler) + # 进行 overfit 相关的设置; + if overfit_batches != 0: + self.dataloader = OverfitDataLoader(self.dataloader, overfit_batches) + self.overfit_batches = overfit_batches + self.evaluator = None self.monitor = monitor self.larger_better = larger_better - if metrics is not None and evaluate_dataloaders is not None: - check_evaluate_every(evaluate_every) - progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 - if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 - progress_bar = progress_bar.name - self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, - driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, - evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, - output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, - use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), - progress_bar=progress_bar, - check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) + if metrics is not None: + if overfit_batches != 0: + logger.warning("Notice you are trying to 'overfit' the model and also using 'metrics', it may cause error " + "because 'metrics' are prepared for 'evaluate_dataloaders', but now 'train_dataloader'.") + evaluate_dataloaders = self.dataloader + if evaluate_dataloaders is not None: + check_evaluate_every(evaluate_every) + progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 + if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 + progress_bar = progress_bar.name + self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, + driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, + evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, + output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, + use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), + progress_bar=progress_bar, + check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) + else: + raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") if train_fn is not None and not isinstance(train_fn, str): raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) self.train_fn = train_fn - self.dataloader = self.train_dataloader - self.driver.set_deterministic_dataloader(self.dataloader) - - self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, - reproducible=self.callback_manager._need_reproducible_sampler) - self.evaluate_batch_step_fn = evaluate_batch_step_fn self.kwargs = kwargs diff --git a/fastNLP/core/dataloaders/__init__.py b/fastNLP/core/dataloaders/__init__.py index 84f8b288..b18e371c 100644 --- a/fastNLP/core/dataloaders/__init__.py +++ b/fastNLP/core/dataloaders/__init__.py @@ -7,10 +7,13 @@ __all__ = [ 'prepare_paddle_dataloader', 'prepare_torch_dataloader', - "prepare_dataloader" + "prepare_dataloader", + + "OverfitDataLoader" ] from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader -from .prepare_dataloader import prepare_dataloader \ No newline at end of file +from .prepare_dataloader import prepare_dataloader +from .utils import OverfitDataLoader \ No newline at end of file diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index d905101f..9f8b608c 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -1,4 +1,4 @@ -from typing import Callable, Any, Union +from typing import Callable, Any, Union, Sequence from abc import ABC import inspect import ast @@ -6,7 +6,8 @@ import ast from ..log import logger from ..utils.cache_results import get_func_calls, truncate_start_blanks __all__ = [ - "indice_collate_wrapper" + "indice_collate_wrapper", + "OverfitDataLoader" ] @@ -111,6 +112,42 @@ class HasLenGetitemType(ABC): return NotImplemented +class OverfitDataLoader: + """ + 实现一个简单的迭代器来模拟实际的 dataloader,从给定的 dataloader 中取出部分数据,来让 Trainer 实现 overfit 的功能; + """ + + def __init__(self, dataloader, overfit_batches: int): + self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作; + self.batches = [] + + if isinstance(overfit_batches, int): + if overfit_batches < 0 and overfit_batches != -1: + raise ValueError("Parameter 'overfit_batches' can only be '-1' when it is smaller than 0, and it means" + "that you use all the data to check whether it could be overfitted.") + else: + raise TypeError("Parameter 'overfit_batches' can only be 'int' type, check the parameter you input into 'Trainer'.") + + if overfit_batches > len(dataloader): + logger.warning("Parameter 'overfit_batches' is bigger than the real length of 'train dataloader'.") + + for idx, batch in enumerate(dataloader): + + if idx < overfit_batches or overfit_batches == -1: + self.batches.append(batch) + + def __len__(self): + return len(self.batches) + + def __iter__(self): + for batch in self.batches: + yield batch + + def __getattr__(self, item): + return getattr(self.dataloader, item) + + + if __name__ == '__main__': def demo(*args, **kwargs): pass diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index fb01b6c3..84e4aa70 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -31,6 +31,7 @@ from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler +from fastNLP.core.dataloaders import OverfitDataLoader class TorchDriver(Driver): @@ -92,7 +93,7 @@ class TorchDriver(Driver): self.grad_scaler.update() def check_dataloader_legality(self, dataloader): - if not isinstance(dataloader, DataLoader): + if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader): raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") if len(dataloader) == 0: logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 752e06d8..a70766f5 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -286,6 +286,9 @@ def test_trainer_specific_params_1( assert trainer.driver.non_blocking is False assert trainer.driver.wo_auto_param_call is True + if dist.is_initialized(): + dist.destroy_process_group() + @pytest.mark.torch @pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) @@ -332,5 +335,44 @@ def test_trainer_specific_params_2( assert _ddp_kwargs.get("broadcast_buffers") is True assert _ddp_kwargs.get("find_unused_parameters") is True + if dist.is_initialized(): + dist.destroy_process_group() +@pytest.mark.torch +@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) +@magic_argv_env_context +def test_trainer_w_evaluator_overfit_torch( + model_and_optimizers: TrainerParameters, + overfit_batches, + num_train_batch_per_epoch +): + """ + 测试一些特殊的参数是否能够正确地传递; + """ + trainer = Trainer( + model=model_and_optimizers.model, + driver="torch", + device=0, + overfit_batches=overfit_batches, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders={"dl": model_and_optimizers.evaluate_dataloaders}, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + n_epochs=2, + output_from_new_proc="all", + evaluate_every=-1, + + torch_kwargs={ + "non_blocking": False, + "set_grad_to_none": True + } + + ) + + trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch) + + if dist.is_initialized(): + dist.destroy_process_group() \ No newline at end of file diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index be04bcd3..a7eeeda6 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -361,5 +361,32 @@ def test_torch_wo_auto_param_call( dist.destroy_process_group() +# 测试 accumulation_steps; +@pytest.mark.torch +@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) +@magic_argv_env_context +def test_trainer_overfit_torch( + model_and_optimizers: TrainerParameters, + overfit_batches, + num_train_batch_per_epoch +): + trainer = Trainer( + model=model_and_optimizers.model, + driver="torch", + device=0, + overfit_batches=overfit_batches, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + output_from_new_proc="all", + n_epochs=2, + ) + + trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch) + if dist.is_initialized(): + dist.destroy_process_group() From a6fc5225cd20638af77756456881c21f53e9cd44 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Thu, 16 Jun 2022 22:37:43 +0800 Subject: [PATCH 2/7] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=20overfit=5Fbat?= =?UTF-8?q?ches=20=E7=9A=84=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/trainer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 00a18f1d..1259a38e 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -119,6 +119,19 @@ class Trainer(TrainerEventTrigger): 对于使用 ``TorchDDPDriver`` 的更多细节,请见 :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`。 :param n_epochs: 训练总共的 epoch 的数量,默认为 20;也可以通过 ``n_batches`` 参数设置总共迭代多少个 ``batch`` 。 + :param overfit_batches: 使用该参数来支持 '过拟合' 的功能;支持的值为 ``-1``、``0`` 或者 大于 0 的整数,表示使用多少 batch 的数据 + 来进行过拟合训练;其中 0 为 默认值表示不进行过拟合;-1 表示使用所有的数据进行训练; + + .. note:: + + 您可以使用该参数来简单地查看您的模型是否是 '正确的',即您的模型是否能够在少量的数据上快速进行收敛,从而说明损失函数以及优化器等 + 没有问题。当使用该参数时,我们会直接从 ``train_dataloader`` 中提取固定大小的 batch,然后在之后的所有 epoch 中都是用这些数据来进行过拟合训练; + + .. warning:: + + 在使用该参数时,您同样可以指定 ``metrics`` 参数来进行简单的验证,当该参数和 ``metrics`` 同时出现时,我们会将 evaluate_dataloaders + 直接替换为在过拟合中所使用的训练数据;因此您需要保证您的 ``metrics`` 是能够在 ``train_dataloader`` 上使用的; + :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 为 None; :param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``, From 399065ae04652b2ef8c71c91b78d1053b853067d Mon Sep 17 00:00:00 2001 From: yhcc Date: Fri, 17 Jun 2022 00:12:53 +0800 Subject: [PATCH 3/7] update overfit_batches --- fastNLP/core/callbacks/progress_callback.py | 19 +++++++----- fastNLP/core/controllers/trainer.py | 32 ++++++++++---------- fastNLP/core/dataloaders/utils.py | 15 +++------- fastNLP/core/drivers/torch_driver/ddp.py | 3 -- fastNLP/core/drivers/torch_driver/utils.py | 46 +++++++++++++++-------------- fastNLP/core/metrics/metric.py | 10 ++++++- 6 files changed, 64 insertions(+), 61 deletions(-) diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 890864ec..2f1d2b17 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -20,6 +20,7 @@ class ProgressCallback(HasMonitorCallback): must_have_monitor=must_have_monitor) self.best_monitor_epoch = -1 self.best_monitor_step = -1 + self.best_results = None def record_better_monitor(self, trainer): self.best_monitor_step = trainer.global_forward_batches @@ -29,6 +30,8 @@ class ProgressCallback(HasMonitorCallback): if self.best_monitor_epoch != -1: msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \ f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}." + if self.best_results is not None: + msg = msg + ' The evaluation result: \n' + str(self.best_results) logger.info(msg) @property @@ -147,9 +150,11 @@ class RichCallback(ProgressCallback): results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if not key.startswith('_')} if self.format_json: - self.progress_bar.console.print_json(json.dumps(results)) + results = json.dumps(results) + self.progress_bar.console.print_json(results) else: self.progress_bar.print(results) + self.best_results = results def clear_tasks(self): for key, taskid in self.task2id.items(): @@ -227,9 +232,9 @@ class RawTextCallback(ProgressCallback): results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if not key.startswith('_')} if self.format_json: - logger.info(json.dumps(results)) - else: - logger.info(results) + results = json.dumps(results) + logger.info(results) + self.best_results = results @property def name(self): # progress bar的名称 @@ -316,9 +321,9 @@ class TqdmCallback(ProgressCallback): results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if not key.startswith('_')} if self.format_json: - logger.info(json.dumps(results)) - else: - logger.info(results) + results = json.dumps(results) + logger.info(results) + self.best_results = results def clear_tasks(self): for key, taskid in self.task2id.items(): diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 1259a38e..0f22e63c 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -119,19 +119,6 @@ class Trainer(TrainerEventTrigger): 对于使用 ``TorchDDPDriver`` 的更多细节,请见 :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`。 :param n_epochs: 训练总共的 epoch 的数量,默认为 20;也可以通过 ``n_batches`` 参数设置总共迭代多少个 ``batch`` 。 - :param overfit_batches: 使用该参数来支持 '过拟合' 的功能;支持的值为 ``-1``、``0`` 或者 大于 0 的整数,表示使用多少 batch 的数据 - 来进行过拟合训练;其中 0 为 默认值表示不进行过拟合;-1 表示使用所有的数据进行训练; - - .. note:: - - 您可以使用该参数来简单地查看您的模型是否是 '正确的',即您的模型是否能够在少量的数据上快速进行收敛,从而说明损失函数以及优化器等 - 没有问题。当使用该参数时,我们会直接从 ``train_dataloader`` 中提取固定大小的 batch,然后在之后的所有 epoch 中都是用这些数据来进行过拟合训练; - - .. warning:: - - 在使用该参数时,您同样可以指定 ``metrics`` 参数来进行简单的验证,当该参数和 ``metrics`` 同时出现时,我们会将 evaluate_dataloaders - 直接替换为在过拟合中所使用的训练数据;因此您需要保证您的 ``metrics`` 是能够在 ``train_dataloader`` 上使用的; - :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 为 None; :param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``, @@ -258,7 +245,20 @@ class Trainer(TrainerEventTrigger): 注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效; - :param n_batches: 迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。 + :param n_batches: 总共迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。 + :param overfit_batches: 使用该参数来支持 '过拟合' 的功能;支持的值为 ``-1``、``0`` 或者 大于 0 的整数,表示使用多少个 batch 的数据 + 来进行过拟合训练;其中 0 为表示不进行任何操作;-1 表示使用所有的数据进行训练; + + .. note:: + + 您可以使用该参数来简单地查看您的模型是否是 '正确的',即您的模型是否能够在少量的数据上快速进行收敛,从而说明损失函数以及优化器等 + 没有问题。当使用该参数时,我们会直接从 ``train_dataloader`` 中提取固定数量的 batch,然后在所有 epoch 中都是用这些数据 + 来进行训练; + + .. warning:: + + 在使用该参数时,您同样可以指定 ``metrics`` 参数来进行简单的验证,当该参数和 ``metrics`` 同时出现时,我们会将 evaluate_dataloaders + 直接替换为在过拟合中所使用的训练数据;因此您需要保证您的 ``metrics`` 是能够在 ``train_dataloader`` 上使用的; :param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None; @@ -370,7 +370,6 @@ class Trainer(TrainerEventTrigger): optimizers, device: Optional[Union[int, List[int], str]] = "cpu", n_epochs: int = 20, - overfit_batches: int = 0, evaluate_dataloaders=None, batch_step_fn: Optional[Callable] = None, evaluate_batch_step_fn: Optional[Callable] = None, @@ -387,6 +386,7 @@ class Trainer(TrainerEventTrigger): monitor: Union[str, Callable] = None, larger_better: bool = True, n_batches: int = -1, + overfit_batches: int = 0, marker: Optional[str] = None, **kwargs ): @@ -522,8 +522,6 @@ class Trainer(TrainerEventTrigger): self.larger_better = larger_better if metrics is not None: if overfit_batches != 0: - logger.warning("Notice you are trying to 'overfit' the model and also using 'metrics', it may cause error " - "because 'metrics' are prepared for 'evaluate_dataloaders', but now 'train_dataloader'.") evaluate_dataloaders = self.dataloader if evaluate_dataloaders is not None: check_evaluate_every(evaluate_every) diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index 9f8b608c..7ce9f153 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -120,20 +120,13 @@ class OverfitDataLoader: def __init__(self, dataloader, overfit_batches: int): self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作; self.batches = [] + self.overfit_batches = int(overfit_batches) - if isinstance(overfit_batches, int): - if overfit_batches < 0 and overfit_batches != -1: - raise ValueError("Parameter 'overfit_batches' can only be '-1' when it is smaller than 0, and it means" - "that you use all the data to check whether it could be overfitted.") - else: - raise TypeError("Parameter 'overfit_batches' can only be 'int' type, check the parameter you input into 'Trainer'.") - - if overfit_batches > len(dataloader): - logger.warning("Parameter 'overfit_batches' is bigger than the real length of 'train dataloader'.") + if self.overfit_batches > len(dataloader): + logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.") for idx, batch in enumerate(dataloader): - - if idx < overfit_batches or overfit_batches == -1: + if idx < self.overfit_batches or self.overfit_batches < -1: self.batches.append(batch) def __len__(self): diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 45a1a61a..affb5ded 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -140,9 +140,6 @@ if _NEED_IMPORT_TORCH: import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.utils.data import BatchSampler - from torch.utils.data import RandomSampler as TorchRandomSampler - from torch.utils.data import SequentialSampler as TorchSequentialSampler - from torch.utils.data import BatchSampler as TorchBatchSampler __all__ = [ 'TorchDDPDriver' diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index f5a76a9e..9bf0da2d 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -181,18 +181,16 @@ def replace_sampler(dataloader: "DataLoader", sampler): instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} # 'multiprocessing_context' 是 user-defined function; - instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context + if getattr(dataloader, 'multiprocessing_context', None) is not None: + instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context # 拿到 dataloader '__init__' 函数的默认函数签名; init_params = dict(inspect.signature(dataloader.__init__).parameters) - # 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 - # 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 - # 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader - # 中寻找; + # 防止用户的 DataLoader 是继承了 pytorch 的 DataLoader,然后还是使用了 **kwargs 的方式对父类传参数 has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) - if has_variadic_kwargs: - # 这里之所以这样写是因为用户自己定制的 Dataloader 中名字一样的参数所设置的默认值可能不同;因此不能直接使用 update 覆盖掉了; + if has_variadic_kwargs and isinstance(dataloader, DataLoader): + # 防止用户写入了 super().__init__(**kwargs) for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): if key not in init_params and key != 'self': init_params[key] = value @@ -204,7 +202,8 @@ def replace_sampler(dataloader: "DataLoader", sampler): non_default_params.add("dataset") reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} - reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None}) + if isinstance(dataloader, DataLoader): + reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None}) batch_sampler = getattr(dataloader, "batch_sampler") if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): @@ -218,35 +217,31 @@ def replace_sampler(dataloader: "DataLoader", sampler): and p.name not in reconstruct_args } - # 这种错误针对的是 __init__ 中的参数没有用同样名字的 self 挂上; + # 在 attribute 中没有找到这些参数,导致了没有办法重新初始化 if required_args: required_args = sorted(required_args) dataloader_self_name = dataloader.__class__.__name__ raise Exception( - f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. " - "This would fail as some of the `__init__` arguments are not available as instance attributes. " - f"The missing attributes are {required_args}. " - f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or " - "manually add the `DistributedSampler` as: " - f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." + f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. " + f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its " + f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be " + f"`{dataloader_self_name}`'s attribute." ) # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; if not has_variadic_kwargs: - # the dataloader signature does not allow keyword arguments that need to be passed missing_kwargs = reconstruct_args.keys() - init_params.keys() if missing_kwargs: missing_kwargs = sorted(missing_kwargs) dataloader_self_name = dataloader.__class__.__name__ raise Exception( - f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. " - "This would fail as it doesn't expose all its attributes in the `__init__` signature. " - f"The missing arguments are {missing_kwargs}. " - f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or " - "manually add the `DistributedSampler` as: " - f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." + f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found." ) + # 如果没有kwargs,则保证一下只传入需要的参数 + if not isinstance(dataloader, DataLoader): + reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params} + return type(dataloader)(**reconstruct_args) @@ -260,6 +255,13 @@ def replace_batch_sampler(dataloader, new_batch_sampler): params_keys.remove(k) params = {k: getattr(dataloader, k) for k in params_keys} params["batch_sampler"] = new_batch_sampler + + if not isinstance(dataloader, DataLoader): + init_params = dict(inspect.signature(dataloader.__init__).parameters) + has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) + if not has_variadic_kwargs: + params = {key:value for key,value in params.items() if key in init_params} + return type(dataloader)(**params) diff --git a/fastNLP/core/metrics/metric.py b/fastNLP/core/metrics/metric.py index 1a69e80c..e2dc3dda 100644 --- a/fastNLP/core/metrics/metric.py +++ b/fastNLP/core/metrics/metric.py @@ -98,7 +98,7 @@ class Metric: return _wrap_get_metric def __setattr__(self, key, value): - if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: + if getattr(self, '_cannot_change_element', False): if key in self.elements and isinstance(value, (float, int, bool)): self.elements[key].fill_value(value) return @@ -109,6 +109,14 @@ class Metric: raise RuntimeError("Please use register_element() function to add Element.") object.__setattr__(self, key, value) + # 当调用 __getattribute__ 没有找到时才会触发这个, 保留这个的目的只是为了防止 ide 的 warning + def __getattr__(self, name: str) -> Element: + if 'elements' in self.__dict__: + elements = self.__dict__['elements'] + if name in elements: + return elements[name] + raise AttributeError("`{}` object has no attribute `{}`.".format(type(self).__name__, name)) + def _wrap_update(self, update): @functools.wraps(update) def _wrap_update(*args, **kwargs): From bb68856f85483cada933c8903f8264e8638c3e53 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Fri, 17 Jun 2022 00:53:13 +0800 Subject: [PATCH 4/7] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=AF=B9=20overfit=20?= =?UTF-8?q?=E5=A4=9A=E5=8D=A1=E7=9A=84=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/utils.py | 3 ++- tests/core/controllers/test_trainer_w_evaluator_torch.py | 7 +++++-- tests/core/controllers/test_trainer_wo_evaluator_torch.py | 7 +++++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index 7ce9f153..06f09da3 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -1,3 +1,4 @@ +import os from typing import Callable, Any, Union, Sequence from abc import ABC import inspect @@ -126,7 +127,7 @@ class OverfitDataLoader: logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.") for idx, batch in enumerate(dataloader): - if idx < self.overfit_batches or self.overfit_batches < -1: + if idx < self.overfit_batches or self.overfit_batches <= -1: self.batches.append(batch) def __len__(self): diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index a70766f5..78eff36c 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -340,10 +340,13 @@ def test_trainer_specific_params_2( @pytest.mark.torch +@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) @pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) @magic_argv_env_context def test_trainer_w_evaluator_overfit_torch( model_and_optimizers: TrainerParameters, + driver, + device, overfit_batches, num_train_batch_per_epoch ): @@ -352,8 +355,8 @@ def test_trainer_w_evaluator_overfit_torch( """ trainer = Trainer( model=model_and_optimizers.model, - driver="torch", - device=0, + driver=driver, + device=device, overfit_batches=overfit_batches, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index a7eeeda6..ce67814e 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -363,17 +363,20 @@ def test_torch_wo_auto_param_call( # 测试 accumulation_steps; @pytest.mark.torch +@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) @pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) @magic_argv_env_context def test_trainer_overfit_torch( model_and_optimizers: TrainerParameters, + driver, + device, overfit_batches, num_train_batch_per_epoch ): trainer = Trainer( model=model_and_optimizers.model, - driver="torch", - device=0, + driver=driver, + device=device, overfit_batches=overfit_batches, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, From dca3377129c834c1d1aeb7e8038dcda819d58a90 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 17 Jun 2022 21:58:50 +0800 Subject: [PATCH 5/7] =?UTF-8?q?ddp=E6=B7=BB=E5=8A=A0=E7=8E=AF=E5=A2=83?= =?UTF-8?q?=E5=8F=98=E9=87=8FRANK=E7=9A=84=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/torch_driver/ddp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 45a1a61a..ae85f0b6 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -421,6 +421,7 @@ class TorchDDPDriver(TorchDriver): os.environ['MASTER_ADDR'] = self.master_address os.environ['MASTER_PORT'] = self.master_port + os.environ["RANK"] = "0" os.environ["LOCAL_RANK"] = str(self.local_rank) os.environ["WORLD_SIZE"] = f"{self.world_size}" @@ -433,6 +434,7 @@ class TorchDDPDriver(TorchDriver): for rank in range(1, len(self.parallel_device)): env_copy = os.environ.copy() env_copy["LOCAL_RANK"] = f"{rank}" + env_copy["RANK"] = f"{rank}" # 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK; env_copy[FASTNLP_GLOBAL_RANK] = str(rank) From 64da46b613547a5768e6b56ffe83ab11ac1caf60 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 17 Jun 2022 23:23:33 +0800 Subject: [PATCH 6/7] =?UTF-8?q?paddle=20replace=5Fbatch=5Fsampler=E5=92=8C?= =?UTF-8?q?check=5Fdataloader=20=E8=B7=9F=E8=BF=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/drivers/jittor_driver/jittor_driver.py | 3 ++- fastNLP/core/drivers/jittor_driver/utils.py | 4 +++ .../core/drivers/paddle_driver/paddle_driver.py | 3 ++- fastNLP/core/drivers/paddle_driver/utils.py | 29 ++++++++++++++-------- 4 files changed, 26 insertions(+), 13 deletions(-) diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index c2e338bb..312f0d83 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from fastNLP.envs.imports import _NEED_IMPORT_JITTOR from fastNLP.core.drivers.driver import Driver from fastNLP.core.dataloaders import JittorDataLoader +from fastNLP.core.dataloaders import OverfitDataLoader from fastNLP.core.samplers import ReproducibleSampler, RandomSampler from fastNLP.core.log import logger from fastNLP.core.utils import apply_to_collection, nullcontext @@ -69,7 +70,7 @@ class JittorDriver(Driver): self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) def check_dataloader_legality(self, dataloader): - if not isinstance(dataloader, (Dataset, JittorDataLoader)): + if not isinstance(dataloader, (Dataset, JittorDataLoader, OverfitDataLoader)): raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`") if len(dataloader) == 0: logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " diff --git a/fastNLP/core/drivers/jittor_driver/utils.py b/fastNLP/core/drivers/jittor_driver/utils.py index c75526df..af840a09 100644 --- a/fastNLP/core/drivers/jittor_driver/utils.py +++ b/fastNLP/core/drivers/jittor_driver/utils.py @@ -14,6 +14,7 @@ from fastNLP.envs import ( FASTNLP_BACKEND_LAUNCH, FASTNLP_GLOBAL_SEED, ) +from fastNLP.core.samplers import ReproducibleBatchSampler from fastNLP.core.log import logger if _NEED_IMPORT_JITTOR: @@ -63,6 +64,9 @@ def replace_batch_sampler(dataloader, batch_sampler): "or report this bug to us.") def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler): + batch_sampler = getattr(dataloader, "sampler") + if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): + raise RuntimeError("It should not be running here, please report a bug to us.") if isinstance(dataloader, JittorDataLoader): init_params = dict(inspect.signature(dataloader.__init__).parameters) reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()} diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 6ef0aaae..bfc26350 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -19,6 +19,7 @@ from fastNLP.envs import ( rank_zero_call, ) from fastNLP.core.log import logger +from fastNLP.core.dataloaders import OverfitDataLoader from fastNLP.core.samplers import ( ReproducibleBatchSampler, ReproducibleSampler, @@ -93,7 +94,7 @@ class PaddleDriver(Driver): self.grad_scaler.update() def check_dataloader_legality(self, dataloader): - if not isinstance(dataloader, DataLoader): + if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader): raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") if dataloader.batch_size is None and dataloader.batch_sampler is None: raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler" diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 1191b60c..be83e5fe 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -15,6 +15,7 @@ from fastNLP.envs import ( FASTNLP_BACKEND_LAUNCH, FASTNLP_GLOBAL_SEED, ) +from fastNLP.core.samplers import ReproducibleBatchSampler from fastNLP.core.utils import auto_param_call, paddle_to from fastNLP.core.log import logger @@ -129,7 +130,7 @@ def _build_fp16_env(dummy=False): "NOTE: your device does NOT support faster training with fp16, " "please switch to FP32 which is likely to be faster" ) - return auto_cast, GradScaler + return auto_cast, GradScaler def find_free_ports(num): """ @@ -189,10 +190,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler non_default_params.add("dataset") reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} - reconstruct_args.update({ - "batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, - "persistent_workers": dataloader._persistent_workers, - }) + if isinstance(dataloader, DataLoader): + reconstruct_args.update({ + "batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, + "persistent_workers": dataloader._persistent_workers, + }) # POSITIONAL_OR_KEYWORD 代表一般的参数 # 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数 @@ -210,9 +212,10 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler required_args = sorted(required_args) dataloader_self_name = dataloader.__class__.__name__ raise Exception( - f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " - "This would fail as some of the `__init__` arguments are not available as instance attributes. " - f"The missing attributes are {required_args}. " + f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. " + f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its " + f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be " + f"`{dataloader_self_name}`'s attribute." ) # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; @@ -224,10 +227,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler missing_kwargs = sorted(missing_kwargs) dataloader_self_name = dataloader.__class__.__name__ raise Exception( - f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " - "This would fail as it doesn't expose all its attributes in the `__init__` signature. " - f"The missing arguments are {missing_kwargs}. " + f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found." ) + # 如果没有kwargs,则保证一下只传入需要的参数 + if not isinstance(dataloader, DataLoader): + reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params} return type(dataloader)(**reconstruct_args) @@ -235,6 +239,9 @@ def replace_sampler(dataloader, new_sampler): """ 使用 ``new_sampler`` 重新构建一个 ``BatchSampler``,并替换到 ``dataloader`` 中 """ + batch_sampler = getattr(dataloader, "batch_sampler") + if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): + raise RuntimeError("It should not be running here, please report a bug to us.") new_batch_sampler = deepcopy(dataloader.batch_sampler) new_batch_sampler.sampler = new_sampler return replace_batch_sampler(dataloader, new_batch_sampler) From b60621f3d1ddace2588535050f9855ba92fde068 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 17 Jun 2022 23:23:43 +0800 Subject: [PATCH 7/7] small --- .../core/drivers/paddle_driver/test_initialize_paddle_driver.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py index 7e567c84..63124cdc 100644 --- a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py +++ b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py @@ -1,3 +1,5 @@ +import os + import pytest from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver @@ -40,9 +42,14 @@ def test_get_fleet(device): """ 测试 fleet 多卡的初始化情况 """ - + flag = False + if "USER_CUDA_VISIBLE_DEVICES" not in os.environ: + os.environ["USER_CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + flag = True model = PaddleNormalModel_Classification_1(20, 10) driver = initialize_paddle_driver("paddle", device, model) + if flag: + del os.environ["USER_CUDA_VISIBLE_DEVICES"] assert isinstance(driver, PaddleFleetDriver)