Browse Source

添加了 overfit 的功能

tags/v1.0.0alpha
YWMditto 3 years ago
parent
commit
024fecfbf3
6 changed files with 146 additions and 26 deletions
  1. +31
    -21
      fastNLP/core/controllers/trainer.py
  2. +5
    -2
      fastNLP/core/dataloaders/__init__.py
  3. +39
    -2
      fastNLP/core/dataloaders/utils.py
  4. +2
    -1
      fastNLP/core/drivers/torch_driver/torch_driver.py
  5. +42
    -0
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  6. +27
    -0
      tests/core/controllers/test_trainer_wo_evaluator_torch.py

+ 31
- 21
fastNLP/core/controllers/trainer.py View File

@@ -35,6 +35,7 @@ from fastNLP.envs import rank_zero_call
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.utils.exceptions import EarlyStopException
from fastNLP.core.dataloaders import OverfitDataLoader


class Trainer(TrainerEventTrigger):
@@ -356,6 +357,7 @@ class Trainer(TrainerEventTrigger):
optimizers,
device: Optional[Union[int, List[int], str]] = "cpu",
n_epochs: int = 20,
overfit_batches: int = 0,
evaluate_dataloaders=None,
batch_step_fn: Optional[Callable] = None,
evaluate_batch_step_fn: Optional[Callable] = None,
@@ -469,9 +471,6 @@ class Trainer(TrainerEventTrigger):
n_batches=n_batches
)

if metrics is None and evaluate_dataloaders is not None:
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.")

if metrics is not None and evaluate_dataloaders is None:
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.")

@@ -495,33 +494,44 @@ class Trainer(TrainerEventTrigger):
else:
_dist_sampler = None

self.dataloader = self.train_dataloader
self.driver.set_deterministic_dataloader(self.dataloader)

self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
reproducible=self.callback_manager._need_reproducible_sampler)
# 进行 overfit 相关的设置;
if overfit_batches != 0:
self.dataloader = OverfitDataLoader(self.dataloader, overfit_batches)
self.overfit_batches = overfit_batches

self.evaluator = None
self.monitor = monitor
self.larger_better = larger_better
if metrics is not None and evaluate_dataloaders is not None:
check_evaluate_every(evaluate_every)
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。
progress_bar = progress_bar.name
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics,
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping,
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0,
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler),
progress_bar=progress_bar,
check_dataloader_legality=kwargs.get('check_dataloader_legality', True))
if metrics is not None:
if overfit_batches != 0:
logger.warning("Notice you are trying to 'overfit' the model and also using 'metrics', it may cause error "
"because 'metrics' are prepared for 'evaluate_dataloaders', but now 'train_dataloader'.")
evaluate_dataloaders = self.dataloader
if evaluate_dataloaders is not None:
check_evaluate_every(evaluate_every)
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。
progress_bar = progress_bar.name
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics,
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping,
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0,
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler),
progress_bar=progress_bar,
check_dataloader_legality=kwargs.get('check_dataloader_legality', True))
else:
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.")

if train_fn is not None and not isinstance(train_fn, str):
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")
self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn)
self.train_fn = train_fn

self.dataloader = self.train_dataloader
self.driver.set_deterministic_dataloader(self.dataloader)

self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
reproducible=self.callback_manager._need_reproducible_sampler)

self.evaluate_batch_step_fn = evaluate_batch_step_fn
self.kwargs = kwargs



+ 5
- 2
fastNLP/core/dataloaders/__init__.py View File

@@ -7,10 +7,13 @@ __all__ = [
'prepare_paddle_dataloader',
'prepare_torch_dataloader',

"prepare_dataloader"
"prepare_dataloader",

"OverfitDataLoader"
]

from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader
from .prepare_dataloader import prepare_dataloader
from .prepare_dataloader import prepare_dataloader
from .utils import OverfitDataLoader

+ 39
- 2
fastNLP/core/dataloaders/utils.py View File

