{ "cells": [ { "cell_type": "markdown", "id": "fdd7ff16", "metadata": {}, "source": [ "# T5. trainer 和 evaluator 的深入介绍\n", "\n", " 1 fastNLP 中 driver 的补充介绍\n", " \n", " 1.1 trainer 和 driver 的构想 \n", "\n", " 1.2 device 与 多卡训练\n", "\n", " 2 fastNLP 中的更多 metric 类型\n", "\n", " 2.1 预定义的 metric 类型\n", "\n", " 2.2 自定义的 metric 类型\n", "\n", " 3 fastNLP 中 trainer 的补充介绍\n", "\n", " 3.1 trainer 的内部结构" ] }, { "cell_type": "markdown", "id": "08752c5a", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 1. fastNLP 中 driver 的补充介绍\n", "\n", "### 1.1 trainer 和 driver 的构想\n", "\n", "在`fastNLP 0.8`中,模型训练最关键的模块便是**训练模块`trainer`、评测模块`evaluator`、驱动模块`driver`**,\n", "\n", " 在`tutorial 0`中,已经简单介绍过上述三个模块:**`driver`用来控制训练评测中的`model`的最终运行**\n", "\n", " **`evaluator`封装评测的`metric`**,**`trainer`封装训练的`optimizer`**,**也可以包括`evaluator`**\n", "\n", "之所以做出上述的划分,其根本目的在于要**达成对于多个`python`学习框架**,**例如`pytorch`、`paddle`、`jittor`的兼容**\n", "\n", " 对于训练环节,其伪代码如下方左边紫色一栏所示,由于**不同框架对模型、损失、张量的定义各有不同**,所以将训练环节\n", "\n", " 划分为**框架无关的循环控制、批量分发部分**,**由`trainer`模块负责**实现,对应的伪代码如下方中间蓝色一栏所示\n", "\n", " 以及**随框架不同的模型调用、数值优化部分**,**由`driver`模块负责**实现,对应的伪代码如下方右边红色一栏所示\n", "\n", "|