You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_tutorial_5.ipynb 55 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "fdd7ff16",
  6. "metadata": {},
  7. "source": [
  8. "# T5. trainer 和 evaluator 的深入介绍\n",
  9. "\n",
  10. "  1   fastNLP 中 driver 的补充介绍\n",
  11. " \n",
  12. "    1.1   trainer 和 driver 的构想 \n",
  13. "\n",
  14. "    1.2   device 与 多卡训练\n",
  15. "\n",
  16. "  2   fastNLP 中的更多 metric 类型\n",
  17. "\n",
  18. "    2.1   预定义的 metric 类型\n",
  19. "\n",
  20. "    2.2   自定义的 metric 类型\n",
  21. "\n",
  22. "  3   fastNLP 中 trainer 的补充介绍\n",
  23. "\n",
  24. "    3.1   trainer 的内部结构"
  25. ]
  26. },
  27. {
  28. "cell_type": "markdown",
  29. "id": "08752c5a",
  30. "metadata": {
  31. "pycharm": {
  32. "name": "#%% md\n"
  33. }
  34. },
  35. "source": [
  36. "## 1. fastNLP 中 driver 的补充介绍\n",
  37. "\n",
  38. "### 1.1 trainer 和 driver 的构想\n",
  39. "\n",
  40. "在`fastNLP 1.0`中,模型训练最关键的模块便是**训练模块 trainer 、评测模块 evaluator 、驱动模块 driver**,\n",
  41. "\n",
  42. "  在`tutorial 0`中,已经简单介绍过上述三个模块:**driver 用来控制训练评测中的 model 的最终运行**\n",
  43. "\n",
  44. "    **evaluator 封装评测的 metric**,**trainer 封装训练的 optimizer**,**也可以包括 evaluator**\n",
  45. "\n",
  46. "之所以做出上述的划分,其根本目的在于要**达成对于多个 python 学习框架**,**例如 pytorch 、 paddle 、 jittor 的兼容**\n",
  47. "\n",
  48. "  对于训练环节,其伪代码如下方左边紫色一栏所示,由于**不同框架对模型、损失、张量的定义各有不同**,所以将训练环节\n",
  49. "\n",
  50. "    划分为**框架无关的循环控制、批量分发部分**,**由 trainer 模块负责**实现,对应的伪代码如下方中间一栏所示\n",
  51. "\n",
  52. "    以及**随框架不同的模型调用、数值优化部分**,**由 driver 模块负责**实现,对应的伪代码如下方右边一栏所示\n",
  53. "\n",
  54. "|训练过程|框架无关 对应`Trainer`|框架相关 对应`Driver`\n",
  55. "|----|----|----|\n",
  56. "| try: | try: | |\n",
  57. "| for epoch in 1:n_eoochs: | for epoch in 1:n_eoochs: | |\n",
  58. "| for step in 1:total_steps: | for step in 1:total_steps: | |\n",
  59. "| batch = fetch_batch() | batch = fetch_batch() | |\n",
  60. "| loss = model.forward(batch)  | | loss = model.forward(batch)  |\n",
  61. "| loss.backward() | | loss.backward() |\n",
  62. "| model.clear_grad() | | model.clear_grad() |\n",
  63. "| model.update() | | model.update() |\n",
  64. "| if need_save: | if need_save: | |\n",
  65. "| model.save() | | model.save() |\n",
  66. "| except: | except: | |\n",
  67. "| process_exception() | process_exception() | |"
  68. ]
  69. },
  70. {
  71. "cell_type": "markdown",
  72. "id": "3e55f07b",
  73. "metadata": {},
  74. "source": [
  75. "  对于评测环节,其伪代码如下方左边紫色一栏所示,同样由于不同框架对模型、损失、张量的定义各有不同,所以将评测环节\n",
  76. "\n",
  77. "    划分为**框架无关的循环控制、分发汇总部分**,**由 evaluator 模块负责**实现,对应的伪代码如下方中间一栏所示\n",
  78. "\n",
  79. "    以及**随框架不同的模型调用、评测计算部分**,同样**由 driver 模块负责**实现,对应的伪代码如下方右边一栏所示\n",
  80. "\n",
  81. "|评测过程|框架无关 对应`Evaluator`|框架相关 对应`Driver`\n",
  82. "|----|----|----|\n",
  83. "| try: | try: | |\n",
  84. "| model.set_eval() | model.set_eval() | |\n",
  85. "| for step in 1:total_steps: | for step in 1:total_steps: | |\n",
  86. "| batch = fetch_batch() | batch = fetch_batch() | |\n",
  87. "| outputs = model.evaluate(batch)  | | outputs = model.evaluate(batch)  |\n",
  88. "| metric.compute(batch, outputs) | | metric.compute(batch, outputs) |\n",
  89. "| results = metric.get_metric() | results = metric.get_metric() | |\n",
  90. "| except: | except: | |\n",
  91. "| process_exception() | process_exception() | |"
  92. ]
  93. },
  94. {
  95. "cell_type": "markdown",
  96. "id": "94ba11c6",
  97. "metadata": {
  98. "pycharm": {
  99. "name": "#%%\n"
  100. }
  101. },
  102. "source": [
  103. "由此,从程序员的角度,`fastNLP v1.0` **通过一个 driver 让基于 pytorch 、 paddle 、 jittor 、 oneflow 框架的模型**\n",
  104. "\n",
  105. "    **都能在相同的 trainer 和 evaluator 上运行**,这也**是 fastNLP v1.0 相比于之前版本的一大亮点**\n",
  106. "\n",
  107. "  而从`driver`的角度,`fastNLP v1.0`通过定义一个`driver`基类,**将所有张量转化为 numpy.tensor**\n",
  108. "\n",
  109. "    并由此泛化出`torch_driver`、`paddle_driver`、`jittor_driver`三个子类,从而实现了\n",
  110. "\n",
  111. "    对`pytorch`、`paddle`、`jittor`的兼容,有关后两者的实践请参考接下来的`tutorial-6`"
  112. ]
  113. },
  114. {
  115. "cell_type": "markdown",
  116. "id": "ab1cea7d",
  117. "metadata": {},
  118. "source": [
  119. "### 1.2 device 与 多卡训练\n",
  120. "\n",
  121. "**fastNLP v1.0 支持多卡训练**,实现方法则是**通过将 trainer 中的 device 设置为对应显卡的序号列表**\n",
  122. "\n",
  123. "  由单卡切换成多卡,无论是数据、模型还是评测都会面临一定的调整,`fastNLP v1.0`保证:\n",
  124. "\n",
  125. "    数据拆分时,不同卡之间相互协调,所有数据都可以被训练,且不会使用到相同的数据\n",
  126. "\n",
  127. "    模型训练时,模型之间需要交换梯度;评测计算时,每张卡先各自计算,再汇总结果\n",
  128. "\n",
  129. "  例如,在评测计算运行`get_metric`函数时,`fastNLP v1.0`将自动按照`self.right`和`self.total`\n",
  130. "\n",
  131. "    指定的 **aggregate_method 方法**,默认为`sum`,将每张卡上结果汇总起来,因此最终\n",
  132. "\n",
  133. "    在调用`get_metric`方法时,`Accuracy`类能够返回全部的统计结果,代码如下\n",
  134. " \n",
  135. "```python\n",
  136. "trainer = Trainer(\n",
  137. " model=model, # model 基于 pytorch 实现 \n",
  138. " train_dataloader=train_dataloader,\n",
  139. " optimizers=optimizer,\n",
  140. " ...\n",
  141. " driver='torch', # driver 使用 torch_driver \n",
  142. " device=[0, 1], # gpu 选择 cuda:0 + cuda:1\n",
  143. " ...\n",
  144. " evaluate_dataloaders=evaluate_dataloader,\n",
  145. " metrics={'acc': Accuracy()},\n",
  146. " ...\n",
  147. " )\n",
  148. "\n",
  149. "class Accuracy(Metric):\n",
  150. " def __init__(self):\n",
  151. " super().__init__()\n",
  152. " self.register_element(name='total', value=0, aggregate_method='sum')\n",
  153. " self.register_element(name='right', value=0, aggregate_method='sum')\n",
  154. "```\n"
  155. ]
  156. },
  157. {
  158. "cell_type": "markdown",
  159. "id": "e2e0a210",
  160. "metadata": {
  161. "pycharm": {
  162. "name": "#%%\n"
  163. }
  164. },
  165. "source": [
  166. "注:`fastNLP v1.0`中要求`jupyter`不能多卡,仅能单卡,故在所有`tutorial`中均不作相关演示"
  167. ]
  168. },
  169. {
  170. "cell_type": "markdown",
  171. "id": "8d19220c",
  172. "metadata": {},
  173. "source": [
  174. "## 2. fastNLP 中的更多 metric 类型\n",
  175. "\n",
  176. "### 2.1 预定义的 metric 类型\n",
  177. "\n",
  178. "在`fastNLP 1.0`中,除了前几篇`tutorial`中经常见到的**正确率 Accuracy**,还有其他**预定义的评测标准 metric**\n",
  179. "\n",
  180. "  包括**所有 metric 的基类 Metric**、适配`Transformers`中相关模型的正确率`TransformersAccuracy`\n",
  181. "\n",
  182. "    **适用于分类语境下的 F1 值 ClassifyFPreRecMetric**(其中也包括召回率`Pre`、精确率`Rec`\n",
  183. "\n",
  184. "    **适用于抽取语境下的 F1 值 SpanFPreRecMetric**;相关基本信息内容见下表,之后是详细分析\n",
  185. "\n",
  186. "代码名称|简要介绍|代码路径\n",
  187. "----|----|----|\n",
  188. " `Metric` | 定义`metrics`时继承的基类 | `/core/metrics/metric.py` |\n",
  189. " `Accuracy` | 正确率,最为常用 | `/core/metrics/accuracy.py` |\n",
  190. " `TransformersAccuracy` | 正确率,为了兼容`Transformers`中相关模型 | `/core/metrics/accuracy.py` |\n",
  191. " `ClassifyFPreRecMetric` | 召回率、精确率、F1值,适用于**分类问题** | `/core/metrics/classify_f1_pre_rec_metric.py` |\n",
  192. " `SpanFPreRecMetric` | 召回率、精确率、F1值,适用于**抽取问题** | `/core/metrics/span_f1_pre_rec_metric.py` |"
  193. ]
  194. },
  195. {
  196. "cell_type": "markdown",
  197. "id": "fdc083a3",
  198. "metadata": {
  199. "pycharm": {
  200. "name": "#%%\n"
  201. }
  202. },
  203. "source": [
  204. "  如`tutorial-0`中所述,所有的`metric`都包含`get_metric`和`update`函数,其中\n",
  205. "\n",
  206. "    **update 函数更新单个 batch 的统计量**,**get_metric 函数返回最终结果**,并打印显示\n",
  207. "\n",
  208. "\n",
  209. "### 2.1.1 Accuracy 与 TransformersAccuracy\n",
  210. "\n",
  211. "`Accuracy`,正确率,预测正确的数据`right_num`在总数据`total_num`,中的占比(公式就不用列了\n",
  212. "\n",
  213. "  `get_metric`函数打印格式为 **{\"acc#xx\": float, 'total#xx': float, 'correct#xx': float}**\n",
  214. "\n",
  215. "  一般在初始化时不需要传参,`fastNLP`会根据`update`函数的传入参数确定对应后台框架`backend`\n",
  216. "\n",
  217. "  **update 函数的参数包括 pred 、 target 、 seq_len**,**后者用来标记批次中每笔数据的长度**\n",
  218. "\n",
  219. "`TransformersAccuracy`,继承自`Accuracy`,只是为了兼容`Transformers`框架中相关模型\n",
  220. "\n",
  221. "  在`update`函数中,将`Transformers`框架输出的`attention_mask`参数转化为`seq_len`参数\n",
  222. "\n",
  223. "\n",
  224. "### 2.1.2 ClassifyFPreRecMetric 与 SpanFPreRecMetric\n",
  225. "\n",
  226. "`ClassifyFPreRecMetric`,分类评价,`SpanFPreRecMetric`,抽取评价,后者在`tutorial-4`中已出现\n",
  227. "\n",
  228. "  两者的相同之处在于:**第一**,**都包括召回率/查全率 ec**、**精确率/查准率 Pre**、**F1 值**这三个指标\n",
  229. "\n",
  230. "    `get_metric`函数打印格式为 **{\"f#xx\": float, 'pre#xx': float, 'rec#xx': float}**\n",
  231. "\n",
  232. "    三者的计算公式如下,其中`beta`默认为`1`,即`F1`值是召回率`Rec`和精确率`Pre`的调和平均数\n",
  233. "\n",
  234. "$$\\text{召回率}\\ Rec=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有本来是正例的数量}}\\qquad \\text{精确率}\\ Pre=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有预测为正例的数量}}$$\n",
  235. "\n",
  236. "$$F_{beta} = \\frac{(1 + {beta}^{2})*(Pre*Rec)}{({beta}^{2}*Pre + Rec)}$$\n",
  237. "\n",
  238. "  **第二**,可以通过参数`only_gross`为`False`,要求返回所有类别的`Rec-Pre-F1`,同时`F1`值又根据参数`f_type`又分为\n",
  239. "\n",
  240. "    **micro F1**(**直接统计所有类别的 Rec-Pre-F1**)、**macro F1**(**统计各类别的 Rec-Pre-F1 再算术平均**)\n",
  241. "\n",
  242. "  **第三**,两者在初始化时还可以**传入基于 fastNLP.Vocabulary 的 tag_vocab 参数记录数据集中的标签序号**\n",
  243. "\n",
  244. "    **与标签名称之间的映射**,通过字符串列表`ignore_labels`参数,指定若干标签不用于`Rec-Pre-F1`的计算\n",
  245. "\n",
  246. "两者的不同之处在于:`ClassifyFPreRecMetric`针对简单的分类问题,每个分类标签之间彼此独立,不构成标签对\n",
  247. "\n",
  248. "    **SpanFPreRecMetric 针对更复杂的抽取问题**,**规定标签 B-xx 和 I-xx 或 B-xx 和 E-xx 构成标签对**\n",
  249. "\n",
  250. "  在计算`Rec-Pre-F1`时,`ClassifyFPreRecMetric`只需要考虑标签本身是否正确这就足够了,但是\n",
  251. "\n",
  252. "    对于`SpanFPreRecMetric`,需要保证**标签符合规则且覆盖的区间与正确结果重合才算正确**\n",
  253. "\n",
  254. "    因此回到`tutorial-4`中`CoNLL-2003`的`NER`任务,如果评测方法选择`ClassifyFPreRecMetric`\n",
  255. "\n",
  256. "      或者`Accuracy`,会发现虽然评测结果显示很高,这是因为选择的评测方法要求太低\n",
  257. "\n",
  258. "    最后通过`CoNLL-2003`的词性标注`POS`任务简单演示下`ClassifyFPreRecMetric`相关的使用\n",
  259. "\n",
  260. "```python\n",
  261. "from fastNLP import Vocabulary\n",
  262. "from fastNLP import ClassifyFPreRecMetric\n",
  263. "\n",
  264. "tag_vocab = Vocabulary(padding=None, unknown=None) # 记录序号与标签之间的映射\n",
  265. "tag_vocab.add_word_lst(['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', \n",
  266. " 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', \n",
  267. " 'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', \n",
  268. " 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', \n",
  269. " 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP+', 'WRB', ]) # CoNLL-2003 中的 pos_tags\n",
  270. "ignore_labels = ['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', ]\n",
  271. "\n",
  272. "FPreRec = ClassifyFPreRecMetric(tag_vocab=tag_vocab, \n",
  273. " ignore_labels=ignore_labels, # 表示评测/优化中不考虑上述标签的正误/损失\n",
  274. " only_gross=True, # 默认为 True 表示输出所有类别的综合统计结果\n",
  275. " f_type='micro') # 默认为 'micro' 表示统计所有类别的 Rec-Pre-F1\n",
  276. "metrics = {'F1': FPreRec}\n",
  277. "```"
  278. ]
  279. },
  280. {
  281. "cell_type": "markdown",
  282. "id": "8a22f522",
  283. "metadata": {},
  284. "source": [
  285. "### 2.2 自定义的 metric 类型\n",
  286. "\n",
  287. "如上文所述,`Metric`作为所有`metric`的基类,`Accuracy`等都是其子类,同样地,对于**自定义的 metric 类型**\n",
  288. "\n",
  289. "    也**需要继承自 Metric 类**,同时**内部自定义好 __init__ 、 update 和 get_metric 函数**\n",
  290. "\n",
  291. "  在`__init__`函数中,根据需求定义评测时需要用到的变量,此处沿用`Accuracy`中的`total_num`和`right_num`\n",
  292. "\n",
  293. "  在`update`函数中,根据需求定义评测变量的更新方式,需要注意的是如`tutorial-0`中所述,**update`的参数名**\n",
  294. "\n",
  295. "    **需要待评估模型在 evaluate_step 中的输出名称一致**,由此**和数据集中对应字段名称一致**,即**参数匹配**\n",
  296. "\n",
  297. "    在`fastNLP v1.0`中,`update`函数的默认输入参数:`pred`,对应预测值;`target`,对应真实值\n",
  298. "\n",
  299. "    此处仍然沿用,因为接下来会需要使用`fastNLP`函数的与定义模型,其输入参数格式即使如此\n",
  300. "\n",
  301. "  在`get_metric`函数中,根据需求定义评测指标最终的计算,此处直接计算准确率,该函数必须返回一个字典\n",
  302. "\n",
  303. "    其中,字串`'prefix'`表示该`metric`的名称,会对应显示到`trainer`的`progress bar`中\n",
  304. "\n",
  305. "根据上述要求,这里简单定义了一个名为`MyMetric`的评测模块,用于分类问题的评测,以此展开一个实例展示"
  306. ]
  307. },
  308. {
  309. "cell_type": "code",
  310. "execution_count": 1,
  311. "id": "08a872e9",
  312. "metadata": {},
  313. "outputs": [
  314. {
  315. "data": {
  316. "text/html": [
  317. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  318. "</pre>\n"
  319. ],
  320. "text/plain": [
  321. "\n"
  322. ]
  323. },
  324. "metadata": {},
  325. "output_type": "display_data"
  326. }
  327. ],
  328. "source": [
  329. "import sys\n",
  330. "sys.path.append('..')\n",
  331. "\n",
  332. "from fastNLP import Metric\n",
  333. "\n",
  334. "class MyMetric(Metric):\n",
  335. "\n",
  336. " def __init__(self):\n",
  337. " Metric.__init__(self)\n",
  338. " self.total_num = 0\n",
  339. " self.right_num = 0\n",
  340. "\n",
  341. " def update(self, pred, target):\n",
  342. " self.total_num += target.size(0)\n",
  343. " self.right_num += target.eq(pred).sum().item()\n",
  344. "\n",
  345. " def get_metric(self, reset=True):\n",
  346. " acc = self.right_num / self.total_num\n",
  347. " if reset:\n",
  348. " self.total_num = 0\n",
  349. " self.right_num = 0\n",
  350. " return {'prefix': acc}"
  351. ]
  352. },
  353. {
  354. "cell_type": "markdown",
  355. "id": "0155f447",
  356. "metadata": {},
  357. "source": [
  358. "&emsp; 数据使用方面,此处仍然使用`datasets`模块中的`load_dataset`函数,加载`SST-2`二分类数据集"
  359. ]
  360. },
  361. {
  362. "cell_type": "code",
  363. "execution_count": 2,
  364. "id": "5ad81ac7",
  365. "metadata": {
  366. "pycharm": {
  367. "name": "#%%\n"
  368. }
  369. },
  370. "outputs": [
  371. {
  372. "name": "stderr",
  373. "output_type": "stream",
  374. "text": [
  375. "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
  376. ]
  377. },
  378. {
  379. "data": {
  380. "application/vnd.jupyter.widget-view+json": {
  381. "model_id": "ef923b90b19847f4916cccda5d33fc36",
  382. "version_major": 2,
  383. "version_minor": 0
  384. },
  385. "text/plain": [
  386. " 0%| | 0/3 [00:00<?, ?it/s]"
  387. ]
  388. },
  389. "metadata": {},
  390. "output_type": "display_data"
  391. }
  392. ],
  393. "source": [
  394. "from datasets import load_dataset\n",
  395. "\n",
  396. "sst2data = load_dataset('glue', 'sst2')"
  397. ]
  398. },
  399. {
  400. "cell_type": "markdown",
  401. "id": "e9d81760",
  402. "metadata": {},
  403. "source": [
  404. "&emsp; 在数据预处理中,需要注意的是,这里原本应该根据`metric`和`model`的输入参数格式,调整\n",
  405. "\n",
  406. "&emsp; &emsp; 数据集中表示预测目标的字段,调整为`target`,在后文中会揭晓为什么,以及如何补救"
  407. ]
  408. },
  409. {
  410. "cell_type": "code",
  411. "execution_count": 3,
  412. "id": "cfb28b1b",
  413. "metadata": {
  414. "pycharm": {
  415. "name": "#%%\n"
  416. }
  417. },
  418. "outputs": [
  419. {
  420. "data": {
  421. "application/vnd.jupyter.widget-view+json": {
  422. "model_id": "",
  423. "version_major": 2,
  424. "version_minor": 0
  425. },
  426. "text/plain": [
  427. "Processing: 0%| | 0/6000 [00:00<?, ?it/s]"
  428. ]
  429. },
  430. "metadata": {},
  431. "output_type": "display_data"
  432. }
  433. ],
  434. "source": [
  435. "from fastNLP import DataSet\n",
  436. "\n",
  437. "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n",
  438. "\n",
  439. "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split()}, progress_bar=\"tqdm\")\n",
  440. "dataset.delete_field('sentence')\n",
  441. "dataset.delete_field('idx')\n",
  442. "\n",
  443. "from fastNLP import Vocabulary\n",
  444. "\n",
  445. "vocab = Vocabulary()\n",
  446. "vocab.from_dataset(dataset, field_name='words')\n",
  447. "vocab.index_dataset(dataset, field_name='words')\n",
  448. "\n",
  449. "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n",
  450. "\n",
  451. "from fastNLP import prepare_torch_dataloader\n",
  452. "\n",
  453. "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
  454. "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
  455. ]
  456. },
  457. {
  458. "cell_type": "markdown",
  459. "id": "af3f8c63",
  460. "metadata": {},
  461. "source": [
  462. "&emsp; 模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类"
  463. ]
  464. },
  465. {
  466. "cell_type": "code",
  467. "execution_count": 4,
  468. "id": "2fd210c5",
  469. "metadata": {},
  470. "outputs": [],
  471. "source": [
  472. "from fastNLP.models.torch import CNNText\n",
  473. "\n",
  474. "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
  475. "\n",
  476. "from torch.optim import AdamW\n",
  477. "\n",
  478. "optimizers = AdamW(params=model.parameters(), lr=5e-4)"
  479. ]
  480. },
  481. {
  482. "cell_type": "markdown",
  483. "id": "6e723b87",
  484. "metadata": {},
  485. "source": [
  486. "## 3. fastNLP 中 trainer 的补充介绍\n",
  487. "\n",
  488. "### 3.1 trainer 的内部结构\n",
  489. "\n",
  490. "在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经展示了\n",
  491. "\n",
  492. "&emsp; 很多`trainer`的使用案例,这里通过表格,相对完整地介绍`trainer`模块的属性和初始化参数(标粗为必选参数\n",
  493. "\n",
  494. "\n",
  495. "名称|参数|属性|功能|内容\n",
  496. "----|----|----|----|----|\n",
  497. "| **model** | √ | √ | 指定`trainer`控制的模型 | 视框架而定,如`torch.nn.Module` |\n",
  498. "| `device` | √ | | 指定`trainer`运行的卡位 | 例如`'cpu'`、`'cuda'`、`0`、`[0, 1]`等 |\n",
  499. "| | | √ | 记录`trainer`运行的卡位 | `Device`类型,在初始化阶段生成 |\n",
  500. "| **driver** | √ | | 指定`trainer`驱动的框架 | 包括`'torch'`、`'paddle'`、`'jittor'` |\n",
  501. "| | | √ | 记录`trainer`驱动的框架 | `Driver`类型,在初始化阶段生成 |\n",
  502. "| `n_epochs` | √ | - | 指定`trainer`迭代的轮数 | 默认`20`,记录在`driver.n_epochs`中 |\n",
  503. "| **optimizers** | √ | √ | 指定`trainer`优化的方法 | 视框架而定,如`torch.optim.Adam` |\n",
  504. "| `metrics` | √ | √ | 指定`trainer`评测的方法 | 字典类型,如`{'acc': Metric()}` |\n",
  505. "| `evaluator` | | √ | 内置的`trainer`评测模块 | `Evaluator`类型,在初始化阶段生成 |\n",
  506. "| `input_mapping` | √ | √ | 调整`dataloader`的参数不匹配 | 函数类型,输出字典匹配`forward`输入参数 |\n",
  507. "| `output_mapping` | √ | √ | 调整`forward`输出的参数不匹配 | 函数类型,输出字典匹配`xx_step`输入参数 |\n",
  508. "| **train_dataloader** | √ | √ | 指定`trainer`训练的数据 | `DataLoader`类型,生成视框架而定 |\n",
  509. "| `evaluate_dataloaders` | √ | √ | 指定`trainer`评测的数据 | `DataLoader`类型,生成视框架而定 |\n",
  510. "| `train_fn` | √ | √ | 指定`trainer`获取某个批次的损失值 | 函数类型,默认为`model.train_step` |\n",
  511. "| `evaluate_fn` | √ | √ | 指定`trainer`获取某个批次的评估量 | 函数类型,默认为`model.evaluate_step` |\n",
  512. "| `batch_step_fn` | √ | √ | 指定`trainer`训练时前向传输一个批次的方式 | 函数类型,默认为`TrainBatchLoop.batch_step_fn` |\n",
  513. "| `evaluate_batch_step_fn` | √ | √ | 指定`trainer`评测时前向传输一个批次的方式 | 函数类型,默认为`EvaluateBatchLoop.batch_step_fn` |\n",
  514. "| `accumulation_steps` | √ | √ | 指定`trainer`训练时反向传播的频率 | 默认为`1`,即每个批次都反向传播 |\n",
  515. "| `evaluate_every` | √ | √ | 指定`evaluator`评测时计算的频率 | 默认`-1`表示每个循环一次,相反`1`表示每个批次一次 |\n",
  516. "| `progress_bar` | √ | √ | 指定`trainer`训练和评测时的进度条样式 | 包括`'auto'`、`'tqdm'`、`'raw'`、`'rich'` |\n",
  517. "| `callbacks` | √ | | 指定`trainer`训练时需要触发的函数 | `Callback`列表类型,详见`tutorial-7` |\n",
  518. "| `callback_manager` | | √ | 记录与管理`callbacks`相关内容 | `CallbackManager`类型,详见`tutorial-7` |\n",
  519. "| `monitor` | √ | √ | 辅助部分的`callbacks`相关内容 | 字符串/函数类型,详见`tutorial-7` |\n",
  520. "| `marker` | √ | √ | 标记`trainer`实例,辅助`callbacks`相关内容 | 字符串型,详见`tutorial-7` |\n",
  521. "| `trainer_state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `TrainerState`类型,详见`tutorial-7` |\n",
  522. "| `state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `State`类型,详见`tutorial-7` |\n",
  523. "| `fp16` | √ | √ | 指定`trainer`是否进行混合精度训练 | 布尔类型,默认`False` |\n",
  524. "\n",
  525. "其中,**input_mapping 和 output_mapping** 定义形式如下:输入字典形式的数据,根据参数匹配要求调整数据格式,这里就回应了前文未在数据集预处理时调整格式的问题,**总之参数匹配一定要求**"
  526. ]
  527. },
  528. {
  529. "cell_type": "code",
  530. "execution_count": 5,
  531. "id": "de96c1d1",
  532. "metadata": {},
  533. "outputs": [],
  534. "source": [
  535. "def input_mapping(data):\n",
  536. " data['target'] = data['label']\n",
  537. " return data"
  538. ]
  539. },
  540. {
  541. "cell_type": "markdown",
  542. "id": "2fc8b9f3",
  543. "metadata": {},
  544. "source": [
  545. "&emsp; 而`trainer`模块的基础方法列表如下,相关进阶操作,如`on`系列函数、`callback`控制,请参考后续的`tutorial-7`\n",
  546. "\n",
  547. "|名称|功能|主要参数|\n",
  548. "|----|----|----|\n",
  549. "| `run` | 控制`trainer`中模型的训练和评测 | 详见后文 |\n",
  550. "| `train_step` | 实现`trainer`训练中一个批数据的前向传播过程 | 输入`batch` |\n",
  551. "| `backward` | 实现`trainer`训练中一次损失的反向传播过程 | 输入`output` |\n",
  552. "| `zero_grad` | 实现`trainer`训练中`optimizers`的梯度置零 | 无输入 |\n",
  553. "| `step` | 实现`trainer`训练中`optimizers`的参数更新 | 无输入 |\n",
  554. "| `epoch_evaluate` | 实现`trainer`训练中每个循环的评测,实际是否执行取决于评测频率 | 无输入 |\n",
  555. "| `step_evaluate` | 实现`trainer`训练中每个批次的评测,实际是否执行取决于评测频率 | 无输入 |\n",
  556. "| `save_model` | 保存`trainer`中的模型参数/状态字典至`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`False` |\n",
  557. "| `load_model` | 加载`trainer`中的模型参数/状态字典自`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只加载状态字典,默认`True` |\n",
  558. "| `save_checkpoint` | 保存`trainer`中模型参数/状态字典 以及 `callback`、`sampler` 和`optimizer`的状态至`fastnlp_model/checkpoint.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True` |\n",
  559. "| `load_checkpoint` | 加载`trainer`中模型参数/状态字典 以及 `callback`、`sampler` 和`optimizer`的状态自`fastnlp_model/checkpoint.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True` `resume_training`指明是否只精确到上次训练的批量,默认`True` |\n",
  560. "| `add_callback_fn` | 在`trainer`初始化后添加`callback`函数 | 输入`event`指明回调时机,`fn`指明回调函数 |\n",
  561. "| `on` | 函数修饰器,将一个函数转变为`callback`函数 | 详见`tutorial-7` |\n",
  562. "\n",
  563. "<!-- ```python\n",
  564. "Trainer.__init__():\n",
  565. "\ton_after_trainer_initialized(trainer, driver)\n",
  566. "Trainer.run():\n",
  567. "\tif num_eval_sanity_batch > 0: # 如果设置了 num_eval_sanity_batch\n",
  568. "\t\ton_sanity_check_begin(trainer)\n",
  569. "\t\ton_sanity_check_end(trainer, sanity_check_res)\n",
  570. "\ttry:\n",
  571. "\t\ton_train_begin(trainer)\n",
  572. "\t\twhile cur_epoch_idx < n_epochs:\n",
  573. "\t\t\ton_train_epoch_begin(trainer)\n",
  574. "\t\t\twhile batch_idx_in_epoch<=num_batches_per_epoch:\n",
  575. "\t\t\t\ton_fetch_data_begin(trainer)\n",
  576. "\t\t\t\tbatch = next(dataloader)\n",
  577. "\t\t\t\ton_fetch_data_end(trainer)\n",
  578. "\t\t\t\ton_train_batch_begin(trainer, batch, indices)\n",
  579. "\t\t\t\ton_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping 后的\n",
  580. "\t\t\t\ton_after_backward(trainer)\n",
  581. "\t\t\t\ton_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
  582. "\t\t\t\ton_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
  583. "\t\t\t\ton_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
  584. "\t\t\t\ton_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
  585. "\t\t\t\ton_train_batch_end(trainer)\n",
  586. "\t\t\ton_train_epoch_end(trainer)\n",
  587. "\texcept BaseException:\n",
  588. "\t\tself.on_exception(trainer, exception)\n",
  589. "\tfinally:\n",
  590. "\t\ton_train_end(trainer)\n",
  591. "``` -->"
  592. ]
  593. },
  594. {
  595. "cell_type": "markdown",
  596. "id": "1e21df35",
  597. "metadata": {},
  598. "source": [
  599. "紧接着,初始化`trainer`实例,继续完成`SST-2`分类,其中`metrics`输入的键值对,字串`'suffix'`和之前定义的\n",
  600. "\n",
  601. "&emsp; 字串`'prefix'`将拼接在一起显示到`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`"
  602. ]
  603. },
  604. {
  605. "cell_type": "code",
  606. "execution_count": 6,
  607. "id": "926a9c50",
  608. "metadata": {},
  609. "outputs": [],
  610. "source": [
  611. "from fastNLP import Trainer\n",
  612. "\n",
  613. "trainer = Trainer(\n",
  614. " model=model,\n",
  615. " driver='torch',\n",
  616. " device=0, # 'cuda'\n",
  617. " n_epochs=10,\n",
  618. " optimizers=optimizers,\n",
  619. " input_mapping=input_mapping,\n",
  620. " train_dataloader=train_dataloader,\n",
  621. " evaluate_dataloaders=evaluate_dataloader,\n",
  622. " metrics={'suffix': MyMetric()}\n",
  623. ")"
  624. ]
  625. },
  626. {
  627. "cell_type": "markdown",
  628. "id": "b1b2e8b7",
  629. "metadata": {
  630. "pycharm": {
  631. "name": "#%%\n"
  632. }
  633. },
  634. "source": [
  635. "最后就是`run`函数的使用,关于其参数,这里也以表格形式列出,由此就解答了`num_eval_batch_per_dl=10`的含义\n",
  636. "\n",
  637. "|名称|功能|默认值|\n",
  638. "|----|----|----|\n",
  639. "| `num_train_batch_per_epoch` | 指定`trainer`训练时,每个循环计算批量数目 | 整数类型,默认`-1`,表示训练时,每个循环计算所有批量 |\n",
  640. "| `num_eval_batch_per_dl` | 指定`trainer`评测时,每个循环计算批量数目 | 整数类型,默认`-1`,表示评测时,每个循环计算所有批量 |\n",
  641. "| `num_eval_sanity_batch` | 指定`trainer`训练开始前,试探性评测批量数目 | 整数类型,默认`2`,表示训练开始前评估两个批量 |\n",
  642. "| `resume_from` | 指定`trainer`恢复状态的路径,需要是文件夹 | 字符串型,默认`None`,使用可参考`CheckpointCallback` |\n",
  643. "| `resume_training` | 指定`trainer`恢复状态的程度 | 布尔类型,默认`True`恢复所有状态,`False`仅恢复`model`和`optimizers`状态 |"
  644. ]
  645. },
  646. {
  647. "cell_type": "code",
  648. "execution_count": 7,
  649. "id": "43be274f",
  650. "metadata": {
  651. "pycharm": {
  652. "name": "#%%\n"
  653. }
  654. },
  655. "outputs": [
  656. {
  657. "data": {
  658. "text/html": [
  659. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[09:30:35] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#596\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">596</span></a>\n",
  660. "</pre>\n"
  661. ],
  662. "text/plain": [
  663. "\u001b[2;36m[09:30:35]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=954293;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=366534;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n"
  664. ]
  665. },
  666. "metadata": {},
  667. "output_type": "display_data"
  668. },
  669. {
  670. "data": {
  671. "application/vnd.jupyter.widget-view+json": {
  672. "model_id": "",
  673. "version_major": 2,
  674. "version_minor": 0
  675. },
  676. "text/plain": [
  677. "Output()"
  678. ]
  679. },
  680. "metadata": {},
  681. "output_type": "display_data"
  682. },
  683. {
  684. "data": {
  685. "text/html": [
  686. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
  687. "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
  688. ".get_parent()\n",
  689. " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
  690. "</pre>\n"
  691. ],
  692. "text/plain": [
  693. "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
  694. "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
  695. ".get_parent()\n",
  696. " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n"
  697. ]
  698. },
  699. "metadata": {},
  700. "output_type": "display_data"
  701. },
  702. {
  703. "data": {
  704. "text/html": [
  705. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
  706. "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
  707. ".get_parent()\n",
  708. " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
  709. "</pre>\n"
  710. ],
  711. "text/plain": [
  712. "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
  713. "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
  714. ".get_parent()\n",
  715. " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n"
  716. ]
  717. },
  718. "metadata": {},
  719. "output_type": "display_data"
  720. },
  721. {
  722. "data": {
  723. "text/html": [
  724. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  725. ],
  726. "text/plain": []
  727. },
  728. "metadata": {},
  729. "output_type": "display_data"
  730. },
  731. {
  732. "data": {
  733. "application/vnd.jupyter.widget-view+json": {
  734. "model_id": "",
  735. "version_major": 2,
  736. "version_minor": 0
  737. },
  738. "text/plain": [
  739. "Output()"
  740. ]
  741. },
  742. "metadata": {},
  743. "output_type": "display_data"
  744. },
  745. {
  746. "data": {
  747. "text/html": [
  748. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  749. "</pre>\n"
  750. ],
  751. "text/plain": [
  752. "\n"
  753. ]
  754. },
  755. "metadata": {},
  756. "output_type": "display_data"
  757. },
  758. {
  759. "data": {
  760. "text/html": [
  761. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  762. "</pre>\n"
  763. ],
  764. "text/plain": [
  765. "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  766. ]
  767. },
  768. "metadata": {},
  769. "output_type": "display_data"
  770. },
  771. {
  772. "data": {
  773. "text/html": [
  774. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  775. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6875</span>\n",
  776. "<span style=\"font-weight: bold\">}</span>\n",
  777. "</pre>\n"
  778. ],
  779. "text/plain": [
  780. "\u001b[1m{\u001b[0m\n",
  781. " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.6875\u001b[0m\n",
  782. "\u001b[1m}\u001b[0m\n"
  783. ]
  784. },
  785. "metadata": {},
  786. "output_type": "display_data"
  787. },
  788. {
  789. "data": {
  790. "text/html": [
  791. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  792. "</pre>\n"
  793. ],
  794. "text/plain": [
  795. "\n"
  796. ]
  797. },
  798. "metadata": {},
  799. "output_type": "display_data"
  800. },
  801. {
  802. "data": {
  803. "text/html": [
  804. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  805. "</pre>\n"
  806. ],
  807. "text/plain": [
  808. "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  809. ]
  810. },
  811. "metadata": {},
  812. "output_type": "display_data"
  813. },
  814. {
  815. "data": {
  816. "text/html": [
  817. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  818. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8125</span>\n",
  819. "<span style=\"font-weight: bold\">}</span>\n",
  820. "</pre>\n"
  821. ],
  822. "text/plain": [
  823. "\u001b[1m{\u001b[0m\n",
  824. " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n",
  825. "\u001b[1m}\u001b[0m\n"
  826. ]
  827. },
  828. "metadata": {},
  829. "output_type": "display_data"
  830. },
  831. {
  832. "data": {
  833. "text/html": [
  834. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  835. "</pre>\n"
  836. ],
  837. "text/plain": [
  838. "\n"
  839. ]
  840. },
  841. "metadata": {},
  842. "output_type": "display_data"
  843. },
  844. {
  845. "data": {
  846. "text/html": [
  847. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  848. "</pre>\n"
  849. ],
  850. "text/plain": [
  851. "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  852. ]
  853. },
  854. "metadata": {},
  855. "output_type": "display_data"
  856. },
  857. {
  858. "data": {
  859. "text/html": [
  860. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  861. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n",
  862. "<span style=\"font-weight: bold\">}</span>\n",
  863. "</pre>\n"
  864. ],
  865. "text/plain": [
  866. "\u001b[1m{\u001b[0m\n",
  867. " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
  868. "\u001b[1m}\u001b[0m\n"
  869. ]
  870. },
  871. "metadata": {},
  872. "output_type": "display_data"
  873. },
  874. {
  875. "data": {
  876. "text/html": [
  877. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  878. "</pre>\n"
  879. ],
  880. "text/plain": [
  881. "\n"
  882. ]
  883. },
  884. "metadata": {},
  885. "output_type": "display_data"
  886. },
  887. {
  888. "data": {
  889. "text/html": [
  890. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  891. "</pre>\n"
  892. ],
  893. "text/plain": [
  894. "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  895. ]
  896. },
  897. "metadata": {},
  898. "output_type": "display_data"
  899. },
  900. {
  901. "data": {
  902. "text/html": [
  903. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  904. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.825</span>\n",
  905. "<span style=\"font-weight: bold\">}</span>\n",
  906. "</pre>\n"
  907. ],
  908. "text/plain": [
  909. "\u001b[1m{\u001b[0m\n",
  910. " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.825\u001b[0m\n",
  911. "\u001b[1m}\u001b[0m\n"
  912. ]
  913. },
  914. "metadata": {},
  915. "output_type": "display_data"
  916. },
  917. {
  918. "data": {
  919. "text/html": [
  920. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  921. "</pre>\n"
  922. ],
  923. "text/plain": [
  924. "\n"
  925. ]
  926. },
  927. "metadata": {},
  928. "output_type": "display_data"
  929. },
  930. {
  931. "data": {
  932. "text/html": [
  933. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  934. "</pre>\n"
  935. ],
  936. "text/plain": [
  937. "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  938. ]
  939. },
  940. "metadata": {},
  941. "output_type": "display_data"
  942. },
  943. {
  944. "data": {
  945. "text/html": [
  946. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  947. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8125</span>\n",
  948. "<span style=\"font-weight: bold\">}</span>\n",
  949. "</pre>\n"
  950. ],
  951. "text/plain": [
  952. "\u001b[1m{\u001b[0m\n",
  953. " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n",
  954. "\u001b[1m}\u001b[0m\n"
  955. ]
  956. },
  957. "metadata": {},
  958. "output_type": "display_data"
  959. },
  960. {
  961. "data": {
  962. "text/html": [
  963. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  964. "</pre>\n"
  965. ],
  966. "text/plain": [
  967. "\n"
  968. ]
  969. },
  970. "metadata": {},
  971. "output_type": "display_data"
  972. },
  973. {
  974. "data": {
  975. "text/html": [
  976. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  977. "</pre>\n"
  978. ],
  979. "text/plain": [
  980. "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  981. ]
  982. },
  983. "metadata": {},
  984. "output_type": "display_data"
  985. },
  986. {
  987. "data": {
  988. "text/html": [
  989. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  990. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n",
  991. "<span style=\"font-weight: bold\">}</span>\n",
  992. "</pre>\n"
  993. ],
  994. "text/plain": [
  995. "\u001b[1m{\u001b[0m\n",
  996. " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
  997. "\u001b[1m}\u001b[0m\n"
  998. ]
  999. },
  1000. "metadata": {},
  1001. "output_type": "display_data"
  1002. },
  1003. {
  1004. "data": {
  1005. "text/html": [
  1006. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1007. "</pre>\n"
  1008. ],
  1009. "text/plain": [
  1010. "\n"
  1011. ]
  1012. },
  1013. "metadata": {},
  1014. "output_type": "display_data"
  1015. },
  1016. {
  1017. "data": {
  1018. "text/html": [
  1019. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1020. "</pre>\n"
  1021. ],
  1022. "text/plain": [
  1023. "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1024. ]
  1025. },
  1026. "metadata": {},
  1027. "output_type": "display_data"
  1028. },
  1029. {
  1030. "data": {
  1031. "text/html": [
  1032. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1033. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n",
  1034. "<span style=\"font-weight: bold\">}</span>\n",
  1035. "</pre>\n"
  1036. ],
  1037. "text/plain": [
  1038. "\u001b[1m{\u001b[0m\n",
  1039. " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
  1040. "\u001b[1m}\u001b[0m\n"
  1041. ]
  1042. },
  1043. "metadata": {},
  1044. "output_type": "display_data"
  1045. },
  1046. {
  1047. "data": {
  1048. "text/html": [
  1049. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1050. "</pre>\n"
  1051. ],
  1052. "text/plain": [
  1053. "\n"
  1054. ]
  1055. },
  1056. "metadata": {},
  1057. "output_type": "display_data"
  1058. },
  1059. {
  1060. "data": {
  1061. "text/html": [
  1062. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1063. "</pre>\n"
  1064. ],
  1065. "text/plain": [
  1066. "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1067. ]
  1068. },
  1069. "metadata": {},
  1070. "output_type": "display_data"
  1071. },
  1072. {
  1073. "data": {
  1074. "text/html": [
  1075. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1076. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8</span>\n",
  1077. "<span style=\"font-weight: bold\">}</span>\n",
  1078. "</pre>\n"
  1079. ],
  1080. "text/plain": [
  1081. "\u001b[1m{\u001b[0m\n",
  1082. " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8\u001b[0m\n",
  1083. "\u001b[1m}\u001b[0m\n"
  1084. ]
  1085. },
  1086. "metadata": {},
  1087. "output_type": "display_data"
  1088. },
  1089. {
  1090. "data": {
  1091. "text/html": [
  1092. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1093. "</pre>\n"
  1094. ],
  1095. "text/plain": [
  1096. "\n"
  1097. ]
  1098. },
  1099. "metadata": {},
  1100. "output_type": "display_data"
  1101. },
  1102. {
  1103. "data": {
  1104. "text/html": [
  1105. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1106. "</pre>\n"
  1107. ],
  1108. "text/plain": [
  1109. "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1110. ]
  1111. },
  1112. "metadata": {},
  1113. "output_type": "display_data"
  1114. },
  1115. {
  1116. "data": {
  1117. "text/html": [
  1118. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1119. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n",
  1120. "<span style=\"font-weight: bold\">}</span>\n",
  1121. "</pre>\n"
  1122. ],
  1123. "text/plain": [
  1124. "\u001b[1m{\u001b[0m\n",
  1125. " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
  1126. "\u001b[1m}\u001b[0m\n"
  1127. ]
  1128. },
  1129. "metadata": {},
  1130. "output_type": "display_data"
  1131. },
  1132. {
  1133. "data": {
  1134. "text/html": [
  1135. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1136. "</pre>\n"
  1137. ],
  1138. "text/plain": [
  1139. "\n"
  1140. ]
  1141. },
  1142. "metadata": {},
  1143. "output_type": "display_data"
  1144. },
  1145. {
  1146. "data": {
  1147. "text/html": [
  1148. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1149. "</pre>\n"
  1150. ],
  1151. "text/plain": [
  1152. "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1153. ]
  1154. },
  1155. "metadata": {},
  1156. "output_type": "display_data"
  1157. },
  1158. {
  1159. "data": {
  1160. "text/html": [
  1161. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1162. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n",
  1163. "<span style=\"font-weight: bold\">}</span>\n",
  1164. "</pre>\n"
  1165. ],
  1166. "text/plain": [
  1167. "\u001b[1m{\u001b[0m\n",
  1168. " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
  1169. "\u001b[1m}\u001b[0m\n"
  1170. ]
  1171. },
  1172. "metadata": {},
  1173. "output_type": "display_data"
  1174. },
  1175. {
  1176. "data": {
  1177. "text/html": [
  1178. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  1179. ],
  1180. "text/plain": []
  1181. },
  1182. "metadata": {},
  1183. "output_type": "display_data"
  1184. },
  1185. {
  1186. "data": {
  1187. "text/html": [
  1188. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1189. "</pre>\n"
  1190. ],
  1191. "text/plain": [
  1192. "\n"
  1193. ]
  1194. },
  1195. "metadata": {},
  1196. "output_type": "display_data"
  1197. }
  1198. ],
  1199. "source": [
  1200. "trainer.run(num_eval_batch_per_dl=10)"
  1201. ]
  1202. },
  1203. {
  1204. "cell_type": "code",
  1205. "execution_count": null,
  1206. "id": "f1abfa0a",
  1207. "metadata": {},
  1208. "outputs": [],
  1209. "source": []
  1210. }
  1211. ],
  1212. "metadata": {
  1213. "kernelspec": {
  1214. "display_name": "Python 3 (ipykernel)",
  1215. "language": "python",
  1216. "name": "python3"
  1217. },
  1218. "language_info": {
  1219. "codemirror_mode": {
  1220. "name": "ipython",
  1221. "version": 3
  1222. },
  1223. "file_extension": ".py",
  1224. "mimetype": "text/x-python",
  1225. "name": "python",
  1226. "nbconvert_exporter": "python",
  1227. "pygments_lexer": "ipython3",
  1228. "version": "3.7.13"
  1229. },
  1230. "pycharm": {
  1231. "stem_cell": {
  1232. "cell_type": "raw",
  1233. "metadata": {
  1234. "collapsed": false
  1235. },
  1236. "source": []
  1237. }
  1238. }
  1239. },
  1240. "nbformat": 4,
  1241. "nbformat_minor": 5
  1242. }