You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer_deepspeed.py 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import pytest
  2. from dataclasses import dataclass
  3. from fastNLP.core.controllers.trainer import Trainer
  4. from fastNLP.core.metrics.accuracy import Accuracy
  5. from fastNLP.core.callbacks.progress_callback import RichCallback
  6. from fastNLP.core.drivers.torch_driver import DeepSpeedDriver
  7. from fastNLP.core.drivers.torch_driver.utils import _create_default_config
  8. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  9. if _NEED_IMPORT_TORCH:
  10. import torch
  11. from torch.optim import Adam
  12. from torch.utils.data import DataLoader
  13. from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
  14. from tests.helpers.datasets.torch_data import TorchArgMaxDataset
  15. from tests.helpers.utils import magic_argv_env_context
  16. @dataclass
  17. class TrainDeepSpeedConfig:
  18. num_labels: int = 3
  19. feature_dimension: int = 3
  20. batch_size: int = 2
  21. shuffle: bool = True
  22. evaluate_every = 2
  23. @pytest.mark.deepspeed
  24. class TestTrainer:
  25. @classmethod
  26. def setup_class(cls):
  27. # 不初始化的话从第二个测试例开始会因为环境变量报错。
  28. torch_model = TorchNormalModel_Classification_1(1, 1)
  29. torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01)
  30. device = [torch.device(i) for i in [0,1]]
  31. driver = DeepSpeedDriver(
  32. model=torch_model,
  33. parallel_device=device,
  34. )
  35. driver.set_optimizers(torch_opt)
  36. driver.setup()
  37. return driver
  38. @pytest.mark.parametrize("device", [[0, 1]])
  39. @pytest.mark.parametrize("callbacks", [[RichCallback(5)]])
  40. @pytest.mark.parametrize("strategy", ["deepspeed", "deepspeed_stage_1"])
  41. @pytest.mark.parametrize("config", [None, _create_default_config(stage=1)])
  42. @magic_argv_env_context
  43. def test_trainer_deepspeed(
  44. self,
  45. device,
  46. callbacks,
  47. strategy,
  48. config,
  49. n_epochs=2,
  50. ):
  51. model = TorchNormalModel_Classification_1(
  52. num_labels=TrainDeepSpeedConfig.num_labels,
  53. feature_dimension=TrainDeepSpeedConfig.feature_dimension
  54. )
  55. optimizers = Adam(params=model.parameters(), lr=0.0001)
  56. train_dataloader = DataLoader(
  57. dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 20),
  58. batch_size=TrainDeepSpeedConfig.batch_size,
  59. shuffle=True
  60. )
  61. val_dataloader = DataLoader(
  62. dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 12),
  63. batch_size=TrainDeepSpeedConfig.batch_size,
  64. shuffle=True
  65. )
  66. train_dataloader = train_dataloader
  67. evaluate_dataloaders = val_dataloader
  68. evaluate_every = TrainDeepSpeedConfig.evaluate_every
  69. metrics = {"acc": Accuracy()}
  70. if config is not None:
  71. config["train_micro_batch_size_per_gpu"] = TrainDeepSpeedConfig.batch_size
  72. trainer = Trainer(
  73. model=model,
  74. driver="torch",
  75. device=device,
  76. optimizers=optimizers,
  77. train_dataloader=train_dataloader,
  78. evaluate_dataloaders=evaluate_dataloaders,
  79. evaluate_every=evaluate_every,
  80. metrics=metrics,
  81. output_mapping={"preds": "pred"},
  82. n_epochs=n_epochs,
  83. callbacks=callbacks,
  84. deepspeed_kwargs={
  85. "strategy": strategy,
  86. "config": config
  87. }
  88. )
  89. trainer.run()