{ "cells": [ { "cell_type": "markdown", "id": "fdd7ff16", "metadata": {}, "source": [ "# T5. trainer 和 evaluator 的深入介绍\n", "\n", "  1   fastNLP 中 driver 的补充介绍\n", " \n", "    1.1   trainer 和 driver 的构想 \n", "\n", "    1.2   device 与 多卡训练\n", "\n", "  2   fastNLP 中的更多 metric 类型\n", "\n", "    2.1   预定义的 metric 类型\n", "\n", "    2.2   自定义的 metric 类型\n", "\n", "  3   fastNLP 中 trainer 的补充介绍\n", "\n", "    3.1   trainer 的内部结构" ] }, { "cell_type": "markdown", "id": "08752c5a", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 1. fastNLP 中 driver 的补充介绍\n", "\n", "### 1.1 trainer 和 driver 的构想\n", "\n", "在`fastNLP 1.0`中,模型训练最关键的模块便是**训练模块 trainer 、评测模块 evaluator 、驱动模块 driver**,\n", "\n", "  在`tutorial 0`中,已经简单介绍过上述三个模块:**driver 用来控制训练评测中的 model 的最终运行**\n", "\n", "    **evaluator 封装评测的 metric**,**trainer 封装训练的 optimizer**,**也可以包括 evaluator**\n", "\n", "之所以做出上述的划分,其根本目的在于要**达成对于多个 python 学习框架**,**例如 pytorch 、 paddle 、 jittor 的兼容**\n", "\n", "  对于训练环节,其伪代码如下方左边紫色一栏所示,由于**不同框架对模型、损失、张量的定义各有不同**,所以将训练环节\n", "\n", "    划分为**框架无关的循环控制、批量分发部分**,**由 trainer 模块负责**实现,对应的伪代码如下方中间一栏所示\n", "\n", "    以及**随框架不同的模型调用、数值优化部分**,**由 driver 模块负责**实现,对应的伪代码如下方右边一栏所示\n", "\n", "|训练过程|框架无关 对应`Trainer`|框架相关 对应`Driver`\n", "|----|----|----|\n", "| try: | try: | |\n", "| for epoch in 1:n_eoochs: | for epoch in 1:n_eoochs: | |\n", "| for step in 1:total_steps: | for step in 1:total_steps: | |\n", "| batch = fetch_batch() | batch = fetch_batch() | |\n", "| loss = model.forward(batch)  | | loss = model.forward(batch)  |\n", "| loss.backward() | | loss.backward() |\n", "| model.clear_grad() | | model.clear_grad() |\n", "| model.update() | | model.update() |\n", "| if need_save: | if need_save: | |\n", "| model.save() | | model.save() |\n", "| except: | except: | |\n", "| process_exception() | process_exception() | |" ] }, { "cell_type": "markdown", "id": "3e55f07b", "metadata": {}, "source": [ "  对于评测环节,其伪代码如下方左边紫色一栏所示,同样由于不同框架对模型、损失、张量的定义各有不同,所以将评测环节\n", "\n", "    划分为**框架无关的循环控制、分发汇总部分**,**由 evaluator 模块负责**实现,对应的伪代码如下方中间一栏所示\n", "\n", "    以及**随框架不同的模型调用、评测计算部分**,同样**由 driver 模块负责**实现,对应的伪代码如下方右边一栏所示\n", "\n", "|评测过程|框架无关 对应`Evaluator`|框架相关 对应`Driver`\n", "|----|----|----|\n", "| try: | try: | |\n", "| model.set_eval() | model.set_eval() | |\n", "| for step in 1:total_steps: | for step in 1:total_steps: | |\n", "| batch = fetch_batch() | batch = fetch_batch() | |\n", "| outputs = model.evaluate(batch)  | | outputs = model.evaluate(batch)  |\n", "| metric.compute(batch, outputs) | | metric.compute(batch, outputs) |\n", "| results = metric.get_metric() | results = metric.get_metric() | |\n", "| except: | except: | |\n", "| process_exception() | process_exception() | |" ] }, { "cell_type": "markdown", "id": "94ba11c6", "metadata": { "pycharm": { "name": "#%%\n" } }, "source": [ "由此,从程序员的角度,`fastNLP v1.0` **通过一个 driver 让基于 pytorch 、 paddle 、 jittor 、 oneflow 框架的模型**\n", "\n", "    **都能在相同的 trainer 和 evaluator 上运行**,这也**是 fastNLP v1.0 相比于之前版本的一大亮点**\n", "\n", "  而从`driver`的角度,`fastNLP v1.0`通过定义一个`driver`基类,**将所有张量转化为 numpy.tensor**\n", "\n", "    并由此泛化出`torch_driver`、`paddle_driver`、`jittor_driver`三个子类,从而实现了\n", "\n", "    对`pytorch`、`paddle`、`jittor`的兼容,有关后两者的实践请参考接下来的`tutorial-6`" ] }, { "cell_type": "markdown", "id": "ab1cea7d", "metadata": {}, "source": [ "### 1.2 device 与 多卡训练\n", "\n", "**fastNLP v1.0 支持多卡训练**,实现方法则是**通过将 trainer 中的 device 设置为对应显卡的序号列表**\n", "\n", "  由单卡切换成多卡,无论是数据、模型还是评测都会面临一定的调整,`fastNLP v1.0`保证:\n", "\n", "    数据拆分时,不同卡之间相互协调,所有数据都可以被训练,且不会使用到相同的数据\n", "\n", "    模型训练时,模型之间需要交换梯度;评测计算时,每张卡先各自计算,再汇总结果\n", "\n", "  例如,在评测计算运行`get_metric`函数时,`fastNLP v1.0`将自动按照`self.right`和`self.total`\n", "\n", "    指定的 **aggregate_method 方法**,默认为`sum`,将每张卡上结果汇总起来,因此最终\n", "\n", "    在调用`get_metric`方法时,`Accuracy`类能够返回全部的统计结果,代码如下\n", " \n", "```python\n", "trainer = Trainer(\n", " model=model, # model 基于 pytorch 实现 \n", " train_dataloader=train_dataloader,\n", " optimizers=optimizer,\n", " ...\n", " driver='torch', # driver 使用 torch_driver \n", " device=[0, 1], # gpu 选择 cuda:0 + cuda:1\n", " ...\n", " evaluate_dataloaders=evaluate_dataloader,\n", " metrics={'acc': Accuracy()},\n", " ...\n", " )\n", "\n", "class Accuracy(Metric):\n", " def __init__(self):\n", " super().__init__()\n", " self.register_element(name='total', value=0, aggregate_method='sum')\n", " self.register_element(name='right', value=0, aggregate_method='sum')\n", "```\n" ] }, { "cell_type": "markdown", "id": "e2e0a210", "metadata": { "pycharm": { "name": "#%%\n" } }, "source": [ "注:`fastNLP v1.0`中要求`jupyter`不能多卡,仅能单卡,故在所有`tutorial`中均不作相关演示" ] }, { "cell_type": "markdown", "id": "8d19220c", "metadata": {}, "source": [ "## 2. fastNLP 中的更多 metric 类型\n", "\n", "### 2.1 预定义的 metric 类型\n", "\n", "在`fastNLP 1.0`中,除了前几篇`tutorial`中经常见到的**正确率 Accuracy**,还有其他**预定义的评测标准 metric**\n", "\n", "  包括**所有 metric 的基类 Metric**、适配`Transformers`中相关模型的正确率`TransformersAccuracy`\n", "\n", "    **适用于分类语境下的 F1 值 ClassifyFPreRecMetric**(其中也包括召回率`Pre`、精确率`Rec`\n", "\n", "    **适用于抽取语境下的 F1 值 SpanFPreRecMetric**;相关基本信息内容见下表,之后是详细分析\n", "\n", "代码名称|简要介绍|代码路径\n", "----|----|----|\n", " `Metric` | 定义`metrics`时继承的基类 | `/core/metrics/metric.py` |\n", " `Accuracy` | 正确率,最为常用 | `/core/metrics/accuracy.py` |\n", " `TransformersAccuracy` | 正确率,为了兼容`Transformers`中相关模型 | `/core/metrics/accuracy.py` |\n", " `ClassifyFPreRecMetric` | 召回率、精确率、F1值,适用于**分类问题** | `/core/metrics/classify_f1_pre_rec_metric.py` |\n", " `SpanFPreRecMetric` | 召回率、精确率、F1值,适用于**抽取问题** | `/core/metrics/span_f1_pre_rec_metric.py` |" ] }, { "cell_type": "markdown", "id": "fdc083a3", "metadata": { "pycharm": { "name": "#%%\n" } }, "source": [ "  如`tutorial-0`中所述,所有的`metric`都包含`get_metric`和`update`函数,其中\n", "\n", "    **update 函数更新单个 batch 的统计量**,**get_metric 函数返回最终结果**,并打印显示\n", "\n", "\n", "### 2.1.1 Accuracy 与 TransformersAccuracy\n", "\n", "`Accuracy`,正确率,预测正确的数据`right_num`在总数据`total_num`,中的占比(公式就不用列了\n", "\n", "  `get_metric`函数打印格式为 **{\"acc#xx\": float, 'total#xx': float, 'correct#xx': float}**\n", "\n", "  一般在初始化时不需要传参,`fastNLP`会根据`update`函数的传入参数确定对应后台框架`backend`\n", "\n", "  **update 函数的参数包括 pred 、 target 、 seq_len**,**后者用来标记批次中每笔数据的长度**\n", "\n", "`TransformersAccuracy`,继承自`Accuracy`,只是为了兼容`Transformers`框架中相关模型\n", "\n", "  在`update`函数中,将`Transformers`框架输出的`attention_mask`参数转化为`seq_len`参数\n", "\n", "\n", "### 2.1.2 ClassifyFPreRecMetric 与 SpanFPreRecMetric\n", "\n", "`ClassifyFPreRecMetric`,分类评价,`SpanFPreRecMetric`,抽取评价,后者在`tutorial-4`中已出现\n", "\n", "  两者的相同之处在于:**第一**,**都包括召回率/查全率 ec**、**精确率/查准率 Pre**、**F1 值**这三个指标\n", "\n", "    `get_metric`函数打印格式为 **{\"f#xx\": float, 'pre#xx': float, 'rec#xx': float}**\n", "\n", "    三者的计算公式如下,其中`beta`默认为`1`,即`F1`值是召回率`Rec`和精确率`Pre`的调和平均数\n", "\n", "$$\\text{召回率}\\ Rec=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有本来是正例的数量}}\\qquad \\text{精确率}\\ Pre=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有预测为正例的数量}}$$\n", "\n", "$$F_{beta} = \\frac{(1 + {beta}^{2})*(Pre*Rec)}{({beta}^{2}*Pre + Rec)}$$\n", "\n", "  **第二**,可以通过参数`only_gross`为`False`,要求返回所有类别的`Rec-Pre-F1`,同时`F1`值又根据参数`f_type`又分为\n", "\n", "    **micro F1**(**直接统计所有类别的 Rec-Pre-F1**)、**macro F1**(**统计各类别的 Rec-Pre-F1 再算术平均**)\n", "\n", "  **第三**,两者在初始化时还可以**传入基于 fastNLP.Vocabulary 的 tag_vocab 参数记录数据集中的标签序号**\n", "\n", "    **与标签名称之间的映射**,通过字符串列表`ignore_labels`参数,指定若干标签不用于`Rec-Pre-F1`的计算\n", "\n", "两者的不同之处在于:`ClassifyFPreRecMetric`针对简单的分类问题,每个分类标签之间彼此独立,不构成标签对\n", "\n", "    **SpanFPreRecMetric 针对更复杂的抽取问题**,**规定标签 B-xx 和 I-xx 或 B-xx 和 E-xx 构成标签对**\n", "\n", "  在计算`Rec-Pre-F1`时,`ClassifyFPreRecMetric`只需要考虑标签本身是否正确这就足够了,但是\n", "\n", "    对于`SpanFPreRecMetric`,需要保证**标签符合规则且覆盖的区间与正确结果重合才算正确**\n", "\n", "    因此回到`tutorial-4`中`CoNLL-2003`的`NER`任务,如果评测方法选择`ClassifyFPreRecMetric`\n", "\n", "      或者`Accuracy`,会发现虽然评测结果显示很高,这是因为选择的评测方法要求太低\n", "\n", "    最后通过`CoNLL-2003`的词性标注`POS`任务简单演示下`ClassifyFPreRecMetric`相关的使用\n", "\n", "```python\n", "from fastNLP import Vocabulary\n", "from fastNLP import ClassifyFPreRecMetric\n", "\n", "tag_vocab = Vocabulary(padding=None, unknown=None) # 记录序号与标签之间的映射\n", "tag_vocab.add_word_lst(['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', \n", " 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', \n", " 'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', \n", " 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', \n", " 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP+', 'WRB', ]) # CoNLL-2003 中的 pos_tags\n", "ignore_labels = ['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', ]\n", "\n", "FPreRec = ClassifyFPreRecMetric(tag_vocab=tag_vocab, \n", " ignore_labels=ignore_labels, # 表示评测/优化中不考虑上述标签的正误/损失\n", " only_gross=True, # 默认为 True 表示输出所有类别的综合统计结果\n", " f_type='micro') # 默认为 'micro' 表示统计所有类别的 Rec-Pre-F1\n", "metrics = {'F1': FPreRec}\n", "```" ] }, { "cell_type": "markdown", "id": "8a22f522", "metadata": {}, "source": [ "### 2.2 自定义的 metric 类型\n", "\n", "如上文所述,`Metric`作为所有`metric`的基类,`Accuracy`等都是其子类,同样地,对于**自定义的 metric 类型**\n", "\n", "    也**需要继承自 Metric 类**,同时**内部自定义好 __init__ 、 update 和 get_metric 函数**\n", "\n", "  在`__init__`函数中,根据需求定义评测时需要用到的变量,此处沿用`Accuracy`中的`total_num`和`right_num`\n", "\n", "  在`update`函数中,根据需求定义评测变量的更新方式,需要注意的是如`tutorial-0`中所述,**update`的参数名**\n", "\n", "    **需要待评估模型在 evaluate_step 中的输出名称一致**,由此**和数据集中对应字段名称一致**,即**参数匹配**\n", "\n", "    在`fastNLP v1.0`中,`update`函数的默认输入参数:`pred`,对应预测值;`target`,对应真实值\n", "\n", "    此处仍然沿用,因为接下来会需要使用`fastNLP`函数的与定义模型,其输入参数格式即使如此\n", "\n", "  在`get_metric`函数中,根据需求定义评测指标最终的计算,此处直接计算准确率,该函数必须返回一个字典\n", "\n", "    其中,字串`'prefix'`表示该`metric`的名称,会对应显示到`trainer`的`progress bar`中\n", "\n", "根据上述要求,这里简单定义了一个名为`MyMetric`的评测模块,用于分类问题的评测,以此展开一个实例展示" ] }, { "cell_type": "code", "execution_count": 1, "id": "08a872e9", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import sys\n", "sys.path.append('..')\n", "\n", "from fastNLP import Metric\n", "\n", "class MyMetric(Metric):\n", "\n", " def __init__(self):\n", " Metric.__init__(self)\n", " self.total_num = 0\n", " self.right_num = 0\n", "\n", " def update(self, pred, target):\n", " self.total_num += target.size(0)\n", " self.right_num += target.eq(pred).sum().item()\n", "\n", " def get_metric(self, reset=True):\n", " acc = self.right_num / self.total_num\n", " if reset:\n", " self.total_num = 0\n", " self.right_num = 0\n", " return {'prefix': acc}" ] }, { "cell_type": "markdown", "id": "0155f447", "metadata": {}, "source": [ "  数据使用方面,此处仍然使用`datasets`模块中的`load_dataset`函数,加载`SST-2`二分类数据集" ] }, { "cell_type": "code", "execution_count": 2, "id": "5ad81ac7", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ef923b90b19847f4916cccda5d33fc36", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00 0: # 如果设置了 num_eval_sanity_batch\n", "\t\ton_sanity_check_begin(trainer)\n", "\t\ton_sanity_check_end(trainer, sanity_check_res)\n", "\ttry:\n", "\t\ton_train_begin(trainer)\n", "\t\twhile cur_epoch_idx < n_epochs:\n", "\t\t\ton_train_epoch_begin(trainer)\n", "\t\t\twhile batch_idx_in_epoch<=num_batches_per_epoch:\n", "\t\t\t\ton_fetch_data_begin(trainer)\n", "\t\t\t\tbatch = next(dataloader)\n", "\t\t\t\ton_fetch_data_end(trainer)\n", "\t\t\t\ton_train_batch_begin(trainer, batch, indices)\n", "\t\t\t\ton_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping 后的\n", "\t\t\t\ton_after_backward(trainer)\n", "\t\t\t\ton_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n", "\t\t\t\ton_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n", "\t\t\t\ton_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n", "\t\t\t\ton_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n", "\t\t\t\ton_train_batch_end(trainer)\n", "\t\t\ton_train_epoch_end(trainer)\n", "\texcept BaseException:\n", "\t\tself.on_exception(trainer, exception)\n", "\tfinally:\n", "\t\ton_train_end(trainer)\n", "``` -->" ] }, { "cell_type": "markdown", "id": "1e21df35", "metadata": {}, "source": [ "紧接着,初始化`trainer`实例,继续完成`SST-2`分类,其中`metrics`输入的键值对,字串`'suffix'`和之前定义的\n", "\n", "  字串`'prefix'`将拼接在一起显示到`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`" ] }, { "cell_type": "code", "execution_count": 6, "id": "926a9c50", "metadata": {}, "outputs": [], "source": [ "from fastNLP import Trainer\n", "\n", "trainer = Trainer(\n", " model=model,\n", " driver='torch',\n", " device=0, # 'cuda'\n", " n_epochs=10,\n", " optimizers=optimizers,\n", " input_mapping=input_mapping,\n", " train_dataloader=train_dataloader,\n", " evaluate_dataloaders=evaluate_dataloader,\n", " metrics={'suffix': MyMetric()}\n", ")" ] }, { "cell_type": "markdown", "id": "b1b2e8b7", "metadata": { "pycharm": { "name": "#%%\n" } }, "source": [ "最后就是`run`函数的使用,关于其参数,这里也以表格形式列出,由此就解答了`num_eval_batch_per_dl=10`的含义\n", "\n", "|名称|功能|默认值|\n", "|----|----|----|\n", "| `num_train_batch_per_epoch` | 指定`trainer`训练时,每个循环计算批量数目 | 整数类型,默认`-1`,表示训练时,每个循环计算所有批量 |\n", "| `num_eval_batch_per_dl` | 指定`trainer`评测时,每个循环计算批量数目 | 整数类型,默认`-1`,表示评测时,每个循环计算所有批量 |\n", "| `num_eval_sanity_batch` | 指定`trainer`训练开始前,试探性评测批量数目 | 整数类型,默认`2`,表示训练开始前评估两个批量 |\n", "| `resume_from` | 指定`trainer`恢复状态的路径,需要是文件夹 | 字符串型,默认`None`,使用可参考`CheckpointCallback` |\n", "| `resume_training` | 指定`trainer`恢复状态的程度 | 布尔类型,默认`True`恢复所有状态,`False`仅恢复`model`和`optimizers`状态 |" ] }, { "cell_type": "code", "execution_count": 7, "id": "43be274f", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/html": [ "
[09:30:35] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
       "
\n" ], "text/plain": [ "\u001b[2;36m[09:30:35]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=954293;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=366534;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
       "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
       ".get_parent()\n",
       "  if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
       "
\n" ], "text/plain": [ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", ".get_parent()\n", " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
       "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
       ".get_parent()\n",
       "  self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
       "
\n" ], "text/plain": [ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", ".get_parent()\n", " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"prefix#suffix\": 0.6875\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.6875\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"prefix#suffix\": 0.8125\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"prefix#suffix\": 0.80625\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"prefix#suffix\": 0.825\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.825\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"prefix#suffix\": 0.8125\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"prefix#suffix\": 0.80625\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"prefix#suffix\": 0.80625\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"prefix#suffix\": 0.8\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"prefix#suffix\": 0.80625\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"prefix#suffix\": 0.80625\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, { "cell_type": "code", "execution_count": null, "id": "f1abfa0a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } } }, "nbformat": 4, "nbformat_minor": 5 }