You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

extend_1_bert_embedding.rst 3.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. ==============================
  2. BertEmbedding的各种用法
  3. ==============================
  4. fastNLP的BertEmbedding以pytorch-transformer.BertModel的代码为基础,是一个使用BERT对words进行编码的Embedding。
  5. 使用BertEmbedding和fastNLP.models.bert里面模型可以搭建BERT应用到五种下游任务的模型。
  6. 预训练好的Embedding参数及数据集的介绍和自动下载功能见 :doc:`/tutorials/tutorial_3_embedding` 和
  7. :doc:`/tutorials/tutorial_4_load_dataset`
  8. 1. BERT for Squence Classification
  9. ----------------------------------
  10. 在文本分类任务中,我们采用SST数据集作为例子来介绍BertEmbedding的使用方法。
  11. .. code-block:: python
  12. import warnings
  13. import torch
  14. warnings.filterwarnings("ignore")
  15. # 载入数据集
  16. from fastNLP.io import SSTPipe
  17. data_bundle = SSTPipe(subtree=False, train_subtree=False, lower=False, tokenizer='raw').process_from_file()
  18. data_bundle
  19. # 载入BertEmbedding
  20. from fastNLP.embeddings import BertEmbedding
  21. embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True)
  22. # 载入模型
  23. from fastNLP.models import BertForSequenceClassification
  24. model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))
  25. # 训练模型
  26. from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
  27. trainer = Trainer(data_bundle.get_dataset('train'), model,
  28. optimizer=Adam(model_params=model.parameters(), lr=2e-5),
  29. loss=CrossEntropyLoss(), device=[0],
  30. batch_size=64, dev_data=data_bundle.get_dataset('dev'),
  31. metrics=AccuracyMetric(), n_epochs=2, print_every=1)
  32. trainer.train()
  33. # 测试结果并删除模型
  34. from fastNLP import Tester
  35. tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric())
  36. tester.test()
  37. 2. BERT for Sentence Matching
  38. -----------------------------
  39. 在Matching任务中,我们采用RTE数据集作为例子来介绍BertEmbedding的使用方法。
  40. .. code-block:: python
  41. # 载入数据集
  42. from fastNLP.io import RTEBertPipe
  43. data_bundle = RTEBertPipe(lower=False, tokenizer='raw').process_from_file()
  44. # 载入BertEmbedding
  45. from fastNLP.embeddings import BertEmbedding
  46. embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True)
  47. # 载入模型
  48. from fastNLP.models import BertForSentenceMatching
  49. model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))
  50. # 训练模型
  51. from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
  52. trainer = Trainer(data_bundle.get_dataset('train'), model,
  53. optimizer=Adam(model_params=model.parameters(), lr=2e-5),
  54. loss=CrossEntropyLoss(), device=[0],
  55. batch_size=16, dev_data=data_bundle.get_dataset('dev'),
  56. metrics=AccuracyMetric(), n_epochs=2, print_every=1)
  57. trainer.train()