You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_tutorial_0.ipynb 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "aec0fde7",
  6. "metadata": {},
  7. "source": [
  8. "# T0. trainer 和 evaluator 的基本使用\n",
  9. "\n",
  10. "  1   trainer 和 evaluator 的基本关系\n",
  11. " \n",
  12. "    1.1   trainer 和 evaluater 的初始化\n",
  13. "\n",
  14. "    1.2   driver 的含义与使用要求\n",
  15. "\n",
  16. "    1.3   trainer 内部初始化 evaluater\n",
  17. "\n",
  18. "  2   使用 fastNLP 0.8 搭建 argmax 模型\n",
  19. "\n",
  20. "    2.1   trainer_step 和 evaluator_step\n",
  21. "\n",
  22. "    2.2   trainer 和 evaluator 的参数匹配\n",
  23. "\n",
  24. "    2.3   一个实际案例:argmax 模型\n",
  25. "\n",
  26. "  3   使用 fastNLP 0.8 训练 argmax 模型\n",
  27. " \n",
  28. "    3.1   trainer 外部初始化的 evaluator\n",
  29. "\n",
  30. "    3.2   trainer 内部初始化的 evaluator "
  31. ]
  32. },
  33. {
  34. "cell_type": "markdown",
  35. "id": "09ea669a",
  36. "metadata": {},
  37. "source": [
  38. "## 1. trainer 和 evaluator 的基本关系\n",
  39. "\n",
  40. "### 1.1 trainer 和 evaluator 的初始化\n",
  41. "\n",
  42. "在`fastNLP 0.8`中,**`Trainer`模块和`Evaluator`模块分别表示“训练器”和“评测器”**\n",
  43. "\n",
  44. "  对应于之前的`fastNLP`版本中的`Trainer`模块和`Tester`模块,其定义方法如下所示\n",
  45. "\n",
  46. "在`fastNLP 0.8`中,需要注意,在同个`python`脚本中先使用`Trainer`训练,然后使用`Evaluator`评测\n",
  47. "\n",
  48. "  非常关键的问题在于**如何正确设置二者的`driver`**。这就引入了另一个问题:什么是 `driver`?\n",
  49. "\n",
  50. "\n",
  51. "```python\n",
  52. "trainer = Trainer(\n",
  53. " model=model, # 模型基于 torch.nn.Module\n",
  54. " train_dataloader=train_dataloader, # 加载模块基于 torch.utils.data.DataLoader \n",
  55. " optimizers=optimizer, # 优化模块基于 torch.optim.*\n",
  56. "\t...\n",
  57. "\tdriver=\"torch\", # 使用 pytorch 模块进行训练 \n",
  58. "\tdevice='cuda', # 使用 GPU:0 显卡执行训练\n",
  59. "\t...\n",
  60. ")\n",
  61. "...\n",
  62. "evaluator = Evaluator(\n",
  63. " model=model, # 模型基于 torch.nn.Module\n",
  64. " dataloaders=evaluate_dataloader, # 加载模块基于 torch.utils.data.DataLoader\n",
  65. " metrics={'acc': Accuracy()}, # 测评方法使用 fastNLP.core.metrics.Accuracy \n",
  66. " ...\n",
  67. " driver=trainer.driver, # 保持同 trainer 的 driver 一致\n",
  68. "\tdevice=None,\n",
  69. " ...\n",
  70. ")\n",
  71. "```"
  72. ]
  73. },
  74. {
  75. "cell_type": "markdown",
  76. "id": "3c11fe1a",
  77. "metadata": {},
  78. "source": [
  79. "### 1.2 driver 的含义与使用要求\n",
  80. "\n",
  81. "在`fastNLP 0.8`中,**`driver`**这一概念被用来表示**控制具体训练的各个步骤的最终执行部分**\n",
  82. "\n",
  83. "  例如神经网络前向、后向传播的具体执行、网络参数的优化和数据在设备间的迁移等\n",
  84. "\n",
  85. "在`fastNLP 0.8`中,**`Trainer`和`Evaluator`都依赖于具体的`driver`来完成整体的工作流程**\n",
  86. "\n",
  87. "  具体`driver`与`Trainer`以及`Evaluator`之间的关系请参考`fastNLP 0.8`的框架设计\n",
  88. "\n",
  89. "注:在同一脚本中,`Trainer`和`Evaluator`使用的`driver`应当保持一致\n",
  90. "\n",
  91. "  一个不能违背的原则在于:**不要将多卡的`driver`前使用单卡的`driver`**(???),这样使用可能会带来很多意想不到的错误"
  92. ]
  93. },
  94. {
  95. "cell_type": "markdown",
  96. "id": "2cac4a1a",
  97. "metadata": {},
  98. "source": [
  99. "### 1.3 Trainer 内部初始化 Evaluator\n",
  100. "\n",
  101. "在`fastNLP 0.8`中,如果在**初始化`Trainer`时**,**传入参数`evaluator_dataloaders`和`metrics`**\n",
  102. "\n",
  103. "  则在`Trainer`内部,也会初始化单独的`Evaluator`来帮助训练过程中对验证集的评测\n",
  104. "\n",
  105. "```python\n",
  106. "trainer = Trainer(\n",
  107. " model=model,\n",
  108. " train_dataloader=train_dataloader,\n",
  109. " optimizers=optimizer,\n",
  110. "\t...\n",
  111. "\tdriver=\"torch\",\n",
  112. "\tdevice='cuda',\n",
  113. "\t...\n",
  114. " evaluate_dataloaders=evaluate_dataloader, # 传入参数 evaluator_dataloaders\n",
  115. " metrics={'acc': Accuracy()}, # 传入参数 metrics\n",
  116. "\t...\n",
  117. ")\n",
  118. "```"
  119. ]
  120. },
  121. {
  122. "cell_type": "markdown",
  123. "id": "0c9c7dda",
  124. "metadata": {},
  125. "source": [
  126. "## 2. argmax 模型的搭建实例"
  127. ]
  128. },
  129. {
  130. "cell_type": "markdown",
  131. "id": "524ac200",
  132. "metadata": {},
  133. "source": [
  134. "### 2.1 trainer_step 和 evaluator_step\n",
  135. "\n",
  136. "在`fastNLP 0.8`中,使用`pytorch.nn.Module`搭建需要训练的模型,在搭建模型过程中,除了\n",
  137. "\n",
  138. "  添加`pytorch`要求的`forward`方法外,还需要添加 **`train_step`** 和 **`evaluate_step`** 这两个方法\n",
  139. "\n",
  140. "```python\n",
  141. "class Model(torch.nn.Module):\n",
  142. " def __init__(self):\n",
  143. " super(Model, self).__init__()\n",
  144. " self.loss_fn = torch.nn.CrossEntropyLoss()\n",
  145. " pass\n",
  146. "\n",
  147. " def forward(self, x):\n",
  148. " pass\n",
  149. "\n",
  150. " def train_step(self, x, y):\n",
  151. " pred = self(x)\n",
  152. " return {\"loss\": self.loss_fn(pred, y)}\n",
  153. "\n",
  154. " def evaluate_step(self, x, y):\n",
  155. " pred = self(x)\n",
  156. " pred = torch.max(pred, dim=-1)[1]\n",
  157. " return {\"pred\": pred, \"target\": y}\n",
  158. "```\n",
  159. "***\n",
  160. "在`fastNLP 0.8`中,**函数`train_step`是`Trainer`中参数`train_fn`的默认值**\n",
  161. "\n",
  162. "  由于,在`Trainer`训练时,**`Trainer`通过参数`train_fn`对应的模型方法获得当前数据批次的损失值**\n",
  163. "\n",
  164. "  因此,在`Trainer`训练时,`Trainer`首先会寻找模型是否定义了`train_step`这一方法\n",
  165. "\n",
  166. "    如果没有找到,那么`Trainer`会默认使用模型的`forward`函数来进行训练的前向传播过程\n",
  167. "\n",
  168. "注:在`fastNLP 0.8`中,**`Trainer`要求模型通过`train_step`来返回一个字典**,**满足如`{\"loss\": loss}`的形式**\n",
  169. "\n",
  170. "  此外,这里也可以通过传入`Trainer`的参数`output_mapping`来实现高度化的定制,具体请见这一note(???)\n",
  171. "\n",
  172. "同样,在`fastNLP 0.8`中,**函数`evaluate_step`是`Evaluator`中参数`evaluate_fn`的默认值**\n",
  173. "\n",
  174. "  在`Evaluator`测试时,**`Evaluator`通过参数`evaluate_fn`对应的模型方法获得当前数据批次的评测结果**\n",
  175. "\n",
  176. "  从用户角度,模型通过`evaluate_step`方法来返回一个字典,内容与传入`Evaluator`的`metrics`一致\n",
  177. "\n",
  178. "  从模块角度,该字典的键值和`metric`中的`update`函数的签名一致,这样的机制在传参时被称为“**参数匹配**”\n",
  179. "\n",
  180. "<img src=\"./figures/T0-fig-trainer-and-evaluator.png\" width=\"80%\" height=\"80%\" align=\"center\"></img>"
  181. ]
  182. },
  183. {
  184. "cell_type": "markdown",
  185. "id": "fb3272eb",
  186. "metadata": {},
  187. "source": [
  188. "### 2.2 trainer 和 evaluator 的参数匹配\n",
  189. "\n",
  190. "在`fastNLP 0.8`中,参数匹配涉及到两个方面,分别是在\n",
  191. "\n",
  192. "&emsp; 一方面,**在模型的前向传播中**,**`dataloader`向`train_step`或`evaluate_step`函数传递`batch`**\n",
  193. "\n",
  194. "&emsp; 另方面,**在模型的评测过程中**,**`evaluate_dataloader`向`metric`的`update`函数传递`batch`**\n",
  195. "\n",
  196. "对于前者,在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`False`时\n",
  197. "\n",
  198. "&emsp; &emsp; **`fastNLP 0.8`要求`dataloader`生成的每个`batch`**,**满足如`{\"x\": x, \"y\": y}`的形式**\n",
  199. "\n",
  200. "&emsp; 同时,`fastNLP 0.8`会查看模型的`train_step`和`evaluate_step`方法的参数签名,并为对应参数传入对应数值\n",
  201. "\n",
  202. "&emsp; &emsp; **字典形式的定义**,**对应在`Dataset`定义的`__getitem__`方法中**,例如下方的`ArgMaxDatset`\n",
  203. "\n",
  204. "&emsp; 而在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`True`时\n",
  205. "\n",
  206. "&emsp; &emsp; `fastNLP 0.8`会将`batch`直接传给模型的`train_step`、`evaluate_step`或`forward`函数\n",
  207. "\n",
  208. "```python\n",
  209. "class Dataset(torch.utils.data.Dataset):\n",
  210. " def __init__(self, x, y):\n",
  211. " self.x = x\n",
  212. " self.y = y\n",
  213. "\n",
  214. " def __len__(self):\n",
  215. " return len(self.x)\n",
  216. "\n",
  217. " def __getitem__(self, item):\n",
  218. " return {\"x\": self.x[item], \"y\": self.y[item]}\n",
  219. "```\n",
  220. "***\n",
  221. "对于后者,首先要明确,在`Trainer`和`Evaluator`中,`metrics`的计算分为`update`和`get_metric`两步\n",
  222. "\n",
  223. "&emsp; &emsp; **`update`函数**,**针对一个`batch`的预测结果**,计算其累计的评价指标\n",
  224. "\n",
  225. "&emsp; &emsp; **`get_metric`函数**,**统计`update`函数累计的评价指标**,来计算最终的评价结果\n",
  226. "\n",
  227. "&emsp; 例如对于`Accuracy`来说,`update`函数会更新一个`batch`的正例数量`right_num`和负例数量`total_num`\n",
  228. "\n",
  229. "&emsp; &emsp; 而`get_metric`函数则会返回所有`batch`的评测值`right_num / total_num`\n",
  230. "\n",
  231. "&emsp; 在此基础上,**`fastNLP 0.8`要求`evaluate_dataloader`生成的每个`batch`传递给对应的`metric`**\n",
  232. "\n",
  233. "&emsp; &emsp; **以`{\"pred\": y_pred, \"target\": y_true}`的形式**,对应其`update`函数的函数签名"
  234. ]
  235. },
  236. {
  237. "cell_type": "markdown",
  238. "id": "f62b7bb1",
  239. "metadata": {},
  240. "source": [
  241. "### 2.3 一个实际案例:argmax 模型\n",
  242. "\n",
  243. "下文将通过训练`argmax`模型,简单介绍如何`Trainer`模块的使用方式\n",
  244. "\n",
  245. "&emsp; 首先,使用`pytorch.nn.Module`定义`argmax`模型,目标是输入一组固定维度的向量,输出其中数值最大的数的索引"
  246. ]
  247. },
  248. {
  249. "cell_type": "code",
  250. "execution_count": 1,
  251. "id": "5314482b",
  252. "metadata": {
  253. "pycharm": {
  254. "is_executing": true
  255. }
  256. },
  257. "outputs": [],
  258. "source": [
  259. "import torch\n",
  260. "import torch.nn as nn\n",
  261. "\n",
  262. "class ArgMaxModel(nn.Module):\n",
  263. " def __init__(self, num_labels, feature_dimension):\n",
  264. " super(ArgMaxModel, self).__init__()\n",
  265. " self.num_labels = num_labels\n",
  266. "\n",
  267. " self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n",
  268. " self.ac1 = nn.ReLU()\n",
  269. " self.linear2 = nn.Linear(in_features=10, out_features=10)\n",
  270. " self.ac2 = nn.ReLU()\n",
  271. " self.output = nn.Linear(in_features=10, out_features=num_labels)\n",
  272. " self.loss_fn = nn.CrossEntropyLoss()\n",
  273. "\n",
  274. " def forward(self, x):\n",
  275. " pred = self.ac1(self.linear1(x))\n",
  276. " pred = self.ac2(self.linear2(pred))\n",
  277. " pred = self.output(pred)\n",
  278. " return pred\n",
  279. "\n",
  280. " def train_step(self, x, y):\n",
  281. " pred = self(x)\n",
  282. " return {\"loss\": self.loss_fn(pred, y)}\n",
  283. "\n",
  284. " def evaluate_step(self, x, y):\n",
  285. " pred = self(x)\n",
  286. " pred = torch.max(pred, dim=-1)[1]\n",
  287. " return {\"pred\": pred, \"target\": y}"
  288. ]
  289. },
  290. {
  291. "cell_type": "markdown",
  292. "id": "71f3fa6b",
  293. "metadata": {},
  294. "source": [
  295. "&emsp; 接着,使用`torch.utils.data.Dataset`定义`ArgMaxDataset`数据集\n",
  296. "\n",
  297. "&emsp; &emsp; 数据集包含三个参数:维度`feature_dimension`、数据量`data_num`和随机种子`seed`\n",
  298. "\n",
  299. "&emsp; &emsp; 数据及初始化是,自动生成指定维度的向量,并为每个向量标注出其中最大值的索引作为预测标签"
  300. ]
  301. },
  302. {
  303. "cell_type": "code",
  304. "execution_count": 2,
  305. "id": "fe612e61",
  306. "metadata": {
  307. "pycharm": {
  308. "is_executing": false
  309. }
  310. },
  311. "outputs": [],
  312. "source": [
  313. "from torch.utils.data import Dataset\n",
  314. "\n",
  315. "class ArgMaxDataset(Dataset):\n",
  316. " def __init__(self, feature_dimension, data_num=1000, seed=0):\n",
  317. " self.num_labels = feature_dimension\n",
  318. " self.feature_dimension = feature_dimension\n",
  319. " self.data_num = data_num\n",
  320. " self.seed = seed\n",
  321. "\n",
  322. " g = torch.Generator()\n",
  323. " g.manual_seed(1000)\n",
  324. " self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()\n",
  325. " self.y = torch.max(self.x, dim=-1)[1]\n",
  326. "\n",
  327. " def __len__(self):\n",
  328. " return self.data_num\n",
  329. "\n",
  330. " def __getitem__(self, item):\n",
  331. " return {\"x\": self.x[item], \"y\": self.y[item]}"
  332. ]
  333. },
  334. {
  335. "cell_type": "markdown",
  336. "id": "2cb96332",
  337. "metadata": {},
  338. "source": [
  339. "&emsp; 然后,根据`ArgMaxModel`类初始化模型实例,保持输入维度`feature_dimension`和输出标签数量`num_labels`一致\n",
  340. "\n",
  341. "&emsp; &emsp; 再根据`ArgMaxDataset`类初始化两个数据集实例,分别用来模型测试和模型评测,数据量各1000笔"
  342. ]
  343. },
  344. {
  345. "cell_type": "code",
  346. "execution_count": 3,
  347. "id": "76172ef8",
  348. "metadata": {
  349. "pycharm": {
  350. "is_executing": false
  351. }
  352. },
  353. "outputs": [],
  354. "source": [
  355. "model = ArgMaxModel(num_labels=10, feature_dimension=10)\n",
  356. "\n",
  357. "train_dataset = ArgMaxDataset(feature_dimension=10, data_num=1000)\n",
  358. "evaluate_dataset = ArgMaxDataset(feature_dimension=10, data_num=100)"
  359. ]
  360. },
  361. {
  362. "cell_type": "markdown",
  363. "id": "4e7d25ee",
  364. "metadata": {},
  365. "source": [
  366. "&emsp; 此外,使用`torch.utils.data.DataLoader`初始化两个数据加载模块,批量大小同为8,分别用于训练和测评"
  367. ]
  368. },
  369. {
  370. "cell_type": "code",
  371. "execution_count": 4,
  372. "id": "363b5b09",
  373. "metadata": {},
  374. "outputs": [],
  375. "source": [
  376. "from torch.utils.data import DataLoader\n",
  377. "\n",
  378. "train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n",
  379. "evaluate_dataloader = DataLoader(evaluate_dataset, batch_size=8)"
  380. ]
  381. },
  382. {
  383. "cell_type": "markdown",
  384. "id": "c8d4443f",
  385. "metadata": {},
  386. "source": [
  387. "&emsp; 最后,使用`torch.optim.SGD`初始化一个优化模块,基于随机梯度下降法"
  388. ]
  389. },
  390. {
  391. "cell_type": "code",
  392. "execution_count": 5,
  393. "id": "dc28a2d9",
  394. "metadata": {
  395. "pycharm": {
  396. "is_executing": false
  397. }
  398. },
  399. "outputs": [],
  400. "source": [
  401. "from torch.optim import SGD\n",
  402. "\n",
  403. "optimizer = SGD(model.parameters(), lr=0.001)"
  404. ]
  405. },
  406. {
  407. "cell_type": "markdown",
  408. "id": "eb8ca6cf",
  409. "metadata": {},
  410. "source": [
  411. "## 3. 使用 fastNLP 0.8 训练 argmax 模型\n",
  412. "\n",
  413. "### 3.1 trainer 外部初始化的 evaluator"
  414. ]
  415. },
  416. {
  417. "cell_type": "markdown",
  418. "id": "55145553",
  419. "metadata": {},
  420. "source": [
  421. "通过从`fastNLP`库中导入`Trainer`类,初始化`trainer`实例,对模型进行训练\n",
  422. "\n",
  423. "&emsp; 需要导入预先定义好的模型`model`、对应的数据加载模块`train_dataloader`、优化模块`optimizer`\n",
  424. "\n",
  425. "&emsp; 通过`progress_bar`设定进度条格式,默认为`\"auto\"`,此外还有`\"rich\"`、`\"raw\"`和`None`\n",
  426. "\n",
  427. "&emsp; &emsp; 但对于`\"auto\"`和`\"rich\"`格式,训练结束后进度条会不显示(???)\n",
  428. "\n",
  429. "&emsp; 通过`n_epochs`设定优化迭代轮数,默认为20;全部`Trainer`的全部变量与函数可以通过`dir(trainer)`查询"
  430. ]
  431. },
  432. {
  433. "cell_type": "code",
  434. "execution_count": 6,
  435. "id": "b51b7a2d",
  436. "metadata": {
  437. "pycharm": {
  438. "is_executing": false
  439. }
  440. },
  441. "outputs": [
  442. {
  443. "data": {
  444. "text/html": [
  445. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  446. "</pre>\n"
  447. ],
  448. "text/plain": [
  449. "\n"
  450. ]
  451. },
  452. "metadata": {},
  453. "output_type": "display_data"
  454. }
  455. ],
  456. "source": [
  457. "from fastNLP import Trainer\n",
  458. "\n",
  459. "trainer = Trainer(\n",
  460. " model=model,\n",
  461. " driver=\"torch\",\n",
  462. " device='cuda',\n",
  463. " train_dataloader=train_dataloader,\n",
  464. " optimizers=optimizer,\n",
  465. " n_epochs=10, # 设定迭代轮数 \n",
  466. " progress_bar=\"auto\" # 设定进度条格式\n",
  467. ")"
  468. ]
  469. },
  470. {
  471. "cell_type": "markdown",
  472. "id": "6e202d6e",
  473. "metadata": {},
  474. "source": [
  475. "通过使用`Trainer`类的`run`函数,进行训练\n",
  476. "\n",
  477. "&emsp; 其中,可以通过参数`num_train_batch_per_epoch`决定每个`epoch`运行多少个`batch`后停止,默认全部\n",
  478. "\n",
  479. "&emsp; 此外,可以通过`inspect.getfullargspec(trainer.run)`查询`run`函数的全部参数列表"
  480. ]
  481. },
  482. {
  483. "cell_type": "code",
  484. "execution_count": 7,
  485. "id": "ba047ead",
  486. "metadata": {
  487. "pycharm": {
  488. "is_executing": true
  489. }
  490. },
  491. "outputs": [
  492. {
  493. "data": {
  494. "application/vnd.jupyter.widget-view+json": {
  495. "model_id": "",
  496. "version_major": 2,
  497. "version_minor": 0
  498. },
  499. "text/plain": [
  500. "Output()"
  501. ]
  502. },
  503. "metadata": {},
  504. "output_type": "display_data"
  505. },
  506. {
  507. "data": {
  508. "text/html": [
  509. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  510. ],
  511. "text/plain": []
  512. },
  513. "metadata": {},
  514. "output_type": "display_data"
  515. },
  516. {
  517. "data": {
  518. "text/html": [
  519. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  520. "</pre>\n"
  521. ],
  522. "text/plain": [
  523. "\n"
  524. ]
  525. },
  526. "metadata": {},
  527. "output_type": "display_data"
  528. },
  529. {
  530. "data": {
  531. "text/html": [
  532. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  533. "</pre>\n"
  534. ],
  535. "text/plain": [
  536. "\n"
  537. ]
  538. },
  539. "metadata": {},
  540. "output_type": "display_data"
  541. }
  542. ],
  543. "source": [
  544. "trainer.run()"
  545. ]
  546. },
  547. {
  548. "cell_type": "markdown",
  549. "id": "c16c5fa4",
  550. "metadata": {},
  551. "source": [
  552. "通过从`fastNLP`库中导入`Evaluator`类,初始化`evaluator`实例,对模型进行评测\n",
  553. "\n",
  554. "&emsp; 需要导入预先定义好的模型`model`、对应的数据加载模块`evaluate_dataloader`\n",
  555. "\n",
  556. "&emsp; 需要注意的是评测方法`metrics`,设定为形如`{'acc': fastNLP.core.metrics.Accuracy()}`的字典\n",
  557. "\n",
  558. "&emsp; 类似地,也可以通过`progress_bar`限定进度条格式,默认为`\"auto\"`"
  559. ]
  560. },
  561. {
  562. "cell_type": "code",
  563. "execution_count": 8,
  564. "id": "1c6b6b36",
  565. "metadata": {
  566. "pycharm": {
  567. "is_executing": true
  568. }
  569. },
  570. "outputs": [],
  571. "source": [
  572. "from fastNLP import Evaluator\n",
  573. "from fastNLP.core.metrics import Accuracy\n",
  574. "\n",
  575. "evaluator = Evaluator(\n",
  576. " model=model,\n",
  577. " driver=trainer.driver, # 需要使用 trainer 已经启动的 driver\n",
  578. " device=None,\n",
  579. " dataloaders=evaluate_dataloader,\n",
  580. " metrics={'acc': Accuracy()} # 需要严格使用此种形式的字典\n",
  581. ")"
  582. ]
  583. },
  584. {
  585. "cell_type": "markdown",
  586. "id": "8157bb9b",
  587. "metadata": {},
  588. "source": [
  589. "通过使用`Evaluator`类的`run`函数,进行训练\n",
  590. "\n",
  591. "&emsp; 其中,可以通过参数`num_eval_batch_per_dl`决定每个`evaluate_dataloader`运行多少个`batch`停止,默认全部\n",
  592. "\n",
  593. "&emsp; 最终,输出形如`{'acc#acc': acc}`的字典,中间的进度条会在运行结束后丢弃掉(???)"
  594. ]
  595. },
  596. {
  597. "cell_type": "code",
  598. "execution_count": 9,
  599. "id": "f7cb0165",
  600. "metadata": {
  601. "pycharm": {
  602. "is_executing": true
  603. }
  604. },
  605. "outputs": [
  606. {
  607. "data": {
  608. "text/html": [
  609. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  610. ],
  611. "text/plain": []
  612. },
  613. "metadata": {},
  614. "output_type": "display_data"
  615. },
  616. {
  617. "data": {
  618. "text/html": [
  619. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  620. ],
  621. "text/plain": []
  622. },
  623. "metadata": {},
  624. "output_type": "display_data"
  625. },
  626. {
  627. "data": {
  628. "text/html": [
  629. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  630. "</pre>\n"
  631. ],
  632. "text/plain": [
  633. "\n"
  634. ]
  635. },
  636. "metadata": {},
  637. "output_type": "display_data"
  638. },
  639. {
  640. "data": {
  641. "text/html": [
  642. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.29</span><span style=\"font-weight: bold\">}</span>\n",
  643. "</pre>\n"
  644. ],
  645. "text/plain": [
  646. "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.29\u001b[0m\u001b[1m}\u001b[0m\n"
  647. ]
  648. },
  649. "metadata": {},
  650. "output_type": "display_data"
  651. },
  652. {
  653. "data": {
  654. "text/plain": [
  655. "{'acc#acc': 0.29}"
  656. ]
  657. },
  658. "execution_count": 9,
  659. "metadata": {},
  660. "output_type": "execute_result"
  661. }
  662. ],
  663. "source": [
  664. "evaluator.run()"
  665. ]
  666. },
  667. {
  668. "cell_type": "markdown",
  669. "id": "dd9f68fa",
  670. "metadata": {},
  671. "source": [
  672. "### 3.2 trainer 内部初始化的 evaluator \n",
  673. "\n",
  674. "通过在初始化`trainer`实例时加入`evaluate_dataloaders`和`metrics`,可以实现在训练过程中进行评测\n",
  675. "\n",
  676. "&emsp; 通过`progress_bar`同时设定训练和评估进度条格式,训练结束后进度条会不显示(???)\n",
  677. "\n",
  678. "&emsp; **通过`evaluate_every`设定评估频率**,可以为负数、正数或者函数:\n",
  679. "\n",
  680. "&emsp; &emsp; **为负数时**,**表示每隔几个`epoch`评估一次**;**为正数时**,**则表示每隔几个`batch`评估一次**"
  681. ]
  682. },
  683. {
  684. "cell_type": "code",
  685. "execution_count": 10,
  686. "id": "183c7d19",
  687. "metadata": {
  688. "pycharm": {
  689. "is_executing": true
  690. }
  691. },
  692. "outputs": [],
  693. "source": [
  694. "trainer = Trainer(\n",
  695. " model=model,\n",
  696. " driver=trainer.driver, # 因为是在同个脚本中,这里的 driver 同样需要重用\n",
  697. " train_dataloader=train_dataloader,\n",
  698. " evaluate_dataloaders=evaluate_dataloader,\n",
  699. " metrics={'acc': Accuracy()},\n",
  700. " optimizers=optimizer,\n",
  701. " n_epochs=10, \n",
  702. " evaluate_every=-1, # 表示每个 epoch 的结束进行评估\n",
  703. ")"
  704. ]
  705. },
  706. {
  707. "cell_type": "markdown",
  708. "id": "714cc404",
  709. "metadata": {},
  710. "source": [
  711. "通过使用`Trainer`类的`run`函数,进行训练\n",
  712. "\n",
  713. "&emsp; 还可以通过参数`num_eval_sanity_batch`决定每次训练前运行多少个`evaluate_batch`进行评测,默认为2"
  714. ]
  715. },
  716. {
  717. "cell_type": "code",
  718. "execution_count": 11,
  719. "id": "2e4daa2c",
  720. "metadata": {
  721. "pycharm": {
  722. "is_executing": true
  723. }
  724. },
  725. "outputs": [
  726. {
  727. "data": {
  728. "text/html": [
  729. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  730. ],
  731. "text/plain": []
  732. },
  733. "metadata": {},
  734. "output_type": "display_data"
  735. },
  736. {
  737. "data": {
  738. "text/html": [
  739. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  740. ],
  741. "text/plain": []
  742. },
  743. "metadata": {},
  744. "output_type": "display_data"
  745. },
  746. {
  747. "data": {
  748. "text/html": [
  749. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  750. "</pre>\n"
  751. ],
  752. "text/plain": [
  753. "\n"
  754. ]
  755. },
  756. "metadata": {},
  757. "output_type": "display_data"
  758. },
  759. {
  760. "data": {
  761. "text/html": [
  762. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  763. "</pre>\n"
  764. ],
  765. "text/plain": [
  766. "\n"
  767. ]
  768. },
  769. "metadata": {},
  770. "output_type": "display_data"
  771. }
  772. ],
  773. "source": [
  774. "trainer.run()"
  775. ]
  776. }
  777. ],
  778. "metadata": {
  779. "kernelspec": {
  780. "display_name": "Python 3 (ipykernel)",
  781. "language": "python",
  782. "name": "python3"
  783. },
  784. "language_info": {
  785. "codemirror_mode": {
  786. "name": "ipython",
  787. "version": 3
  788. },
  789. "file_extension": ".py",
  790. "mimetype": "text/x-python",
  791. "name": "python",
  792. "nbconvert_exporter": "python",
  793. "pygments_lexer": "ipython3",
  794. "version": "3.7.4"
  795. },
  796. "pycharm": {
  797. "stem_cell": {
  798. "cell_type": "raw",
  799. "metadata": {
  800. "collapsed": false
  801. },
  802. "source": []
  803. }
  804. }
  805. },
  806. "nbformat": 4,
  807. "nbformat_minor": 5
  808. }