You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_biaffine_parser.py 4.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import unittest
  2. import fastNLP
  3. from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric
  4. data_file = """
  5. 1 The _ DET DT _ 3 det _ _
  6. 2 new _ ADJ JJ _ 3 amod _ _
  7. 3 rate _ NOUN NN _ 6 nsubj _ _
  8. 4 will _ AUX MD _ 6 aux _ _
  9. 5 be _ VERB VB _ 6 cop _ _
  10. 6 payable _ ADJ JJ _ 0 root _ _
  11. 7 mask _ ADJ JJ _ 6 punct _ _
  12. 8 mask _ ADJ JJ _ 6 punct _ _
  13. 9 cents _ NOUN NNS _ 4 nmod _ _
  14. 10 from _ ADP IN _ 12 case _ _
  15. 11 seven _ NUM CD _ 12 nummod _ _
  16. 12 cents _ NOUN NNS _ 4 nmod _ _
  17. 13 a _ DET DT _ 14 det _ _
  18. 14 share _ NOUN NN _ 12 nmod:npmod _ _
  19. 15 . _ PUNCT . _ 4 punct _ _
  20. 1 The _ DET DT _ 3 det _ _
  21. 2 new _ ADJ JJ _ 3 amod _ _
  22. 3 rate _ NOUN NN _ 6 nsubj _ _
  23. 4 will _ AUX MD _ 6 aux _ _
  24. 5 be _ VERB VB _ 6 cop _ _
  25. 6 payable _ ADJ JJ _ 0 root _ _
  26. 7 Feb. _ PROPN NNP _ 6 nmod:tmod _ _
  27. 8 15 _ NUM CD _ 7 nummod _ _
  28. 9 . _ PUNCT . _ 6 punct _ _
  29. 1 A _ DET DT _ 3 det _ _
  30. 2 record _ NOUN NN _ 3 compound _ _
  31. 3 date _ NOUN NN _ 7 nsubjpass _ _
  32. 4 has _ AUX VBZ _ 7 aux _ _
  33. 5 n't _ PART RB _ 7 neg _ _
  34. 6 been _ AUX VBN _ 7 auxpass _ _
  35. 7 set _ VERB VBN _ 0 root _ _
  36. 8 . _ PUNCT . _ 7 punct _ _
  37. """
  38. def init_data():
  39. ds = fastNLP.DataSet()
  40. v = {'word_seq': fastNLP.Vocabulary(),
  41. 'pos_seq': fastNLP.Vocabulary(),
  42. 'label_true': fastNLP.Vocabulary()}
  43. data = []
  44. for line in data_file.split('\n'):
  45. line = line.split()
  46. if len(line) == 0 and len(data) > 0:
  47. data = list(zip(*data))
  48. ds.append(fastNLP.Instance(word_seq=data[1],
  49. pos_seq=data[4],
  50. arc_true=data[6],
  51. label_true=data[7]))
  52. data = []
  53. elif len(line) > 0:
  54. data.append(line)
  55. for name in ['word_seq', 'pos_seq', 'label_true']:
  56. ds.apply(lambda x: ['<st>'] + list(x[name]), new_field_name=name)
  57. ds.apply(lambda x: v[name].add_word_lst(x[name]))
  58. for name in ['word_seq', 'pos_seq', 'label_true']:
  59. ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name)
  60. ds.apply(lambda x: [0] + list(map(int, x['arc_true'])), new_field_name='arc_true')
  61. ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens')
  62. ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True)
  63. ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True)
  64. return ds, v['word_seq'], v['pos_seq'], v['label_true']
  65. class TestBiaffineParser(unittest.TestCase):
  66. def test_train(self):
  67. ds, v1, v2, v3 = init_data()
  68. model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30,
  69. pos_vocab_size=len(v2), pos_emb_dim=30,
  70. num_label=len(v3), encoder='var-lstm')
  71. trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds,
  72. loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS',
  73. batch_size=1, validate_every=10,
  74. n_epochs=10, use_cuda=False, use_tqdm=False)
  75. trainer.train(load_best_model=False)
  76. if __name__ == '__main__':
  77. unittest.main()