You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer.py 1.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import unittest
  2. import numpy as np
  3. from fastNLP.core.dataset import DataSet
  4. from fastNLP.core.instance import Instance
  5. from fastNLP.core.losses import BCELoss
  6. from fastNLP.core.metrics import AccuracyMetric
  7. from fastNLP.core.optimizer import SGD
  8. from fastNLP.core.trainer import Trainer
  9. from fastNLP.models.base_model import NaiveClassifier
  10. class TrainerTestGround(unittest.TestCase):
  11. def test_case(self):
  12. mean = np.array([-5, -5])
  13. cov = np.array([[1, 0], [0, 1]])
  14. class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
  15. mean = np.array([5, 5])
  16. cov = np.array([[1, 0], [0, 1]])
  17. class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
  18. data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
  19. [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
  20. data_set.set_input("x", flag=True)
  21. data_set.set_target("y", flag=True)
  22. train_set, dev_set = data_set.split(0.3)
  23. model = NaiveClassifier(2, 1)
  24. trainer = Trainer(train_set, model,
  25. losser=BCELoss(input="predict", target="y"),
  26. metrics=AccuracyMetric(pred="predict", target="y"),
  27. n_epochs=10,
  28. batch_size=32,
  29. print_every=10,
  30. validate_every=-1,
  31. dev_data=dev_set,
  32. optimizer=SGD(0.001),
  33. check_code_level=2
  34. )
  35. trainer.train()