You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_biaffine_parser.py 3.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import pytest
  2. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  3. if _NEED_IMPORT_TORCH:
  4. import torch
  5. from fastNLP.models.torch.biaffine_parser import BiaffineParser
  6. from fastNLP import Metric, seq_len_to_mask
  7. from .model_runner import *
  8. class ParserMetric(Metric):
  9. r"""
  10. 评估parser的性能
  11. """
  12. def __init__(self):
  13. super().__init__()
  14. self.num_arc = 0
  15. self.num_label = 0
  16. self.num_sample = 0
  17. def get_metric(self, reset=True):
  18. res = {'UAS': self.num_arc * 1.0 / self.num_sample, 'LAS': self.num_label * 1.0 / self.num_sample}
  19. if reset:
  20. self.num_sample = self.num_label = self.num_arc = 0
  21. return res
  22. def update(self, pred1, pred2, target1, target2, seq_len=None):
  23. r"""
  24. :param pred1: 边预测logits
  25. :param pred2: label预测logits
  26. :param target1: 真实边的标注
  27. :param target2: 真实类别的标注
  28. :param seq_len: 序列长度
  29. :return dict: 评估结果::
  30. UAS: 不带label时, 边预测的准确率
  31. LAS: 同时预测边和label的准确率
  32. """
  33. if seq_len is None:
  34. seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long)
  35. else:
  36. seq_mask = seq_len_to_mask(seq_len.long()).long()
  37. # mask out <root> tag
  38. seq_mask[:, 0] = 0
  39. head_pred_correct = (pred1 == target1).long() * seq_mask
  40. label_pred_correct = (pred2 == target2).long() * head_pred_correct
  41. self.num_arc += head_pred_correct.sum().item()
  42. self.num_label += label_pred_correct.sum().item()
  43. self.num_sample += seq_mask.sum().item()
  44. def prepare_parser_data():
  45. index = 'index'
  46. ds = DataSet({index: list(range(N_SAMPLES))})
  47. ds.apply_field(lambda x: RUNNER.gen_var_seq(MAX_LEN, VOCAB_SIZE),
  48. field_name=index, new_field_name='words1')
  49. ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS),
  50. field_name='words1', new_field_name='words2')
  51. # target1 is heads, should in range(0, len(words))
  52. ds.apply_field(lambda x: RUNNER.gen_seq(len(x), len(x)),
  53. field_name='words1', new_field_name='target1')
  54. ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS),
  55. field_name='words1', new_field_name='target2')
  56. ds.apply_field(len, field_name='words1', new_field_name='seq_len')
  57. dl = TorchDataLoader(ds, batch_size=BATCH_SIZE)
  58. return dl
  59. @pytest.mark.torch
  60. class TestBiaffineParser:
  61. def test_train(self):
  62. model = BiaffineParser(embed=(VOCAB_SIZE, 10),
  63. pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10,
  64. rnn_hidden_size=10,
  65. arc_mlp_size=10,
  66. label_mlp_size=10,
  67. num_label=NUM_CLS, encoder='var-lstm')
  68. ds = prepare_parser_data()
  69. RUNNER.run_model(model, ds, metrics=ParserMetric())
  70. def test_train2(self):
  71. model = BiaffineParser(embed=(VOCAB_SIZE, 10),
  72. pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10,
  73. rnn_hidden_size=16,
  74. arc_mlp_size=10,
  75. label_mlp_size=10,
  76. num_label=NUM_CLS, encoder='transformer')
  77. ds = prepare_parser_data()
  78. RUNNER.run_model(model, ds, metrics=ParserMetric())