You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

extend_1_bert_embedding.rst 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. ==============================
  2. BertEmbedding的各种用法
  3. ==============================
  4. Bert自从在 `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding <https://arxiv.org/abs/1810.04805>`_
  5. 中被提出后,因其性能卓越受到了极大的关注,在这里我们展示一下在fastNLP中如何使用Bert进行各类任务。其中中文Bert我们使用的模型的权重来自于
  6. `中文Bert预训练 <https://github.com/ymcui/Chinese-BERT-wwm>`_ 。
  7. 为了方便大家的使用,fastNLP提供了预训练的Embedding权重及数据集的自动下载,支持自动下载的Embedding和数据集见
  8. `数据集 <https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?tab=fed5xh&c=D42A0AC0>`_ 。或您可从 :doc:`/tutorials/tutorial_3_embedding` 与
  9. :doc:`/tutorials/tutorial_4_load_dataset` 了解更多相关信息。
  10. ----------------------------------
  11. 中文任务
  12. ----------------------------------
  13. 下面我们将介绍通过使用Bert来进行文本分类, 中文命名实体识别, 文本匹配, 中文问答。
  14. .. note::
  15. 本教程必须使用 GPU 进行实验,并且会花费大量的时间
  16. 1. 使用Bert进行文本分类
  17. ----------------------------------
  18. 文本分类是指给定一段文字,判定其所属的类别。例如下面的文本情感分类
  19. .. code-block:: text
  20. 1, 商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!
  21. 这里我们使用fastNLP提供自动下载的微博分类进行测试
  22. .. code-block:: python
  23. from fastNLP.io import WeiboSenti100kPipe
  24. from fastNLP.embeddings import BertEmbedding
  25. from fastNLP.models import BertForSequenceClassification
  26. from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
  27. import torch
  28. data_bundle =WeiboSenti100kPipe().process_from_file()
  29. data_bundle.rename_field('chars', 'words')
  30. # 载入BertEmbedding
  31. embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)
  32. # 载入模型
  33. model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))
  34. # 训练模型
  35. device = 0 if torch.cuda.is_available() else 'cpu'
  36. trainer = Trainer(data_bundle.get_dataset('train'), model,
  37. optimizer=Adam(model_params=model.parameters(), lr=2e-5),
  38. loss=CrossEntropyLoss(), device=device,
  39. batch_size=8, dev_data=data_bundle.get_dataset('dev'),
  40. metrics=AccuracyMetric(), n_epochs=2, print_every=1)
  41. trainer.train()
  42. # 测试结果
  43. from fastNLP import Tester
  44. tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric())
  45. tester.test()
  46. 输出结果::
  47. In Epoch:1/Step:12499, got best dev performance:
  48. AccuracyMetric: acc=0.9838
  49. Reloaded the best model.
  50. Evaluate data in 63.84 seconds!
  51. [tester]
  52. AccuracyMetric: acc=0.9815
  53. 2. 使用Bert进行命名实体识别
  54. ----------------------------------
  55. 命名实体识别是给定一句话,标记出其中的实体。一般序列标注的任务都使用conll格式,conll格式是至一行中通过制表符分隔不同的内容,使用空行分隔
  56. 两句话,例如下面的例子
  57. .. code-block:: text
  58. 中 B-ORG
  59. 共 I-ORG
  60. 中 I-ORG
  61. 央 I-ORG
  62. 致 O
  63. 中 B-ORG
  64. 国 I-ORG
  65. 致 I-ORG
  66. 公 I-ORG
  67. 党 I-ORG
  68. 十 I-ORG
  69. 一 I-ORG
  70. 大 I-ORG
  71. 的 O
  72. 贺 O
  73. 词 O
  74. 这部分内容请参考 :doc:`/tutorials/序列标注`
  75. 3. 使用Bert进行文本匹配
  76. ----------------------------------
  77. 文本匹配任务是指给定两句话判断他们的关系。比如,给定两句话判断前一句是否和后一句具有因果关系或是否是矛盾关系;或者给定两句话判断两句话是否
  78. 具有相同的意思。这里我们使用
  79. .. code-block:: python
  80. from fastNLP.io import CNXNLIBertPipe
  81. from fastNLP.embeddings import BertEmbedding
  82. from fastNLP.models import BertForSentenceMatching
  83. from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
  84. from fastNLP.core.optimizer import AdamW
  85. from fastNLP.core.callback import WarmupCallback
  86. from fastNLP import Tester
  87. import torch
  88. data_bundle = CNXNLIBertPipe().process_from_file()
  89. data_bundle.rename_field('chars', 'words')
  90. print(data_bundle)
  91. # 载入BertEmbedding
  92. embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)
  93. # 载入模型
  94. model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))
  95. # 训练模型
  96. callbacks = [WarmupCallback(warmup=0.1, schedule='linear'), ]
  97. device = 0 if torch.cuda.is_available() else 'cpu'
  98. trainer = Trainer(data_bundle.get_dataset('train'), model,
  99. optimizer=AdamW(params=model.parameters(), lr=4e-5),
  100. loss=CrossEntropyLoss(), device=device,
  101. batch_size=8, dev_data=data_bundle.get_dataset('dev'),
  102. metrics=AccuracyMetric(), n_epochs=5, print_every=1,
  103. update_every=8, callbacks=callbacks)
  104. trainer.train()
  105. tester = Tester(data_bundle.get_dataset('test'), model, batch_size=8, metrics=AccuracyMetric())
  106. tester.test()
  107. 运行结果::
  108. In Epoch:3/Step:73632, got best dev performance:
  109. AccuracyMetric: acc=0.781928
  110. Reloaded the best model.
  111. Evaluate data in 18.54 seconds!
  112. [tester]
  113. AccuracyMetric: acc=0.783633
  114. 4. 使用Bert进行中文问答
  115. ----------------------------------
  116. 问答任务是给定一段内容,以及一个问题,需要从这段内容中找到答案。
  117. 例如::
  118. "context": "锣鼓经是大陆传统器乐及戏曲里面常用的打击乐记谱方法,以中文字的声音模拟敲击乐的声音,纪录打击乐的各种不同的演奏方法。常
  119. 用的节奏型称为「锣鼓点」。而锣鼓是戏曲节奏的支柱,除了加强演员身段动作的节奏感,也作为音乐的引子和尾声,提示音乐的板式和速度,以及
  120. 作为唱腔和念白的伴奏,令诗句的韵律更加抑扬顿锉,段落分明。锣鼓的运用有约定俗成的程式,依照角色行当的身份、性格、情绪以及环境,配合
  121. 相应的锣鼓点。锣鼓亦可以模仿大自然的音响效果,如雷电、波浪等等。戏曲锣鼓所运用的敲击乐器主要分为鼓、锣、钹和板四类型:鼓类包括有单
  122. 皮鼓(板鼓)、大鼓、大堂鼓(唐鼓)、小堂鼓、怀鼓、花盆鼓等;锣类有大锣、小锣(手锣)、钲锣、筛锣、马锣、镗锣、云锣;钹类有铙钹、大
  123. 钹、小钹、水钹、齐钹、镲钹、铰子、碰钟等;打拍子用的檀板、木鱼、梆子等。因为京剧的锣鼓通常由四位乐师负责,又称为四大件,领奏的师
  124. 傅称为:「鼓佬」,其职责有如西方乐队的指挥,负责控制速度以及利用各种手势提示乐师演奏不同的锣鼓点。粤剧吸收了部份京剧的锣鼓,但以木鱼
  125. 和沙的代替了京剧的板和鼓,作为打拍子的主要乐器。以下是京剧、昆剧和粤剧锣鼓中乐器对应的口诀用字:",
  126. "question": "锣鼓经是什么?",
  127. "answers": [
  128. {
  129. "text": "大陆传统器乐及戏曲里面常用的打击乐记谱方法",
  130. "answer_start": 4
  131. },
  132. {
  133. "text": "大陆传统器乐及戏曲里面常用的打击乐记谱方法",
  134. "answer_start": 4
  135. },
  136. {
  137. "text": "大陆传统器乐及戏曲里面常用的打击乐记谱方法",
  138. "answer_start": 4
  139. }
  140. ]
  141. 您可以通过以下的代码训练 (原文代码:`CMRC2018 <https://github.com/ymcui/cmrc2018>`_)
  142. .. code-block:: python
  143. from fastNLP.embeddings import BertEmbedding
  144. from fastNLP.models import BertForQuestionAnswering
  145. from fastNLP.core.losses import CMRC2018Loss
  146. from fastNLP.core.metrics import CMRC2018Metric
  147. from fastNLP.io.pipe.qa import CMRC2018BertPipe
  148. from fastNLP import Trainer, BucketSampler
  149. from fastNLP import WarmupCallback, GradientClipCallback
  150. from fastNLP.core.optimizer import AdamW
  151. import torch
  152. data_bundle = CMRC2018BertPipe().process_from_file()
  153. data_bundle.rename_field('chars', 'words')
  154. print(data_bundle)
  155. embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn', requires_grad=True, include_cls_sep=False, auto_truncate=True,
  156. dropout=0.5, word_dropout=0.01)
  157. model = BertForQuestionAnswering(embed)
  158. loss = CMRC2018Loss()
  159. metric = CMRC2018Metric()
  160. wm_callback = WarmupCallback(schedule='linear')
  161. gc_callback = GradientClipCallback(clip_value=1, clip_type='norm')
  162. callbacks = [wm_callback, gc_callback]
  163. optimizer = AdamW(model.parameters(), lr=5e-5)
  164. device = 0 if torch.cuda.is_available() else 'cpu'
  165. trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,
  166. sampler=BucketSampler(seq_len_field_name='context_len'),
  167. dev_data=data_bundle.get_dataset('dev'), metrics=metric,
  168. callbacks=callbacks, device=device, batch_size=6, num_workers=2, n_epochs=2, print_every=1,
  169. test_use_tqdm=False, update_every=10)
  170. trainer.train(load_best_model=False)
  171. 训练结果(和原论文中报道的基本一致)::
  172. In Epoch:2/Step:1692, got best dev performance:
  173. CMRC2018Metric: f1=85.61, em=66.08
  174. ----------------------------------
  175. 代码下载
  176. ----------------------------------
  177. `点击下载 IPython Notebook 文件 <https://sourcegraph.com/github.com/fastnlp/fastNLP@master/-/raw/tutorials/extend_1_bert_embedding.ipynb>`_)