Browse Source

添加对 overfit 多卡的测试

tags/v1.0.0alpha
YWMditto 3 years ago
parent
commit
bb68856f85
3 changed files with 12 additions and 5 deletions
  1. +2
    -1
      fastNLP/core/dataloaders/utils.py
  2. +5
    -2
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  3. +5
    -2
      tests/core/controllers/test_trainer_wo_evaluator_torch.py

+ 2
- 1
fastNLP/core/dataloaders/utils.py View File

@@ -1,3 +1,4 @@
import os
from typing import Callable, Any, Union, Sequence
from abc import ABC
import inspect
@@ -126,7 +127,7 @@ class OverfitDataLoader:
logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.")

for idx, batch in enumerate(dataloader):
if idx < self.overfit_batches or self.overfit_batches < -1:
if idx < self.overfit_batches or self.overfit_batches <= -1:
self.batches.append(batch)

def __len__(self):


+ 5
- 2
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -340,10 +340,13 @@ def test_trainer_specific_params_2(


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)])
@magic_argv_env_context
def test_trainer_w_evaluator_overfit_torch(
model_and_optimizers: TrainerParameters,
driver,
device,
overfit_batches,
num_train_batch_per_epoch
):
@@ -352,8 +355,8 @@ def test_trainer_w_evaluator_overfit_torch(
"""
trainer = Trainer(
model=model_and_optimizers.model,
driver="torch",
device=0,
driver=driver,
device=device,
overfit_batches=overfit_batches,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,


+ 5
- 2
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -363,17 +363,20 @@ def test_torch_wo_auto_param_call(

# 测试 accumulation_steps;
@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])])
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)])
@magic_argv_env_context
def test_trainer_overfit_torch(
model_and_optimizers: TrainerParameters,
driver,
device,
overfit_batches,
num_train_batch_per_epoch
):
trainer = Trainer(
model=model_and_optimizers.model,
driver="torch",
device=0,
driver=driver,
device=device,
overfit_batches=overfit_batches,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,


Loading…
Cancel
Save