{ "cells": [ { "cell_type": "markdown", "id": "fdd7ff16", "metadata": {}, "source": [ "# T5. trainer 和 evaluator 的深入介绍\n", "\n", " 1 fastNLP 中 driver 的补充介绍\n", " \n", " 1.1 trainer 和 driver 的构想 \n", "\n", " 1.2 device 与 多卡训练\n", "\n", " 2 fastNLP 中的更多 metric 类型\n", "\n", " 2.1 预定义的 metric 类型\n", "\n", " 2.2 自定义的 metric 类型\n", "\n", " 3 fastNLP 中 trainer 的补充介绍\n", "\n", " 3.1 trainer 的内部结构" ] }, { "cell_type": "markdown", "id": "08752c5a", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 1. fastNLP 中 driver 的补充介绍\n", "\n", "### 1.1 trainer 和 driver 的构想\n", "\n", "在`fastNLP 0.8`中,模型训练最关键的模块便是**训练模块`trainer`、评测模块`evaluator`、驱动模块`driver`**,\n", "\n", " 在`tutorial 0`中,已经简单介绍过上述三个模块:**`driver`用来控制训练评测中的`model`的最终运行**\n", "\n", " **`evaluator`封装评测的`metric`**,**`trainer`封装训练的`optimizer`**,**也可以包括`evaluator`**\n", "\n", "之所以做出上述的划分,其根本目的在于要**达成对于多个`python`学习框架**,**例如`pytorch`、`paddle`、`jittor`的兼容**\n", "\n", " 对于训练环节,其伪代码如下方左边紫色一栏所示,由于**不同框架对模型、损失、张量的定义各有不同**,所以将训练环节\n", "\n", " 划分为**框架无关的循环控制、批量分发部分**,**由`trainer`模块负责**实现,对应的伪代码如下方中间蓝色一栏所示\n", "\n", " 以及**随框架不同的模型调用、数值优化部分**,**由`driver`模块负责**实现,对应的伪代码如下方右边红色一栏所示\n", "\n", "|
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import sys\n", "sys.path.append('..')\n", "\n", "from fastNLP import Metric\n", "\n", "class MyMetric(Metric):\n", "\n", " def __init__(self):\n", " Metric.__init__(self)\n", " self.total_num = 0\n", " self.right_num = 0\n", "\n", " def update(self, pred, target):\n", " self.total_num += target.size(0)\n", " self.right_num += target.eq(pred).sum().item()\n", "\n", " def get_metric(self, reset=True):\n", " acc = self.right_num / self.total_num\n", " if reset:\n", " self.total_num = 0\n", " self.right_num = 0\n", " return {'prefix': acc}" ] }, { "cell_type": "markdown", "id": "0155f447", "metadata": {}, "source": [ " 数据使用方面,此处仍然使用`datasets`模块中的`load_dataset`函数,加载`SST-2`二分类数据集" ] }, { "cell_type": "code", "execution_count": 2, "id": "5ad81ac7", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ef923b90b19847f4916cccda5d33fc36", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from datasets import load_dataset\n", "\n", "sst2data = load_dataset('glue', 'sst2')" ] }, { "cell_type": "markdown", "id": "e9d81760", "metadata": {}, "source": [ " 在数据预处理中,需要注意的是,这里原本应该根据`metric`和`model`的输入参数格式,调整\n", "\n", " 数据集中表示预测目标的字段,调整为`target`,在后文中会揭晓为什么,以及如何补救" ] }, { "cell_type": "code", "execution_count": 3, "id": "cfb28b1b", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Processing: 0%| | 0/6000 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from fastNLP import DataSet\n", "\n", "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", "\n", "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split()}, progress_bar=\"tqdm\")\n", "dataset.delete_field('sentence')\n", "dataset.delete_field('idx')\n", "\n", "from fastNLP import Vocabulary\n", "\n", "vocab = Vocabulary()\n", "vocab.from_dataset(dataset, field_name='words')\n", "vocab.index_dataset(dataset, field_name='words')\n", "\n", "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n", "\n", "from fastNLP import prepare_torch_dataloader\n", "\n", "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" ] }, { "cell_type": "markdown", "id": "af3f8c63", "metadata": {}, "source": [ " 模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类" ] }, { "cell_type": "code", "execution_count": 4, "id": "2fd210c5", "metadata": {}, "outputs": [], "source": [ "from fastNLP.models.torch import CNNText\n", "\n", "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n", "\n", "from torch.optim import AdamW\n", "\n", "optimizers = AdamW(params=model.parameters(), lr=5e-4)" ] }, { "cell_type": "markdown", "id": "6e723b87", "metadata": {}, "source": [ "## 3. fastNLP 中 trainer 的补充介绍\n", "\n", "### 3.1 trainer 的内部结构\n", "\n", "在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经展示了\n", "\n", " 很多`trainer`的使用案例,这里通过表格,相对完整地介绍`trainer`模块的属性和初始化参数(标粗为必选参数\n", "\n", "|
[09:30:35] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", "\n" ], "text/plain": [ "\u001b[2;36m[09:30:35]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=954293;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=366534;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", ".get_parent()\n", " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n", "\n" ], "text/plain": [ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", ".get_parent()\n", " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", ".get_parent()\n", " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n", "\n" ], "text/plain": [ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", ".get_parent()\n", " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.6875\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.6875\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.8125\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.80625\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.825\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.825\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.8125\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.80625\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.80625\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.8\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.80625\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.80625\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, { "cell_type": "code", "execution_count": null, "id": "f1abfa0a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } } }, "nbformat": 4, "nbformat_minor": 5 }