|
|
@@ -340,10 +340,13 @@ def test_trainer_specific_params_2( |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
|
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) |
|
|
|
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) |
|
|
|
@magic_argv_env_context |
|
|
|
def test_trainer_w_evaluator_overfit_torch( |
|
|
|
model_and_optimizers: TrainerParameters, |
|
|
|
driver, |
|
|
|
device, |
|
|
|
overfit_batches, |
|
|
|
num_train_batch_per_epoch |
|
|
|
): |
|
|
@@ -352,8 +355,8 @@ def test_trainer_w_evaluator_overfit_torch( |
|
|
|
""" |
|
|
|
trainer = Trainer( |
|
|
|
model=model_and_optimizers.model, |
|
|
|
driver="torch", |
|
|
|
device=0, |
|
|
|
driver=driver, |
|
|
|
device=device, |
|
|
|
overfit_batches=overfit_batches, |
|
|
|
optimizers=model_and_optimizers.optimizers, |
|
|
|
train_dataloader=model_and_optimizers.train_dataloader, |
|
|
|