You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_predictor.py 1.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import unittest
  2. from collections import defaultdict
  3. import numpy as np
  4. import torch
  5. from fastNLP.core.dataset import DataSet
  6. from fastNLP.core.instance import Instance
  7. from fastNLP.core.predictor import Predictor
  8. from fastNLP.modules.encoder.linear import Linear
  9. def prepare_fake_dataset():
  10. mean = np.array([-3, -3])
  11. cov = np.array([[1, 0], [0, 1]])
  12. class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
  13. mean = np.array([3, 3])
  14. cov = np.array([[1, 0], [0, 1]])
  15. class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
  16. data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
  17. [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
  18. return data_set
  19. class LinearModel(torch.nn.Module):
  20. def __init__(self):
  21. super(LinearModel, self).__init__()
  22. self.linear = Linear(2, 1)
  23. def forward(self, x):
  24. return {"predict": self.linear(x)}
  25. class TestPredictor(unittest.TestCase):
  26. def test_simple(self):
  27. model = LinearModel()
  28. predictor = Predictor(model)
  29. data = prepare_fake_dataset()
  30. data.set_input("x")
  31. ans = predictor.predict(data)
  32. self.assertTrue(isinstance(ans, defaultdict))
  33. self.assertTrue("predict" in ans)
  34. self.assertTrue(isinstance(ans["predict"], list))
  35. def test_sequence(self):
  36. # test sequence input/output
  37. pass