You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tester.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import unittest
  2. data_name = "pku_training.utf8"
  3. pickle_path = "data_for_tests"
  4. import numpy as np
  5. import torch.nn.functional as F
  6. from torch import nn
  7. import time
  8. from fastNLP.core.utils import CheckError
  9. from fastNLP.core.dataset import DataSet
  10. from fastNLP.core.instance import Instance
  11. from fastNLP.core.losses import BCELoss
  12. from fastNLP.core.losses import CrossEntropyLoss
  13. from fastNLP.core.metrics import AccuracyMetric
  14. from fastNLP.core.optimizer import SGD
  15. from fastNLP.core.tester import Tester
  16. from fastNLP.models.base_model import NaiveClassifier
  17. def prepare_fake_dataset():
  18. mean = np.array([-3, -3])
  19. cov = np.array([[1, 0], [0, 1]])
  20. class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
  21. mean = np.array([3, 3])
  22. cov = np.array([[1, 0], [0, 1]])
  23. class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
  24. data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
  25. [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
  26. return data_set
  27. def prepare_fake_dataset2(*args, size=100):
  28. ys = np.random.randint(4, size=100, dtype=np.int64)
  29. data = {'y': ys}
  30. for arg in args:
  31. data[arg] = np.random.randn(size, 5)
  32. return DataSet(data=data)
  33. class TestTester(unittest.TestCase):
  34. def test_case_1(self):
  35. # 检查报错提示能否正确提醒用户
  36. # 这里传入多余参数,让其duplicate
  37. dataset = prepare_fake_dataset2('x1', 'x_unused')
  38. dataset.rename_field('x_unused', 'x2')
  39. dataset.set_input('x1', 'x2')
  40. dataset.set_target('y', 'x1')
  41. class Model(nn.Module):
  42. def __init__(self):
  43. super().__init__()
  44. self.fc = nn.Linear(5, 4)
  45. def forward(self, x1, x2):
  46. x1 = self.fc(x1)
  47. x2 = self.fc(x2)
  48. x = x1 + x2
  49. time.sleep(0.1)
  50. # loss = F.cross_entropy(x, y)
  51. return {'preds': x}
  52. model = Model()
  53. tester = Tester(
  54. data=dataset,
  55. model=model,
  56. metrics=AccuracyMetric())
  57. tester.test()