You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tutorial.py 3.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import unittest
  2. from fastNLP import DataSet
  3. from fastNLP import Instance
  4. from fastNLP import Tester
  5. from fastNLP import Vocabulary
  6. from fastNLP.core.losses import CrossEntropyLoss
  7. from fastNLP.core.metrics import AccuracyMetric
  8. from fastNLP.models import CNNText
  9. class TestTutorial(unittest.TestCase):
  10. def test_tutorial(self):
  11. # 从csv读取数据到DataSet
  12. sample_path = "test/data_for_tests/tutorial_sample_dataset.csv"
  13. dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'),
  14. sep='\t')
  15. print(len(dataset))
  16. print(dataset[0])
  17. dataset.append(Instance(raw_sentence='fake data', label='0'))
  18. dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')
  19. # label转int
  20. dataset.apply(lambda x: int(x['label']), new_field_name='label')
  21. # 使用空格分割句子
  22. def split_sent(ins):
  23. return ins['raw_sentence'].split()
  24. dataset.apply(split_sent, new_field_name='words')
  25. # 增加长度信息
  26. dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')
  27. print(len(dataset))
  28. print(dataset[0])
  29. # DataSet.drop(func)筛除数据
  30. dataset.drop(lambda x: x['seq_len'] <= 3)
  31. print(len(dataset))
  32. # 设置DataSet中,哪些field要转为tensor
  33. # set target,loss或evaluate中的golden,计算loss,模型评估时使用
  34. dataset.set_target("label")
  35. # set input,模型forward时使用
  36. dataset.set_input("words")
  37. # 分出测试集、训练集
  38. test_data, train_data = dataset.split(0.5)
  39. print(len(test_data))
  40. print(len(train_data))
  41. # 构建词表, Vocabulary.add(word)
  42. vocab = Vocabulary(min_freq=2)
  43. train_data.apply(lambda x: [vocab.add(word) for word in x['words']])
  44. vocab.build_vocab()
  45. # index句子, Vocabulary.to_index(word)
  46. train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')
  47. test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')
  48. print(test_data[0])
  49. model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)
  50. from fastNLP import Trainer
  51. from copy import deepcopy
  52. # 更改DataSet中对应field的名称,要以模型的forward等参数名一致
  53. train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致
  54. train_data.rename_field('label', 'label_seq')
  55. test_data.rename_field('words', 'word_seq')
  56. test_data.rename_field('label', 'label_seq')
  57. # 实例化Trainer,传入模型和数据,进行训练
  58. copy_model = deepcopy(model)
  59. overfit_trainer = Trainer(train_data=test_data, model=copy_model,
  60. loss=CrossEntropyLoss(pred="output", target="label_seq"),
  61. metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4,
  62. dev_data=test_data, save_path="./save")
  63. overfit_trainer.train()
  64. trainer = Trainer(train_data=train_data, model=model,
  65. loss=CrossEntropyLoss(pred="output", target="label_seq"),
  66. metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4,
  67. dev_data=test_data, save_path="./save")
  68. trainer.train()
  69. print('Train finished!')
  70. # 使用fastNLP的Tester测试脚本
  71. tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"),
  72. batch_size=4)
  73. acc = tester.test()
  74. print(acc)