You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_1_minute_tutorial.rst 2.8 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. FastNLP 1分钟上手教程
  2. =====================
  3. 教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_1min_tutorial.ipynb
  4. step 1
  5. ------
  6. 读取数据集
  7. .. code:: ipython3
  8. from fastNLP import DataSet
  9. # linux_path = "../test/data_for_tests/tutorial_sample_dataset.csv"
  10. win_path = "C:\\Users\zyfeng\Desktop\FudanNLP\\fastNLP\\test\\data_for_tests\\tutorial_sample_dataset.csv"
  11. ds = DataSet.read_csv(win_path, headers=('raw_sentence', 'label'), sep='\t')
  12. step 2
  13. ------
  14. 数据预处理 1. 类型转换 2. 切分验证集 3. 构建词典
  15. .. code:: ipython3
  16. # 将所有数字转为小写
  17. ds.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')
  18. # label转int
  19. ds.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True)
  20. def split_sent(ins):
  21. return ins['raw_sentence'].split()
  22. ds.apply(split_sent, new_field_name='words', is_input=True)
  23. .. code:: ipython3
  24. # 分割训练集/验证集
  25. train_data, dev_data = ds.split(0.3)
  26. print("Train size: ", len(train_data))
  27. print("Test size: ", len(dev_data))
  28. .. parsed-literal::
  29. Train size: 54
  30. Test size: 23
  31. .. code:: ipython3
  32. from fastNLP import Vocabulary
  33. vocab = Vocabulary(min_freq=2)
  34. train_data.apply(lambda x: [vocab.add(word) for word in x['words']])
  35. # index句子, Vocabulary.to_index(word)
  36. train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)
  37. dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)
  38. step 3
  39. ------
  40. 定义模型
  41. .. code:: ipython3
  42. from fastNLP.models import CNNText
  43. model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)
  44. step 4
  45. ------
  46. 开始训练
  47. .. code:: ipython3
  48. from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric
  49. trainer = Trainer(model=model,
  50. train_data=train_data,
  51. dev_data=dev_data,
  52. loss=CrossEntropyLoss(),
  53. metrics=AccuracyMetric()
  54. )
  55. trainer.train()
  56. print('Train finished!')
  57. .. parsed-literal::
  58. training epochs started 2018-12-07 14:03:41
  59. .. parsed-literal::
  60. HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6), HTML(value='')), layout=Layout(display='i…
  61. .. parsed-literal::
  62. Epoch 1/3. Step:2/6. AccuracyMetric: acc=0.26087
  63. Epoch 2/3. Step:4/6. AccuracyMetric: acc=0.347826
  64. Epoch 3/3. Step:6/6. AccuracyMetric: acc=0.608696
  65. Train finished!
  66. 本教程结束。更多操作请参考进阶教程。
  67. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~