You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer_oneflow.py 2.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import os
  2. import pytest
  3. from dataclasses import dataclass
  4. from fastNLP.core.controllers.trainer import Trainer
  5. from fastNLP.core.metrics.accuracy import Accuracy
  6. from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW
  7. if _NEED_IMPORT_ONEFLOW:
  8. from oneflow.optim import Adam
  9. from oneflow.utils.data import DataLoader
  10. from tests.helpers.models.oneflow_model import OneflowNormalModel_Classification_1
  11. from tests.helpers.datasets.oneflow_data import OneflowArgMaxDataset
  12. from tests.helpers.utils import magic_argv_env_context
  13. @dataclass
  14. class TrainOneflowConfig:
  15. num_labels: int = 3
  16. feature_dimension: int = 3
  17. batch_size: int = 2
  18. shuffle: bool = True
  19. evaluate_every = 2
  20. @pytest.mark.parametrize("device", ["cpu", 1])
  21. @pytest.mark.parametrize("callbacks", [[]])
  22. @pytest.mark.oneflow
  23. @magic_argv_env_context
  24. def test_trainer_oneflow(
  25. device,
  26. callbacks,
  27. n_epochs=2,
  28. ):
  29. model = OneflowNormalModel_Classification_1(
  30. num_labels=TrainOneflowConfig.num_labels,
  31. feature_dimension=TrainOneflowConfig.feature_dimension
  32. )
  33. optimizers = Adam(params=model.parameters(), lr=0.0001)
  34. train_dataloader = DataLoader(
  35. dataset=OneflowArgMaxDataset(20, TrainOneflowConfig.feature_dimension),
  36. batch_size=TrainOneflowConfig.batch_size,
  37. shuffle=True
  38. )
  39. val_dataloader = DataLoader(
  40. dataset=OneflowArgMaxDataset(12, TrainOneflowConfig.feature_dimension),
  41. batch_size=TrainOneflowConfig.batch_size,
  42. shuffle=True
  43. )
  44. train_dataloader = train_dataloader
  45. evaluate_dataloaders = val_dataloader
  46. evaluate_every = TrainOneflowConfig.evaluate_every
  47. metrics = {"acc": Accuracy()}
  48. trainer = Trainer(
  49. model=model,
  50. driver="oneflow",
  51. device=device,
  52. optimizers=optimizers,
  53. train_dataloader=train_dataloader,
  54. evaluate_dataloaders=evaluate_dataloaders,
  55. evaluate_every=evaluate_every,
  56. input_mapping=None,
  57. output_mapping=None,
  58. metrics=metrics,
  59. n_epochs=n_epochs,
  60. callbacks=callbacks,
  61. )
  62. trainer.run()