import os from typing import List import pytest from dataclasses import dataclass from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.callbacks.progress_callback import RichCallback from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES if _NEED_IMPORT_PADDLE: from paddle.optimizer import Adam from paddle.io import DataLoader from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset from tests.helpers.utils import magic_argv_env_context @dataclass class TrainPaddleConfig: num_labels: int = 3 feature_dimension: int = 3 batch_size: int = 2 shuffle: bool = True evaluate_every = 2 @pytest.mark.parametrize("device", ["cpu", 1, [0, 1]]) @pytest.mark.parametrize("callbacks", [[RichCallback(5)]]) @pytest.mark.paddledist @magic_argv_env_context def test_trainer_paddle( device, callbacks, n_epochs=2, ): if isinstance(device, List) and USER_CUDA_VISIBLE_DEVICES not in os.environ: pytest.skip("Skip test fleet if FASTNLP_BACKEND is not set to paddle.") model = PaddleNormalModel_Classification_1( num_labels=TrainPaddleConfig.num_labels, feature_dimension=TrainPaddleConfig.feature_dimension ) optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) train_dataloader = DataLoader( dataset=PaddleArgMaxDataset(20, TrainPaddleConfig.feature_dimension), batch_size=TrainPaddleConfig.batch_size, shuffle=True ) val_dataloader = DataLoader( dataset=PaddleArgMaxDataset(12, TrainPaddleConfig.feature_dimension), batch_size=TrainPaddleConfig.batch_size, shuffle=True ) train_dataloader = train_dataloader evaluate_dataloaders = val_dataloader evaluate_every = TrainPaddleConfig.evaluate_every metrics = {"acc": Accuracy(backend="paddle")} trainer = Trainer( model=model, driver="paddle", device=device, optimizers=optimizers, train_dataloader=train_dataloader, evaluate_dataloaders=evaluate_dataloaders, evaluate_every=evaluate_every, input_mapping=None, output_mapping=None, metrics=metrics, n_epochs=n_epochs, callbacks=callbacks, ) trainer.run()