You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_torch_tutorial.ipynb 54 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "6011adf8",
  6. "metadata": {},
  7. "source": [
  8. "# 10 分钟快速上手 fastNLP torch\n",
  9. "\n",
  10. "在这个例子中,我们将使用BERT来解决conll2003数据集中的命名实体识别任务。"
  11. ]
  12. },
  13. {
  14. "cell_type": "code",
  15. "execution_count": 1,
  16. "id": "e166c051",
  17. "metadata": {},
  18. "outputs": [
  19. {
  20. "name": "stdout",
  21. "output_type": "stream",
  22. "text": [
  23. "--2022-07-07 10:12:29-- https://data.deepai.org/conll2003.zip\n",
  24. "Resolving data.deepai.org (data.deepai.org)... 138.201.36.183\n",
  25. "Connecting to data.deepai.org (data.deepai.org)|138.201.36.183|:443... connected.\n",
  26. "WARNING: cannot verify data.deepai.org's certificate, issued by ‘CN=R3,O=Let's Encrypt,C=US’:\n",
  27. " Issued certificate has expired.\n",
  28. "HTTP request sent, awaiting response... 200 OK\n",
  29. "Length: 982975 (960K) [application/x-zip-compressed]\n",
  30. "Saving to: ‘conll2003.zip’\n",
  31. "\n",
  32. "conll2003.zip 100%[===================>] 959.94K 653KB/s in 1.5s \n",
  33. "\n",
  34. "2022-07-07 10:12:32 (653 KB/s) - ‘conll2003.zip’ saved [982975/982975]\n",
  35. "\n",
  36. "Archive: conll2003.zip\n",
  37. " inflating: conll2003/metadata \n",
  38. " inflating: conll2003/test.txt \n",
  39. " inflating: conll2003/train.txt \n",
  40. " inflating: conll2003/valid.txt \n"
  41. ]
  42. }
  43. ],
  44. "source": [
  45. "# Linux/Mac 下载数据,并解压\n",
  46. "import platform\n",
  47. "if platform.system() != \"Windows\":\n",
  48. " !wget https://data.deepai.org/conll2003.zip --no-check-certificate -O conll2003.zip\n",
  49. " !unzip conll2003.zip -d conll2003\n",
  50. "# Windows用户请通过复制该url到浏览器下载该数据并解压"
  51. ]
  52. },
  53. {
  54. "cell_type": "markdown",
  55. "id": "f7acbf1f",
  56. "metadata": {},
  57. "source": [
  58. "## 目录\n",
  59. "接下来我们将按照以下的内容介绍在如何通过fastNLP减少工程性代码的撰写 \n",
  60. "- 1. 数据加载\n",
  61. "- 2. 数据预处理、数据缓存\n",
  62. "- 3. DataLoader\n",
  63. "- 4. 模型准备\n",
  64. "- 5. Trainer的使用\n",
  65. "- 6. Evaluator的使用\n",
  66. "- 7. 其它【待补充】\n",
  67. " - 7.1 使用多卡进行训练、评测\n",
  68. " - 7.2 使用ZeRO优化\n",
  69. " - 7.3 通过overfit测试快速验证模型\n",
  70. " - 7.4 复杂Monitor的使用\n",
  71. " - 7.5 训练过程中,使用不同的测试函数\n",
  72. " - 7.6 更有效率的Sampler\n",
  73. " - 7.7 保存模型\n",
  74. " - 7.8 断点重训\n",
  75. " - 7.9 使用huggingface datasets\n",
  76. " - 7.10 使用torchmetrics来作为metric\n",
  77. " - 7.11 将预测结果写出到文件\n",
  78. " - 7.12 混合 dataset 训练\n",
  79. " - 7.13 logger的使用\n",
  80. " - 7.14 自定义分布式 Metric 。\n",
  81. " - 7.15 通过batch_step_fn实现R-Drop"
  82. ]
  83. },
  84. {
  85. "cell_type": "markdown",
  86. "id": "0657dfba",
  87. "metadata": {},
  88. "source": [
  89. "#### 1. 数据加载\n",
  90. "目前在``conll2003``目录下有``train.txt``, ``test.txt``与``valid.txt``三个文件,文件的格式为[conll格式](https://universaldependencies.org/format.html),其编码格式为 [BIO](https://blog.csdn.net/HappyRocking/article/details/79716212) 类型。可以通过继承 fastNLP.io.Loader 来简化加载过程,继承了 Loader 函数后,只需要在实现读取单个文件 _load() 函数即可。"
  91. ]
  92. },
  93. {
  94. "cell_type": "code",
  95. "execution_count": 1,
  96. "id": "c557f0ba",
  97. "metadata": {},
  98. "outputs": [],
  99. "source": [
  100. "import sys\n",
  101. "sys.path.append('../..')"
  102. ]
  103. },
  104. {
  105. "cell_type": "code",
  106. "execution_count": 2,
  107. "id": "6f59e438",
  108. "metadata": {},
  109. "outputs": [
  110. {
  111. "data": {
  112. "text/html": [
  113. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  114. "</pre>\n"
  115. ],
  116. "text/plain": [
  117. "\n"
  118. ]
  119. },
  120. "metadata": {},
  121. "output_type": "display_data"
  122. },
  123. {
  124. "name": "stdout",
  125. "output_type": "stream",
  126. "text": [
  127. "In total 3 datasets:\n",
  128. "\ttrain has 14987 instances.\n",
  129. "\ttest has 3684 instances.\n",
  130. "\tdev has 3466 instances.\n",
  131. "\n"
  132. ]
  133. }
  134. ],
  135. "source": [
  136. "from fastNLP import DataSet, Instance\n",
  137. "from fastNLP.io import Loader\n",
  138. "\n",
  139. "\n",
  140. "# 继承Loader之后,我们只需要实现其中_load()方法,_load()方法传入一个文件路径,返回一个fastNLP DataSet对象,其目的是读取一个文件。\n",
  141. "class ConllLoader(Loader):\n",
  142. " def _load(self, path):\n",
  143. " ds = DataSet()\n",
  144. " with open(path, 'r') as f:\n",
  145. " segments = []\n",
  146. " for line in f:\n",
  147. " line = line.strip()\n",
  148. " if line == '': # 如果为空行,说明需要切换到下一句了。\n",
  149. " if segments:\n",
  150. " raw_words = [s[0] for s in segments]\n",
  151. " raw_target = [s[1] for s in segments]\n",
  152. " # 将一个 sample 插入到 DataSet中\n",
  153. " ds.append(Instance(raw_words=raw_words, raw_target=raw_target)) \n",
  154. " segments = []\n",
  155. " else:\n",
  156. " parts = line.split()\n",
  157. " assert len(parts)==4\n",
  158. " segments.append([parts[0], parts[-1]])\n",
  159. " return ds\n",
  160. " \n",
  161. "\n",
  162. "# 直接使用 load() 方法加载数据集, 返回的 data_bundle 是一个 fastNLP.io.DataBundle 对象,该对象相当于将多个 dataset 放置在一起,\n",
  163. "# 可以方便之后的预处理,DataBundle 支持的接口可以在 !!! 查看。\n",
  164. "data_bundle = ConllLoader().load({\n",
  165. " 'train': 'conll2003/train.txt',\n",
  166. " 'test': 'conll2003/test.txt',\n",
  167. " 'dev': 'conll2003/valid.txt'\n",
  168. "})\n",
  169. "\"\"\"\n",
  170. "也可以通过 ConllLoader().load('conll2003/') 来读取,其原理是load()函数将尝试从'conll2003/'文件夹下寻找文件名称中包含了\n",
  171. "'train'、'test'和'dev'的文件,并分别读取将其命名为'train'、'test'和'dev'(如文件夹中同一个关键字出现在了多个文件名中将导致报错,\n",
  172. "此时请通过dict的方式传入路径信息)。但在我们这里的数据里,没有文件包含dev,所以无法直接使用文件夹读取,转而通过dict的方式传入读取的路径,\n",
  173. "该dict的key也将作为读取的数据集的名称,value即对应的文件路径。\n",
  174. "\"\"\"\n",
  175. "\n",
  176. "print(data_bundle) # 打印 data_bundle 可以查看包含的 DataSet \n",
  177. "# data_bundle.get_dataset('train') # 可以获取单个 dataset"
  178. ]
  179. },
  180. {
  181. "cell_type": "markdown",
  182. "id": "57ae314d",
  183. "metadata": {},
  184. "source": [
  185. "#### 2. 数据预处理\n",
  186. "接下来,我们将演示如何通过fastNLP提供的apply函数方便快捷地进行预处理。我们需要进行的预处理操作有: \n",
  187. "(1)使用BertTokenizer将文本转换为index;同时记录每个word被bpe之后第一个bpe的index,用于得到word的hidden state; \n",
  188. "(2)使用[Vocabulary](../fastNLP)来将raw_target转换为序号。 "
  189. ]
  190. },
  191. {
  192. "cell_type": "code",
  193. "execution_count": 3,
  194. "id": "96389988",
  195. "metadata": {},
  196. "outputs": [
  197. {
  198. "data": {
  199. "application/vnd.jupyter.widget-view+json": {
  200. "model_id": "",
  201. "version_major": 2,
  202. "version_minor": 0
  203. },
  204. "text/plain": [
  205. "Output()"
  206. ]
  207. },
  208. "metadata": {},
  209. "output_type": "display_data"
  210. },
  211. {
  212. "data": {
  213. "text/html": [
  214. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  215. ],
  216. "text/plain": []
  217. },
  218. "metadata": {},
  219. "output_type": "display_data"
  220. },
  221. {
  222. "data": {
  223. "application/vnd.jupyter.widget-view+json": {
  224. "model_id": "c3bd41a323c94a41b409d29a5d4079b6",
  225. "version_major": 2,
  226. "version_minor": 0
  227. },
  228. "text/plain": [
  229. "Output()"
  230. ]
  231. },
  232. "metadata": {},
  233. "output_type": "display_data"
  234. },
  235. {
  236. "name": "stderr",
  237. "output_type": "stream",
  238. "text": [
  239. "IOPub message rate exceeded.\n",
  240. "The notebook server will temporarily stop sending output\n",
  241. "to the client in order to avoid crashing it.\n",
  242. "To change this limit, set the config variable\n",
  243. "`--NotebookApp.iopub_msg_rate_limit`.\n",
  244. "\n",
  245. "Current values:\n",
  246. "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
  247. "NotebookApp.rate_limit_window=3.0 (secs)\n",
  248. "\n"
  249. ]
  250. },
  251. {
  252. "data": {
  253. "text/html": [
  254. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  255. ],
  256. "text/plain": []
  257. },
  258. "metadata": {},
  259. "output_type": "display_data"
  260. },
  261. {
  262. "data": {
  263. "text/html": [
  264. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:48:13] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Save cache to <span style=\"color: #800080; text-decoration-color: #800080\">/remote-home/hyan01/exps/fastNLP/fastN</span> <a href=\"file://../../fastNLP/core/utils/cache_results.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">cache_results.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/utils/cache_results.py#332\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">332</span></a>\n",
  265. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #800080; text-decoration-color: #800080\">LP/demo/torch_tutorial/caches/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">c7f74559_cache.pkl.</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  266. "</pre>\n"
  267. ],
  268. "text/plain": [
  269. "\u001b[2;36m[10:48:13]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Save cache to \u001b[35m/remote-home/hyan01/exps/fastNLP/fastN\u001b[0m \u001b]8;id=831330;file://../../fastNLP/core/utils/cache_results.py\u001b\\\u001b[2mcache_results.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=609545;file://../../fastNLP/core/utils/cache_results.py#332\u001b\\\u001b[2m332\u001b[0m\u001b]8;;\u001b\\\n",
  270. "\u001b[2;36m \u001b[0m \u001b[35mLP/demo/torch_tutorial/caches/\u001b[0m\u001b[95mc7f74559_cache.pkl.\u001b[0m \u001b[2m \u001b[0m\n"
  271. ]
  272. },
  273. "metadata": {},
  274. "output_type": "display_data"
  275. }
  276. ],
  277. "source": [
  278. "# fastNLP 中提供了BERT, RoBERTa, GPT, BART 模型,更多的预训练模型请直接使用transformers\n",
  279. "from fastNLP.transformers.torch import BertTokenizer\n",
  280. "from fastNLP import cache_results, Vocabulary\n",
  281. "\n",
  282. "# 使用cache_results来装饰函数,会将函数的返回结果缓存到'caches/{param_hash_id}_cache.pkl'路径中(其中{param_hash_id}是根据\n",
  283. "# 传递给 process_data 函数参数决定的,因此当函数的参数变化时,会再生成新的缓存文件。如果需要重新生成新的缓存,(a) 可以在调用process_data\n",
  284. "# 函数时,额外传入一个_refresh=True的参数; 或者(b)删除相应的缓存文件。此外,保存结果时,cache_results默认还会\n",
  285. "# 记录 process_data 函数源码的hash值,当其源码发生了变动,直接读取缓存会发出警告,以防止在修改预处理代码之后,忘记刷新缓存。)\n",
  286. "@cache_results('caches/cache.pkl')\n",
  287. "def process_data(data_bundle, model_name):\n",
  288. " tokenizer = BertTokenizer.from_pretrained(model_name)\n",
  289. " def bpe(raw_words):\n",
  290. " bpes = [tokenizer.cls_token_id]\n",
  291. " first = [0]\n",
  292. " first_index = 1 # 记录第一个bpe的位置\n",
  293. " for word in raw_words:\n",
  294. " bpe = tokenizer.encode(word, add_special_tokens=False)\n",
  295. " bpes.extend(bpe)\n",
  296. " first.append(first_index)\n",
  297. " first_index += len(bpe)\n",
  298. " bpes.append(tokenizer.sep_token_id)\n",
  299. " first.append(first_index)\n",
  300. " return {'input_ids': bpes, 'input_len': len(bpes), 'first': first, 'first_len': len(raw_words)}\n",
  301. " # 对data_bundle中每个dataset的每一条数据中的raw_words使用bpe函数,并且将返回的结果加入到每条数据中。\n",
  302. " data_bundle.apply_field_more(bpe, field_name='raw_words', num_proc=4)\n",
  303. " # 对应我们还有 apply_field() 函数,该函数和 apply_field_more() 的区别在于传入到 apply_field() 中的函数应该返回一个 field 的\n",
  304. " # 内容(即不需要用dict包裹了)。此外,我们还提供了 data_bundle.apply() ,传入 apply() 的函数需要支持传入一个Instance对象,\n",
  305. " # 更多信息可以参考对应的文档。\n",
  306. " \n",
  307. " # tag的词表,由于这是词表,所以不需要有padding和unk\n",
  308. " tag_vocab = Vocabulary(padding=None, unknown=None)\n",
  309. " # 从 train 数据的 raw_target 中获取建立词表\n",
  310. " tag_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='raw_target')\n",
  311. " # 使用词表将每个 dataset 中的raw_target转为数字,并且将写入到target这个field中\n",
  312. " tag_vocab.index_dataset(data_bundle.datasets.values(), field_name='raw_target', new_field_name='target')\n",
  313. " \n",
  314. " # 可以将 vocabulary 绑定到 data_bundle 上,方便之后使用。\n",
  315. " data_bundle.set_vocab(tag_vocab, field_name='target')\n",
  316. " \n",
  317. " return data_bundle, tokenizer\n",
  318. "\n",
  319. "data_bundle, tokenizer = process_data(data_bundle, 'bert-base-cased', _refresh=True) # 第一次调用耗时较长,第二次调用则会直接读取缓存的文件\n",
  320. "# data_bundle = process_data(data_bundle, 'bert-base-uncased') # 由于参数变化,fastNLP 会再次生成新的缓存文件。 "
  321. ]
  322. },
  323. {
  324. "cell_type": "markdown",
  325. "id": "80036fcd",
  326. "metadata": {},
  327. "source": [
  328. "### 3. DataLoader \n",
  329. "由于现在的深度学习算法大都基于 mini-batch 进行优化,因此需要将多个 sample 组合成一个 batch 再输入到模型之中。在自然语言处理中,不同的 sample 往往长度不一致,需要进行 padding 操作。在fastNLP中,我们使用 fastNLP.TorchDataLoader 帮助用户快速进行 padding ,我们使用了 !!!fastNLP.Collator!!! 对象来进行 pad ,Collator 会在迭代过程中根据第一个 batch 的数据自动判定每个 field 是否可以进行 pad ,可以通过 Collator.set_pad() 函数修改某个 field 的 pad 行为。"
  330. ]
  331. },
  332. {
  333. "cell_type": "code",
  334. "execution_count": 4,
  335. "id": "09494695",
  336. "metadata": {},
  337. "outputs": [],
  338. "source": [
  339. "from fastNLP import prepare_dataloader\n",
  340. "\n",
  341. "# 将 data_bundle 中每个 dataset 取出并构造出相应的 DataLoader 对象。返回的 dls 是一个 dict ,包含了 'train', 'test', 'dev' 三个\n",
  342. "# fastNLP.TorchDataLoader 对象。\n",
  343. "dls = prepare_dataloader(data_bundle, batch_size=24) \n",
  344. "\n",
  345. "\n",
  346. "# fastNLP 将默认尝试对所有 field 都进行 pad ,如果当前 field 是不可 pad 的类型,则不进行pad;如果是可以 pad 的类型\n",
  347. "# 默认使用 0 进行 pad 。\n",
  348. "for dl in dls.values():\n",
  349. " # 可以通过 set_pad 修改 padding 的行为。\n",
  350. " dl.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n",
  351. " # 如果希望忽略某个 field ,可以通过 set_ignore 方法。\n",
  352. " dl.set_ignore('raw_target')\n",
  353. " dl.set_pad('target', pad_val=-100)\n",
  354. "# 另一种设置的方法是,可以在 dls = prepare_dataloader(data_bundle, batch_size=32) 之前直接调用 \n",
  355. "# data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id); data_bundle.set_ignore('raw_target')来进行设置。\n",
  356. "# DataSet 也支持这两个方法。\n",
  357. "# 若此时调用 batch = next(dls['train']),则 batch 是一个 dict ,其中包含了\n",
  358. "# 'input_ids': torch.LongTensor([batch_size, max_len])\n",
  359. "# 'input_len': torch.LongTensor([batch_size])\n",
  360. "# 'first': torch.LongTensor([batch_size, max_len'])\n",
  361. "# 'first_len': torch.LongTensor([batch_size])\n",
  362. "# 'target': torch.LongTensor([batch_size, max_len'-2])\n",
  363. "# 'raw_words': List[List[str]] # 因为无法判断,所以 Collator 不会做任何处理"
  364. ]
  365. },
  366. {
  367. "cell_type": "markdown",
  368. "id": "3583df6d",
  369. "metadata": {},
  370. "source": [
  371. "### 4. 模型准备\n",
  372. "传入给fastNLP的模型,需要有两个特殊的方法``train_step``、``evaluate_step``,前者默认在 fastNLP.Trainer 中进行调用,后者默认在 fastNLP.Evaluator 中调用。如果模型中没有``train_step``方法,则Trainer会直接使用模型的``forward``函数;如果模型没有``evaluate_step``方法,则Evaluator会直接使用模型的``forward``函数。``train_step``方法(或当其不存在时,``forward``方法)的返回值必须为 dict 类型,并且必须包含``loss``这个 key 。\n",
  373. "\n",
  374. "此外fastNLP会使用形参名匹配的方式进行参数传递,例如以下模型\n",
  375. "```python\n",
  376. "class Model(nn.Module):\n",
  377. " def train_step(self, x, y):\n",
  378. " return {'loss': (x-y).abs().mean()}\n",
  379. "```\n",
  380. "fastNLP将尝试从 DataLoader 返回的 batch(假设包含的 key 为 input_ids, target) 中寻找 'x' 和 'y' 这两个 key ,如果没有找到则会报错。有以下的方法可以解决报错\n",
  381. "- 修改 train_step 的参数为(input_ids, target),以保证和 DataLoader 返回的 batch 中的 key 匹配\n",
  382. "- 修改 DataLoader 中返回 batch 的 key 的名字为 (x, y)\n",
  383. "- 在 Trainer 中传入参数 train_input_mapping={'input_ids': 'x', 'target': 'y'} 将输入进行映射,train_input_mapping 也可以是一个函数,更多 train_input_mapping 的介绍可以参考文档。\n",
  384. "\n",
  385. "``evaluate_step``也是使用同样的匹配方式,前两条解决方法是一致的,第三种解决方案中,需要在 Evaluator 中传入 evaluate_input_mapping={'input_ids': 'x', 'target': 'y'}。"
  386. ]
  387. },
  388. {
  389. "cell_type": "code",
  390. "execution_count": 5,
  391. "id": "f131c1a3",
  392. "metadata": {},
  393. "outputs": [
  394. {
  395. "data": {
  396. "text/html": [
  397. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:48:21] </span><span style=\"color: #800000; text-decoration-color: #800000\">WARNING </span> Some weights of the model checkpoint at <a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">modeling_utils.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py#1490\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1490</span></a>\n",
  398. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> bert-base-uncased were not used when initializing <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  399. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertModel: <span style=\"font-weight: bold\">[</span><span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.bias'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  400. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.LayerNorm.weight'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  401. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.seq_relationship.weight'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  402. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.decoder.weight'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  403. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.dense.weight'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  404. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.LayerNorm.bias'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  405. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.dense.bias'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  406. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.seq_relationship.bias'</span><span style=\"font-weight: bold\">]</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  407. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> - This IS expected if you are initializing <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  408. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertModel from the checkpoint of a model trained <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  409. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> on another task or with another architecture <span style=\"font-weight: bold\">(</span>e.g. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  410. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> initializing a BertForSequenceClassification model <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  411. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> from a BertForPreTraining model<span style=\"font-weight: bold\">)</span>. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  412. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> - This IS NOT expected if you are initializing <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  413. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertModel from the checkpoint of a model that you <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  414. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> expect to be exactly identical <span style=\"font-weight: bold\">(</span>initializing a <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  415. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertForSequenceClassification model from a <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  416. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertForSequenceClassification model<span style=\"font-weight: bold\">)</span>. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  417. "</pre>\n"
  418. ],
  419. "text/plain": [
  420. "\u001b[2;36m[10:48:21]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m Some weights of the model checkpoint at \u001b]8;id=387614;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=648168;file://../../fastNLP/transformers/torch/modeling_utils.py#1490\u001b\\\u001b[2m1490\u001b[0m\u001b]8;;\u001b\\\n",
  421. "\u001b[2;36m \u001b[0m bert-base-uncased were not used when initializing \u001b[2m \u001b[0m\n",
  422. "\u001b[2;36m \u001b[0m BertModel: \u001b[1m[\u001b[0m\u001b[32m'cls.predictions.bias'\u001b[0m, \u001b[2m \u001b[0m\n",
  423. "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.LayerNorm.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
  424. "\u001b[2;36m \u001b[0m \u001b[32m'cls.seq_relationship.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
  425. "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.decoder.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
  426. "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.dense.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
  427. "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.LayerNorm.bias'\u001b[0m, \u001b[2m \u001b[0m\n",
  428. "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.dense.bias'\u001b[0m, \u001b[2m \u001b[0m\n",
  429. "\u001b[2;36m \u001b[0m \u001b[32m'cls.seq_relationship.bias'\u001b[0m\u001b[1m]\u001b[0m \u001b[2m \u001b[0m\n",
  430. "\u001b[2;36m \u001b[0m - This IS expected if you are initializing \u001b[2m \u001b[0m\n",
  431. "\u001b[2;36m \u001b[0m BertModel from the checkpoint of a model trained \u001b[2m \u001b[0m\n",
  432. "\u001b[2;36m \u001b[0m on another task or with another architecture \u001b[1m(\u001b[0me.g. \u001b[2m \u001b[0m\n",
  433. "\u001b[2;36m \u001b[0m initializing a BertForSequenceClassification model \u001b[2m \u001b[0m\n",
  434. "\u001b[2;36m \u001b[0m from a BertForPreTraining model\u001b[1m)\u001b[0m. \u001b[2m \u001b[0m\n",
  435. "\u001b[2;36m \u001b[0m - This IS NOT expected if you are initializing \u001b[2m \u001b[0m\n",
  436. "\u001b[2;36m \u001b[0m BertModel from the checkpoint of a model that you \u001b[2m \u001b[0m\n",
  437. "\u001b[2;36m \u001b[0m expect to be exactly identical \u001b[1m(\u001b[0minitializing a \u001b[2m \u001b[0m\n",
  438. "\u001b[2;36m \u001b[0m BertForSequenceClassification model from a \u001b[2m \u001b[0m\n",
  439. "\u001b[2;36m \u001b[0m BertForSequenceClassification model\u001b[1m)\u001b[0m. \u001b[2m \u001b[0m\n"
  440. ]
  441. },
  442. "metadata": {},
  443. "output_type": "display_data"
  444. },
  445. {
  446. "data": {
  447. "text/html": [
  448. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> All the weights of BertModel were initialized from <a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">modeling_utils.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py#1507\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1507</span></a>\n",
  449. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> the model checkpoint at bert-base-uncased. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  450. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> If your task is similar to the task the model of <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  451. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> the checkpoint was trained on, you can already use <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  452. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertModel for predictions without further <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  453. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> training. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  454. "</pre>\n"
  455. ],
  456. "text/plain": [
  457. "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m All the weights of BertModel were initialized from \u001b]8;id=544687;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=934505;file://../../fastNLP/transformers/torch/modeling_utils.py#1507\u001b\\\u001b[2m1507\u001b[0m\u001b]8;;\u001b\\\n",
  458. "\u001b[2;36m \u001b[0m the model checkpoint at bert-base-uncased. \u001b[2m \u001b[0m\n",
  459. "\u001b[2;36m \u001b[0m If your task is similar to the task the model of \u001b[2m \u001b[0m\n",
  460. "\u001b[2;36m \u001b[0m the checkpoint was trained on, you can already use \u001b[2m \u001b[0m\n",
  461. "\u001b[2;36m \u001b[0m BertModel for predictions without further \u001b[2m \u001b[0m\n",
  462. "\u001b[2;36m \u001b[0m training. \u001b[2m \u001b[0m\n"
  463. ]
  464. },
  465. "metadata": {},
  466. "output_type": "display_data"
  467. }
  468. ],
  469. "source": [
  470. "import torch\n",
  471. "from torch import nn\n",
  472. "from torch.nn.utils.rnn import pad_sequence\n",
  473. "from fastNLP.transformers.torch import BertModel\n",
  474. "from fastNLP import seq_len_to_mask\n",
  475. "import torch.nn.functional as F\n",
  476. "\n",
  477. "\n",
  478. "class BertNER(nn.Module):\n",
  479. " def __init__(self, model_name, num_class, tag_vocab=None):\n",
  480. " super().__init__()\n",
  481. " self.bert = BertModel.from_pretrained(model_name)\n",
  482. " self.mlp = nn.Sequential(nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size),\n",
  483. " nn.Dropout(0.3),\n",
  484. " nn.Linear(self.bert.config.hidden_size, num_class))\n",
  485. " self.tag_vocab = tag_vocab # 这里传入 tag_vocab 的目的是为了演示 constrined_decode \n",
  486. " if tag_vocab is not None:\n",
  487. " self._init_constrained_transition()\n",
  488. " \n",
  489. " def forward(self, input_ids, input_len, first):\n",
  490. " attention_mask = seq_len_to_mask(input_len)\n",
  491. " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
  492. " last_hidden_state = outputs.last_hidden_state\n",
  493. " first = first.unsqueeze(-1).repeat(1, 1, last_hidden_state.size(-1))\n",
  494. " first_bpe_state = last_hidden_state.gather(dim=1, index=first)\n",
  495. " first_bpe_state = first_bpe_state[:, 1:-1] # 删除 cls 和 sep\n",
  496. " \n",
  497. " pred = self.mlp(first_bpe_state)\n",
  498. " return {'pred': pred}\n",
  499. " \n",
  500. " def train_step(self, input_ids, input_len, first, target):\n",
  501. " pred = self(input_ids, input_len, first)['pred']\n",
  502. " loss = F.cross_entropy(pred.transpose(1, 2), target)\n",
  503. " return {'loss': loss}\n",
  504. " \n",
  505. " def evaluate_step(self, input_ids, input_len, first):\n",
  506. " pred = self(input_ids, input_len, first)['pred'].argmax(dim=-1)\n",
  507. " return {'pred': pred}\n",
  508. " \n",
  509. " def constrained_decode(self, input_ids, input_len, first, first_len):\n",
  510. " # 这个函数在推理时,将保证解码出来的 tag 一定不与前一个 tag 矛盾【例如一定不会出现 B-person 后面接着 I-Location 的情况】\n",
  511. " # 本身这个需求可以在 Metric 中实现,这里在模型中实现的目的是为了方便演示:如何在fastNLP中使用不同的评测函数\n",
  512. " pred = self(input_ids, input_len, first)['pred']\n",
  513. " cons_pred = []\n",
  514. " for _pred, _len in zip(pred, first_len):\n",
  515. " _pred = _pred[:_len]\n",
  516. " tags = [_pred[0].argmax(dim=-1).item()] # 这里就不考虑第一个位置非法的情况了\n",
  517. " for i in range(1, _len):\n",
  518. " tags.append((_pred[i] + self.transition[tags[-1]]).argmax().item())\n",
  519. " cons_pred.append(torch.LongTensor(tags))\n",
  520. " cons_pred = pad_sequence(cons_pred, batch_first=True)\n",
  521. " return {'pred': cons_pred}\n",
  522. " \n",
  523. " def _init_constrained_transition(self):\n",
  524. " from fastNLP.modules.torch import allowed_transitions\n",
  525. " allowed_trans = allowed_transitions(self.tag_vocab)\n",
  526. " transition = torch.ones((len(self.tag_vocab), len(self.tag_vocab)))*-100000.0\n",
  527. " for s, e in allowed_trans:\n",
  528. " transition[s, e] = 0\n",
  529. " self.register_buffer('transition', transition)\n",
  530. "\n",
  531. "model = BertNER('bert-base-uncased', len(data_bundle.get_vocab('target')), data_bundle.get_vocab('target'))"
  532. ]
  533. },
  534. {
  535. "cell_type": "markdown",
  536. "id": "5aeee1e9",
  537. "metadata": {},
  538. "source": [
  539. "### Trainer 的使用\n",
  540. "fastNLP 的 Trainer 是用于对模型进行训练的部件。"
  541. ]
  542. },
  543. {
  544. "cell_type": "code",
  545. "execution_count": 8,
  546. "id": "f4250f0b",
  547. "metadata": {
  548. "scrolled": false
  549. },
  550. "outputs": [
  551. {
  552. "data": {
  553. "text/html": [
  554. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:49:22] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../../fastNLP/core/controllers/trainer.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/controllers/trainer.py#661\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">661</span></a>\n",
  555. "</pre>\n"
  556. ],
  557. "text/plain": [
  558. "\u001b[2;36m[10:49:22]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=246773;file://../../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=639347;file://../../fastNLP/core/controllers/trainer.py#661\u001b\\\u001b[2m661\u001b[0m\u001b]8;;\u001b\\\n"
  559. ]
  560. },
  561. "metadata": {},
  562. "output_type": "display_data"
  563. },
  564. {
  565. "data": {
  566. "application/vnd.jupyter.widget-view+json": {
  567. "model_id": "",
  568. "version_major": 2,
  569. "version_minor": 0
  570. },
  571. "text/plain": [
  572. "Output()"
  573. ]
  574. },
  575. "metadata": {},
  576. "output_type": "display_data"
  577. },
  578. {
  579. "data": {
  580. "text/html": [
  581. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  582. ],
  583. "text/plain": []
  584. },
  585. "metadata": {},
  586. "output_type": "display_data"
  587. },
  588. {
  589. "data": {
  590. "application/vnd.jupyter.widget-view+json": {
  591. "model_id": "",
  592. "version_major": 2,
  593. "version_minor": 0
  594. },
  595. "text/plain": [
  596. "Output()"
  597. ]
  598. },
  599. "metadata": {},
  600. "output_type": "display_data"
  601. },
  602. {
  603. "data": {
  604. "text/html": [
  605. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  606. "</pre>\n"
  607. ],
  608. "text/plain": [
  609. "\n"
  610. ]
  611. },
  612. "metadata": {},
  613. "output_type": "display_data"
  614. },
  615. {
  616. "data": {
  617. "text/html": [
  618. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #00d75f; text-decoration-color: #00d75f\">+++++++++++++++++++++++++++++ </span><span style=\"font-weight: bold\">Eval. results on Epoch:</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span><span style=\"font-weight: bold\">, Batch:</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span><span style=\"color: #00d75f; text-decoration-color: #00d75f\"> +++++++++++++++++++++++++++++</span>\n",
  619. "</pre>\n"
  620. ],
  621. "text/plain": [
  622. "\u001b[38;5;41m+++++++++++++++++++++++++++++ \u001b[0m\u001b[1mEval. results on Epoch:\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1m, Batch:\u001b[0m\u001b[1;36m0\u001b[0m\u001b[38;5;41m +++++++++++++++++++++++++++++\u001b[0m\n"
  623. ]
  624. },
  625. "metadata": {},
  626. "output_type": "display_data"
  627. },
  628. {
  629. "data": {
  630. "text/html": [
  631. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  632. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#f\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.402447</span>,\n",
  633. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#f\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.447906</span>,\n",
  634. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#f\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.365365</span>\n",
  635. "<span style=\"font-weight: bold\">}</span>\n",
  636. "</pre>\n"
  637. ],
  638. "text/plain": [
  639. "\u001b[1m{\u001b[0m\n",
  640. " \u001b[1;34m\"f#f\"\u001b[0m: \u001b[1;36m0.402447\u001b[0m,\n",
  641. " \u001b[1;34m\"pre#f\"\u001b[0m: \u001b[1;36m0.447906\u001b[0m,\n",
  642. " \u001b[1;34m\"rec#f\"\u001b[0m: \u001b[1;36m0.365365\u001b[0m\n",
  643. "\u001b[1m}\u001b[0m\n"
  644. ]
  645. },
  646. "metadata": {},
  647. "output_type": "display_data"
  648. },
  649. {
  650. "data": {
  651. "text/html": [
  652. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:51:15] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> The best performance for monitor f#<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">f:0</span>.<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">402447</span> was <a href=\"file://../../fastNLP/core/callbacks/progress_callback.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">progress_callback.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/callbacks/progress_callback.py#37\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">37</span></a>\n",
  653. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> achieved in Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Global Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">625</span>. The <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  654. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> evaluation result: <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  655. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'f#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.402447</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'pre#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.447906</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'rec#f'</span>: <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  656. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.365365</span><span style=\"font-weight: bold\">}</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  657. "</pre>\n"
  658. ],
  659. "text/plain": [
  660. "\u001b[2;36m[10:51:15]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m The best performance for monitor f#\u001b[1;92mf:0\u001b[0m.\u001b[1;36m402447\u001b[0m was \u001b]8;id=192029;file://../../fastNLP/core/callbacks/progress_callback.py\u001b\\\u001b[2mprogress_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=994998;file://../../fastNLP/core/callbacks/progress_callback.py#37\u001b\\\u001b[2m37\u001b[0m\u001b]8;;\u001b\\\n",
  661. "\u001b[2;36m \u001b[0m achieved in Epoch:\u001b[1;36m1\u001b[0m, Global Batch:\u001b[1;36m625\u001b[0m. The \u001b[2m \u001b[0m\n",
  662. "\u001b[2;36m \u001b[0m evaluation result: \u001b[2m \u001b[0m\n",
  663. "\u001b[2;36m \u001b[0m \u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.402447\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.447906\u001b[0m, \u001b[32m'rec#f'\u001b[0m: \u001b[2m \u001b[0m\n",
  664. "\u001b[2;36m \u001b[0m \u001b[1;36m0.365365\u001b[0m\u001b[1m}\u001b[0m \u001b[2m \u001b[0m\n"
  665. ]
  666. },
  667. "metadata": {},
  668. "output_type": "display_data"
  669. },
  670. {
  671. "data": {
  672. "text/html": [
  673. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  674. ],
  675. "text/plain": []
  676. },
  677. "metadata": {},
  678. "output_type": "display_data"
  679. },
  680. {
  681. "data": {
  682. "text/html": [
  683. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  684. "</pre>\n"
  685. ],
  686. "text/plain": [
  687. "\n"
  688. ]
  689. },
  690. "metadata": {},
  691. "output_type": "display_data"
  692. },
  693. {
  694. "data": {
  695. "text/html": [
  696. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Loading best model from buffer with f#f: <a href=\"file://../../fastNLP/core/callbacks/load_best_model_callback.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">load_best_model_callback.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/callbacks/load_best_model_callback.py#115\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">115</span></a>\n",
  697. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.402447</span><span style=\"color: #808000; text-decoration-color: #808000\">...</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  698. "</pre>\n"
  699. ],
  700. "text/plain": [
  701. "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from buffer with f#f: \u001b]8;id=654516;file://../../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=96586;file://../../fastNLP/core/callbacks/load_best_model_callback.py#115\u001b\\\u001b[2m115\u001b[0m\u001b]8;;\u001b\\\n",
  702. "\u001b[2;36m \u001b[0m \u001b[1;36m0.402447\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n"
  703. ]
  704. },
  705. "metadata": {},
  706. "output_type": "display_data"
  707. }
  708. ],
  709. "source": [
  710. "from torch import optim\n",
  711. "from fastNLP import Trainer, LoadBestModelCallback, TorchWarmupCallback\n",
  712. "from fastNLP import SpanFPreRecMetric\n",
  713. "\n",
  714. "optimizer = optim.AdamW(model.parameters(), lr=2e-5)\n",
  715. "callbacks = [\n",
  716. " LoadBestModelCallback(), # 用于在训练结束之后加载性能最好的model的权重\n",
  717. " TorchWarmupCallback()\n",
  718. "] \n",
  719. "\n",
  720. "trainer = Trainer(model=model, train_dataloader=dls['train'], optimizers=optimizer, \n",
  721. " evaluate_dataloaders=dls['dev'], \n",
  722. " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n",
  723. " n_epochs=1, callbacks=callbacks, \n",
  724. " # 在评测时将 dataloader 中的 first_len 映射 seq_len, 因为 Accuracy.update 接口需要输入一个名为 seq_len 的参数\n",
  725. " evaluate_input_mapping={'first_len': 'seq_len'}, overfit_batches=0,\n",
  726. " device=0, monitor='f#f', fp16=False) # fp16 为 True 的话,将使用 float16 进行训练。\n",
  727. "trainer.run()"
  728. ]
  729. },
  730. {
  731. "cell_type": "markdown",
  732. "id": "c600a450",
  733. "metadata": {},
  734. "source": [
  735. "### Evaluator的使用\n",
  736. "fastNLP中用于评测数据的对象。"
  737. ]
  738. },
  739. {
  740. "cell_type": "code",
  741. "execution_count": 9,
  742. "id": "1b19f0ba",
  743. "metadata": {},
  744. "outputs": [
  745. {
  746. "data": {
  747. "application/vnd.jupyter.widget-view+json": {
  748. "model_id": "",
  749. "version_major": 2,
  750. "version_minor": 0
  751. },
  752. "text/plain": [
  753. "Output()"
  754. ]
  755. },
  756. "metadata": {},
  757. "output_type": "display_data"
  758. },
  759. {
  760. "data": {
  761. "text/html": [
  762. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  763. ],
  764. "text/plain": []
  765. },
  766. "metadata": {},
  767. "output_type": "display_data"
  768. },
  769. {
  770. "data": {
  771. "text/html": [
  772. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'f#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.390326</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'pre#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.414741</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'rec#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.368626</span><span style=\"font-weight: bold\">}</span>\n",
  773. "</pre>\n"
  774. ],
  775. "text/plain": [
  776. "\u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.390326\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.414741\u001b[0m, \u001b[32m'rec#f'\u001b[0m: \u001b[1;36m0.368626\u001b[0m\u001b[1m}\u001b[0m\n"
  777. ]
  778. },
  779. "metadata": {},
  780. "output_type": "display_data"
  781. },
  782. {
  783. "data": {
  784. "text/plain": [
  785. "{'f#f': 0.390326, 'pre#f': 0.414741, 'rec#f': 0.368626}"
  786. ]
  787. },
  788. "execution_count": 9,
  789. "metadata": {},
  790. "output_type": "execute_result"
  791. }
  792. ],
  793. "source": [
  794. "from fastNLP import Evaluator\n",
  795. "from fastNLP import SpanFPreRecMetric\n",
  796. "\n",
  797. "evaluator = Evaluator(model=model, dataloaders=dls['test'], \n",
  798. " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n",
  799. " evaluate_input_mapping={'first_len': 'seq_len'}, \n",
  800. " device=0)\n",
  801. "evaluator.run()"
  802. ]
  803. },
  804. {
  805. "cell_type": "code",
  806. "execution_count": null,
  807. "id": "52f87770",
  808. "metadata": {},
  809. "outputs": [
  810. {
  811. "data": {
  812. "application/vnd.jupyter.widget-view+json": {
  813. "model_id": "f723fe399df34917875ad74c2542508c",
  814. "version_major": 2,
  815. "version_minor": 0
  816. },
  817. "text/plain": [
  818. "Output()"
  819. ]
  820. },
  821. "metadata": {},
  822. "output_type": "display_data"
  823. }
  824. ],
  825. "source": [
  826. "# 如果想评测一下使用 constrained decoding的性能,则可以通过传入 evaluate_fn 指定使用的函数\n",
  827. "def input_mapping(x):\n",
  828. " x['seq_len'] = x['first_len']\n",
  829. " return x\n",
  830. "evaluator = Evaluator(model=model, dataloaders=dls['test'], device=0,\n",
  831. " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))},\n",
  832. " evaluate_fn='constrained_decode',\n",
  833. " # 如果将 first_len 重新命名为了 seq_len, 将导致 constrained_decode 的输入缺少 first_len 参数,因此\n",
  834. " # 额外重复一下 'first_len': 'first_len',使得这个参数不会消失。\n",
  835. " evaluate_input_mapping=input_mapping)\n",
  836. "evaluator.run()"
  837. ]
  838. },
  839. {
  840. "cell_type": "code",
  841. "execution_count": null,
  842. "id": "419e718b",
  843. "metadata": {},
  844. "outputs": [],
  845. "source": []
  846. }
  847. ],
  848. "metadata": {
  849. "kernelspec": {
  850. "display_name": "Python 3 (ipykernel)",
  851. "language": "python",
  852. "name": "python3"
  853. },
  854. "language_info": {
  855. "codemirror_mode": {
  856. "name": "ipython",
  857. "version": 3
  858. },
  859. "file_extension": ".py",
  860. "mimetype": "text/x-python",
  861. "name": "python",
  862. "nbconvert_exporter": "python",
  863. "pygments_lexer": "ipython3",
  864. "version": "3.7.13"
  865. }
  866. },
  867. "nbformat": 4,
  868. "nbformat_minor": 5
  869. }