@@ -42,7 +42,7 @@ class Callback: | |||
on_train_end(trainer) | |||
其它 callback 例如 **on_evaluate_begin(trainer)** / **on_evaluate_end(trainer, results)** / **on_save_model(trainer)** / | |||
**on_load_model(trainer)** / **on_save_checkpoint(trainer)** / **on_load_checkpoint(trainer)** 将根据需要在 :meth:`fastNLP.Trainer.run` | |||
**on_load_model(trainer)** / **on_save_checkpoint(trainer)** / **on_load_checkpoint(trainer)** 将根据需要在 :meth:`Trainer.run <fastNLP.core.controllers.Trainer.run>` | |||
中特定的时间调用。 | |||
""" | |||
@@ -50,8 +50,8 @@ class Callback: | |||
r""" | |||
在 ``Trainer`` 初始化后会被触发; | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param driver: :class:`~fastNLP.Trainer` 中的 ``driver`` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
:param driver: :class:`~fastNLP.core.controllers.Trainer` 中的 ``driver`` 实例; | |||
""" | |||
pass | |||
@@ -59,7 +59,7 @@ class Callback: | |||
r""" | |||
在 '预跑'检测 开始前会被触发; | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
""" | |||
pass | |||
@@ -67,7 +67,7 @@ class Callback: | |||
r""" | |||
在 '预跑'检测 开始后会被触发; | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
:param sanity_check_res: 预跑得到的评测结果,关于对于 **预跑** 的解释,请见 :meth:`~fastNLP.core.controllers.trainer.Trainer.run`; | |||
""" | |||
pass | |||
@@ -76,7 +76,7 @@ class Callback: | |||
r""" | |||
在训练开始前会被触发; | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
""" | |||
pass | |||
@@ -84,7 +84,7 @@ class Callback: | |||
r""" | |||
在训练完成后会被触发; | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
""" | |||
pass | |||
@@ -92,7 +92,7 @@ class Callback: | |||
r""" | |||
在训练过程中的每一个 epoch 开始前会被触发; | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
""" | |||
pass | |||
@@ -100,7 +100,7 @@ class Callback: | |||
r""" | |||
在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
""" | |||
pass | |||
@@ -108,7 +108,7 @@ class Callback: | |||
r""" | |||
在训练过程中准备取出下一个 batch 的数据时触发 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
""" | |||
pass | |||
@@ -116,17 +116,17 @@ class Callback: | |||
r""" | |||
在训练过程中拿到当前的 batch 数据后会被触发; | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
""" | |||
pass | |||
def on_train_batch_begin(self, trainer, batch, indices): | |||
r""" | |||
在取得数据,执行完 ``input_mapping`` (如果 :class:`~fastNLP.Trainer` 传有该参数),并且移动 ``batch`` 中的张量到了指定设备之后会被触发。 | |||
在取得数据,执行完 ``input_mapping`` (如果 :class:`~fastNLP.core.controllers.Trainer` 传有该参数),并且移动 ``batch`` 中的张量到了指定设备之后会被触发。 | |||
其中 ``batch`` 中的数据格式要么是 ``Dataloader`` 返回的每个 ``batch`` 的格式;要么是 ``input_mapping`` 之后的内容。 | |||
如果 ``batch`` 是 ``dict`` 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入模型的中的 ``batch`` 数据。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
:param batch: batch 的数据,已经经过 ``input_mapping`` (如果有) 以及移动到指定设备 。 | |||
:param list[int] indices: 当前的 ``batch`` 是数据集中的哪些数据。仅在 ``DataLoader`` 支持得到当前 ``batch index`` 的时候有值, | |||
其它时候为 ``None`` 。 | |||
@@ -139,7 +139,7 @@ class Callback: | |||
global_forward_batches 累计加1操作之后会被触发。其中梯度更新、梯度置零操作会考虑 **accumulation_steps** ,所以不一定在当前 batch 会 | |||
执行。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
""" | |||
pass | |||
@@ -147,41 +147,42 @@ class Callback: | |||
r""" | |||
在训练过程遇到异常时调用。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
:param exception: 遭遇的异常; | |||
""" | |||
pass | |||
def on_save_model(self, trainer): | |||
r""" | |||
当调用 :meth:`fastNLP.Trainer.save_model` 时调用,此刻模型还未保存。 | |||
当调用 :meth:`Trainer.save_model() <fastNLP.core.controllers.Trainer.save_model>` 时调用,此刻模型还未保存。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
""" | |||
pass | |||
def on_load_model(self, trainer): | |||
r""" | |||
当调用 :meth:`fastNLP.Trainer.load_model` 加载模型时调用,此刻模型还未加载。 | |||
当调用 :meth:`Trainer.load_model() <fastNLP.core.controllers.Trainer.load_model>` 加载模型时调用,此刻模型还未加载。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
""" | |||
pass | |||
def on_save_checkpoint(self, trainer) -> Dict: | |||
r""" | |||
当 Trainer 将要保存 checkpoint 的时候触发 (即调用 :meth:`Trainer.save_checkpoint`()` 函数时),该函数用于保存当前 callback 在恢复时需要的相关数据。 | |||
当 Trainer 将要保存 checkpoint 的时候触发 (即调用 :meth:`Trainer.save_checkpoint() <fastNLP.core.controllers.Trainer.save_checkpoint>` | |||
函数时),该函数用于保存当前 callback 在恢复时需要的相关数据。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
""" | |||
pass | |||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||
r""" | |||
当 Trainer 要恢复 checkpoint 的时候触发(即调用 :meth:`Trainer.load_checkpoint` 函数时, 此刻 Trainer 与 Driver 已经加载好自身 | |||
的状态), 参数 states 为 Callback 在调用 :meth:`on_save_checkpoint` 的返回值。 | |||
当 Trainer 要恢复 checkpoint 的时候触发(即调用 :meth:`Trainer.load_checkpoint() <fastNLP.core.controllers.Trainer.load_checkpoint>` | |||
函数时, 此刻 Trainer 与 Driver 已经加载好自身的状态), 参数 states 为 Callback 在调用 :meth:`on_save_checkpoint` 的返回值。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
:param states: | |||
""" | |||
pass | |||
@@ -190,7 +191,7 @@ class Callback: | |||
r""" | |||
在 backward 前执行。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
:param outputs: ``model`` 的返回内容。如果有 ``output_mapping``,则 ``outputs`` 中的内容为已经执行了 ``output_mapping`` 后的结果。 | |||
""" | |||
pass | |||
@@ -200,7 +201,7 @@ class Callback: | |||
在 ``backward`` 后执行。在多卡场景下,由于 ``accumulation_steps`` 的影响,仅在需要真正 ``update`` 参数那次梯度回传才会触发梯度同步, | |||
因此在多卡且使用 ``accumulation_steps`` 时,可能存在某些 step 各卡上梯度不一致的问题。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
""" | |||
pass | |||
@@ -208,8 +209,8 @@ class Callback: | |||
r""" | |||
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param optimizers: 优化器,内容为在 :class:`~fastNLP.Trainer` 初始化时传入的值。 | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
:param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。 | |||
""" | |||
pass | |||
@@ -217,8 +218,8 @@ class Callback: | |||
r""" | |||
在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param optimizers: 优化器,内容为在 :class:`~fastNLP.Trainer` 初始化时传入的值。 | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
:param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。 | |||
""" | |||
pass | |||
@@ -226,8 +227,8 @@ class Callback: | |||
r""" | |||
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param optimizers: 优化器,内容为在 :class:`~fastNLP.Trainer` 初始化时传入的值。 | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
:param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。 | |||
""" | |||
pass | |||
@@ -235,8 +236,8 @@ class Callback: | |||
r""" | |||
在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param optimizers: 优化器,内容为在 :class:`~fastNLP.Trainer` 初始化时传入的值。 | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
:param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。 | |||
""" | |||
pass | |||
@@ -245,7 +246,7 @@ class Callback: | |||
在将要进行 ``evaluate`` 时调用。如果是设置的以 step 数量或自定义地决定 evaluate 的频率,该接口是在 :meth:`on_train_batch_end` 之后 | |||
进行调用。如果是以 epoch 数量决定调用时机,该接口是在 :meth:`on_train_epoch_end` 之后调用。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
""" | |||
pass | |||
@@ -253,8 +254,8 @@ class Callback: | |||
r""" | |||
结束 evaluate 时调用,并把 evaluate 的结果传入。 | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param results: :class:`~fastNLP.Trainer` 内置的 ``Evaluator`` 评测的结果,通常是个 ``dict``; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
:param results: :class:`~fastNLP.core.controllers.Trainer` 内置的 ``Evaluator`` 评测的结果,通常是个 ``dict``; | |||
""" | |||
pass | |||
@@ -146,7 +146,7 @@ class CallbackManager: | |||
断点重训应当保存的状态; | |||
2. 每一个具体的 callback 函数的 filter 的状态; | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
:return: 一个包含上述内容的字典,格式如下: | |||
.. code-block:: | |||
@@ -196,7 +196,7 @@ class CallbackManager: | |||
r""" | |||
用于断点重训的加载函数,对应于断点重训的保存函数; | |||
:param trainer: :class:`~fastNLP.Trainer` 实例; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
:param states: 同 :func:`on_save_checkpoint` 函数的返回值; | |||
""" | |||
@@ -24,20 +24,21 @@ class CheckpointCallback(Callback): | |||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||
``model_save_fn`` 为 ``Non``e ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。若 ``model_save_fn`` 不为 ``None``, | |||
``model_save_fn`` 为 ``None`` ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。若 ``model_save_fn`` 不为 ``None``, | |||
则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。默认情况下,本 checkpoint 只保存了 model | |||
的状态;如还需保存 Trainer 的状态以断点重训的话,请使用 ``save_object='trainer'`` 。 | |||
:param monitor: 监控的 metric 值。 | |||
* 为 ``None`` | |||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||
* 为 ``str`` | |||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||
* 为 ``Callable`` | |||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||
的 ``monitor`` 值请返回 ``None`` 。 | |||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||
* 为 :class:`Callable` | |||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||
的 ``monitor`` 值请返回 ``None`` 。 | |||
:param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
:param every_n_epochs: 多少个 epoch 保存一次。 | |||
@@ -50,10 +51,10 @@ class CheckpointCallback(Callback): | |||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
如果传入了 ``model_save_fn`` 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||
:param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 ``True`` ,在保存 topk 模型的 folder 中还将额外保存一个 | |||
``*``fastnlp_evaluate_results.json``文件,记录当前的 results。仅在设置了 ``topk`` 的场景下有用,默认为 ``True`` 。 | |||
``fastnlp_evaluate_results.json`` 文件,记录当前的 results。仅在设置了 ``topk`` 的场景下有用,默认为 ``True`` 。 | |||
:param kwargs: | |||
""" | |||
def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, | |||
@@ -10,16 +10,16 @@ from fastNLP.core.utils.exceptions import EarlyStopException | |||
class EarlyStopCallback(HasMonitorCallback): | |||
""" | |||
用于 early stop 的 callback 。当监控的结果连续多少次没有变好便 raise 一个 EarlyStopException 。 | |||
用于 early stop 的 callback 。当监控的结果连续多少次没有变好便 raise 一个 :class:`EarlyStopException` 。 | |||
:param monitor: 监控的 metric 值。 | |||
* 为 ``None`` | |||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||
* 为 ``str`` | |||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||
* 为 ``Callable`` | |||
* 为 :class:`Callable` | |||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||
的 ``monitor`` 值请返回 ``None`` 。 | |||
:param larger_better: monitor 的值是否是越大越好。 | |||
@@ -16,18 +16,19 @@ class FitlogCallback(HasMonitorCallback): | |||
自动记录 ``evaluation`` 结果到 ``fitlog`` 中。会自动记录每一次 ``evaluate`` 后的结果;同时会根据 | |||
``monitor`` 记录最好的结果。另外,会自动将非 ``rank 0`` 上的 ``fitlog`` 设置为 ``debug`` 状态。同时还会在 ``fitlog`` 的 | |||
``other`` 列中记录一个 ``launch_time`` ,可以通过这个数值找到当前这个脚本的在 save_folder (如果有使用其它需要保存模型的 | |||
``Callback`` ,例如 :class:`~fastNLP.CheckpointCallback` )下的文件夹名称。 | |||
``Callback`` ,例如 :class:`~fastNLP.core.callbacks.CheckpointCallback` )下的文件夹名称。 | |||
:param monitor: 监控的 metric 值。 | |||
* 为 ``None`` | |||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||
* 为 ``str`` | |||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||
* 为 ``Callable`` | |||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||
的 ``monitor`` 值请返回 ``None`` 。 | |||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||
* 为 :class:`Callable` | |||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||
的 ``monitor`` 值请返回 ``None`` 。 | |||
:param larger_better: 是否是越大越好。 | |||
:param log_exception: 是否记录 ``exception`` 。 | |||
:param log_loss_every: 多少个 ``batch`` 记录一次 loss 到 ``fitlog`` 中。 | |||
@@ -31,11 +31,11 @@ class ResultsMonitor: | |||
:param monitor: 监控的 metric 值: | |||
* 为 ``None`` | |||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置); | |||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置); | |||
* 为 ``str`` | |||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ; | |||
* 为 ``Callable`` | |||
* 为 :class:`Callable` | |||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||
的 ``monitor`` 值请返回 ``None`` ; | |||
:param larger_better: monitor 是否为越大越好; | |||
@@ -190,11 +190,11 @@ class HasMonitorCallback(ResultsMonitor, Callback): | |||
:param monitor: 监控的 metric 值: | |||
* 为 ``None`` | |||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置); | |||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置); | |||
* 为 ``str`` | |||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ; | |||
* 为 ``Callable`` | |||
* 为 :class:`Callable` | |||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||
的 ``monitor`` 值请返回 ``None`` ; | |||
:param larger_better: monitor 是否为越大越好; | |||
@@ -235,11 +235,11 @@ class ExecuteOnceBetterMonitor(HasMonitorCallback): | |||
:param monitor: 监控的 metric 值: | |||
* 为 ``None`` | |||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 ``monitor`` 值(如果有设置); | |||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 ``monitor`` 值(如果有设置); | |||
* 为 ``str`` | |||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ; | |||
* 为 ``Callable`` | |||
* 为 :class:`Callable` | |||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||
的 ``monitor`` 值请返回 ``None`` ; | |||
:param larger_better: monitor 是否是越大越好; | |||
@@ -22,11 +22,11 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
:param monitor: 监控的 metric 值: | |||
* 为 ``None`` | |||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置); | |||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置); | |||
* 为 ``str`` | |||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ; | |||
* 为 ``Callable`` | |||
* 为 :class:`Callable` | |||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||
的 ``monitor`` 值请返回 ``None`` ; | |||
:param larger_better: 该 metric 值是否是越大越好; | |||
@@ -72,14 +72,14 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||
时间戳文件夹中。如果为 ``None`` ,默认使用当前文件夹。 | |||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 ``model_save_fn`` 不为 ``None`` 时,该参数无效。 | |||
:param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
如果传入了 ``model_save_fn`` 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 ``True`` ,在保存 topk 模型的 folder 中还将额外保存一个 | |||
``fastnlp_evaluate_results.json`` 文件,记录当前的 results。仅在设置了 ``topk`` 的场景下有效,默认为 True 。 | |||
:param save_kwargs: 一个字典,表示更多的保存相关的参数。 | |||
:param kwargs: 其它与 :class:`~fastNLP.Evaluator` 相关的初始化参数,如果不传入,将从 :class:`~fastNLP.Trainer` 中获取。 | |||
:param kwargs: 其它与 :class:`~fastNLP.core.controllers.Evaluator` 相关的初始化参数,如果不传入,将从 :class:`~fastNLP.core.controllers.Trainer` 中获取。 | |||
""" | |||
def __init__(self, dataloaders, metrics:Dict, evaluate_every:Optional[Union[int, Callable]]=-1, | |||
watch_monitor:Union[str, Callable]=None, watch_monitor_larger_better:bool=True, | |||
@@ -68,13 +68,14 @@ class RichCallback(ProgressCallback): | |||
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||
* 为 ``None`` | |||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||
* 为 ``str`` | |||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||
* 为 ``Callable`` | |||
* 为 :class:`Callable` | |||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||
的 ``monitor`` 值请返回 ``None`` 。 | |||
:param larger_better: 是否是 monitor 的结果越大越好。 | |||
:param format_json: 是否格式化 json 再打印 | |||
""" | |||
@@ -177,11 +178,11 @@ class RawTextCallback(ProgressCallback): | |||
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||
* 为 ``None`` | |||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||
* 为 ``str`` | |||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||
* 为 ``Callable`` | |||
* 为 :class:`Callable` | |||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||
的 ``monitor`` 值请返回 ``None`` 。 | |||
:param larger_better: 是否是monitor的结果越大越好。 | |||
@@ -251,11 +252,11 @@ class TqdmCallback(ProgressCallback): | |||
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||
* 为 ``None`` | |||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||
* 为 ``str`` | |||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||
* 为 ``Callable`` | |||
* 为 :class:`Callable` | |||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||
的 ``monitor`` 值请返回 ``None`` 。 | |||
:param larger_better: 是否是 monitor 的结果越大越好。 | |||
@@ -25,7 +25,7 @@ class Saver: | |||
:param folder: 保存在哪个文件夹下,默认为当前 folder 下。 | |||
:param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||
:param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。 | |||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
@@ -182,17 +182,17 @@ class TopkSaver(ResultsMonitor, Saver): | |||
:param monitor: 监控的 metric 值: | |||
* 为 ``None`` | |||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||
* 为 ``str`` | |||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||
* 为 ``Callable`` | |||
* 为 :class:`Callable` | |||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||
的 ``monitor`` 值请返回 ``None`` 。 | |||
:param larger_better: 该 monitor 是否越大越好。 | |||
:param folder: 保存在哪个文件夹下,默认为当前 folder 下。 | |||
:param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||
:param only_state_dict: 保存时是否仅保存权重,在 ``model_save_fn`` 不为 None 时无意义。 | |||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
@@ -220,7 +220,7 @@ class TopkSaver(ResultsMonitor, Saver): | |||
@rank_zero_call | |||
def save_topk(self, trainer, results: Dict) -> Optional[str]: | |||
""" | |||
根据 results 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。如果返回为 None ,则说明此次没有满足 | |||
根据 ``results`` 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。如果返回为 ``None`` ,则说明此次没有满足 | |||
topk 要求,没有发生保存。 | |||
:param trainer: | |||
@@ -205,7 +205,7 @@ class Collator: | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, | |||
该值无意义。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 ``dtype`` 应该是什么。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 | |||
:param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, | |||
:class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 | |||
若 ``pad_val`` 为 ``None`` ,该值无意义 。 | |||
@@ -291,7 +291,7 @@ class Collator: | |||
def set_ignore(self, *field_names) -> "Collator": | |||
""" | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略:: | |||
>>> collator = Collator().set_ignore('field1', 'field2') | |||
@@ -2,6 +2,7 @@ from collections import defaultdict | |||
from functools import reduce | |||
from typing import Sequence, Mapping, Dict | |||
__all__ = [] | |||
class MappingPackerUnpacker: | |||
@staticmethod | |||
@@ -1,7 +1,7 @@ | |||
class Padder: | |||
""" | |||
所有 **Padder** 对象的父类,所有的 Padder 对象都会实现静态函数 *pad(batch_field, pad_val=0, dtype=None)* 。 | |||
所有 **Padder** 对象的父类,所有的 Padder 对象都会实现静态函数 ``pad(batch_field, pad_val=0, dtype=None)`` 。 | |||
""" | |||
def __init__(self, pad_val, dtype): | |||
@@ -80,7 +80,7 @@ def get_padded_numpy_array(batch_field: List, dtype=None, pad_val=0) -> np.ndarr | |||
""" | |||
将输入 pad 为 :class:`numpy.arraay` 类型,如:``[[1,2], [3]] -> np.array([[1, 2], [3, 0]])`` | |||
:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 **1d**(多为句子长度)/ **2d**(多为文本序列)/ **3d**(多为字符序列) | |||
:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 **1d** (多为句子长度)/ **2d** (多为文本序列)/ **3d** (多为字符序列) | |||
/4d(多为图片); | |||
:param dtype: 输出数据的 dtype 类型; | |||
:param pad_val: 填充值; | |||
@@ -1,5 +1,5 @@ | |||
r""" | |||
``Evaluator`` 是新版 fastNLP 中用来进行评测模型的评测器,其与 ``Trainer`` 相对应,二者共同构建起了 fastNLP 中**训练**和**评测**的框架。 | |||
``Evaluator`` 是新版 **fastNLP** 中用来进行评测模型的评测器,其与 ``Trainer`` 相对应,二者共同构建起了 **fastNLP** 中 **训练** 和 **评测** 的框架。 | |||
``Evaluator`` 的整体架构与 ``Trainer`` 类似,也是利用 ``Driver`` 来负责底层的评测逻辑。通过使用 ``Evaluator``,您可以快速、方便、准确地 | |||
对您的模型进行全方位地评测。 | |||
@@ -75,11 +75,11 @@ class Evaluator: | |||
:param device: 等价于 ``Trainer`` 中的 ``device`` 参数; | |||
:param evaluate_batch_step_fn: 您可以传入该参数来定制每次评测一个 batch 的数据时所执行的函数。该函数应接受的两个参数为 ``evaluator`` 和 ``batch``, | |||
不需要有返回值;可以参考 :meth:`~fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop.batch_step_fn`; | |||
:param evaluate_fn: 用来控制 ``Evaluator`` 在评测的前向传播过程中调用的是哪一个函数,例如对于 pytorch 而言,通过该参数确定使用的是 ``model.evaluate_step`` 还是 | |||
``model.forward``(不同训练框架所使用的的前向传播函数的方法名称不同); | |||
:param evaluate_fn: 用来控制 ``Evaluator`` 在评测的前向传播过程中调用的是哪一个函数,例如对于 pytorch 而言,通过该参数确定使用的是 :meth:`model.evaluate_step` 还是 | |||
:meth:`model.forward` (不同训练框架所使用的的前向传播函数的方法名称不同); | |||
1. 如果该值是 ``None``,那么我们会默认使用 ``evaluate_step`` 当做前向传播的函数,如果在模型中没有找到该方法,则使用训练框架默认的前向传播函数; | |||
2. 如果为 ``str`` 类型,例如为 ``my_evaluate_step_fn``,则尝试寻找 ``model.my_evaluate_step_fn``,如果找不到则直接报错; | |||
2. 如果为 ``str`` 类型,例如为 ``'my_evaluate_step_fn'``,则尝试寻找 :meth:`model.my_evaluate_step_fn`,如果找不到则直接报错; | |||
:param input_mapping: 等价于 ``Trainer`` 中的 ``input_mapping`` 参数;对具体的用于评测一个 batch 的数据使用 ``input_mapping`` 处理之后再输入到 ``model`` 以及 ``metric`` 中。如果针对 | |||
``model`` 和 ``metric`` 需要不同的 ``mapping``,请考虑使用 ``evaluate_batch_step_fn`` 参数定制; | |||
@@ -97,20 +97,27 @@ class Evaluator: | |||
``metric`` 的计算都是自动化的,因此其一定需要参数匹配:根据 ``metric.update`` 的函数签名直接从字典数据中抽取其需要的参数传入进去; | |||
:param fp16: 是否在评测时使用 fp16; | |||
:param fp16: 是否在评测时使用 fp16 混合精度; | |||
:param verbose: 是否打印 evaluate 的结果; | |||
:kwargs: | |||
* *torch_kwargs* -- 等价于 ``Trainer`` 中的 ``torch_kwargs`` 参数; | |||
* *paddle_kwargs* -- 等价于 ``Trainer`` 中的 ``paddle_kwargs`` 参数; | |||
* *fairscale_kwargs* -- 等价于 ``Trainer`` 中的 ``fairscale_kwargs`` 参数; | |||
* *deepspeed_kwargs* -- 等价于 ``Trainer`` 中的 ``deepspeed_kwargs`` 参数; | |||
* *oneflow_kwargs* -- 等价于 ``Trainer`` 中的 ``oneflow_kwargs`` 参数; | |||
* *data_device* -- 等价于 ``Trainer`` 中的 ``data_device`` 参数; | |||
* *model_use_eval_mode* (``bool``) -- | |||
是否在评测的时候将 ``model`` 的状态设置成 ``eval`` 状态。在 ``eval`` 状态下,``model`` 的 | |||
``dropout`` 与 ``batch normalization`` 将会关闭。默认为 ``True``。如果为 ``False``,``fastNLP`` 不会对 ``model`` 的 ``evaluate`` 状态做任何设置。无论 | |||
该值是什么,``fastNLP`` 都会在评测后将 ``model`` 的状态设置为 ``train``; | |||
是否在评测的时候将 ``model`` 的状态设置成 ``eval`` 状态。在 ``eval`` 状态下,``model`` 的 | |||
``dropout`` 与 ``batch normalization`` 将会关闭。默认为 ``True``。如果为 ``False``,``fastNLP`` 不会对 ``model`` 的 ``evaluate`` 状态做任何设置。无论 | |||
该值是什么,``fastNLP`` 都会在评测后将 ``model`` 的状态设置为 ``train``; | |||
* *use_dist_sampler* -- | |||
表示在 ``Evaluator`` 中在使用分布式的时候是否将保证 dataloader 的 ``sampler`` 替换为 | |||
分布式的 ``sampler``,其特点是每个卡上的数据之间不重叠,所有卡上数据的加起来是整个数据集。若传入的 dataloader | |||
的 sampler 为 (a) 深度学习框架自带的默认 sampler ; (b) fastNLP 的 Sampler 等,则将替换为 | |||
:class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以 | |||
表示在 ``Evaluator`` 中在使用分布式的时候是否将保证 dataloader 的 ``sampler`` 替换为 | |||
分布式的 ``sampler``,其特点是每个卡上的数据之间不重叠,所有卡上数据的加起来是整个数据集。若传入的 dataloader | |||
的 sampler 为: | |||
- 深度学习框架自带的默认 sampler ; | |||
- fastNLP 的 Sampler ; | |||
则将替换为 :class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以 | |||
用到的数据。如果不是以上两类 sampler ,fastNLP 将报错。 | |||
* *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数; | |||
* *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数; | |||
@@ -123,7 +130,7 @@ class Evaluator: | |||
def __init__(self, model, dataloaders, metrics: Optional[Dict] = None, | |||
driver: Union[str, Driver] = 'auto', device: Optional[Union[int, List[int], str]] = None, | |||
evaluate_batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None, | |||
evaluate_batch_step_fn: Optional[Callable] = None, evaluate_fn: Optional[str] = None, | |||
input_mapping: Optional[Union[Callable, Dict]] = None, | |||
output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, | |||
fp16: bool = False, verbose: int = 1, **kwargs): | |||
@@ -203,16 +210,16 @@ class Evaluator: | |||
""" | |||
用于帮助您加载模型的辅助函数; | |||
:param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时, | |||
直接将该 folder 传递到 model_load_fn 中; | |||
:param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn 不为 None`` 时,该参数无意义; | |||
:param model_load_fn: ``callable`` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容; | |||
:param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 ``fastnlp_model.pkl.tar`` 文件。在 ``model_load_fn`` 不为空时, | |||
直接将该 folder 传递到 ``model_load_fn`` 中; | |||
:param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn`` 不为 ``None`` 时,该参数无意义; | |||
:param model_load_fn: :class:`Callable` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容; | |||
:param kwargs: 理论上您不需要使用到该参数; | |||
.. note:: | |||
注意您需要在初始化 ``Evaluator`` 后再通过 ``evaluator`` 实例来调用该函数;这意味着您需要保证在保存和加载时使用的 ``driver`` 是属于同一个 | |||
训练框架的,例如都是 ``pytorch`` 或者 ``paddle``; | |||
训练框架的,例如都是 **pytorch** 或者 **PaddlePaddle** ; | |||
""" | |||
self.driver.barrier() | |||
if not isinstance(folder, (io.BytesIO, BinaryIO)): | |||
@@ -240,15 +247,14 @@ class Evaluator: | |||
""" | |||
该函数是在 ``Evaluator`` 初始化后用于真正开始评测的函数; | |||
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | |||
返回一个字典类型的数据,其中 key 为 metric 的名字,value 为对应 metric 的结果。 | |||
1. 如果存在多个metric,一个dataloader的情况,key的命名规则是 | |||
``metric_indicator_name#metric_name`` | |||
1. 如果存在多个 metric ,一个 dataloader 的情况,key 的命名规则是 | |||
``metric_indicator_name#metric_name``; | |||
2. 如果存在多个数据集,一个metric的情况,key的命名规则是 | |||
``metric_indicator_name#metric_name#dataloader_name`` (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 | |||
如果存在多个metric,多个dataloader的情况,key的命名规则是 | |||
``metric_indicator_name#metric_name#dataloader_name`` | |||
其中 metric_indicator_name 可能不存在; | |||
``metric_indicator_name#metric_name#dataloader_name`` (其中 **#** 是默认的 separator ,可以通过 Evaluator 初始化参数修改); | |||
3. 如果存在多个metric,多个dataloader的情况,key的命名规则是 | |||
``metric_indicator_name#metric_name#dataloader_name``,其中 metric_indicator_name 可能不存在; | |||
:param num_eval_batch_per_dl: 每个 dataloader 测试前多少个 batch 的数据,-1 为测试所有数据。 | |||
:return: 返回评测得到的结果,是一个没有嵌套的字典; | |||
@@ -360,7 +366,7 @@ class Evaluator: | |||
def reset(self): | |||
""" | |||
调用所有 metric 的 reset() 方法,清除累积的状态。 | |||
调用所有 metric 的 :meth:`reset` 方法,清除累积的状态。 | |||
:return: | |||
""" | |||
@@ -368,7 +374,7 @@ class Evaluator: | |||
def update(self, batch, outputs): | |||
""" | |||
自动调用所有 metric 的 update 方法,会根据不同 metric 的参数列表进行匹配传参。 | |||
自动调用所有 metric 的 :meth:`update` 方法,会根据不同 metric 的参数列表进行匹配传参。 | |||
:param batch: 一般是来自于 DataLoader 的输出,如果不为 dict 类型的话,该值将被忽略。 | |||
:param outputs: 一般是来自于模型的输出。类别应为 dict 或者 dataclass 类型。 | |||
@@ -378,7 +384,7 @@ class Evaluator: | |||
def get_metric(self) -> Dict: | |||
""" | |||
调用所有 metric 的 get_metric 方法,并返回结果。其中 key 为 metric 的名称,value 是各个 metric 的结果。 | |||
调用所有 metric 的 :meth:`get_metric` 方法,并返回结果。其中 key 为 metric 的名称,value 是各个 metric 的结果。 | |||
:return: | |||
""" | |||
@@ -387,11 +393,9 @@ class Evaluator: | |||
@property | |||
def metrics_wrapper(self): | |||
""" | |||
由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 evaluate_batch_step_fn )中使用,同时也为了支持 | |||
由于需要保持 Evaluator 中 ``metrics`` 对象与用户传入的 ``metrics`` 保持完全一致(方便在 ``evaluate_batch_step_fn`` )中使用,同时也为了支持 | |||
不同形式的 metric( fastNLP 的 metric/torchmetrics 等),所以 Evaluator 在进行 metric 操作的时候都调用 metrics_wrapper | |||
进行操作。 | |||
Returns: | |||
""" | |||
if self._metric_wrapper is None: | |||
self._metric_wrapper = _MetricsWrapper(self.metrics, evaluator=self) | |||
@@ -399,11 +403,12 @@ class Evaluator: | |||
def evaluate_step(self, batch): | |||
""" | |||
将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 。会将返回结果经过 output_mapping 处理后再 | |||
返回。 | |||
将 ``batch`` 传递到 model 中进行处理,根据当前 ``evaluate_fn`` 选择进行 evaluate 。会将返回结果经过 ``output_mapping`` | |||
处理后再 | |||
返回。 | |||
:param batch: {evaluate_fn} 函数支持的输入类型 | |||
:return: {evaluate_fn} 函数的输出结果,如果有设置 output_mapping ,将是 output_mapping 之后的结果。 | |||
:param batch: ``evaluate_fn`` 函数支持的输入类型 | |||
:return: ``evaluate_fn`` 函数的输出结果,如果有设置 ``output_mapping`` ,将是 ``output_mapping`` 之后的结果。 | |||
""" | |||
outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn) | |||
outputs = match_and_substitute_params(self.output_mapping, outputs) | |||
@@ -412,7 +417,7 @@ class Evaluator: | |||
@property | |||
def metrics(self): | |||
""" | |||
返回用户传入的 metrics 对象。 | |||
返回用户传入的 ``metrics`` 对象。 | |||
:return: | |||
""" | |||
@@ -13,7 +13,7 @@ class EvaluateBatchLoop(Loop): | |||
r""" | |||
``EvaluateBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的评测迭代过程; | |||
:param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; | |||
:param batch_step_fn: 您可以传入该参数来替换默认的 ``bath_step_fn``; | |||
""" | |||
def __init__(self, batch_step_fn:Optional[Callable]=None): | |||
if batch_step_fn is not None: | |||
@@ -21,10 +21,10 @@ class EvaluateBatchLoop(Loop): | |||
def run(self, evaluator, dataloader) -> Dict: | |||
r""" | |||
需要返回在传入的 dataloader 中的 evaluation 结果 | |||
需要返回在传入的 ``dataloader`` 中的 evaluation 结果 | |||
:param evaluator: Evaluator 对象 | |||
:param dataloader: 当前需要进行评测的dataloader | |||
:param evaluator: :class:`~fastNLP.core.controllers.Evaluator` 对象 | |||
:param dataloader: 当前需要进行评测的 ``dataloader`` | |||
:return: | |||
""" | |||
iterator = iter(dataloader) | |||
@@ -58,10 +58,10 @@ class EvaluateBatchLoop(Loop): | |||
@staticmethod | |||
def batch_step_fn(evaluator, batch): | |||
r""" | |||
针对一个 batch 的数据的评测过程; | |||
针对一个 ``batch`` 的数据的评测过程; | |||
:param evaluator: Evaluator 对象 | |||
:param batch: 当前需要评测的一个 batch 的数据; | |||
:param evaluator: :class:`~fastNLP.core.controllers.Evaluator` 对象 | |||
:param batch: 当前需要评测的一个 ``batch`` 的数据; | |||
""" | |||
outputs = evaluator.evaluate_step(batch) # 将batch输入到model中得到结果 | |||
evaluator.update(batch, outputs) # evaluator将根据metric的形参名字从batch/outputs中取出对应的值进行赋值 |
@@ -1,5 +1,5 @@ | |||
r""" | |||
``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,为了在实现 fastNLP 主要功能的同时保证 fastNLP 的易用性和代码的易读性,我们只对 | |||
``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,为了在实现 **fastNLP** 主要功能的同时保证 **fastNLP** 的易用性和代码的易读性,我们只对 | |||
训练中的循环做了非常简单的抽象,``Loop`` 表示的是在训练或者评测的过程中针对单独一个 ``dataloader`` 的一个 ``epoch`` 的运算过程; | |||
更为具体的使用详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` 和 | |||
@@ -24,7 +24,7 @@ class Loop: | |||
.. note:: | |||
``Trainer`` 和 ``Evaluator`` 中都提供了方便您进行定制 ``Loop`` 的接口函数,例如 ``Trainer.train_step``,``Trainer.backward``等; | |||
``Trainer`` 和 ``Evaluator`` 中都提供了方便您进行定制 ``Loop`` 的接口函数,例如 ``Trainer.train_step``, ``Trainer.backward`` 等; | |||
在定制您自己的 ``TrainBatchLoop`` 时,请务必记得在正确的时机调用对应的 callback 函数,详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` | |||
中对于 callback 函数的调用; | |||
@@ -34,5 +34,5 @@ class Loop: | |||
@staticmethod | |||
def batch_step_fn(controller: Union["Trainer", "Evaluator"], batch): | |||
r""" | |||
对于具体的一个 batch 的数据,实现训练或者评测过程中的一步; | |||
对于具体的一个 ``batch`` 的数据,实现训练或者评测过程中的一步; | |||
""" |
@@ -14,7 +14,7 @@ class TrainBatchLoop(Loop): | |||
r""" | |||
``TrainBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的训练迭代过程; | |||
:param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; | |||
:param batch_step_fn: 您可以传入该参数来替换默认的 ``bath_step_fn``; | |||
""" | |||
def __init__(self, batch_step_fn: Optional[Callable] = None): | |||
@@ -23,14 +23,14 @@ class TrainBatchLoop(Loop): | |||
def run(self, trainer, dataloader): | |||
r""" | |||
对传入的 dataloader 进行一个 epoch 的主要的训练的循环过程; | |||
对传入的 ``dataloader`` 进行一个 epoch 的主要的训练的循环过程; | |||
.. note:: | |||
您不需要自己主动地调用该方法,``Trainer`` 会负责调用该方法来完成训练过程; | |||
:param trainer: ``Trainer`` 实例; | |||
:param dataloader: 当前训练所使用的 dataloader; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
:param dataloader: 当前训练所使用的 ``dataloader``; | |||
""" | |||
get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ | |||
else lambda *args, **kwargs: None | |||
@@ -68,10 +68,10 @@ class TrainBatchLoop(Loop): | |||
@staticmethod | |||
def batch_step_fn(trainer, batch): | |||
r""" | |||
针对一个 batch 的数据的训练过程; | |||
针对一个 ``batch`` 的数据的训练过程; | |||
:param trainer: ``Trainer`` 实例; | |||
:param batch: 一个 batch 的数据; | |||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||
:param batch: 一个 ``batch`` 的数据; | |||
""" | |||
outputs = trainer.train_step(batch) | |||
trainer.backward(outputs) | |||
@@ -1,7 +1,7 @@ | |||
""" | |||
``Trainer`` 是 fastNLP 用于训练模型的专门的训练器,其支持多种不同的驱动模式 ``Driver``,不仅包括最为经常使用的 DDP,而且还支持 jittor 等国产 | |||
的训练框架;新版的 fastNLP 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需要自己实现 | |||
模型部分,而将训练层面的逻辑完全地交给 fastNLP; | |||
``Trainer`` 是 **fastNLP** 用于训练模型的专门的训练器,其支持多种不同的驱动模式 ``Driver``,不仅包括最为经常使用的 DDP,而且还支持 jittor 等国产 | |||
的训练框架;新版的 **fastNLP** 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需要自己实现 | |||
模型部分,而将训练层面的逻辑完全地交给 **fastNLP**; | |||
""" | |||
from typing import Union, Optional, List, Callable, Dict, BinaryIO | |||
@@ -42,7 +42,7 @@ class Trainer(TrainerEventTrigger): | |||
r""" | |||
用于支持快速训练的训练器。 | |||
:param model: 训练所需要的模型,例如 ``torch.nn.Module``; | |||
:param model: 训练所需要的模型,例如 :class:`torch.nn.Module`; | |||
.. note:: | |||
@@ -55,10 +55,17 @@ class Trainer(TrainerEventTrigger): | |||
您应当使用 ``TorchDDPDriver``,意味着您需要通过 ``python -m torch.distributed.launch`` 的方式来启动训练,此时参数 ``device`` | |||
应当设置为 None(此时我们会忽略该参数),具体见下面对于参数 ``device`` 的更详细的解释。 | |||
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["auto", "torch", "paddle", "jittor", "fairscale"]。其值为 ``"auto"`` 时, | |||
**FastNLP** 会根据传入模型的类型自行判断使用哪一种模式;其值为 "torch" 时,表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``; | |||
其值为 "paddle" 时,表示使用 ``PaddleSingleDriver`` 或者 ``PaddleFleetDriver``;其值为 "jittor" 时,表示使用 ``JittorSingleDriver`` | |||
或者 ``JittorMPIDriver``;其值为 "fairscale" 时,表示使用 ``FairScaleDriver``。在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置; | |||
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:``["auto", "torch", "paddle", "jittor", "fairscale", "deepspeed", "oneflow"]``: | |||
1. 值为 ``"auto"`` 时,**FastNLP** 会根据传入模型的类型自行判断使用哪一种模式; | |||
2. 其值为 ``"torch"`` 时,表示使用 :class:`~fastNLP.core.drivers.TorchSingleDriver` 或者 :class:`~fastNLP.core.drivers.TorchDDPDriver`; | |||
3. 其值为 ``"paddle"`` 时,表示使用 :class:`~fastNLP.core.drivers.PaddleSingleDriver` 或者 :class:`~fastNLP.core.drivers.PaddleFleetDriver`; | |||
4. 其值为 ``"jittor"`` 时,表示使用 :class:`~fastNLP.core.drivers.JittorSingleDriver` 或者 :class:`~fastNLP.core.drivers.JittorMPIDriver`; | |||
5. 其值为 ``"fairscale"`` 时,表示使用 :class:`~fastNLP.core.drivers.FairScaleDriver`; | |||
6. 其值为 ``"deepspeed"`` 时,表示使用 :class:`~fastNLP.core.drivers.DeepSpeedDriver`; | |||
7. 其值为 ``"oneflow"`` 时,表示使用 :class:`~fastNLP.core.drivers.OneflowSingleDriver` 或者 :class:`~fastNLP.core.drivers.OneflowDDPDriver`; | |||
在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置; | |||
.. warning:: | |||
@@ -66,26 +73,28 @@ class Trainer(TrainerEventTrigger): | |||
这意味着当您传入一个 ``Driver`` 实例时,您传入给 ``Trainer`` 的 ``model`` 参数将会被忽略;也就是说模型在训练时使用的真正的模型是 | |||
您传入的 ``Driver`` 实例中的模型; | |||
:param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict; | |||
:param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 :class:`List` 或者 :class:`Dict`; | |||
.. warning:: | |||
当使用分布式训练时, ``fastNLP`` 会默认将 ``dataloader`` 中的 ``Sampler`` 进行处理,以使得在一个 ``epcoh`` 中,不同卡 | |||
当使用分布式训练时, **fastNLP** 会默认将 ``dataloader`` 中的 ``Sampler`` 进行处理,以使得在一个 epoch 中,不同卡 | |||
用以训练的数据是不重叠的。如果你对 sampler 有特殊处理,那么请将 ``use_dist_sampler`` 参数设置为 ``False`` ,此刻需要由 | |||
你自身保证每张卡上所使用的数据是不同的。 | |||
:param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; | |||
:param device: 该参数用来指定具体训练时使用的机器;注意当该参数仅当您通过 `torch.distributed.launch/run` 启动时可以为 None, | |||
此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是你可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间 | |||
数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据 | |||
:param device: 该参数用来指定具体训练时使用的机器;注意当该参数仅当您通过 ``torch.distributed.launch/run`` 启动时可以为 ``None``, | |||
此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是你可以通过参数 ``input_mapping`` 和 ``output_mapping`` 来实现设备之间 | |||
数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也可以通过在 kwargs 添加参数 ``data_device`` 来让我们帮助您将数据 | |||
迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前自己构造 DDP 的场景); | |||
device 的可选输入如下所示: | |||
* *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1', 'gpu:0' 等; | |||
* *torch.device*: 例如 'torch.device("cuda:0")'; | |||
* *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver`; | |||
* *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用 ``TorchDDPDriver``,不管您传入的列表的长度是 1 还是其它值; | |||
* *str*: 例如 ``'cpu'``, ``'cuda'``, ``'cuda:0'``, ``'cuda:1'``, ``'gpu:0'`` 等; | |||
* *torch.device*: 例如 ``torch.device("cuda:0")``; | |||
* *oneflow.device*:例如 ``oneflow.device("cuda", 0)``; | |||
* *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver` 这类 | |||
执行分布式训练的 Driver | |||
* *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用分布式训练的 Driver ,不管您传入的列表的长度是 1 还是其它值; | |||
* *None*: 仅当用户自己通过训练框架提供的并行训练启动脚本开启 ddp 进程时为 None; | |||
.. note:: | |||
@@ -121,7 +130,7 @@ class Trainer(TrainerEventTrigger): | |||
:param n_epochs: 训练总共的 epoch 的数量,默认为 20;也可以通过 ``n_batches`` 参数设置总共迭代多少个 ``batch`` 。 | |||
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | |||
为 None; | |||
为 ``None``; | |||
:param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``, | |||
不需要要返回值;更详细的使用位置和说明请见 :meth:`~fastNLP.core.controllers.TrainBatchLoop.batch_step_fn`; | |||
:param evaluate_batch_step_fn: 定制每次验证时前向运行一个 batch 的数据所执行的函数。该函数应接受的两个参数为 ``evaluator`` 和 ``batch``, | |||
@@ -133,8 +142,8 @@ class Trainer(TrainerEventTrigger): | |||
.. note:: | |||
在 fastNLP 中,对于训练时使用的前向传播函数的查找逻辑如下所示: | |||
1. 如果 ``train_fn`` 为 None,那么在 model 的类 Model 中寻找方法 ``Model.train_step``;如果没有找到,那么默认使用 ``Model.forward``; | |||
2. 如果 ``train_fn`` 为一个字符串,例如 'my_step_fn',那么我们首先会在 model 的类 Model 中寻找方法 ``Model.my_step_fn``, | |||
1. 如果 ``train_fn`` 为 None,那么在 model 的类 Model 中寻找方法 :meth:`Model.train_step` ;如果没有找到,那么默认使用 :meth:`Model.forward`; | |||
2. 如果 ``train_fn`` 为一个字符串,例如 ``'my_step_fn'``,那么我们首先会在 model 的类 Model 中寻找方法 :meth:`Model.my_step_fn`, | |||
如果没有找到,那么会直接报错; | |||
:param evaluate_fn: 用来控制 ``Trainer`` 中内置的 ``Evaluator`` 在验证的前向传播过程中是调用模型的哪一个函数,应当为 ``None`` | |||
@@ -142,7 +151,7 @@ class Trainer(TrainerEventTrigger): | |||
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 ``Callback`` 类;具体可见 | |||
:class:`~fastNLP.core.callbacks.Callback`; | |||
:param metrics: 用于传给 ``Trainer`` 内部的 ``Evaluator`` 实例来进行训练过程中的验证。其应当为一个字典,其中 key 表示 monitor, | |||
例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | |||
例如 ``{"acc1": AccMetric(), "acc2": AccMetric()}``; | |||
目前我们支持的 ``metric`` 的种类有以下几种: | |||
@@ -156,7 +165,7 @@ class Trainer(TrainerEventTrigger): | |||
1. 为负数时表示每隔几个 ``epoch`` evaluate 一次; | |||
2. 为正数则表示每隔几个 ``batch`` evaluate 一次; | |||
3. 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | |||
返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 ``batch`` 结束后调用该函数判断是否需要 evaluate; | |||
返回一个 bool 值,返回为 ``True`` 说明需要进行 evaluate ;将在每个 ``batch`` 结束后调用该函数判断是否需要 evaluate; | |||
.. note:: | |||
@@ -208,7 +217,7 @@ class Trainer(TrainerEventTrigger): | |||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为; | |||
1. 如果该值为 ``False``,并且当 batch 为字典时,我们会根据**前向函数**所需要的参数从 batch 中提取对应的对象,然后传入到**前向函数**中; | |||
1. 如果该值为 ``False``,并且当 batch 为字典时,我们会根据 **前向函数** 所需要的参数从 batch 中提取对应的对象,然后传入到 **前向函数** 中; | |||
2. 如果该值为 ``True``,那么我们会将 batch 直接透传给模型; | |||
.. todo:: | |||
@@ -219,8 +228,8 @@ class Trainer(TrainerEventTrigger): | |||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 才让优化器迭代一次,默认为 1; | |||
:param fp16: 是否开启混合精度训练,默认为 False; | |||
:param monitor: 对于一些特殊的 ``Callback``,例如 :class:`~fastNLP.core.callbacks.CheckpointCallback`,它们需要参数 ``monitor`` | |||
来从 ``Evaluator`` 的验证结果中获取当前评测的值,从而来判断是否执行一些特殊的操作。例如,对于 ``CheckpointCallback`` 而言,如果我们 | |||
想要每隔一个 epoch 让 ``Evaluator`` 进行一次验证,然后保存训练以来的最好的结果;那么我们需要这样设置: | |||
来从 ``Evaluator`` 的验证结果中获取当前评测的值,从而来判断是否执行一些特殊的操作。例如,对于 :class:`~fastNLP.core.callbacks.CheckpointCallback` | |||
而言,如果我们想要每隔一个 epoch 让 ``Evaluator`` 进行一次验证,然后保存训练以来的最好的结果;那么我们需要这样设置: | |||
.. code-block:: | |||
@@ -234,7 +243,7 @@ class Trainer(TrainerEventTrigger): | |||
)] | |||
) | |||
这意味着对于 ``CheckpointCallback`` 来说,*'acc'* 就是一个监测的指标,用于在 ``Evaluator`` 验证后取出其需要监测的那个指标的值。 | |||
这意味着对于 :class:`~fastNLP.core.callbacks.CheckpointCallback` 来说,*'acc'* 就是一个监测的指标,用于在 ``Evaluator`` 验证后取出其需要监测的那个指标的值。 | |||
``Trainer`` 中的参数 ``monitor`` 的作用在于为没有设置 ``monitor`` 参数但是需要该参数的 *callback* 实例设置该值。关于 ``monitor`` | |||
参数更详细的说明,请见 :class:`~fastNLP.core.callbacks.CheckpointCallback`; | |||
@@ -247,7 +256,7 @@ class Trainer(TrainerEventTrigger): | |||
注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效; | |||
:param n_batches: 总共迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。 | |||
:param overfit_batches: 使用该参数来支持 '过拟合' 的功能;支持的值为 ``-1``、``0`` 或者 大于 0 的整数,表示使用多少个 batch 的数据 | |||
:param overfit_batches: 使用该参数来支持 **'过拟合'** 的功能;支持的值为 ``-1``、``0`` 或者 大于 0 的整数,表示使用多少个 batch 的数据 | |||
来进行过拟合训练;其中 0 为表示不进行任何操作;-1 表示使用所有的数据进行训练; | |||
.. note:: | |||
@@ -258,10 +267,10 @@ class Trainer(TrainerEventTrigger): | |||
.. warning:: | |||
在使用该参数时,您同样可以指定 ``metrics`` 参数来进行简单的验证,当该参数和 ``metrics`` 同时出现时,我们会将 evaluate_dataloaders | |||
在使用该参数时,您同样可以指定 ``metrics`` 参数来进行简单的验证,当该参数和 ``metrics`` 同时出现时,我们会将 ``evaluate_dataloaders`` | |||
直接替换为在过拟合中所使用的训练数据;因此您需要保证您的 ``metrics`` 是能够在 ``train_dataloader`` 上使用的; | |||
:param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None; | |||
:param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 ``None``; | |||
.. note:: | |||
@@ -283,7 +292,7 @@ class Trainer(TrainerEventTrigger): | |||
) | |||
另一点需要说明的是,如果一个被 ``Trainer.on`` 修饰的函数,其修饰时没有指明 ``marker``,那么会将该函数传给代码位于其之后的 | |||
第一个 ``Trainer`` 实例,即使该 ``Trainer`` 实例的 marker 不为 None;这一点详见 :meth:`~fastNLP.core.controllers.Trainer.on` | |||
第一个 ``Trainer`` 实例,即使该 ``Trainer`` 实例的 marker 不为 ``None``;这一点详见 :meth:`~fastNLP.core.controllers.Trainer.on` | |||
:kwargs: | |||
* *torch_kwargs* -- ``TorchDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.TorchSingleDriver` 和 | |||
@@ -291,35 +300,43 @@ class Trainer(TrainerEventTrigger): | |||
* *paddle_kwargs* -- ``PaddleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver` 和 | |||
:class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver`; | |||
* *fairscale_kwargs* -- ``FairScaleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.FairScaleDriver`; | |||
* *deepspeed_kwargs* -- ``DeepSpeedDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.DeepSpeedDriver`; | |||
* *torch_kwargs* -- ``OneflowDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.oneflow_driver.OneflowSingleDriver` 和 | |||
:class:`~fastNLP.core.drivers.oneflow_driver.OneflowDDPDriver`; | |||
* *data_device* -- 一个具体的 driver 实例中,有 ``model_device`` 和 ``data_device``,前者表示模型所在的设备,后者表示 | |||
当 ``model_device`` 为 None 时应当将数据迁移到哪个设备; | |||
.. note:: | |||
注意您在绝大部分情况下不会用到该参数! | |||
**注意您在绝大部分情况下不会用到该参数!** | |||
1. 当 driver 实例的 ``model_device`` 不为 None 时,该参数无效; | |||
2. 对于 pytorch,仅当用户自己通过 ``python -m torch.distributed.launch`` 并且自己初始化 ``init_process_group`` 时, | |||
driver 实例的 ``model_device`` 才会为 None; | |||
3. 对于 paddle,该参数无效; | |||
2. 对于 **pytorch**,仅当用户自己通过 ``python -m torch.distributed.launch`` 并且自己初始化 ``init_process_group`` 时, | |||
driver 实例的 ``model_device`` 才会为 None; | |||
2. 对于 **deepspeed**,仅当用户自己通过 ``deepspeed xxx.py`` 并且自己初始化 ``model.initialize`` 时, | |||
driver 实例的 ``model_device`` 才会为 None; | |||
3. 对于 **paddle** 和 **oneflow**,该参数无效; | |||
* *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch | |||
内所有卡的 sample 加起来为一整个数据集的 sample,同时为了保证所有卡上拥有相同数量的 sample ,有的卡上可能会有重复的 sample ,例如 | |||
8卡训练,只有9个sample,如果batch_size为1,那么第二个batch时,有7张卡将没有 sample 可用,因此只有重复使用 sample 来 pad 到第二个 | |||
batch 中。如果不希望 fastNLP 对 dataloader 的sampler 做特殊设置,请将该值设置为 False ,若确实需要分布式的训练,请在 Trainer 外 | |||
对 train_dataloader 做的数据做特殊处理使得其在不同的卡之间 sample 是 | |||
8卡训练,只有9个 sample ,如果 batch_size 为 1,那么第二个 batch 时,有7张卡将没有 sample 可用,因此只有 **重复** 使用 sample 来 pad 到第二个 | |||
batch 中。如果不希望 fastNLP 对 dataloader 的 sampler 做特殊设置,请将该值设置为 False ,若确实需要分布式的训练,请在 Trainer 外 | |||
对 ``train_dataloader`` 做的数据做特殊处理使得其在不同的卡之间 sample 是不同的。 | |||
* *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将保证 dataloader 的 ``sampler`` 替换为 | |||
evaluate 时使用的分布式的 ``sampler``,其特点是每个卡上的数据之间不重叠,所有卡上数据的加起来是整个数据集。若传入的 dataloader | |||
的 sampler 为 (a) 深度学习框架自带的默认 sampler ; (b) fastNLP 的 Sampler 等,则将替换为 | |||
:class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以 | |||
的 sampler 为: | |||
- 深度学习框架自带的默认 sampler ; | |||
- fastNLP 的 Sampler ; | |||
则将替换为 :class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以 | |||
用到的数据。 | |||
* *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | |||
``["all", "ignore", "only_error"]``;当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 ``"only_error"``; | |||
注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``; | |||
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 :class:`~fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback`等对象, | |||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 :class:`~fastNLP.RichCallback`,否则使用 :class:`~fastNLP.RawTextCallback` 对象。如果 | |||
* *progress_bar* -- 显示进度条的方式,目前支持 ``[None, 'raw', 'rich', 'auto', 'tqdm']`` 或者 :class:`~fastNLP.RichCallback` 、 :class:`~fastNLP.RawTextCallback` 等对象, | |||
默认为 ``'auto'`` , ``'auto'`` 表示如果检测到当前 terminal 为交互型则使用 :class:`~fastNLP.RichCallback`,否则使用 :class:`~fastNLP.RawTextCallback` 对象。如果 | |||
需要定制 progress bar 的参数,例如打印频率等,可以传入 :class:`~fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback` 等对象。 | |||
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。 | |||
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。 | |||
@@ -331,19 +348,19 @@ class Trainer(TrainerEventTrigger): | |||
``Trainer`` 是通过在内部直接初始化一个 ``Evaluator`` 来进行验证; | |||
``Trainer`` 内部的 ``Evaluator`` 默认是 None,如果您需要在训练过程中进行验证,你需要保证这几个参数得到正确的传入: | |||
必须的参数:1. ``metrics``;2. ``evaluate_dataloaders``; | |||
必须的参数:``metrics`` 与 ``evaluate_dataloaders``; | |||
可选的其它参数:1. ``evaluate_batch_step_fn;2. ``evaluate_fn``;3. ``evaluate_every``;4. ``input_mapping``; | |||
5. ``output_mapping``; 6. ``model_wo_auto_param_call``;7. ``fp16``;8. ``monitor``;9. ``larger_better``; | |||
可选的其它参数:``evaluate_batch_step_fn``、 ``evaluate_fn``、``evaluate_every``、``input_mapping``、 | |||
``output_mapping``、``model_wo_auto_param_call``、``fp16``、``monitor``、``larger_better``; | |||
.. warning:: | |||
如果 ``Trainer`` 中内置的 ``Evaluator`` 实例不为 ``None``,那么需要注意 ``Trainer`` 中的一些参数是与 ``Evaluator`` 一致的,它们分别为: | |||
1. ``Evaluator`` 在初始化时的 ``driver`` 参数是 ``Trainer`` 中已经实例化过的 driver;这一点使得一些参数对于 ``Trainer`` 内部的 | |||
``Evaluator`` 没有用处,例如 ``device``,``torch_kwargs``,``data_device`` 和 ``output_from_new_proc`` 等; | |||
``Evaluator`` 没有用处,例如 ``device``,``torch_kwargs``,``data_device`` 和 ``output_from_new_proc`` 等; | |||
2. ``input_mapping``,``output_mapping``,``model_wo_auto_param_call`` 和 ``fp16`` 是 ``Trainer`` 和其内部默认的 | |||
``Evaluator`` 是一致的; | |||
``Evaluator`` 是一致的; | |||
当然,对于 ``input_mapping`` 和 ``output_mapping``,您可以通过添加 ``kwargs`` 中的参数 ``evaluate_input_mapping`` 和 | |||
``evaluate_output_mapping`` 来单独为 ``Evaluator`` 进行更细致的订制。 | |||
@@ -547,17 +564,17 @@ class Trainer(TrainerEventTrigger): | |||
r""" | |||
该函数是在 ``Trainer`` 初始化后用于真正开始训练的函数; | |||
注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ``CheckpointCallback`` | |||
注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 ``resume_from`` 为 ``None``,并且使用 ``CheckpointCallback`` | |||
去保存断点重训的文件; | |||
:param num_train_batch_per_epoch: 每个 epoch 训练多少个 batch 后停止,*-1* 表示使用 train_dataloader 本身的长度; | |||
:param num_eval_batch_per_dl: 每个 evaluate_dataloader 验证多少个 batch 停止,*-1* 表示使用 evaluate_dataloader 本身的长度; | |||
:param num_train_batch_per_epoch: 每个 epoch 训练多少个 batch 后停止,*-1* 表示使用 ``train_dataloader`` 本身的长度; | |||
:param num_eval_batch_per_dl: 每个 ``evaluate_dataloader`` 验证多少个 batch 停止,*-1* 表示使用 ``evaluate_dataloader`` 本身的长度; | |||
:param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 的过程是否有错误。为 0 表示不检测; | |||
:param resume_from: 从哪个路径下恢复 trainer 的状态,注意该值需要为一个文件夹,例如使用 ``CheckpointCallback`` 时帮助您创建的保存的子文件夹; | |||
:param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态;该参数如果为 ``True``, | |||
在下一次断点重训的时候我们会精确到上次训练截止的具体的 sample 进行训练;否则我们只会恢复 model 和 optimizers 的状态,而 ``Trainer`` 中的 | |||
其余状态都是保持初始化时的状态不会改变; | |||
:param catch_KeyboardInterrupt: 是否捕获 KeyboardInterrupt;如果该参数为 ``True``,在训练时如果您使用 ``ctrl+c`` 来终止程序, | |||
:param catch_KeyboardInterrupt: 是否捕获 :class:`KeyboardInterrupt`;如果该参数为 ``True``,在训练时如果您使用 ``ctrl+c`` 来终止程序, | |||
``Trainer`` 不会抛出异常,但是会提前退出,然后 ``trainer.run()`` 之后的代码会继续运行。注意该参数在您使用分布式训练的 ``Driver`` | |||
时无效,例如 ``TorchDDPDriver``;非分布式训练的 ``Driver`` 下该参数默认为 True; | |||
@@ -578,7 +595,7 @@ class Trainer(TrainerEventTrigger): | |||
整体的验证流程是否正确; | |||
``num_eval_sanity_batch`` 的作用可能会让人产生迷惑,其本质和 ``num_eval_batch_per_dl`` 作用一致,但是其只被 ``Trainer`` 使用; | |||
并且其只会在训练的一开始使用,意思为:我们在训练的开始时会先使用 ``Evaluator``(如果其不为 ``None``) 进行验证,此时验证的 batch 的 | |||
并且其只会在训练的一开始使用,意思为:我们在训练的开始时会先使用 ``Evaluator`` (如果其不为 ``None``) 进行验证,此时验证的 batch 的 | |||
数量只有 ``num_eval_sanity_batch`` 个;但是对于 ``num_eval_batch_per_dl`` 而言,其表示在实际的整体的训练过程中,每次 ``Evaluator`` | |||
进行验证时会验证的 batch 的数量。 | |||
@@ -724,7 +741,7 @@ class Trainer(TrainerEventTrigger): | |||
.. note:: | |||
对于训练一个神经网络的整体的流程来说,其可以分为很多个时间点,例如 **"整体的训练前"**,**"训练具体的一个 epoch 前"**, | |||
**"反向传播前"**,**"整体的训练结束后"**等;一个 ``callback`` 时机指的就是这些一个个具体的时间点; | |||
**"反向传播前"**,**"整体的训练结束后"** 等;一个 ``callback`` 时机指的就是这些一个个具体的时间点; | |||
该函数的参数 ``event`` 需要是一个 ``Event`` 实例,其使用方式见下方的例子; | |||
@@ -1014,10 +1031,11 @@ class Trainer(TrainerEventTrigger): | |||
r""" | |||
用于帮助您保存模型的辅助函数; | |||
:param folder: 保存模型的文件夹。如果没有传入 model_save_fn 参数,则我们会在这个文件夹下保存 fastnlp_model.pkl.tar 文件; | |||
:param only_state_dict: 仅在 model_save_fn 为空时,有效。是否只保存模型的 ``state_dict``; | |||
:param folder: 保存模型的文件夹。如果没有传入 ``model_save_fn`` 参数,则我们会在这个文件夹下保存 ``fastnlp_model.pkl.tar`` 文件; | |||
:param only_state_dict: 仅在 ``model_save_fn`` 为空时,有效。是否只保存模型的 ``state_dict``; | |||
:param model_save_fn: 您自己定制的用来替换该保存函数本身保存逻辑的函数,当您传入了该参数后,我们会实际调用该函数,而不会去调用 ``driver`` 的 ``save_model`` 函数; | |||
:param kwargs: 理论上您不需要使用到该参数; | |||
:kwargs: | |||
* *input_spec* -- 该参数详见 **PaddlePaddle** 框架的保存函数 :meth:`~fastNLP.core.drivers.PaddleDriver.save_model` 中的说明; | |||
.. note:: | |||
@@ -1056,10 +1074,10 @@ class Trainer(TrainerEventTrigger): | |||
""" | |||
用于帮助您加载模型的辅助函数; | |||
:param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时, | |||
直接将该 folder 传递到 model_load_fn 中; | |||
:param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn 不为 None`` 时,该参数无意义; | |||
:param model_load_fn: ``callable`` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容; | |||
:param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 ``fastnlp_model.pkl.tar`` 文件。在 ``model_load_fn`` | |||
不为空时,直接将该 folder 传递到 ``model_load_fn`` 中; | |||
:param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn`` 不为 ``None`` 时,该参数无意义; | |||
:param model_load_fn: :class:`Callable` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容; | |||
:param kwargs: 理论上您不需要使用到该参数; | |||
.. note:: | |||
@@ -1099,12 +1117,13 @@ class Trainer(TrainerEventTrigger): | |||
用于帮助您实现断点重训功能的保存函数;保存内容包括:callback 状态、Trainer 的状态、Sampler 的状态【在恢复的时候才能恢复到特定 batch 】、 | |||
模型参数、optimizer的状态、fp16 Scaler的状态【如果有】。 | |||
:param folder: 保存在哪个文件夹下,会在该文件下声称两个文件:fastnlp_checkpoint.pkl.tar 与 fastnlp_model.pkl.tar 。 | |||
如果 model_save_fn 不为空,则没有 fastnlp_model.pkl.tar 文件; | |||
:param only_state_dict: 当 model_save_fn 为空时有效,表明是否仅保存模型的权重; | |||
:param folder: 保存在哪个文件夹下,会在该文件下生成两个文件:``fastnlp_checkpoint.pkl.tar`` 与 ``fastnlp_model.pkl.tar`` 。 | |||
如果 ``model_save_fn`` 不为空,则没有 ``fastnlp_model.pkl.tar`` 文件; | |||
:param only_state_dict: 当 ``model_save_fn`` 为空时有效,表明是否仅保存模型的权重; | |||
:param model_save_fn: 如果模型保存比较特殊,可以传入该函数自定义模型的保存过程,输入应该接受一个文件夹(实际上就是接受上面的 folder | |||
参数),不需要返回值;这意味着您可以通过该函数来自己负责模型的保存过程,而我们则会将 ``trainer`` 的状态保存好; | |||
:param kwargs: 理论上您不需要使用到该参数; | |||
:kwargs: | |||
* *input_spec* -- 该参数详见 **PaddlePaddle** 框架的保存函数 :meth:`~fastNLP.core.drivers.PaddleDriver.save_model` 中的说明; | |||
.. note:: | |||
@@ -1123,7 +1142,7 @@ class Trainer(TrainerEventTrigger): | |||
为了支持断点重训功能,我们会在调用该函数时保存以下内容: | |||
1. 各个 ``callback`` 的状态,这主要涉及到一些带有运行状态的 ``callback``; | |||
2. 控制训练流程的变量 ``trainer_state``,具体详见 :class:`~fastNLP.core.controllers.utils.states.TrainerState`; | |||
2. 控制训练流程的变量 ``trainer_state``,具体详见 :class:`~fastNLP.core.controllers.utils.state.TrainerState`; | |||
3. 一个特殊的变量 ``num_consumed_batches``,表示在这次训练过程中总共训练了多少个 batch 的数据;您不需要关心这个变量; | |||
4. sampler 的状态,为了支持断点重训功能,我们会在 trainer 初始化的时候,将您的 ``trainer_dataloader`` 的 ``sampler`` 替换为 | |||
我们专门用于断点重训功能的 ``ReproducibleSampler``,详见 :class:`~fastNLP.core.samplers.reproducible_sampler.ReproducibleSampler`; | |||
@@ -1335,6 +1354,11 @@ class Trainer(TrainerEventTrigger): | |||
用于在使用梯度累积并且进行分布式训练时,由于在前 ``accumulation_steps - 1`` 的时间内不需要进行梯度的同步,因此通过使用该 context 上下文 | |||
环境来避免梯度的同步; | |||
.. note:: | |||
部分深度学习框架的梯度累积并不需要通过提供上下文环境实现,关于这点需要您深入了解您正在使用的框架的机制;而对于这些框架,fastNLP 会返回一个 | |||
空的上下文环境。 | |||
:return: 一个支持 ``no_sync`` 的 ``context``; | |||
""" | |||
@@ -1420,7 +1444,7 @@ class Trainer(TrainerEventTrigger): | |||
def model_device(self): | |||
r""" | |||
:return: 返回当前模型所在的设备;注意该值在当且仅当在少数情况下为 ``None``,例如当使用 ``pytorch`` 时,仅当用户自己初始化 ``init_progress_group`` 时 | |||
``model_device`` 才为 None; | |||
``model_device`` 才为 None; | |||
""" | |||
return self.driver.model_device | |||
@@ -42,7 +42,7 @@ class State(dict): | |||
class TrainerState: | |||
r""" | |||
该类用于我们 fastNLP 自己内部为了训练流程所记录的一些状态,当然是要暴露给用户给用户使用的; | |||
我们保存的state大部分上是 trainer 断点重训 需要重新加载的; | |||
我们保存的 state 大部分上是 trainer 断点重训 需要重新加载的; | |||
专属于 `Trainer` 的状态记载的类; | |||
:param n_epochs: 训练过程中总共的 epoch 的数量; | |||
@@ -50,7 +50,7 @@ class TrainerState: | |||
:param global_forward_batches: 当前模型总共 forward 了多少个 step; | |||
:param batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | |||
:param num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; | |||
:param n_batches: 完整训练过程会 forward 的 step 数量,注意 n_batches = n_batches * n_epochs; | |||
:param n_batches: 完整训练过程会 forward 的 step 数量,注意 ``n_batches = num_batches_per_epoch * n_epochs`` ; | |||
""" | |||
n_epochs: Optional[int] = None # 无论如何重新算 | |||
@@ -73,6 +73,7 @@ class TrainerState: | |||
def load_state_dict(self, state_dict: Dict): | |||
r""" | |||
用于断点重训来重新加载保存的状态字典; | |||
:param state_dict: 用于加载的状态字典; | |||
""" | |||
for key in state_dict: | |||
@@ -4,11 +4,12 @@ from fastNLP.core.callbacks import CallbackManager | |||
from .state import TrainerState | |||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||
__all__ = [] | |||
class TrainerEventTrigger: | |||
r""" | |||
为了避免在训练流程中调用 callback 函数中写成类似 'trainer.callback_manager.on_train_begin' 的形式,我们选择单独抽象为 'Trainer' | |||
抽象一层,然后一些特殊的操作可以在这里进行,例如我们通过 `on_validate_end` 来通知所有的 'CheckpointCallback' 实例在当前的 step 后保存 | |||
为了避免在训练流程中调用 callback 函数中写成类似 `'trainer.callback_manager.on_train_begin'` 的形式,我们选择单独为 ``Trainer`` | |||
抽象一层,然后一些特殊的操作可以在这里进行,例如我们通过 :meth:`on_validate_end` 来通知所有的 ``CheckpointCallback`` 实例在当前的 step 后保存 | |||
模型。 | |||
""" | |||
callback_manager: CallbackManager | |||
@@ -138,7 +139,7 @@ def check_evaluate_every(evaluate_every): | |||
``evaluate_every`` 的使用详见 ``Trainer`` 的 ``evaluate_every`` 参数; | |||
主要在于当参数 ``evaluate_every`` 是一个 callable 的函数时,需要保证其参数的正确性; | |||
主要在于当参数 ``evaluate_every`` 是一个 Callable 的函数时,需要保证其参数的正确性; | |||
""" | |||
if not callable(evaluate_every) and (not isinstance(evaluate_every, int) or evaluate_every == 0): | |||
raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | |||
@@ -47,38 +47,35 @@ class JittorDataLoader: | |||
* callate_fn 为 ``'auto'`` 时,``JittorDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。 | |||
此时可以配套使用 ``JittorDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* callate_fn 为 ``None`` 时, ``JittorDataLoader`` 默认使用 Jittor DataLoader 自带的 collate_fn | |||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||
:param shuffle: 是否打乱数据集, 默认为 ``False``。 | |||
:param drop_last: 当 ``drop_last=True`` 时,``JittorDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``JittorDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param buffer_size: 每个进程占用的内存空间,默认为 512M。主要是配合 ``num_workers`` 使用,用户可以自定义每个进程的内存大小。 | |||
:param stop_grad: 是否不使用梯度, 默认 ``True`` 。 | |||
:param keep_numpy_array: 返回的数据是 ``np.array`` 类型而不是 ``jittor.Var`` 类型,默认为 ``False`` | |||
:param endless: 是否让 ``JittorDataLoader`` 无限返回数据,也就是将 dataset 循环使用使得返回数据是没有限制的。默认为 ``False``. | |||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||
``JittorDataLoader`` 调用默认的 Jittor 框架的 ``DataLoader`` 自带的 ``collate_batch`` 作为 callate_fn 的默认值, 其无法处理 | |||
:class:`~fastNLP.core.dataset.DataSet` 的 dataset 对象。 | |||
* callate_fn 为 ``'auto'`` 时,``JittorDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||
此时可以配套使用 ``JittorDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
""" | |||
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, | |||
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | |||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | |||
collate_fn: Union[None, str, Callable] = "auto") -> None: | |||
""" | |||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||
:param shuffle: 是否打乱数据集, 默认为 ``False``。 | |||
:param drop_last: 当 ``drop_last=True`` 时,``JittorDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``JittorDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param buffer_size: 每个进程占用的内存空间,默认为512M。主要是配合num_workers使用,用户可以自定义每个进程的内存大小。 | |||
:param stop_grad: 是否不使用梯度, 默认 ``True`` 。 | |||
:param keep_numpy_array: 返回的数据是 ``np.array`` 类型而不是 ``jittor.Var`` 类型,默认为 ``False`` | |||
:param endless: 是否让 ``JittorDataLoader`` 无限返回数据,也就是将 dataset 循环使用使得返回数据是没有限制的。默认为 ``False``. | |||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||
``JittorDataLoader`` 调用默认的 Jittor 框架的 ``DataLoader`` 自带的 ``collate_batch`` 作为 callate_fn 的默认值, 其无法处理 | |||
:class:`~fastNLP.core.dataset.DataSet` 的 dataset 对象。 | |||
* callate_fn 为 ``'auto'`` 时,``JittorDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||
此时可以配套使用 ``JittorDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
""" | |||
# TODO 验证支持replacesampler (以后完成) 增加Sampler | |||
# 将内部dataset批次设置为1 | |||
if isinstance(dataset, Dataset): | |||
@@ -136,20 +133,20 @@ class JittorDataLoader: | |||
""" | |||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
:param field_name: 需要调整的 field 的名称。如果 :class:`~fastNLP.core.Dataset` 的 :class:`~fastNLP.core.Dataset.__getitem__` | |||
方法返回的是 dict 类型的,则可以直接使用对应的 field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 | |||
``{'a': {'b': 1}}`` 中的使用 ``('a', 'b')`` 如果 ``__getitem__`` 返回的是 Sequence 类型的,则可以使用 *_0*, *_1* 表示序列中 | |||
第 **0** 或 **1** 个元素。如果该 field 在数据中没有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 ``_single`` 。 | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||
无意义。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||
:param backend: 可选 ``['raw', 'numpy', 'Jittor', 'paddle', 'jittor', 'auto']`` ,分别代表,输出为 ``list`` , ``numpy.ndarray`` , | |||
``Jittor.Tensor`` , ``paddle.Tensor`` , ``jittor.Var`` 类型。若 ``pad_val`` 为 ``None`` ,该值无意义 。 | |||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||
形式,输出将被直接作为结果输出。 | |||
:return: 返回 Collator 自身 | |||
:param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||
如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 ``"_single"`` 。 | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, | |||
该值无意义。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 | |||
:param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, | |||
:class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 | |||
若 ``pad_val`` 为 ``None`` ,该值无意义 。 | |||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 | |||
batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 | |||
:return: 返回使用的 collator | |||
""" | |||
collator = self._get_collator() | |||
if isinstance(collator, Collator): | |||
@@ -173,16 +170,14 @@ class JittorDataLoader: | |||
def set_ignore(self, *field_names) -> Collator: | |||
""" | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
Example:: | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略:: | |||
collator.set_ignore('field1', 'field2') | |||
dataloader.set_ignore('field1', 'field2') | |||
:param field_name: 需要调整的 field 的名称。如果 :class:`~fastNLP.core.Dataset` 的 :class:`~fastNLP.core.Dataset.__getitem__` | |||
方法返回的是 dict 类型的,则可以直接使用对应的 field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 | |||
``{'a': {'b': 1}}`` 中的使用 ``('a', 'b')`` 如果 ``__getitem__`` 返回的是 Sequence 类型的,则可以使用 *_0*, *_1* 表示序列中 | |||
第 **0** 或 **1** 个元素。 | |||
:return: 返回 Collator 自身 | |||
:param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||
:return: 返回使用的 collator | |||
""" | |||
collator = self._get_collator() | |||
if isinstance(collator, Collator): | |||
@@ -193,9 +188,9 @@ class JittorDataLoader: | |||
def get_batch_indices(self) -> List[int]: | |||
""" | |||
获取当前 batch 的 idx | |||
获取当前 ``batch`` 中每条数据对应的索引。 | |||
:return: | |||
:return: 当前 ``batch`` 数据的索引; | |||
""" | |||
return self.cur_batch_indices | |||
@@ -208,37 +203,37 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = No | |||
-> Union[Dict[str, JittorDataLoader], JittorDataLoader]: | |||
""" | |||
``prepare_jittor_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 :class:`JittorDataLoader` 对象, 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 | |||
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||
* 当 ds_or_db 为 ``DataSet`` 时,``prepare_jittor_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | |||
帮你实例化一个 :class:`JittorDataLoader` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 | |||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_Jittor_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value | |||
来创建不同的 :class:`JittorDataLoader` 对象;当 key 中包含'train'字符串时,``prepare_jittor_dataloader`` 默认该 value 为 train 数据集, | |||
会将 batch_size 和 sampler 作为参数,其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。 | |||
最终根据 ``key: JittorDataLoader`` 组成 ``Dict[key, JittorDataLoader]`` 的字典返回。 | |||
根据 ``ds_or_db`` 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataSet` 时,``prepare_jittor_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | |||
帮你实例化一个 :class:`JittorDataLoader` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`; | |||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_jittor_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value | |||
来创建不同的 :class:`JittorDataLoader` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_jittor_dataloader`` 默认该 value 为训练数据集, | |||
会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。 | |||
最终根据 ``key: JittorDataLoader`` 组成 ``Dict[key, JittorDataLoader]`` 的字典返回; | |||
* 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_jittor_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 | |||
:class:`JittorDataLoader` 对象;当 key 中包含'train'字符串时,``prepare_Jittor_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | |||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: JittorDataLoader`` 组成 | |||
``Dict[key, JittorDataLoader]`` 的字典返回。 | |||
:class:`JittorDataLoader` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_jittor_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和 | |||
``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: JittorDataLoader`` 组成 | |||
``Dict[key, JittorDataLoader]`` 的字典返回; | |||
:param ds_or_db: 可以有以下三种取值, | |||
:param ds_or_db: 可以有以下三种取值: | |||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | |||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | |||
* ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为:class:`~fastNLP.TorchDataLoader` | |||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; | |||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; | |||
* ds_or_db 为实现了 :meth:`__getitem__` 和 :meth:`__len__` 的对象 ,返回值为 :class:`~fastNLP.core.dataloaders.JittorDataLoader`; | |||
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 | |||
设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。 | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True , | |||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 True , | |||
其它的为 False 。 | |||
:param drop_last: 当 ``drop_last=True`` 时,:class:`JittorDataLoader` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, :class:`JittorDataLoader` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param buffer_size: 每个进程占用的内存空间,默认为512M。主要是配合num_workers使用,用户可以自定义每个进程的内存大小。 | |||
:param buffer_size: 每个进程占用的内存空间,默认为512M。主要是配合 ``num_workers`` 使用,用户可以自定义每个进程的内存大小。 | |||
:param stop_grad: 是否不使用梯度, 默认 ``True`` 。 | |||
:param keep_numpy_array: 返回的数据是 ``np.array`` 类型而不是 ``jittor.Var`` 类型,默认为 ``False`` | |||
:param keep_numpy_array: 返回的数据是 :class:`np.array` 类型而不是 :class:`ittor.Var` 类型,默认为 ``False`` | |||
:param endless: 是否让 :class:`JittorDataLoader` 无限返回数据,也就是将 dataset 循环使用使得返回数据是没有限制的。默认为 ``False``. | |||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||
@@ -247,11 +242,8 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = No | |||
:class:`~fastNLP.core.dataset.DataSet` 的 dataset 对象。 | |||
* callate_fn 为 ``'auto'`` 时,:class:`JittorDataLoader` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||
此时可以配套使用 :class:`JittorDataLoader` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:return: 返回数据类型为 :class:`Dict[str, JittorDataLoader]`, :class:`JittorDataLoader` 其中之一,根据输入 | |||
``ds_or_db`` 变化而变化。 | |||
""" | |||
from fastNLP.io.data_bundle import DataBundle | |||
@@ -57,10 +57,42 @@ class OneflowDataLoader(DataLoader): | |||
* callate_fn 为 ``'auto'`` 时,``OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。 | |||
此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* callate_fn 为 ``None`` 时, ``OneflowDataLoadr`` 默认使用 oneflow DataLoader 自带的 collate_fn | |||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
* callate_fn 为 ``None`` 时, ``OneflowDataLoadr`` 默认使用 :class:`oneflow.utils.data.DataLoader` 自带的 collate_fn | |||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||
:param non_train_batch_size: 非训练数据集的 ``OneflowDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。 | |||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , | |||
其它的为 False 。 | |||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||
默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。 | |||
:param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||
dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``OneflowDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 | |||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||
``OneflowDataLoader`` 调用默认的 oneflow 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||
* callate_fn 为 ``'auto'`` 时,``OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||
此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:param pin_memory: 如果其为 ``True``, 那么 ``OneflowDataLoader`` 会在返回数据张量之前将其 copy 到 cuda 的 pin memory 中。 | |||
:param drop_last: 当 ``drop_last=True`` 时,``OneflowDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
:param timeout: 子进程的输出队列获取数据的超时值 | |||
:param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。 | |||
:param multiprocessing_context: 多进程的上下文环境 | |||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` | |||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||
:param persistent_workers: 如果其为 ``True``, ``OneflowDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||
""" | |||
def __init__(self, dataset, batch_size: int = 16, | |||
@@ -70,38 +102,7 @@ class OneflowDataLoader(DataLoader): | |||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||
persistent_workers: bool = False, **kwargs) -> None: | |||
""" | |||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||
:param shuffle: 是否打乱数据集, 默认为 ``False``。 | |||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``OneflowDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||
``OneflowDataLoader`` 调用默认的 oneflow 框架的 ``DataLoader`` 自带的 ``default_collate_fn`` 作为 callate_fn 的默认值, 其无法处理 | |||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||
* callate_fn 为 ``'auto'`` 时,``OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||
此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:param pin_memory: 如果其为 ``True``, 那么 ``OneflowDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 | |||
:param drop_last: 当 ``drop_last=True`` 时,``OneflowDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
:param timeout: 子进程的输出队列获取数据的超时值 | |||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | |||
:param multiprocessing_context: 多进程的上下文环境 | |||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` | |||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2``意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||
:param persistent_workers: 如果其为 ``True``, ``OneflowDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||
""" | |||
if isinstance(dataset, DataSet) and collate_fn is None: | |||
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | |||
@@ -153,20 +154,20 @@ class OneflowDataLoader(DataLoader): | |||
""" | |||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||
无意义。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||
torch.Tensor, paddle.Tensor, jittor.Var, oneflow.Tensor 类型。若 pad_val 为 None ,该值无意义 。 | |||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||
形式,输出将被直接作为结果输出。 | |||
:return: 返回 Collator | |||
:param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||
如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, | |||
该值无意义。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 | |||
:param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, | |||
:class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 | |||
若 ``pad_val`` 为 ``None`` ,该值无意义 。 | |||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 | |||
batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 | |||
:return: 返回使用的 collator | |||
""" | |||
collator = self._get_collator() | |||
if isinstance(collator, Collator): | |||
@@ -190,15 +191,14 @@ class OneflowDataLoader(DataLoader): | |||
def set_ignore(self, *field_names) -> Collator: | |||
""" | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
Example:: | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略:: | |||
collator.set_ignore('field1', 'field2') | |||
dataloader.set_ignore('field1', 'field2') | |||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
:return: 返回 Collator 自身 | |||
:param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||
:return: 返回使用的 collator | |||
""" | |||
collator = self._get_collator() | |||
if isinstance(collator, Collator): | |||
@@ -230,56 +230,56 @@ def prepare_oneflow_dataloader(ds_or_db, | |||
non_train_batch_size: int = None) \ | |||
-> Union[OneflowDataLoader, Dict[str, OneflowDataLoader]]: | |||
""" | |||
``prepare_oneflow_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``OneflowDataloader``对象, 详见 :class:`~fastNLP.OneflowDataLoader`。 | |||
``prepare_oneflow_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``OneflowDataloader`` 对象, 详见 :class:`~fastNLP.OneflowDataLoader`。 | |||
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||
* 当 ds_or_db 为 ``DataSet``时,``prepare_oneflow_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | |||
帮你实例化一个 ``OneflowDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.OneflowDataLoader`。 | |||
* 当 ds_or_db 为 ``DataSet`` 时,``prepare_oneflow_dataloader`` 会将使用的除了 ``non_train_batch_size`` 和 ``non_train_sampler`` 以外的参数来 | |||
帮你实例化一个 ``OneflowDataLoader`` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.OneflowDataLoader`。 | |||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_oneflow_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value | |||
来创建不同的 ``OneflowDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_oneflow_dataloader`` 默认该 value 为 train 数据集, | |||
会将 batch_size 和 sampler 作为参数,其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。 | |||
最终根据 ``key: OneflowDataLoader`` 组成 ``Dict[key, OneflowDataLoader]`` 的字典返回。 | |||
来创建不同的 ``OneflowDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_oneflow_dataloader`` 默认该 value 为训练数据集, | |||
会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。 | |||
最终根据 ``key: OneflowDataLoader`` 组成 ``Dict[key, OneflowDataLoader]`` 的字典返回。 | |||
* 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_oneflow_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 | |||
``OneflowDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_oneflow_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | |||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: OneflowDataLoader`` 组成 | |||
``Dict[key, OneflowDataLoader]`` 的字典返回。 | |||
``OneflowDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_oneflow_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和 ``sampler`` 作为参数, | |||
其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: OneflowDataLoader`` 组成 | |||
``Dict[key, OneflowDataLoader]`` 的字典返回。 | |||
:param ds_or_db: 可以有以下三种取值, | |||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, OneflowDataLoader]`` 的字典 | |||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, OneflowDataLoader]`` 的字典 | |||
* ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为:class:`~fastNLP.OneflowDataLoader` | |||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, OneflowDataLoader]`` 的字典; | |||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, OneflowDataLoader]`` 的字典; | |||
* ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为 :class:`~fastNLP.core.dataloaders.OneflowDataLoader`; | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||
:param non_train_batch_size: 非 'train' 数据集的 ``OneflowDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True , | |||
:param non_train_batch_size: 非训练数据集的 ``OneflowDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。 | |||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , | |||
其它的为 False 。 | |||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||
默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。 | |||
:param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||
:param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``OneflowDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List 中的值为 | |||
dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``OneflowDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 | |||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||
* callate_fn 为 'None' 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||
``OneflowDataLoader`` 调用默认的 oneflow 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||
* callate_fn 为 ``'auto'`` 时,`OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||
此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||
``OneflowDataLoader`` 调用默认的 oneflow 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||
* callate_fn 为 ``'auto'`` 时,``OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||
此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:param pin_memory: 如果其为 ``True``, 那么 ``OneflowDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 | |||
:param pin_memory: 如果其为 ``True``, 那么 ``OneflowDataLoader`` 会在返回数据张量之前将其 copy 到 cuda 的 pin memory 中。 | |||
:param drop_last: 当 ``drop_last=True`` 时,``OneflowDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
:param timeout: 子进程的输出队列获取数据的超时值 | |||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | |||
:param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。 | |||
:param multiprocessing_context: 多进程的上下文环境 | |||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个``base_seed`` | |||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2``意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` | |||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。 ``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` 。 | |||
:param persistent_workers: 如果其为 ``True``, ``OneflowDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||
""" | |||
@@ -1,6 +1,6 @@ | |||
__all__ = [ | |||
'PaddleDataLoader', | |||
'prepare_paddle_dataloader', | |||
'PaddleDataLoader' | |||
] | |||
from .fdl import PaddleDataLoader, prepare_paddle_dataloader |
@@ -53,6 +53,7 @@ class PaddleDataLoader(DataLoader): | |||
1. ``PaddleDataLoader`` 支持输入的 dataset 是无框架的,只要实现了 __getitem__() 和 __len__() 的对象即可, | |||
当不使用 :class:`~fastNLP.core.dataset.DataSet` 时也不需要传入 collate_fn, 只要只需要将 ``collate_fn='auto'`` 就能够自动 | |||
探测数据的类型并判断能否 pad 。此时可以调用 ``set_pad`` 和 ``set_ignore`` 方法来设置 field 的 pad_val 或者忽略某个 field 的 pad 操作。 | |||
Example:: | |||
from fastNLP import PaddleDataLoader | |||
@@ -76,9 +77,46 @@ class PaddleDataLoader(DataLoader): | |||
.. note:: | |||
当传入的dataset为fastNLP的DataSet时,collate_fn不能为None。默认可以是"auto"或者自定义callable函数。 | |||
3. 当 collate_fn 为 ``Callable`` 时,该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
3. 当 collate_fn 为 :class:`Callable` 时,该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||
:param feed_list: feed Tensor list. | |||
这个张量能被 ``paddle.static.data`` 创建。 如果 :attr:`return_list` 是 ``False``, 那么 :attr:`feed_list` | |||
应该被设置。 默认为 ``None `` 。 | |||
:param places: 将数据放进的一个 list 的 place。 :attr:`places` 能为 None. | |||
如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串 | |||
可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。 | |||
:param return_list: 每个设备上的返回值是否为以列表形式显示。 如果 :attr:`return_list=False`, | |||
每个设备上的返回值值为 str -> Tensor 的 dict, 其中 dict 的 key 为每个 fed Tensors 的名字。 | |||
如果 :attr:`return_list=True`, 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` 只能在动态图情况下设置为 ``True`` . | |||
默认值为 ``True`` 。 | |||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||
dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``shuffle`` 参数均失效。 | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 None 有效。 | |||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , | |||
其它的为 False 。 | |||
:param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||
``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||
* callate_fn 为 ``'auto'`` 时,``PaddleDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||
此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 | |||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param use_buffer_reader: 是否开启 buffer_reader 。如果 ``use_buffer_reader=True`` ,那么 ``PaddleDataLoader`` 会异步地预取下一个 batch 的 | |||
数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。 | |||
:param use_shared_memory: 是否使用共享内存。当 ``use_shared_memory=True`` 时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的 | |||
共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( ``num_workers>0`` )下生效。 | |||
:param timeout: 从子进程的输出队列获取数据的超时值 | |||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | |||
:param persistent_workers: 如果其为 ``True``, ``PaddleDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||
""" | |||
def __init__(self, dataset, feed_list=None, places=None, | |||
@@ -88,45 +126,7 @@ class PaddleDataLoader(DataLoader): | |||
num_workers: int = 0, use_buffer_reader: bool = True, | |||
use_shared_memory: bool = True, timeout: int = 0, | |||
worker_init_fn: Callable = None, persistent_workers=False) -> None: | |||
""" | |||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||
:param feed_list: feed Tensor list。 | |||
这个张量能被 :code:`paddle.static.data()` 创建。 如果 :attr:`return_list` 是 ``False``, 那么 :attr:`feed_list` | |||
应该被设置。 默认为 ``None`` | |||
:param places: 将数据放进的一个 list 的 place。 :attr:`places` 能为 None。 | |||
如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串 | |||
可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。 | |||
:param return_list: 每个设备上的返回值是否为以列表形式显示。 如果 :attr:`return_list=False`, 每个设备上的返回值值为 str -> Tensor 的 dict, | |||
其中 dict 的 key 为每个 fed Tensors 的名字。如果 :attr:`return_list` 为 ``True`` , 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` | |||
只能在动态图情况下设置为 ``True`` 。默认值为 ``True`` 。 | |||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, shuffle 参数均失效。 | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||
:param shuffle: 是否将数据打乱,若``shuffle=True`` 则会将dataset打乱;若否则什么也不做。 | |||
:param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||
``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 ``default_collate_fn`` 作为 callate_fn 的默认值, 其无法处理 | |||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||
* callate_fn 为 ``'auto'`` 时,``PaddleDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||
此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param use_buffer_reader: 是否开启 buffer_reader 。如果 ``use_buffer_reader=True`` ,那么 ``PaddleDataLoader`` 会异步地预取下一个 batch 的 | |||
数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。 | |||
:param use_shared_memory: 是否使用共享内存。当 ``use_shared_memory=True`` 时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的 | |||
共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( num_workers>0 )下生效。 | |||
:param timeout: 从子进程的输出队列获取数据的超时值 | |||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | |||
:param persistent_workers: 如果其为 ``True``, ``PaddleDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||
""" | |||
# FastNLP Datset, collate_fn not None | |||
if isinstance(dataset, FDataSet) and collate_fn is None: | |||
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | |||
@@ -186,20 +186,20 @@ class PaddleDataLoader(DataLoader): | |||
""" | |||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||
无意义。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||
torch.Tensor, paddle.Tensor, jittor.Var, oneflow.Tensor 类型。若 pad_val 为 None ,该值无意义 。 | |||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||
形式,输出将被直接作为结果输出。 | |||
:return: 返回 Collator 自身 | |||
:param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||
如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, | |||
该值无意义。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 | |||
:param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, | |||
:class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 | |||
若 ``pad_val`` 为 ``None`` ,该值无意义 。 | |||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 | |||
batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 | |||
:return: 返回使用的 collator | |||
""" | |||
collator = self._get_collator() | |||
if isinstance(collator, Collator): | |||
@@ -223,15 +223,14 @@ class PaddleDataLoader(DataLoader): | |||
def set_ignore(self, *field_names) -> Collator: | |||
""" | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
Example:: | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略:: | |||
collator.set_ignore('field1', 'field2') | |||
dataloader.set_ignore('field1', 'field2') | |||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
:return: 返回 Collator 自身 | |||
:param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||
:return: 返回使用的 collator | |||
""" | |||
collator = self._get_collator() | |||
if isinstance(collator, Collator): | |||
@@ -260,59 +259,59 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
non_train_batch_size: int = None) \ | |||
-> Union[Dict[str, PaddleDataLoader], PaddleDataLoader]: | |||
""" | |||
``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader``对象, 详见 :class:`~fastNLP.PaddleDataLoader`。 | |||
``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader`` 对象, 详见 :class:`~fastNLP.PaddleDataLoader`。 | |||
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||
* 当 ds_or_db 为 ``DataSet``时,``prepare_paddle_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | |||
帮你实例化一个 ``PaddleDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.PaddleDataLoader`。 | |||
* 当 ds_or_db 为 ``DataSet`` 时,``prepare_paddle_dataloader`` 会将除了 ``non_train_batch_size`` 和 ``non_train_sampler`` 以外的参数来 | |||
帮你实例化一个 ``PaddleDataLoader`` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.PaddleDataLoader`。 | |||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_paddle_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value | |||
来创建不同的 ``PaddleDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_Paddle_dataloader`` 默认该 value 为 train 数据集, | |||
会将 batch_size 和 sampler 作为参数,其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。 | |||
最终根据 ``key: PaddleDataLoader`` 组成 ``Dict[key, PaddleDataLoader]`` 的字典返回。 | |||
来创建不同的 ``PaddleDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_Paddle_dataloader`` 默认该 value 为训练数据集, | |||
会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。 | |||
最终根据 ``key: PaddleDataLoader`` 组成 ``Dict[key, PaddleDataLoader]`` 的字典返回。 | |||
* 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_paddle_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 | |||
``PaddleDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_paddle_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | |||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: PaddleDataLoader`` 组成 | |||
``Dict[key, PaddleDataLoader]`` 的字典返回。 | |||
``PaddleDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_paddle_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和 ``sampler`` 作为参数, | |||
其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: PaddleDataLoader`` 组成 | |||
``Dict[key, PaddleDataLoader]`` 的字典返回。 | |||
:param ds_or_db: 可以有以下三种取值, | |||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | |||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | |||
* ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为:class:`~fastNLP.TorchDataLoader` | |||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; | |||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; | |||
* ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为 :class:`~fastNLP.TorchDataLoader`; | |||
:param feed_list: (list(Tensor)|tuple(Tensor)): feed Tensor list. | |||
这个张量能被 :code:`paddle.static.data()` 创建。 如果:attr:`return_list` 是 ``False``, 那么 :attr:`feed_list` | |||
应该被设置。 默认为 ``None `` | |||
:param places: (list(Place)|tuple(Place)|list(str)|optional): 将数据放进的一个 list 的 place。 :attr:`places` 能为 None. | |||
如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串 | |||
可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。 | |||
:param feed_list: feed Tensor list. | |||
这个张量能被 ``paddle.static.data`` 创建。 如果 :attr:`return_list` 是 ``False``, 那么 :attr:`feed_list` | |||
应该被设置。 默认为 ``None `` 。 | |||
:param places: 将数据放进的一个 list 的 place。 :attr:`places` 能为 None. | |||
如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串 | |||
可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。 | |||
:param return_list: 每个设备上的返回值是否为以列表形式显示。 如果 :attr:`return_list=False`, | |||
每个设备上的返回值值为 str -> Tensor 的 dict, 其中 dict 的 key 为每个 fed Tensors 的名字。 | |||
如果 :attr:`return_list=True`, 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` 只能在动态图情况下设置为 ``True`` . | |||
默认值为 ``True`` 。 | |||
每个设备上的返回值值为 str -> Tensor 的 dict, 其中 dict 的 key 为每个 fed Tensors 的名字。 | |||
如果 :attr:`return_list=True`, 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` 只能在动态图情况下设置为 ``True`` . | |||
默认值为 ``True`` 。 | |||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, shuffle 参数均失效。 | |||
dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``shuffle`` 参数均失效。 | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True , | |||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , | |||
其它的为 False 。 | |||
:param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||
``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||
``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||
* callate_fn 为 ``'auto'`` 时,``PaddleDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||
此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param use_buffer_reader: 是否开启 buffer_reader 。如果 `use_buffer_reader=True`` ,那么 ``PaddleDataLoader` `会异步的预取下一个 batch 的 | |||
数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。 | |||
此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 | |||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param use_buffer_reader: 是否开启 buffer_reader 。如果 ``use_buffer_reader=True`` ,那么 ``PaddleDataLoader`` 会异步地预取下一个 batch 的 | |||
数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。 | |||
:param use_shared_memory: 是否使用共享内存。当 ``use_shared_memory=True`` 时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的 | |||
共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( num_workers>0 )下生效。 | |||
共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( ``num_workers>0`` )下生效。 | |||
:param timeout: 从子进程的输出队列获取数据的超时值 | |||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | |||
:param persistent_workers: 如果其为 ``True``, ``PaddleDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||
@@ -20,33 +20,33 @@ def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = None, drop | |||
""" | |||
自动创建合适的 ``DataLoader`` 对象。例如,检测当当前环境是 ``torch`` 的,则返回 ``TorchDataLoader`` , 是 ``paddle`` 的则 | |||
返回 ``PaddleDataLoader`` 。如果有更多需要定制的参数,请直接使用对应的 ``prepare`` 函数,例如 | |||
:func:`~fastNLP.prepare_torch_dataloader` 或 :func:`~fastNLP.prepare_paddle_dataloader` 等。 | |||
:func:`~fastNLP.core.dataloaders.prepare_torch_dataloader` 或 :func:`~fastNLP.core.dataloaders.prepare_paddle_dataloader` 等。 | |||
:param dataset: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。 | |||
* 为单个数据集对象时,返回一个 DataLoader 。 | |||
* 为数据集对象序列时,返回一个序列的 DataLoader 。 | |||
* 为字典型 或 :class:`~fastNLP.io.DataBundle` 数据时,返回 `Dict` 类型的数据。 | |||
* 为字典型 或 :class:`~fastNLP.io.DataBundle` 数据时,返回 :class:`Dict` 类型的数据。 | |||
:param batch_size: 批次大小。 | |||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True , | |||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , | |||
其它的为 False 。 | |||
:param drop_last: 当最后一个 batch 不足 batch_size 数量的是否,是否丢弃。 | |||
:param drop_last: 当最后一个 batch 不足 ``batch_size`` 数量的是否,是否丢弃。 | |||
:param collate_fn: 用于处理一个 batch 的函数,一般包括 padding 和转为 tensor。有以下三种取值: | |||
* 为 ``auto`` 时,使用 :class:`~fastNLP.Collator` 进行 padding 和 转tensor 。 | |||
* 为 ``Callable`` 时,应当接受一个 ``batch`` 的数据作为参数,同时输出一个对象 。 | |||
* 为 :class:`Callable` 时,应当接受一个 ``batch`` 的数据作为参数,同时输出一个对象 。 | |||
* 为 ``None`` 时,使用各个框架的 DataLoader 的默认 ``collate_fn`` 。 | |||
:param num_workers: 使用多少进程进行数据的 fetch 。 | |||
:param backend: 当前支持 ``["auto", "torch", "paddle", "jittor", "oneflow"]`` 四种类型。 | |||
* 为 ``auto`` 时,首先(1) 根据环境变量 "FASTNLP_BACKEND" 进行判断;如果没有设置则,(2)通过当前 | |||
* 为 ``auto`` 时,首先根据环境变量 ``"FASTNLP_BACKEND"`` 进行判断;如果没有设置则通过当前 | |||
``sys.modules`` 中已经 import 的 ``backend`` 进行判定。如果以上均无法判定,则报错。如果找到了 | |||
``backend`` ,则按照下述的方式处理。 | |||
* 为 ``torch`` 时,使用 :func:`~fastNLP.prepare_torch_dataloader` 。 | |||
* 为 ``paddle`` 时,使用 :func:`~fastNLP.prepare_paddle_dataloader` 。 | |||
* 为 ``jittor`` 时,使用 :func:`~fastNLP.prepare_jittor_dataloader` 。 | |||
* 为 ``oneflow`` 时,使用 :func:`~fastNLP.prepare_oneflow_dataloader` 。 | |||
* 为 ``torch`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_torch_dataloader` 。 | |||
* 为 ``paddle`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_paddle_dataloader` 。 | |||
* 为 ``jittor`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_jittor_dataloader` 。 | |||
* 为 ``oneflow`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_oneflow_dataloader` 。 | |||
:return | |||
""" | |||
@@ -58,9 +58,41 @@ class TorchDataLoader(DataLoader): | |||
* callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。 | |||
此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* callate_fn 为 ``None`` 时, ``TorchDataLoadr`` 默认使用 torch DataLoader 自带的 collate_fn | |||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||
:param non_train_batch_size: 非训练数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。 | |||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , | |||
其它的为 False 。 | |||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||
默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。 | |||
:param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||
dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 | |||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||
``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||
* callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||
此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 | |||
:param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
:param timeout: 子进程的输出队列获取数据的超时值 | |||
:param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。 | |||
:param multiprocessing_context: 多进程的上下文环境 | |||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` | |||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||
:param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||
""" | |||
def __init__(self, dataset, batch_size: int = 16, | |||
@@ -70,38 +102,7 @@ class TorchDataLoader(DataLoader): | |||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||
persistent_workers: bool = False, **kwargs) -> None: | |||
""" | |||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||
:param shuffle: 是否打乱数据集, 默认为 ``False``。 | |||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||
``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 ``default_collate_fn`` 作为 callate_fn 的默认值, 其无法处理 | |||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||
* callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||
此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 | |||
:param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
:param timeout: 子进程的输出队列获取数据的超时值 | |||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | |||
:param multiprocessing_context: 多进程的上下文环境 | |||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` | |||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2``意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||
:param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||
""" | |||
if isinstance(dataset, DataSet) and collate_fn is None: | |||
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | |||
@@ -153,20 +154,20 @@ class TorchDataLoader(DataLoader): | |||
""" | |||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||
无意义。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||
torch.Tensor, paddle.Tensor, jittor.Var, oneflow.Tensor 类型。若 pad_val 为 None ,该值无意义 。 | |||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||
形式,输出将被直接作为结果输出。 | |||
:return: 返回 Collator | |||
:param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||
如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, | |||
该值无意义。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 | |||
:param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, | |||
:class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 | |||
若 ``pad_val`` 为 ``None`` ,该值无意义 。 | |||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 | |||
batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 | |||
:return: 返回使用的 collator | |||
""" | |||
collator = self._get_collator() | |||
if isinstance(collator, Collator): | |||
@@ -190,15 +191,14 @@ class TorchDataLoader(DataLoader): | |||
def set_ignore(self, *field_names) -> Collator: | |||
""" | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
Example:: | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略:: | |||
collator.set_ignore('field1', 'field2') | |||
dataloader.set_ignore('field1', 'field2') | |||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
:return: 返回 Collator 自身 | |||
:param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||
:return: 返回使用的 collator | |||
""" | |||
collator = self._get_collator() | |||
if isinstance(collator, Collator): | |||
@@ -230,56 +230,56 @@ def prepare_torch_dataloader(ds_or_db, | |||
non_train_batch_size: int = None) \ | |||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader]]: | |||
""" | |||
``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader``对象, 详见 :class:`~fastNLP.TorchDataLoader`。 | |||
``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader`` 对象, 详见 :class:`~fastNLP.TorchDataLoader`。 | |||
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||
* 当 ds_or_db 为 ``DataSet``时,``prepare_torch_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | |||
帮你实例化一个 ``TorchDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.TorchDataLoader`。 | |||
* 当 ds_or_db 为 ``DataSet`` 时,``prepare_torch_dataloader`` 会将使用的除了 ``non_train_batch_size`` 和 ``non_train_sampler`` 以外的参数来 | |||
帮你实例化一个 ``TorchDataLoader`` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.TorchDataLoader`。 | |||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_torch_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value | |||
来创建不同的 ``TorchDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_torch_dataloader`` 默认该 value 为 train 数据集, | |||
会将 batch_size 和 sampler 作为参数,其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。 | |||
最终根据 ``key: TorchDataLoader`` 组成 ``Dict[key, TorchDataLoader]`` 的字典返回。 | |||
来创建不同的 ``TorchDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_torch_dataloader`` 默认该 value 为训练数据集, | |||
会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。 | |||
最终根据 ``key: TorchDataLoader`` 组成 ``Dict[key, TorchDataLoader]`` 的字典返回。 | |||
* 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_torch_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 | |||
``TorchDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_torch_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | |||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: TorchDataLoader`` 组成 | |||
``Dict[key, TorchDataLoader]`` 的字典返回。 | |||
``TorchDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_torch_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和 ``sampler`` 作为参数, | |||
其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: TorchDataLoader`` 组成 | |||
``Dict[key, TorchDataLoader]`` 的字典返回。 | |||
:param ds_or_db: 可以有以下三种取值, | |||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | |||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | |||
* ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为:class:`~fastNLP.TorchDataLoader` | |||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; | |||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; | |||
* ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为 :class:`~fastNLP.TorchDataLoader`; | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||
:param non_train_batch_size: 非 'train' 数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True , | |||
:param non_train_batch_size: 非训练数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。 | |||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , | |||
其它的为 False 。 | |||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||
:param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||
默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。 | |||
:param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||
dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 | |||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||
* callate_fn 为 'None' 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||
``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||
* callate_fn 为 ``'auto'`` 时,`TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||
此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||
``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||
* callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||
此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 | |||
:param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
:param timeout: 子进程的输出队列获取数据的超时值 | |||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | |||
:param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。 | |||
:param multiprocessing_context: 多进程的上下文环境 | |||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个``base_seed`` | |||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2``意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` | |||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||
:param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||
""" | |||
@@ -101,6 +101,19 @@ class MixDataLoader(DataLoader): | |||
""" | |||
针对以下四种情况提供的 ``MixDataLoader``, 目前只支持 ``torch`` 框架的版本, 其中 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: | |||
* 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 ``datasets`` 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 | |||
接一个的 sample 完所有数据。 | |||
* 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 ``datasets`` 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机 sample | |||
混合数据集 datasets 的数据组成一个 batch 序列返回。 | |||
* 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 ``datasets`` 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回, | |||
再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。 | |||
* 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int]; | |||
且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数 | |||
``sampler``, ``drop_last``, ``ds_ratio`` 均无效。 | |||
:param datasets: 实现了 __getitem__() 和 __len__() 对象的序列或者字典。 | |||
:param mode: mode 控制 ``MixDataLoader`` 运行模式。 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: | |||
* 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 | |||
接一个的 sample 完所有数据。 | |||
* 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample | |||
@@ -111,6 +124,40 @@ class MixDataLoader(DataLoader): | |||
且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数 | |||
sampler, drop_last, ds_ratio 均无效。 | |||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数。 其取值可以为 ``['auto', Callable, List[Callable], Dict[str, Callable]]``: | |||
* collate_fn 为 ``'auto'`` 时, ``MixDataLoader`` datasets 序列或者dict 初始化一个 :class:`~fastNLP.core.collators.Collator` 作为其默认值, | |||
需要注意的是只有当 datasets 包含的所以 dataset 的数据都为 ``List`` 或者 ``Dict`` 类型时才能使用。否则只能用户自己定义 collate_fn . | |||
* collate_fn 为 :class:`Callable` 时, 该 collate_fn 会被 datasets 序列或者dict 的所有数据所共享。该 Callable 函数应当接受一个 batch 参数作为输入, | |||
batch 是一个 List 对象且 List 中的每一条数据都是 dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
* collate_fn 为 ``Dict[str, Callable]`` 时, datasets 的 key 必须和 callable_fn 的 key 一致。 ``MixDataLoader`` 会将 ``collate_fn[key]`` | |||
用到 ``datasets[key]`` 的数据集上。 ``collate_fn[key]`` 是一个 Callable 对象。 | |||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,其取值范围为 | |||
``[None, str, Dict[str, "Sampler"]]``: | |||
* sampler 为 ``None`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。 | |||
* sampler 为 ``str`` 时, sampler 选择范围为 ``[rand, seq]``。当 sampler 为 ``rand`` 时,``MixDataLoader`` 默认初始化 ``torch`` 的 ``RandomSampler`` | |||
作为默认值, 其功能时随机采样 dataset 的下标并返回。 当 sampler 为 ``seq`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。 | |||
* sampler 为 ``Dict[str, "Sampler"]`` 时, ``Sampler`` 为用户定义的实现了 __len__() 和 __iter__() 的实例化对象。 其每次 iter 必须返回一个 int 下标。 | |||
Dict 的 str 必须和 datasets 的 key 一致。 也即是 ``Dict[str, Sampler]`` 为 datasets 字典的每个 dataset 初始化了一个 Sampler。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``MixDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快数据处理速度,但同时 | |||
也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 ``None`` 有效。 且 datasets 上所有 dataset 的 batch_size 一致。 | |||
:param drop_last: 当 ``drop_last=True`` 时,``MixDataLoader`` 会扔掉 datasets 中 每个 dataset 最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
:param ds_ratio: ``ds_ratio`` 是控制 datasets 怎么组成一个混合大数据集的重要参数, 其取值为 ``[None, 'truncate_to_least', 'pad_to_most', List[float], Dict[str, float]]``: | |||
* ds_ratio 为 ``None``, datasets 数据集序列或字典不进行数据扩充处理。 | |||
* ds_ratio 为 ``'truncate_to_least'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最断长度 ``mix_len``, 其他数据集会被切断 | |||
到最短长度 ``mix_len``。这种切断不是物理上切断,``MixDataLoader`` 会根据 sampler 不同来采样数据集到指定的最短长度 ``mix_len``。 | |||
* ds_ratio 为 ``'pad_to_most'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最大长度 ``max_len``, 其他其他数据集会扩充 | |||
到最大长度 ``mix_len``。这种扩充不是物理上扩充, ``MixDataLoader`` 会根据 sampler 不同来重采样 dataset 到指定的最大长度 ``max_len``。 | |||
* ds_ratio 为 ``Dict[str, float]`` 时, datasets 类型也必须为 ``Dict[str, DataSet]``, 其 key 一一对应。 ds_ratio 的 value 是任意大于 0 的浮点数, | |||
代表着 datasets 的 value 数据进行扩充或者缩减的倍数。 | |||
""" | |||
def __init__(self, datasets: Dict = None, mode: str = 'sequential', | |||
@@ -119,55 +166,6 @@ class MixDataLoader(DataLoader): | |||
num_workers: int = 0, batch_size: int = 16, drop_last=False, | |||
ds_ratio: Union[None, str, Dict[str, float]] = None, | |||
pin_memory: bool = False) -> None: | |||
""" | |||
:param datasets: 实现了 __getitem__() 和 __len__() 对象的序列或者字典。 | |||
:param mode: mode 控制 ``MixDataLoader`` 运行模式。 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: | |||
* 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 | |||
接一个的 sample 完所有数据。 | |||
* 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample | |||
混合数据集 datasets 的数据组成一个 batch 序列返回。 | |||
* 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 datasets 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回, | |||
再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。 | |||
* 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int]; | |||
且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数 | |||
sampler, drop_last, ds_ratio 均无效。 | |||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数。 其取值可以为 ``['auto', Callable, List[Callable], Dict[str, Callable]]``: | |||
* collate_fn 为 ``'auto'`` 时, ``MixDataLoader`` datasets 序列或者dict 初始化一个 :class:`~fastNLP.core.collators.Collator` 作为其默认值, | |||
需要注意的是只有当 datasets 包含的所以 dataset 的数据都为 ``List`` 或者 ``Dict`` 类型时才能使用。否则只能用户自己定义 collate_fn . | |||
* collate_fn 为 ``Callable`` 时, 该 collate_fn 会被 datasets 序列或者dict 的所有数据所共享。该 Callable 函数应当接受一个 batch 参数作为输入, | |||
batch 是一个 List 对象且 List 中的每一条数据都是 dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
* collate_fn 为 ``Dict[str, Callable]`` 时, datasets 的 key 必须和 callable_fn 的 key 一致。 ``MixDataLoader`` 会将 ``collate_fn[key]`` | |||
用到 ``datasets[key]`` 的数据集上。 ``collate_fn[key]`` 是一个 Callable 对象。 | |||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,其取值范围为 | |||
``[None, str, Dict[str, "Sampler"]]``: | |||
* sampler 为 ``None`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。 | |||
* sampler 为 ``str`` 时, sampler 选择范围为 ``[rand, seq]``。当 sampler 为 ``rand`` 时,``MixDataLoader`` 默认初始化 ``torch`` 的 ``RandomSampler`` | |||
作为默认值, 其功能时随机采样 dataset 的下标并返回。 当 sampler 为 ``seq`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。 | |||
* sampler 为 ``Dict[str, "Sampler"]`` 时, ``Sampler`` 为用户定义的实现了 __len__() 和 __iter__() 的实例化对象。 其每次 iter 必须返回一个 int 下标。 | |||
Dict 的 str 必须和 datasets 的 key 一致。 也即是 ``Dict[str, Sampler]`` 为 datasets 字典的每个 dataset 初始化勒一个 Sampler。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``MixDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快数据处理速度,但同时 | |||
也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 且 datasets 上所有 dataset 的 batch_size 一致。 | |||
:param drop_last: 当 ``drop_last=True`` 时,``MixDataLoader`` 会扔掉 datasets 中 每个 dataset 最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
:param ds_ratio: ``ds_ratio`` 是控制 datasets 怎么组成一个混合大数据集的重要参数, 其取值为 ``[None, 'truncate_to_least', 'pad_to_most', List[float], Dict[str, float]]``: | |||
* ds_ratio 为 ``None``, datasets 数据集序列或字典不进行数据扩充处理。 | |||
* ds_ratio 为 ``'truncate_to_least'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最断长度 ``mix_len``, 其他数据集会被切断 | |||
到最短长度 ``mix_len``。这种切断不是物理上切断,``MixDataLoader`` 会根据 sampler 不同来采样数据集到指定的最短长度 ``mix_len``。 | |||
* ds_ratio 为 ``'pad_to_most'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最大长度 ``max_len``, 其他其他数据集会扩充 | |||
到最大长度 ``mix_len``。这种扩充不是物理上扩充, ``MixDataLoader`` 会根据 sampler 不同来重采样 dataset 到指定的最大长度``max_len``。 | |||
* ds_ratio 为 ``Dict[str, float]`` 时, datasets 类型也必须为 ``Dict[str, DataSet]``, 其 key 一一对应。 ds_ratio 的 value 是任意大于 0 的浮点数, | |||
代表着 datasets 的 value 数据进行扩充或者缩减的倍数。 | |||
""" | |||
# sampler 为 dict,则判断是否与 datasets 的 key 相同 | |||
if isinstance(sampler, Dict): | |||
for key in datasets.keys(): | |||
@@ -14,7 +14,7 @@ __all__ = [ | |||
def indice_collate_wrapper(func:Callable): | |||
""" | |||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | |||
其功能是封装一层 collate_fn,将 dataset 取到的 tuple 数据分离开,将 idx 打包为 indices。 | |||
:param func: 需要修饰的函数 | |||
:return: | |||
@@ -115,7 +115,7 @@ class HasLenGetitemType(ABC): | |||
class OverfitDataLoader: | |||
""" | |||
实现一个简单的迭代器来模拟实际的 dataloader,从给定的 dataloader 中取出部分数据,来让 Trainer 实现 overfit 的功能; | |||
实现一个简单的迭代器来模拟实际的 dataloader,从给定的 ``dataloader`` 中取出部分数据,来让 Trainer 实现 overfit 的功能; | |||
""" | |||
def __init__(self, dataloader, overfit_batches: int, batches=None): | |||
@@ -1,7 +1,7 @@ | |||
r""" | |||
:class:`~fastNLP.core.dataset.DataSet` 是 fastNLP 中用于承载数据的容器。可以将 DataSet 看做是一个表格, | |||
每一行是一个 sample (在 fastNLP 中被称为 :mod:`~fastNLP.core.instance` ), | |||
每一列是一个 feature (在 fastNLP 中称为 :mod:`~fastNLP.core.field` )。 | |||
每一行是一个 sample (在 fastNLP 中被称为 :mod:`~fastNLP.core.dataset.instance` ), | |||
每一列是一个 feature (在 fastNLP 中称为 :mod:`~fastNLP.core.dataset.field` )。 | |||
.. csv-table:: Following is a demo layout of DataSet | |||
:header: "sentence", "words", "seq_len" | |||
@@ -11,7 +11,7 @@ r""" | |||
"Third instance .", "[Third, instance, .]", 3 | |||
"...", "[...]", "..." | |||
在 fastNLP 内部每一行是一个 :class:`~fastNLP.Instance` 对象; 每一列是一个 :class:`~fastNLP.FieldArray` 对象。 | |||
在 fastNLP 内部每一行是一个 :class:`~fastNLP.core.dataset.Instance` 对象; 每一列是一个 :class:`~fastNLP.core.dataset.FieldArray` 对象。 | |||
---------------------------- | |||
1.DataSet的创建 | |||
@@ -65,7 +65,7 @@ r""" | |||
2.DataSet 与预处理 | |||
-------------------------------------- | |||
常见的预处理有如下几种 | |||
常见的预处理有如下几种: | |||
2.1 从某个文本文件读取内容 | |||
-------------------------------------- | |||
@@ -97,10 +97,10 @@ r""" | |||
# 将句子分成单词形式, 详见DataSet.apply()方法, 可以开启多进程来加快处理, 也可以更改展示的bar,目前支持 ``['rich', 'tqdm', None]``, | |||
# 详细内容可以见 :class:`~fastNLP.core.dataset.DataSet`, 需要注意的时匿名函数不支持多进程 | |||
dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words', | |||
progress_des='Main',progress_bar='rich') | |||
progress_des='Main',progress_bar='rich') | |||
# 或使用DataSet.apply_field() | |||
dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words', | |||
progress_des='Main',progress_bar='rich') | |||
progress_des='Main',progress_bar='rich') | |||
# 除了匿名函数,也可以定义函数传递进去 | |||
def get_words(instance): | |||
sentence = instance['sentence'] | |||
@@ -145,8 +145,8 @@ r""" | |||
# DataSet 的长度 | |||
len(dataset) | |||
""" | |||
__all__ = [ | |||
"DataSet", | |||
"ApplyResultException" | |||
@@ -255,34 +255,31 @@ def _multi_proc(ds, _apply_field, func, counter, queue): | |||
class DataSet: | |||
r""" | |||
fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset` | |||
""" | |||
fastNLP的数据容器。 | |||
def __init__(self, data: Union[List[Instance], Dict[str, List[Any]], None] = None): | |||
r""" | |||
初始化 ``DataSet``, fastNLP的 DataSet 是 key-value 存储形式, 目前支持两种初始化方式,输入 data 分别为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 和 | |||
``Dict[str, List[Any]]``。 | |||
* 当 data 为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 | |||
Instance 详见 :class:`~fastNLP.core.dataset.Instance` 。 | |||
* 当 data 为 ``Dict[str, List[Any]] 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 | |||
Example:: | |||
:param data: 初始化的内容, 其只能为两种类型,分别为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 和 | |||
``Dict[str, List[Any]]``。 | |||
from fastNLP.core.dataset import DataSet, Instance | |||
data = {'x': [[1, 0, 1], [0, 1, 1], 'y': [0, 1]} | |||
data1 = [Instance(x=[1,0,1],y=0), Instance(x=[0,1,1],y=1)] | |||
ds = DataSet(data) | |||
ds = DataSet(data1) | |||
* 当 data 为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 | |||
Instance 详见 :class:`~fastNLP.core.dataset.Instance` 。 | |||
* 当 data 为 ``Dict[str, List[Any]] 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 | |||
fastNLP的 DataSet 是 key-value 存储形式, 目前支持两种初始化方式,输入 data 分别为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 和 | |||
``Dict[str, List[Any]]``。 | |||
Example:: | |||
* 当 data 为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 | |||
Instance 详见 :class:`~fastNLP.core.dataset.Instance` 。 | |||
* 当 data 为 ``Dict[str, List[Any]]`` 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 | |||
from fastNLP.core.dataset import DataSet, Instance | |||
data = {'x': [[1, 0, 1], [0, 1, 1], 'y': [0, 1]} | |||
data1 = [Instance(x=[1,0,1],y=0), Instance(x=[0,1,1],y=1)] | |||
ds = DataSet(data) | |||
ds = DataSet(data1) | |||
:param data: 初始化的内容,其只能为两种类型,分别为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 和 | |||
``Dict[str, List[Any]]``。 | |||
""" | |||
* 当 data 为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 | |||
Instance 详见 :class:`~fastNLP.core.dataset.Instance` 。 | |||
* 当 data 为 ``Dict[str, List[Any]] 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 | |||
""" | |||
def __init__(self, data: Union[List[Instance], Dict[str, List[Any]], None] = None): | |||
self.field_arrays = {} | |||
self._collator = Collator() | |||
if data is not None: | |||
@@ -429,10 +426,9 @@ class DataSet: | |||
def append(self, instance: Instance) -> None: | |||
r""" | |||
将一个 instance 对象 append 到 DataSet 后面。详见 :class:`~fastNLP.Instance` | |||
:param instance: 若 DataSet 不为空,则 instance 应该拥有和 DataSet 完全一样的 field。 | |||
将一个 ``instance`` 对象 append 到 DataSet 后面。详见 :class:`~fastNLP.core.dataset.Instance` | |||
:param instance: 若 DataSet 不为空,则 instance 应该拥有和 DataSet 完全一样的 field; | |||
""" | |||
if len(self.field_arrays) == 0: | |||
# DataSet has no field yet | |||
@@ -454,10 +450,10 @@ class DataSet: | |||
def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None: | |||
r""" | |||
将 fieldarray 添加到 DataSet 中. | |||
将 ``fieldarray`` 添加到 DataSet 中. | |||
:param field_name: 新加入的 field 的名称 | |||
:param fieldarray: 需要加入 DataSet 的 field 的内容, 详见 :class:`~fastNLP.core.dataset.FieldArray` | |||
:param field_name: 新加入的 field 的名称; | |||
:param fieldarray: 需要加入 DataSet 的 field 的内容, 详见 :class:`~fastNLP.core.dataset.FieldArray` ; | |||
:return: | |||
""" | |||
if not isinstance(fieldarray, FieldArray): | |||
@@ -472,8 +468,8 @@ class DataSet: | |||
r""" | |||
新增一个 field, 需要注意的是 fields 的长度跟 DataSet 长度一致 | |||
:param field_name: 新增的 field 的名称 | |||
:param fields: 需要新增的 field 的内容 | |||
:param field_name: 新增的 field 的名称; | |||
:param fields: 需要新增的 field 的内容; | |||
""" | |||
if len(self.field_arrays) != 0: | |||
@@ -484,9 +480,9 @@ class DataSet: | |||
def delete_instance(self, index: int): | |||
r""" | |||
删除第 ``index `` 个 Instance | |||
删除第 ``index`` 个 Instance | |||
:param index: 需要删除的 instanc e的 index,序号从 `0` 开始。 | |||
:param index: 需要删除的 instance 的 index,序号从 `0` 开始。 | |||
""" | |||
assert isinstance(index, int), "Only integer supported." | |||
if len(self) <= index: | |||
@@ -500,9 +496,9 @@ class DataSet: | |||
def delete_field(self, field_name: str): | |||
r""" | |||
删除名为 field_name 的 field | |||
删除名为 ``field_name`` 的 field | |||
:param field_name: 需要删除的 field 的名称. | |||
:param field_name: 需要删除的 field 的名称; | |||
""" | |||
if self.has_field(field_name): | |||
self.field_arrays.pop(field_name) | |||
@@ -512,11 +508,11 @@ class DataSet: | |||
def copy_field(self, field_name: str, new_field_name: str): | |||
r""" | |||
深度 copy 名为 field_name 的 field 到 new_field_name | |||
深度 copy 名为 ``field_name`` 的 field 到 ``new_field_name`` | |||
:param field_name: 需要 copy 的 field。 | |||
:param new_field_name: copy 生成的 field 名称 | |||
:return: self | |||
:param field_name: 需要 copy 的 field; | |||
:param new_field_name: copy 生成的 field 名称; | |||
:return: 数据集自身; | |||
""" | |||
if not self.has_field(field_name): | |||
raise KeyError(f"Field:{field_name} not found in DataSet.") | |||
@@ -527,10 +523,10 @@ class DataSet: | |||
def has_field(self, field_name: str) -> bool: | |||
r""" | |||
判断 DataSet 中是否有名为 field_name 这个 field | |||
判断 DataSet 中是否有名为 ``field_name`` 这个 field | |||
:param field_name: field 的名称 | |||
:return: 表示是否有名为 field_name 这个 field | |||
:param field_name: field 的名称; | |||
:return: 表示是否有名为 ``field_name`` 这个 field; | |||
""" | |||
if isinstance(field_name, str): | |||
return field_name in self.field_arrays | |||
@@ -538,10 +534,10 @@ class DataSet: | |||
def get_field(self, field_name: str) -> FieldArray: | |||
r""" | |||
获取 field_name 这个 field | |||
获取名为 ``field_name`` 的 field | |||
:param field_name: field 的名称 | |||
:return: :class:`~fastNLP.FieldArray` | |||
:param field_name: field 的名称; | |||
:return: 一个 :class:`~fastNLP.core.dataset.FieldArray` 对象; | |||
""" | |||
if field_name not in self.field_arrays: | |||
raise KeyError("Field name {} not found in DataSet".format(field_name)) | |||
@@ -549,17 +545,13 @@ class DataSet: | |||
def get_all_fields(self) -> dict: | |||
r""" | |||
返回一个 dict,key 为 field_name, value为对应的 :class:`~fastNLP.FieldArray` | |||
:return: 返回如上所述的字典 | |||
:return: 一个 dict,key 为 field_name, value为对应的 :class:`~fastNLP.core.dataset.FieldArray` 对象。 | |||
""" | |||
return self.field_arrays | |||
def get_field_names(self) -> list: | |||
r""" | |||
返回一个 list,包含所有 field 的名字 | |||
:return: 返回如上所述的列表 | |||
:return: 一个 list,包含所有 field 的名字 | |||
""" | |||
return sorted(self.field_arrays.keys()) | |||
@@ -575,8 +567,8 @@ class DataSet: | |||
r""" | |||
将某个 field 重新命名. | |||
:param field_name: 原来的 field 名称。 | |||
:param new_field_name: 修改为 new_name。 | |||
:param field_name: 原来的 field 名称; | |||
:param new_field_name: 修改为 new_name; | |||
""" | |||
if field_name in self.field_arrays: | |||
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) | |||
@@ -589,13 +581,13 @@ class DataSet: | |||
new_field_name: str = None, num_proc: int = 0, | |||
progress_desc: str = None, progress_bar: str = 'rich'): | |||
r""" | |||
将 :class:`DataSet` 每个 ``instance`` 中为 ``field_name`` 的 ``field`` 传给函数 ``func``,并写入到 ``new_field_name`` | |||
将 :class:`DataSet` 每个 ``instance`` 中为 ``field_name`` 的 field 传给函数 ``func``,并写入到 ``new_field_name`` | |||
中。 | |||
:param field_name: 传入 ``func`` 的 ``field`` 名称; | |||
:param func: 对指定 ``field`` 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 ``field`` 的内容; | |||
:param func: 对指定 fiel` 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; | |||
:param field_name: 传入 ``func`` 的 field 名称; | |||
:param new_field_name: 函数执行结果写入的 ``field`` 名称。该函数会将 ``func`` 返回的内容放入到 ``new_field_name`` 对 | |||
应的 ``field`` 中,注意如果名称与已有的 ``field`` 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 ``field`` ; | |||
应的 ``field`` 中,注意如果名称与已有的 field 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 field ; | |||
:param num_proc: 使用进程的数量。 | |||
.. note:: | |||
@@ -603,8 +595,8 @@ class DataSet: | |||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | |||
``func`` 函数中的打印将不会输出。 | |||
:param progress_desc: 进度条的描述字符,默认为 ``Processing``; | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param progress_desc: 如果不为 ``None``,则会显示当前正在处理的进度条的名称; | |||
:param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 | |||
:return: 从函数 ``func`` 中得到的返回值; | |||
""" | |||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | |||
@@ -625,26 +617,27 @@ class DataSet: | |||
modify_fields: bool = True, num_proc: int = 0, | |||
progress_desc: str = None, progress_bar: str = 'rich'): | |||
r""" | |||
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。 | |||
func 可以返回一个或多个 field 上的结果。 | |||
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的 field 传给 ``func``,并获取它的返回值。 | |||
``func`` 可以返回一个或多个 field 上的结果。 | |||
.. note:: | |||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.core.dataset.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply`` 区别的介绍。 | |||
:param field_name: 传入func的是哪个field。 | |||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||
:param func: 对指定 fiel` 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; | |||
:param field_name: 传入 ``func`` 的 fiel` 名称; | |||
:param new_field_name: 函数执行结果写入的 ``field`` 名称。该函数会将 ``func`` 返回的内容放入到 ``new_field_name`` 对 | |||
应的 ``field`` 中,注意如果名称与已有的 field 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 field ; | |||
:param num_proc: 使用进程的数量。 | |||
.. note:: | |||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | |||
``func`` 函数中的打印将不会输出。 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param progress_desc: 当显示 progress_bar 时,显示当前正在处理的进度条描述字符 | |||
:return Dict[str:Field]: 返回一个字典 | |||
:param progress_desc: 如果不为 ``None``,则会显示当前正在处理的进度条的名称; | |||
:param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 | |||
:return: 返回一个字典 | |||
""" | |||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | |||
if not self.has_field(field_name=field_name): | |||
@@ -747,7 +740,7 @@ class DataSet: | |||
def apply_more(self, func: Callable = None, modify_fields: bool = True, | |||
num_proc: int = 0, progress_desc: str = '', progress_bar: str = 'rich'): | |||
r""" | |||
将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。 | |||
将 ``DataSet`` 中每个 ``Instance`` 传入到 ``func`` 中,并获取它的返回值。``func``可以返回一个或多个 field 上的结果。 | |||
.. note:: | |||
``apply_more`` 与 ``apply`` 的区别: | |||
@@ -767,9 +760,9 @@ class DataSet: | |||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | |||
``func`` 函数中的打印将不会输出。 | |||
:param progress_desc: 当 progress_bar 不为 None 时,可以显示当前正在处理的进度条名称 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:return Dict[str:Field]: 返回一个字典 | |||
:param progress_desc: 当 progress_bar 不为 ``None`` 时,可以显示当前正在处理的进度条名称 | |||
:param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 | |||
:return: 返回一个字典 | |||
""" | |||
assert callable(func), "The func is not callable." | |||
assert len(self) != 0, "Null DataSet cannot use apply()." | |||
@@ -808,10 +801,11 @@ class DataSet: | |||
def apply(self, func: Callable = None, new_field_name: str = None, | |||
num_proc: int = 0, progress_bar: str = 'rich', progress_desc: str = ''): | |||
""" | |||
将 ``DataSet`` 中每个 ``Instance`` 传入到 ``func`` 中,并获取它的返回值。``func`` 仅能返回一个结果。 | |||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param new_field_name: 将 ``func`` 返回的内容放入到 ``new_field_name`` 这个 field中 ,如果名称与已有的 field 相同,则覆 | |||
盖之前的 field。如果为 ``None`` 则不创建新的 field。 | |||
:param num_proc: 使用进程的数量。 | |||
.. note:: | |||
@@ -819,8 +813,8 @@ class DataSet: | |||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | |||
``func`` 函数中的打印将不会输出。 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param progress_desc: progress bar 显示的值,默认为空。 | |||
:param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 | |||
:param progress_desc: 如果不为 ``None``,则会显示当前正在处理的进度条的名称。 | |||
""" | |||
assert callable(func), "The func you provide is not callable." | |||
assert len(self) != 0, "Null DataSet cannot use apply()." | |||
@@ -838,10 +832,10 @@ class DataSet: | |||
def add_seq_len(self, field_name: str, new_field_name='seq_len'): | |||
r""" | |||
将使用 len() 直接对 field_name 中每个元素作用,将其结果作为 sequence length, 并放入 seq_len 这个 field。 | |||
将使用 :func:`len` 直接对 ``field_name`` 中每个元素作用,将其结果作为 sequence length, 并放入 ``new_field_name`` 这个 field。 | |||
:param field_name: 需要处理的 field_name | |||
:param new_field_name: str. 新的 field_name | |||
:param new_field_name: 新的 field_name | |||
:return: | |||
""" | |||
if self.has_field(field_name=field_name): | |||
@@ -852,10 +846,10 @@ class DataSet: | |||
def drop(self, func: Callable, inplace=True): | |||
r""" | |||
删除某些 Instance。 需要注意的时func 接受一个 Instance ,返回 bool 值。返回值为 True 时, | |||
删除某些 Instance。 需要注意的是 ``func`` 接受一个 Instance ,返回 bool 值。返回值为 ``True`` 时, | |||
该 Instance 会被移除或者不会包含在返回的 DataSet 中。 | |||
:param func: 接受一个 Instance 作为参数,返回 bool 值。为 True 时删除该 instance | |||
:param func: 接受一个 Instance 作为参数,返回 bool 值。为 ``True`` 时删除该 instance | |||
:param inplace: 是否在当前 DataSet 中直接删除 instance;如果为 False,将返回一个新的 DataSet。 | |||
:return: DataSet | |||
@@ -875,11 +869,11 @@ class DataSet: | |||
def split(self, ratio: float, shuffle=True): | |||
r""" | |||
将 DataSet 按照 ratio 的比例拆分,返回两个 DataSet | |||
将 DataSet 按照 ``ratio`` 的比例拆分,返回两个 DataSet | |||
:param ratio: 0<ratio<1, 返回的第一个 DataSet 拥有 `ratio` 这么多数据,第二个 DataSet 拥有 `(1-ratio)` 这么多数据 | |||
:param shuffle: 在 split 前是否 shuffle 一下。为 False,返回的第一个 dataset 就是当前 dataset 中前 `ratio` 比例的数据, | |||
:return: [ :class:`~fastNLP.读取后的DataSet` , :class:`~fastNLP.读取后的DataSet` ] | |||
:param ratio: 0<ratio<1, 返回的第一个 DataSet 拥有 ``ratio`` 比例的数据,第二个 DataSet 拥有 ``1-ratio`` 的数据; | |||
:param shuffle: 在拆分前是否进行排序。为 False,返回的第一个 dataset 就是当前 dataset 中前 ``ratio`` 比例的数据; | |||
:return: 拆分后的两个 DataSet; | |||
""" | |||
assert len(self) > 1, f'DataSet with {len(self)} instance cannot be split.' | |||
assert isinstance(ratio, float) | |||
@@ -906,9 +900,9 @@ class DataSet: | |||
def save(self, path: str) -> None: | |||
r""" | |||
保存DataSet. | |||
保存 DataSet。 | |||
:param path: 将DataSet存在哪个路径 | |||
:param path: 保存路径; | |||
""" | |||
with open(path, 'wb') as f: | |||
pickle.dump(self, f) | |||
@@ -916,10 +910,10 @@ class DataSet: | |||
@staticmethod | |||
def load(path: str): | |||
r""" | |||
从保存的 DataSet pickle文件的路径中读取DataSet | |||
从保存的 DataSet pickle 文件的路径中读取 DataSet | |||
:param path: 从哪里读取 DataSet | |||
:return: 读取后的 :class:`~fastNLP.读取后的DataSet`。 | |||
:param path: 读取路径; | |||
:return: 读取出的 DataSet | |||
""" | |||
with open(path, 'rb') as f: | |||
d = pickle.load(f) | |||
@@ -928,16 +922,16 @@ class DataSet: | |||
def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet': | |||
""" | |||
将当前 dataset 与输入的 dataset 结合成一个更大的 dataset,需要保证两个 dataset 都包含了相同的 field。结合后的 dataset | |||
的 field_name 和 _collator 以当前 dataset 为准。当 dataset 中包含的 field 多于当前的 dataset,则多余的 field 会被忽略; | |||
若 dataset 中未包含所有当前 dataset 含有 field,则会报错。 | |||
将当前 DataSet 与输入的 ``dataset`` 结合成一个更大的 dataset,需要保证两个 dataset 都包含了相同的 field。结合后的 dataset | |||
的 field_name 和 _collator 以当前 dataset 为准。若 ``dataset`` 中包含的 field 多于当前的 DataSet,则多余的 field 会被忽略; | |||
若 ``dataset`` 中未包含所有当前 DataSet 含有 field,则会报错。 | |||
:param dataset: 需要和当前 dataset concat的 dataset | |||
:param inplace: 是否直接将 dataset 组合到当前 dataset 中 | |||
:param field_mapping: 当传入的 dataset 中的 field 名称和当前 dataset 不一致时,需要通过 field_mapping 把输入的 dataset 中的 | |||
field 名称映射到当前 field. field_mapping 为 dict 类型,key 为 dataset 中的 field 名称,value 是需要映射成的名称 | |||
:param dataset: 需要和当前 DataSet 拼接的 ``dataset``; | |||
:param inplace: 是否直接将 ``dataset`` 组合到当前 DataSet 中; | |||
:param field_mapping: 当传入的 ``dataset`` 中的 field 名称和当前 dataset 不一致时,需要通过 ``field_mapping`` 把输入的 ``dataset`` | |||
中的 field 名称映射到当前 field。``field_mapping`` 为 dict 类型,key 为 11dataset`` 中的 field 名称,value 是需要映射成的名称 | |||
:return: :class:`~fastNLP.core.dataset.DataSet`` | |||
:return: :class:`~fastNLP.core.dataset.DataSet` | |||
""" | |||
assert isinstance(dataset, DataSet), "Can only concat two datasets." | |||
@@ -966,7 +960,8 @@ class DataSet: | |||
@classmethod | |||
def from_pandas(cls, df): | |||
""" | |||
从 ``pandas.DataFrame`` 中读取数据转为 DataSet | |||
从 :class:`pandas.DataFrame` 中读取并数据转化为 DataSet | |||
:param df: 使用 pandas 读取的数据 | |||
:return: | |||
""" | |||
@@ -975,7 +970,7 @@ class DataSet: | |||
def to_pandas(self): | |||
""" | |||
将 DataSet 数据转为 ``pandas.DataFrame`` 类型的数据 | |||
将 DataSet 数据转为 :class:`pandas.DataFrame` 类型的数据 | |||
:return: | |||
""" | |||
@@ -1003,23 +998,22 @@ class DataSet: | |||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | |||
pad_fn: Callable = None) -> Collator: | |||
""" | |||
``DataSet`` 中想要对绑定的 collator 进行调整可以调用此函数。 ``collator`` 为 :class:`~fastNLP.core.collators.Collator` | |||
时该函数才有效。调用该函数可以对 field 内容的 pad_val, dtype, backend 等进行调整。 | |||
:param field_name: 需要调整的 field 的名称。如果 DataSet 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||
无意义。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||
:param backend: 可选['raw', 'numpy', 'torch', 'torch', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||
torch.Tensor, torch.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||
形式,输出将被直接作为结果输出。 | |||
:return: 返回 Collator | |||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
:param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||
如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, | |||
该值无意义。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 | |||
:param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, | |||
:class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 | |||
若 ``pad_val`` 为 ``None`` ,该值无意义 。 | |||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 | |||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 | |||
:return: 返回自身的 collator; | |||
""" | |||
if isinstance(self.collator, Collator): | |||
self.collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||
@@ -1030,16 +1024,14 @@ class DataSet: | |||
def set_ignore(self, *field_names) -> Collator: | |||
""" | |||
``DataSet`` 中想要对绑定的 collator 进行调整可以调用此函数。 ``collator`` 为 :class:`~fastNLP.core.collators.Collator` | |||
时该函数才有效。调用该函数可以设置忽略输出某些 field 的内容,被设置的 field 将在 batch 的输出中被忽略。 | |||
Example:: | |||
时该函数才有效。调用该函数可以设置忽略输出某些 field 的内容,被设置的 field 将在 batch 的输出中被忽略:: | |||
collator.set_ignore('field1', 'field2') | |||
dataset.set_ignore('field1', 'field2') | |||
:param field_names: 需要忽略的 field 的名称。如果 DataSet 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
:return: 返回 Collator 自身 | |||
:param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||
:return: 返回自身的 collator; | |||
""" | |||
if isinstance(self.collator, Collator): | |||
self.collator.set_ignore(*field_names) | |||
@@ -14,15 +14,14 @@ import numpy as np | |||
class FieldArray: | |||
""" | |||
:class:`~fastNLP.core.dataset.DatSet` 中用于表示列的数据类型。 | |||
def __init__(self, name: str, content): | |||
""" | |||
初始化 FieldArray | |||
:param name: 字符串的名称 | |||
:param content: 任意类型的数据 | |||
:param name: 字符串的名称 | |||
:param content: 任意类型的数据 | |||
""" | |||
""" | |||
def __init__(self, name: str, content): | |||
if len(content) == 0: | |||
raise RuntimeError("Empty fieldarray is not allowed.") | |||
_content = content | |||
@@ -36,18 +35,15 @@ class FieldArray: | |||
def append(self, val: Any) -> None: | |||
r""" | |||
:param val: 把该 val append 到 fieldarray。 | |||
:return: | |||
:param val: 把该 ``val`` 添加到 fieldarray 中。 | |||
""" | |||
self.content.append(val) | |||
def pop(self, index: int) -> None: | |||
r""" | |||
删除该 field 中 index 处的元素 | |||
删除该 field 中 ``index`` 处的元素 | |||
:param index: 从 ``0`` 开始的数据下标。 | |||
:return: | |||
""" | |||
self.content.pop(index) | |||
@@ -60,10 +56,10 @@ class FieldArray: | |||
def get(self, indices: Union[int, List[int]]): | |||
r""" | |||
根据给定的 indices 返回内容。 | |||
根据给定的 ``indices`` 返回内容。 | |||
:param indices: 获取 indices 对应的内容。 | |||
:return: 根据给定的 indices 返回的内容,可能是单个值 或 ``ndarray`` | |||
:param indices: 获取 ``indices`` 对应的内容。 | |||
:return: 根据给定的 ``indices`` 返回的内容,可能是单个值 或 :class:`numpy.ndarray` | |||
""" | |||
if isinstance(indices, int): | |||
if indices == -1: | |||
@@ -80,16 +76,16 @@ class FieldArray: | |||
r""" | |||
返回长度 | |||
:return length: | |||
:return: | |||
""" | |||
return len(self.content) | |||
def split(self, sep: str = None, inplace: bool = True): | |||
r""" | |||
依次对自身的元素使用 ``.split()`` 方法,应该只有当本 field 的元素为 ``str`` 时,该方法才有用。 | |||
依次对自身的元素使用 ``.split()`` 方法,应该只有当本 field 的元素为 :class:`str` 时,该方法才有用。 | |||
:param sep: 分割符,如果为 ``None`` 则直接调用 ``str.split()``。 | |||
:param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 | |||
:param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 :class:`list`。 | |||
:return: List[List[str]] or self | |||
""" | |||
new_contents = [] | |||
@@ -104,10 +100,11 @@ class FieldArray: | |||
def int(self, inplace: bool = True): | |||
r""" | |||
将本 field 中的值调用 ``int(cell)``. 支持 field 中内容为以下两种情况: | |||
* ['1', '2', ...](即 field 中每个值为 ``str`` 的), | |||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list`` ,``list`` 中的值会被依次转换。) | |||
:param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 | |||
* ['1', '2', ...](即 field 中每个值为 :class:`str` 的), | |||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。) | |||
:param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。 | |||
:return: List[int], List[List[int]], self | |||
""" | |||
new_contents = [] | |||
@@ -126,10 +123,10 @@ class FieldArray: | |||
r""" | |||
将本 field 中的值调用 ``float(cell)``. 支持 field 中内容为以下两种情况: | |||
* ['1', '2', ...](即 field 中每个值为 ``str`` 的), | |||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list`` 中的值会被依次转换。) | |||
* ['1', '2', ...](即 field 中每个值为 :class:`str` 的), | |||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。) | |||
:param inplace: 如果为 ``True``,则将新生成值替换本 ``field``。否则返回 ``list``。 | |||
:param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。 | |||
:return: | |||
""" | |||
new_contents = [] | |||
@@ -148,10 +145,10 @@ class FieldArray: | |||
r""" | |||
将本field中的值调用 ``bool(cell)``. 支持 field 中内容为以下两种情况 | |||
* ['1', '2', ...](即 field 中每个值为 ``str`` 的), | |||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list`` 中的值会被依次转换。) | |||
* ['1', '2', ...](即 field 中每个值为 :class:`str` 的), | |||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。) | |||
:param inplace: 如果为 ``True``,则将新生成值替换本 ``field``。否则返回 ``list``。 | |||
:param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。 | |||
:return: | |||
""" | |||
new_contents = [] | |||
@@ -169,12 +166,12 @@ class FieldArray: | |||
def lower(self, inplace=True): | |||
r""" | |||
将本 field 中的值调用 ``cell.lower()``. 支持 field 中内容为以下两种情况 | |||
将本 field 中的值调用 ``cell.lower()``, 支持 field 中内容为以下两种情况 | |||
* ['1', '2', ...](即 ``field`` 中每个值为 ``str`` 的), | |||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list``中的值会被依次转换。) | |||
* ['1', '2', ...](即 field 中每个值为 :class:`str` 的), | |||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。) | |||
:param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 | |||
:param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。 | |||
:return: List[int], List[List[int]], self | |||
""" | |||
new_contents = [] | |||
@@ -191,12 +188,12 @@ class FieldArray: | |||
def upper(self, inplace=True): | |||
r""" | |||
将本 field 中的值调用 ``cell.lower()``. 支持 field 中内容为以下两种情况 | |||
将本 field 中的值调用 ``cell.upper()``, 支持 field 中内容为以下两种情况 | |||
* ['1', '2', ...](即 field 中每个值为 ``str`` 的), | |||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list`` 中的值会被依次转换。) | |||
* ['1', '2', ...](即 field 中每个值为 :class:`str` 的), | |||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。) | |||
:param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 | |||
:param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。 | |||
:return: List[int], List[List[int]], self | |||
""" | |||
new_contents = [] | |||
@@ -211,11 +208,11 @@ class FieldArray: | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def value_count(self): | |||
def value_count(self) -> Counter: | |||
r""" | |||
返回该 field 下不同 value的 数量。多用于统计 label 数量 | |||
返回该 field 下不同 value 的数量。多用于统计 label 数量 | |||
:return: Counter, key 是 label,value 是出现次数 | |||
:return: 计数结果,key 是 label,value 是出现次数 | |||
""" | |||
count = Counter() | |||
@@ -1,7 +1,6 @@ | |||
r""" | |||
instance 模块实现了 Instance 类在 fastNLP 中对应 sample。一个 sample 可以认为是一个 Instance 类型的对象。 | |||
便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 。 | |||
instance 模块实现了 Instance 类,即在 fastNLP 中 sample 对应的类型。一个 sample 可以认为是一个 Instance 类型的对象。 | |||
便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset.dataset` 。 | |||
""" | |||
__all__ = [ | |||
@@ -15,9 +14,9 @@ from fastNLP.core.utils.utils import pretty_table_printer | |||
class Instance(Mapping): | |||
r""" | |||
Instance 是 fastNLP 中对应一个 sample 的类。每个 sample 在 fastNLP 中是一个 Instance 对象。 | |||
Instance 一般与 :class:`~fastNLP.DataSet` 一起使用, Instance 的初始化如下面的 Example 所示:: | |||
Instance 一般与 :class:`~fastNLP.DataSet` 一起使用, Instance 的初始化如下面的代码所示:: | |||
>>> instance = Instance(input="this is a demo sentence", label='good') # 请补充完整 | |||
>>> instance = Instance(input="this is a demo sentence", label='good') | |||
""" | |||
@@ -159,7 +159,7 @@ class PaddleDriver(Driver): | |||
:param only_state_dict: 是否只保存模型的 ``state_dict``;如果为 ``False``,则会调用 ``paddle.jit.save`` | |||
函数保存整个模型的参数,此时需要传入 ``input_spec`` 参数; | |||
:kwargs: | |||
* input_spec -- 描述存储模型 ``forward`` 方法的输入; | |||
* *input_spec* -- 描述存储模型 ``forward`` 方法的输入; | |||
当 ``only_state_dict`` 为 ``False`` 时必须传入,否则加载时会报错。您可以通过 ``InputSpec`` 或者示例 ``Tensor`` | |||
进行描述。详细的使用方法可以参考 **PaddlePaddle** `关于 paddle.jit.save 函数的文档 <https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save>`_; | |||
""" | |||
@@ -20,7 +20,7 @@ class Metric: | |||
:param backend: 目前支持四种类型的 backend, ``[torch, paddle, jittor, auto]``。其中 ``auto`` 表示根据实际调用 | |||
Metric.update() 函数时传入的参数决定具体的 ``backend`` ,大部分情况下直接使用 ``auto`` 即可。 | |||
:param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 :class:`~fastNLP.Evaluator` 中根据 sampler 是否使用分布式 | |||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 :class:`~fastNLP.core.controllers.Evaluator` 中根据 sampler 是否使用分布式 | |||
进行自动设置。 | |||
""" | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None): | |||
@@ -212,7 +212,7 @@ class SpanFPreRecMetric(Metric): | |||
:param backend: 目前支持四种类型的 backend, ``[torch, paddle, jittor, auto]``。其中 ``auto`` 表示根据实际调用 | |||
Metric.update() 函数时传入的参数决定具体的 ``backend`` ,大部分情况下直接使用 ``auto`` 即可。 | |||
:param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 :class:`~fastNLP.Evaluator` 中根据 sampler 是否使用分布式 | |||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 :class:`~fastNLP.core.controllers.Evaluator` 中根据 sampler 是否使用分布式 | |||
进行自动设置。 | |||
""" | |||
def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | |||