|
|
@@ -0,0 +1,834 @@ |
|
|
|
{ |
|
|
|
"cells": [ |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"## 文本分类(Text classification)\n", |
|
|
|
"文本分类任务是将一句话或一段话划分到某个具体的类别。比如垃圾邮件识别,文本情绪分类等。\n", |
|
|
|
"\n", |
|
|
|
"Example:: \n", |
|
|
|
"1,商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"其中开头的1是只这条评论的标签,表示是正面的情绪。我们将使用到的数据可以通过http://dbcloud.irocn.cn:8989/api/public/dl/dataset/chn_senti_corp.zip 下载并解压,当然也可以通过fastNLP自动下载该数据。\n", |
|
|
|
"\n", |
|
|
|
"数据中的内容如下图所示。接下来,我们将用fastNLP在这个数据上训练一个分类网络。" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"## 步骤\n", |
|
|
|
"一共有以下的几个步骤 \n", |
|
|
|
"(1) 读取数据 \n", |
|
|
|
"(2) 预处理数据 \n", |
|
|
|
"(3) 选择预训练词向量 \n", |
|
|
|
"(4) 创建模型 \n", |
|
|
|
"(5) 训练模型 " |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### (1) 读取数据\n", |
|
|
|
"fastNLP提供多种数据的自动下载与自动加载功能,对于这里我们要用到的数据,我们可以用\\ref{Loader}自动下载并加载该数据。更多有关Loader的使用可以参考\\ref{Loader}" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 1, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"from fastNLP.io import ChnSentiCorpLoader\n", |
|
|
|
"\n", |
|
|
|
"loader = ChnSentiCorpLoader() # 初始化一个中文情感分类的loader\n", |
|
|
|
"data_dir = loader.download() # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回\n", |
|
|
|
"data_bundle = loader.load(data_dir) # 这一行代码将从{data_dir}处读取数据至DataBundle" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"DataBundle的相关介绍,可以参考\\ref{}。我们可以打印该data_bundle的基本信息。" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 2, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"In total 3 datasets:\n", |
|
|
|
"\tdev has 1200 instances.\n", |
|
|
|
"\ttrain has 9600 instances.\n", |
|
|
|
"\ttest has 1200 instances.\n", |
|
|
|
"In total 0 vocabs:\n", |
|
|
|
"\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"print(data_bundle)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"可以看出,该data_bundle中一个含有三个\\ref{DataSet}。通过下面的代码,我们可以查看DataSet的基本情况" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 6, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"DataSet({'raw_chars': 选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般 type=str,\n", |
|
|
|
"'target': 1 type=str},\n", |
|
|
|
"{'raw_chars': 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错 type=str,\n", |
|
|
|
"'target': 1 type=str})\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"print(data_bundle.get_dataset('train')[:2]) # 查看Train集前两个sample" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### (2) 预处理数据\n", |
|
|
|
"在NLP任务中,预处理一般包括: (a)将一整句话切分成汉字或者词; (b)将文本转换为index \n", |
|
|
|
"\n", |
|
|
|
"fastNLP中也提供了多种数据集的处理类,这里我们直接使用fastNLP的ChnSentiCorpPipe。更多关于Pipe的说明可以参考\\ref{Pipe}。" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 3, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"from fastNLP.io import ChnSentiCorpPipe\n", |
|
|
|
"\n", |
|
|
|
"pipe = ChnSentiCorpPipe()\n", |
|
|
|
"data_bundle = pipe.process(data_bundle) # 所有的Pipe都实现了process()方法,且输入输出都为DataBundle类型" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 4, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"In total 3 datasets:\n", |
|
|
|
"\tdev has 1200 instances.\n", |
|
|
|
"\ttrain has 9600 instances.\n", |
|
|
|
"\ttest has 1200 instances.\n", |
|
|
|
"In total 2 vocabs:\n", |
|
|
|
"\tchars has 4409 entries.\n", |
|
|
|
"\ttarget has 2 entries.\n", |
|
|
|
"\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"print(data_bundle) # 打印data_bundle,查看其变化" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"可以看到除了之前已经包含的3个\\ref{DataSet}, 还新增了两个\\ref{Vocabulary}。我们可以打印DataSet中的内容" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 5, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"DataSet({'raw_chars': 选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般 type=str,\n", |
|
|
|
"'target': 1 type=int,\n", |
|
|
|
"'chars': [338, 464, 1400, 784, 468, 739, 3, 289, 151, 21, 5, 88, 143, 2, 9, 81, 134, 2573, 766, 233, 196, 23, 536, 342, 297, 2, 405, 698, 132, 281, 74, 744, 1048, 74, 420, 387, 74, 412, 433, 74, 2021, 180, 8, 219, 1929, 213, 4, 34, 31, 96, 363, 8, 230, 2, 66, 18, 229, 331, 768, 4, 11, 1094, 479, 17, 35, 593, 3, 1126, 967, 2, 151, 245, 12, 44, 2, 6, 52, 260, 263, 635, 5, 152, 162, 4, 11, 336, 3, 154, 132, 5, 236, 443, 3, 2, 18, 229, 761, 700, 4, 11, 48, 59, 653, 2, 8, 230] type=list,\n", |
|
|
|
"'seq_len': 106 type=int},\n", |
|
|
|
"{'raw_chars': 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错 type=str,\n", |
|
|
|
"'target': 1 type=int,\n", |
|
|
|
"'chars': [50, 133, 20, 135, 945, 520, 343, 24, 3, 301, 176, 350, 86, 785, 2, 456, 24, 461, 163, 443, 128, 109, 6, 47, 7, 2, 916, 152, 162, 524, 296, 44, 301, 176, 2, 1384, 524, 296, 259, 88, 143, 2, 92, 67, 26, 12, 277, 269, 2, 188, 223, 26, 228, 83, 6, 63] type=list,\n", |
|
|
|
"'seq_len': 56 type=int})\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"print(data_bundle.get_dataset('train')[:2])" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"新增了一列为数字列表的chars,以及变为数字的target列。可以看出这两列的名称和刚好与data_bundle中两个Vocabulary的名称是一致的,我们可以打印一下Vocabulary看一下里面的内容。" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 6, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Vocabulary(['选', '择', '珠', '江', '花']...)\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"char_vocab = data_bundle.get_vocab('chars')\n", |
|
|
|
"print(char_vocab)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"Vocabulary是一个记录着词语与index之间映射关系的类,比如" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 7, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"'选'的index是338\n", |
|
|
|
"index:338对应的汉字是选\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"index = char_vocab.to_index('选')\n", |
|
|
|
"print(\"'选'的index是{}\".format(index)) # 这个值与上面打印出来的第一个instance的chars的第一个index是一致的\n", |
|
|
|
"print(\"index:{}对应的汉字是{}\".format(index, char_vocab.to_word(index))) " |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### (3) 选择预训练词向量 \n", |
|
|
|
"由于Word2vec, Glove, Elmo, Bert等预训练模型可以增强模型的性能,所以在训练具体任务前,选择合适的预训练词向量非常重要。在fastNLP中我们提供了多种Embedding使得加载这些预训练模型的过程变得更加便捷。更多关于Embedding的说明可以参考\\ref{Embedding}。这里我们先给出一个使用word2vec的中文汉字预训练的示例,之后再给出一个使用Bert的文本分类。这里使用的预训练词向量为'cn-fastnlp-100d',fastNLP将自动下载该embedding至本地缓存,fastNLP支持使用名字指定的Embedding以及相关说明可以参见\\ref{Embedding}" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 8, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Found 4321 out of 4409 words in the pre-training embedding.\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"from fastNLP.embeddings import StaticEmbedding\n", |
|
|
|
"\n", |
|
|
|
"word2vec_embed = StaticEmbedding(char_vocab, model_dir_or_name='cn-char-fastnlp-100d')" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### (4) 创建模型\n", |
|
|
|
"这里我们使用到的模型结构如下所示,补图" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 9, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"from torch import nn\n", |
|
|
|
"from fastNLP.modules import LSTM\n", |
|
|
|
"import torch\n", |
|
|
|
"\n", |
|
|
|
"# 定义模型\n", |
|
|
|
"class BiLSTMMaxPoolCls(nn.Module):\n", |
|
|
|
" def __init__(self, embed, num_classes, hidden_size=400, num_layers=1, dropout=0.3):\n", |
|
|
|
" super().__init__()\n", |
|
|
|
" self.embed = embed\n", |
|
|
|
" \n", |
|
|
|
" self.lstm = LSTM(self.embed.embedding_dim, hidden_size=hidden_size//2, num_layers=num_layers, \n", |
|
|
|
" batch_first=True, bidirectional=True)\n", |
|
|
|
" self.dropout_layer = nn.Dropout(dropout)\n", |
|
|
|
" self.fc = nn.Linear(hidden_size, num_classes)\n", |
|
|
|
" \n", |
|
|
|
" def forward(self, chars, seq_len): # 这里的名称必须和DataSet中相应的field对应,比如之前我们DataSet中有chars,这里就必须为chars\n", |
|
|
|
" # chars:[batch_size, max_len]\n", |
|
|
|
" # seq_len: [batch_size, ]\n", |
|
|
|
" chars = self.embed(chars)\n", |
|
|
|
" outputs, _ = self.lstm(chars, seq_len)\n", |
|
|
|
" outputs = self.dropout_layer(outputs)\n", |
|
|
|
" outputs, _ = torch.max(outputs, dim=1)\n", |
|
|
|
" outputs = self.fc(outputs)\n", |
|
|
|
" \n", |
|
|
|
" return {'pred':outputs} # [batch_size,], 返回值必须是dict类型,且预测值的key建议设为pred\n", |
|
|
|
"\n", |
|
|
|
"# 初始化模型\n", |
|
|
|
"model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### (5) 训练模型\n", |
|
|
|
"fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所以在初始化Trainer的时候需要指定loss类型),梯度更新(所以在初始化Trainer的时候需要提供优化器optimizer)以及在验证集上的性能验证(所以在初始化时需要提供一个Metric)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 10, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"input fields after batch(if batch size is 2):\n", |
|
|
|
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", |
|
|
|
"\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n", |
|
|
|
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", |
|
|
|
"target fields after batch(if batch size is 2):\n", |
|
|
|
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", |
|
|
|
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", |
|
|
|
"\n", |
|
|
|
"Evaluate data in 0.01 seconds!\n", |
|
|
|
"training epochs started 2019-09-03-23-57-10\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3000), HTML(value='')), layout=Layout(display…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"\r", |
|
|
|
"Evaluate data in 0.43 seconds!\n", |
|
|
|
"\r", |
|
|
|
"Evaluation on dev at Epoch 1/10. Step:300/3000: \n", |
|
|
|
"\r", |
|
|
|
"AccuracyMetric: acc=0.81\n", |
|
|
|
"\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"\r", |
|
|
|
"Evaluate data in 0.44 seconds!\n", |
|
|
|
"\r", |
|
|
|
"Evaluation on dev at Epoch 2/10. Step:600/3000: \n", |
|
|
|
"\r", |
|
|
|
"AccuracyMetric: acc=0.8675\n", |
|
|
|
"\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"\r", |
|
|
|
"Evaluate data in 0.44 seconds!\n", |
|
|
|
"\r", |
|
|
|
"Evaluation on dev at Epoch 3/10. Step:900/3000: \n", |
|
|
|
"\r", |
|
|
|
"AccuracyMetric: acc=0.878333\n", |
|
|
|
"\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"\r", |
|
|
|
"Evaluate data in 0.43 seconds!\n", |
|
|
|
"\r", |
|
|
|
"Evaluation on dev at Epoch 4/10. Step:1200/3000: \n", |
|
|
|
"\r", |
|
|
|
"AccuracyMetric: acc=0.873333\n", |
|
|
|
"\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"\r", |
|
|
|
"Evaluate data in 0.44 seconds!\n", |
|
|
|
"\r", |
|
|
|
"Evaluation on dev at Epoch 5/10. Step:1500/3000: \n", |
|
|
|
"\r", |
|
|
|
"AccuracyMetric: acc=0.878333\n", |
|
|
|
"\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"\r", |
|
|
|
"Evaluate data in 0.42 seconds!\n", |
|
|
|
"\r", |
|
|
|
"Evaluation on dev at Epoch 6/10. Step:1800/3000: \n", |
|
|
|
"\r", |
|
|
|
"AccuracyMetric: acc=0.895833\n", |
|
|
|
"\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"\r", |
|
|
|
"Evaluate data in 0.44 seconds!\n", |
|
|
|
"\r", |
|
|
|
"Evaluation on dev at Epoch 7/10. Step:2100/3000: \n", |
|
|
|
"\r", |
|
|
|
"AccuracyMetric: acc=0.8975\n", |
|
|
|
"\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"\r", |
|
|
|
"Evaluate data in 0.43 seconds!\n", |
|
|
|
"\r", |
|
|
|
"Evaluation on dev at Epoch 8/10. Step:2400/3000: \n", |
|
|
|
"\r", |
|
|
|
"AccuracyMetric: acc=0.894167\n", |
|
|
|
"\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"\r", |
|
|
|
"Evaluate data in 0.48 seconds!\n", |
|
|
|
"\r", |
|
|
|
"Evaluation on dev at Epoch 9/10. Step:2700/3000: \n", |
|
|
|
"\r", |
|
|
|
"AccuracyMetric: acc=0.8875\n", |
|
|
|
"\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"\r", |
|
|
|
"Evaluate data in 0.43 seconds!\n", |
|
|
|
"\r", |
|
|
|
"Evaluation on dev at Epoch 10/10. Step:3000/3000: \n", |
|
|
|
"\r", |
|
|
|
"AccuracyMetric: acc=0.895833\n", |
|
|
|
"\n", |
|
|
|
"\r\n", |
|
|
|
"In Epoch:7/Step:2100, got best dev performance:\n", |
|
|
|
"AccuracyMetric: acc=0.8975\n", |
|
|
|
"Reloaded the best model.\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"\r", |
|
|
|
"Evaluate data in 0.34 seconds!\n", |
|
|
|
"[tester] \n", |
|
|
|
"AccuracyMetric: acc=0.8975\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"{'AccuracyMetric': {'acc': 0.8975}}" |
|
|
|
] |
|
|
|
}, |
|
|
|
"execution_count": 10, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "execute_result" |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"from fastNLP import Trainer\n", |
|
|
|
"from fastNLP import CrossEntropyLoss\n", |
|
|
|
"from torch.optim import Adam\n", |
|
|
|
"from fastNLP import AccuracyMetric\n", |
|
|
|
"\n", |
|
|
|
"loss = CrossEntropyLoss()\n", |
|
|
|
"optimizer = Adam(model.parameters(), lr=0.001)\n", |
|
|
|
"metric = AccuracyMetric()\n", |
|
|
|
"device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n", |
|
|
|
"\n", |
|
|
|
"trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n", |
|
|
|
" optimizer=optimizer, batch_size=32, dev_data=data_bundle.get_dataset('dev'),\n", |
|
|
|
" metrics=metric, device=device)\n", |
|
|
|
"trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n", |
|
|
|
"\n", |
|
|
|
"# 在测试集上测试一下模型的性能\n", |
|
|
|
"from fastNLP import Tester\n", |
|
|
|
"print(\"Performance on test is:\")\n", |
|
|
|
"tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n", |
|
|
|
"tester.test()" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### 使用Bert进行文本分类" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 12, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"loading vocabulary file /home/yh/.fastNLP/embedding/bert-chinese-wwm/vocab.txt\n", |
|
|
|
"Load pre-trained BERT parameters from file /home/yh/.fastNLP/embedding/bert-chinese-wwm/chinese_wwm_pytorch.bin.\n", |
|
|
|
"Start to generating word pieces for word.\n", |
|
|
|
"Found(Or segment into word pieces) 4286 words out of 4409.\n", |
|
|
|
"input fields after batch(if batch size is 2):\n", |
|
|
|
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", |
|
|
|
"\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n", |
|
|
|
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", |
|
|
|
"target fields after batch(if batch size is 2):\n", |
|
|
|
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", |
|
|
|
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", |
|
|
|
"\n", |
|
|
|
"Evaluate data in 0.05 seconds!\n", |
|
|
|
"training epochs started 2019-09-04-00-02-37\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3600), HTML(value='')), layout=Layout(display…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"\r", |
|
|
|
"Evaluate data in 15.89 seconds!\n", |
|
|
|
"\r", |
|
|
|
"Evaluation on dev at Epoch 1/3. Step:1200/3600: \n", |
|
|
|
"\r", |
|
|
|
"AccuracyMetric: acc=0.9\n", |
|
|
|
"\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"\r", |
|
|
|
"Evaluate data in 15.92 seconds!\n", |
|
|
|
"\r", |
|
|
|
"Evaluation on dev at Epoch 2/3. Step:2400/3600: \n", |
|
|
|
"\r", |
|
|
|
"AccuracyMetric: acc=0.904167\n", |
|
|
|
"\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"\r", |
|
|
|
"Evaluate data in 15.91 seconds!\n", |
|
|
|
"\r", |
|
|
|
"Evaluation on dev at Epoch 3/3. Step:3600/3600: \n", |
|
|
|
"\r", |
|
|
|
"AccuracyMetric: acc=0.918333\n", |
|
|
|
"\n", |
|
|
|
"\r\n", |
|
|
|
"In Epoch:3/Step:3600, got best dev performance:\n", |
|
|
|
"AccuracyMetric: acc=0.918333\n", |
|
|
|
"Reloaded the best model.\n", |
|
|
|
"Performance on test is:\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"\r", |
|
|
|
"Evaluate data in 29.24 seconds!\n", |
|
|
|
"[tester] \n", |
|
|
|
"AccuracyMetric: acc=0.919167\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"{'AccuracyMetric': {'acc': 0.919167}}" |
|
|
|
] |
|
|
|
}, |
|
|
|
"execution_count": 12, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "execute_result" |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"# 只需要切换一下Embedding即可\n", |
|
|
|
"from fastNLP.embeddings import BertEmbedding\n", |
|
|
|
"\n", |
|
|
|
"# 这里为了演示一下效果,所以默认Bert不更新权重\n", |
|
|
|
"bert_embed = BertEmbedding(char_vocab, model_dir_or_name='cn', auto_truncate=True, requires_grad=False)\n", |
|
|
|
"model = BiLSTMMaxPoolCls(bert_embed, len(data_bundle.get_vocab('target')), )\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"import torch\n", |
|
|
|
"from fastNLP import Trainer\n", |
|
|
|
"from fastNLP import CrossEntropyLoss\n", |
|
|
|
"from torch.optim import Adam\n", |
|
|
|
"from fastNLP import AccuracyMetric\n", |
|
|
|
"\n", |
|
|
|
"loss = CrossEntropyLoss()\n", |
|
|
|
"optimizer = Adam(model.parameters(), lr=2e-5)\n", |
|
|
|
"metric = AccuracyMetric()\n", |
|
|
|
"device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n", |
|
|
|
"\n", |
|
|
|
"trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n", |
|
|
|
" optimizer=optimizer, batch_size=16, dev_data=data_bundle.get_dataset('test'),\n", |
|
|
|
" metrics=metric, device=device, n_epochs=3)\n", |
|
|
|
"trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n", |
|
|
|
"\n", |
|
|
|
"# 在测试集上测试一下模型的性能\n", |
|
|
|
"from fastNLP import Tester\n", |
|
|
|
"print(\"Performance on test is:\")\n", |
|
|
|
"tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n", |
|
|
|
"tester.test()" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [] |
|
|
|
} |
|
|
|
], |
|
|
|
"metadata": { |
|
|
|
"kernelspec": { |
|
|
|
"display_name": "Python 3", |
|
|
|
"language": "python", |
|
|
|
"name": "python3" |
|
|
|
}, |
|
|
|
"language_info": { |
|
|
|
"codemirror_mode": { |
|
|
|
"name": "ipython", |
|
|
|
"version": 3 |
|
|
|
}, |
|
|
|
"file_extension": ".py", |
|
|
|
"mimetype": "text/x-python", |
|
|
|
"name": "python", |
|
|
|
"nbconvert_exporter": "python", |
|
|
|
"pygments_lexer": "ipython3", |
|
|
|
"version": "3.6.7" |
|
|
|
} |
|
|
|
}, |
|
|
|
"nbformat": 4, |
|
|
|
"nbformat_minor": 2 |
|
|
|
} |