You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sequence_labeling.py 1.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import pytest
  2. from .model_runner import *
  3. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  4. if _NEED_IMPORT_TORCH:
  5. from fastNLP.models.torch.sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF
  6. @pytest.mark.torch
  7. class TestBiLSTM:
  8. def test_case1(self):
  9. # 测试能否正常运行CNN
  10. init_emb = (VOCAB_SIZE, 30)
  11. model = BiLSTMCRF(init_emb,
  12. hidden_size=30,
  13. num_classes=NUM_CLS)
  14. dl = RUNNER.prepare_pos_tagging_data()
  15. metric = Accuracy()
  16. RUNNER.run_model(model, dl, metric)
  17. @pytest.mark.torch
  18. class TestSeqLabel:
  19. def test_case1(self):
  20. # 测试能否正常运行CNN
  21. init_emb = (VOCAB_SIZE, 30)
  22. model = SeqLabeling(init_emb,
  23. hidden_size=30,
  24. num_classes=NUM_CLS)
  25. dl = RUNNER.prepare_pos_tagging_data()
  26. metric = Accuracy()
  27. RUNNER.run_model(model, dl, metric)
  28. @pytest.mark.torch
  29. class TestAdvSeqLabel:
  30. def test_case1(self):
  31. # 测试能否正常运行CNN
  32. init_emb = (VOCAB_SIZE, 30)
  33. model = AdvSeqLabel(init_emb,
  34. hidden_size=30,
  35. num_classes=NUM_CLS)
  36. dl = RUNNER.prepare_pos_tagging_data()
  37. metric = Accuracy()
  38. RUNNER.run_model(model, dl, metric)