You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer_paddle.py 2.5 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import os
  2. from typing import List
  3. import pytest
  4. from dataclasses import dataclass
  5. from fastNLP.core.controllers.trainer import Trainer
  6. from fastNLP.core.metrics.accuracy import Accuracy
  7. from fastNLP.core.callbacks.progress_callback import RichCallback
  8. from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
  9. from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES
  10. if _NEED_IMPORT_PADDLE:
  11. from paddle.optimizer import Adam
  12. from paddle.io import DataLoader
  13. from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
  14. from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset
  15. from tests.helpers.utils import magic_argv_env_context
  16. @dataclass
  17. class TrainPaddleConfig:
  18. num_labels: int = 3
  19. feature_dimension: int = 3
  20. batch_size: int = 2
  21. shuffle: bool = True
  22. evaluate_every = 2
  23. @pytest.mark.parametrize("device", ["cpu", 1, [0, 1]])
  24. @pytest.mark.parametrize("callbacks", [[RichCallback(5)]])
  25. @pytest.mark.paddledist
  26. @magic_argv_env_context
  27. def test_trainer_paddle(
  28. device,
  29. callbacks,
  30. n_epochs=2,
  31. ):
  32. if isinstance(device, List) and USER_CUDA_VISIBLE_DEVICES not in os.environ:
  33. pytest.skip("Skip test fleet if FASTNLP_BACKEND is not set to paddle.")
  34. model = PaddleNormalModel_Classification_1(
  35. num_labels=TrainPaddleConfig.num_labels,
  36. feature_dimension=TrainPaddleConfig.feature_dimension
  37. )
  38. optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
  39. train_dataloader = DataLoader(
  40. dataset=PaddleArgMaxDataset(20, TrainPaddleConfig.feature_dimension),
  41. batch_size=TrainPaddleConfig.batch_size,
  42. shuffle=True
  43. )
  44. val_dataloader = DataLoader(
  45. dataset=PaddleArgMaxDataset(12, TrainPaddleConfig.feature_dimension),
  46. batch_size=TrainPaddleConfig.batch_size,
  47. shuffle=True
  48. )
  49. train_dataloader = train_dataloader
  50. evaluate_dataloaders = val_dataloader
  51. evaluate_every = TrainPaddleConfig.evaluate_every
  52. metrics = {"acc": Accuracy(backend="paddle")}
  53. trainer = Trainer(
  54. model=model,
  55. driver="paddle",
  56. device=device,
  57. optimizers=optimizers,
  58. train_dataloader=train_dataloader,
  59. evaluate_dataloaders=evaluate_dataloaders,
  60. evaluate_every=evaluate_every,
  61. input_mapping=None,
  62. output_mapping=None,
  63. metrics=metrics,
  64. n_epochs=n_epochs,
  65. callbacks=callbacks,
  66. )
  67. trainer.run()