@@ -1,4 +1,4 @@
from typing import Callable, Any, Union
from typing import Callable, Any, Union, Sequence
from abc import ABC
import inspect
import ast
@@ -6,7 +6,8 @@ import ast
from ..log import logger
from ..utils.cache_results import get_func_calls, truncate_start_blanks
__all__ = [
"indice_collate_wrapper"
"indice_collate_wrapper",
"OverfitDataLoader"
]


@@ -111,6 +112,42 @@ class HasLenGetitemType(ABC):
return NotImplemented


class OverfitDataLoader:
"""
实现一个简单的迭代器来模拟实际的 dataloader,从给定的 dataloader 中取出部分数据,来让 Trainer 实现 overfit 的功能;
"""

def __init__(self, dataloader, overfit_batches: int):
self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作;
self.batches = []

if isinstance(overfit_batches, int):
if overfit_batches < 0 and overfit_batches != -1:
raise ValueError("Parameter 'overfit_batches' can only be '-1' when it is smaller than 0, and it means"
"that you use all the data to check whether it could be overfitted.")
else:
raise TypeError("Parameter 'overfit_batches' can only be 'int' type, check the parameter you input into 'Trainer'.")

if overfit_batches > len(dataloader):
logger.warning("Parameter 'overfit_batches' is bigger than the real length of 'train dataloader'.")

for idx, batch in enumerate(dataloader):

if idx < overfit_batches or overfit_batches == -1:
self.batches.append(batch)

def __len__(self):
return len(self.batches)

def __iter__(self):
for batch in self.batches:
yield batch

def __getattr__(self, item):
return getattr(self.dataloader, item)



if __name__ == '__main__':
def demo(*args, **kwargs):
pass


+ 2
- 1
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -31,6 +31,7 @@ from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler
from fastNLP.core.dataloaders import OverfitDataLoader


class TorchDriver(Driver):
@@ -92,7 +93,7 @@ class TorchDriver(Driver):
self.grad_scaler.update()

def check_dataloader_legality(self, dataloader):
if not isinstance(dataloader, DataLoader):
if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader):
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`")
if len(dataloader) == 0:
logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it "


+ 42
- 0
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -286,6 +286,9 @@ def test_trainer_specific_params_1(
assert trainer.driver.non_blocking is False
assert trainer.driver.wo_auto_param_call is True

if dist.is_initialized():
dist.destroy_process_group()


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1)
@@ -332,5 +335,44 @@ def test_trainer_specific_params_2(
assert _ddp_kwargs.get("broadcast_buffers") is True
assert _ddp_kwargs.get("find_unused_parameters") is True

if dist.is_initialized():
dist.destroy_process_group()


@pytest.mark.torch
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)])
@magic_argv_env_context
def test_trainer_w_evaluator_overfit_torch(
model_and_optimizers: TrainerParameters,
overfit_batches,
num_train_batch_per_epoch
):
"""
测试一些特殊的参数是否能够正确地传递;
"""
trainer = Trainer(
model=model_and_optimizers.model,
driver="torch",
device=0,
overfit_batches=overfit_batches,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders={"dl": model_and_optimizers.evaluate_dataloaders},
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=2,
output_from_new_proc="all",
evaluate_every=-1,

torch_kwargs={
"non_blocking": False,
"set_grad_to_none": True
}

)

trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch)

if dist.is_initialized():
dist.destroy_process_group()

+ 27
- 0
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -361,5 +361,32 @@ def test_torch_wo_auto_param_call(
dist.destroy_process_group()


# 测试 accumulation_steps;
@pytest.mark.torch
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)])
@magic_argv_env_context
def test_trainer_overfit_torch(
model_and_optimizers: TrainerParameters,
overfit_batches,
num_train_batch_per_epoch
):
trainer = Trainer(
model=model_and_optimizers.model,
driver="torch",
device=0,
overfit_batches=overfit_batches,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
output_from_new_proc="all",
n_epochs=2,
)

trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch)

if dist.is_initialized():
dist.destroy_process_group()


Loading…
Cancel
Save