import pytest from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch from fastNLP.models.torch.biaffine_parser import BiaffineParser from fastNLP import Metric, seq_len_to_mask from .model_runner import * class ParserMetric(Metric): r""" 评估parser的性能 """ def __init__(self): super().__init__() self.num_arc = 0 self.num_label = 0 self.num_sample = 0 def get_metric(self, reset=True): res = {'UAS': self.num_arc * 1.0 / self.num_sample, 'LAS': self.num_label * 1.0 / self.num_sample} if reset: self.num_sample = self.num_label = self.num_arc = 0 return res def update(self, pred1, pred2, target1, target2, seq_len=None): r""" :param pred1: 边预测logits :param pred2: label预测logits :param target1: 真实边的标注 :param target2: 真实类别的标注 :param seq_len: 序列长度 :return dict: 评估结果:: UAS: 不带label时, 边预测的准确率 LAS: 同时预测边和label的准确率 """ if seq_len is None: seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long) else: seq_mask = seq_len_to_mask(seq_len.long()).long() # mask out tag seq_mask[:, 0] = 0 head_pred_correct = (pred1 == target1).long() * seq_mask label_pred_correct = (pred2 == target2).long() * head_pred_correct self.num_arc += head_pred_correct.sum().item() self.num_label += label_pred_correct.sum().item() self.num_sample += seq_mask.sum().item() def prepare_parser_data(): index = 'index' ds = DataSet({index: list(range(N_SAMPLES))}) ds.apply_field(lambda x: RUNNER.gen_var_seq(MAX_LEN, VOCAB_SIZE), field_name=index, new_field_name='words1') ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), field_name='words1', new_field_name='words2') # target1 is heads, should in range(0, len(words)) ds.apply_field(lambda x: RUNNER.gen_seq(len(x), len(x)), field_name='words1', new_field_name='target1') ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), field_name='words1', new_field_name='target2') ds.apply_field(len, field_name='words1', new_field_name='seq_len') dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) return dl @pytest.mark.torch class TestBiaffineParser: def test_train(self): model = BiaffineParser(embed=(VOCAB_SIZE, 10), pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10, rnn_hidden_size=10, arc_mlp_size=10, label_mlp_size=10, num_label=NUM_CLS, encoder='var-lstm') ds = prepare_parser_data() RUNNER.run_model(model, ds, metrics=ParserMetric()) def test_train2(self): model = BiaffineParser(embed=(VOCAB_SIZE, 10), pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10, rnn_hidden_size=16, arc_mlp_size=10, label_mlp_size=10, num_label=NUM_CLS, encoder='transformer') ds = prepare_parser_data() RUNNER.run_model(model, ds, metrics=ParserMetric())