You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

文本分类.ipynb 21 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "## 文本分类(Text classification)\n",
  8. "文本分类任务是将一句话或一段话划分到某个具体的类别。比如垃圾邮件识别,文本情绪分类等。\n",
  9. "\n",
  10. "Example:: \n",
  11. "1,商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!\n",
  12. "\n",
  13. "\n",
  14. "其中开头的1是只这条评论的标签,表示是正面的情绪。我们将使用到的数据可以通过http://dbcloud.irocn.cn:8989/api/public/dl/dataset/chn_senti_corp.zip 下载并解压,当然也可以通过fastNLP自动下载该数据。\n",
  15. "\n",
  16. "数据中的内容如下图所示。接下来,我们将用fastNLP在这个数据上训练一个分类网络。"
  17. ]
  18. },
  19. {
  20. "cell_type": "markdown",
  21. "metadata": {},
  22. "source": [
  23. "![jupyter](./cn_cls_example.png)"
  24. ]
  25. },
  26. {
  27. "cell_type": "markdown",
  28. "metadata": {},
  29. "source": [
  30. "## 步骤\n",
  31. "一共有以下的几个步骤 \n",
  32. "(1) 读取数据 \n",
  33. "(2) 预处理数据 \n",
  34. "(3) 选择预训练词向量 \n",
  35. "(4) 创建模型 \n",
  36. "(5) 训练模型 "
  37. ]
  38. },
  39. {
  40. "cell_type": "markdown",
  41. "metadata": {},
  42. "source": [
  43. "### (1) 读取数据\n",
  44. "fastNLP提供多种数据的自动下载与自动加载功能,对于这里我们要用到的数据,我们可以用\\ref{Loader}自动下载并加载该数据。更多有关Loader的使用可以参考\\ref{Loader}"
  45. ]
  46. },
  47. {
  48. "cell_type": "code",
  49. "execution_count": null,
  50. "metadata": {},
  51. "outputs": [],
  52. "source": [
  53. "from fastNLP.io import ChnSentiCorpLoader\n",
  54. "\n",
  55. "loader = ChnSentiCorpLoader() # 初始化一个中文情感分类的loader\n",
  56. "data_dir = loader.download() # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回\n",
  57. "data_bundle = loader.load(data_dir) # 这一行代码将从{data_dir}处读取数据至DataBundle"
  58. ]
  59. },
  60. {
  61. "cell_type": "markdown",
  62. "metadata": {},
  63. "source": [
  64. "DataBundle的相关介绍,可以参考\\ref{}。我们可以打印该data_bundle的基本信息。"
  65. ]
  66. },
  67. {
  68. "cell_type": "code",
  69. "execution_count": null,
  70. "metadata": {},
  71. "outputs": [],
  72. "source": [
  73. "print(data_bundle)"
  74. ]
  75. },
  76. {
  77. "cell_type": "markdown",
  78. "metadata": {},
  79. "source": [
  80. "可以看出,该data_bundle中一个含有三个\\ref{DataSet}。通过下面的代码,我们可以查看DataSet的基本情况"
  81. ]
  82. },
  83. {
  84. "cell_type": "code",
  85. "execution_count": null,
  86. "metadata": {},
  87. "outputs": [],
  88. "source": [
  89. "print(data_bundle.get_dataset('train')[:2]) # 查看Train集前两个sample"
  90. ]
  91. },
  92. {
  93. "cell_type": "markdown",
  94. "metadata": {},
  95. "source": [
  96. "### (2) 预处理数据\n",
  97. "在NLP任务中,预处理一般包括: (a)将一整句话切分成汉字或者词; (b)将文本转换为index \n",
  98. "\n",
  99. "fastNLP中也提供了多种数据集的处理类,这里我们直接使用fastNLP的ChnSentiCorpPipe。更多关于Pipe的说明可以参考\\ref{Pipe}。"
  100. ]
  101. },
  102. {
  103. "cell_type": "code",
  104. "execution_count": null,
  105. "metadata": {},
  106. "outputs": [],
  107. "source": [
  108. "from fastNLP.io import ChnSentiCorpPipe\n",
  109. "\n",
  110. "pipe = ChnSentiCorpPipe()\n",
  111. "data_bundle = pipe.process(data_bundle) # 所有的Pipe都实现了process()方法,且输入输出都为DataBundle类型"
  112. ]
  113. },
  114. {
  115. "cell_type": "code",
  116. "execution_count": null,
  117. "metadata": {},
  118. "outputs": [],
  119. "source": [
  120. "print(data_bundle) # 打印data_bundle,查看其变化"
  121. ]
  122. },
  123. {
  124. "cell_type": "markdown",
  125. "metadata": {},
  126. "source": [
  127. "可以看到除了之前已经包含的3个\\ref{DataSet}, 还新增了两个\\ref{Vocabulary}。我们可以打印DataSet中的内容"
  128. ]
  129. },
  130. {
  131. "cell_type": "code",
  132. "execution_count": null,
  133. "metadata": {},
  134. "outputs": [],
  135. "source": [
  136. "print(data_bundle.get_dataset('train')[:2])"
  137. ]
  138. },
  139. {
  140. "cell_type": "markdown",
  141. "metadata": {},
  142. "source": [
  143. "新增了一列为数字列表的chars,以及变为数字的target列。可以看出这两列的名称和刚好与data_bundle中两个Vocabulary的名称是一致的,我们可以打印一下Vocabulary看一下里面的内容。"
  144. ]
  145. },
  146. {
  147. "cell_type": "code",
  148. "execution_count": null,
  149. "metadata": {},
  150. "outputs": [],
  151. "source": [
  152. "char_vocab = data_bundle.get_vocab('chars')\n",
  153. "print(char_vocab)"
  154. ]
  155. },
  156. {
  157. "cell_type": "markdown",
  158. "metadata": {},
  159. "source": [
  160. "Vocabulary是一个记录着词语与index之间映射关系的类,比如"
  161. ]
  162. },
  163. {
  164. "cell_type": "code",
  165. "execution_count": null,
  166. "metadata": {},
  167. "outputs": [],
  168. "source": [
  169. "index = char_vocab.to_index('选')\n",
  170. "print(\"'选'的index是{}\".format(index)) # 这个值与上面打印出来的第一个instance的chars的第一个index是一致的\n",
  171. "print(\"index:{}对应的汉字是{}\".format(index, char_vocab.to_word(index))) "
  172. ]
  173. },
  174. {
  175. "cell_type": "markdown",
  176. "metadata": {},
  177. "source": [
  178. "### (3) 选择预训练词向量 \n",
  179. "由于Word2vec, Glove, Elmo, Bert等预训练模型可以增强模型的性能,所以在训练具体任务前,选择合适的预训练词向量非常重要。在fastNLP中我们提供了多种Embedding使得加载这些预训练模型的过程变得更加便捷。更多关于Embedding的说明可以参考\\ref{Embedding}。这里我们先给出一个使用word2vec的中文汉字预训练的示例,之后再给出一个使用Bert的文本分类。这里使用的预训练词向量为'cn-fastnlp-100d',fastNLP将自动下载该embedding至本地缓存,fastNLP支持使用名字指定的Embedding以及相关说明可以参见\\ref{Embedding}"
  180. ]
  181. },
  182. {
  183. "cell_type": "code",
  184. "execution_count": null,
  185. "metadata": {},
  186. "outputs": [],
  187. "source": [
  188. "from fastNLP.embeddings import StaticEmbedding\n",
  189. "\n",
  190. "word2vec_embed = StaticEmbedding(char_vocab, model_dir_or_name='cn-char-fastnlp-100d')"
  191. ]
  192. },
  193. {
  194. "cell_type": "markdown",
  195. "metadata": {},
  196. "source": [
  197. "### (4) 创建模型\n",
  198. "这里我们使用到的模型结构如下所示,补图"
  199. ]
  200. },
  201. {
  202. "cell_type": "code",
  203. "execution_count": null,
  204. "metadata": {},
  205. "outputs": [],
  206. "source": [
  207. "from torch import nn\n",
  208. "from fastNLP.modules import LSTM\n",
  209. "import torch\n",
  210. "\n",
  211. "# 定义模型\n",
  212. "class BiLSTMMaxPoolCls(nn.Module):\n",
  213. " def __init__(self, embed, num_classes, hidden_size=400, num_layers=1, dropout=0.3):\n",
  214. " super().__init__()\n",
  215. " self.embed = embed\n",
  216. " \n",
  217. " self.lstm = LSTM(self.embed.embedding_dim, hidden_size=hidden_size//2, num_layers=num_layers, \n",
  218. " batch_first=True, bidirectional=True)\n",
  219. " self.dropout_layer = nn.Dropout(dropout)\n",
  220. " self.fc = nn.Linear(hidden_size, num_classes)\n",
  221. " \n",
  222. " def forward(self, chars, seq_len): # 这里的名称必须和DataSet中相应的field对应,比如之前我们DataSet中有chars,这里就必须为chars\n",
  223. " # chars:[batch_size, max_len]\n",
  224. " # seq_len: [batch_size, ]\n",
  225. " chars = self.embed(chars)\n",
  226. " outputs, _ = self.lstm(chars, seq_len)\n",
  227. " outputs = self.dropout_layer(outputs)\n",
  228. " outputs, _ = torch.max(outputs, dim=1)\n",
  229. " outputs = self.fc(outputs)\n",
  230. " \n",
  231. " return {'pred':outputs} # [batch_size,], 返回值必须是dict类型,且预测值的key建议设为pred\n",
  232. "\n",
  233. "# 初始化模型\n",
  234. "model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))"
  235. ]
  236. },
  237. {
  238. "cell_type": "markdown",
  239. "metadata": {},
  240. "source": [
  241. "### (5) 训练模型\n",
  242. "fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所以在初始化Trainer的时候需要指定loss类型),梯度更新(所以在初始化Trainer的时候需要提供优化器optimizer)以及在验证集上的性能验证(所以在初始化时需要提供一个Metric)"
  243. ]
  244. },
  245. {
  246. "cell_type": "code",
  247. "execution_count": null,
  248. "metadata": {},
  249. "outputs": [],
  250. "source": [
  251. "from fastNLP import Trainer\n",
  252. "from fastNLP import CrossEntropyLoss\n",
  253. "from torch.optim import Adam\n",
  254. "from fastNLP import AccuracyMetric\n",
  255. "\n",
  256. "loss = CrossEntropyLoss()\n",
  257. "optimizer = Adam(model.parameters(), lr=0.001)\n",
  258. "metric = AccuracyMetric()\n",
  259. "device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n",
  260. "\n",
  261. "trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
  262. " optimizer=optimizer, batch_size=32, dev_data=data_bundle.get_dataset('dev'),\n",
  263. " metrics=metric, device=device)\n",
  264. "trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n",
  265. "\n",
  266. "# 在测试集上测试一下模型的性能\n",
  267. "from fastNLP import Tester\n",
  268. "print(\"Performance on test is:\")\n",
  269. "tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
  270. "tester.test()"
  271. ]
  272. },
  273. {
  274. "cell_type": "markdown",
  275. "metadata": {},
  276. "source": [
  277. "### 使用Bert进行文本分类"
  278. ]
  279. },
  280. {
  281. "cell_type": "code",
  282. "execution_count": null,
  283. "metadata": {},
  284. "outputs": [],
  285. "source": [
  286. "# 只需要切换一下Embedding即可\n",
  287. "from fastNLP.embeddings import BertEmbedding\n",
  288. "\n",
  289. "# 这里为了演示一下效果,所以默认Bert不更新权重\n",
  290. "bert_embed = BertEmbedding(char_vocab, model_dir_or_name='cn', auto_truncate=True, requires_grad=False)\n",
  291. "model = BiLSTMMaxPoolCls(bert_embed, len(data_bundle.get_vocab('target')), )\n",
  292. "\n",
  293. "\n",
  294. "import torch\n",
  295. "from fastNLP import Trainer\n",
  296. "from fastNLP import CrossEntropyLoss\n",
  297. "from torch.optim import Adam\n",
  298. "from fastNLP import AccuracyMetric\n",
  299. "\n",
  300. "loss = CrossEntropyLoss()\n",
  301. "optimizer = Adam(model.parameters(), lr=2e-5)\n",
  302. "metric = AccuracyMetric()\n",
  303. "device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n",
  304. "\n",
  305. "trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
  306. " optimizer=optimizer, batch_size=16, dev_data=data_bundle.get_dataset('test'),\n",
  307. " metrics=metric, device=device, n_epochs=3)\n",
  308. "trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n",
  309. "\n",
  310. "# 在测试集上测试一下模型的性能\n",
  311. "from fastNLP import Tester\n",
  312. "print(\"Performance on test is:\")\n",
  313. "tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
  314. "tester.test()"
  315. ]
  316. },
  317. {
  318. "cell_type": "markdown",
  319. "metadata": {},
  320. "source": [
  321. "### 基于词进行文本分类"
  322. ]
  323. },
  324. {
  325. "cell_type": "markdown",
  326. "metadata": {},
  327. "source": [
  328. "由于汉字中没有显示的字与字的边界,一般需要通过分词器先将句子进行分词操作。\n",
  329. "下面的例子演示了如何不基于fastNLP已有的数据读取、预处理代码进行文本分类。"
  330. ]
  331. },
  332. {
  333. "cell_type": "markdown",
  334. "metadata": {},
  335. "source": [
  336. "### (1) 读取数据"
  337. ]
  338. },
  339. {
  340. "cell_type": "markdown",
  341. "metadata": {},
  342. "source": [
  343. "这里我们继续以之前的数据为例,但这次我们不使用fastNLP自带的数据读取代码 "
  344. ]
  345. },
  346. {
  347. "cell_type": "code",
  348. "execution_count": null,
  349. "metadata": {},
  350. "outputs": [],
  351. "source": [
  352. "from fastNLP.io import ChnSentiCorpLoader\n",
  353. "\n",
  354. "loader = ChnSentiCorpLoader() # 初始化一个中文情感分类的loader\n",
  355. "data_dir = loader.download() # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回"
  356. ]
  357. },
  358. {
  359. "cell_type": "markdown",
  360. "metadata": {},
  361. "source": [
  362. "下面我们先定义一个read_file_to_dataset的函数, 即给定一个文件路径,读取其中的内容,并返回一个DataSet。然后我们将所有的DataSet放入到DataBundle对象中来方便接下来的预处理"
  363. ]
  364. },
  365. {
  366. "cell_type": "code",
  367. "execution_count": null,
  368. "metadata": {},
  369. "outputs": [],
  370. "source": [
  371. "import os\n",
  372. "from fastNLP import DataSet, Instance\n",
  373. "from fastNLP.io import DataBundle\n",
  374. "\n",
  375. "\n",
  376. "def read_file_to_dataset(fp):\n",
  377. " ds = DataSet()\n",
  378. " with open(fp, 'r') as f:\n",
  379. " f.readline() # 第一行是title名称,忽略掉\n",
  380. " for line in f:\n",
  381. " line = line.strip()\n",
  382. " target, chars = line.split('\\t')\n",
  383. " ins = Instance(target=target, raw_chars=chars)\n",
  384. " ds.append(ins)\n",
  385. " return ds\n",
  386. "\n",
  387. "data_bundle = DataBundle()\n",
  388. "for name in ['train.tsv', 'dev.tsv', 'test.tsv']:\n",
  389. " fp = os.path.join(data_dir, name)\n",
  390. " ds = read_file_to_dataset(fp)\n",
  391. " data_bundle.set_dataset(name=name.split('.')[0], dataset=ds)\n",
  392. "\n",
  393. "print(data_bundle) # 查看以下数据集的情况\n",
  394. "# In total 3 datasets:\n",
  395. "# train has 9600 instances.\n",
  396. "# dev has 1200 instances.\n",
  397. "# test has 1200 instances."
  398. ]
  399. },
  400. {
  401. "cell_type": "markdown",
  402. "metadata": {},
  403. "source": [
  404. "### (2) 数据预处理"
  405. ]
  406. },
  407. {
  408. "cell_type": "markdown",
  409. "metadata": {},
  410. "source": [
  411. "在这里,我们首先把句子通过 [fastHan](http://gitee.com/fastnlp/fastHan) 进行分词操作,然后创建词表,并将词语转换为序号。"
  412. ]
  413. },
  414. {
  415. "cell_type": "code",
  416. "execution_count": null,
  417. "metadata": {},
  418. "outputs": [],
  419. "source": [
  420. "from fastHan import FastHan\n",
  421. "from fastNLP import Vocabulary\n",
  422. "\n",
  423. "model=FastHan()\n",
  424. "# model.set_device('cuda')\n",
  425. "\n",
  426. "# 定义分词处理操作\n",
  427. "def word_seg(ins):\n",
  428. " raw_chars = ins['raw_chars']\n",
  429. " # 由于有些句子比较长,我们只截取前128个汉字\n",
  430. " raw_words = model(raw_chars[:128], target='CWS')[0]\n",
  431. " return raw_words\n",
  432. "\n",
  433. "for name, ds in data_bundle.iter_datasets():\n",
  434. " # apply函数将对内部的instance依次执行word_seg操作,并把其返回值放入到raw_words这个field\n",
  435. " ds.apply(word_seg, new_field_name='raw_words')\n",
  436. " # 除了apply函数,fastNLP还支持apply_field, apply_more(可同时创建多个field)等操作\n",
  437. " # 同时我们增加一个seq_len的field\n",
  438. " ds.add_seq_len('raw_words')\n",
  439. "\n",
  440. "vocab = Vocabulary()\n",
  441. "\n",
  442. "# 对raw_words列创建词表, 建议把非训练集的dataset放在no_create_entry_dataset参数中\n",
  443. "# 也可以通过add_word(), add_word_lst()等建立词表,请参考http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_2_vocabulary.html\n",
  444. "vocab.from_dataset(data_bundle.get_dataset('train'), field_name='raw_words', \n",
  445. " no_create_entry_dataset=[data_bundle.get_dataset('dev'), \n",
  446. " data_bundle.get_dataset('test')]) \n",
  447. "\n",
  448. "# 将建立好词表的Vocabulary用于对raw_words列建立词表,并把转为序号的列存入到words列\n",
  449. "vocab.index_dataset(data_bundle.get_dataset('train'), data_bundle.get_dataset('dev'), \n",
  450. " data_bundle.get_dataset('test'), field_name='raw_words', new_field_name='words')\n",
  451. "\n",
  452. "# 建立target的词表,target的词表一般不需要padding和unknown\n",
  453. "target_vocab = Vocabulary(padding=None, unknown=None) \n",
  454. "# 一般情况下我们可以只用训练集建立target的词表\n",
  455. "target_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='target') \n",
  456. "# 如果没有传递new_field_name, 则默认覆盖原词表\n",
  457. "target_vocab.index_dataset(data_bundle.get_dataset('train'), data_bundle.get_dataset('dev'), \n",
  458. " data_bundle.get_dataset('test'), field_name='target')\n",
  459. "\n",
  460. "# 我们可以把词表保存到data_bundle中,方便之后使用\n",
  461. "data_bundle.set_vocab(field_name='words', vocab=vocab)\n",
  462. "data_bundle.set_vocab(field_name='target', vocab=target_vocab)\n",
  463. "\n",
  464. "# 我们把words和target分别设置为input和target,这样它们才会在训练循环中被取出并自动padding, 有关这部分更多的内容参考\n",
  465. "# http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_6_datasetiter.html\n",
  466. "data_bundle.set_target('target')\n",
  467. "data_bundle.set_input('words', 'seq_len') # DataSet也有这两个接口\n",
  468. "# 如果某些field,您希望它被设置为target或者input,但是不希望fastNLP自动padding或需要使用特定的padding方式,请参考\n",
  469. "# http://www.fastnlp.top/docs/fastNLP/fastNLP.core.dataset.html\n",
  470. "\n",
  471. "print(data_bundle.get_dataset('train')[:2]) # 我们可以看一下当前dataset的内容\n",
  472. "\n",
  473. "# 由于之后需要使用之前定义的BiLSTMMaxPoolCls模型,所以需要将words这个field修改为chars(因为该模型的forward接受chars参数)\n",
  474. "data_bundle.rename_field('words', 'chars')"
  475. ]
  476. },
  477. {
  478. "cell_type": "markdown",
  479. "metadata": {},
  480. "source": [
  481. "### (3) 选择预训练词向量"
  482. ]
  483. },
  484. {
  485. "cell_type": "markdown",
  486. "metadata": {},
  487. "source": [
  488. "这里我们选择腾讯的预训练中文词向量,可以在 [腾讯词向量](https://ai.tencent.com/ailab/nlp/en/embedding.html) 处下载并解压。这里我们不能直接使用BERT,因为BERT是基于中文字进行预训练的。"
  489. ]
  490. },
  491. {
  492. "cell_type": "code",
  493. "execution_count": null,
  494. "metadata": {},
  495. "outputs": [],
  496. "source": [
  497. "from fastNLP.embeddings import StaticEmbedding\n",
  498. "\n",
  499. "word2vec_embed = StaticEmbedding(data_bundle.get_vocab('words'), \n",
  500. " model_dir_or_name='/path/to/Tencent_AILab_ChineseEmbedding.txt')"
  501. ]
  502. },
  503. {
  504. "cell_type": "code",
  505. "execution_count": null,
  506. "metadata": {},
  507. "outputs": [],
  508. "source": [
  509. "from fastNLP import Trainer\n",
  510. "from fastNLP import CrossEntropyLoss\n",
  511. "from torch.optim import Adam\n",
  512. "from fastNLP import AccuracyMetric\n",
  513. "\n",
  514. "# 初始化模型\n",
  515. "model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))\n",
  516. "\n",
  517. "# 开始训练\n",
  518. "loss = CrossEntropyLoss()\n",
  519. "optimizer = Adam(model.parameters(), lr=0.001)\n",
  520. "metric = AccuracyMetric()\n",
  521. "device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n",
  522. "\n",
  523. "trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
  524. " optimizer=optimizer, batch_size=32, dev_data=data_bundle.get_dataset('dev'),\n",
  525. " metrics=metric, device=device)\n",
  526. "trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n",
  527. "\n",
  528. "# 在测试集上测试一下模型的性能\n",
  529. "from fastNLP import Tester\n",
  530. "print(\"Performance on test is:\")\n",
  531. "tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
  532. "tester.test()"
  533. ]
  534. },
  535. {
  536. "cell_type": "code",
  537. "execution_count": null,
  538. "metadata": {},
  539. "outputs": [],
  540. "source": []
  541. }
  542. ],
  543. "metadata": {
  544. "kernelspec": {
  545. "display_name": "Python 3",
  546. "language": "python",
  547. "name": "python3"
  548. },
  549. "language_info": {
  550. "codemirror_mode": {
  551. "name": "ipython",
  552. "version": 3
  553. },
  554. "file_extension": ".py",
  555. "mimetype": "text/x-python",
  556. "name": "python",
  557. "nbconvert_exporter": "python",
  558. "pygments_lexer": "ipython3",
  559. "version": "3.6.8"
  560. }
  561. },
  562. "nbformat": 4,
  563. "nbformat_minor": 2
  564. }