You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cnn_text_classification.py 987 B

123456789101112131415161718192021222324252627282930313233
  1. import pytest
  2. from .model_runner import *
  3. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  4. if _NEED_IMPORT_TORCH:
  5. from fastNLP.models.torch.cnn_text_classification import CNNText
  6. @pytest.mark.torch
  7. class TestCNNText:
  8. def init_model(self, kernel_sizes, kernel_nums=(1,3,5)):
  9. model = CNNText((VOCAB_SIZE, 30),
  10. NUM_CLS,
  11. kernel_nums=kernel_nums,
  12. kernel_sizes=kernel_sizes)
  13. return model
  14. def test_case1(self):
  15. # 测试能否正常运行CNN
  16. model = self.init_model((1,3,5))
  17. RUNNER.run_model_with_task(TEXT_CLS, model)
  18. def test_init_model(self):
  19. with pytest.raises(Exception):
  20. self.init_model(2, 4)
  21. with pytest.raises(Exception):
  22. self.init_model((2,))
  23. def test_output(self):
  24. model = self.init_model((3,), (1,))
  25. global MAX_LEN
  26. MAX_LEN = 2
  27. RUNNER.run_model_with_task(TEXT_CLS, model)