You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_tutorial_0.ipynb 54 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "aec0fde7",
  6. "metadata": {},
  7. "source": [
  8. "# T0. trainer 和 evaluator 的基本使用\n",
  9. "\n",
  10. "  1   trainer 和 evaluator 的基本关系\n",
  11. " \n",
  12. "    1.1   trainer 和 evaluater 的初始化\n",
  13. "\n",
  14. "    1.2   driver 的含义与使用要求\n",
  15. "\n",
  16. "    1.3   trainer 内部初始化 evaluater\n",
  17. "\n",
  18. "  2   使用 fastNLP 搭建 argmax 模型\n",
  19. "\n",
  20. "    2.1   trainer_step 和 evaluator_step\n",
  21. "\n",
  22. "    2.2   trainer 和 evaluator 的参数匹配\n",
  23. "\n",
  24. "    2.3   示例:argmax 模型的搭建\n",
  25. "\n",
  26. "  3   使用 fastNLP 训练 argmax 模型\n",
  27. " \n",
  28. "    3.1   trainer 外部初始化的 evaluator\n",
  29. "\n",
  30. "    3.2   trainer 内部初始化的 evaluator "
  31. ]
  32. },
  33. {
  34. "cell_type": "markdown",
  35. "id": "09ea669a",
  36. "metadata": {},
  37. "source": [
  38. "## 1. trainer 和 evaluator 的基本关系\n",
  39. "\n",
  40. "### 1.1 trainer 和 evaluator 的初始化\n",
  41. "\n",
  42. "在`fastNLP 0.8`中,**`Trainer`模块和`Evaluator`模块分别表示“训练器”和“评测器”**\n",
  43. "\n",
  44. "  对应于之前的`fastNLP`版本中的`Trainer`模块和`Tester`模块,其定义方法如下所示\n",
  45. "\n",
  46. "在`fastNLP 0.8`中,需要注意,在同个`python`脚本中先使用`Trainer`训练,然后使用`Evaluator`评测\n",
  47. "\n",
  48. "  非常关键的问题在于**如何正确设置二者的`driver`**。这就引入了另一个问题:什么是 `driver`?\n",
  49. "\n",
  50. "\n",
  51. "```python\n",
  52. "trainer = Trainer(\n",
  53. " model=model, # 模型基于 torch.nn.Module\n",
  54. " train_dataloader=train_dataloader, # 加载模块基于 torch.utils.data.DataLoader \n",
  55. " optimizers=optimizer, # 优化模块基于 torch.optim.*\n",
  56. " ...\n",
  57. " driver=\"torch\", # 使用 pytorch 模块进行训练 \n",
  58. " device='cuda', # 使用 GPU:0 显卡执行训练\n",
  59. " ...\n",
  60. " )\n",
  61. "...\n",
  62. "evaluator = Evaluator(\n",
  63. " model=model, # 模型基于 torch.nn.Module\n",
  64. " dataloaders=evaluate_dataloader, # 加载模块基于 torch.utils.data.DataLoader\n",
  65. " metrics={'acc': Accuracy()}, # 测评方法使用 fastNLP.core.metrics.Accuracy \n",
  66. " ...\n",
  67. " driver=trainer.driver, # 保持同 trainer 的 driver 一致\n",
  68. " device=None,\n",
  69. " ...\n",
  70. " )\n",
  71. "```"
  72. ]
  73. },
  74. {
  75. "cell_type": "markdown",
  76. "id": "3c11fe1a",
  77. "metadata": {},
  78. "source": [
  79. "### 1.2 driver 的含义与使用要求\n",
  80. "\n",
  81. "在`fastNLP 0.8`中,**`driver`**这一概念被用来表示**控制具体训练的各个步骤的最终执行部分**\n",
  82. "\n",
  83. "  例如神经网络前向、后向传播的具体执行、网络参数的优化和数据在设备间的迁移等\n",
  84. "\n",
  85. "在`fastNLP 0.8`中,**`Trainer`和`Evaluator`都依赖于具体的`driver`来完成整体的工作流程**\n",
  86. "\n",
  87. "  具体`driver`与`Trainer`以及`Evaluator`之间的关系之后`tutorial 4`中的详细介绍\n",
  88. "\n",
  89. "注:这里给出一条建议:**在同一脚本中**,**所有的`Trainer`和`Evaluator`使用的`driver`应当保持一致**\n",
  90. "\n",
  91. "  尽量不出现,之前使用单卡的`driver`,后面又使用多卡的`driver`,这是因为,当脚本执行至\n",
  92. "\n",
  93. "  多卡`driver`处时,会重启一个进程执行之前所有内容,如此一来可能会造成一些意想不到的麻烦"
  94. ]
  95. },
  96. {
  97. "cell_type": "markdown",
  98. "id": "2cac4a1a",
  99. "metadata": {},
  100. "source": [
  101. "### 1.3 Trainer 内部初始化 Evaluator\n",
  102. "\n",
  103. "在`fastNLP 0.8`中,如果在**初始化`Trainer`时**,**传入参数`evaluator_dataloaders`和`metrics`**\n",
  104. "\n",
  105. "  则在`Trainer`内部,也会初始化单独的`Evaluator`来帮助训练过程中对验证集的评测\n",
  106. "\n",
  107. "```python\n",
  108. "trainer = Trainer(\n",
  109. " model=model,\n",
  110. " train_dataloader=train_dataloader,\n",
  111. " optimizers=optimizer,\n",
  112. " ...\n",
  113. " driver=\"torch\",\n",
  114. " device='cuda',\n",
  115. " ...\n",
  116. " evaluate_dataloaders=evaluate_dataloader, # 传入参数 evaluator_dataloaders\n",
  117. " metrics={'acc': Accuracy()}, # 传入参数 metrics\n",
  118. " ...\n",
  119. " )\n",
  120. "```"
  121. ]
  122. },
  123. {
  124. "cell_type": "markdown",
  125. "id": "0c9c7dda",
  126. "metadata": {},
  127. "source": [
  128. "## 2. argmax 模型的搭建实例"
  129. ]
  130. },
  131. {
  132. "cell_type": "markdown",
  133. "id": "524ac200",
  134. "metadata": {},
  135. "source": [
  136. "### 2.1 trainer_step 和 evaluator_step\n",
  137. "\n",
  138. "在`fastNLP 0.8`中,使用`pytorch.nn.Module`搭建需要训练的模型,在搭建模型过程中,除了\n",
  139. "\n",
  140. "  添加`pytorch`要求的`forward`方法外,还需要添加 **`train_step`** 和 **`evaluate_step`** 这两个方法\n",
  141. "\n",
  142. "```python\n",
  143. "class Model(torch.nn.Module):\n",
  144. " def __init__(self):\n",
  145. " super(Model, self).__init__()\n",
  146. " self.loss_fn = torch.nn.CrossEntropyLoss()\n",
  147. " pass\n",
  148. "\n",
  149. " def forward(self, x):\n",
  150. " pass\n",
  151. "\n",
  152. " def train_step(self, x, y):\n",
  153. " pred = self(x)\n",
  154. " return {\"loss\": self.loss_fn(pred, y)}\n",
  155. "\n",
  156. " def evaluate_step(self, x, y):\n",
  157. " pred = self(x)\n",
  158. " pred = torch.max(pred, dim=-1)[1]\n",
  159. " return {\"pred\": pred, \"target\": y}\n",
  160. "```\n",
  161. "***\n",
  162. "在`fastNLP 0.8`中,**函数`train_step`是`Trainer`中参数`train_fn`的默认值**\n",
  163. "\n",
  164. "  由于,在`Trainer`训练时,**`Trainer`通过参数`train_fn`对应的模型方法获得当前数据批次的损失值**\n",
  165. "\n",
  166. "  因此,在`Trainer`训练时,`Trainer`首先会寻找模型是否定义了`train_step`这一方法\n",
  167. "\n",
  168. "    如果没有找到,那么`Trainer`会默认使用模型的`forward`函数来进行训练的前向传播过程\n",
  169. "\n",
  170. "注:在`fastNLP 0.8`中,**`Trainer`要求模型通过`train_step`来返回一个字典**,**满足如`{\"loss\": loss}`的形式**\n",
  171. "\n",
  172. "  此外,这里也可以通过传入`Trainer`的参数`output_mapping`来实现输出的转换,详见(trainer的详细讲解,待补充)\n",
  173. "\n",
  174. "同样,在`fastNLP 0.8`中,**函数`evaluate_step`是`Evaluator`中参数`evaluate_fn`的默认值**\n",
  175. "\n",
  176. "  在`Evaluator`测试时,**`Evaluator`通过参数`evaluate_fn`对应的模型方法获得当前数据批次的评测结果**\n",
  177. "\n",
  178. "  从用户角度,模型通过`evaluate_step`方法来返回一个字典,内容与传入`Evaluator`的`metrics`一致\n",
  179. "\n",
  180. "  从模块角度,该字典的键值和`metric`中的`update`函数的签名一致,这样的机制在传参时被称为“**参数匹配**”\n",
  181. "\n",
  182. "<img src=\"./figures/T0-fig-training-structure.png\" width=\"68%\" height=\"68%\" align=\"center\"></img>"
  183. ]
  184. },
  185. {
  186. "cell_type": "markdown",
  187. "id": "fb3272eb",
  188. "metadata": {},
  189. "source": [
  190. "### 2.2 trainer 和 evaluator 的参数匹配\n",
  191. "\n",
  192. "在`fastNLP 0.8`中,参数匹配涉及到两个方面,分别是在\n",
  193. "\n",
  194. "&emsp; 一方面,**在模型的前向传播中**,**`dataloader`向`train_step`或`evaluate_step`函数传递`batch`**\n",
  195. "\n",
  196. "&emsp; 另方面,**在模型的评测过程中**,**`evaluate_dataloader`向`metric`的`update`函数传递`batch`**\n",
  197. "\n",
  198. "对于前者,在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`False`时\n",
  199. "\n",
  200. "&emsp; &emsp; **`fastNLP 0.8`要求`dataloader`生成的每个`batch`**,**满足如`{\"x\": x, \"y\": y}`的形式**\n",
  201. "\n",
  202. "&emsp; 同时,`fastNLP 0.8`会查看模型的`train_step`和`evaluate_step`方法的参数签名,并为对应参数传入对应数值\n",
  203. "\n",
  204. "&emsp; &emsp; **字典形式的定义**,**对应在`Dataset`定义的`__getitem__`方法中**,例如下方的`ArgMaxDatset`\n",
  205. "\n",
  206. "&emsp; 而在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`True`时\n",
  207. "\n",
  208. "&emsp; &emsp; `fastNLP 0.8`会将`batch`直接传给模型的`train_step`、`evaluate_step`或`forward`函数\n",
  209. "\n",
  210. "```python\n",
  211. "class Dataset(torch.utils.data.Dataset):\n",
  212. " def __init__(self, x, y):\n",
  213. " self.x = x\n",
  214. " self.y = y\n",
  215. "\n",
  216. " def __len__(self):\n",
  217. " return len(self.x)\n",
  218. "\n",
  219. " def __getitem__(self, item):\n",
  220. " return {\"x\": self.x[item], \"y\": self.y[item]}\n",
  221. "```"
  222. ]
  223. },
  224. {
  225. "cell_type": "markdown",
  226. "id": "f5f1a6aa",
  227. "metadata": {},
  228. "source": [
  229. "对于后者,首先要明确,在`Trainer`和`Evaluator`中,`metrics`的计算分为`update`和`get_metric`两步\n",
  230. "\n",
  231. "&emsp; &emsp; **`update`函数**,**针对一个`batch`的预测结果**,计算其累计的评价指标\n",
  232. "\n",
  233. "&emsp; &emsp; **`get_metric`函数**,**统计`update`函数累计的评价指标**,来计算最终的评价结果\n",
  234. "\n",
  235. "&emsp; 例如对于`Accuracy`来说,`update`函数会更新一个`batch`的正例数量`right_num`和负例数量`total_num`\n",
  236. "\n",
  237. "&emsp; &emsp; 而`get_metric`函数则会返回所有`batch`的评测值`right_num / total_num`\n",
  238. "\n",
  239. "&emsp; 在此基础上,**`fastNLP 0.8`要求`evaluate_dataloader`生成的每个`batch`传递给对应的`metric`**\n",
  240. "\n",
  241. "&emsp; &emsp; **以`{\"pred\": y_pred, \"target\": y_true}`的形式**,对应其`update`函数的函数签名\n",
  242. "\n",
  243. "<img src=\"./figures/T0-fig-parameter-matching.png\" width=\"75%\" height=\"75%\" align=\"center\"></img>"
  244. ]
  245. },
  246. {
  247. "cell_type": "markdown",
  248. "id": "f62b7bb1",
  249. "metadata": {},
  250. "source": [
  251. "### 2.3 示例:argmax 模型的搭建\n",
  252. "\n",
  253. "下文将通过训练`argmax`模型,简单介绍如何`Trainer`模块的使用方式\n",
  254. "\n",
  255. "&emsp; 首先,使用`pytorch.nn.Module`定义`argmax`模型,目标是输入一组固定维度的向量,输出其中数值最大的数的索引"
  256. ]
  257. },
  258. {
  259. "cell_type": "code",
  260. "execution_count": 1,
  261. "id": "5314482b",
  262. "metadata": {
  263. "pycharm": {
  264. "is_executing": true
  265. }
  266. },
  267. "outputs": [],
  268. "source": [
  269. "import torch\n",
  270. "import torch.nn as nn\n",
  271. "\n",
  272. "class ArgMaxModel(nn.Module):\n",
  273. " def __init__(self, num_labels, feature_dimension):\n",
  274. " nn.Module.__init__(self)\n",
  275. " self.num_labels = num_labels\n",
  276. "\n",
  277. " self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n",
  278. " self.ac1 = nn.ReLU()\n",
  279. " self.linear2 = nn.Linear(in_features=10, out_features=10)\n",
  280. " self.ac2 = nn.ReLU()\n",
  281. " self.output = nn.Linear(in_features=10, out_features=num_labels)\n",
  282. " self.loss_fn = nn.CrossEntropyLoss()\n",
  283. "\n",
  284. " def forward(self, x):\n",
  285. " pred = self.ac1(self.linear1(x))\n",
  286. " pred = self.ac2(self.linear2(pred))\n",
  287. " pred = self.output(pred)\n",
  288. " return pred\n",
  289. "\n",
  290. " def train_step(self, x, y):\n",
  291. " pred = self(x)\n",
  292. " return {\"loss\": self.loss_fn(pred, y)}\n",
  293. "\n",
  294. " def evaluate_step(self, x, y):\n",
  295. " pred = self(x)\n",
  296. " pred = torch.max(pred, dim=-1)[1]\n",
  297. " return {\"pred\": pred, \"target\": y}"
  298. ]
  299. },
  300. {
  301. "cell_type": "markdown",
  302. "id": "71f3fa6b",
  303. "metadata": {},
  304. "source": [
  305. "&emsp; 接着,使用`torch.utils.data.Dataset`定义`ArgMaxDataset`数据集\n",
  306. "\n",
  307. "&emsp; &emsp; 数据集包含三个参数:维度`feature_dimension`、数据量`data_num`和随机种子`seed`\n",
  308. "\n",
  309. "&emsp; &emsp; 数据及初始化是,自动生成指定维度的向量,并为每个向量标注出其中最大值的索引作为预测标签"
  310. ]
  311. },
  312. {
  313. "cell_type": "code",
  314. "execution_count": 2,
  315. "id": "fe612e61",
  316. "metadata": {
  317. "pycharm": {
  318. "is_executing": false
  319. }
  320. },
  321. "outputs": [],
  322. "source": [
  323. "from torch.utils.data import Dataset\n",
  324. "\n",
  325. "class ArgMaxDataset(Dataset):\n",
  326. " def __init__(self, feature_dimension, data_num=1000, seed=0):\n",
  327. " self.num_labels = feature_dimension\n",
  328. " self.feature_dimension = feature_dimension\n",
  329. " self.data_num = data_num\n",
  330. " self.seed = seed\n",
  331. "\n",
  332. " g = torch.Generator()\n",
  333. " g.manual_seed(1000)\n",
  334. " self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()\n",
  335. " self.y = torch.max(self.x, dim=-1)[1]\n",
  336. "\n",
  337. " def __len__(self):\n",
  338. " return self.data_num\n",
  339. "\n",
  340. " def __getitem__(self, item):\n",
  341. " return {\"x\": self.x[item], \"y\": self.y[item]}"
  342. ]
  343. },
  344. {
  345. "cell_type": "markdown",
  346. "id": "2cb96332",
  347. "metadata": {},
  348. "source": [
  349. "&emsp; 然后,根据`ArgMaxModel`类初始化模型实例,保持输入维度`feature_dimension`和输出标签数量`num_labels`一致\n",
  350. "\n",
  351. "&emsp; &emsp; 再根据`ArgMaxDataset`类初始化两个数据集实例,分别用来模型测试和模型评测,数据量各1000笔"
  352. ]
  353. },
  354. {
  355. "cell_type": "code",
  356. "execution_count": 3,
  357. "id": "76172ef8",
  358. "metadata": {
  359. "pycharm": {
  360. "is_executing": false
  361. }
  362. },
  363. "outputs": [],
  364. "source": [
  365. "model = ArgMaxModel(num_labels=10, feature_dimension=10)\n",
  366. "\n",
  367. "train_dataset = ArgMaxDataset(feature_dimension=10, data_num=1000)\n",
  368. "evaluate_dataset = ArgMaxDataset(feature_dimension=10, data_num=100)"
  369. ]
  370. },
  371. {
  372. "cell_type": "markdown",
  373. "id": "4e7d25ee",
  374. "metadata": {},
  375. "source": [
  376. "&emsp; 此外,使用`torch.utils.data.DataLoader`初始化两个数据加载模块,批量大小同为8,分别用于训练和测评"
  377. ]
  378. },
  379. {
  380. "cell_type": "code",
  381. "execution_count": 4,
  382. "id": "363b5b09",
  383. "metadata": {},
  384. "outputs": [],
  385. "source": [
  386. "from torch.utils.data import DataLoader\n",
  387. "\n",
  388. "train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n",
  389. "evaluate_dataloader = DataLoader(evaluate_dataset, batch_size=8)"
  390. ]
  391. },
  392. {
  393. "cell_type": "markdown",
  394. "id": "c8d4443f",
  395. "metadata": {},
  396. "source": [
  397. "&emsp; 最后,使用`torch.optim.SGD`初始化一个优化模块,基于随机梯度下降法"
  398. ]
  399. },
  400. {
  401. "cell_type": "code",
  402. "execution_count": 5,
  403. "id": "dc28a2d9",
  404. "metadata": {
  405. "pycharm": {
  406. "is_executing": false
  407. }
  408. },
  409. "outputs": [],
  410. "source": [
  411. "from torch.optim import SGD\n",
  412. "\n",
  413. "optimizer = SGD(model.parameters(), lr=0.001)"
  414. ]
  415. },
  416. {
  417. "cell_type": "markdown",
  418. "id": "eb8ca6cf",
  419. "metadata": {},
  420. "source": [
  421. "## 3. 使用 fastNLP 0.8 训练 argmax 模型\n",
  422. "\n",
  423. "### 3.1 trainer 外部初始化的 evaluator"
  424. ]
  425. },
  426. {
  427. "cell_type": "markdown",
  428. "id": "55145553",
  429. "metadata": {},
  430. "source": [
  431. "通过从`fastNLP`库中导入`Trainer`类,初始化`trainer`实例,对模型进行训练\n",
  432. "\n",
  433. "&emsp; 需要导入预先定义好的模型`model`、对应的数据加载模块`train_dataloader`、优化模块`optimizer`\n",
  434. "\n",
  435. "&emsp; 通过`progress_bar`设定进度条格式,默认为`\"auto\"`,此外还有`\"rich\"`、`\"raw\"`和`None`\n",
  436. "\n",
  437. "&emsp; &emsp; 但对于`\"auto\"`和`\"rich\"`格式,在`jupyter`中,进度条会在训练结束后会被丢弃\n",
  438. "\n",
  439. "&emsp; 通过`n_epochs`设定优化迭代轮数,默认为20;全部`Trainer`的全部变量与函数可以通过`dir(trainer)`查询"
  440. ]
  441. },
  442. {
  443. "cell_type": "code",
  444. "execution_count": 6,
  445. "id": "b51b7a2d",
  446. "metadata": {
  447. "pycharm": {
  448. "is_executing": false
  449. }
  450. },
  451. "outputs": [
  452. {
  453. "data": {
  454. "text/html": [
  455. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  456. "</pre>\n"
  457. ],
  458. "text/plain": [
  459. "\n"
  460. ]
  461. },
  462. "metadata": {},
  463. "output_type": "display_data"
  464. }
  465. ],
  466. "source": [
  467. "import sys\n",
  468. "sys.path.append('..')\n",
  469. "\n",
  470. "from fastNLP import Trainer\n",
  471. "\n",
  472. "trainer = Trainer(\n",
  473. " model=model,\n",
  474. " driver=\"torch\",\n",
  475. " device='cuda',\n",
  476. " train_dataloader=train_dataloader,\n",
  477. " optimizers=optimizer,\n",
  478. " n_epochs=10, # 设定迭代轮数 \n",
  479. " progress_bar=\"auto\" # 设定进度条格式\n",
  480. ")"
  481. ]
  482. },
  483. {
  484. "cell_type": "markdown",
  485. "id": "6e202d6e",
  486. "metadata": {},
  487. "source": [
  488. "通过使用`Trainer`类的`run`函数,进行训练\n",
  489. "\n",
  490. "&emsp; 其中,可以通过参数`num_train_batch_per_epoch`决定每个`epoch`运行多少个`batch`后停止,默认全部\n",
  491. "\n",
  492. "&emsp; `run`函数完成后在`jupyter`中没有输出保留,此外,通过`help(trainer.run)`可以查询`run`函数的详细内容"
  493. ]
  494. },
  495. {
  496. "cell_type": "code",
  497. "execution_count": 7,
  498. "id": "ba047ead",
  499. "metadata": {
  500. "pycharm": {
  501. "is_executing": true
  502. }
  503. },
  504. "outputs": [
  505. {
  506. "data": {
  507. "application/vnd.jupyter.widget-view+json": {
  508. "model_id": "",
  509. "version_major": 2,
  510. "version_minor": 0
  511. },
  512. "text/plain": [
  513. "Output()"
  514. ]
  515. },
  516. "metadata": {},
  517. "output_type": "display_data"
  518. },
  519. {
  520. "data": {
  521. "text/html": [
  522. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  523. ],
  524. "text/plain": []
  525. },
  526. "metadata": {},
  527. "output_type": "display_data"
  528. },
  529. {
  530. "data": {
  531. "text/html": [
  532. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  533. "</pre>\n"
  534. ],
  535. "text/plain": [
  536. "\n"
  537. ]
  538. },
  539. "metadata": {},
  540. "output_type": "display_data"
  541. }
  542. ],
  543. "source": [
  544. "trainer.run()"
  545. ]
  546. },
  547. {
  548. "cell_type": "markdown",
  549. "id": "c16c5fa4",
  550. "metadata": {},
  551. "source": [
  552. "通过从`fastNLP`库中导入`Evaluator`类,初始化`evaluator`实例,对模型进行评测\n",
  553. "\n",
  554. "&emsp; 需要导入预先定义好的模型`model`、对应的数据加载模块`evaluate_dataloader`\n",
  555. "\n",
  556. "&emsp; 需要注意的是评测方法`metrics`,设定为形如`{'acc': fastNLP.core.metrics.Accuracy()}`的字典\n",
  557. "\n",
  558. "&emsp; 类似地,也可以通过`progress_bar`限定进度条格式,默认为`\"auto\"`"
  559. ]
  560. },
  561. {
  562. "cell_type": "code",
  563. "execution_count": 8,
  564. "id": "1c6b6b36",
  565. "metadata": {
  566. "pycharm": {
  567. "is_executing": true
  568. }
  569. },
  570. "outputs": [],
  571. "source": [
  572. "from fastNLP import Evaluator\n",
  573. "from fastNLP import Accuracy\n",
  574. "\n",
  575. "evaluator = Evaluator(\n",
  576. " model=model,\n",
  577. " driver=trainer.driver, # 需要使用 trainer 已经启动的 driver\n",
  578. " device=None,\n",
  579. " dataloaders=evaluate_dataloader,\n",
  580. " metrics={'acc': Accuracy()} # 需要严格使用此种形式的字典\n",
  581. ")"
  582. ]
  583. },
  584. {
  585. "cell_type": "markdown",
  586. "id": "8157bb9b",
  587. "metadata": {},
  588. "source": [
  589. "通过使用`Evaluator`类的`run`函数,进行训练\n",
  590. "\n",
  591. "&emsp; 其中,可以通过参数`num_eval_batch_per_dl`决定每个`evaluate_dataloader`运行多少个`batch`停止,默认全部\n",
  592. "\n",
  593. "&emsp; 最终,输出形如`{'acc#acc': acc}`的字典,在`jupyter`中,进度条会在评测结束后会被丢弃"
  594. ]
  595. },
  596. {
  597. "cell_type": "code",
  598. "execution_count": 9,
  599. "id": "f7cb0165",
  600. "metadata": {
  601. "pycharm": {
  602. "is_executing": true
  603. }
  604. },
  605. "outputs": [
  606. {
  607. "data": {
  608. "application/vnd.jupyter.widget-view+json": {
  609. "model_id": "",
  610. "version_major": 2,
  611. "version_minor": 0
  612. },
  613. "text/plain": [
  614. "Output()"
  615. ]
  616. },
  617. "metadata": {},
  618. "output_type": "display_data"
  619. },
  620. {
  621. "data": {
  622. "text/html": [
  623. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  624. ],
  625. "text/plain": []
  626. },
  627. "metadata": {},
  628. "output_type": "display_data"
  629. },
  630. {
  631. "data": {
  632. "text/html": [
  633. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.31</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'total#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'correct#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31.0</span><span style=\"font-weight: bold\">}</span>\n",
  634. "</pre>\n"
  635. ],
  636. "text/plain": [
  637. "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.31\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m31.0\u001b[0m\u001b[1m}\u001b[0m\n"
  638. ]
  639. },
  640. "metadata": {},
  641. "output_type": "display_data"
  642. },
  643. {
  644. "data": {
  645. "text/plain": [
  646. "{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}"
  647. ]
  648. },
  649. "execution_count": 9,
  650. "metadata": {},
  651. "output_type": "execute_result"
  652. }
  653. ],
  654. "source": [
  655. "evaluator.run()"
  656. ]
  657. },
  658. {
  659. "cell_type": "markdown",
  660. "id": "dd9f68fa",
  661. "metadata": {},
  662. "source": [
  663. "### 3.2 trainer 内部初始化的 evaluator \n",
  664. "\n",
  665. "通过在初始化`trainer`实例时加入`evaluate_dataloaders`和`metrics`,可以实现在训练过程中进行评测\n",
  666. "\n",
  667. "&emsp; 通过`progress_bar`同时设定训练和评估进度条格式,在`jupyter`中,在进度条训练结束后会被丢弃\n",
  668. "\n",
  669. "&emsp; 但是中间的评估结果仍会保留;**通过`evaluate_every`设定评估频率**,可以为负数、正数或者函数:\n",
  670. "\n",
  671. "&emsp; &emsp; **为负数时**,**表示每隔几个`epoch`评估一次**;**为正数时**,**则表示每隔几个`batch`评估一次**"
  672. ]
  673. },
  674. {
  675. "cell_type": "code",
  676. "execution_count": 10,
  677. "id": "183c7d19",
  678. "metadata": {
  679. "pycharm": {
  680. "is_executing": true
  681. }
  682. },
  683. "outputs": [],
  684. "source": [
  685. "trainer = Trainer(\n",
  686. " model=model,\n",
  687. " driver=trainer.driver, # 因为是在同个脚本中,这里的 driver 同样需要重用\n",
  688. " train_dataloader=train_dataloader,\n",
  689. " evaluate_dataloaders=evaluate_dataloader,\n",
  690. " metrics={'acc': Accuracy()},\n",
  691. " optimizers=optimizer,\n",
  692. " n_epochs=10, \n",
  693. " evaluate_every=-1, # 表示每个 epoch 的结束进行评估\n",
  694. ")"
  695. ]
  696. },
  697. {
  698. "cell_type": "markdown",
  699. "id": "714cc404",
  700. "metadata": {},
  701. "source": [
  702. "通过使用`Trainer`类的`run`函数,进行训练\n",
  703. "\n",
  704. "&emsp; 还可以通过**参数`num_eval_sanity_batch`决定每次训练前运行多少个`evaluate_batch`进行评测**,**默认为`2`**\n",
  705. "\n",
  706. "&emsp; 之所以“先评测后训练”,是为了保证训练很长时间的数据,不会在评测阶段出问题,故作此**试探性评测**"
  707. ]
  708. },
  709. {
  710. "cell_type": "code",
  711. "execution_count": 11,
  712. "id": "2e4daa2c",
  713. "metadata": {
  714. "pycharm": {
  715. "is_executing": true
  716. }
  717. },
  718. "outputs": [
  719. {
  720. "data": {
  721. "text/html": [
  722. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[18:28:25] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#592\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">592</span></a>\n",
  723. "</pre>\n"
  724. ],
  725. "text/plain": [
  726. "\u001b[2;36m[18:28:25]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=549287;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=645362;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
  727. ]
  728. },
  729. "metadata": {},
  730. "output_type": "display_data"
  731. },
  732. {
  733. "data": {
  734. "application/vnd.jupyter.widget-view+json": {
  735. "model_id": "",
  736. "version_major": 2,
  737. "version_minor": 0
  738. },
  739. "text/plain": [
  740. "Output()"
  741. ]
  742. },
  743. "metadata": {},
  744. "output_type": "display_data"
  745. },
  746. {
  747. "data": {
  748. "text/html": [
  749. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  750. ],
  751. "text/plain": []
  752. },
  753. "metadata": {},
  754. "output_type": "display_data"
  755. },
  756. {
  757. "data": {
  758. "application/vnd.jupyter.widget-view+json": {
  759. "model_id": "",
  760. "version_major": 2,
  761. "version_minor": 0
  762. },
  763. "text/plain": [
  764. "Output()"
  765. ]
  766. },
  767. "metadata": {},
  768. "output_type": "display_data"
  769. },
  770. {
  771. "data": {
  772. "text/html": [
  773. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  774. "</pre>\n"
  775. ],
  776. "text/plain": [
  777. "\n"
  778. ]
  779. },
  780. "metadata": {},
  781. "output_type": "display_data"
  782. },
  783. {
  784. "data": {
  785. "text/html": [
  786. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  787. "</pre>\n"
  788. ],
  789. "text/plain": [
  790. "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  791. ]
  792. },
  793. "metadata": {},
  794. "output_type": "display_data"
  795. },
  796. {
  797. "data": {
  798. "text/html": [
  799. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  800. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.31</span>,\n",
  801. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
  802. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31.0</span>\n",
  803. "<span style=\"font-weight: bold\">}</span>\n",
  804. "</pre>\n"
  805. ],
  806. "text/plain": [
  807. "\u001b[1m{\u001b[0m\n",
  808. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.31\u001b[0m,\n",
  809. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
  810. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m31.0\u001b[0m\n",
  811. "\u001b[1m}\u001b[0m\n"
  812. ]
  813. },
  814. "metadata": {},
  815. "output_type": "display_data"
  816. },
  817. {
  818. "data": {
  819. "text/html": [
  820. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  821. "</pre>\n"
  822. ],
  823. "text/plain": [
  824. "\n"
  825. ]
  826. },
  827. "metadata": {},
  828. "output_type": "display_data"
  829. },
  830. {
  831. "data": {
  832. "text/html": [
  833. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  834. "</pre>\n"
  835. ],
  836. "text/plain": [
  837. "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  838. ]
  839. },
  840. "metadata": {},
  841. "output_type": "display_data"
  842. },
  843. {
  844. "data": {
  845. "text/html": [
  846. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  847. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.33</span>,\n",
  848. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
  849. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">33.0</span>\n",
  850. "<span style=\"font-weight: bold\">}</span>\n",
  851. "</pre>\n"
  852. ],
  853. "text/plain": [
  854. "\u001b[1m{\u001b[0m\n",
  855. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.33\u001b[0m,\n",
  856. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
  857. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m33.0\u001b[0m\n",
  858. "\u001b[1m}\u001b[0m\n"
  859. ]
  860. },
  861. "metadata": {},
  862. "output_type": "display_data"
  863. },
  864. {
  865. "data": {
  866. "text/html": [
  867. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  868. "</pre>\n"
  869. ],
  870. "text/plain": [
  871. "\n"
  872. ]
  873. },
  874. "metadata": {},
  875. "output_type": "display_data"
  876. },
  877. {
  878. "data": {
  879. "text/html": [
  880. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  881. "</pre>\n"
  882. ],
  883. "text/plain": [
  884. "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  885. ]
  886. },
  887. "metadata": {},
  888. "output_type": "display_data"
  889. },
  890. {
  891. "data": {
  892. "text/html": [
  893. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  894. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.34</span>,\n",
  895. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
  896. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">34.0</span>\n",
  897. "<span style=\"font-weight: bold\">}</span>\n",
  898. "</pre>\n"
  899. ],
  900. "text/plain": [
  901. "\u001b[1m{\u001b[0m\n",
  902. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.34\u001b[0m,\n",
  903. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
  904. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m34.0\u001b[0m\n",
  905. "\u001b[1m}\u001b[0m\n"
  906. ]
  907. },
  908. "metadata": {},
  909. "output_type": "display_data"
  910. },
  911. {
  912. "data": {
  913. "text/html": [
  914. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  915. "</pre>\n"
  916. ],
  917. "text/plain": [
  918. "\n"
  919. ]
  920. },
  921. "metadata": {},
  922. "output_type": "display_data"
  923. },
  924. {
  925. "data": {
  926. "text/html": [
  927. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  928. "</pre>\n"
  929. ],
  930. "text/plain": [
  931. "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  932. ]
  933. },
  934. "metadata": {},
  935. "output_type": "display_data"
  936. },
  937. {
  938. "data": {
  939. "text/html": [
  940. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  941. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n",
  942. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
  943. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n",
  944. "<span style=\"font-weight: bold\">}</span>\n",
  945. "</pre>\n"
  946. ],
  947. "text/plain": [
  948. "\u001b[1m{\u001b[0m\n",
  949. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
  950. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
  951. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
  952. "\u001b[1m}\u001b[0m\n"
  953. ]
  954. },
  955. "metadata": {},
  956. "output_type": "display_data"
  957. },
  958. {
  959. "data": {
  960. "text/html": [
  961. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  962. "</pre>\n"
  963. ],
  964. "text/plain": [
  965. "\n"
  966. ]
  967. },
  968. "metadata": {},
  969. "output_type": "display_data"
  970. },
  971. {
  972. "data": {
  973. "text/html": [
  974. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  975. "</pre>\n"
  976. ],
  977. "text/plain": [
  978. "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  979. ]
  980. },
  981. "metadata": {},
  982. "output_type": "display_data"
  983. },
  984. {
  985. "data": {
  986. "text/html": [
  987. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  988. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n",
  989. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
  990. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n",
  991. "<span style=\"font-weight: bold\">}</span>\n",
  992. "</pre>\n"
  993. ],
  994. "text/plain": [
  995. "\u001b[1m{\u001b[0m\n",
  996. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
  997. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
  998. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
  999. "\u001b[1m}\u001b[0m\n"
  1000. ]
  1001. },
  1002. "metadata": {},
  1003. "output_type": "display_data"
  1004. },
  1005. {
  1006. "data": {
  1007. "text/html": [
  1008. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1009. "</pre>\n"
  1010. ],
  1011. "text/plain": [
  1012. "\n"
  1013. ]
  1014. },
  1015. "metadata": {},
  1016. "output_type": "display_data"
  1017. },
  1018. {
  1019. "data": {
  1020. "text/html": [
  1021. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1022. "</pre>\n"
  1023. ],
  1024. "text/plain": [
  1025. "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1026. ]
  1027. },
  1028. "metadata": {},
  1029. "output_type": "display_data"
  1030. },
  1031. {
  1032. "data": {
  1033. "text/html": [
  1034. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1035. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n",
  1036. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
  1037. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n",
  1038. "<span style=\"font-weight: bold\">}</span>\n",
  1039. "</pre>\n"
  1040. ],
  1041. "text/plain": [
  1042. "\u001b[1m{\u001b[0m\n",
  1043. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
  1044. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
  1045. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
  1046. "\u001b[1m}\u001b[0m\n"
  1047. ]
  1048. },
  1049. "metadata": {},
  1050. "output_type": "display_data"
  1051. },
  1052. {
  1053. "data": {
  1054. "text/html": [
  1055. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1056. "</pre>\n"
  1057. ],
  1058. "text/plain": [
  1059. "\n"
  1060. ]
  1061. },
  1062. "metadata": {},
  1063. "output_type": "display_data"
  1064. },
  1065. {
  1066. "data": {
  1067. "text/html": [
  1068. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1069. "</pre>\n"
  1070. ],
  1071. "text/plain": [
  1072. "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1073. ]
  1074. },
  1075. "metadata": {},
  1076. "output_type": "display_data"
  1077. },
  1078. {
  1079. "data": {
  1080. "text/html": [
  1081. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1082. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n",
  1083. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
  1084. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n",
  1085. "<span style=\"font-weight: bold\">}</span>\n",
  1086. "</pre>\n"
  1087. ],
  1088. "text/plain": [
  1089. "\u001b[1m{\u001b[0m\n",
  1090. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
  1091. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
  1092. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
  1093. "\u001b[1m}\u001b[0m\n"
  1094. ]
  1095. },
  1096. "metadata": {},
  1097. "output_type": "display_data"
  1098. },
  1099. {
  1100. "data": {
  1101. "text/html": [
  1102. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1103. "</pre>\n"
  1104. ],
  1105. "text/plain": [
  1106. "\n"
  1107. ]
  1108. },
  1109. "metadata": {},
  1110. "output_type": "display_data"
  1111. },
  1112. {
  1113. "data": {
  1114. "text/html": [
  1115. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1116. "</pre>\n"
  1117. ],
  1118. "text/plain": [
  1119. "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1120. ]
  1121. },
  1122. "metadata": {},
  1123. "output_type": "display_data"
  1124. },
  1125. {
  1126. "data": {
  1127. "text/html": [
  1128. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1129. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n",
  1130. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
  1131. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n",
  1132. "<span style=\"font-weight: bold\">}</span>\n",
  1133. "</pre>\n"
  1134. ],
  1135. "text/plain": [
  1136. "\u001b[1m{\u001b[0m\n",
  1137. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
  1138. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
  1139. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
  1140. "\u001b[1m}\u001b[0m\n"
  1141. ]
  1142. },
  1143. "metadata": {},
  1144. "output_type": "display_data"
  1145. },
  1146. {
  1147. "data": {
  1148. "text/html": [
  1149. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1150. "</pre>\n"
  1151. ],
  1152. "text/plain": [
  1153. "\n"
  1154. ]
  1155. },
  1156. "metadata": {},
  1157. "output_type": "display_data"
  1158. },
  1159. {
  1160. "data": {
  1161. "text/html": [
  1162. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1163. "</pre>\n"
  1164. ],
  1165. "text/plain": [
  1166. "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1167. ]
  1168. },
  1169. "metadata": {},
  1170. "output_type": "display_data"
  1171. },
  1172. {
  1173. "data": {
  1174. "text/html": [
  1175. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1176. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.37</span>,\n",
  1177. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
  1178. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">37.0</span>\n",
  1179. "<span style=\"font-weight: bold\">}</span>\n",
  1180. "</pre>\n"
  1181. ],
  1182. "text/plain": [
  1183. "\u001b[1m{\u001b[0m\n",
  1184. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.37\u001b[0m,\n",
  1185. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
  1186. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m37.0\u001b[0m\n",
  1187. "\u001b[1m}\u001b[0m\n"
  1188. ]
  1189. },
  1190. "metadata": {},
  1191. "output_type": "display_data"
  1192. },
  1193. {
  1194. "data": {
  1195. "text/html": [
  1196. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1197. "</pre>\n"
  1198. ],
  1199. "text/plain": [
  1200. "\n"
  1201. ]
  1202. },
  1203. "metadata": {},
  1204. "output_type": "display_data"
  1205. },
  1206. {
  1207. "data": {
  1208. "text/html": [
  1209. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1210. "</pre>\n"
  1211. ],
  1212. "text/plain": [
  1213. "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1214. ]
  1215. },
  1216. "metadata": {},
  1217. "output_type": "display_data"
  1218. },
  1219. {
  1220. "data": {
  1221. "text/html": [
  1222. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1223. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.4</span>,\n",
  1224. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
  1225. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">40.0</span>\n",
  1226. "<span style=\"font-weight: bold\">}</span>\n",
  1227. "</pre>\n"
  1228. ],
  1229. "text/plain": [
  1230. "\u001b[1m{\u001b[0m\n",
  1231. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.4\u001b[0m,\n",
  1232. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
  1233. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m40.0\u001b[0m\n",
  1234. "\u001b[1m}\u001b[0m\n"
  1235. ]
  1236. },
  1237. "metadata": {},
  1238. "output_type": "display_data"
  1239. },
  1240. {
  1241. "data": {
  1242. "text/html": [
  1243. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  1244. ],
  1245. "text/plain": []
  1246. },
  1247. "metadata": {},
  1248. "output_type": "display_data"
  1249. },
  1250. {
  1251. "data": {
  1252. "text/html": [
  1253. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1254. "</pre>\n"
  1255. ],
  1256. "text/plain": [
  1257. "\n"
  1258. ]
  1259. },
  1260. "metadata": {},
  1261. "output_type": "display_data"
  1262. }
  1263. ],
  1264. "source": [
  1265. "trainer.run()"
  1266. ]
  1267. },
  1268. {
  1269. "cell_type": "code",
  1270. "execution_count": 12,
  1271. "id": "c4e9c619",
  1272. "metadata": {},
  1273. "outputs": [
  1274. {
  1275. "data": {
  1276. "application/vnd.jupyter.widget-view+json": {
  1277. "model_id": "",
  1278. "version_major": 2,
  1279. "version_minor": 0
  1280. },
  1281. "text/plain": [
  1282. "Output()"
  1283. ]
  1284. },
  1285. "metadata": {},
  1286. "output_type": "display_data"
  1287. },
  1288. {
  1289. "data": {
  1290. "text/html": [
  1291. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  1292. ],
  1293. "text/plain": []
  1294. },
  1295. "metadata": {},
  1296. "output_type": "display_data"
  1297. },
  1298. {
  1299. "data": {
  1300. "text/plain": [
  1301. "{'acc#acc': 0.4, 'total#acc': 100.0, 'correct#acc': 40.0}"
  1302. ]
  1303. },
  1304. "execution_count": 12,
  1305. "metadata": {},
  1306. "output_type": "execute_result"
  1307. }
  1308. ],
  1309. "source": [
  1310. "trainer.evaluator.run()"
  1311. ]
  1312. },
  1313. {
  1314. "cell_type": "code",
  1315. "execution_count": null,
  1316. "id": "1bc7cb4a",
  1317. "metadata": {},
  1318. "outputs": [],
  1319. "source": []
  1320. }
  1321. ],
  1322. "metadata": {
  1323. "kernelspec": {
  1324. "display_name": "Python 3 (ipykernel)",
  1325. "language": "python",
  1326. "name": "python3"
  1327. },
  1328. "language_info": {
  1329. "codemirror_mode": {
  1330. "name": "ipython",
  1331. "version": 3
  1332. },
  1333. "file_extension": ".py",
  1334. "mimetype": "text/x-python",
  1335. "name": "python",
  1336. "nbconvert_exporter": "python",
  1337. "pygments_lexer": "ipython3",
  1338. "version": "3.7.13"
  1339. },
  1340. "pycharm": {
  1341. "stem_cell": {
  1342. "cell_type": "raw",
  1343. "metadata": {
  1344. "collapsed": false
  1345. },
  1346. "source": []
  1347. }
  1348. }
  1349. },
  1350. "nbformat": 4,
  1351. "nbformat_minor": 5
  1352. }