diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index 7ce9f153..06f09da3 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -1,3 +1,4 @@ +import os from typing import Callable, Any, Union, Sequence from abc import ABC import inspect @@ -126,7 +127,7 @@ class OverfitDataLoader: logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.") for idx, batch in enumerate(dataloader): - if idx < self.overfit_batches or self.overfit_batches < -1: + if idx < self.overfit_batches or self.overfit_batches <= -1: self.batches.append(batch) def __len__(self): diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index a70766f5..78eff36c 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -340,10 +340,13 @@ def test_trainer_specific_params_2( @pytest.mark.torch +@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) @pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) @magic_argv_env_context def test_trainer_w_evaluator_overfit_torch( model_and_optimizers: TrainerParameters, + driver, + device, overfit_batches, num_train_batch_per_epoch ): @@ -352,8 +355,8 @@ def test_trainer_w_evaluator_overfit_torch( """ trainer = Trainer( model=model_and_optimizers.model, - driver="torch", - device=0, + driver=driver, + device=device, overfit_batches=overfit_batches, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index a7eeeda6..ce67814e 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -363,17 +363,20 @@ def test_torch_wo_auto_param_call( # 测试 accumulation_steps; @pytest.mark.torch +@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) @pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) @magic_argv_env_context def test_trainer_overfit_torch( model_and_optimizers: TrainerParameters, + driver, + device, overfit_batches, num_train_batch_per_epoch ): trainer = Trainer( model=model_and_optimizers.model, - driver="torch", - device=0, + driver=driver, + device=device, overfit_batches=overfit_batches, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader,