diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 4f5ee3d8..300a342f 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -16,6 +16,7 @@ __all__ = [ "ResultsMonitor", 'HasMonitorCallback', "FitlogCallback", + "TimerCallback", # collators 'Collator', @@ -45,9 +46,11 @@ __all__ = [ 'TorchDataLoader', 'PaddleDataLoader', 'JittorDataLoader', + 'OneflowDataLoader', 'prepare_jittor_dataloader', 'prepare_paddle_dataloader', 'prepare_torch_dataloader', + 'prepare_oneflow_dataloader', "prepare_dataloader", # dataset @@ -58,10 +61,13 @@ __all__ = [ # drivers "TorchSingleDriver", "TorchDDPDriver", + "DeepSpeedDriver", "PaddleSingleDriver", "PaddleFleetDriver", "JittorSingleDriver", "JittorMPIDriver", + "OneflowSingleDriver", + "OneflowDDPDriver", # log "logger", diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index 48699b68..d1f19b96 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -21,7 +21,9 @@ __all__ = [ "ResultsMonitor", 'HasMonitorCallback', - "FitlogCallback" + "FitlogCallback", + + "TimerCallback" ] @@ -37,4 +39,4 @@ from .torch_callbacks import * from .more_evaluate_callback import MoreEvaluateCallback from .has_monitor_callback import ResultsMonitor, HasMonitorCallback from .fitlog_callback import FitlogCallback - +from .timer_callback import TimerCallback diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index d3679572..9584aba5 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -10,7 +10,7 @@ from .callback_event import Event, Filter class Callback: r""" - 实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; + 实际使用的 callback 类,不管是 **fastNLP** 默认提供的一些 callback 实例,还是用户自己定制的 callback 类,都应该继承该基类; callback 调用时机顺序大概如下:: Trainer.__init__(): @@ -41,17 +41,17 @@ class Callback: finally: on_train_end(trainer) - 其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/ - on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中特定 - 的时间调用。 + 其它 callback 例如 **on_evaluate_begin(trainer)** / **on_evaluate_end(trainer, results)** / **on_save_model(trainer)** / + **on_load_model(trainer)** / **on_save_checkpoint(trainer)** / **on_load_checkpoint(trainer)** 将根据需要在 :meth:`Trainer.run ` + 中特定的时间调用。 """ def on_after_trainer_initialized(self, trainer, driver): r""" - 在 `Trainer` 初始化后会被触发; + 在 ``Trainer`` 初始化后会被触发; - :param trainer: ``Trainer`` 实例; - :param driver: ``Trainer`` 中的 ``driver`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; + :param driver: :class:`~fastNLP.core.controllers.Trainer` 中的 ``driver`` 实例; """ pass @@ -59,7 +59,7 @@ class Callback: r""" 在 '预跑'检测 开始前会被触发; - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; """ pass @@ -67,7 +67,7 @@ class Callback: r""" 在 '预跑'检测 开始后会被触发; - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; :param sanity_check_res: 预跑得到的评测结果,关于对于 **预跑** 的解释,请见 :meth:`~fastNLP.core.controllers.trainer.Trainer.run`; """ pass @@ -76,7 +76,7 @@ class Callback: r""" 在训练开始前会被触发; - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; """ pass @@ -84,7 +84,7 @@ class Callback: r""" 在训练完成后会被触发; - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; """ pass @@ -92,7 +92,7 @@ class Callback: r""" 在训练过程中的每一个 epoch 开始前会被触发; - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; """ pass @@ -100,7 +100,7 @@ class Callback: r""" 在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。 - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; """ pass @@ -108,7 +108,7 @@ class Callback: r""" 在训练过程中准备取出下一个 batch 的数据时触发 - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; """ pass @@ -116,30 +116,30 @@ class Callback: r""" 在训练过程中拿到当前的 batch 数据后会被触发; - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; """ pass def on_train_batch_begin(self, trainer, batch, indices): r""" - 在取得数据,执行完 ``input_mapping`` (如果 ``Trainer`` 传有该参数),并且移动 ``batch`` 中的 ``tensor`` 到了指定设备。 + 在取得数据,执行完 ``input_mapping`` (如果 :class:`~fastNLP.core.controllers.Trainer` 传有该参数),并且移动 ``batch`` 中的张量到了指定设备之后会被触发。 其中 ``batch`` 中的数据格式要么是 ``Dataloader`` 返回的每个 ``batch`` 的格式;要么是 ``input_mapping`` 之后的内容。 - 如果 ``batch`` 是 ``dict`` 类型,直接增删其中的 ``key`` 或 修改其中的 ``value`` 会影响到输入到 ``model`` 的中的 ``batch`` 数据。 + 如果 ``batch`` 是 ``dict`` 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入模型的中的 ``batch`` 数据。 - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; :param batch: batch 的数据,已经经过 ``input_mapping`` (如果有) 以及移动到指定设备 。 - :param list[int] indices: 当前的 ``batch`` 是 ``dataset`` 中的哪些数据。仅在 ``DataLoader`` 支持得到当前 ``batch index`` 的时候有值, - 其它时候为 None 。 + :param list[int] indices: 当前的 ``batch`` 是数据集中的哪些数据。仅在 ``DataLoader`` 支持得到当前 ``batch index`` 的时候有值, + 其它时候为 ``None`` 。 """ pass def on_train_batch_end(self, trainer): r""" - 完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与 - global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会 + 完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch 与 + global_forward_batches 累计加1操作之后会被触发。其中梯度更新、梯度置零操作会考虑 **accumulation_steps** ,所以不一定在当前 batch 会 执行。 - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; """ pass @@ -147,41 +147,42 @@ class Callback: r""" 在训练过程遇到异常时调用。 - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; :param exception: 遭遇的异常; """ pass def on_save_model(self, trainer): r""" - 当调用 Trainer.save_model() 时调用,此刻模型还未保存。 + 当调用 :meth:`Trainer.save_model() ` 时调用,此刻模型还未保存。 - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; """ pass def on_load_model(self, trainer): r""" - 当调用 Trainer.load_model() 加载模型时调用,此刻模型还未加载。 + 当调用 :meth:`Trainer.load_model() ` 加载模型时调用,此刻模型还未加载。 - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; """ pass def on_save_checkpoint(self, trainer) -> Dict: r""" - 当 Trainer 将要保存 checkpoint 的时候触发 (即调用 Trainer.save_checkpoint() 函数时),该函数用于保存当前 callback 在恢复需要的相关数据。 + 当 Trainer 将要保存 checkpoint 的时候触发 (即调用 :meth:`Trainer.save_checkpoint() ` + 函数时),该函数用于保存当前 callback 在恢复时需要的相关数据。 - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; """ pass def on_load_checkpoint(self, trainer, states: Optional[Dict]): r""" - 当 Trainer 要恢复 checkpoint 的时候触发(即调用 Trainer.load_checkpoint() 函数时, 此刻 Trainer 与 Driver 已经加载好自身 - 的状态), 参数 states 为 Callback 在调用 on_save_checkpoint() 的返回值。 + 当 Trainer 要恢复 checkpoint 的时候触发(即调用 :meth:`Trainer.load_checkpoint() ` + 函数时, 此刻 Trainer 与 Driver 已经加载好自身的状态), 参数 states 为 Callback 在调用 :meth:`on_save_checkpoint` 的返回值。 - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; :param states: """ pass @@ -190,7 +191,7 @@ class Callback: r""" 在 backward 前执行。 - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; :param outputs: ``model`` 的返回内容。如果有 ``output_mapping``,则 ``outputs`` 中的内容为已经执行了 ``output_mapping`` 后的结果。 """ pass @@ -198,54 +199,54 @@ class Callback: def on_after_backward(self, trainer): r""" 在 ``backward`` 后执行。在多卡场景下,由于 ``accumulation_steps`` 的影响,仅在需要真正 ``update`` 参数那次梯度回传才会触发梯度同步, - 因此在多卡且使用 ``accumulation_steps`` 时,可能存在某些 ``step`` 各卡上梯度不一致的问题。 + 因此在多卡且使用 ``accumulation_steps`` 时,可能存在某些 step 各卡上梯度不一致的问题。 - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; """ pass def on_before_optimizers_step(self, trainer, optimizers): r""" - 在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 + 在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。 - :param trainer: ``Trainer`` 实例; - :param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; + :param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。 """ pass def on_after_optimizers_step(self, trainer, optimizers): r""" - 在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 + 在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。 - :param trainer: ``Trainer`` 实例; - :param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; + :param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。 """ pass def on_before_zero_grad(self, trainer, optimizers): r""" - 在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 + 在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。 - :param trainer: ``Trainer`` 实例; - :param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; + :param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。 """ pass def on_after_zero_grad(self, trainer, optimizers): r""" - 在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 + 在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。 - :param trainer: ``Trainer`` 实例; - :param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; + :param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。 """ pass def on_evaluate_begin(self, trainer): r""" - 在将要进行 evaluate 时调用。如果是设置的以 step 数量 或 自定义地 决定 evaluate 的频率,该接口是在 on_train_batch_end 之后 - 进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 + 在将要进行 ``evaluate`` 时调用。如果是设置的以 step 数量或自定义地决定 evaluate 的频率,该接口是在 :meth:`on_train_batch_end` 之后 + 进行调用。如果是以 epoch 数量决定调用时机,该接口是在 :meth:`on_train_epoch_end` 之后调用。 - :param trainer: ``Trainer`` 实例; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; """ pass @@ -253,17 +254,17 @@ class Callback: r""" 结束 evaluate 时调用,并把 evaluate 的结果传入。 - :param trainer: ``Trainer`` 实例; - :param results: ``Trainer`` 内置的 ``Evaluator`` 评测的结果,通常是个 ``dict``; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; + :param results: :class:`~fastNLP.core.controllers.Trainer` 内置的 ``Evaluator`` 评测的结果,通常是个 ``dict``; """ pass @property def callback_name(self): r""" - ``callback`` 的名称,我们会使用该名称从 ``checkpoint`` 中读取的相应的 ``state`` 并传递给 ``on_load_checkpoint()`` 函数。 + ``callback`` 的名称,我们会使用该名称从 ``checkpoint`` 中读取的相应的 ``state`` 并传递给 :meth:`on_load_checkpoint` 函数。 - :return: 返回用于区分该 ``callback`` 实例的 ``name``; + :return: 返回用于区分该 ``callback`` 实例的名称; """ return self.__class__.__name__ diff --git a/fastNLP/core/callbacks/callback_event.py b/fastNLP/core/callbacks/callback_event.py index e7657a25..8a51b6de 100644 --- a/fastNLP/core/callbacks/callback_event.py +++ b/fastNLP/core/callbacks/callback_event.py @@ -31,13 +31,13 @@ def check_legality(fn): class Event: """ - 与 Trainer.on 函数配合使用,达到控制 callback 函数运行时机的目的。 + 与 :meth:`Trainer.on` 函数配合使用,达到控制 callback 函数运行时机的目的。 - :param value: Trainer 的 callback 时机。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变量分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param value: Trainer 的 callback 时机; + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; """ every: Optional[int] once: Optional[int] @@ -53,416 +53,416 @@ class Event: return "".format(self.value, self.every, self.once, self.filter_fn) @staticmethod - @check_legality def on_after_trainer_initialized(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_after_trainer_initialized 时 + 当 Trainer 运行到 :func:`on_after_trainer_initialized` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。默认为 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_after_trainer_initialized', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_sanity_check_begin(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_sanity_check_begin 时 + 当 Trainer 运行到 :func:`on_sanity_check_begin` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; + :return: :return: """ return Event(value='on_sanity_check_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_sanity_check_end(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_sanity_check_end 时 + 当 Trainer 运行到 :func:`on_sanity_check_end` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_sanity_check_end', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_train_begin(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_train_begin 时 + 当 Trainer 运行到 :func:`on_train_begin` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_train_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_train_end(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_train_end 时 + 当 Trainer 运行到 :func:`on_train_end` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_train_end', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_train_epoch_begin(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_train_epoch_begin 时 + 当 Trainer 运行到 :func:`on_train_epoch_begin` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_train_epoch_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_train_epoch_end(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_train_epoch_end 时 + 当 Trainer 运行到 :func:`on_train_epoch_end` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_train_epoch_end', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_fetch_data_begin(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_fetch_data_begin 时 + 当 Trainer 运行到 :func:`on_fetch_data_begin` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_fetch_data_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_fetch_data_end(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_fetch_data_end 时 + 当 Trainer 运行到 :func:`on_fetch_data_end` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_fetch_data_end', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_train_batch_begin(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_train_batch_begin 时 + 当 Trainer 运行到 :func:`on_train_batch_begin` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_train_batch_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_train_batch_end(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_train_batch_end 时 + 当 Trainer 运行到 :func:`on_train_batch_end` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_train_batch_end', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_exception(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_exception 时 + 当 Trainer 运行到 :func:`on_exception` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_exception', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_save_model(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_save_model 时 + 当 Trainer 运行到 :func:`on_save_model` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_save_model', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_load_model(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_load_model 时 + 当 Trainer 运行到 :func:`on_load_model` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_load_model', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_save_checkpoint(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_save_checkpoint 时 + 当 Trainer 运行到 :func:`on_save_checkpoint` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_save_checkpoint', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_load_checkpoint(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_load_checkpoint 时 + 当 Trainer 运行到 :func:`on_load_checkpoint` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_load_checkpoint(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_load_checkpoint 时 + 当 Trainer 运行到 :func:`on_load_checkpoint` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_before_backward(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_before_backward 时 + 当 Trainer 运行到 :func:`on_before_backward` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_before_backward', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_after_backward(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_after_backward 时 + 当 Trainer 运行到 :func:`on_after_backward` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_after_backward', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_before_optimizers_step(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_before_optimizers_step 时 + 当 Trainer 运行到 :func:`on_before_optimizers_step` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_before_optimizers_step', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_after_optimizers_step(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_after_optimizers_step 时 + 当 Trainer 运行到 :func:`on_after_optimizers_step` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_after_optimizers_step', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_before_zero_grad(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_before_zero_grad 时 + 当 Trainer 运行到 :func:`on_before_zero_grad` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_before_zero_grad', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_after_zero_grad(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_after_zero_grad 时 + 当 Trainer 运行到 :func:`on_after_zero_grad` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_after_zero_grad', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_evaluate_begin(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_evaluate_begin 时 + 当 Trainer 运行到 :func:`on_evaluate_begin` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_evaluate_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - @check_legality + def on_evaluate_end(every=None, once=None, filter_fn=None): """ - 当 Trainer 运行到 on_evaluate_end 时 + 当 Trainer 运行到 :func:`on_evaluate_end` 时触发; - 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 - :param int every: 触发了多少次,才真正运行一次。 - :param bool once: 是否只在第一次运行后就不再执行了。 - :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 - filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :param every: 每触发多少次才真正运行一次; + :param once: 在第一次运行后时候再次执行; + :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 + `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; :return: """ return Event(value='on_evaluate_end', every=every, once=once, filter_fn=filter_fn) class Filter: - def __init__(self, every: Optional[int] = None, once: Optional[bool] = None, filter_fn: Optional[Callable] = None): - r""" - 通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率。 + r""" + 可以控制一个函数实际的运行频率的函数修饰器。 - :param every: 表示一个函数隔多少次运行一次; - :param once: 表示一个函数只运行一次; - :param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和 - `self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态; - """ + :param every: 表示一个函数隔多少次运行一次; + :param once: 表示一个函数是否只运行一次; + :param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和 + `self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态; + """ + def __init__(self, every: Optional[int] = None, once: Optional[bool] = None, filter_fn: Optional[Callable] = None): # check legality check_legality(lambda *args,**kwargs:...)(every, once, filter_fn) if (every is None) and (once is None) and (filter_fn is None): diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index d3d8ae75..bf1de884 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -75,12 +75,13 @@ class CallbackManager: def __init__(self, callbacks: Optional[List[Callback]]): r""" - 注意 callback 的调用顺序: + 注意 callback 的调用顺序为: + 1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数; 2. 通过 `Trainer` 的参数 `callbacks` 添加的 callback 类; 3. 通过 `Trainer.add_callback_fn` 添加的 callback 函数; - :param callbacks: 初始化时可以传入的一系列 callback 类,通常为用户在初始化 ``Trainer`` 时直接传入的 callback 类; + :param callbacks: 初始化时可以传入的一系列 :class:`~fastNLP.Callback` 类,通常为用户在初始化 ``Trainer`` 时直接传入的 callback 列表; """ self._need_reproducible_sampler = False @@ -106,12 +107,9 @@ class CallbackManager: def initialize_class_callbacks(self): r""" - 在实际的运行过程中,我们是将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是 + 在实际的运行过程中,我们会将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是 一个个 callback 时机,也就是 `Event` 的类别; 如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中; - - :param callbacks: - :return: """ for each_callback in self.class_callbacks: self._need_reproducible_sampler |= each_callback.need_reproducible_sampler @@ -144,11 +142,12 @@ class CallbackManager: 用于断点重训的 callback 的保存函数; 该函数主要涉及两个方面: - 1. callback 的状态的保存;我们会调用每一个 callback 的 `on_save_checkpoint` 方法,该方法应当返回一个字典,其中包含着 - 断点重训应当保存的状态; + 1. callback 的状态的保存;我们会调用每一个 callback 的 :func:`on_save_checkpoint` 方法,该方法应当返回一个字典,其中包含着 + 断点重训应当保存的状态; 2. 每一个具体的 callback 函数的 filter 的状态; - :return: 一个包含上述内容的字典: + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; + :return: 一个包含上述内容的字典,格式如下: .. code-block:: { @@ -195,11 +194,10 @@ class CallbackManager: def on_load_checkpoint(self, trainer, states: Dict): r""" - 用于断点重训的加载函数; - 对应于断点重训的保存函数; + 用于断点重训的加载函数,对应于断点重训的保存函数; - :param trainer: `Trainer` - :param states: 见 `on_save_checkpoint` 函数的返回值; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; + :param states: 同 :func:`on_save_checkpoint` 函数的返回值; """ # 1. 先恢复每一个具体的 callback 函数的 filter 的状态; diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 0cc3021b..44bd9c03 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -24,36 +24,37 @@ class CheckpointCallback(Callback): - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 - model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。若 model_save_fn 不为 None, + ``model_save_fn`` 为 ``None`` ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。若 ``model_save_fn`` 不为 ``None``, 则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。默认情况下,本 checkpoint 只保存了 model 的状态;如还需保存 Trainer 的状态以断点重训的话,请使用 ``save_object='trainer'`` 。 :param monitor: 监控的 metric 值。 * 为 ``None`` - 将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 + 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 * 为 ``str`` - 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 - 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 - * 为 ``Callable`` - 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 - 的 ``monitor`` 值请返回 ``None`` 。 + 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 + 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 + * 为 :class:`Callable` + 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 + 的 ``monitor`` 值请返回 ``None`` 。 + :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 :param every_n_epochs: 多少个 epoch 保存一次。 :param every_n_batches: 多少个 batch 保存一次。 - :param last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。如果为 False 则不会保存 {save_object}-last 文件 - :param topk: 保存 monitor 结果 topK 个。 - :param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。默认将捕获 EarlyStopException 。 + :param last: 如果为 ``True`` ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。如果为 ``False`` 则不会保存 ``{save_object}-last`` 文件 + :param topk: 保存 monitor 结果中的 ``topk`` 个。 + :param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。默认将捕获 :class:`~fastNLP.core.callbacks.EarlyStopException` 。 :param larger_better: monitor 的值是否时越大越好。 - :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 + :param only_state_dict: 保存模型时是否只保存 state_dict 。当 ``model_save_fn`` 不为 ``None`` 时,该参数无效。 :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 - 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 - :param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 - 保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 + 如果传入了 ``model_save_fn`` 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 + :param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 + 保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 - :param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 - fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 + :param save_evaluate_results: 是否保存 evaluate 的结果。如果为 ``True`` ,在保存 topk 模型的 folder 中还将额外保存一个 + ``fastnlp_evaluate_results.json`` 文件,记录当前的 results。仅在设置了 ``topk`` 的场景下有用,默认为 ``True`` 。 :param kwargs: """ def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, diff --git a/fastNLP/core/callbacks/early_stop_callback.py b/fastNLP/core/callbacks/early_stop_callback.py index db9b6493..8e542d56 100644 --- a/fastNLP/core/callbacks/early_stop_callback.py +++ b/fastNLP/core/callbacks/early_stop_callback.py @@ -10,16 +10,16 @@ from fastNLP.core.utils.exceptions import EarlyStopException class EarlyStopCallback(HasMonitorCallback): """ - 用于 early stop 的 callback 。当监控的结果连续多少次没有变好边 raise 一个 EarlyStopException 。 + 用于 early stop 的 callback 。当监控的结果连续多少次没有变好便 raise 一个 :class:`EarlyStopException` 。 :param monitor: 监控的 metric 值。 * 为 ``None`` - 将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 + 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 * 为 ``str`` 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 - * 为 ``Callable`` + * 为 :class:`Callable` 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 的 ``monitor`` 值请返回 ``None`` 。 :param larger_better: monitor 的值是否是越大越好。 diff --git a/fastNLP/core/callbacks/fitlog_callback.py b/fastNLP/core/callbacks/fitlog_callback.py index 35662539..a7716fa6 100644 --- a/fastNLP/core/callbacks/fitlog_callback.py +++ b/fastNLP/core/callbacks/fitlog_callback.py @@ -14,20 +14,21 @@ if _module_available('fitlog'): class FitlogCallback(HasMonitorCallback): """ 自动记录 ``evaluation`` 结果到 ``fitlog`` 中。会自动记录每一次 ``evaluate`` 后的结果;同时会根据 - ``monitor`` 记录最好的结果。另外,会自动将非 ``rank 0`` 上的 ``fitlog`` 设置为 ``debug`` 状态。同时还会在 ``fitlog`` 的 - ``other`` 列中记录一个 ``launch_time`` ,可以通过这个数值找到当前这个脚本的在 save_folder (如果有使用其它需要保存模型的 - ``Callback`` ,例如 :class:`~fastNLP.CheckpointCallback` )下的文件夹名称。 + ``monitor`` 记录最好的结果。另外,会自动将非 ``rank 0`` 上的 ``fitlog`` 设置为 ``debug`` 状态。同时还会在 ``fitlog`` 的 + ``other`` 列中记录一个 ``launch_time`` ,可以通过这个数值找到当前这个脚本的在 save_folder (如果有使用其它需要保存模型的 + ``Callback`` ,例如 :class:`~fastNLP.core.callbacks.CheckpointCallback` )下的文件夹名称。 :param monitor: 监控的 metric 值。 * 为 ``None`` - 将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 + 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 * 为 ``str`` - 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 - 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 - * 为 ``Callable`` - 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 - 的 ``monitor`` 值请返回 ``None`` 。 + 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 + 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 + * 为 :class:`Callable` + 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 + 的 ``monitor`` 值请返回 ``None`` 。 + :param larger_better: 是否是越大越好。 :param log_exception: 是否记录 ``exception`` 。 :param log_loss_every: 多少个 ``batch`` 记录一次 loss 到 ``fitlog`` 中。 @@ -44,7 +45,7 @@ class FitlogCallback(HasMonitorCallback): if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog fitlog.debug() super().on_after_trainer_initialized(trainer, driver) - fitlog.add_other('launch_time', os.environ['FASTNLP_LAUNCH_TIME']) + fitlog.add_other(name='launch_time', value=os.environ['FASTNLP_LAUNCH_TIME']) def on_sanity_check_end(self, trainer, sanity_check_res): super(FitlogCallback, self).on_sanity_check_end(trainer, sanity_check_res) diff --git a/fastNLP/core/callbacks/has_monitor_callback.py b/fastNLP/core/callbacks/has_monitor_callback.py index 0b57bf53..bb865bf8 100644 --- a/fastNLP/core/callbacks/has_monitor_callback.py +++ b/fastNLP/core/callbacks/has_monitor_callback.py @@ -26,19 +26,19 @@ class CanItemDataType(ABC): class ResultsMonitor: """ - 可用于监控某个数值,并通过 is_better_results() 等接口实现检测结果是否变得更好了。 + 可用于监控某个数值,并通过 :meth:`is_better_results` 等接口检测结果是否变得更好。 - :param monitor: 监控的 metric 值。 + :param monitor: 监控的 metric 值: * 为 ``None`` - 将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 + 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置); * 为 ``str`` 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 - 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 - * 为 ``Callable`` + 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ; + * 为 :class:`Callable` 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 - 的 ``monitor`` 值请返回 ``None`` 。 - :param larger_better: monitor 是否时越大越好 + 的 ``monitor`` 值请返回 ``None`` ; + :param larger_better: monitor 是否为越大越好; """ def __init__(self, monitor:Union[Callback, str], larger_better:bool=True): self.set_monitor(monitor, larger_better) @@ -60,7 +60,7 @@ class ResultsMonitor: def itemize_results(self, results): """ - 将结果中有 .item() 方法的都调用一下,使得 tensor 类型的数据转为 python 内置类型。 + 执行结果中所有对象的 :meth:`item` 方法(如果没有则忽略),使得 Tensor 类型的数据转为 python 内置类型。 :param results: :return: @@ -69,10 +69,10 @@ class ResultsMonitor: def get_monitor_value(self, results:Dict)->Union[float, None]: """ - 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用 最长公共字符串算法 匹配的方式寻找。 + 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用 **最长公共字符串算法** 匹配的方式寻找。 - :param results: 评测结果。 - :return: 如果为 None ,表明此次没有找到合适的monitor + :param results: 评测结果; + :return: monitor 的值;如果为 ``None`` ,表明此次没有找到合适的monitor; """ if len(results) == 0 or self.monitor is None: return None @@ -100,10 +100,10 @@ class ResultsMonitor: def is_better_monitor_value(self, monitor_value: float, keep_if_better=True): """ - 检测 monitor_value 是否是更好的 + 检测 ``monitor_value`` 是否是更好的 - :param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False - :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 + :param monitor_value: 待检查的 ``monitor_value`` 。如果为 ``None`` ,返回 False; + :param keep_if_better: 如果传入的 ``monitor_value`` 值更好,则将其保存下来; :return: """ if monitor_value is None: @@ -115,10 +115,10 @@ class ResultsMonitor: def is_better_results(self, results, keep_if_better=True): """ - 检测给定的 results 是否比上一次更好,如果本次 results 中没有找到相关的monitor 返回 False。 + 检测给定的 ``results`` 是否比上一次更好,如果本次 results 中没有找到相关的 monitor 返回 ``False``。 - :param results: evaluation 结果。 - :param keep_if_better: 当返回为 True 时,是否保存到 self.monitor_value 中。 + :param results: evaluation 结果; + :param keep_if_better: 当返回为 ``True`` 时,是否保存到 ``self.monitor_value`` 中; :return: """ monitor_value = self.get_monitor_value(results) @@ -128,7 +128,7 @@ class ResultsMonitor: def is_former_monitor_value_better(self, monitor_value1, monitor_value2): """ - 传入的两个值中,是否monitor_value1的结果更好。 + 传入的两个值中,是否 ``monitor_value1`` 的结果更好。 :param monitor_value1: :param monitor_value2: @@ -149,7 +149,7 @@ class ResultsMonitor: @property def monitor_name(self): """ - 返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。 + 返回 monitor 的名字,如果 monitor 是个 Callable 的函数,则返回该函数的名称。 :return: """ @@ -171,7 +171,7 @@ class ResultsMonitor: @property def log_name(self) -> str: """ - 内部用于打印信息使用 + 内部用于打印当前类别信息使用 :return: """ @@ -185,20 +185,20 @@ class ResultsMonitor: class HasMonitorCallback(ResultsMonitor, Callback): """ 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 - (1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 + (1)判断 monitor 合法性;(2)在需要时, 根据 trainer 的 monitor 设置自己的 monitor 名称。 - :param monitor: 监控的 metric 值。 + :param monitor: 监控的 metric 值: * 为 ``None`` - 将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 + 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置); * 为 ``str`` 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 - 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 - * 为 ``Callable`` + 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ; + * 为 :class:`Callable` 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 - 的 ``monitor`` 值请返回 ``None`` 。 - :param larger_better: monitor 是否时越大越好 - :param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 + 的 ``monitor`` 值请返回 ``None`` ; + :param larger_better: monitor 是否为越大越好; + :param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 ``True`` ,且没检测到设置 monitor 会报错; """ def __init__(self, monitor, larger_better, must_have_monitor=False): super().__init__(monitor, larger_better) @@ -230,20 +230,20 @@ class HasMonitorCallback(ResultsMonitor, Callback): class ExecuteOnceBetterMonitor(HasMonitorCallback): """ - 当监控的 monitor 结果更好的时候,调用 execute_fn 函数。 + 当监控的 ``monitor`` 结果更好的时候,调用 ``execute_fn`` 函数。 - :param monitor: 监控的 metric 值。 + :param monitor: 监控的 metric 值: * 为 ``None`` - 将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 + 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 ``monitor`` 值(如果有设置); * 为 ``str`` 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 - 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 - * 为 ``Callable`` + 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ; + * 为 :class:`Callable` 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 - 的 ``monitor`` 值请返回 ``None`` 。 - :param larger_better: monitor 是否时越大越好 - :param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。 + 的 ``monitor`` 值请返回 ``None`` ; + :param larger_better: monitor 是否是越大越好; + :param execute_fn: 一个可执行的函数,不接受任何参数,没有返回值。在 monitor 取得更好结果的时候会调用; """ def __init__(self, monitor, larger_better, execute_fn): super().__init__(monitor, larger_better, must_have_monitor=True) diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index 4f52720f..2bd41b5a 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -19,25 +19,25 @@ class LoadBestModelCallback(HasMonitorCallback): 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型,默认会在加载之后删除权重文件。仅在训练正常结束的时候才能加载 最好的模型。 - :param monitor: 监控的 metric 值。 + :param monitor: 监控的 metric 值: * 为 ``None`` - 将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 + 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置); * 为 ``str`` 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 - 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 - * 为 ``Callable`` + 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ; + * 为 :class:`Callable` 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 - 的 ``monitor`` 值请返回 ``None`` 。 - :param larger_better: 该 metric 值是否是越大越好。 + 的 ``monitor`` 值请返回 ``None`` ; + :param larger_better: 该 metric 值是否是越大越好; :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 - 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 - :param only_state_dict: 是否只保存模型的参数。当 model_save_fn 不为空时,该值无效。 - :param model_save_fn: 保存 model 的函数,与 model_load_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, - 请在函数内完成对模型的保存。 - :param model_load_fn: 加载 model 的函数,与 model_save_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, - 请在函数内完成对模型的加载。 - :param delete_after_train: 在训练结束后是否删掉模型。 + 不同的机器均可访问当该路径。当 ``model_save_fn`` 不为 None 时该值一定不能为空; + :param only_state_dict: 是否只保存模型的参数。当 ``model_save_fn`` 不为空时,该值无效; + :param model_save_fn: 保存 model 的函数,与 ``model_load_fn`` 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, + 请在函数内完成对模型的保存; + :param model_load_fn: 加载 model 的函数,与 ``model_save_fn`` 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, + 请在函数内完成对模型的加载; + :param delete_after_train: 在训练结束后是否删掉模型; """ def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool = True, only_state_dict:bool = True, save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, @@ -105,14 +105,16 @@ class LoadBestModelCallback(HasMonitorCallback): def on_train_end(self, trainer): if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 - if self.real_save_folder: - logger.info(f"Loading best model from {self.real_save_folder} with {self.monitor_name}: {self.monitor_value}...") - trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, - model_load_fn=self.model_load_fn) - else: - logger.info(f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...") - self.buffer.seek(0) - trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) + # 如果是分布式且报错了,就不要加载了,防止barrier的问题 + if not (trainer.driver.is_distributed() and self.encounter_exception): + if self.real_save_folder: + logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...") + trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, + model_load_fn=self.model_load_fn) + else: + logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...") + self.buffer.seek(0) + trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) if self.delete_after_after: if not self.encounter_exception: # 防止出现死锁。 trainer.driver.barrier() diff --git a/fastNLP/core/callbacks/lr_scheduler_callback.py b/fastNLP/core/callbacks/lr_scheduler_callback.py index 37d089bd..3d3f4a0f 100644 --- a/fastNLP/core/callbacks/lr_scheduler_callback.py +++ b/fastNLP/core/callbacks/lr_scheduler_callback.py @@ -7,11 +7,11 @@ __all__ = [ class LRSchedCallback(Callback): """ - 根据 step_on 参数在合适的时机调用 scheduler 的 step 函数。 + 根据 ``step_on`` 参数在合适的时机调用 scheduler 的 step 函数。 - :param scheduler: 实现了 step() 函数的对象 - :param step_on: 可选 ['batch', 'epoch'] 表示在何时调用 scheduler 的 step 函数。如果为 batch 的话在每次更新参数 - 之前调用;如果为 epoch 则是在一个 epoch 运行结束后调用。 + :param scheduler: 实现了 :meth:`step` 函数的对象; + :param step_on: 可选 ``['batch', 'epoch']`` 表示在何时调用 scheduler 的 step 函数。如果为 ``batch`` 的话在每次更新参数 + 之前调用;如果为 ``epoch`` 则是在一个 epoch 运行结束后调用; """ def __init__(self, scheduler, step_on:str='batch'): assert hasattr(scheduler, 'step') and callable(scheduler.step), "The scheduler object should have a " \ @@ -19,7 +19,7 @@ class LRSchedCallback(Callback): self.scheduler = scheduler self.step_on = 0 if step_on == 'batch' else 1 - def on_before_optimizers_step(self, trainer, optimizers): + def on_after_optimizers_step(self, trainer, optimizers): if self.step_on == 0: self.scheduler.step() diff --git a/fastNLP/core/callbacks/more_evaluate_callback.py b/fastNLP/core/callbacks/more_evaluate_callback.py index 690146a2..04b35984 100644 --- a/fastNLP/core/callbacks/more_evaluate_callback.py +++ b/fastNLP/core/callbacks/more_evaluate_callback.py @@ -12,10 +12,10 @@ from .topk_saver import TopkSaver class MoreEvaluateCallback(HasMonitorCallback): """ - 当评测时需要调用不同的 evaluate_fn (例如在大部分生成任务中,一般使用训练 loss 作为训练过程中的 evaluate ;但同时在训练到 - 一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 evaluate_fn ),只使用 Trainer - 无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 topk 以及 - topk_monitor 等相关参数。可以通过 evaluate_every 或 watch_monitor 控制触发进行 evaluate 的条件。 + 当评测时需要调用不同的 ``evaluate_fn`` (例如在大部分生成任务中,一般使用训练 loss 作为训练过程中的 evaluate ;但同时在训练到 + 一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 ``evaluate_fn`` ),只使用 Trainer + 无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 ``topk`` 以及 + ``topk_monitor`` 等相关参数。可以通过 ``evaluate_every`` 或 ``watch_monitor`` 控制触发进行 evaluate 的条件。 如果设置了 evaluate 结果更好就保存的话,将按如下文件结构进行保存:: @@ -30,7 +30,7 @@ class MoreEvaluateCallback(HasMonitorCallback): 1. 为负数时表示每隔几个 ``epoch`` evaluate 一次; 2. 为正数则表示每隔几个 ``batch`` evaluate 一次; 3. 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 - 返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 ``batch`` 结束后调用该函数判断是否需要 evaluate; + 返回一个 bool 值,返回为 ``True`` 说明需要进行 evaluate ;将在每个 ``batch`` 结束后调用该函数判断是否需要 evaluate; .. note:: @@ -45,32 +45,41 @@ class MoreEvaluateCallback(HasMonitorCallback): 该函数表示当每经过 1000 个 batch,``Trainer`` 中内置的 ``Evaluator`` 就会验证一次; 另一个需要注意的事情在于该函数会在每一次 batch 的结尾进行调用,当该函数返回 ``True`` 时,``Evaluator`` 才会进行验证; - :param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 - 意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 - 取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最 - 匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor - 的结果,如果当前结果中没有相关的monitor 值请返回 None 。 - :param watch_monitor_larger_better: watch_monitor 是否越大越好。 - :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 - `model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 - 找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 + :param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 ``None`` ,``evaluate_every`` 失效。本参数的 + 意义是,当检测到 Trainer 中 evaluate results 的 ``{watch_monitor}`` 的结果更好时,则进行一次 evaluate 。该参数有两种 + 取值: + + 1. ``str`` 类型,含义为监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 **最长公共字符串算法** 找到最 + 匹配的那个作为 monitor ; + 2. 一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor + 的结果,如果当前结果中没有相关的monitor 值请返回 ``None`` ; + :param watch_monitor_larger_better: ``watch_monitor`` 是否越大越好; + :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 :meth:`model.evaluate_step` 还是 + :meth:`model.forward`: + + 1. 如果该值是 ``None``,那么我们会默认使用 :meth:`model.evaluate_step` 当做前向传播的函数,如果 + 在模型中没有找到该方法,则使用 :meth:`model.forward` 函数; + 2. 如果为 ``str`` 类型,则尝试从 model 中寻找该方法,找不到则报错; :param num_eval_sanity_batch: 在初始化 Evaluator 后运行多少个 sanity check 的 batch ,检测一下。 - :param topk: 如果需要根据当前 callback 中的 evaluate 结果保存模型或 Trainer ,可以通过设置 tokp 实现。(1)为 -1 表示每次 - evaluate 后都保存;(2)为 0 (默认),表示不保存;(3)为整数,表示保存性能最 topk 个。 + :param topk: 如果需要根据当前 callback 中的 evaluate 结果保存模型或 Trainer ,可以通过设置 topk 实现: + + 1. 为 ``-1`` 表示每次 evaluate 后都保存; + 2. 为 ``0`` (默认),表示不保存; + 3. 为整数,表示保存性能最好的 ``topk`` 个。 :param topk_monitor: 如果需要根据当前 callback 中的 evaluate 结果保存。这个参数是指在当前 callback 中的 evaluate 结果寻找 - :param topk_larger_better: topk_monitor 的值是否时越大越好。 + :param topk_larger_better: ``topk_monitor`` 的值是否是越大越好。 :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 - 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 - :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 - :param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 - 保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 + 时间戳文件夹中。如果为 ``None`` ,默认使用当前文件夹。 + :param only_state_dict: 保存模型时是否只保存 state_dict 。当 ``model_save_fn`` 不为 ``None`` 时,该参数无效。 + :param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 + 保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 - 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 - :param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 - ``fastnlp_evaluate_results.json`` 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 - :param save_kwargs: dict。更多的保存相关的参数。 - :param kwargs: 其它与 Evaluator 相关的初始化参数,如果不传入,将从 Trainer 中获取。 + 如果传入了 ``model_save_fn`` 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 + :param save_evaluate_results: 是否保存 evaluate 的结果。如果为 ``True`` ,在保存 topk 模型的 folder 中还将额外保存一个 + ``fastnlp_evaluate_results.json`` 文件,记录当前的 results。仅在设置了 ``topk`` 的场景下有效,默认为 True 。 + :param save_kwargs: 一个字典,表示更多的保存相关的参数。 + :param kwargs: 其它与 :class:`~fastNLP.core.controllers.Evaluator` 相关的初始化参数,如果不传入,将从 :class:`~fastNLP.core.controllers.Trainer` 中获取。 """ def __init__(self, dataloaders, metrics:Dict, evaluate_every:Optional[Union[int, Callable]]=-1, watch_monitor:Union[str, Callable]=None, watch_monitor_larger_better:bool=True, diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 36524a6b..681ea4d3 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -1,5 +1,4 @@ import json -import sys from typing import Union __all__ = [ @@ -16,8 +15,25 @@ from fastNLP.core.log import logger class ProgressCallback(HasMonitorCallback): + def __init__(self, monitor, larger_better, must_have_monitor=False): + super(ProgressCallback, self).__init__(monitor=monitor, larger_better=larger_better, + must_have_monitor=must_have_monitor) + self.best_monitor_epoch = -1 + self.best_monitor_step = -1 + self.best_results = None + + def record_better_monitor(self, trainer, results): + self.best_monitor_step = trainer.global_forward_batches + self.best_monitor_epoch = trainer.cur_epoch_idx + self.best_results = self.itemize_results(results) + def on_train_end(self, trainer): - f_rich_progress.stop() + if self.best_monitor_epoch != -1: + msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \ + f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}." + if self.best_results is not None: + msg = msg + ' The evaluation result: \n' + str(self.best_results) + logger.info(msg) @property def name(self): # progress bar的名称 @@ -44,21 +60,22 @@ def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> Prog class RichCallback(ProgressCallback): """ - 在训练过程中打印 rich progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 - 参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 + 在训练过程中打印 *rich* progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 + 参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 :param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 * 为 ``None`` - 将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 + 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 * 为 ``str`` 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 - * 为 ``Callable`` + * 为 :class:`Callable` 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 的 ``monitor`` 值请返回 ``None`` 。 + :param larger_better: 是否是 monitor 的结果越大越好。 :param format_json: 是否格式化 json 再打印 """ @@ -97,6 +114,7 @@ class RichCallback(ProgressCallback): advance=None, completed=trainer.cur_epoch_idx, refresh=True) def on_train_end(self, trainer): + super(RichCallback, self).on_train_end(trainer) self.clear_tasks() def on_before_backward(self, trainer, outputs): @@ -121,8 +139,8 @@ class RichCallback(ProgressCallback): text_style = '' characters = '-' if self.monitor is not None: - monitor_value = self.get_monitor_value(results) - if self.is_better_monitor_value(monitor_value, keep_if_better=True): + if self.is_better_results(results, keep_if_better=True): + self.record_better_monitor(trainer, results) if abs(self.monitor_value) != float('inf'): rule_style = 'spring_green3' text_style = '[bold]' @@ -131,8 +149,11 @@ class RichCallback(ProgressCallback): self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " f"Batch:{trainer.batch_idx_in_epoch}", style=rule_style, characters=characters) + results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if + not key.startswith('_')} if self.format_json: - self.progress_bar.console.print_json(json.dumps(trainer.driver.tensor_to_numeric(results))) + results = json.dumps(results) + self.progress_bar.console.print_json(results) else: self.progress_bar.print(results) @@ -149,26 +170,26 @@ class RichCallback(ProgressCallback): class RawTextCallback(ProgressCallback): + """ + 通过向命令行打印进度的方式显示。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 + + :param print_every: 多少个 batch 更新一次显示。 + :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 + :param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 + + * 为 ``None`` + 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 + * 为 ``str`` + 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 + 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 + * 为 :class:`Callable` + 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 + 的 ``monitor`` 值请返回 ``None`` 。 + :param larger_better: 是否是monitor的结果越大越好。 + :param format_json: 是否format json再打印 + """ def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, format_json=True): - """ - 通过向命令行打印进度的方式显示 - - :param print_every: 多少个 batch 更新一次显示。 - :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 - :param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 - - * 为 ``None`` - 将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 - * 为 ``str`` - 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 - 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 - * 为 ``Callable`` - 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 - 的 ``monitor`` 值请返回 ``None`` 。 - :param larger_better: 是否是monitor的结果越大越好。 - :param format_json: 是否format json再打印 - """ super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) self.print_every = print_every self.task2id = {} @@ -201,18 +222,19 @@ class RawTextCallback(ProgressCallback): base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' text = '' if self.monitor is not None: - monitor_value = self.get_monitor_value(results) - if self.is_better_monitor_value(monitor_value, keep_if_better=True): + if self.is_better_results(results, keep_if_better=True): + self.record_better_monitor(trainer, results) if abs(self.monitor_value) != float('inf'): text = '+'*self.num_signs + base_text + '+'*self.num_signs if len(text) == 0: text = '-'*self.num_signs + base_text + '-'*self.num_signs logger.info(text) + results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if + not key.startswith('_')} if self.format_json: - logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) - else: - logger.info(results) + results = json.dumps(results) + logger.info(results) @property def name(self): # progress bar的名称 @@ -221,19 +243,20 @@ class RawTextCallback(ProgressCallback): class TqdmCallback(ProgressCallback): """ - 在训练过程中打印 tqdm progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 - 参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 + 在训练过程中打印 *tqdm* progress bar 的 callback 。在 Trainer 中,如果设置了 ``progress_bar='tqdm'`` 就会使用 + 这个 callback 来显示进度。如果需要定制这个 Callback 的参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在 + 打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 :param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 * 为 ``None`` - 将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 + 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 * 为 ``str`` 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 - * 为 ``Callable`` + * 为 :class:`Callable` 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 的 ``monitor`` 值请返回 ``None`` 。 :param larger_better: 是否是 monitor 的结果越大越好。 @@ -266,6 +289,7 @@ class TqdmCallback(ProgressCallback): self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True) def on_train_end(self, trainer): + super(TqdmCallback, self).on_train_end(trainer) self.clear_tasks() def on_before_backward(self, trainer, outputs): @@ -287,18 +311,19 @@ class TqdmCallback(ProgressCallback): base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' text = '' if self.monitor is not None: - monitor_value = self.get_monitor_value(results) - if self.is_better_monitor_value(monitor_value, keep_if_better=True): + if self.is_better_results(results, keep_if_better=True): + self.record_better_monitor(trainer, results) if abs(self.monitor_value) != float('inf'): text = '+'*self.num_signs + base_text + '+'*self.num_signs if len(text) == 0: text = '-'*self.num_signs + base_text + '-'*self.num_signs logger.info(text) + results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if + not key.startswith('_')} if self.format_json: - logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) - else: - logger.info(results) + results = json.dumps(results) + logger.info(results) def clear_tasks(self): for key, taskid in self.task2id.items(): diff --git a/fastNLP/core/callbacks/timer_callback.py b/fastNLP/core/callbacks/timer_callback.py new file mode 100644 index 00000000..27dbe538 --- /dev/null +++ b/fastNLP/core/callbacks/timer_callback.py @@ -0,0 +1,152 @@ +import time +from .callback import Callback +from ..log import logger +__all__ = ['TimerCallback'] + + +class _Timer: + """Timer.""" + + def __init__(self, name): + self.name_ = name + self.elapsed_ = 0.0 + self.started_ = False + self.start_time = time.time() + + def start(self): + """Start the timer.""" + assert not self.started_, f'{self.name_} timer has already been started' + self.start_time = time.time() + self.started_ = True + + def stop(self): + """Stop the timer.""" + assert self.started_, f'{self.name_} timer is not started' + self.elapsed_ += (time.time() - self.start_time) + self.started_ = False + + def reset(self): + """Reset timer.""" + self.elapsed_ = 0.0 + self.started_ = False + + def elapsed(self, reset=True): + """Calculate the elapsed time.""" + started_ = self.started_ + # If the timing in progress, end it first. + if self.started_: + self.stop() + # Get the elapsed time. + elapsed_ = self.elapsed_ + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if started_: + self.start() + return elapsed_ + + +class Timers: + """Group of timers.""" + + def __init__(self): + self.timers = {} + + def __call__(self, name): + if name not in self.timers: + self.timers[name] = _Timer(name) + return self.timers[name] + + def __contains__(self, item): + return item in self.timers + + def reset(self): + for timer in self.timers.values(): + timer.reset() + + +class TimerCallback(Callback): + """ + 这个 callback 的作用是打印训练过程中的相关时间信息,例如训练时长、评测时长、总时长等 + + """ + def __init__(self, print_every=-1, time_ndigit=3): + """ + + :param print_every: 在哪个时候打印时间信息。 + + * *负数*: 表示每隔多少 epoch 结束打印一次; + * *0*: 表示整个训练结束才打印; + * *正数*: 每隔多少个 step 打印一次; + + :param time_ndigit: 保留多少位的小数 + """ + assert isinstance(print_every, int), "print_every must be an int number." + self.timers = Timers() + self.print_every = print_every + self.time_ndigit = time_ndigit + + def on_train_begin(self, trainer): + self.timers('total').start() + self.timers('train').start() + + def on_fetch_data_begin(self, trainer): + self.timers('fetch-data').start() + + def on_fetch_data_end(self, trainer): + self.timers('fetch-data').stop() + + def on_train_batch_begin(self, trainer, batch, indices): + self.timers('forward').start() + + def on_before_backward(self, trainer, outputs): + self.timers('forward').stop() + self.timers('backward').start() + + def on_after_backward(self, trainer): + self.timers('backward').stop() + + def on_before_optimizers_step(self, trainer, optimizers): + self.timers('optimize').start() + + def on_after_optimizers_step(self, trainer, optimizers): + self.timers('optimize').stop() + + def on_evaluate_begin(self, trainer): + self.timers('train').stop() + self.timers('evaluate').start() + + def on_evaluate_end(self, trainer, results): + self.timers('evaluate').stop() + self.timers('train').start() + + def format_timer(self, reset=True): + line = '' + timers = ['fetch-data', 'forward', 'backward', 'optimize', 'evaluate', 'train', 'total'] + for timer_name in timers: + if not timer_name in self.timers: + continue + timer = self.timers(timer_name) + elapsed = round(timer.elapsed(reset=reset), self.time_ndigit) + if elapsed != 0: + line = line + f', {timer_name}: {elapsed}s' + return line + + def on_train_batch_end(self, trainer): + if self.print_every>0 and trainer.global_forward_batches % self.print_every == 0: + line = self.format_timer() + logger.info(f"Running {self.print_every} batches{line}") + + def on_train_epoch_end(self, trainer): + if self.print_every < 0 and trainer.cur_epoch_idx % abs(self.print_every) == 0: + line = self.format_timer() + logger.info(f"Running {abs(self.print_every)} epochs{line}") + + def on_train_end(self, trainer): + if self.print_every == 0: + line = self.format_timer() + logger.info(f"Training finished{line}") + + + diff --git a/fastNLP/core/callbacks/topk_saver.py b/fastNLP/core/callbacks/topk_saver.py index 98e5c269..1ac23b77 100644 --- a/fastNLP/core/callbacks/topk_saver.py +++ b/fastNLP/core/callbacks/topk_saver.py @@ -24,8 +24,8 @@ class Saver: - folder_name # 由 save() 调用时传入。 :param folder: 保存在哪个文件夹下,默认为当前 folder 下。 - :param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 - 保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 + :param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 + 保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 :param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。 :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 @@ -178,28 +178,28 @@ class TopkSaver(ResultsMonitor, Saver): - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 - :param topk: 保存 topk 多少的模型,-1 为保存所有模型;0 为都不保存;大于 0 的数为保存 topk 个。 - :param monitor: 监控的 metric 值。 + :param topk: 保存表现最好的 ``topk`` 个模型,-1 为保存所有模型;0 为都不保存;大于 0 的数为保存 ``topk`` 个; + :param monitor: 监控的 metric 值: * 为 ``None`` - 将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 + 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 * 为 ``str`` 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 - * 为 ``Callable`` + * 为 :class:`Callable` 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 的 ``monitor`` 值请返回 ``None`` 。 :param larger_better: 该 monitor 是否越大越好。 :param folder: 保存在哪个文件夹下,默认为当前 folder 下。 - :param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 - 保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 + :param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 + 保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 - :param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。 + :param only_state_dict: 保存时是否仅保存权重,在 ``model_save_fn`` 不为 None 时无意义。 :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 - 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 + 如果传入了 ``model_save_fn`` 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 :param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 - ``fastnlp_evaluate_results.json`` 文件,记录当前的 metric results 。仅在设置了 topk 的场景下有用,默认为 True 。 - :param kwargs: 更多需要传递给 Trainer.save_checkpoint() 或者 Trainer.save_model() 接口的参数。 + ``fastnlp_evaluate_results.json`` 文件,记录当前的 metric results 。仅在设置了 ``topk`` 的场景下有用,默认为 True 。 + :param kwargs: 更多需要传递给 :meth:`Trainer.save_checkpoint` 或者 :meth:`Trainer.save_model` 接口的参数。 """ def __init__(self, topk:int=0, monitor:str=None, larger_better:bool=True, folder:str=None, save_object:str='model', only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True, @@ -220,7 +220,7 @@ class TopkSaver(ResultsMonitor, Saver): @rank_zero_call def save_topk(self, trainer, results: Dict) -> Optional[str]: """ - 根据 results 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。如果返回为 None ,则说明此次没有满足 + 根据 ``results`` 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。如果返回为 ``None`` ,则说明此次没有满足 topk 要求,没有发生保存。 :param trainer: diff --git a/fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py b/fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py index c986e4e4..10ef7894 100644 --- a/fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py +++ b/fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py @@ -11,17 +11,17 @@ if _NEED_IMPORT_FAIRSCALE: class TorchGradClipCallback(Callback): r""" - 在每次 optimizer update 之前将 parameter 进行 clip 。 + 在每次 :func:`optimizer.step` 之前对参数的梯度进行截断。 - :param clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 - :param clip_type: 支持'norm', 'value'两种: + :param clip_value: 将梯度限制到 [-clip_value, clip_value] 之间。``clip_value`` 应该为正数; + :param clip_type: 应为 ``'norm'``, ``'value'`` 中的一个: - 1. 'norm', 将gradient的norm rescale到[-clip_value, clip_value] - 2. 'value', 将gradient限制在[-clip_value, clip_value], - 小于-clip_value的gradient被赋值为-clip_value;大于clip_value的gradient被赋值为clip_value. + 1. 为 ``'norm'`` 时, 将梯度的范数限制在 [-clip_value, clip_value] 之间; + 2. 为 ``'value'`` 时,, 将梯度限制在 [-clip_value, clip_value] 之间,小于 ``-clip_value`` + 的梯度被赋值为 ``-clip_value``,大于 ``clip_value`` 的梯度被赋值为 ``clip_value``; - :param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。 - 如果为None则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。 + :param parameters: 参数,一般通过 :func:`model.parameters` 获得。 + 如果为 ``None`` 则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。 """ def __init__(self, clip_value:int=1, clip_type:str='norm', parameters:Union["torch.Tensor", List["torch.Tensor"]]=None): diff --git a/fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py b/fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py index 29c1aa2b..24474b64 100644 --- a/fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py +++ b/fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py @@ -9,14 +9,14 @@ from ..callback import Callback class TorchWarmupCallback(Callback): r""" - 调整 learning rate 的 callback 。 + 调整学习率的 **callback** 。 - :param warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float, - 如0.1, 则前10%的step是按照schedule策略调整learning rate。 - :param schedule: 以哪种方式调整。 + :param warmup: 如果 ``warmup`` 为整数,则在该 step 之前,学习率根据 ``schedule`` 的策略变化; 如果 ``warmup`` 为 ``float``, + 如 0.1, 则前 10% 的 step 是按照 ``schedule`` 策略调整。 + :param schedule: 对学习率进行调整的策略: - 1. linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后warmup的step下降到0; - 2. constant前warmup的step上升到指定learning rate,后面的step保持learning rate. + 1. *linear* -- 前 ``warmup`` 的 step 上升到指定的学习率(从 Trainer 中 optimizer 处获取), 在剩下的 step 中下降到 0; + 2. *constant* -- 前 ``warmup`` 的 step 上升到指定的学习率,余下的 step 保持不变。 """ def __init__(self, warmup:Union[int, float]=0.1, schedule:str='constant'): super().__init__() diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py index dab5028c..dde3e2af 100644 --- a/fastNLP/core/collators/collator.py +++ b/fastNLP/core/collators/collator.py @@ -18,7 +18,7 @@ from .packer_unpacker import SequencePackerUnpacker, SinglePackerUnpacker, Mappi NestedMappingPackerUnpacker sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 -SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] +SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'oneflow', 'numpy', 'raw', 'auto', None] # 由于 jittor DataLoader 存在自动的 to_jittor 的转换,所以只需要 collate 成为 numpy 就行 AUTO_BACKEND_MAPPING = {'jittor': 'numpy'} @@ -85,27 +85,33 @@ def _get_backend() -> str: class Collator: """ 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 - 哦安定一个 field 是否可以 pad 的方式为:(1)当前这个 field 是否所有对象都是一样的数据类型;(因此,如果某 field 的数据有些是float - 有些是 int 将知道该 field 被判定为不可 pad 类型。)(2)当前这个 field 是否每个 sample 都具有一样的深度;(因此,例如有个 field 的 - 数据转为 batch 类型后为 [1, [1,2]], 会被判定为不可 pad ,因为第一个 sample 与 第二个 sample 深度不同)(3)当前这个 field 的类 - 型是否是可以 pad (例如 str 类型的数据)。可以通过设置 logger.setLevel('debug') 来打印是判定不可 pad 的原因。 + 判定一个 field 是否可以 pad 的方式为: + + 1. 当前这个 field 是否所有对象都是一样的数据类型;比如,如果某 field 的数据有些是 float ,有些是 int ,则该 field 将被 + 判定为不可 pad 类型; + 2. 当前这个 field 是否每个 sample 都具有一样的深度;比如,如果某 field 的数据转为 batch 类型后为 ``[1, [1,2]]``, 则会 + 被判定为不可 pad ,因为第一个 sample 与 第二个 sample 深度不同; + 3. 当前这个 field 的类型是否是可以 pad (例如 str 类型的数据)。可以通过设置 ``logger.setLevel('debug')`` 来打印是判定不可 + pad 的原因。 .. note:: - ``Collator`` 的原理是使用第一个 ``batch`` 的数据尝试推断每个``field``应该使用哪种类型的 ``Padder``,如果第一个 ``batch`` - 的数据刚好比较特殊,可能导致在之后的 pad 中遭遇失败,这种情况请通过 ``set_pad()`` 函数手动设置一下。 + ``Collator`` 的原理是使用第一个 ``batch`` 的数据尝试推断每个 ``field`` 应该使用哪种类型的 ``Padder``,如果第一个 ``batch`` + 的数据刚好比较特殊,可能导致在之后的 pad 中遭遇失败,这种情况请通过 :meth:`set_pad` 函数手动设置一下。 - todo 补充 code example 。 + .. todo:: + + 补充 code example 。 - 如果需要将某个本可以 pad 的 field 设置为不可 pad ,则可以通过 :meth:`~fastNLP.Collator.set_pad` 的 pad_val 设置为 None 实现。 + 如果需要将某个本可以 pad 的 field 设置为不可 pad ,则可以通过 :meth:`~fastNLP.Collator.set_pad` 的 ``pad_val`` 设置为 ``None`` 实现。 如果需要某些 field 不要包含在 pad 之后的结果中,可以使用 :meth:`~fastNLP.Collator.set_ignore` 进行设置。 Collator 在第一次进行 pad 的时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应 的 Padder 给对应的 field 。 - :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 - 若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad - 的数据返回一定是 list 。 + :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ``['torch','jittor','paddle','oneflow','numpy','raw', 'auto', None]``。 + 若为 ``'auto'`` ,则在进行 pad 的时候会根据调用的环境决定其 ``backend`` 。该参数对不能进行 pad 的数据没有影响,无法 pad 的数据返回一定 + 是 :class:`list` 。 """ def __init__(self, backend='auto'): self.unpack_batch_func = None @@ -192,20 +198,20 @@ class Collator: """ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 - :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); - 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 - 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 - :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 - field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 - 无意义。 - :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, - torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 - :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 - batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch - 形式,输出将被直接作为结果输出。 - :return: 返回 Collator 自身 + :param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 + field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; + 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 + 如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, + 该值无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 + :param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, + :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 + 若 ``pad_val`` 为 ``None`` ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 + :return: 返回 Collator 自身; """ self._renew() @@ -275,8 +281,8 @@ class Collator: """ 设置可以 pad 的 field 默认 pad 为什么类型的 tensor - :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None], - 若为 auto ,则在进行 pad 的时候会自动根据调用的环境决定其 backend 。 + :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ``['torch','jittor','paddle','oneflow','numpy','raw', 'auto', None]``, + 若为 ``'auto'`` ,则在进行 pad 的时候会自动根据调用的环境决定其 ``backend`` ; :return: """ assert backend in SUPPORTED_BACKENDS @@ -285,14 +291,14 @@ class Collator: def set_ignore(self, *field_names) -> "Collator": """ - 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 + 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略:: >>> collator = Collator().set_ignore('field1', 'field2') - :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 - __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 - :return: 返回 Collator 自身 + :param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 + field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; + 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 + :return: 返回 Collator 自身; """ self._renew() input_field_names = [(field, field) if isinstance(field, tuple) else ((field,), field) diff --git a/fastNLP/core/collators/packer_unpacker.py b/fastNLP/core/collators/packer_unpacker.py index 2b78ea0a..7d9c23cd 100644 --- a/fastNLP/core/collators/packer_unpacker.py +++ b/fastNLP/core/collators/packer_unpacker.py @@ -2,6 +2,7 @@ from collections import defaultdict from functools import reduce from typing import Sequence, Mapping, Dict +__all__ = [] class MappingPackerUnpacker: @staticmethod @@ -70,7 +71,7 @@ class SequencePackerUnpacker: @staticmethod def unpack_batch(batch:Sequence[Sequence], ignore_fields, input_fields)->Dict: """ - 将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} + 将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [2, 2]} :param batch: 需要 unpack 的 batch 数据。 :param ignore_fields: 需要忽略的 field 。 diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py index b0a82849..41bcd8c0 100644 --- a/fastNLP/core/collators/padders/get_padder.py +++ b/fastNLP/core/collators/padders/get_padder.py @@ -10,18 +10,19 @@ from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPad from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder from .jittor_padder import JittorTensorPadder, JittorSequencePadder, JittorNumberPadder +from .oneflow_padder import OneflowTensorPadder, OneflowSequencePadder, OneflowNumberPadder from .exceptions import * def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->Padder: """ - 根据 参数 与 batch_field ,返回适合于当前 batch_field 的 padder 。 + 根据 参数 与 ``batch_field`` ,返回适合于当前 ``batch_field`` 的 *padder* 。 - :param batch_field: 将某 field 的内容组合成一个 batch 传入。 - :param pad_val: + :param batch_field: 将某 field 的内容组合成一个 batch 传入; + :param pad_val: :param backend: :param dtype: - :param field_name: 方便报错的。 + :param field_name: field 名称,方便在报错时显示; :return: """ try: @@ -91,6 +92,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'jittor': return JittorNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) + elif backend == 'oneflow': + return OneflowNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) else: raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).") @@ -105,6 +108,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'jittor': return JittorSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) + elif backend == 'oneflow': + return OneflowSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) else: raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).") @@ -121,6 +126,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'jittor': return JittorTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) + elif backend == 'oneflow': + return OneflowTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) else: raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).") diff --git a/fastNLP/core/collators/padders/jittor_padder.py b/fastNLP/core/collators/padders/jittor_padder.py index c9b36b89..6b37d61c 100644 --- a/fastNLP/core/collators/padders/jittor_padder.py +++ b/fastNLP/core/collators/padders/jittor_padder.py @@ -84,14 +84,14 @@ def _get_dtype(ele_dtype, dtype, class_name): class JittorNumberPadder(Padder): - def __init__(self, pad_val=0, ele_dtype=None, dtype=None): - """ - 可以将形如 [1, 2, 3] 这类的数据转为 jittor.Var([1, 2, 3]) + """ + 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``jittor.Var([1, 2, 3])`` - :param pad_val: 该值无意义 - :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。 - :param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等 - """ + :param pad_val: 该值无意义 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`jittor.Var` 类型; + :param dtype: 输出的数据的 dtype 是什么。如 :class:`jittor.long`, :class:`jittor.float32`, :class:`int`, :class:`float` 等; + """ + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) super().__init__(pad_val=pad_val, dtype=dtype) @@ -106,23 +106,23 @@ class JittorNumberPadder(Padder): class JittorSequencePadder(Padder): - def __init__(self, pad_val=0, ele_dtype=None, dtype=None): - """ - 将类似于 [[1], [1, 2]] 的内容 pad 为 jittor.Var([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 + """ + 可以将形如 ``[[1], [1, 2]]`` 这类的数据转为 ``jittor.Var([[1], [1, 2]])`` - :param pad_val: 需要 pad 的值。 - :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。 - :param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等 - """ + :param pad_val: 该值无意义 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`jittor.Var` 类型; + :param dtype: 输出的数据的 dtype 是什么。如 :class:`jittor.long`, :class:`jittor.float32`, :class:`int`, :class:`float` 等; + """ + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) super().__init__(pad_val=pad_val, dtype=dtype) @staticmethod def pad(batch_field, pad_val=0, dtype=None): """ - :param batch_field 输入的某个 field 的 batch 数据。 - :param pad_val 需要填充的值 - :dtype 数据的类型 + :param batch_field: 输入的某个 field 的 batch 数据。 + :param pad_val: 需要填充的值 + :param dtype: 数据的类型 """ tensor = get_padded_jittor_tensor(batch_field, dtype=dtype, pad_val=pad_val) return tensor @@ -131,11 +131,11 @@ class JittorSequencePadder(Padder): class JittorTensorPadder(Padder): def __init__(self, pad_val=0, ele_dtype=None, dtype=None): """ - 目前支持 [jittor.Var([3, 2], jittor.Var([1])] 类似的。若内部元素不为 jittor.Var ,则必须含有 tolist() 方法。 + 目前支持 ``[jittor.Var([3, 2], jittor.Var([1])]`` 类似的输入。若内部元素不为 :class:`jittor.Var` ,则必须含有 :meth:`tolist` 方法。 - :param pad_val: 需要 pad 的值。 - :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。 - :param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等 + :param pad_val: 需要 pad 的值; + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`jittor.Var` 类型; + :param dtype: 输出的数据的 dtype 是什么。如 :class:`jittor.long`, :class:`jittor.float32`, :class:`int`, :class:`float` 等 """ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) super().__init__(pad_val=pad_val, dtype=dtype) @@ -143,11 +143,11 @@ class JittorTensorPadder(Padder): @staticmethod def pad(batch_field, pad_val=0, dtype=None): """ - 将 batch_field 数据 转为 jittor.Var 并 pad 到相同长度。 + 将 ``batch_field`` 数据 转为 :class:`jittor.Var` 并 pad 到相同长度。 - :param batch_field 输入的某个 field 的 batch 数据。 - :param pad_val 需要填充的值 - :dtype 数据的类型 + :param batch_field: 输入的某个 field 的 batch 数据。 + :param pad_val: 需要填充的值 + :param dtype: 数据的类型 """ try: if not isinstance(batch_field[0], jittor.Var): diff --git a/fastNLP/core/collators/padders/numpy_padder.py b/fastNLP/core/collators/padders/numpy_padder.py index b6edba04..2f386978 100644 --- a/fastNLP/core/collators/padders/numpy_padder.py +++ b/fastNLP/core/collators/padders/numpy_padder.py @@ -18,9 +18,9 @@ def _get_dtype(ele_dtype, dtype, class_name): """ 用于检测数据的 dtype 类型, 根据内部和外部数据判断。 - :param ele_dtype 内部数据的类型 - :param dtype 数据外部类型 - :param class_name 类的名称 + :param ele_dtype: 内部数据的类型 + :param dtype: 数据外部类型 + :param class_name: 类的名称 """ if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " @@ -38,15 +38,15 @@ def _get_dtype(ele_dtype, dtype, class_name): class NumpyNumberPadder(Padder): """ - 可以将形如 [1, 2, 3] 这类的数据转为 np.array([1, 2, 3]) 。可以通过: + 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``np.array([1, 2, 3])`` 。可以通过:: >>> NumpyNumberPadder.pad([1, 2, 3]) 使用。 - :param pad_val: 该值无意义 - :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 - :param dtype: 输出的数据的 dtype 是什么 + :param pad_val: 该值无意义; + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型; + :param dtype: 输出的数据的 dtype ; """ def __init__(self, pad_val=0, ele_dtype=None, dtype=None): dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) @@ -54,21 +54,28 @@ class NumpyNumberPadder(Padder): @staticmethod def pad(batch_field, pad_val=0, dtype=None): + """ + 将 ``batch_field`` 数据 转为 :class:`numpy.array` 并 pad 到相同长度。 + + :param batch_field: 输入的某个 field 的 batch 数据。 + :param pad_val: 需要填充的值 + :param dtype: 数据的类型 + """ return np.array(batch_field, dtype=dtype) class NumpySequencePadder(Padder): """ - 将类似于 [[1], [1, 2]] 的内容 pad 为 np.array([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 + 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``np.array([[1, 0], [1, 2]])``, 可以 pad 多重嵌套的数据。 可以通过以下的方式直接使用: >>> NumpySequencePadder.pad([[1], [1, 2]], pad_val=-100, dtype=float) [[ 1. -100.] [ 1. 2.]] - :param pad_val: pad 的值是多少。 - :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 - :param dtype: 输出的数据的 dtype 是什么 + :param pad_val: pad 的值是多少; + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型; + :param dtype: 输出的数据的 dtype ; """ def __init__(self, pad_val=0, ele_dtype=None, dtype=None): dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) @@ -76,18 +83,25 @@ class NumpySequencePadder(Padder): @staticmethod def pad(batch_field, pad_val=0, dtype=None): + """ + 将 ``batch_field`` 数据 转为 :class:`numpy.array` 并 pad 到相同长度。 + + :param batch_field: 输入的某个 field 的 batch 数据。 + :param pad_val: 需要填充的值 + :param dtype: 数据的类型 + """ return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val) class NumpyTensorPadder(Padder): """ - pad 类似于 [np.array([3, 4]), np.array([1])] 的 field 。若内部元素不为 np.ndarray ,则必须含有 tolist() 方法。 + pad 类似于 ``[np.array([3, 4]), np.array([1])]`` 的 field 。若内部元素不为 :class:`np.ndarray` ,则必须含有 :meth:`tolist` 方法。 >>> NumpyTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) [[ 3. 4.] [ 1. -100.]] :param pad_val: pad 的值是多少。 - :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型。 :param dtype: 输出的数据的 dtype 是什么 """ def __init__(self, pad_val=0, ele_dtype=None, dtype=None): @@ -96,6 +110,13 @@ class NumpyTensorPadder(Padder): @staticmethod def pad(batch_field, pad_val=0, dtype=None): + """ + 将 ``batch_field`` 数据 转为 :class:`numpy.array` 并 pad 到相同长度。 + + :param batch_field: 输入的某个 field 的 batch 数据。 + :param pad_val: 需要填充的值 + :param dtype: 数据的类型 + """ try: if not isinstance(batch_field[0], np.ndarray): batch_field = [np.array(field.tolist(), dtype=dtype) for field in batch_field] diff --git a/fastNLP/core/collators/padders/oneflow_padder.py b/fastNLP/core/collators/padders/oneflow_padder.py new file mode 100644 index 00000000..5e235a0f --- /dev/null +++ b/fastNLP/core/collators/padders/oneflow_padder.py @@ -0,0 +1,225 @@ +__all__ = [ + 'OneflowNumberPadder', + 'OneflowSequencePadder', + 'OneflowTensorPadder' +] +from inspect import isclass +import numpy as np + +from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW + +if _NEED_IMPORT_ONEFLOW: + import oneflow + numpy_to_oneflow_dtype_dict = { + np.bool_: oneflow.bool, + np.uint8: oneflow.uint8, + np.int8: oneflow.int8, + np.int32: oneflow.int32, + np.int64: oneflow.int64, + np.float16: oneflow.float16, + np.float32: oneflow.float32, + np.float64: oneflow.float32, # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了 + } + number_to_oneflow_dtype_dict = { + float: oneflow.float32, # 因为 oneflow.tensor([1], dtype=float)是oneflow.float64 + int: oneflow.int64, + bool: oneflow.bool + } + +from .padder import Padder +from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class +from .exceptions import * + + +def is_oneflow_tensor(dtype): + """ + 判断是否为 oneflow 的 tensor + + :param dtype 数据的 dtype 类型 + """ + if not isclass(dtype) and isinstance(dtype, oneflow.dtype): + return True + return False + + +def _get_dtype(ele_dtype, dtype, class_name): + """ + 用于检测数据的 dtype 类型, 根据内部和外部数据判断。 + + :param ele_dtype: 内部数据的类型 + :param dtype: 数据外部类型 + :param class_name: 类的名称 + """ + if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_oneflow_tensor(ele_dtype))): + raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " + f"or numpy numbers or oneflow.Tensor but get `{ele_dtype}`.") + + if dtype is not None: + if not (is_oneflow_tensor(dtype) or is_number(dtype)): + raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " + f"or oneflow.dtype but get `{dtype}`.") + dtype = number_to_oneflow_dtype_dict.get(dtype, dtype) + else: + if ele_dtype is not None: + if (is_number(ele_dtype) or is_oneflow_tensor(ele_dtype)): + ele_dtype = number_to_oneflow_dtype_dict.get(ele_dtype, ele_dtype) + dtype = ele_dtype + elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 + dtype = numpy_to_oneflow_dtype_dict.get(ele_dtype.type) + elif is_numpy_generic_class(ele_dtype): + dtype = numpy_to_oneflow_dtype_dict.get(ele_dtype) + + return dtype + + +class OneflowNumberPadder(Padder): + """ + 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``oneflow.Tensor([1, 2, 3])``。 + + :param pad_val: 该值无意义; + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`oneflow.Tensor` 类型; + :param dtype: 输出的数据的 dtype,。如 :class:`oneflow.long`, :class:`oneflow.float32`, :class:`int`, :class:`float` 等; + """ + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): + dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val=0, dtype=None): + """ + 将 ``batch_field`` 数据 转为 :class:`oneflow.Tensor` 并 pad 到相同长度。 + + :param batch_field: 输入的某个 field 的 batch 数据。 + :param pad_val: 需要填充的值 + :param dtype: 数据的类型 + """ + return oneflow.tensor(batch_field, dtype=dtype) + + +class OneflowSequencePadder(Padder): + """ + 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``oneflow.Tensor([[1, 0], [1, 2]])``, 可以 pad 多重嵌套的数据。 + + :param pad_val: 需要 pad 的值; + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`oneflow.Tensor` 类型; + :param type: 输出的数据的 dtype,。如 :class:`oneflow.long`, :class:`oneflow.float32`, :class:`int`, :class:`float` 等; + """ + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): + dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val=0, dtype=None): + """ + 将 ``batch_field`` 数据 转为 :class:`oneflow.Tensor` 并 pad 到相同长度。 + + :param batch_field: 输入的某个 field 的 batch 数据。 + :param pad_val: 需要填充的值 + :param dtype: 数据的类型 + """ + tensor = get_padded_oneflow_tensor(batch_field, dtype=dtype, pad_val=pad_val) + return tensor + + +class OneflowTensorPadder(Padder): + """ + 目前支持 ``[oneflow.tensor([3, 2], oneflow.tensor([1])]`` 类似的输入,若内部元素不为 :class:`oneflow.Tensor` ,则必须含有 :meth:`tolist` 方法。 + + >>> OneflowTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) + [[ 3. 4.] + [ 1. -100.]] + >>> OneflowTensorPadder.pad([oneflow.LongTensor([3, 4]), oneflow.LongTensor([1])], pad_val=-100) + tensor([[ 3, 4], + [ 1, -100]]) + + :param pad_val: 需要 pad 的值。 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`oneflow.Tensor` 类型。 + :param dtype: 输出的数据的 dtype,。如 :class:`oneflow.long`, :class:`oneflow.float32`, :class:`int`, :class:`float` 等; + """ + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): + dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val=0, dtype=None): + """ + 将 ``batch_field`` 数据 转为 :class:`oneflow.Tensor` 并 pad 到相同长度。 + + :param batch_field: 输入的某个 field 的 batch 数据。 + :param pad_val: 需要填充的值 + :param dtype: 数据的类型 + """ + device = None + try: + if not isinstance(batch_field[0], oneflow.Tensor): + batch_field = [oneflow.tensor(field.tolist(), dtype=dtype) for field in batch_field] + else: + batch_field = [field.to(dtype) for field in batch_field] + device = batch_field[0].device + if dtype is None: + dtype = batch_field[0].dtype + except AttributeError: + raise RuntimeError(f"If the field is not a oneflow.Tensor (it is {type(batch_field[0])}), " + f"it must have tolist() method.") + + shapes = [field.shape for field in batch_field] + if len(batch_field) < 2: + max_shape = [len(batch_field)] + list(shapes[0]) + else: + max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] + + tensor = oneflow.full(max_shape, value=pad_val, dtype=dtype, device=device) + for i, field in enumerate(batch_field): + slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) + tensor[slices] = field + return tensor + + +def fill_tensor(batch_field, padded_batch, dtype): + """ + 将 batch_field 中的值填入到 tensor 中。 + + :param batch_field: 需要填充进入 array 中的内容 + :param padded_batch: 待填充的 tensor + :param dtype: 数据的类别 + + :return: + """ + if padded_batch.ndim == 2: + for i, content_i in enumerate(batch_field): + padded_batch[i, :len(content_i)] = oneflow.tensor(content_i, dtype=dtype) + elif padded_batch.ndim == 3: + for i, content_i in enumerate(batch_field): + for j, content_ii in enumerate(content_i): + padded_batch[i, j, :len(content_ii)] = oneflow.tensor(content_ii, dtype=dtype) + elif padded_batch.ndim == 4: + try: # 应该是图像,所以直接应该就 ok 了。 + padded_batch = oneflow.tensor(batch_field) + except: + for i, content_i in enumerate(batch_field): + for j, content_ii in enumerate(content_i): + for k, content_iii in enumerate(content_ii): + padded_batch[i, j, k, :len(content_iii)] = oneflow.tensor(content_iii, dtype=dtype) + elif padded_batch.ndim == 1: + padded_batch[:] = oneflow.tensor(batch_field, dtype=dtype) + else: + raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " + "report.") + return padded_batch + + +def get_padded_oneflow_tensor(batch_field, dtype=None, pad_val=0): + """ + 例如: + [[1,2], [3]] -> oneflow.LongTensor([[1, 2], [3, 0]]) + + :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) + /4d(多为图片)。 + :param dtype: 目标类别是什么 + :param pad_val: pad 的 value + :return: + """ + shapes = get_shape(batch_field) + tensor = oneflow.full(shapes, dtype=dtype, value=pad_val) + tensor = fill_tensor(batch_field, tensor, dtype=dtype) + return tensor diff --git a/fastNLP/core/collators/padders/padder.py b/fastNLP/core/collators/padders/padder.py index 6a75b634..783d8fa2 100644 --- a/fastNLP/core/collators/padders/padder.py +++ b/fastNLP/core/collators/padders/padder.py @@ -1,7 +1,7 @@ class Padder: """ - 所有 Padder 对象父类,所有的 Padder 对象都会实现 pad(batch_field, pad_val=0, dtype=None) 的静态函数。 + 所有 **Padder** 对象的父类,所有的 Padder 对象都会实现静态函数 ``pad(batch_field, pad_val=0, dtype=None)`` 。 """ def __init__(self, pad_val, dtype): diff --git a/fastNLP/core/collators/padders/paddle_padder.py b/fastNLP/core/collators/padders/paddle_padder.py index ab287b58..57d31967 100644 --- a/fastNLP/core/collators/padders/paddle_padder.py +++ b/fastNLP/core/collators/padders/paddle_padder.py @@ -99,11 +99,11 @@ def _get_dtype(ele_dtype, dtype, class_name): class PaddleNumberPadder(Padder): """ - 可以将形如 [1, 2, 3] 这类的数据转为 paddle.Tensor([1, 2, 3]) + 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``paddle.Tensor([1, 2, 3])`` - :param pad_val: 该值无意义 - :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 - :param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 + :param pad_val: 该值无意义; + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`paddle.Tensor` 类型; + :param dtype: 输出的数据的 dtype 是什么。如 :class:`int`, :class:`float`, :class:`int32` 等; """ def __init__(self, pad_val=0, ele_dtype=None, dtype=None): # 仅当 ele_dtype 是 python number/ numpy number 或者 tensor @@ -112,16 +112,23 @@ class PaddleNumberPadder(Padder): @staticmethod def pad(batch_field, pad_val=0, dtype=None): + """ + 将 ``batch_field`` 数据 转为 :class:`paddle.Tensor` 并 pad 到相同长度。 + + :param batch_field: 输入的某个 field 的 batch 数据。 + :param pad_val: 需要填充的值 + :param dtype: 数据的类型 + """ return paddle.to_tensor(batch_field, dtype=dtype) class PaddleSequencePadder(Padder): """ - 将类似于 [[1], [1, 2]] 的内容 pad 为 paddle.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 + 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``paddle.Tensor([[1, 0], [1, 2]])`` 可以 pad 多重嵌套的数据。 :param pad_val: pad 的值。 - :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 - :param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`paddle.Tensor` 类型; + :param dtype: 输出的数据的 dtype 是什么。如 :class:`int`, :class:`float`, :class:`int32` 等; """ def __init__(self, ele_dtype=None, pad_val=0, dtype=None): dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) @@ -129,17 +136,30 @@ class PaddleSequencePadder(Padder): @staticmethod def pad(batch_field, pad_val=0, dtype=None): + """ + 将 ``batch_field`` 数据 转为 :class:`paddle.Tensor` 并 pad 到相同长度。 + + :param batch_field: 输入的某个 field 的 batch 数据。 + :param pad_val: 需要填充的值 + :param dtype: 数据的类型 + """ tensor = get_padded_paddle_tensor(batch_field, dtype=dtype, pad_val=pad_val) return tensor class PaddleTensorPadder(Padder): """ - 目前支持 [paddle.tensor([3, 2], paddle.tensor([2, 1])] 类似的,若内部元素不为 paddle.tensor ,则必须含有 tolist() 方法。 - + 目前支持 ``[paddle.tensor([3, 2], paddle.tensor([2, 1])]`` 类似的输入,若内部元素不为 :class:`paddle.Tensor` ,则必须含有 :meth:`tolist` 方法。 + + >>> PaddleTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) + [[ 3. 4.] + [ 1. -100.]] + >>> PaddleTensorPadder.pad([paddle.to_tensor([3, 4]), paddle.to_tensor([1])], pad_val=-100) + tensor([[ 3, 4], + [ 1, -100]]) :param pad_val: pad 的值。 - :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 - :param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`paddle.Tensor` 类型; + :param dtype: 输出的数据的 dtype 是什么。如 :class:`int`, :class:`float`, :class:`int32` 等; """ def __init__(self, pad_val=0, ele_dtype=None, dtype=None): dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) @@ -147,6 +167,13 @@ class PaddleTensorPadder(Padder): @staticmethod def pad(batch_field, pad_val=0, dtype=None): + """ + 将 ``batch_field`` 数据 转为 :class:`paddle.Tensor` 并 pad 到相同长度。 + + :param batch_field: 输入的某个 field 的 batch 数据。 + :param pad_val: 需要填充的值 + :param dtype: 数据的类型 + """ try: if not isinstance(batch_field[0], paddle.Tensor): batch_field = [np.array(field.tolist()) for field in batch_field] diff --git a/fastNLP/core/collators/padders/raw_padder.py b/fastNLP/core/collators/padders/raw_padder.py index 645c145c..52ba6617 100644 --- a/fastNLP/core/collators/padders/raw_padder.py +++ b/fastNLP/core/collators/padders/raw_padder.py @@ -13,9 +13,9 @@ def _get_dtype(ele_dtype, dtype, class_name): """ 用于检测数据的 dtype 类型, 根据内部和外部数据判断。 - :param ele_dtype 内部数据的类型 - :param dtype 数据外部类型 - :param class_name 类的名称 + :param ele_dtype: 内部数据的类型 + :param dtype: 数据外部类型 + :param class_name: 类的名称 """ if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " @@ -34,11 +34,11 @@ def _get_dtype(ele_dtype, dtype, class_name): class RawNumberPadder(Padder): """ - 可以将形如 [1, 2, 3] 这类的数据转为 [1, 2, 3] 。实际上该 padder 无意义。 + 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``[1, 2, 3]`` 。实际上该 padder 无意义。 - :param pad_val: 该值无意义 - :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 - :param dtype: 输出的数据的 dtype 是什么 + :param pad_val: + :param ele_dtype: + :param dtype: """ def __init__(self, pad_val=0, ele_dtype=None, dtype=None): dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) @@ -54,11 +54,11 @@ class RawNumberPadder(Padder): class RawSequencePadder(Padder): """ - 将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 + 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``[[1, 0], [1, 2]]`` 。可以 pad 多重嵌套的数据。 - :param pad_val: pad 的值 - :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 - :param dtype: 输出的数据的 dtype 是什么 + :param pad_val: pad 的值; + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型; + :param dtype: 输出的数据的 dtype ; """ def __init__(self, pad_val=0, ele_dtype=None, dtype=None): dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) @@ -68,8 +68,8 @@ class RawSequencePadder(Padder): def pad(batch_field, pad_val=0, dtype=None): """ - :param batch_field: - :param pad_val: + :param batch_field: 输入的某个 field 的 batch 数据。 + :param pad_val: 需要填充的值 :param dtype: 该参数无意义。 :return: """ @@ -78,11 +78,11 @@ class RawSequencePadder(Padder): class RawTensorPadder(Padder): """ - 将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 + 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``[[1, 0], [1, 2]]`` 。可以 pad 多重嵌套的数据。 - :param pad_val: pad 的值 - :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 - :param dtype: 输出的数据的 dtype 是什么 + :param pad_val: pad 的值; + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型; + :param dtype: 输出的数据的 dtype ; """ def __init__(self, pad_val=0, ele_dtype=None, dtype=None): dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) @@ -92,8 +92,8 @@ class RawTensorPadder(Padder): def pad(batch_field, pad_val=0, dtype=None): """ - :param batch_field: - :param pad_val: + :param batch_field: 输入的某个 field 的 batch 数据。 + :param pad_val: 需要填充的值 :param dtype: 该参数无意义。 :return: """ diff --git a/fastNLP/core/collators/padders/torch_padder.py b/fastNLP/core/collators/padders/torch_padder.py index 911c7d8c..a5ab9149 100644 --- a/fastNLP/core/collators/padders/torch_padder.py +++ b/fastNLP/core/collators/padders/torch_padder.py @@ -38,7 +38,7 @@ def is_torch_tensor(dtype): """ 判断是否为 torch 的 tensor - :param dtype 数据的 dtype 类型 + :param dtype: 数据的 dtype 类型 """ if not isclass(dtype) and isinstance(dtype, torch.dtype): return True @@ -49,9 +49,9 @@ def _get_dtype(ele_dtype, dtype, class_name): """ 用于检测数据的 dtype 类型, 根据内部和外部数据判断。 - :param ele_dtype 内部数据的类型 - :param dtype 数据外部类型 - :param class_name 类的名称 + :param ele_dtype: 内部数据的类型 + :param dtype: 数据外部类型 + :param class_name: 类的名称 """ if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " @@ -77,11 +77,11 @@ def _get_dtype(ele_dtype, dtype, class_name): class TorchNumberPadder(Padder): """ - 可以将形如 [1, 2, 3] 这类的数据转为 torch.Tensor([1, 2, 3]) + 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``torch.Tensor([1, 2, 3])`` - :param pad_val: 该值无意义 - :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 - :param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 + :param pad_val: 该值无意义; + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`torch.Tensor` 类型; + :param dtype: 输出的数据的 dtype 是什么。如 :class:`torch.long`, :class:`torch.float32`, :class:`int`, :class:`float` 等; """ def __init__(self, pad_val=0, ele_dtype=None, dtype=None): dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) @@ -94,11 +94,11 @@ class TorchNumberPadder(Padder): class TorchSequencePadder(Padder): """ - 将类似于 [[1], [1, 2]] 的内容 pad 为 torch.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 + 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``torch.Tensor([[1, 0], [1, 2]])`` 可以 pad 多重嵌套的数据。 - :param pad_val: 需要 pad 的值。 - :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 - :param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 + :param pad_val: 需要 pad 的值; + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`torch.Tensor` 类型; + :param dtype: 输出的数据的 dtype 是什么。如 :class:`torch.long`, :class:`torch.float32`, :class:`int`, :class:`float` 等; """ def __init__(self, pad_val=0, ele_dtype=None, dtype=None): dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) @@ -112,7 +112,7 @@ class TorchSequencePadder(Padder): class TorchTensorPadder(Padder): """ - 目前支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的。若内部元素不为 torch.tensor ,则必须含有 tolist() 方法。 + 目前支持 ``[torch.tensor([3, 2], torch.tensor([1])]`` 类似的输入。若内部元素不为 :class:`torch.Tensor` ,则必须含有 :meth:`tolist` 方法。 >>> TorchTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) [[ 3. 4.] @@ -121,9 +121,9 @@ class TorchTensorPadder(Padder): tensor([[ 3, 4], [ 1, -100]]) - :param pad_val: 需要 pad 的值。 - :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 - :param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 + :param pad_val: 需要 pad 的值; + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`torch.Tensor` 类型; + :param dtype: 输出的数据的 dtype 是什么。如 :class:`torch.long`, :class:`torch.float32`, :class:`int`, :class:`float` 等; """ def __init__(self, pad_val=0, ele_dtype=None, dtype=None): dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) diff --git a/fastNLP/core/collators/padders/torch_utils.py b/fastNLP/core/collators/padders/torch_utils.py index 3f21333b..d1887b36 100644 --- a/fastNLP/core/collators/padders/torch_utils.py +++ b/fastNLP/core/collators/padders/torch_utils.py @@ -5,6 +5,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch +__all__ = [] def is_torch_tensor_dtype(dtype) -> bool: """ diff --git a/fastNLP/core/collators/padders/utils.py b/fastNLP/core/collators/padders/utils.py index e4a258a8..6a50b33d 100644 --- a/fastNLP/core/collators/padders/utils.py +++ b/fastNLP/core/collators/padders/utils.py @@ -78,13 +78,12 @@ def fill_array(batch_field:List, padded_batch:np.ndarray): def get_padded_numpy_array(batch_field: List, dtype=None, pad_val=0) -> np.ndarray: """ - 例如: - [[1,2], [3]] -> np.array([[1, 2], [3, 0]]) + 将输入 pad 为 :class:`numpy.arraay` 类型,如:``[[1,2], [3]] -> np.array([[1, 2], [3, 0]])`` - :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) - /4d(多为图片)。 - :param dtype: 目标类别是什么 - :param pad_val: pad 的 value + :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 **1d** (多为句子长度)/ **2d** (多为文本序列)/ **3d** (多为字符序列) + /4d(多为图片); + :param dtype: 输出数据的 dtype 类型; + :param pad_val: 填充值; :return: """ shapes = get_shape(batch_field) diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index a1d4adf8..2b749b29 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -1,5 +1,5 @@ r""" -``Evaluator`` 是新版 fastNLP 中用来进行评测模型的评测器,其与 ``Trainer`` 相对应,二者共同构建起了 fastNLP 中**训练**和**评测**的框架。 +``Evaluator`` 是新版 **fastNLP** 中用来进行评测模型的评测器,其与 ``Trainer`` 相对应,二者共同构建起了 **fastNLP** 中 **训练** 和 **评测** 的框架。 ``Evaluator`` 的整体架构与 ``Trainer`` 类似,也是利用 ``Driver`` 来负责底层的评测逻辑。通过使用 ``Evaluator``,您可以快速、方便、准确地 对您的模型进行全方位地评测。 @@ -75,11 +75,11 @@ class Evaluator: :param device: 等价于 ``Trainer`` 中的 ``device`` 参数; :param evaluate_batch_step_fn: 您可以传入该参数来定制每次评测一个 batch 的数据时所执行的函数。该函数应接受的两个参数为 ``evaluator`` 和 ``batch``, 不需要有返回值;可以参考 :meth:`~fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop.batch_step_fn`; - :param evaluate_fn: 用来控制 ``Evaluator`` 在评测的前向传播过程中调用的是哪一个函数,例如对于 pytorch 而言,通过该参数确定使用的是 ``model.evaluate_step`` 还是 - ``model.forward``(不同训练框架所使用的的前向传播函数的方法名称不同); + :param evaluate_fn: 用来控制 ``Evaluator`` 在评测的前向传播过程中调用的是哪一个函数,例如对于 pytorch 而言,通过该参数确定使用的是 :meth:`model.evaluate_step` 还是 + :meth:`model.forward` (不同训练框架所使用的的前向传播函数的方法名称不同); 1. 如果该值是 ``None``,那么我们会默认使用 ``evaluate_step`` 当做前向传播的函数,如果在模型中没有找到该方法,则使用训练框架默认的前向传播函数; - 2. 如果为 ``str`` 类型,例如为 ``my_evaluate_step_fn``,则尝试寻找 ``model.my_evaluate_step_fn``,如果找不到则直接报错; + 2. 如果为 ``str`` 类型,例如为 ``'my_evaluate_step_fn'``,则尝试寻找 :meth:`model.my_evaluate_step_fn`,如果找不到则直接报错; :param input_mapping: 等价于 ``Trainer`` 中的 ``input_mapping`` 参数;对具体的用于评测一个 batch 的数据使用 ``input_mapping`` 处理之后再输入到 ``model`` 以及 ``metric`` 中。如果针对 ``model`` 和 ``metric`` 需要不同的 ``mapping``,请考虑使用 ``evaluate_batch_step_fn`` 参数定制; @@ -97,18 +97,28 @@ class Evaluator: ``metric`` 的计算都是自动化的,因此其一定需要参数匹配:根据 ``metric.update`` 的函数签名直接从字典数据中抽取其需要的参数传入进去; - :param fp16: 是否在评测时使用 fp16; + :param fp16: 是否在评测时使用 fp16 混合精度; :param verbose: 是否打印 evaluate 的结果; :kwargs: * *torch_kwargs* -- 等价于 ``Trainer`` 中的 ``torch_kwargs`` 参数; + * *paddle_kwargs* -- 等价于 ``Trainer`` 中的 ``paddle_kwargs`` 参数; + * *fairscale_kwargs* -- 等价于 ``Trainer`` 中的 ``fairscale_kwargs`` 参数; + * *deepspeed_kwargs* -- 等价于 ``Trainer`` 中的 ``deepspeed_kwargs`` 参数; + * *oneflow_kwargs* -- 等价于 ``Trainer`` 中的 ``oneflow_kwargs`` 参数; * *data_device* -- 等价于 ``Trainer`` 中的 ``data_device`` 参数; * *model_use_eval_mode* (``bool``) -- - 是否在评测的时候将 ``model`` 的状态设置成 ``eval`` 状态。在 ``eval`` 状态下,``model`` 的 - ``dropout`` 与 ``batch normalization`` 将会关闭。默认为 ``True``。如果为 ``False``,``fastNLP`` 不会对 ``model`` 的 ``evaluate`` 状态做任何设置。无论 - 该值是什么,``fastNLP`` 都会在评测后将 ``model`` 的状态设置为 ``train``; + 是否在评测的时候将 ``model`` 的状态设置成 ``eval`` 状态。在 ``eval`` 状态下,``model`` 的 + ``dropout`` 与 ``batch normalization`` 将会关闭。默认为 ``True``。如果为 ``False``,``fastNLP`` 不会对 ``model`` 的 ``evaluate`` 状态做任何设置。无论 + 该值是什么,``fastNLP`` 都会在评测后将 ``model`` 的状态设置为 ``train``; * *use_dist_sampler* -- - 是否使用分布式评测的方式。仅当 ``driver`` 为分布式类型时,该参数才有效。默认为根据 ``driver`` 是否支持 - 分布式进行设置。如果为 ``True``,将使得每个进程上的 ``dataloader`` 自动使用不同数据,所有进程的数据并集是整个数据集; + 表示在 ``Evaluator`` 中在使用分布式的时候是否将保证 dataloader 的 ``sampler`` 替换为 + 分布式的 ``sampler``,其特点是每个卡上的数据之间不重叠,所有卡上数据的加起来是整个数据集。若传入的 dataloader + 的 sampler 为: + + - 深度学习框架自带的默认 sampler ; + - fastNLP 的 Sampler ; + 则将替换为 :class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以 + 用到的数据。如果不是以上两类 sampler ,fastNLP 将报错。 * *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数; * *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数; * *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法,默认为 ``True`` 。 @@ -119,8 +129,8 @@ class Evaluator: _evaluate_batch_loop: Loop def __init__(self, model, dataloaders, metrics: Optional[Dict] = None, - driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None, - evaluate_batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None, + driver: Union[str, Driver] = 'auto', device: Optional[Union[int, List[int], str]] = None, + evaluate_batch_step_fn: Optional[Callable] = None, evaluate_fn: Optional[str] = None, input_mapping: Optional[Union[Callable, Dict]] = None, output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, fp16: bool = False, verbose: int = 1, **kwargs): @@ -200,16 +210,16 @@ class Evaluator: """ 用于帮助您加载模型的辅助函数; - :param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时, - 直接将该 folder 传递到 model_load_fn 中; - :param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn 不为 None`` 时,该参数无意义; - :param model_load_fn: ``callable`` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容; + :param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 ``fastnlp_model.pkl.tar`` 文件。在 ``model_load_fn`` 不为空时, + 直接将该 folder 传递到 ``model_load_fn`` 中; + :param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn`` 不为 ``None`` 时,该参数无意义; + :param model_load_fn: :class:`Callable` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容; :param kwargs: 理论上您不需要使用到该参数; .. note:: 注意您需要在初始化 ``Evaluator`` 后再通过 ``evaluator`` 实例来调用该函数;这意味着您需要保证在保存和加载时使用的 ``driver`` 是属于同一个 - 训练框架的,例如都是 ``pytorch`` 或者 ``paddle``; + 训练框架的,例如都是 **pytorch** 或者 **PaddlePaddle** ; """ self.driver.barrier() if not isinstance(folder, (io.BytesIO, BinaryIO)): @@ -237,15 +247,14 @@ class Evaluator: """ 该函数是在 ``Evaluator`` 初始化后用于真正开始评测的函数; - 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 + 返回一个字典类型的数据,其中 key 为 metric 的名字,value 为对应 metric 的结果。 - 1. 如果存在多个metric,一个dataloader的情况,key的命名规则是 - ``metric_indicator_name#metric_name`` + 1. 如果存在多个 metric ,一个 dataloader 的情况,key 的命名规则是 + ``metric_indicator_name#metric_name``; 2. 如果存在多个数据集,一个metric的情况,key的命名规则是 - ``metric_indicator_name#metric_name#dataloader_name`` (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 - 如果存在多个metric,多个dataloader的情况,key的命名规则是 - ``metric_indicator_name#metric_name#dataloader_name`` - 其中 metric_indicator_name 可能不存在; + ``metric_indicator_name#metric_name#dataloader_name`` (其中 **#** 是默认的 separator ,可以通过 Evaluator 初始化参数修改); + 3. 如果存在多个metric,多个dataloader的情况,key的命名规则是 + ``metric_indicator_name#metric_name#dataloader_name``,其中 metric_indicator_name 可能不存在; :param num_eval_batch_per_dl: 每个 dataloader 测试前多少个 batch 的数据,-1 为测试所有数据。 :return: 返回评测得到的结果,是一个没有嵌套的字典; @@ -276,8 +285,9 @@ class Evaluator: raise e finally: self.finally_progress_bar() + metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) if len(metric_results) > 0: # 如果 metric 不为 None 需要 print 。 - metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) + # metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) if self.verbose: if self.progress_bar == 'rich': f_rich_progress.print(metric_results) @@ -356,7 +366,7 @@ class Evaluator: def reset(self): """ - 调用所有 metric 的 reset() 方法,清除累积的状态。 + 调用所有 metric 的 :meth:`reset` 方法,清除累积的状态。 :return: """ @@ -364,7 +374,7 @@ class Evaluator: def update(self, batch, outputs): """ - 自动调用所有 metric 的 update 方法,会根据不同 metric 的参数列表进行匹配传参。 + 自动调用所有 metric 的 :meth:`update` 方法,会根据不同 metric 的参数列表进行匹配传参。 :param batch: 一般是来自于 DataLoader 的输出,如果不为 dict 类型的话,该值将被忽略。 :param outputs: 一般是来自于模型的输出。类别应为 dict 或者 dataclass 类型。 @@ -374,7 +384,7 @@ class Evaluator: def get_metric(self) -> Dict: """ - 调用所有 metric 的 get_metric 方法,并返回结果。其中 key 为 metric 的名称,value 是各个 metric 的结果。 + 调用所有 metric 的 :meth:`get_metric` 方法,并返回结果。其中 key 为 metric 的名称,value 是各个 metric 的结果。 :return: """ @@ -383,11 +393,9 @@ class Evaluator: @property def metrics_wrapper(self): """ - 由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 evaluate_batch_step_fn )中使用,同时也为了支持 + 由于需要保持 Evaluator 中 ``metrics`` 对象与用户传入的 ``metrics`` 保持完全一致(方便在 ``evaluate_batch_step_fn`` )中使用,同时也为了支持 不同形式的 metric( fastNLP 的 metric/torchmetrics 等),所以 Evaluator 在进行 metric 操作的时候都调用 metrics_wrapper 进行操作。 - - Returns: """ if self._metric_wrapper is None: self._metric_wrapper = _MetricsWrapper(self.metrics, evaluator=self) @@ -395,11 +403,12 @@ class Evaluator: def evaluate_step(self, batch): """ - 将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 。会将返回结果经过 output_mapping 处理后再 - 返回。 + 将 ``batch`` 传递到 model 中进行处理,根据当前 ``evaluate_fn`` 选择进行 evaluate 。会将返回结果经过 ``output_mapping`` + 处理后再 +返回。 - :param batch: {evaluate_fn} 函数支持的输入类型 - :return: {evaluate_fn} 函数的输出结果,如果有设置 output_mapping ,将是 output_mapping 之后的结果。 + :param batch: ``evaluate_fn`` 函数支持的输入类型 + :return: ``evaluate_fn`` 函数的输出结果,如果有设置 ``output_mapping`` ,将是 ``output_mapping`` 之后的结果。 """ outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn) outputs = match_and_substitute_params(self.output_mapping, outputs) @@ -408,7 +417,7 @@ class Evaluator: @property def metrics(self): """ - 返回用户传入的 metrics 对象。 + 返回用户传入的 ``metrics`` 对象。 :return: """ diff --git a/fastNLP/core/controllers/loops/evaluate_batch_loop.py b/fastNLP/core/controllers/loops/evaluate_batch_loop.py index fb936236..c31cfa0e 100644 --- a/fastNLP/core/controllers/loops/evaluate_batch_loop.py +++ b/fastNLP/core/controllers/loops/evaluate_batch_loop.py @@ -13,7 +13,7 @@ class EvaluateBatchLoop(Loop): r""" ``EvaluateBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的评测迭代过程; - :param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; + :param batch_step_fn: 您可以传入该参数来替换默认的 ``bath_step_fn``; """ def __init__(self, batch_step_fn:Optional[Callable]=None): if batch_step_fn is not None: @@ -21,10 +21,10 @@ class EvaluateBatchLoop(Loop): def run(self, evaluator, dataloader) -> Dict: r""" - 需要返回在传入的 dataloader 中的 evaluation 结果 + 需要返回在传入的 ``dataloader`` 中的 evaluation 结果 - :param evaluator: Evaluator 对象 - :param dataloader: 当前需要进行评测的dataloader + :param evaluator: :class:`~fastNLP.core.controllers.Evaluator` 对象 + :param dataloader: 当前需要进行评测的 ``dataloader`` :return: """ iterator = iter(dataloader) @@ -58,10 +58,10 @@ class EvaluateBatchLoop(Loop): @staticmethod def batch_step_fn(evaluator, batch): r""" - 针对一个 batch 的数据的评测过程; + 针对一个 ``batch`` 的数据的评测过程; - :param evaluator: Evaluator 对象 - :param batch: 当前需要评测的一个 batch 的数据; + :param evaluator: :class:`~fastNLP.core.controllers.Evaluator` 对象 + :param batch: 当前需要评测的一个 ``batch`` 的数据; """ outputs = evaluator.evaluate_step(batch) # 将batch输入到model中得到结果 evaluator.update(batch, outputs) # evaluator将根据metric的形参名字从batch/outputs中取出对应的值进行赋值 diff --git a/fastNLP/core/controllers/loops/loop.py b/fastNLP/core/controllers/loops/loop.py index b1952236..dc149587 100644 --- a/fastNLP/core/controllers/loops/loop.py +++ b/fastNLP/core/controllers/loops/loop.py @@ -1,5 +1,5 @@ r""" -``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,为了在实现 fastNLP 主要功能的同时保证 fastNLP 的易用性和代码的易读性,我们只对 +``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,为了在实现 **fastNLP** 主要功能的同时保证 **fastNLP** 的易用性和代码的易读性,我们只对 训练中的循环做了非常简单的抽象,``Loop`` 表示的是在训练或者评测的过程中针对单独一个 ``dataloader`` 的一个 ``epoch`` 的运算过程; 更为具体的使用详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` 和 @@ -24,7 +24,7 @@ class Loop: .. note:: - ``Trainer`` 和 ``Evaluator`` 中都提供了方便您进行定制 ``Loop`` 的接口函数,例如 ``Trainer.train_step``,``Trainer.backward``等; + ``Trainer`` 和 ``Evaluator`` 中都提供了方便您进行定制 ``Loop`` 的接口函数,例如 ``Trainer.train_step``, ``Trainer.backward`` 等; 在定制您自己的 ``TrainBatchLoop`` 时,请务必记得在正确的时机调用对应的 callback 函数,详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` 中对于 callback 函数的调用; @@ -34,5 +34,5 @@ class Loop: @staticmethod def batch_step_fn(controller: Union["Trainer", "Evaluator"], batch): r""" - 对于具体的一个 batch 的数据,实现训练或者评测过程中的一步; + 对于具体的一个 ``batch`` 的数据,实现训练或者评测过程中的一步; """ \ No newline at end of file diff --git a/fastNLP/core/controllers/loops/train_batch_loop.py b/fastNLP/core/controllers/loops/train_batch_loop.py index 645f4224..ca97fe9e 100644 --- a/fastNLP/core/controllers/loops/train_batch_loop.py +++ b/fastNLP/core/controllers/loops/train_batch_loop.py @@ -14,7 +14,7 @@ class TrainBatchLoop(Loop): r""" ``TrainBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的训练迭代过程; - :param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; + :param batch_step_fn: 您可以传入该参数来替换默认的 ``bath_step_fn``; """ def __init__(self, batch_step_fn: Optional[Callable] = None): @@ -23,14 +23,14 @@ class TrainBatchLoop(Loop): def run(self, trainer, dataloader): r""" - 对传入的 dataloader 进行一个 epoch 的主要的训练的循环过程; + 对传入的 ``dataloader`` 进行一个 epoch 的主要的训练的循环过程; .. note:: 您不需要自己主动地调用该方法,``Trainer`` 会负责调用该方法来完成训练过程; - :param trainer: ``Trainer`` 实例; - :param dataloader: 当前训练所使用的 dataloader; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; + :param dataloader: 当前训练所使用的 ``dataloader``; """ get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ else lambda *args, **kwargs: None @@ -41,10 +41,12 @@ class TrainBatchLoop(Loop): batch = next(dataloader) indices = get_batch_indices() except StopIteration: + trainer.on_fetch_data_end() break + trainer.on_fetch_data_end() + try: - trainer.on_fetch_data_end() batch = match_and_substitute_params(trainer.input_mapping, batch) batch = trainer.move_data_to_device(batch) @@ -66,10 +68,10 @@ class TrainBatchLoop(Loop): @staticmethod def batch_step_fn(trainer, batch): r""" - 针对一个 batch 的数据的训练过程; + 针对一个 ``batch`` 的数据的训练过程; - :param trainer: ``Trainer`` 实例; - :param batch: 一个 batch 的数据; + :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; + :param batch: 一个 ``batch`` 的数据; """ outputs = trainer.train_step(batch) trainer.backward(outputs) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 41fca6ba..af227380 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -1,7 +1,7 @@ """ -``Trainer`` 是 fastNLP 用于训练模型的专门的训练器,其支持多种不同的驱动模式 ``Driver``,不仅包括最为经常使用的 DDP,而且还支持 jittor 等国产 -的训练框架;新版的 fastNLP 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需要自己实现 -模型部分,而将训练层面的逻辑完全地交给 fastNLP; +``Trainer`` 是 **fastNLP** 用于训练模型的专门的训练器,其支持多种不同的驱动模式 ``Driver``,不仅包括最为经常使用的 DDP,而且还支持 jittor 等国产 +的训练框架;新版的 **fastNLP** 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需要自己实现 +模型部分,而将训练层面的逻辑完全地交给 **fastNLP**; """ from typing import Union, Optional, List, Callable, Dict, BinaryIO @@ -35,13 +35,14 @@ from fastNLP.envs import rank_zero_call from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.utils.exceptions import EarlyStopException +from fastNLP.core.dataloaders import OverfitDataLoader class Trainer(TrainerEventTrigger): r""" 用于支持快速训练的训练器。 - :param model: 训练所需要的模型,例如 ``torch.nn.Module``; + :param model: 训练所需要的模型,例如 :class:`torch.nn.Module`; .. note:: @@ -54,9 +55,17 @@ class Trainer(TrainerEventTrigger): 您应当使用 ``TorchDDPDriver``,意味着您需要通过 ``python -m torch.distributed.launch`` 的方式来启动训练,此时参数 ``device`` 应当设置为 None(此时我们会忽略该参数),具体见下面对于参数 ``device`` 的更详细的解释。 - :param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch"],之后我们会加入 jittor、paddle 等 - 国产框架的训练模式;其中 "torch" 表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``,具体使用哪一种取决于参数 ``device`` - 的设置; + :param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:``["auto", "torch", "paddle", "jittor", "fairscale", "deepspeed", "oneflow"]``: + + 1. 值为 ``"auto"`` 时,**FastNLP** 会根据传入模型的类型自行判断使用哪一种模式; + 2. 其值为 ``"torch"`` 时,表示使用 :class:`~fastNLP.core.drivers.TorchSingleDriver` 或者 :class:`~fastNLP.core.drivers.TorchDDPDriver`; + 3. 其值为 ``"paddle"`` 时,表示使用 :class:`~fastNLP.core.drivers.PaddleSingleDriver` 或者 :class:`~fastNLP.core.drivers.PaddleFleetDriver`; + 4. 其值为 ``"jittor"`` 时,表示使用 :class:`~fastNLP.core.drivers.JittorSingleDriver` 或者 :class:`~fastNLP.core.drivers.JittorMPIDriver`; + 5. 其值为 ``"fairscale"`` 时,表示使用 :class:`~fastNLP.core.drivers.FairScaleDriver`; + 6. 其值为 ``"deepspeed"`` 时,表示使用 :class:`~fastNLP.core.drivers.DeepSpeedDriver`; + 7. 其值为 ``"oneflow"`` 时,表示使用 :class:`~fastNLP.core.drivers.OneflowSingleDriver` 或者 :class:`~fastNLP.core.drivers.OneflowDDPDriver`; + + 在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置; .. warning:: @@ -64,19 +73,28 @@ class Trainer(TrainerEventTrigger): 这意味着当您传入一个 ``Driver`` 实例时,您传入给 ``Trainer`` 的 ``model`` 参数将会被忽略;也就是说模型在训练时使用的真正的模型是 您传入的 ``Driver`` 实例中的模型; - :param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict; + :param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 :class:`List` 或者 :class:`Dict`; + + .. warning:: + + 当使用分布式训练时, **fastNLP** 会默认将 ``dataloader`` 中的 ``Sampler`` 进行处理,以使得在一个 epoch 中,不同卡 + 用以训练的数据是不重叠的。如果你对 sampler 有特殊处理,那么请将 ``use_dist_sampler`` 参数设置为 ``False`` ,此刻需要由 + 你自身保证每张卡上所使用的数据是不同的。 + :param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; - :param device: 该参数用来指定具体训练时使用的机器;注意当该参数仅当您通过 `torch.distributed.launch/run` 启动时可以为 None, - 此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是你可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间 - 数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据 + :param device: 该参数用来指定具体训练时使用的机器;注意当该参数仅当您通过 ``torch.distributed.launch/run`` 启动时可以为 ``None``, + 此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是你可以通过参数 ``input_mapping`` 和 ``output_mapping`` 来实现设备之间 + 数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也可以通过在 kwargs 添加参数 ``data_device`` 来让我们帮助您将数据 迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前自己构造 DDP 的场景); device 的可选输入如下所示: - * *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1' 等; - * *torch.device*: 例如 'torch.device("cuda:0")'; - * *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver`; - * *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用 ``TorchDDPDriver``,不管您传入的列表的长度是 1 还是其它值; + * *str*: 例如 ``'cpu'``, ``'cuda'``, ``'cuda:0'``, ``'cuda:1'``, ``'gpu:0'`` 等; + * *torch.device*: 例如 ``torch.device("cuda:0")``; + * *oneflow.device*:例如 ``oneflow.device("cuda", 0)``; + * *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver` 这类 + 执行分布式训练的 Driver + * *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用分布式训练的 Driver ,不管您传入的列表的长度是 1 还是其它值; * *None*: 仅当用户自己通过训练框架提供的并行训练启动脚本开启 ddp 进程时为 None; .. note:: @@ -93,9 +111,9 @@ class Trainer(TrainerEventTrigger): .. warning:: - 注意参数 ``device`` 仅当您通过 pytorch 或者其它训练框架自身的并行训练启动脚本启动 ddp 训练时才允许为 ``None``! + 注意参数 ``device`` 仅当您通过训练框架自身的并行训练启动脚本启动 ddp 训练时才允许为 ``None``! - 例如,当您使用:: + 例如,在 pytorch 中,当您使用:: python -m torch.distributed.launch --nproc_per_node 2 train.py @@ -112,7 +130,7 @@ class Trainer(TrainerEventTrigger): :param n_epochs: 训练总共的 epoch 的数量,默认为 20;也可以通过 ``n_batches`` 参数设置总共迭代多少个 ``batch`` 。 :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 - 为 None; + 为 ``None``; :param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``, 不需要要返回值;更详细的使用位置和说明请见 :meth:`~fastNLP.core.controllers.TrainBatchLoop.batch_step_fn`; :param evaluate_batch_step_fn: 定制每次验证时前向运行一个 batch 的数据所执行的函数。该函数应接受的两个参数为 ``evaluator`` 和 ``batch``, @@ -124,8 +142,8 @@ class Trainer(TrainerEventTrigger): .. note:: 在 fastNLP 中,对于训练时使用的前向传播函数的查找逻辑如下所示: - 1. 如果 ``train_fn`` 为 None,那么在 model 的类 Model 中寻找方法 ``Model.train_step``;如果没有找到,那么默认使用 ``Model.forward``; - 2. 如果 ``train_fn`` 为一个字符串,例如 'my_step_fn',那么我们首先会在 model 的类 Model 中寻找方法 ``Model.my_step_fn``, + 1. 如果 ``train_fn`` 为 None,那么在 model 的类 Model 中寻找方法 :meth:`Model.train_step` ;如果没有找到,那么默认使用 :meth:`Model.forward`; + 2. 如果 ``train_fn`` 为一个字符串,例如 ``'my_step_fn'``,那么我们首先会在 model 的类 Model 中寻找方法 :meth:`Model.my_step_fn`, 如果没有找到,那么会直接报错; :param evaluate_fn: 用来控制 ``Trainer`` 中内置的 ``Evaluator`` 在验证的前向传播过程中是调用模型的哪一个函数,应当为 ``None`` @@ -133,7 +151,7 @@ class Trainer(TrainerEventTrigger): :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 ``Callback`` 类;具体可见 :class:`~fastNLP.core.callbacks.Callback`; :param metrics: 用于传给 ``Trainer`` 内部的 ``Evaluator`` 实例来进行训练过程中的验证。其应当为一个字典,其中 key 表示 monitor, - 例如 {"acc1": AccMetric(), "acc2": AccMetric()}; + 例如 ``{"acc1": AccMetric(), "acc2": AccMetric()}``; 目前我们支持的 ``metric`` 的种类有以下几种: @@ -147,7 +165,7 @@ class Trainer(TrainerEventTrigger): 1. 为负数时表示每隔几个 ``epoch`` evaluate 一次; 2. 为正数则表示每隔几个 ``batch`` evaluate 一次; 3. 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 - 返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 ``batch`` 结束后调用该函数判断是否需要 evaluate; + 返回一个 bool 值,返回为 ``True`` 说明需要进行 evaluate ;将在每个 ``batch`` 结束后调用该函数判断是否需要 evaluate; .. note:: @@ -199,7 +217,7 @@ class Trainer(TrainerEventTrigger): :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为; - 1. 如果该值为 ``False``,并且当 batch 为字典时,我们会根据**前向函数**所需要的参数从 batch 中提取对应的对象,然后传入到**前向函数**中; + 1. 如果该值为 ``False``,并且当 batch 为字典时,我们会根据 **前向函数** 所需要的参数从 batch 中提取对应的对象,然后传入到 **前向函数** 中; 2. 如果该值为 ``True``,那么我们会将 batch 直接透传给模型; .. todo:: @@ -210,8 +228,8 @@ class Trainer(TrainerEventTrigger): :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 才让优化器迭代一次,默认为 1; :param fp16: 是否开启混合精度训练,默认为 False; :param monitor: 对于一些特殊的 ``Callback``,例如 :class:`~fastNLP.core.callbacks.CheckpointCallback`,它们需要参数 ``monitor`` - 来从 ``Evaluator`` 的验证结果中获取当前评测的值,从而来判断是否执行一些特殊的操作。例如,对于 ``CheckpointCallback`` 而言,如果我们 - 想要每隔一个 epoch 让 ``Evaluator`` 进行一次验证,然后保存训练以来的最好的结果;那么我们需要这样设置: + 来从 ``Evaluator`` 的验证结果中获取当前评测的值,从而来判断是否执行一些特殊的操作。例如,对于 :class:`~fastNLP.core.callbacks.CheckpointCallback` + 而言,如果我们想要每隔一个 epoch 让 ``Evaluator`` 进行一次验证,然后保存训练以来的最好的结果;那么我们需要这样设置: .. code-block:: @@ -225,7 +243,7 @@ class Trainer(TrainerEventTrigger): )] ) - 这意味着对于 ``CheckpointCallback`` 来说,*'acc'* 就是一个监测的指标,用于在 ``Evaluator`` 验证后取出其需要监测的那个指标的值。 + 这意味着对于 :class:`~fastNLP.core.callbacks.CheckpointCallback` 来说,*'acc'* 就是一个监测的指标,用于在 ``Evaluator`` 验证后取出其需要监测的那个指标的值。 ``Trainer`` 中的参数 ``monitor`` 的作用在于为没有设置 ``monitor`` 参数但是需要该参数的 *callback* 实例设置该值。关于 ``monitor`` 参数更详细的说明,请见 :class:`~fastNLP.core.callbacks.CheckpointCallback`; @@ -237,9 +255,22 @@ class Trainer(TrainerEventTrigger): 注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效; - :param n_batches: 迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。 + :param n_batches: 总共迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。 + :param overfit_batches: 使用该参数来支持 **'过拟合'** 的功能;支持的值为 ``-1``、``0`` 或者 大于 0 的整数,表示使用多少个 batch 的数据 + 来进行过拟合训练;其中 0 为表示不进行任何操作;-1 表示使用所有的数据进行训练; - :param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None; + .. note:: + + 您可以使用该参数来简单地查看您的模型是否是 '正确的',即您的模型是否能够在少量的数据上快速进行收敛,从而说明损失函数以及优化器等 + 没有问题。当使用该参数时,我们会直接从 ``train_dataloader`` 中提取固定数量的 batch,然后在所有 epoch 中都是用这些数据 + 来进行训练; + + .. warning:: + + 在使用该参数时,您同样可以指定 ``metrics`` 参数来进行简单的验证,当该参数和 ``metrics`` 同时出现时,我们会将 ``evaluate_dataloaders`` + 直接替换为在过拟合中所使用的训练数据;因此您需要保证您的 ``metrics`` 是能够在 ``train_dataloader`` 上使用的; + + :param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 ``None``; .. note:: @@ -261,47 +292,52 @@ class Trainer(TrainerEventTrigger): ) 另一点需要说明的是,如果一个被 ``Trainer.on`` 修饰的函数,其修饰时没有指明 ``marker``,那么会将该函数传给代码位于其之后的 - 第一个 ``Trainer`` 实例,即使该 ``Trainer`` 实例的 marker 不为 None;这一点详见 :meth:`~fastNLP.core.controllers.Trainer.on` + 第一个 ``Trainer`` 实例,即使该 ``Trainer`` 实例的 marker 不为 ``None``;这一点详见 :meth:`~fastNLP.core.controllers.Trainer.on` :kwargs: - * *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数: - - * ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 - {'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; - * set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; - * non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; - * gradscaler_kwargs -- 用于 fp16=True 时,提供给 ``torch.amp.cuda.GradScaler`` 的参数。 - * *paddle_kwargs* -- 用于在指定 ``driver`` 为 'paddle' 时设定具体 driver 实例的一些参数: - - * fleet_kwargs -- 用于在使用 ``PaddleFleetDriver`` 时指定 ``DataParallel`` 和 ``fleet`` 初始化时的参数,包括: - - * is_collective -- 是否使用 paddle 集群式的分布式训练方法,目前仅支持为 ``True`` 的情况; - * role_maker -- 初始化 ``fleet`` 分布式训练 API 时使用的 ``RoleMaker`` - * 其它用于初始化 ``DataParallel`` 的参数; + * *torch_kwargs* -- ``TorchDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.TorchSingleDriver` 和 + :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`; + * *paddle_kwargs* -- ``PaddleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver` 和 + :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver`; + * *fairscale_kwargs* -- ``FairScaleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.FairScaleDriver`; + * *deepspeed_kwargs* -- ``DeepSpeedDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.DeepSpeedDriver`; + * *torch_kwargs* -- ``OneflowDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.oneflow_driver.OneflowSingleDriver` 和 + :class:`~fastNLP.core.drivers.oneflow_driver.OneflowDDPDriver`; * *data_device* -- 一个具体的 driver 实例中,有 ``model_device`` 和 ``data_device``,前者表示模型所在的设备,后者表示 当 ``model_device`` 为 None 时应当将数据迁移到哪个设备; .. note:: - 注意您在绝大部分情况下不会用到该参数! + **注意您在绝大部分情况下不会用到该参数!** 1. 当 driver 实例的 ``model_device`` 不为 None 时,该参数无效; - 2. 对于 pytorch,仅当用户自己通过 ``python -m torch.distributed.launch`` 并且自己初始化 ``init_process_group`` 时, - driver 实例的 ``model_device`` 才会为 None; - 3. 对于 paddle,该参数无效; + 2. 对于 **pytorch**,仅当用户自己通过 ``python -m torch.distributed.launch`` 并且自己初始化 ``init_process_group`` 时, + driver 实例的 ``model_device`` 才会为 None; + 2. 对于 **deepspeed**,仅当用户自己通过 ``deepspeed xxx.py`` 并且自己初始化 ``model.initialize`` 时, + driver 实例的 ``model_device`` 才会为 None; + 3. 对于 **paddle** 和 **oneflow**,该参数无效; * *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch - 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 - * *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``; - 不传入该值时,该值与 ``use_dist_sampler`` 参数保持一致; + 内所有卡的 sample 加起来为一整个数据集的 sample,同时为了保证所有卡上拥有相同数量的 sample ,有的卡上可能会有重复的 sample ,例如 + 8卡训练,只有9个 sample ,如果 batch_size 为 1,那么第二个 batch 时,有7张卡将没有 sample 可用,因此只有 **重复** 使用 sample 来 pad 到第二个 + batch 中。如果不希望 fastNLP 对 dataloader 的 sampler 做特殊设置,请将该值设置为 False ,若确实需要分布式的训练,请在 Trainer 外 + 对 ``train_dataloader`` 做的数据做特殊处理使得其在不同的卡之间 sample 是不同的。 + * *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将保证 dataloader 的 ``sampler`` 替换为 + evaluate 时使用的分布式的 ``sampler``,其特点是每个卡上的数据之间不重叠,所有卡上数据的加起来是整个数据集。若传入的 dataloader + 的 sampler 为: + + - 深度学习框架自带的默认 sampler ; + - fastNLP 的 Sampler ; + 则将替换为 :class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以 + 用到的数据。 * *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: - ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 - log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; + ``["all", "ignore", "only_error"]``;当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 + log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 ``"only_error"``; 注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``; - * *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 :class:`~.fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback`等对象, - 默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 :class:`~fastNLP.RichCallback`,否则使用 :class:`~fastNLP.RawTextCallback` 对象。如果 - 需要定制 progress bar 的参数,例如打印频率等,可以传入 :class:`~fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback` 等对象。 + * *progress_bar* -- 显示进度条的方式,目前支持 ``[None, 'raw', 'rich', 'auto', 'tqdm']`` 或者 :class:`~fastNLP.RichCallback` 、 :class:`~fastNLP.RawTextCallback` 等对象, + 默认为 ``'auto'`` , ``'auto'`` 表示如果检测到当前 terminal 为交互型则使用 :class:`~fastNLP.RichCallback`,否则使用 :class:`~fastNLP.RawTextCallback` 对象。如果 + 需要定制 progress bar 的参数,例如打印频率等,可以传入 :class:`~fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback` 等对象。 * *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。 * *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。 * *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。 @@ -312,19 +348,19 @@ class Trainer(TrainerEventTrigger): ``Trainer`` 是通过在内部直接初始化一个 ``Evaluator`` 来进行验证; ``Trainer`` 内部的 ``Evaluator`` 默认是 None,如果您需要在训练过程中进行验证,你需要保证这几个参数得到正确的传入: - 必须的参数:1. ``metrics``;2. ``evaluate_dataloaders``; + 必须的参数:``metrics`` 与 ``evaluate_dataloaders``; - 可选的其它参数:1. ``evaluate_batch_step_fn;2. ``evaluate_fn``;3. ``evaluate_every``;4. ``input_mapping``; - 5. ``output_mapping``; 6. ``model_wo_auto_param_call``;7. ``fp16``;8. ``monitor``;9. ``larger_better``; + 可选的其它参数:``evaluate_batch_step_fn``、 ``evaluate_fn``、``evaluate_every``、``input_mapping``、 + ``output_mapping``、``model_wo_auto_param_call``、``fp16``、``monitor``、``larger_better``; .. warning:: 如果 ``Trainer`` 中内置的 ``Evaluator`` 实例不为 ``None``,那么需要注意 ``Trainer`` 中的一些参数是与 ``Evaluator`` 一致的,它们分别为: 1. ``Evaluator`` 在初始化时的 ``driver`` 参数是 ``Trainer`` 中已经实例化过的 driver;这一点使得一些参数对于 ``Trainer`` 内部的 - ``Evaluator`` 没有用处,例如 ``device``,``torch_kwargs``,``data_device`` 和 ``output_from_new_proc`` 等; + ``Evaluator`` 没有用处,例如 ``device``,``torch_kwargs``,``data_device`` 和 ``output_from_new_proc`` 等; 2. ``input_mapping``,``output_mapping``,``model_wo_auto_param_call`` 和 ``fp16`` 是 ``Trainer`` 和其内部默认的 - ``Evaluator`` 是一致的; + ``Evaluator`` 是一致的; 当然,对于 ``input_mapping`` 和 ``output_mapping``,您可以通过添加 ``kwargs`` 中的参数 ``evaluate_input_mapping`` 和 ``evaluate_output_mapping`` 来单独为 ``Evaluator`` 进行更细致的订制。 @@ -338,9 +374,9 @@ class Trainer(TrainerEventTrigger): def __init__( self, model, - driver, train_dataloader, optimizers, + driver: str = "auto", device: Optional[Union[int, List[int], str]] = "cpu", n_epochs: int = 20, evaluate_dataloaders=None, @@ -359,6 +395,7 @@ class Trainer(TrainerEventTrigger): monitor: Union[str, Callable] = None, larger_better: bool = True, n_batches: int = -1, + overfit_batches: int = 0, marker: Optional[str] = None, **kwargs ): @@ -456,9 +493,6 @@ class Trainer(TrainerEventTrigger): n_batches=n_batches ) - if metrics is None and evaluate_dataloaders is not None: - raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") - if metrics is not None and evaluate_dataloaders is None: raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.") @@ -482,33 +516,42 @@ class Trainer(TrainerEventTrigger): else: _dist_sampler = None + self.dataloader = self.train_dataloader + self.driver.set_deterministic_dataloader(self.dataloader) + + self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, + reproducible=self.callback_manager._need_reproducible_sampler) + # 进行 overfit 相关的设置; + if overfit_batches != 0: + self.dataloader = OverfitDataLoader(self.dataloader, overfit_batches) + self.overfit_batches = overfit_batches + self.evaluator = None self.monitor = monitor self.larger_better = larger_better - if metrics is not None and evaluate_dataloaders is not None: - check_evaluate_every(evaluate_every) - progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 - if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 - progress_bar = progress_bar.name - self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, - driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, - evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, - output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, - use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), - progress_bar=progress_bar, - check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) + if metrics is not None: + if overfit_batches != 0: + evaluate_dataloaders = self.dataloader + if evaluate_dataloaders is not None: + check_evaluate_every(evaluate_every) + progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 + if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 + progress_bar = progress_bar.name + self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, + driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, + evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, + output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, + use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), + progress_bar=progress_bar, + check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) + else: + raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") if train_fn is not None and not isinstance(train_fn, str): raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) self.train_fn = train_fn - self.dataloader = self.train_dataloader - self.driver.set_deterministic_dataloader(self.dataloader) - - self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, - reproducible=self.callback_manager._need_reproducible_sampler) - self.evaluate_batch_step_fn = evaluate_batch_step_fn self.kwargs = kwargs @@ -521,17 +564,17 @@ class Trainer(TrainerEventTrigger): r""" 该函数是在 ``Trainer`` 初始化后用于真正开始训练的函数; - 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ``CheckpointCallback`` + 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 ``resume_from`` 为 ``None``,并且使用 ``CheckpointCallback`` 去保存断点重训的文件; - :param num_train_batch_per_epoch: 每个 epoch 训练多少个 batch 后停止,*-1* 表示使用 train_dataloader 本身的长度; - :param num_eval_batch_per_dl: 每个 evaluate_dataloader 验证多少个 batch 停止,*-1* 表示使用 evaluate_dataloader 本身的长度; + :param num_train_batch_per_epoch: 每个 epoch 训练多少个 batch 后停止,*-1* 表示使用 ``train_dataloader`` 本身的长度; + :param num_eval_batch_per_dl: 每个 ``evaluate_dataloader`` 验证多少个 batch 停止,*-1* 表示使用 ``evaluate_dataloader`` 本身的长度; :param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 的过程是否有错误。为 0 表示不检测; :param resume_from: 从哪个路径下恢复 trainer 的状态,注意该值需要为一个文件夹,例如使用 ``CheckpointCallback`` 时帮助您创建的保存的子文件夹; :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态;该参数如果为 ``True``, 在下一次断点重训的时候我们会精确到上次训练截止的具体的 sample 进行训练;否则我们只会恢复 model 和 optimizers 的状态,而 ``Trainer`` 中的 其余状态都是保持初始化时的状态不会改变; - :param catch_KeyboardInterrupt: 是否捕获 KeyboardInterrupt;如果该参数为 ``True``,在训练时如果您使用 ``ctrl+c`` 来终止程序, + :param catch_KeyboardInterrupt: 是否捕获 :class:`KeyboardInterrupt`;如果该参数为 ``True``,在训练时如果您使用 ``ctrl+c`` 来终止程序, ``Trainer`` 不会抛出异常,但是会提前退出,然后 ``trainer.run()`` 之后的代码会继续运行。注意该参数在您使用分布式训练的 ``Driver`` 时无效,例如 ``TorchDDPDriver``;非分布式训练的 ``Driver`` 下该参数默认为 True; @@ -552,7 +595,7 @@ class Trainer(TrainerEventTrigger): 整体的验证流程是否正确; ``num_eval_sanity_batch`` 的作用可能会让人产生迷惑,其本质和 ``num_eval_batch_per_dl`` 作用一致,但是其只被 ``Trainer`` 使用; - 并且其只会在训练的一开始使用,意思为:我们在训练的开始时会先使用 ``Evaluator``(如果其不为 ``None``) 进行验证,此时验证的 batch 的 + 并且其只会在训练的一开始使用,意思为:我们在训练的开始时会先使用 ``Evaluator`` (如果其不为 ``None``) 进行验证,此时验证的 batch 的 数量只有 ``num_eval_sanity_batch`` 个;但是对于 ``num_eval_batch_per_dl`` 而言,其表示在实际的整体的训练过程中,每次 ``Evaluator`` 进行验证时会验证的 batch 的数量。 @@ -698,7 +741,7 @@ class Trainer(TrainerEventTrigger): .. note:: 对于训练一个神经网络的整体的流程来说,其可以分为很多个时间点,例如 **"整体的训练前"**,**"训练具体的一个 epoch 前"**, - **"反向传播前"**,**"整体的训练结束后"**等;一个 ``callback`` 时机指的就是这些一个个具体的时间点; + **"反向传播前"**,**"整体的训练结束后"** 等;一个 ``callback`` 时机指的就是这些一个个具体的时间点; 该函数的参数 ``event`` 需要是一个 ``Event`` 实例,其使用方式见下方的例子; @@ -988,10 +1031,11 @@ class Trainer(TrainerEventTrigger): r""" 用于帮助您保存模型的辅助函数; - :param folder: 保存模型的文件夹。如果没有传入 model_save_fn 参数,则我们会在这个文件夹下保存 fastnlp_model.pkl.tar 文件; - :param only_state_dict: 仅在 model_save_fn 为空时,有效。是否只保存模型的 ``state_dict``; + :param folder: 保存模型的文件夹。如果没有传入 ``model_save_fn`` 参数,则我们会在这个文件夹下保存 ``fastnlp_model.pkl.tar`` 文件; + :param only_state_dict: 仅在 ``model_save_fn`` 为空时,有效。是否只保存模型的 ``state_dict``; :param model_save_fn: 您自己定制的用来替换该保存函数本身保存逻辑的函数,当您传入了该参数后,我们会实际调用该函数,而不会去调用 ``driver`` 的 ``save_model`` 函数; - :param kwargs: 理论上您不需要使用到该参数; + :kwargs: + * *input_spec* -- 该参数详见 **PaddlePaddle** 框架的保存函数 :meth:`~fastNLP.core.drivers.PaddleDriver.save_model` 中的说明; .. note:: @@ -1030,10 +1074,10 @@ class Trainer(TrainerEventTrigger): """ 用于帮助您加载模型的辅助函数; - :param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时, - 直接将该 folder 传递到 model_load_fn 中; - :param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn 不为 None`` 时,该参数无意义; - :param model_load_fn: ``callable`` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容; + :param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 ``fastnlp_model.pkl.tar`` 文件。在 ``model_load_fn`` + 不为空时,直接将该 folder 传递到 ``model_load_fn`` 中; + :param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn`` 不为 ``None`` 时,该参数无意义; + :param model_load_fn: :class:`Callable` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容; :param kwargs: 理论上您不需要使用到该参数; .. note:: @@ -1073,12 +1117,13 @@ class Trainer(TrainerEventTrigger): 用于帮助您实现断点重训功能的保存函数;保存内容包括:callback 状态、Trainer 的状态、Sampler 的状态【在恢复的时候才能恢复到特定 batch 】、 模型参数、optimizer的状态、fp16 Scaler的状态【如果有】。 - :param folder: 保存在哪个文件夹下,会在该文件下声称两个文件:fastnlp_checkpoint.pkl.tar 与 fastnlp_model.pkl.tar 。 - 如果 model_save_fn 不为空,则没有 fastnlp_model.pkl.tar 文件; - :param only_state_dict: 当 model_save_fn 为空时有效,表明是否仅保存模型的权重; + :param folder: 保存在哪个文件夹下,会在该文件下生成两个文件:``fastnlp_checkpoint.pkl.tar`` 与 ``fastnlp_model.pkl.tar`` 。 + 如果 ``model_save_fn`` 不为空,则没有 ``fastnlp_model.pkl.tar`` 文件; + :param only_state_dict: 当 ``model_save_fn`` 为空时有效,表明是否仅保存模型的权重; :param model_save_fn: 如果模型保存比较特殊,可以传入该函数自定义模型的保存过程,输入应该接受一个文件夹(实际上就是接受上面的 folder 参数),不需要返回值;这意味着您可以通过该函数来自己负责模型的保存过程,而我们则会将 ``trainer`` 的状态保存好; - :param kwargs: 理论上您不需要使用到该参数; + :kwargs: + * *input_spec* -- 该参数详见 **PaddlePaddle** 框架的保存函数 :meth:`~fastNLP.core.drivers.PaddleDriver.save_model` 中的说明; .. note:: @@ -1097,7 +1142,7 @@ class Trainer(TrainerEventTrigger): 为了支持断点重训功能,我们会在调用该函数时保存以下内容: 1. 各个 ``callback`` 的状态,这主要涉及到一些带有运行状态的 ``callback``; - 2. 控制训练流程的变量 ``trainer_state``,具体详见 :class:`~fastNLP.core.controllers.utils.states.TrainerState`; + 2. 控制训练流程的变量 ``trainer_state``,具体详见 :class:`~fastNLP.core.controllers.utils.state.TrainerState`; 3. 一个特殊的变量 ``num_consumed_batches``,表示在这次训练过程中总共训练了多少个 batch 的数据;您不需要关心这个变量; 4. sampler 的状态,为了支持断点重训功能,我们会在 trainer 初始化的时候,将您的 ``trainer_dataloader`` 的 ``sampler`` 替换为 我们专门用于断点重训功能的 ``ReproducibleSampler``,详见 :class:`~fastNLP.core.samplers.reproducible_sampler.ReproducibleSampler`; @@ -1309,6 +1354,11 @@ class Trainer(TrainerEventTrigger): 用于在使用梯度累积并且进行分布式训练时,由于在前 ``accumulation_steps - 1`` 的时间内不需要进行梯度的同步,因此通过使用该 context 上下文 环境来避免梯度的同步; + .. note:: + + 部分深度学习框架的梯度累积并不需要通过提供上下文环境实现,关于这点需要您深入了解您正在使用的框架的机制;而对于这些框架,fastNLP 会返回一个 + 空的上下文环境。 + :return: 一个支持 ``no_sync`` 的 ``context``; """ @@ -1394,7 +1444,7 @@ class Trainer(TrainerEventTrigger): def model_device(self): r""" :return: 返回当前模型所在的设备;注意该值在当且仅当在少数情况下为 ``None``,例如当使用 ``pytorch`` 时,仅当用户自己初始化 ``init_progress_group`` 时 - ``model_device`` 才为 None; + ``model_device`` 才为 None; """ return self.driver.model_device diff --git a/fastNLP/core/controllers/utils/state.py b/fastNLP/core/controllers/utils/state.py index 676b548c..8c1bfde8 100644 --- a/fastNLP/core/controllers/utils/state.py +++ b/fastNLP/core/controllers/utils/state.py @@ -42,7 +42,7 @@ class State(dict): class TrainerState: r""" 该类用于我们 fastNLP 自己内部为了训练流程所记录的一些状态,当然是要暴露给用户给用户使用的; - 我们保存的state大部分上是 trainer 断点重训 需要重新加载的; + 我们保存的 state 大部分上是 trainer 断点重训 需要重新加载的; 专属于 `Trainer` 的状态记载的类; :param n_epochs: 训练过程中总共的 epoch 的数量; @@ -50,7 +50,7 @@ class TrainerState: :param global_forward_batches: 当前模型总共 forward 了多少个 step; :param batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; :param num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; - :param n_batches: 完整训练过程会 forward 的 step 数量,注意 n_batches = n_batches * n_epochs; + :param n_batches: 完整训练过程会 forward 的 step 数量,注意 ``n_batches = num_batches_per_epoch * n_epochs`` ; """ n_epochs: Optional[int] = None # 无论如何重新算 @@ -73,6 +73,7 @@ class TrainerState: def load_state_dict(self, state_dict: Dict): r""" 用于断点重训来重新加载保存的状态字典; + :param state_dict: 用于加载的状态字典; """ for key in state_dict: diff --git a/fastNLP/core/controllers/utils/utils.py b/fastNLP/core/controllers/utils/utils.py index ef3cf98c..0a351354 100644 --- a/fastNLP/core/controllers/utils/utils.py +++ b/fastNLP/core/controllers/utils/utils.py @@ -4,11 +4,12 @@ from fastNLP.core.callbacks import CallbackManager from .state import TrainerState from fastNLP.core.utils.utils import _check_valid_parameters_number +__all__ = [] class TrainerEventTrigger: r""" - 为了避免在训练流程中调用 callback 函数中写成类似 'trainer.callback_manager.on_train_begin' 的形式,我们选择单独抽象为 'Trainer' - 抽象一层,然后一些特殊的操作可以在这里进行,例如我们通过 `on_validate_end` 来通知所有的 'CheckpointCallback' 实例在当前的 step 后保存 + 为了避免在训练流程中调用 callback 函数中写成类似 `'trainer.callback_manager.on_train_begin'` 的形式,我们选择单独为 ``Trainer`` + 抽象一层,然后一些特殊的操作可以在这里进行,例如我们通过 :meth:`on_validate_end` 来通知所有的 ``CheckpointCallback`` 实例在当前的 step 后保存 模型。 """ callback_manager: CallbackManager @@ -138,7 +139,7 @@ def check_evaluate_every(evaluate_every): ``evaluate_every`` 的使用详见 ``Trainer`` 的 ``evaluate_every`` 参数; - 主要在于当参数 ``evaluate_every`` 是一个 callable 的函数时,需要保证其参数的正确性; + 主要在于当参数 ``evaluate_every`` 是一个 Callable 的函数时,需要保证其参数的正确性; """ if not callable(evaluate_every) and (not isinstance(evaluate_every, int) or evaluate_every == 0): raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") diff --git a/fastNLP/core/dataloaders/__init__.py b/fastNLP/core/dataloaders/__init__.py index 84f8b288..06d3f5a8 100644 --- a/fastNLP/core/dataloaders/__init__.py +++ b/fastNLP/core/dataloaders/__init__.py @@ -3,14 +3,20 @@ __all__ = [ 'TorchDataLoader', 'PaddleDataLoader', 'JittorDataLoader', + 'OneflowDataLoader', 'prepare_jittor_dataloader', 'prepare_paddle_dataloader', 'prepare_torch_dataloader', + 'prepare_oneflow_dataloader', - "prepare_dataloader" + "prepare_dataloader", + + "OverfitDataLoader" ] from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader -from .prepare_dataloader import prepare_dataloader \ No newline at end of file +from .oneflow_dataloader import OneflowDataLoader, prepare_oneflow_dataloader +from .prepare_dataloader import prepare_dataloader +from .utils import OverfitDataLoader \ No newline at end of file diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 83555f6e..22aeeec7 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -47,38 +47,35 @@ class JittorDataLoader: * callate_fn 为 ``'auto'`` 时,``JittorDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。 此时可以配套使用 ``JittorDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 * callate_fn 为 ``None`` 时, ``JittorDataLoader`` 默认使用 Jittor DataLoader 自带的 collate_fn - * collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 + * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 dataset 的一条数据;该 Callable 函数还应当返回一个对象。 + :param dataset: 实现了 __getitem__() 和 __len__() 的对象。 + :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 + :param shuffle: 是否打乱数据集, 默认为 ``False``。 + :param drop_last: 当 ``drop_last=True`` 时,``JittorDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; + 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 + :param num_workers: 当 ``num_workers > 0`` 时, ``JittorDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 + 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 + :param buffer_size: 每个进程占用的内存空间,默认为 512M。主要是配合 ``num_workers`` 使用,用户可以自定义每个进程的内存大小。 + :param stop_grad: 是否不使用梯度, 默认 ``True`` 。 + :param keep_numpy_array: 返回的数据是 ``np.array`` 类型而不是 ``jittor.Var`` 类型,默认为 ``False`` + :param endless: 是否让 ``JittorDataLoader`` 无限返回数据,也就是将 dataset 循环使用使得返回数据是没有限制的。默认为 ``False``. + :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. + + * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, + ``JittorDataLoader`` 调用默认的 Jittor 框架的 ``DataLoader`` 自带的 ``collate_batch`` 作为 callate_fn 的默认值, 其无法处理 + :class:`~fastNLP.core.dataset.DataSet` 的 dataset 对象。 + * callate_fn 为 ``'auto'`` 时,``JittorDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 + 此时可以配套使用 ``JittorDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 + * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 + dataset 的一条数据;该 Callable 函数还应当返回一个对象。 """ def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, collate_fn: Union[None, str, Callable] = "auto") -> None: - """ - :param dataset: 实现了 __getitem__() 和 __len__() 的对象。 - :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 - :param shuffle: 是否打乱数据集, 默认为 ``False``。 - :param drop_last: 当 ``drop_last=True`` 时,``JittorDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; - 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 - :param num_workers: 当 ``num_workers > 0`` 时, ``JittorDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 - 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 - :param buffer_size: 每个进程占用的内存空间,默认为512M。主要是配合num_workers使用,用户可以自定义每个进程的内存大小。 - :param stop_grad: 是否不使用梯度, 默认 ``True`` 。 - :param keep_numpy_array: 返回的数据是 ``np.array`` 类型而不是 ``jittor.Var`` 类型,默认为 ``False`` - :param endless: 是否让 ``JittorDataLoader`` 无限返回数据,也就是将 dataset 循环使用使得返回数据是没有限制的。默认为 ``False``. - :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. - - * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, - ``JittorDataLoader`` 调用默认的 Jittor 框架的 ``DataLoader`` 自带的 ``collate_batch`` 作为 callate_fn 的默认值, 其无法处理 - :class:`~fastNLP.core.dataset.DataSet` 的 dataset 对象。 - * callate_fn 为 ``'auto'`` 时,``JittorDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 - 此时可以配套使用 ``JittorDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 - * collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 - dataset 的一条数据;该 Callable 函数还应当返回一个对象。 - - """ # TODO 验证支持replacesampler (以后完成) 增加Sampler # 将内部dataset批次设置为1 if isinstance(dataset, Dataset): @@ -136,20 +133,20 @@ class JittorDataLoader: """ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 - :param field_name: 需要调整的 field 的名称。如果 :class:`~fastNLP.core.Dataset` 的 :class:`~fastNLP.core.Dataset.__getitem__` - 方法返回的是 dict 类型的,则可以直接使用对应的 field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 - ``{'a': {'b': 1}}`` 中的使用 ``('a', 'b')`` 如果 ``__getitem__`` 返回的是 Sequence 类型的,则可以使用 *_0*, *_1* 表示序列中 - 第 **0** 或 **1** 个元素。如果该 field 在数据中没有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 ``_single`` 。 - :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 - field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 - 无意义。 - :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选 ``['raw', 'numpy', 'Jittor', 'paddle', 'jittor', 'auto']`` ,分别代表,输出为 ``list`` , ``numpy.ndarray`` , - ``Jittor.Tensor`` , ``paddle.Tensor`` , ``jittor.Var`` 类型。若 ``pad_val`` 为 ``None`` ,该值无意义 。 - :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 - batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch - 形式,输出将被直接作为结果输出。 - :return: 返回 Collator 自身 + :param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 + field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; + 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 + 如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 ``"_single"`` 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, + 该值无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 + :param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, + :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 + 若 ``pad_val`` 为 ``None`` ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 + batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 + :return: 返回使用的 collator """ collator = self._get_collator() if isinstance(collator, Collator): @@ -173,16 +170,14 @@ class JittorDataLoader: def set_ignore(self, *field_names) -> Collator: """ - 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 - Example:: + 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略:: - collator.set_ignore('field1', 'field2') + dataloader.set_ignore('field1', 'field2') - :param field_name: 需要调整的 field 的名称。如果 :class:`~fastNLP.core.Dataset` 的 :class:`~fastNLP.core.Dataset.__getitem__` - 方法返回的是 dict 类型的,则可以直接使用对应的 field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 - ``{'a': {'b': 1}}`` 中的使用 ``('a', 'b')`` 如果 ``__getitem__`` 返回的是 Sequence 类型的,则可以使用 *_0*, *_1* 表示序列中 - 第 **0** 或 **1** 个元素。 - :return: 返回 Collator 自身 + :param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 + field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; + 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 + :return: 返回使用的 collator """ collator = self._get_collator() if isinstance(collator, Collator): @@ -193,14 +188,14 @@ class JittorDataLoader: def get_batch_indices(self) -> List[int]: """ - 获取当前 batch 的 idx + 获取当前 ``batch`` 中每条数据对应的索引。 - :return: + :return: 当前 ``batch`` 数据的索引; """ return self.cur_batch_indices -def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = False, +def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = None, drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, collate_fn: Union[None, str, Callable] = "auto", @@ -208,36 +203,37 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa -> Union[Dict[str, JittorDataLoader], JittorDataLoader]: """ ``prepare_jittor_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 :class:`JittorDataLoader` 对象, 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 - 根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: - - * 当 ds_or_db 为 ``DataSet`` 时,``prepare_jittor_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 - 帮你实例化一个 :class:`JittorDataLoader` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 - * 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_Jittor_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value - 来创建不同的 :class:`JittorDataLoader` 对象;当 key 中包含'train'字符串时,``prepare_jittor_dataloader`` 默认该 value 为 train 数据集, - 会将 batch_size 和 sampler 作为参数,其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。 - 最终根据 ``key: JittorDataLoader`` 组成 ``Dict[key, JittorDataLoader]`` 的字典返回。 + 根据 ``ds_or_db`` 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: + + * 当 ds_or_db 为 :class:`~fastNLP.io.DataSet` 时,``prepare_jittor_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 + 帮你实例化一个 :class:`JittorDataLoader` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`; + * 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_jittor_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value + 来创建不同的 :class:`JittorDataLoader` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_jittor_dataloader`` 默认该 value 为训练数据集, + 会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。 + 最终根据 ``key: JittorDataLoader`` 组成 ``Dict[key, JittorDataLoader]`` 的字典返回; * 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_jittor_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 - :class:`JittorDataLoader` 对象;当 key 中包含'train'字符串时,``prepare_Jittor_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, - 其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: JittorDataLoader`` 组成 - ``Dict[key, JittorDataLoader]`` 的字典返回。 + :class:`JittorDataLoader` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_jittor_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和 + ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: JittorDataLoader`` 组成 + ``Dict[key, JittorDataLoader]`` 的字典返回; - :param ds_or_db: 可以有以下三种取值, + :param ds_or_db: 可以有以下三种取值: - * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 - * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 - * ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为:class:`~fastNLP.TorchDataLoader` + * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; + * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; + * ds_or_db 为实现了 :meth:`__getitem__` 和 :meth:`__len__` 的对象 ,返回值为 :class:`~fastNLP.core.dataloaders.JittorDataLoader`; :param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。 :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 - :param shuffle: 是否打乱数据集, 默认为 ``False``。 + :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 True , + 其它的为 False 。 :param drop_last: 当 ``drop_last=True`` 时,:class:`JittorDataLoader` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 :param num_workers: 当 ``num_workers > 0`` 时, :class:`JittorDataLoader` 会开启 num_workers 个子进程来处理数据, 可以加快 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 - :param buffer_size: 每个进程占用的内存空间,默认为512M。主要是配合num_workers使用,用户可以自定义每个进程的内存大小。 + :param buffer_size: 每个进程占用的内存空间,默认为512M。主要是配合 ``num_workers`` 使用,用户可以自定义每个进程的内存大小。 :param stop_grad: 是否不使用梯度, 默认 ``True`` 。 - :param keep_numpy_array: 返回的数据是 ``np.array`` 类型而不是 ``jittor.Var`` 类型,默认为 ``False`` + :param keep_numpy_array: 返回的数据是 :class:`np.array` 类型而不是 :class:`ittor.Var` 类型,默认为 ``False`` :param endless: 是否让 :class:`JittorDataLoader` 无限返回数据,也就是将 dataset 循环使用使得返回数据是没有限制的。默认为 ``False``. :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. @@ -246,11 +242,8 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa :class:`~fastNLP.core.dataset.DataSet` 的 dataset 对象。 * callate_fn 为 ``'auto'`` 时,:class:`JittorDataLoader` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 此时可以配套使用 :class:`JittorDataLoader` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 - * collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 + * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 dataset 的一条数据;该 Callable 函数还应当返回一个对象。 - - :return: 返回数据类型为 :class:`Dict[str, JittorDataLoader]`, :class:`JittorDataLoader` 其中之一,根据输入 - ``ds_or_db`` 变化而变化。 """ from fastNLP.io.data_bundle import DataBundle @@ -258,7 +251,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa dl_bundle = {} for name, ds in ds_or_db.iter_datasets(): if 'train' in name: - dl_bundle[name] = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, + dl_bundle[name] = JittorDataLoader(ds, batch_size=batch_size, shuffle=True if shuffle is None else shuffle, drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, @@ -267,7 +260,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa else: dl_bundle[name] = JittorDataLoader(ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, - shuffle=shuffle, + shuffle=False if shuffle is None else shuffle, drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, @@ -279,14 +272,14 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa ds_dict = {} for name, ds in ds_or_db.items(): if 'train' in name: - dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, + dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=True if shuffle is None else shuffle, drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, collate_fn=collate_fn) else: dl = JittorDataLoader(ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, - shuffle=shuffle, + shuffle=False if shuffle is None else shuffle, drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, @@ -296,7 +289,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa return ds_dict elif isinstance(ds_or_db, HasLenGetitemType): - dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, + dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=False if shuffle is None else shuffle, drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, collate_fn=collate_fn) diff --git a/fastNLP/core/dataloaders/oneflow_dataloader/__init__.py b/fastNLP/core/dataloaders/oneflow_dataloader/__init__.py new file mode 100644 index 00000000..d17ce91c --- /dev/null +++ b/fastNLP/core/dataloaders/oneflow_dataloader/__init__.py @@ -0,0 +1,6 @@ +__all__ = [ + "OneflowDataLoader", + "prepare_oneflow_dataloader", +] + +from .fdl import OneflowDataLoader, prepare_oneflow_dataloader diff --git a/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py b/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py new file mode 100644 index 00000000..deed9281 --- /dev/null +++ b/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py @@ -0,0 +1,353 @@ +__all__ = [ + 'OneflowDataLoader', + 'prepare_oneflow_dataloader' +] + +from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List, Any +from abc import ABC +from copy import deepcopy + +from fastNLP.core.dataset import DataSet +from fastNLP.core.collators import Collator +from fastNLP.core.dataloaders.utils import indice_collate_wrapper +from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler +from ..utils import _match_param +from ..utils import HasLenGetitemType + +if _NEED_IMPORT_ONEFLOW: + from oneflow.utils.data import DataLoader, Sampler, Dataset +else: + from fastNLP.core.utils.dummy_class import DummyClass as DataLoader + + +class _FDataSet: + """ + 提供给 ``OneflowDataLoader`` 使用的 warp 类,其功能是对 dataset 进行封装,wrap 修改 dataset 的 __getitem__ 函数,增加返回 + 数据的下标 idx 。 + + ..note:: + + 需要注意的是传入 ``__init__`` 的 dataset 需要实现 __getattribute__ 方法才能在 _FDataset 实例化对象中调用 dataset 的方法 + + """ + + def __init__(self, dataset) -> None: + self.dataset = dataset + + def __getitem__(self, item: Union[int, list]) -> Tuple: + return (item, self.dataset[item]) + + def __getattr__(self, item): + try: + return self.dataset.__getattribute__(item) + except AttributeError as e: + raise e + + def __len__(self) -> int: + return len(self.dataset) + + +class OneflowDataLoader(DataLoader): + """ + 提供给 ``oneflow`` 框架使用的 ``DataLoader`` 函数,``OneflowDataLoader`` 提供了 ``Collator`` 来自动检测 dataset 的每个 field 是否可 pad, + 若是可 pad 的 field 则自动 pad 到相同长度,否则只会将相同 field 的数据收集组成一个 batch 返回。 + 具体详见 :class:`~fastNLP.core.collators.Collator`;用户通过 callte_fn 来控制是否使用该功能, collate_fn 只能为 ``['auto', None, Callable]`` + 三种取值。 + + * callate_fn 为 ``'auto'`` 时,``OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。 + 此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 + * callate_fn 为 ``None`` 时, ``OneflowDataLoadr`` 默认使用 :class:`oneflow.utils.data.DataLoader` 自带的 collate_fn + * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 + dataset 的一条数据;该 Callable 函数还应当返回一个对象。 + + :param dataset: 实现了 __getitem__() 和 __len__() 的对象。 + :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 + :param non_train_batch_size: 非训练数据集的 ``OneflowDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。 + :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , + 其它的为 False 。 + :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , + 默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。 + :param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , + 默认为None, 当其不为 None 时, shuffle 参数无效。 + :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 + dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。 + :param num_workers: 当 ``num_workers > 0`` 时, ``OneflowDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 + 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 + :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. + + * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, + ``OneflowDataLoader`` 调用默认的 oneflow 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 + :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 + * callate_fn 为 ``'auto'`` 时,``OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 + 此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 + * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 + dataset 的一条数据;该 Callable 函数还应当返回一个对象。 + + :param pin_memory: 如果其为 ``True``, 那么 ``OneflowDataLoader`` 会在返回数据张量之前将其 copy 到 cuda 的 pin memory 中。 + :param drop_last: 当 ``drop_last=True`` 时,``OneflowDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; + 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 + :param timeout: 子进程的输出队列获取数据的超时值 + :param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。 + :param multiprocessing_context: 多进程的上下文环境 + :param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` + :param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . + :param persistent_workers: 如果其为 ``True``, ``OneflowDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` + """ + + def __init__(self, dataset, batch_size: int = 16, + shuffle: bool = False, sampler = None, batch_sampler = None, + num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', + pin_memory: bool = False, drop_last: bool = False, + timeout: float = 0, worker_init_fn: Optional[Callable] = None, + multiprocessing_context=None, generator=None, prefetch_factor: int = 2, + persistent_workers: bool = False, **kwargs) -> None: + + if isinstance(dataset, DataSet) and collate_fn is None: + raise ValueError("When use FastNLP DataSet, collate_fn must be not None") + + if not isinstance(dataset, _FDataSet): + dataset = _FDataSet(dataset) + + if num_workers>0 and multiprocessing_context is None: + multiprocessing_context = 'fork' # 这里默认使用fork的方式来启动多进程 + + if batch_sampler is not None: + batch_size = 1 + shuffle = False + sampler = None + elif sampler is None: + sampler = RandomSampler(dataset, shuffle=shuffle) + shuffle = False + + if isinstance(collate_fn, str): + if collate_fn == 'auto': + if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset + collate_fn = deepcopy(dataset.dataset.collator) + collate_fn.set_backend(backend="oneflow") + else: + collate_fn = Collator(backend="oneflow") + else: + raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") + + dl_kwargs = _match_param(OneflowDataLoader.__init__, DataLoader.__init__, fn_name=DataLoader.__name__) + if dl_kwargs is None: + super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, + pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers) + else: + super().__init__(**dl_kwargs) + + self.cur_batch_indices = None + + def __iter__(self): + self.collate_fn = indice_collate_wrapper(self.collate_fn) + for indices, data in super().__iter__(): + self.cur_batch_indices = indices + yield data + + def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, + pad_fn: Callable = None) -> Collator: + """ + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + + :param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 + field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; + 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 + 如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, + 该值无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 + :param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, + :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 + 若 ``pad_val`` 为 ``None`` ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 + batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 + :return: 返回使用的 collator + """ + collator = self._get_collator() + if isinstance(collator, Collator): + collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) + return collator + else: + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") + + def _get_collator(self): + """ + 如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None + + :return: + """ + collator = None + if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator): + collator = self.collate_fn.__wrapped__ + elif isinstance(self.collate_fn, Collator): + collator = self.collate_fn + return collator + + def set_ignore(self, *field_names) -> Collator: + """ + 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略:: + + dataloader.set_ignore('field1', 'field2') + + :param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 + field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; + 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 + :return: 返回使用的 collator + """ + collator = self._get_collator() + if isinstance(collator, Collator): + collator.set_ignore(*field_names) + return collator + else: + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") + + def get_batch_indices(self) -> List[int]: + """ + 获取当前 ``batch`` 中每条数据对应的索引。 + + :return: 当前 ``batch`` 数据的索引; + """ + return self.cur_batch_indices + + +def prepare_oneflow_dataloader(ds_or_db, + batch_size: int = 16, + shuffle: bool = None, + sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, + batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, + num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', + pin_memory: bool = False, drop_last: bool = False, + timeout: float = 0, worker_init_fn: Optional[Callable] = None, + multiprocessing_context=None, generator=None, prefetch_factor: int = 2, + persistent_workers: bool = False, + non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, + non_train_batch_size: int = None) \ + -> Union[OneflowDataLoader, Dict[str, OneflowDataLoader]]: + """ + ``prepare_oneflow_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``OneflowDataloader`` 对象, 详见 :class:`~fastNLP.OneflowDataLoader`。 + 根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: + + * 当 ds_or_db 为 ``DataSet`` 时,``prepare_oneflow_dataloader`` 会将使用的除了 ``non_train_batch_size`` 和 ``non_train_sampler`` 以外的参数来 + 帮你实例化一个 ``OneflowDataLoader`` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.OneflowDataLoader`。 + * 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_oneflow_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value + 来创建不同的 ``OneflowDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_oneflow_dataloader`` 默认该 value 为训练数据集, + 会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。 + 最终根据 ``key: OneflowDataLoader`` 组成 ``Dict[key, OneflowDataLoader]`` 的字典返回。 + * 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_oneflow_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 + ``OneflowDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_oneflow_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和 ``sampler`` 作为参数, + 其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: OneflowDataLoader`` 组成 + ``Dict[key, OneflowDataLoader]`` 的字典返回。 + + :param ds_or_db: 可以有以下三种取值, + + * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, OneflowDataLoader]`` 的字典; + * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, OneflowDataLoader]`` 的字典; + * ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为 :class:`~fastNLP.core.dataloaders.OneflowDataLoader`; + + :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 + :param non_train_batch_size: 非训练数据集的 ``OneflowDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。 + :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , + 其它的为 False 。 + :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , + 默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。 + :param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , + 默认为None, 当其不为 None 时, shuffle 参数无效。 + :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List 中的值为 + dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。 + :param num_workers: 当 ``num_workers > 0`` 时, ``OneflowDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 + 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 + :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. + + * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, + ``OneflowDataLoader`` 调用默认的 oneflow 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 + :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 + * callate_fn 为 ``'auto'`` 时,``OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 + 此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 + * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 + dataset 的一条数据;该 Callable 函数还应当返回一个对象。 + + :param pin_memory: 如果其为 ``True``, 那么 ``OneflowDataLoader`` 会在返回数据张量之前将其 copy 到 cuda 的 pin memory 中。 + :param drop_last: 当 ``drop_last=True`` 时,``OneflowDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; + 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 + :param timeout: 子进程的输出队列获取数据的超时值 + :param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。 + :param multiprocessing_context: 多进程的上下文环境 + :param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` + :param prefetch_factor: 每个 worker 提前装载的 samples 数量。 ``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` 。 + :param persistent_workers: 如果其为 ``True``, ``OneflowDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` + + """ + + from fastNLP.io import DataBundle + + if isinstance(ds_or_db, DataBundle): + dl_bundle = {} + for name, ds in ds_or_db.iter_datasets(): + if 'train' in name: + dl_bundle[name] = OneflowDataLoader(dataset=ds, batch_size=batch_size, + shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, + num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, + ) + else: + dl_bundle[name] = OneflowDataLoader(dataset=ds, + batch_size=non_train_batch_size if non_train_batch_size else batch_size, + shuffle=False if shuffle is None else shuffle, + sampler=non_train_sampler if non_train_sampler else sampler, + batch_sampler=batch_sampler, + num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, + ) + return dl_bundle + + elif isinstance(ds_or_db, Mapping): + dl_bundle = {} + for name, ds in ds_or_db.items(): + if 'train' in name: + dl_bundle[name] = OneflowDataLoader(dataset=ds, batch_size=batch_size, + shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, + num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, + ) + else: + dl_bundle[name] = OneflowDataLoader(dataset=ds, + batch_size=non_train_batch_size if non_train_batch_size else batch_size, + shuffle=False if shuffle is None else shuffle, + sampler=non_train_sampler if non_train_sampler else sampler, + batch_sampler=batch_sampler, + num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, + ) + + return dl_bundle + + elif isinstance(ds_or_db, HasLenGetitemType): + dl = OneflowDataLoader(dataset=ds_or_db, batch_size=batch_size, + shuffle=False if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, + num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + ) + return dl + + else: + raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!") diff --git a/fastNLP/core/dataloaders/paddle_dataloader/__init__.py b/fastNLP/core/dataloaders/paddle_dataloader/__init__.py index ab9523e5..a5ae3a68 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/__init__.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/__init__.py @@ -1,6 +1,6 @@ __all__ = [ + 'PaddleDataLoader', 'prepare_paddle_dataloader', - 'PaddleDataLoader' ] from .fdl import PaddleDataLoader, prepare_paddle_dataloader \ No newline at end of file diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index 37130e3e..575ffef4 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -39,7 +39,7 @@ class _PaddleDataset(Dataset): def __getattr__(self, item): try: - self.dataset.__getattribute__(item) + return self.dataset.__getattribute__(item) except Exception as e: raise e @@ -53,6 +53,7 @@ class PaddleDataLoader(DataLoader): 1. ``PaddleDataLoader`` 支持输入的 dataset 是无框架的,只要实现了 __getitem__() 和 __len__() 的对象即可, 当不使用 :class:`~fastNLP.core.dataset.DataSet` 时也不需要传入 collate_fn, 只要只需要将 ``collate_fn='auto'`` 就能够自动 探测数据的类型并判断能否 pad 。此时可以调用 ``set_pad`` 和 ``set_ignore`` 方法来设置 field 的 pad_val 或者忽略某个 field 的 pad 操作。 + Example:: from fastNLP import PaddleDataLoader @@ -76,9 +77,46 @@ class PaddleDataLoader(DataLoader): .. note:: 当传入的dataset为fastNLP的DataSet时,collate_fn不能为None。默认可以是"auto"或者自定义callable函数。 - 3. 当 collate_fn 为 ``Callable`` 时,该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 + 3. 当 collate_fn 为 :class:`Callable` 时,该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 dataset 的一条数据;该 Callable 函数还应当返回一个对象。 - + + :param dataset: 实现了 __getitem__() 和 __len__() 的对象。 + :param feed_list: feed Tensor list. + 这个张量能被 ``paddle.static.data`` 创建。 如果 :attr:`return_list` 是 ``False``, 那么 :attr:`feed_list` + 应该被设置。 默认为 ``None `` 。 + :param places: 将数据放进的一个 list 的 place。 :attr:`places` 能为 None. + 如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串 + 可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。 + :param return_list: 每个设备上的返回值是否为以列表形式显示。 如果 :attr:`return_list=False`, + 每个设备上的返回值值为 str -> Tensor 的 dict, 其中 dict 的 key 为每个 fed Tensors 的名字。 + 如果 :attr:`return_list=True`, 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` 只能在动态图情况下设置为 ``True`` . + 默认值为 ``True`` 。 + :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 + dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``shuffle`` 参数均失效。 + :param batch_size: 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 None 有效。 + :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , + 其它的为 False 。 + :param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; + 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 + :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. + + * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, + ``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 + :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 + * callate_fn 为 ``'auto'`` 时,``PaddleDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 + 此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 + * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 + dataset 的一条数据;该 Callable 函数还应当返回一个对象。 + + :param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 + 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 + :param use_buffer_reader: 是否开启 buffer_reader 。如果 ``use_buffer_reader=True`` ,那么 ``PaddleDataLoader`` 会异步地预取下一个 batch 的 + 数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。 + :param use_shared_memory: 是否使用共享内存。当 ``use_shared_memory=True`` 时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的 + 共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( ``num_workers>0`` )下生效。 + :param timeout: 从子进程的输出队列获取数据的超时值 + :param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 + :param persistent_workers: 如果其为 ``True``, ``PaddleDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` """ def __init__(self, dataset, feed_list=None, places=None, @@ -88,45 +126,7 @@ class PaddleDataLoader(DataLoader): num_workers: int = 0, use_buffer_reader: bool = True, use_shared_memory: bool = True, timeout: int = 0, worker_init_fn: Callable = None, persistent_workers=False) -> None: - """ - - :param dataset: 实现了 __getitem__() 和 __len__() 的对象。 - :param feed_list: feed Tensor list。 - 这个张量能被 :code:`paddle.static.data()` 创建。 如果 :attr:`return_list` 是 ``False``, 那么 :attr:`feed_list` - 应该被设置。 默认为 ``None`` - :param places: 将数据放进的一个 list 的 place。 :attr:`places` 能为 None。 - 如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串 - 可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。 - :param return_list: 每个设备上的返回值是否为以列表形式显示。 如果 :attr:`return_list=False`, 每个设备上的返回值值为 str -> Tensor 的 dict, - 其中 dict 的 key 为每个 fed Tensors 的名字。如果 :attr:`return_list` 为 ``True`` , 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` - 只能在动态图情况下设置为 ``True`` 。默认值为 ``True`` 。 - :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 - dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, shuffle 参数均失效。 - :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 - :param shuffle: 是否将数据打乱,若``shuffle=True`` 则会将dataset打乱;若否则什么也不做。 - :param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; - 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 - :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. - - * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, - ``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 ``default_collate_fn`` 作为 callate_fn 的默认值, 其无法处理 - :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 - * callate_fn 为 ``'auto'`` 时,``PaddleDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 - 此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 - * collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 - dataset 的一条数据;该 Callable 函数还应当返回一个对象。 - - :param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 - 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 - :param use_buffer_reader: 是否开启 buffer_reader 。如果 ``use_buffer_reader=True`` ,那么 ``PaddleDataLoader`` 会异步地预取下一个 batch 的 - 数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。 - :param use_shared_memory: 是否使用共享内存。当 ``use_shared_memory=True`` 时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的 - 共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( num_workers>0 )下生效。 - :param timeout: 从子进程的输出队列获取数据的超时值 - :param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 - :param persistent_workers: 如果其为 ``True``, ``PaddleDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` - """ # FastNLP Datset, collate_fn not None if isinstance(dataset, FDataSet) and collate_fn is None: raise ValueError("When use FastNLP DataSet, collate_fn must be not None") @@ -137,9 +137,11 @@ class PaddleDataLoader(DataLoader): if batch_sampler is None: batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) - batch_size = 1 - shuffle = False - drop_last = False + # 因为无论如何传给 DataLoader 的 batch_sampler 都不是 None + # 所以要恢复默认值防止报错 + batch_size = 1 + shuffle = False + drop_last = False if isinstance(collate_fn, str): if collate_fn == 'auto': @@ -184,20 +186,20 @@ class PaddleDataLoader(DataLoader): """ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 - :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); - 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 - 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 - :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 - field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 - 无意义。 - :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选['raw', 'numpy', 'Paddle', 'paddle', 'paddle', 'auto'],分别代表,输出为 list, numpy.ndarray, - Paddle.Tensor, paddle.Tensor, paddle.Var 类型。若 pad_val 为 None ,该值无意义 。 - :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 - batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch - 形式,输出将被直接作为结果输出。 - :return: 返回 Collator 自身 + :param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 + field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; + 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 + 如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, + 该值无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 + :param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, + :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 + 若 ``pad_val`` 为 ``None`` ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 + batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 + :return: 返回使用的 collator """ collator = self._get_collator() if isinstance(collator, Collator): @@ -221,15 +223,14 @@ class PaddleDataLoader(DataLoader): def set_ignore(self, *field_names) -> Collator: """ - 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 - Example:: + 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略:: - collator.set_ignore('field1', 'field2') + dataloader.set_ignore('field1', 'field2') - :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 - __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 - :return: 返回 Collator 自身 + :param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 + field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; + 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 + :return: 返回使用的 collator """ collator = self._get_collator() if isinstance(collator, Collator): @@ -258,58 +259,59 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, non_train_batch_size: int = None) \ -> Union[Dict[str, PaddleDataLoader], PaddleDataLoader]: """ - ``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader``对象, 详见 :class:`~fastNLP.PaddleDataLoader`。 + ``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader`` 对象, 详见 :class:`~fastNLP.PaddleDataLoader`。 根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: - * 当 ds_or_db 为 ``DataSet``时,``prepare_paddle_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 - 帮你实例化一个 ``PaddleDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.PaddleDataLoader`。 + * 当 ds_or_db 为 ``DataSet`` 时,``prepare_paddle_dataloader`` 会将除了 ``non_train_batch_size`` 和 ``non_train_sampler`` 以外的参数来 + 帮你实例化一个 ``PaddleDataLoader`` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.PaddleDataLoader`。 * 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_paddle_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value - 来创建不同的 ``PaddleDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_Paddle_dataloader`` 默认该 value 为 train 数据集, - 会将 batch_size 和 sampler 作为参数,其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。 - 最终根据 ``key: PaddleDataLoader`` 组成 ``Dict[key, PaddleDataLoader]`` 的字典返回。 + 来创建不同的 ``PaddleDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_Paddle_dataloader`` 默认该 value 为训练数据集, + 会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。 + 最终根据 ``key: PaddleDataLoader`` 组成 ``Dict[key, PaddleDataLoader]`` 的字典返回。 * 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_paddle_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 - ``PaddleDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_paddle_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, - 其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: PaddleDataLoader`` 组成 - ``Dict[key, PaddleDataLoader]`` 的字典返回。 + ``PaddleDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_paddle_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和 ``sampler`` 作为参数, + 其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: PaddleDataLoader`` 组成 + ``Dict[key, PaddleDataLoader]`` 的字典返回。 :param ds_or_db: 可以有以下三种取值, - * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 - * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 - * ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为:class:`~fastNLP.TorchDataLoader` + * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; + * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; + * ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为 :class:`~fastNLP.TorchDataLoader`; - :param feed_list: (list(Tensor)|tuple(Tensor)): feed Tensor list. - 这个张量能被 :code:`paddle.static.data()` 创建。 如果:attr:`return_list` 是 ``False``, 那么 :attr:`feed_list` - 应该被设置。 默认为 ``None `` - :param places: (list(Place)|tuple(Place)|list(str)|optional): 将数据放进的一个 list 的 place。 :attr:`places` 能为 None. - 如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串 - 可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。 + :param feed_list: feed Tensor list. + 这个张量能被 ``paddle.static.data`` 创建。 如果 :attr:`return_list` 是 ``False``, 那么 :attr:`feed_list` + 应该被设置。 默认为 ``None `` 。 + :param places: 将数据放进的一个 list 的 place。 :attr:`places` 能为 None. + 如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串 + 可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。 :param return_list: 每个设备上的返回值是否为以列表形式显示。 如果 :attr:`return_list=False`, - 每个设备上的返回值值为 str -> Tensor 的 dict, 其中 dict 的 key 为每个 fed Tensors 的名字。 - 如果 :attr:`return_list=True`, 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` 只能在动态图情况下设置为 ``True`` . - 默认值为 ``True`` 。 + 每个设备上的返回值值为 str -> Tensor 的 dict, 其中 dict 的 key 为每个 fed Tensors 的名字。 + 如果 :attr:`return_list=True`, 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` 只能在动态图情况下设置为 ``True`` . + 默认值为 ``True`` 。 :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 - dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, shuffle 参数均失效。 + dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``shuffle`` 参数均失效。 :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 - :param shuffle: 是否将数据打乱,若``shuffle=True``则会将dataset打乱;若否则什么也不做。 + :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , + 其它的为 False 。 :param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; - 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 + 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. - * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, - ``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 - :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 + * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, + ``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 + :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 * callate_fn 为 ``'auto'`` 时,``PaddleDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 - 此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 - * `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 - dataset 的一条数据;该 Callable 函数还应当返回一个对象。 - - :param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 - 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 - :param use_buffer_reader: 是否开启 buffer_reader 。如果 `use_buffer_reader=True`` ,那么 ``PaddleDataLoader` `会异步的预取下一个 batch 的 - 数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。 + 此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 + * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 + dataset 的一条数据;该 Callable 函数还应当返回一个对象。 + + :param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 + 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 + :param use_buffer_reader: 是否开启 buffer_reader 。如果 ``use_buffer_reader=True`` ,那么 ``PaddleDataLoader`` 会异步地预取下一个 batch 的 + 数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。 :param use_shared_memory: 是否使用共享内存。当 ``use_shared_memory=True`` 时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的 - 共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( num_workers>0 )下生效。 + 共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( ``num_workers>0`` )下生效。 :param timeout: 从子进程的输出队列获取数据的超时值 :param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 :param persistent_workers: 如果其为 ``True``, ``PaddleDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` @@ -324,7 +326,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, batch_sampler=batch_sampler, batch_size=batch_size, - shuffle=shuffle, + shuffle=True if shuffle is None else shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, @@ -335,7 +337,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, return_list=return_list, batch_sampler=batch_sampler, batch_size=non_train_batch_size if non_train_batch_size else batch_size, - shuffle=shuffle, + shuffle=False if shuffle is None else shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, @@ -348,7 +350,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, for name, ds in ds_or_db.items(): if 'train' in name: dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, - batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, + batch_sampler=batch_sampler, batch_size=batch_size, + shuffle=False if shuffle is None else shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, timeout=timeout, worker_init_fn=worker_init_fn, @@ -357,7 +360,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, batch_sampler=batch_sampler, batch_size=non_train_batch_size if non_train_batch_size else batch_size, - shuffle=shuffle, + shuffle=False if shuffle is None else shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, timeout=timeout, worker_init_fn=worker_init_fn, @@ -367,7 +370,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, elif isinstance(ds_or_db, HasLenGetitemType): dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, - batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, + batch_sampler=batch_sampler, batch_size=batch_size, + shuffle=False if shuffle is None else shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) diff --git a/fastNLP/core/dataloaders/prepare_dataloader.py b/fastNLP/core/dataloaders/prepare_dataloader.py index 5f469f2b..65b739aa 100644 --- a/fastNLP/core/dataloaders/prepare_dataloader.py +++ b/fastNLP/core/dataloaders/prepare_dataloader.py @@ -9,41 +9,44 @@ import sys from .torch_dataloader import prepare_torch_dataloader from .paddle_dataloader import prepare_paddle_dataloader from .jittor_dataloader import prepare_jittor_dataloader +from .oneflow_dataloader import prepare_oneflow_dataloader from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS from ..log import logger -def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, +def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = None, drop_last: bool = False, collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, backend: str = 'auto'): """ 自动创建合适的 ``DataLoader`` 对象。例如,检测当当前环境是 ``torch`` 的,则返回 ``TorchDataLoader`` , 是 ``paddle`` 的则 返回 ``PaddleDataLoader`` 。如果有更多需要定制的参数,请直接使用对应的 ``prepare`` 函数,例如 - :func:`~fastNLP.prepare_torch_dataloader` 或 :func:`~fastNLP.prepare_paddle_dataloader` 等。 + :func:`~fastNLP.core.dataloaders.prepare_torch_dataloader` 或 :func:`~fastNLP.core.dataloaders.prepare_paddle_dataloader` 等。 :param dataset: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。 * 为单个数据集对象时,返回一个 DataLoader 。 * 为数据集对象序列时,返回一个序列的 DataLoader 。 - * 为字典型 或 :class:`~fastNLP.io.DataBundle` 数据时,返回 `Dict` 类型的数据。 + * 为字典型 或 :class:`~fastNLP.io.DataBundle` 数据时,返回 :class:`Dict` 类型的数据。 :param batch_size: 批次大小。 - :param shuffle: 是否打乱数据集。 - :param drop_last: 当最后一个 batch 不足 batch_size 数量的是否,是否丢弃。 + :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , + 其它的为 False 。 + :param drop_last: 当最后一个 batch 不足 ``batch_size`` 数量的是否,是否丢弃。 :param collate_fn: 用于处理一个 batch 的函数,一般包括 padding 和转为 tensor。有以下三种取值: * 为 ``auto`` 时,使用 :class:`~fastNLP.Collator` 进行 padding 和 转tensor 。 - * 为 ``Callable`` 时,应当接受一个 ``batch`` 的数据作为参数,同时输出一个对象 。 + * 为 :class:`Callable` 时,应当接受一个 ``batch`` 的数据作为参数,同时输出一个对象 。 * 为 ``None`` 时,使用各个框架的 DataLoader 的默认 ``collate_fn`` 。 :param num_workers: 使用多少进程进行数据的 fetch 。 - :param backend: 当前支持 ``["auto", "torch", "paddle", "jittor"]`` 四种类型。 + :param backend: 当前支持 ``["auto", "torch", "paddle", "jittor", "oneflow"]`` 四种类型。 - * 为 ``auto`` 时,首先(1) 根据环境变量 "FASTNLP_BACKEND" 进行判断;如果没有设置则,(2)通过当前 + * 为 ``auto`` 时,首先根据环境变量 ``"FASTNLP_BACKEND"`` 进行判断;如果没有设置则通过当前 ``sys.modules`` 中已经 import 的 ``backend`` 进行判定。如果以上均无法判定,则报错。如果找到了 ``backend`` ,则按照下述的方式处理。 - * 为 ``torch`` 时,使用 :func:`~fastNLP.prepare_torch_dataloader` 。 - * 为 ``paddle`` 时,使用 :func:`~fastNLP.prepare_paddle_dataloader` 。 - * 为 ``jittor`` 时,使用 :func:`~fastNLP.prepare_jittor_dataloader` 。 + * 为 ``torch`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_torch_dataloader` 。 + * 为 ``paddle`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_paddle_dataloader` 。 + * 为 ``jittor`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_jittor_dataloader` 。 + * 为 ``oneflow`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_oneflow_dataloader` 。 :return """ @@ -60,6 +63,10 @@ def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, dro prepare_jittor_dataloader(ds_or_db=dataset, sampler=None, collate_fn=collate_fn, num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) + elif backend == 'oneflow': + return prepare_oneflow_dataloader(ds_or_db=dataset, batch_sampler=None, collate_fn=collate_fn, + num_workers=num_workers, shuffle=shuffle, sampler=None, + batch_size=batch_size) else: raise ValueError(f"Currently we do not support backend:{backend}.") diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 1ee33c66..e73bd818 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -58,9 +58,41 @@ class TorchDataLoader(DataLoader): * callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。 此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 * callate_fn 为 ``None`` 时, ``TorchDataLoadr`` 默认使用 torch DataLoader 自带的 collate_fn - * collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 + * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 dataset 的一条数据;该 Callable 函数还应当返回一个对象。 + :param dataset: 实现了 __getitem__() 和 __len__() 的对象。 + :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 + :param non_train_batch_size: 非训练数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。 + :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , + 其它的为 False 。 + :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , + 默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。 + :param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , + 默认为None, 当其不为 None 时, shuffle 参数无效。 + :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 + dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。 + :param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 + 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 + :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. + + * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, + ``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 + :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 + * callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 + 此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 + * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 + dataset 的一条数据;该 Callable 函数还应当返回一个对象。 + + :param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 + :param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; + 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 + :param timeout: 子进程的输出队列获取数据的超时值 + :param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。 + :param multiprocessing_context: 多进程的上下文环境 + :param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` + :param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . + :param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` """ def __init__(self, dataset, batch_size: int = 16, @@ -70,44 +102,16 @@ class TorchDataLoader(DataLoader): timeout: float = 0, worker_init_fn: Optional[Callable] = None, multiprocessing_context=None, generator=None, prefetch_factor: int = 2, persistent_workers: bool = False, **kwargs) -> None: - """ - - :param dataset: 实现了 __getitem__() 和 __len__() 的对象。 - :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 - :param shuffle: 是否打乱数据集, 默认为 ``False``。 - :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , - 默认为None, 当其不为 None 时, shuffle 参数无效。 - :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 - dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 - :param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 - 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 - :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. - - * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, - ``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 ``default_collate_fn`` 作为 callate_fn 的默认值, 其无法处理 - :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 - * callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 - 此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 - * collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 - dataset 的一条数据;该 Callable 函数还应当返回一个对象。 - - :param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 - :param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; - 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 - :param timeout: 子进程的输出队列获取数据的超时值 - :param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 - :param multiprocessing_context: 多进程的上下文环境 - :param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` - :param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2``意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . - :param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` - """ if isinstance(dataset, DataSet) and collate_fn is None: raise ValueError("When use FastNLP DataSet, collate_fn must be not None") if not isinstance(dataset, _FDataSet): dataset = _FDataSet(dataset) + if num_workers>0 and multiprocessing_context is None: + multiprocessing_context = 'fork' # 这里默认使用fork的方式来启动多进程 + if batch_sampler is not None: batch_size = 1 shuffle = False @@ -150,20 +154,20 @@ class TorchDataLoader(DataLoader): """ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 - :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); - 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 - 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 - :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 - field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 - 无意义。 - :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选['raw', 'numpy', 'torch', 'torch', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, - torch.Tensor, torch.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 - :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 - batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch - 形式,输出将被直接作为结果输出。 - :return: 返回 Collator + :param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 + field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; + 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 + 如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, + 该值无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 + :param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, + :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 + 若 ``pad_val`` 为 ``None`` ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 + batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 + :return: 返回使用的 collator """ collator = self._get_collator() if isinstance(collator, Collator): @@ -187,15 +191,14 @@ class TorchDataLoader(DataLoader): def set_ignore(self, *field_names) -> Collator: """ - 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 - Example:: + 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略:: - collator.set_ignore('field1', 'field2') + dataloader.set_ignore('field1', 'field2') - :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 - __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 - :return: 返回 Collator 自身 + :param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 + field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; + 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 + :return: 返回使用的 collator """ collator = self._get_collator() if isinstance(collator, Collator): @@ -215,7 +218,7 @@ class TorchDataLoader(DataLoader): def prepare_torch_dataloader(ds_or_db, batch_size: int = 16, - shuffle: bool = False, + shuffle: bool = None, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', @@ -227,55 +230,56 @@ def prepare_torch_dataloader(ds_or_db, non_train_batch_size: int = None) \ -> Union[TorchDataLoader, Dict[str, TorchDataLoader]]: """ - ``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader``对象, 详见 :class:`~fastNLP.TorchDataLoader`。 + ``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader`` 对象, 详见 :class:`~fastNLP.TorchDataLoader`。 根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: - * 当 ds_or_db 为 ``DataSet``时,``prepare_torch_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 - 帮你实例化一个 ``TorchDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.TorchDataLoader`。 + * 当 ds_or_db 为 ``DataSet`` 时,``prepare_torch_dataloader`` 会将使用的除了 ``non_train_batch_size`` 和 ``non_train_sampler`` 以外的参数来 + 帮你实例化一个 ``TorchDataLoader`` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.TorchDataLoader`。 * 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_torch_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value - 来创建不同的 ``TorchDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_torch_dataloader`` 默认该 value 为 train 数据集, - 会将 batch_size 和 sampler 作为参数,其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。 - 最终根据 ``key: TorchDataLoader`` 组成 ``Dict[key, TorchDataLoader]`` 的字典返回。 + 来创建不同的 ``TorchDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_torch_dataloader`` 默认该 value 为训练数据集, + 会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。 + 最终根据 ``key: TorchDataLoader`` 组成 ``Dict[key, TorchDataLoader]`` 的字典返回。 * 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_torch_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 - ``TorchDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_torch_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, - 其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: TorchDataLoader`` 组成 - ``Dict[key, TorchDataLoader]`` 的字典返回。 + ``TorchDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_torch_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和 ``sampler`` 作为参数, + 其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: TorchDataLoader`` 组成 + ``Dict[key, TorchDataLoader]`` 的字典返回。 :param ds_or_db: 可以有以下三种取值, - * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 - * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 - * ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为:class:`~fastNLP.TorchDataLoader` + * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; + * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; + * ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为 :class:`~fastNLP.TorchDataLoader`; :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 - :param non_train_batch_size: 非 'train' 数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 - :param shuffle: 是否打乱数据集, 默认为 ``False``。 + :param non_train_batch_size: 非训练数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。 + :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , + 其它的为 False 。 :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , - 默认为None, 当其不为 None 时, shuffle 参数无效。 - :param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , + 默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。 + :param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , 默认为None, 当其不为 None 时, shuffle 参数无效。 :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 - dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 - :param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 + dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。 + :param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. - * callate_fn 为 'None' 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, - ``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 - :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 - * callate_fn 为 ``'auto'`` 时,`TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 - 此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 - * `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 - dataset 的一条数据;该 Callable 函数还应当返回一个对象。 + * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, + ``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 + :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 + * callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 + 此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 + * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 + dataset 的一条数据;该 Callable 函数还应当返回一个对象。 :param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 :param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 :param timeout: 子进程的输出队列获取数据的超时值 - :param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 + :param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。 :param multiprocessing_context: 多进程的上下文环境 - :param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个``base_seed`` - :param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2``意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . + :param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` + :param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . :param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` """ @@ -287,7 +291,7 @@ def prepare_torch_dataloader(ds_or_db, for name, ds in ds_or_db.iter_datasets(): if 'train' in name: dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, - shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, + shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, @@ -297,7 +301,7 @@ def prepare_torch_dataloader(ds_or_db, else: dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, - shuffle=shuffle, + shuffle=False if shuffle is None else shuffle, sampler=non_train_sampler if non_train_sampler else sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, @@ -313,7 +317,7 @@ def prepare_torch_dataloader(ds_or_db, for name, ds in ds_or_db.items(): if 'train' in name: dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, - shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, + shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, @@ -323,7 +327,7 @@ def prepare_torch_dataloader(ds_or_db, else: dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, - shuffle=shuffle, + shuffle=False if shuffle is None else shuffle, sampler=non_train_sampler if non_train_sampler else sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, @@ -337,7 +341,7 @@ def prepare_torch_dataloader(ds_or_db, elif isinstance(ds_or_db, HasLenGetitemType): dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, - shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, + shuffle=False if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, diff --git a/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py b/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py index 6b9b9f4d..bd5cd176 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py +++ b/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py @@ -101,6 +101,19 @@ class MixDataLoader(DataLoader): """ 针对以下四种情况提供的 ``MixDataLoader``, 目前只支持 ``torch`` 框架的版本, 其中 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: + * 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 ``datasets`` 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 + 接一个的 sample 完所有数据。 + * 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 ``datasets`` 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机 sample + 混合数据集 datasets 的数据组成一个 batch 序列返回。 + * 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 ``datasets`` 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回, + 再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。 + * 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int]; + 且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0 0`` 时, ``MixDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快数据处理速度,但同时 + 也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 + :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 ``None`` 有效。 且 datasets 上所有 dataset 的 batch_size 一致。 + :param drop_last: 当 ``drop_last=True`` 时,``MixDataLoader`` 会扔掉 datasets 中 每个 dataset 最后一个长度小于 ``batch_size`` 的 batch 数据; + 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 + :param ds_ratio: ``ds_ratio`` 是控制 datasets 怎么组成一个混合大数据集的重要参数, 其取值为 ``[None, 'truncate_to_least', 'pad_to_most', List[float], Dict[str, float]]``: + + * ds_ratio 为 ``None``, datasets 数据集序列或字典不进行数据扩充处理。 + * ds_ratio 为 ``'truncate_to_least'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最断长度 ``mix_len``, 其他数据集会被切断 + 到最短长度 ``mix_len``。这种切断不是物理上切断,``MixDataLoader`` 会根据 sampler 不同来采样数据集到指定的最短长度 ``mix_len``。 + * ds_ratio 为 ``'pad_to_most'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最大长度 ``max_len``, 其他其他数据集会扩充 + 到最大长度 ``mix_len``。这种扩充不是物理上扩充, ``MixDataLoader`` 会根据 sampler 不同来重采样 dataset 到指定的最大长度 ``max_len``。 + * ds_ratio 为 ``Dict[str, float]`` 时, datasets 类型也必须为 ``Dict[str, DataSet]``, 其 key 一一对应。 ds_ratio 的 value 是任意大于 0 的浮点数, + 代表着 datasets 的 value 数据进行扩充或者缩减的倍数。 + """ def __init__(self, datasets: Dict = None, mode: str = 'sequential', @@ -119,55 +166,6 @@ class MixDataLoader(DataLoader): num_workers: int = 0, batch_size: int = 16, drop_last=False, ds_ratio: Union[None, str, Dict[str, float]] = None, pin_memory: bool = False) -> None: - """ - - :param datasets: 实现了 __getitem__() 和 __len__() 对象的序列或者字典。 - :param mode: mode 控制 ``MixDataLoader`` 运行模式。 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: - - * 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 - 接一个的 sample 完所有数据。 - * 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample - 混合数据集 datasets 的数据组成一个 batch 序列返回。 - * 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 datasets 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回, - 再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。 - * 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int]; - 且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0 0`` 时, ``MixDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快数据处理速度,但同时 - 也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 - :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 且 datasets 上所有 dataset 的 batch_size 一致。 - :param drop_last: 当 ``drop_last=True`` 时,``MixDataLoader`` 会扔掉 datasets 中 每个 dataset 最后一个长度小于 ``batch_size`` 的 batch 数据; - 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 - :param ds_ratio: ``ds_ratio`` 是控制 datasets 怎么组成一个混合大数据集的重要参数, 其取值为 ``[None, 'truncate_to_least', 'pad_to_most', List[float], Dict[str, float]]``: - - * ds_ratio 为 ``None``, datasets 数据集序列或字典不进行数据扩充处理。 - * ds_ratio 为 ``'truncate_to_least'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最断长度 ``mix_len``, 其他数据集会被切断 - 到最短长度 ``mix_len``。这种切断不是物理上切断,``MixDataLoader`` 会根据 sampler 不同来采样数据集到指定的最短长度 ``mix_len``。 - * ds_ratio 为 ``'pad_to_most'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最大长度 ``max_len``, 其他其他数据集会扩充 - 到最大长度 ``mix_len``。这种扩充不是物理上扩充, ``MixDataLoader`` 会根据 sampler 不同来重采样 dataset 到指定的最大长度``max_len``。 - * ds_ratio 为 ``Dict[str, float]`` 时, datasets 类型也必须为 ``Dict[str, DataSet]``, 其 key 一一对应。 ds_ratio 的 value 是任意大于 0 的浮点数, - 代表着 datasets 的 value 数据进行扩充或者缩减的倍数。 - """ # sampler 为 dict,则判断是否与 datasets 的 key 相同 if isinstance(sampler, Dict): for key in datasets.keys(): diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index d905101f..4a648b99 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -1,4 +1,5 @@ -from typing import Callable, Any, Union +import os +from typing import Callable, Any, Union, Sequence from abc import ABC import inspect import ast @@ -6,13 +7,14 @@ import ast from ..log import logger from ..utils.cache_results import get_func_calls, truncate_start_blanks __all__ = [ - "indice_collate_wrapper" + "indice_collate_wrapper", + "OverfitDataLoader" ] def indice_collate_wrapper(func:Callable): """ - 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 + 其功能是封装一层 collate_fn,将 dataset 取到的 tuple 数据分离开,将 idx 打包为 indices。 :param func: 需要修饰的函数 :return: @@ -111,6 +113,40 @@ class HasLenGetitemType(ABC): return NotImplemented +class OverfitDataLoader: + """ + 实现一个简单的迭代器来模拟实际的 dataloader,从给定的 ``dataloader`` 中取出部分数据,来让 Trainer 实现 overfit 的功能; + """ + + def __init__(self, dataloader, overfit_batches: int, batches=None): + # batches 参数是给重新初始化dataloader使用的 + self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作; + if batches is None: + self.batches = [] + self.overfit_batches = int(overfit_batches) + + if self.overfit_batches > len(dataloader): + logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.") + + for idx, batch in enumerate(dataloader): + if idx < self.overfit_batches or self.overfit_batches <= -1: + self.batches.append(batch) + else: + assert isinstance(batches, list) + self.batches = batches + + def __len__(self): + return len(self.batches) + + def __iter__(self): + for batch in self.batches: + yield batch + + def __getattr__(self, item): + return getattr(self.dataloader, item) + + + if __name__ == '__main__': def demo(*args, **kwargs): pass diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index fff8b5c2..f91bc930 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -1,7 +1,7 @@ r""" :class:`~fastNLP.core.dataset.DataSet` 是 fastNLP 中用于承载数据的容器。可以将 DataSet 看做是一个表格, -每一行是一个 sample (在 fastNLP 中被称为 :mod:`~fastNLP.core.instance` ), -每一列是一个 feature (在 fastNLP 中称为 :mod:`~fastNLP.core.field` )。 +每一行是一个 sample (在 fastNLP 中被称为 :mod:`~fastNLP.core.dataset.instance` ), +每一列是一个 feature (在 fastNLP 中称为 :mod:`~fastNLP.core.dataset.field` )。 .. csv-table:: Following is a demo layout of DataSet :header: "sentence", "words", "seq_len" @@ -11,7 +11,7 @@ r""" "Third instance .", "[Third, instance, .]", 3 "...", "[...]", "..." -在 fastNLP 内部每一行是一个 :class:`~fastNLP.Instance` 对象; 每一列是一个 :class:`~fastNLP.FieldArray` 对象。 +在 fastNLP 内部每一行是一个 :class:`~fastNLP.core.dataset.Instance` 对象; 每一列是一个 :class:`~fastNLP.core.dataset.FieldArray` 对象。 ---------------------------- 1.DataSet的创建 @@ -65,7 +65,7 @@ r""" 2.DataSet 与预处理 -------------------------------------- -常见的预处理有如下几种 +常见的预处理有如下几种: 2.1 从某个文本文件读取内容 -------------------------------------- @@ -97,10 +97,10 @@ r""" # 将句子分成单词形式, 详见DataSet.apply()方法, 可以开启多进程来加快处理, 也可以更改展示的bar,目前支持 ``['rich', 'tqdm', None]``, # 详细内容可以见 :class:`~fastNLP.core.dataset.DataSet`, 需要注意的时匿名函数不支持多进程 dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words', - progress_des='Main',progress_bar='rich') + progress_des='Main',progress_bar='rich') # 或使用DataSet.apply_field() dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words', - progress_des='Main',progress_bar='rich') + progress_des='Main',progress_bar='rich') # 除了匿名函数,也可以定义函数传递进去 def get_words(instance): sentence = instance['sentence'] @@ -145,8 +145,8 @@ r""" # DataSet 的长度 len(dataset) - """ + __all__ = [ "DataSet", "ApplyResultException" @@ -255,34 +255,31 @@ def _multi_proc(ds, _apply_field, func, counter, queue): class DataSet: r""" - fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset` - """ + fastNLP的数据容器。 - def __init__(self, data: Union[List[Instance], Dict[str, List[Any]], None] = None): - r""" - 初始化 ``DataSet``, fastNLP的 DataSet 是 key-value 存储形式, 目前支持两种初始化方式,输入 data 分别为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 和 - ``Dict[str, List[Any]]``。 - - * 当 data 为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 - Instance 详见 :class:`~fastNLP.core.dataset.Instance` 。 - * 当 data 为 ``Dict[str, List[Any]] 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 + Example:: - :param data: 初始化的内容, 其只能为两种类型,分别为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 和 - ``Dict[str, List[Any]]``。 + from fastNLP.core.dataset import DataSet, Instance + data = {'x': [[1, 0, 1], [0, 1, 1], 'y': [0, 1]} + data1 = [Instance(x=[1,0,1],y=0), Instance(x=[0,1,1],y=1)] + ds = DataSet(data) + ds = DataSet(data1) - * 当 data 为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 - Instance 详见 :class:`~fastNLP.core.dataset.Instance` 。 - * 当 data 为 ``Dict[str, List[Any]] 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 + fastNLP的 DataSet 是 key-value 存储形式, 目前支持两种初始化方式,输入 data 分别为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 和 + ``Dict[str, List[Any]]``。 - Example:: + * 当 data 为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 + Instance 详见 :class:`~fastNLP.core.dataset.Instance` 。 + * 当 data 为 ``Dict[str, List[Any]]`` 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 - from fastNLP.core.dataset import DataSet, Instance - data = {'x': [[1, 0, 1], [0, 1, 1], 'y': [0, 1]} - data1 = [Instance(x=[1,0,1],y=0), Instance(x=[0,1,1],y=1)] - ds = DataSet(data) - ds = DataSet(data1) + :param data: 初始化的内容,其只能为两种类型,分别为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 和 + ``Dict[str, List[Any]]``。 - """ + * 当 data 为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 + Instance 详见 :class:`~fastNLP.core.dataset.Instance` 。 + * 当 data 为 ``Dict[str, List[Any]] 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 + """ + def __init__(self, data: Union[List[Instance], Dict[str, List[Any]], None] = None): self.field_arrays = {} self._collator = Collator() if data is not None: @@ -429,10 +426,9 @@ class DataSet: def append(self, instance: Instance) -> None: r""" - 将一个 instance 对象 append 到 DataSet 后面。详见 :class:`~fastNLP.Instance` - - :param instance: 若 DataSet 不为空,则 instance 应该拥有和 DataSet 完全一样的 field。 + 将一个 ``instance`` 对象 append 到 DataSet 后面。详见 :class:`~fastNLP.core.dataset.Instance` + :param instance: 若 DataSet 不为空,则 instance 应该拥有和 DataSet 完全一样的 field; """ if len(self.field_arrays) == 0: # DataSet has no field yet @@ -445,7 +441,7 @@ class DataSet: "DataSet object has {} fields, but attempt to append an Instance object with {} fields." .format(len(self.field_arrays), len(instance.fields))) for name, field in instance.items(): - assert name in self.field_arrays + assert name in self.field_arrays, f'Field:`{name}` is not found in {self.field_arrays.keys()}' try: self.field_arrays[name].append(field) except Exception as e: @@ -454,10 +450,10 @@ class DataSet: def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None: r""" - 将 fieldarray 添加到 DataSet 中. + 将 ``fieldarray`` 添加到 DataSet 中. - :param field_name: 新加入的 field 的名称 - :param fieldarray: 需要加入 DataSet 的 field 的内容, 详见 :class:`~fastNLP.core.dataset.FieldArray` + :param field_name: 新加入的 field 的名称; + :param fieldarray: 需要加入 DataSet 的 field 的内容, 详见 :class:`~fastNLP.core.dataset.FieldArray` ; :return: """ if not isinstance(fieldarray, FieldArray): @@ -472,8 +468,8 @@ class DataSet: r""" 新增一个 field, 需要注意的是 fields 的长度跟 DataSet 长度一致 - :param field_name: 新增的 field 的名称 - :param fields: 需要新增的 field 的内容 + :param field_name: 新增的 field 的名称; + :param fields: 需要新增的 field 的内容; """ if len(self.field_arrays) != 0: @@ -484,9 +480,9 @@ class DataSet: def delete_instance(self, index: int): r""" - 删除第 ``index `` 个 Instance + 删除第 ``index`` 个 Instance - :param index: 需要删除的 instanc e的 index,序号从 `0` 开始。 + :param index: 需要删除的 instance 的 index,序号从 `0` 开始。 """ assert isinstance(index, int), "Only integer supported." if len(self) <= index: @@ -500,9 +496,9 @@ class DataSet: def delete_field(self, field_name: str): r""" - 删除名为 field_name 的 field + 删除名为 ``field_name`` 的 field - :param field_name: 需要删除的 field 的名称. + :param field_name: 需要删除的 field 的名称; """ if self.has_field(field_name): self.field_arrays.pop(field_name) @@ -512,11 +508,11 @@ class DataSet: def copy_field(self, field_name: str, new_field_name: str): r""" - 深度 copy 名为 field_name 的 field 到 new_field_name + 深度 copy 名为 ``field_name`` 的 field 到 ``new_field_name`` - :param field_name: 需要 copy 的 field。 - :param new_field_name: copy 生成的 field 名称 - :return: self + :param field_name: 需要 copy 的 field; + :param new_field_name: copy 生成的 field 名称; + :return: 数据集自身; """ if not self.has_field(field_name): raise KeyError(f"Field:{field_name} not found in DataSet.") @@ -527,10 +523,10 @@ class DataSet: def has_field(self, field_name: str) -> bool: r""" - 判断 DataSet 中是否有名为 field_name 这个 field + 判断 DataSet 中是否有名为 ``field_name`` 这个 field - :param field_name: field 的名称 - :return: 表示是否有名为 field_name 这个 field + :param field_name: field 的名称; + :return: 表示是否有名为 ``field_name`` 这个 field; """ if isinstance(field_name, str): return field_name in self.field_arrays @@ -538,10 +534,10 @@ class DataSet: def get_field(self, field_name: str) -> FieldArray: r""" - 获取 field_name 这个 field + 获取名为 ``field_name`` 的 field - :param field_name: field 的名称 - :return: :class:`~fastNLP.FieldArray` + :param field_name: field 的名称; + :return: 一个 :class:`~fastNLP.core.dataset.FieldArray` 对象; """ if field_name not in self.field_arrays: raise KeyError("Field name {} not found in DataSet".format(field_name)) @@ -549,17 +545,13 @@ class DataSet: def get_all_fields(self) -> dict: r""" - 返回一个 dict,key 为 field_name, value为对应的 :class:`~fastNLP.FieldArray` - - :return: 返回如上所述的字典 + :return: 一个 dict,key 为 field_name, value为对应的 :class:`~fastNLP.core.dataset.FieldArray` 对象。 """ return self.field_arrays def get_field_names(self) -> list: r""" - 返回一个 list,包含所有 field 的名字 - - :return: 返回如上所述的列表 + :return: 一个 list,包含所有 field 的名字 """ return sorted(self.field_arrays.keys()) @@ -575,8 +567,8 @@ class DataSet: r""" 将某个 field 重新命名. - :param field_name: 原来的 field 名称。 - :param new_field_name: 修改为 new_name。 + :param field_name: 原来的 field 名称; + :param new_field_name: 修改为 new_name; """ if field_name in self.field_arrays: self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) @@ -589,13 +581,13 @@ class DataSet: new_field_name: str = None, num_proc: int = 0, progress_desc: str = None, progress_bar: str = 'rich'): r""" - 将 :class:`DataSet` 每个 ``instance`` 中为 ``field_name`` 的 ``field`` 传给函数 ``func``,并写入到 ``new_field_name`` + 将 :class:`DataSet` 每个 ``instance`` 中为 ``field_name`` 的 field 传给函数 ``func``,并写入到 ``new_field_name`` 中。 - :param field_name: 传入 ``func`` 的 ``field`` 名称; - :param func: 对指定 ``field`` 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 ``field`` 的内容; + :param func: 对指定 fiel` 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; + :param field_name: 传入 ``func`` 的 field 名称; :param new_field_name: 函数执行结果写入的 ``field`` 名称。该函数会将 ``func`` 返回的内容放入到 ``new_field_name`` 对 - 应的 ``field`` 中,注意如果名称与已有的 ``field`` 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 ``field`` ; + 应的 ``field`` 中,注意如果名称与已有的 field 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 field ; :param num_proc: 使用进程的数量。 .. note:: @@ -603,8 +595,8 @@ class DataSet: 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, ``func`` 函数中的打印将不会输出。 - :param progress_desc: 进度条的描述字符,默认为 ``Processing``; - :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 + :param progress_desc: 如果不为 ``None``,则会显示当前正在处理的进度条的名称; + :param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 :return: 从函数 ``func`` 中得到的返回值; """ assert len(self) != 0, "Null DataSet cannot use apply_field()." @@ -625,26 +617,27 @@ class DataSet: modify_fields: bool = True, num_proc: int = 0, progress_desc: str = None, progress_bar: str = 'rich'): r""" - 将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。 - func 可以返回一个或多个 field 上的结果。 + 将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的 field 传给 ``func``,并获取它的返回值。 + ``func`` 可以返回一个或多个 field 上的结果。 .. note:: - ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 + ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.core.dataset.DataSet.apply_more` 中关于 ``apply_more`` 与 ``apply`` 区别的介绍。 - :param field_name: 传入func的是哪个field。 - :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 - :param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True + :param func: 对指定 fiel` 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; + :param field_name: 传入 ``func`` 的 fiel` 名称; + :param new_field_name: 函数执行结果写入的 ``field`` 名称。该函数会将 ``func`` 返回的内容放入到 ``new_field_name`` 对 + 应的 ``field`` 中,注意如果名称与已有的 field 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 field ; :param num_proc: 使用进程的数量。 - + .. note:: - + 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, ``func`` 函数中的打印将不会输出。 - :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 - :param progress_desc: 当显示 progress_bar 时,显示当前正在处理的进度条描述字符 - :return Dict[str:Field]: 返回一个字典 + :param progress_desc: 如果不为 ``None``,则会显示当前正在处理的进度条的名称; + :param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 + :return: 返回一个字典 """ assert len(self) != 0, "Null DataSet cannot use apply_field()." if not self.has_field(field_name=field_name): @@ -747,7 +740,7 @@ class DataSet: def apply_more(self, func: Callable = None, modify_fields: bool = True, num_proc: int = 0, progress_desc: str = '', progress_bar: str = 'rich'): r""" - 将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。 + 将 ``DataSet`` 中每个 ``Instance`` 传入到 ``func`` 中,并获取它的返回值。``func``可以返回一个或多个 field 上的结果。 .. note:: ``apply_more`` 与 ``apply`` 的区别: @@ -767,9 +760,9 @@ class DataSet: 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, ``func`` 函数中的打印将不会输出。 - :param progress_desc: 当 progress_bar 不为 None 时,可以显示当前正在处理的进度条名称 - :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 - :return Dict[str:Field]: 返回一个字典 + :param progress_desc: 当 progress_bar 不为 ``None`` 时,可以显示当前正在处理的进度条名称 + :param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 + :return: 返回一个字典 """ assert callable(func), "The func is not callable." assert len(self) != 0, "Null DataSet cannot use apply()." @@ -808,10 +801,11 @@ class DataSet: def apply(self, func: Callable = None, new_field_name: str = None, num_proc: int = 0, progress_bar: str = 'rich', progress_desc: str = ''): """ + 将 ``DataSet`` 中每个 ``Instance`` 传入到 ``func`` 中,并获取它的返回值。``func`` 仅能返回一个结果。 :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 - :param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 - 盖之前的field。如果为None则不创建新的field。 + :param new_field_name: 将 ``func`` 返回的内容放入到 ``new_field_name`` 这个 field中 ,如果名称与已有的 field 相同,则覆 + 盖之前的 field。如果为 ``None`` 则不创建新的 field。 :param num_proc: 使用进程的数量。 .. note:: @@ -819,8 +813,8 @@ class DataSet: 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, ``func`` 函数中的打印将不会输出。 - :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 - :param progress_desc: progress bar 显示的值,默认为空。 + :param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 + :param progress_desc: 如果不为 ``None``,则会显示当前正在处理的进度条的名称。 """ assert callable(func), "The func you provide is not callable." assert len(self) != 0, "Null DataSet cannot use apply()." @@ -838,10 +832,10 @@ class DataSet: def add_seq_len(self, field_name: str, new_field_name='seq_len'): r""" - 将使用 len() 直接对 field_name 中每个元素作用,将其结果作为 sequence length, 并放入 seq_len 这个 field。 + 将使用 :func:`len` 直接对 ``field_name`` 中每个元素作用,将其结果作为 sequence length, 并放入 ``new_field_name`` 这个 field。 :param field_name: 需要处理的 field_name - :param new_field_name: str. 新的 field_name + :param new_field_name: 新的 field_name :return: """ if self.has_field(field_name=field_name): @@ -852,10 +846,10 @@ class DataSet: def drop(self, func: Callable, inplace=True): r""" - 删除某些 Instance。 需要注意的时func 接受一个 Instance ,返回 bool 值。返回值为 True 时, + 删除某些 Instance。 需要注意的是 ``func`` 接受一个 Instance ,返回 bool 值。返回值为 ``True`` 时, 该 Instance 会被移除或者不会包含在返回的 DataSet 中。 - :param func: 接受一个 Instance 作为参数,返回 bool 值。为 True 时删除该 instance + :param func: 接受一个 Instance 作为参数,返回 bool 值。为 ``True`` 时删除该 instance :param inplace: 是否在当前 DataSet 中直接删除 instance;如果为 False,将返回一个新的 DataSet。 :return: DataSet @@ -875,11 +869,11 @@ class DataSet: def split(self, ratio: float, shuffle=True): r""" - 将 DataSet 按照 ratio 的比例拆分,返回两个 DataSet + 将 DataSet 按照 ``ratio`` 的比例拆分,返回两个 DataSet - :param ratio: 0 1, f'DataSet with {len(self)} instance cannot be split.' assert isinstance(ratio, float) @@ -906,9 +900,9 @@ class DataSet: def save(self, path: str) -> None: r""" - 保存DataSet. + 保存 DataSet。 - :param path: 将DataSet存在哪个路径 + :param path: 保存路径; """ with open(path, 'wb') as f: pickle.dump(self, f) @@ -916,10 +910,10 @@ class DataSet: @staticmethod def load(path: str): r""" - 从保存的 DataSet pickle文件的路径中读取DataSet + 从保存的 DataSet pickle 文件的路径中读取 DataSet - :param path: 从哪里读取 DataSet - :return: 读取后的 :class:`~fastNLP.读取后的DataSet`。 + :param path: 读取路径; + :return: 读取出的 DataSet """ with open(path, 'rb') as f: d = pickle.load(f) @@ -928,16 +922,16 @@ class DataSet: def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet': """ - 将当前 dataset 与输入的 dataset 结合成一个更大的 dataset,需要保证两个 dataset 都包含了相同的 field。结合后的 dataset - 的 field_name 和 _collator 以当前 dataset 为准。当 dataset 中包含的 field 多于当前的 dataset,则多余的 field 会被忽略; - 若 dataset 中未包含所有当前 dataset 含有 field,则会报错。 + 将当前 DataSet 与输入的 ``dataset`` 结合成一个更大的 dataset,需要保证两个 dataset 都包含了相同的 field。结合后的 dataset + 的 field_name 和 _collator 以当前 dataset 为准。若 ``dataset`` 中包含的 field 多于当前的 DataSet,则多余的 field 会被忽略; + 若 ``dataset`` 中未包含所有当前 DataSet 含有 field,则会报错。 - :param dataset: 需要和当前 dataset concat的 dataset - :param inplace: 是否直接将 dataset 组合到当前 dataset 中 - :param field_mapping: 当传入的 dataset 中的 field 名称和当前 dataset 不一致时,需要通过 field_mapping 把输入的 dataset 中的 - field 名称映射到当前 field. field_mapping 为 dict 类型,key 为 dataset 中的 field 名称,value 是需要映射成的名称 + :param dataset: 需要和当前 DataSet 拼接的 ``dataset``; + :param inplace: 是否直接将 ``dataset`` 组合到当前 DataSet 中; + :param field_mapping: 当传入的 ``dataset`` 中的 field 名称和当前 dataset 不一致时,需要通过 ``field_mapping`` 把输入的 ``dataset`` + 中的 field 名称映射到当前 field。``field_mapping`` 为 dict 类型,key 为 11dataset`` 中的 field 名称,value 是需要映射成的名称 - :return: :class:`~fastNLP.core.dataset.DataSet`` + :return: :class:`~fastNLP.core.dataset.DataSet` """ assert isinstance(dataset, DataSet), "Can only concat two datasets." @@ -966,7 +960,8 @@ class DataSet: @classmethod def from_pandas(cls, df): """ - 从 ``pandas.DataFrame`` 中读取数据转为 DataSet + 从 :class:`pandas.DataFrame` 中读取并数据转化为 DataSet + :param df: 使用 pandas 读取的数据 :return: """ @@ -975,7 +970,7 @@ class DataSet: def to_pandas(self): """ - 将 DataSet 数据转为 ``pandas.DataFrame`` 类型的数据 + 将 DataSet 数据转为 :class:`pandas.DataFrame` 类型的数据 :return: """ @@ -1003,23 +998,22 @@ class DataSet: def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, pad_fn: Callable = None) -> Collator: """ - ``DataSet`` 中想要对绑定的 collator 进行调整可以调用此函数。 ``collator`` 为 :class:`~fastNLP.core.collators.Collator` - 时该函数才有效。调用该函数可以对 field 内容的 pad_val, dtype, backend 等进行调整。 - - :param field_name: 需要调整的 field 的名称。如果 DataSet 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); - 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 - 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 - :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 - field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 - 无意义。 - :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选['raw', 'numpy', 'torch', 'torch', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, - torch.Tensor, torch.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 - :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 - batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch - 形式,输出将被直接作为结果输出。 - :return: 返回 Collator + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + + :param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 + field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; + 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 + 如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, + 该值无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 + :param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, + :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 + 若 ``pad_val`` 为 ``None`` ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 + :return: 返回自身的 collator; """ if isinstance(self.collator, Collator): self.collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) @@ -1030,16 +1024,14 @@ class DataSet: def set_ignore(self, *field_names) -> Collator: """ ``DataSet`` 中想要对绑定的 collator 进行调整可以调用此函数。 ``collator`` 为 :class:`~fastNLP.core.collators.Collator` - 时该函数才有效。调用该函数可以设置忽略输出某些 field 的内容,被设置的 field 将在 batch 的输出中被忽略。 - - Example:: + 时该函数才有效。调用该函数可以设置忽略输出某些 field 的内容,被设置的 field 将在 batch 的输出中被忽略:: - collator.set_ignore('field1', 'field2') + dataset.set_ignore('field1', 'field2') - :param field_names: 需要忽略的 field 的名称。如果 DataSet 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 - __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 - :return: 返回 Collator 自身 + :param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 + field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; + 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 + :return: 返回自身的 collator; """ if isinstance(self.collator, Collator): self.collator.set_ignore(*field_names) diff --git a/fastNLP/core/dataset/field.py b/fastNLP/core/dataset/field.py index 6bb94416..e9795885 100644 --- a/fastNLP/core/dataset/field.py +++ b/fastNLP/core/dataset/field.py @@ -14,15 +14,14 @@ import numpy as np class FieldArray: + """ + :class:`~fastNLP.core.dataset.DatSet` 中用于表示列的数据类型。 - def __init__(self, name: str, content): - """ - 初始化 FieldArray - - :param name: 字符串的名称 - :param content: 任意类型的数据 + :param name: 字符串的名称 + :param content: 任意类型的数据 + """ - """ + def __init__(self, name: str, content): if len(content) == 0: raise RuntimeError("Empty fieldarray is not allowed.") _content = content @@ -36,18 +35,15 @@ class FieldArray: def append(self, val: Any) -> None: r""" - :param val: 把该 val append 到 fieldarray。 - :return: - + :param val: 把该 ``val`` 添加到 fieldarray 中。 """ self.content.append(val) def pop(self, index: int) -> None: r""" - 删除该 field 中 index 处的元素 + 删除该 field 中 ``index`` 处的元素 :param index: 从 ``0`` 开始的数据下标。 - :return: """ self.content.pop(index) @@ -60,10 +56,10 @@ class FieldArray: def get(self, indices: Union[int, List[int]]): r""" - 根据给定的 indices 返回内容。 + 根据给定的 ``indices`` 返回内容。 - :param indices: 获取 indices 对应的内容。 - :return: 根据给定的 indices 返回的内容,可能是单个值 或 ``ndarray`` + :param indices: 获取 ``indices`` 对应的内容。 + :return: 根据给定的 ``indices`` 返回的内容,可能是单个值 或 :class:`numpy.ndarray` """ if isinstance(indices, int): if indices == -1: @@ -80,16 +76,16 @@ class FieldArray: r""" 返回长度 - :return length: + :return: """ return len(self.content) def split(self, sep: str = None, inplace: bool = True): r""" - 依次对自身的元素使用 ``.split()`` 方法,应该只有当本 field 的元素为 ``str`` 时,该方法才有用。 + 依次对自身的元素使用 ``.split()`` 方法,应该只有当本 field 的元素为 :class:`str` 时,该方法才有用。 :param sep: 分割符,如果为 ``None`` 则直接调用 ``str.split()``。 - :param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 + :param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 :class:`list`。 :return: List[List[str]] or self """ new_contents = [] @@ -104,10 +100,11 @@ class FieldArray: def int(self, inplace: bool = True): r""" 将本 field 中的值调用 ``int(cell)``. 支持 field 中内容为以下两种情况: - * ['1', '2', ...](即 field 中每个值为 ``str`` 的), - * [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list`` ,``list`` 中的值会被依次转换。) - :param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 + * ['1', '2', ...](即 field 中每个值为 :class:`str` 的), + * [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。) + + :param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。 :return: List[int], List[List[int]], self """ new_contents = [] @@ -126,10 +123,10 @@ class FieldArray: r""" 将本 field 中的值调用 ``float(cell)``. 支持 field 中内容为以下两种情况: - * ['1', '2', ...](即 field 中每个值为 ``str`` 的), - * [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list`` 中的值会被依次转换。) + * ['1', '2', ...](即 field 中每个值为 :class:`str` 的), + * [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。) - :param inplace: 如果为 ``True``,则将新生成值替换本 ``field``。否则返回 ``list``。 + :param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。 :return: """ new_contents = [] @@ -148,10 +145,10 @@ class FieldArray: r""" 将本field中的值调用 ``bool(cell)``. 支持 field 中内容为以下两种情况 - * ['1', '2', ...](即 field 中每个值为 ``str`` 的), - * [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list`` 中的值会被依次转换。) + * ['1', '2', ...](即 field 中每个值为 :class:`str` 的), + * [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。) - :param inplace: 如果为 ``True``,则将新生成值替换本 ``field``。否则返回 ``list``。 + :param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。 :return: """ new_contents = [] @@ -169,12 +166,12 @@ class FieldArray: def lower(self, inplace=True): r""" - 将本 field 中的值调用 ``cell.lower()``. 支持 field 中内容为以下两种情况 + 将本 field 中的值调用 ``cell.lower()``, 支持 field 中内容为以下两种情况 - * ['1', '2', ...](即 ``field`` 中每个值为 ``str`` 的), - * [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list``中的值会被依次转换。) + * ['1', '2', ...](即 field 中每个值为 :class:`str` 的), + * [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。) - :param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 + :param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。 :return: List[int], List[List[int]], self """ new_contents = [] @@ -191,12 +188,12 @@ class FieldArray: def upper(self, inplace=True): r""" - 将本 field 中的值调用 ``cell.lower()``. 支持 field 中内容为以下两种情况 + 将本 field 中的值调用 ``cell.upper()``, 支持 field 中内容为以下两种情况 - * ['1', '2', ...](即 field 中每个值为 ``str`` 的), - * [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list`` 中的值会被依次转换。) + * ['1', '2', ...](即 field 中每个值为 :class:`str` 的), + * [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。) - :param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 + :param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。 :return: List[int], List[List[int]], self """ new_contents = [] @@ -211,11 +208,11 @@ class FieldArray: raise e return self._after_process(new_contents, inplace=inplace) - def value_count(self): + def value_count(self) -> Counter: r""" - 返回该 field 下不同 value的 数量。多用于统计 label 数量 + 返回该 field 下不同 value 的数量。多用于统计 label 数量 - :return: Counter, key 是 label,value 是出现次数 + :return: 计数结果,key 是 label,value 是出现次数 """ count = Counter() diff --git a/fastNLP/core/dataset/instance.py b/fastNLP/core/dataset/instance.py index be455938..b721472e 100644 --- a/fastNLP/core/dataset/instance.py +++ b/fastNLP/core/dataset/instance.py @@ -1,7 +1,6 @@ r""" -instance 模块实现了 Instance 类在 fastNLP 中对应 sample。一个 sample 可以认为是一个 Instance 类型的对象。 -便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 。 - +instance 模块实现了 Instance 类,即在 fastNLP 中 sample 对应的类型。一个 sample 可以认为是一个 Instance 类型的对象。 +便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset.dataset` 。 """ __all__ = [ @@ -15,9 +14,9 @@ from fastNLP.core.utils.utils import pretty_table_printer class Instance(Mapping): r""" Instance 是 fastNLP 中对应一个 sample 的类。每个 sample 在 fastNLP 中是一个 Instance 对象。 - Instance 一般与 :class:`~fastNLP.DataSet` 一起使用, Instance 的初始化如下面的 Example 所示:: + Instance 一般与 :class:`~fastNLP.DataSet` 一起使用, Instance 的初始化如下面的代码所示:: - >>> instance = Instance(input="this is a demo sentence", label='good') # 请补充完整 + >>> instance = Instance(input="this is a demo sentence", label='good') """ diff --git a/fastNLP/core/drivers/__init__.py b/fastNLP/core/drivers/__init__.py index f9be3180..d775a039 100644 --- a/fastNLP/core/drivers/__init__.py +++ b/fastNLP/core/drivers/__init__.py @@ -3,20 +3,34 @@ __all__ = [ 'TorchDriver', "TorchSingleDriver", "TorchDDPDriver", + "DeepSpeedDriver", "PaddleDriver", "PaddleSingleDriver", "PaddleFleetDriver", "JittorDriver", "JittorSingleDriver", "JittorMPIDriver", + 'TorchSingleDriver', + 'TorchDDPDriver', + 'PaddleDriver', + 'PaddleSingleDriver', + 'PaddleFleetDriver', + 'JittorDriver', + 'JittorSingleDriver', + 'JittorMPIDriver', + 'OneflowDriver', + 'OneflowSingleDriver', + 'OneflowDDPDriver', 'torch_seed_everything', 'paddle_seed_everything', + 'oneflow_seed_everything', 'optimizer_state_to_device' ] -from .torch_driver import TorchDriver, TorchSingleDriver, TorchDDPDriver, torch_seed_everything, optimizer_state_to_device +from .torch_driver import TorchDriver, TorchSingleDriver, TorchDDPDriver, DeepSpeedDriver, torch_seed_everything, optimizer_state_to_device from .jittor_driver import JittorDriver, JittorMPIDriver, JittorSingleDriver from .paddle_driver import PaddleDriver, PaddleFleetDriver, PaddleSingleDriver, paddle_seed_everything +from .oneflow_driver import OneflowDriver, OneflowSingleDriver, OneflowDDPDriver, oneflow_seed_everything from .driver import Driver diff --git a/fastNLP/core/drivers/choose_driver.py b/fastNLP/core/drivers/choose_driver.py index 4be1e502..0f173b1c 100644 --- a/fastNLP/core/drivers/choose_driver.py +++ b/fastNLP/core/drivers/choose_driver.py @@ -1,6 +1,7 @@ from typing import Union, Optional, List from .driver import Driver +from ..utils import is_torch_module, is_paddle_module, is_jittor_module, is_oneflow_module def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: @@ -17,7 +18,19 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, if isinstance(driver, Driver): return driver - if driver in {"torch", "fairscale"}: + if driver == "auto": + if is_torch_module(model): + driver = "torch" + elif is_paddle_module(model): + driver = "paddle" + elif is_jittor_module(model): + driver = "jittor" + elif is_oneflow_module(model): + driver = "oneflow" + else: + raise ValueError(f"Cannot choose driver automatically based on model, please set `driver` specifically.") + + if driver in {"torch", "fairscale", "deepspeed"}: from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver return initialize_torch_driver(driver, device, model, **kwargs) elif driver in {"jittor"}: @@ -26,6 +39,9 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, elif driver in {"paddle"}: from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver return initialize_paddle_driver(driver, device, model, **kwargs) + elif driver in {"oneflow"}: + from fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver import initialize_oneflow_driver + return initialize_oneflow_driver(driver, device, model, **kwargs) else: raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale', " - "'jittor', 'paddle'].") \ No newline at end of file + "'jittor', 'paddle', 'oneflow'].") \ No newline at end of file diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index 63ac6ec4..ebcd7bfd 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from fastNLP.envs.imports import _NEED_IMPORT_JITTOR from fastNLP.core.drivers.driver import Driver from fastNLP.core.dataloaders import JittorDataLoader +from fastNLP.core.dataloaders import OverfitDataLoader from fastNLP.core.samplers import ReproducibleSampler, RandomSampler from fastNLP.core.log import logger from fastNLP.core.utils import apply_to_collection, nullcontext @@ -39,20 +40,22 @@ __all__ = [ class JittorDriver(Driver): r""" - ``Jittor`` 框架的 ``Driver`` + ``Jittor`` 框架的 ``Driver``,是 ``JittorSingleDevice`` 和 ``JittorMPIDriver`` 的父类。 - .. note:: + .. warning:: - 这是一个正在开发中的功能,敬请期待。 + 您不应当直接初始化该类,然后传入给 ``Trainer``,换句话说,您应当使用该类的子类 ``JittorSingleDriver`` 和 ``TorchDDPDriver``,而不是 + 该类本身; - .. todo:: + .. note:: - 实现 fp16 的设置,且支持 cpu 和 gpu 的切换; - 实现用于断点重训的 save 和 load 函数; + 您可以在使用 ``JittorSingleDevice`` 和 ``JittorMPIDriver`` 时使用 ``JittorDriver`` 提供的接口; + :param model: 训练时使用的 **jittor** 模型; + :param fp16: 是否开启混合精度训练; + :param jittor_kwargs: """ - - def __init__(self, model, fp16: bool = False, **kwargs): + def __init__(self, model, fp16: bool = False, jittor_kwargs: Dict = None, **kwargs): if not isinstance(model, Module): raise ValueError(f"Parameter `model` can not be `{type(model)}` in `JittorDriver`, it should be exactly " f"`jittor.Module` type.") @@ -64,12 +67,13 @@ class JittorDriver(Driver): jt.flags.auto_mixed_precision_level = 0 self.fp16 = fp16 self._auto_cast = nullcontext + self._jittor_kwargs = jittor_kwargs if jittor_kwargs is not None else {} # 用来设置是否关闭 auto_param_call 中的参数匹配问题; self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) def check_dataloader_legality(self, dataloader): - if not isinstance(dataloader, (Dataset, JittorDataLoader)): + if not isinstance(dataloader, (Dataset, JittorDataLoader, OverfitDataLoader)): raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`") if len(dataloader) == 0: logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " @@ -138,26 +142,12 @@ class JittorDriver(Driver): num_consumed_batches = states.pop('num_consumed_batches') if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): sampler_states = sampler.state_dict() - # 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples - # 会造成多余实际消耗的问题。因为 - num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) - if num_consumed_samples_array is not None: - if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 - if dataloader_args.batch_size is not None: - num_consumed_batches = num_consumed_batches * dataloader_args.batch_size - else: # 有可能 batch_size 为 None,就只有损失精度了 - logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " - "it may cause missing some samples when reload.") - num_consumed_batches = sampler_states['num_consumed_samples'] - sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] - assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." + if dataloader_args.batch_size is not None: + sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ + * num_consumed_batches else: - if dataloader_args.batch_size is not None: - sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ - * num_consumed_batches - else: - logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " - "it may cause missing some samples when reload.") + logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + "it may cause missing some samples when reload.") states['sampler_states'] = sampler_states else: diff --git a/fastNLP/core/drivers/jittor_driver/mpi.py b/fastNLP/core/drivers/jittor_driver/mpi.py index b072b83d..2e3d42c2 100644 --- a/fastNLP/core/drivers/jittor_driver/mpi.py +++ b/fastNLP/core/drivers/jittor_driver/mpi.py @@ -34,10 +34,11 @@ class JittorMPIDriver(JittorDriver): parallel_device: None, is_pull_by_jittor_run: bool = False, fp16: bool = False, + jittor_kwargs: Dict = None, **kwargs ): - super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs) + super(JittorMPIDriver, self).__init__(model, fp16=fp16, jittor_kwargs=jittor_kwargs, **kwargs) raise NotImplementedError("MPI for Jittor is not supported right now.") self.is_pull_by_jittor_run = is_pull_by_jittor_run diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index 4e9b3447..eda11660 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -25,15 +25,6 @@ class JittorSingleDriver(JittorDriver): r""" ``Jittor`` 框架下用于 ``cpu`` 和单卡 ``gpu`` 运算的 ``Driver``。 - .. note:: - - 这是一个正在开发中的功能,敬请期待。 - - .. todo:: - - 支持 cpu 和 gpu 的切换; - 实现断点重训中替换 dataloader 的 set_dist_repro_dataloader 函数 - :param model: 传入给 ``Trainer`` 的 ``model`` 参数; :param device: 训练和模型所在的设备,在 **Jittor** 中,应当为以下值之一:``[None, 'cpu', 'gpu', 'cuda']``; @@ -43,12 +34,13 @@ class JittorSingleDriver(JittorDriver): 表示在显卡设备上进行训练; :param fp16: 是否开启 fp16; + :param jittor_kwargs: """ - def __init__(self, model, device=None, fp16: bool = False, **kwargs): + def __init__(self, model, device=None, fp16: bool = False, jittor_kwargs: Dict = None, **kwargs): if device not in [None, "cpu", "gpu", "cuda"]: raise RuntimeError("Parameter `device` should be one of [None, 'cpu', 'gpu', 'cuda'] .") - super(JittorSingleDriver, self).__init__(model, fp16) + super(JittorSingleDriver, self).__init__(model, fp16, jittor_kwargs=jittor_kwargs) self.model_device = device if device is not None else "cpu" @@ -118,14 +110,14 @@ class JittorSingleDriver(JittorDriver): if args.sampler is None: sampler = RandomSampler(args.dataset, args.shuffle) return replace_sampler(dataloader, sampler) - elif isinstance(args.sampler, JittorRandomSampler): + elif type(args.sampler) is JittorRandomSampler: if getattr(args.sampler, '_num_samples', None) is None \ and getattr(args.sampler, 'rep', False) is False: # 如果本来就是随机的,并且没有定制,直接替换掉吧。 sampler = RandomSampler(args.sampler.dataset, shuffle=True) logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") return replace_sampler(dataloader, sampler) - elif isinstance(args.sampler, JittorSequentialSampler): + elif type(args.sampler) is JittorSequentialSampler: # 需要替换为不要 shuffle 的。 sampler = RandomSampler(args.sampler.dataset, shuffle=False) logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") diff --git a/fastNLP/core/drivers/jittor_driver/utils.py b/fastNLP/core/drivers/jittor_driver/utils.py index c75526df..af840a09 100644 --- a/fastNLP/core/drivers/jittor_driver/utils.py +++ b/fastNLP/core/drivers/jittor_driver/utils.py @@ -14,6 +14,7 @@ from fastNLP.envs import ( FASTNLP_BACKEND_LAUNCH, FASTNLP_GLOBAL_SEED, ) +from fastNLP.core.samplers import ReproducibleBatchSampler from fastNLP.core.log import logger if _NEED_IMPORT_JITTOR: @@ -63,6 +64,9 @@ def replace_batch_sampler(dataloader, batch_sampler): "or report this bug to us.") def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler): + batch_sampler = getattr(dataloader, "sampler") + if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): + raise RuntimeError("It should not be running here, please report a bug to us.") if isinstance(dataloader, JittorDataLoader): init_params = dict(inspect.signature(dataloader.__init__).parameters) reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()} diff --git a/fastNLP/core/drivers/oneflow_driver/__init__.py b/fastNLP/core/drivers/oneflow_driver/__init__.py new file mode 100644 index 00000000..12beffc0 --- /dev/null +++ b/fastNLP/core/drivers/oneflow_driver/__init__.py @@ -0,0 +1,18 @@ +__all__ = [ + "OneflowDDPDriver", + "OneflowSingleDriver", + "OneflowDriver", + "oneflow_seed_everything", + "optimizer_state_to_device" +] + +from .ddp import OneflowDDPDriver +from .single_device import OneflowSingleDriver +from .oneflow_driver import OneflowDriver +from .utils import oneflow_seed_everything, optimizer_state_to_device + + + + + + diff --git a/fastNLP/core/drivers/oneflow_driver/ddp.py b/fastNLP/core/drivers/oneflow_driver/ddp.py new file mode 100644 index 00000000..4a285856 --- /dev/null +++ b/fastNLP/core/drivers/oneflow_driver/ddp.py @@ -0,0 +1,323 @@ +import os +from typing import List, Optional, Union, Dict + +from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW + +if _NEED_IMPORT_ONEFLOW: + import oneflow + import oneflow.comm as comm + import oneflow.env as dist_env + from oneflow.nn.parallel import DistributedDataParallel + from oneflow.utils.data import BatchSampler + +__all__ = [ + "OneflowDDPDriver" +] + +from .oneflow_driver import OneflowDriver +from fastNLP.core.drivers.oneflow_driver.utils import ( + replace_sampler, + replace_batch_sampler +) +from fastNLP.core.utils import check_user_specific_params +from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, \ + ReproducibleBatchSampler, \ + re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler +from fastNLP.envs import FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC +from fastNLP.core.log import logger +from fastNLP.core.drivers.oneflow_driver.dist_utils import fastnlp_oneflow_all_gather, fastnlp_oneflow_broadcast_object +from .utils import _check_dataloader_args_for_distributed + + +class OneflowDDPDriver(OneflowDriver): + r""" + ``OneflowDDPDriver`` 实现了动态图下使用 ``DistributedDataParallel`` 进行的数据并行分布式训练。 + + .. note:: + + 您在绝大多数情况下不需要自己使用到该类,通过向 ``Trainer`` 传入正确的参数,您可以方便快速地部署您的分布式训练; + + ``OneflowDDPDriver`` 目前支持两种启动方式: + + 1. 用户不做任何处理,通过运行 ``python -m oneflow.distributed.launch --nproc_per_node 2 train.py`` 启动; + 2. 用户将模型通过 ``DistributedDataParallel`` 处理后,通过运行 ``python -m oneflow.distributed.launch --nproc_per_node 2 train.py`` 启动; + + 注意多机的启动强制要求用户在每一台机器上使用 ``python -m oneflow.distributed.launch`` 启动;因此我们不会在 ``OneflowDDPDriver`` 中保存 + 任何当前有多少台机器的信息; + + :param model: 传入给 ``Trainer`` 的 ``model`` 参数; + :param parallel_device: 该参数无效,**fastNLP** 会自动获取当前进程的设备; + :param fp16: 是否开启 fp16 训练;目前该参数无效; + :param oneflow_kwargs: + * *ddp_kwargs* -- 用于 ``DistributedDataParallel`` 的其它参数,详情可查阅 **oneflow** 的官方文档; + """ + + def __init__( + self, + model, + parallel_device: Optional["oneflow.device"], + fp16: bool = False, + oneflow_kwargs: Dict = None, + **kwargs + ): + + super(OneflowDDPDriver, self).__init__(model, fp16=fp16, oneflow_kwargs=oneflow_kwargs, **kwargs) + + # oneflow 会自己初始化通信组,因此 parallel_device 实际上不起作用,可以通过 current_device 获取设备 + self.model_device = oneflow.device("cuda", oneflow.cuda.current_device()) + self._data_device = self.model_device + + self.global_rank = int(os.environ["RANK"]) + self.world_size = int(os.environ["WORLD_SIZE"]) + + self._ddp_kwargs = self._oneflow_kwargs.get("ddp_kwargs", {}) + check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__, DistributedDataParallel.__name__) + if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None: + logger.info("Notice your model has buffers and you are using `OneflowDDPDriver`, but you do not set " + "'broadcast_buffers' in your trainer. Cause in most situations, this parameter can be set" + " to 'False' to avoid redundant data communication between different processes.") + + self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") + assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type." + if self.output_from_new_proc not in {"all", "ignore", "only_error"}: + os.makedirs(name=self.output_from_new_proc, exist_ok=True) + self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) + + self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; + self._has_ddpwrapped = False# hasattr(model, ) + + def setup(self): + r""" + 将模型用 ``DistributedDataParallel`` 进行处理; + """ + if self._has_setup: + return + self._has_setup = True + + self.configure_ddp() + self.barrier() + # 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; + # self._pids = [oneflow.tensor(0, dtype=oneflow.int).to(self.data_device) for _ in range(dist_env.get_world_size())] + # comm.all_gather(self._pids, oneflow.tensor(os.getpid(), dtype=oneflow.int).to(self.data_device)) + # local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None + # if local_world_size is None: + # local_world_size = oneflow.tensor(int(os.environ.get("LOCAL_RANK")), dtype=oneflow.int).to(self.data_device) + # comm.all_reduce(local_world_size, op=dist_env.ReduceOp.MAX) + # local_world_size = local_world_size.tolist() + 1 + + # node_rank = self.global_rank // local_world_size + # self._pids = self._pids[node_rank * local_world_size: (node_rank + 1) * local_world_size] + # self._pids = self.tensor_to_numeric(self._pids) + + def configure_ddp(self): + if not hasattr(self.model, "_ddp_state_for_reversed_params"): + self.model.to(self.model_device) + self.model = DistributedDataParallel( + # 注意这里的 self.model_device 是 `oneflow.device` type,因此 self.model_device.index; + self.model, + **self._ddp_kwargs + ) + self._has_ddpwrapped = True + + @property + def master_address(self) -> str: + return os.environ.get("MASTER_ADDR") + + @property + def master_port(self) -> str: + return os.environ.get("MASTER_PORT") + + @property + def world_size(self) -> int: + return self._world_size + + @world_size.setter + def world_size(self, size: int): + self._world_size = size + + @property + def global_rank(self) -> int: + return self._global_rank + + @global_rank.setter + def global_rank(self, rank: int) -> None: + self._global_rank = rank + + @property + def local_rank(self) -> int: # 这个不会受到 all_rank_call_context 的影响 + return int(os.environ.get("LOCAL_RANK", 0)) + + @property + def data_device(self): + return self._data_device + + def set_dist_repro_dataloader(self, dataloader, + dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]] = None, + reproducible: bool = False): + # 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load_checkpoint 函数调用; + # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 OneflowDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; + if isinstance(dist, ReproducibleBatchSampler): + dist.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) + return replace_batch_sampler(dataloader, dist) + if isinstance(dist, ReproducibleSampler): + dist.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) + return replace_sampler(dataloader, dist) + + # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; + # trainer, evaluator + if dist is None: + if reproducible: + raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.") + else: + args = self.get_dataloader_args(dataloader) + if isinstance(args.batch_sampler, ReproducibleBatchSampler): + return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler)) + if isinstance(args.sampler, ReproducibleSampler): + return replace_sampler(dataloader, re_instantiate_sampler(args.sampler)) + return dataloader + # trainer + elif dist == "dist": + args = self.get_dataloader_args(dataloader) + # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; + if isinstance(args.batch_sampler, ReproducibleBatchSampler): + batch_sampler = re_instantiate_sampler(args.batch_sampler) + batch_sampler.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) + return replace_batch_sampler(dataloader, batch_sampler) + elif isinstance(args.sampler, ReproducibleSampler): + sampler = re_instantiate_sampler(args.sampler) + sampler.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) + return replace_sampler(dataloader, sampler) + else: + _check_dataloader_args_for_distributed(args, controller="Trainer") + sampler = RandomSampler( + dataset=args.dataset, + shuffle=args.shuffle, + seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0)) + ) + sampler.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) + return replace_sampler(dataloader, sampler) + # evaluator + elif dist == "unrepeatdist": + args = self.get_dataloader_args(dataloader) + if isinstance(args.sampler, ReproducibleSampler): + sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) + elif not isinstance(args.sampler, UnrepeatedSampler): + _check_dataloader_args_for_distributed(args, controller="Evaluator") + sampler = UnrepeatedSequentialSampler( + dataset=args.dataset + ) + else: + sampler = re_instantiate_sampler(args.sampler) + sampler.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank + ) + batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) + return replace_batch_sampler(dataloader, batch_sampler) + else: + raise ValueError( + "Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") + + def is_global_zero(self): + r""" + :return: 返回当前的进程是否在全局上是进程 0 ; + """ + return self.global_rank == 0 + + def get_model_no_sync_context(self): + r""" + :return: 返回一个 ``context`` 上下文环境,用于关闭各个进程之间的同步;该功能暂时无效,返回一个空的上下文环境; + """ + # TODO 暂时没有在 oneflow 中找到类似的功能; + from fastNLP.core.utils import nullcontext + return nullcontext + return self.model.no_sync + + def unwrap_model(self): + r""" + :return: 返回原始模型; + """ + return self.model + + def get_local_rank(self) -> int: + r""" + :return: 返回当前进程局部的进程编号; + """ + return self.local_rank + + def barrier(self): + r""" + 通过使用该函数来使得各个进程之间同步操作; + """ + if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 + comm.barrier() + + def is_distributed(self): + r""" + :return: 返回当前使用的 driver 是否是分布式的 driver,对于 ``OneflowDDPDriver`` 来说,该函数一定返回 ``True``; + """ + return True + + def broadcast_object(self, obj, src: int = 0, **kwargs): + r""" + 从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 + 传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 + + :param obj: obj,可能是 Tensor 或 嵌套类型的数据 + :param int src: source 的 global rank 。 + :param int dst: target 的 global rank,可以是多个目标 rank + :param group: 所属的 group + :return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 + 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 + """ + if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。 + return + return fastnlp_oneflow_broadcast_object(obj, src, device=self.data_device) + + def all_gather(self, obj) -> List: + r""" + 将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 + pickle 进行序列化,接收到之后再反序列化。 + + example:: + + obj = { + 'a': [1, 1], + 'b': [[1, 2], [1, 2]], + 'c': { + 'd': [1, 2] + } + } + -> + [ + {'a': 1, 'b':[1, 2], 'c':{'d': 1}}, + {'a': 1, 'b':[1, 2], 'c':{'d': 2}} + ] + + :param obj: 需要传输的对象,在每个rank上都应该保持相同的结构。 + :param group: + :return: + """ + if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行 + return [obj] + return fastnlp_oneflow_all_gather(obj) diff --git a/fastNLP/core/drivers/oneflow_driver/dist_utils.py b/fastNLP/core/drivers/oneflow_driver/dist_utils.py new file mode 100644 index 00000000..e84df213 --- /dev/null +++ b/fastNLP/core/drivers/oneflow_driver/dist_utils.py @@ -0,0 +1,306 @@ +import io +import pickle +import os +from typing import Any, List + +from fastNLP.core.utils import apply_to_collection, get_oneflow_device +from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW +from fastNLP.envs.env import FASTNLP_NO_SYNC +if _NEED_IMPORT_ONEFLOW: + import oneflow + import oneflow.comm as comm + import oneflow.env as dist_env + +PROTOCOL_VERSION = 1 + +def _validate_output_list_for_rank(my_rank, dst, gather_list): + if dst == my_rank: + if not gather_list: + raise ValueError( + "Argument ``gather_list`` must be specified on destination rank." + ) + elif gather_list: + raise ValueError( + "Argument ``gather_list`` must NOT be specified " + "on non-destination ranks." + ) + + obj = {"protocol_version": PROTOCOL_VERSION, "data": obj} + pickled_bytes = pickle.dumps(obj) + +def fastnlp_oneflow_gather_object(obj, dst=0): + """ + 从其它 rank gather 东西到 dst rank 。 + + Example:: + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> fastnlp_oneflow_gather_object( + gather_objects[dist.get_rank()], + output if dist.get_rank() == 0 else None, + dst=0 + ) + >>> # On rank 0 + >>> output + ['foo', 12, {1: 2}] + + :param obj: 需要发送的 obj 对象,需要是可以 pickable 的对象 + :param dst: 目标的 rank 。 + :return: 在 dst 上面返回 world_size 的 list,依次为 rank 0;rank 1...上 obj + """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + return [obj] + + if dist_env.get_rank() == dst: + object_gather_list = [None for _ in range(dist_env.get_world_size())] + else: + object_gather_list = None + + # Ensure object_gather_list is specified appopriately. + my_rank = dist_env.get_rank() + _validate_output_list_for_rank(my_rank, dst, object_gather_list) + # 防止 unpickle 的时候出现在了发送的 gpu 上。 + obj = apply_to_collection(obj, oneflow.Tensor, _to_device, device=oneflow.device("cpu")) + input_tensor, local_size = _object_to_tensor(obj) + current_device = oneflow.device("cuda") + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = dist_env.get_world_size() + object_sizes_tensor = oneflow.zeros(group_size, dtype=oneflow.long, device=current_device) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes. An all-gather is needed here despite this being a + # gather, since each rank needs to broadcast a tensor of the same (maximal) + # size. + comm.all_gather(object_size_list, local_size) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor = input_tensor.reshape(max_object_size) + # Avoid populating output tensors if the result won't be gathered on this rank. + if my_rank == dst: + coalesced_output_tensor = oneflow.empty( + max_object_size * group_size, dtype=oneflow.uint8, device=current_device + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + # All ranks call gather with equal-sized tensors. + comm.gather( + input_tensor, + gather_list=output_tensors if my_rank == dst else None, + dst=dst, + ) + if my_rank != dst: + return + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(oneflow.uint8) # type: ignore[call-overload] + tensor_size = object_size_list[i] + object_gather_list[i] = _tensor_to_object(tensor, tensor_size) + + +def _object_to_tensor(obj, device=None): + f = io.BytesIO() + obj = {"protocol_version": PROTOCOL_VERSION, "data": obj} + pickled_bytes = pickle.dumps(obj) + + byte_tensor = oneflow.ByteTensor(list(pickled_bytes)) + local_size = oneflow.LongTensor([byte_tensor.numel()]) + if device is not None: + byte_tensor = byte_tensor.to(device) + local_size = local_size.to(device) + return byte_tensor, local_size + +def _tensor_to_object(tensor, tensor_size): + buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size] + res = pickle.loads(buf) + assert res["protocol_version"] == PROTOCOL_VERSION + return res["data"] + +def send_recv_object(obj, src, cur_rank, device): + r""" + oneflow 中的单点对多点的分发函数; + + 例如将进程 0 上的对象 object 分发到其它进程上; + + Example:: + + cur_rank = int(os.environ.get('LOCAL_RANK', 0)) + + # 拿到 local_device + + send_recv_object(object, 0, cur_rank, local_device) + + :param obj: 一个可以序列化的 python 对象; + :param src: 从哪一个 rank 上发送到其它 rank; + :param cur_rank: 当前的进程的 rank 序号; + :param device: 当前的进程所在的设备; + :param group: 通信组,默认为 None; + :param tag: 将发送与远程接收匹配的标记; + :return: + """ + # src rank send to all other ranks + size = oneflow.LongTensor([0]).to(device) + + if cur_rank == src: + world_size = dist_env.get_world_size() + tensor, size = _object_to_tensor(obj) + tensor = tensor.to(device) + size = size.to(device) + + # 首先同步 obj 的 size 的信息; + comm.broadcast(size, src) + for subrank in range(world_size): + if subrank != src: + comm.send(tensor=tensor, dst=subrank) + else: + comm.broadcast(size, src) + tensor = oneflow.ByteTensor([0] * size).to(device) + comm.recv(tensor=tensor, src=src) + + return _tensor_to_object(tensor.cpu(), size) + + +def _to_device(tensor, device): + return tensor.contiguous().to(device) + + +def fastnlp_oneflow_all_gather(obj: Any, device=None) ->List: + """ + 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 + + example:: + + obj = { + 'a': [1, 1], + 'b': [[1, 2], [1, 2]], + 'c': { + 'd': [1, 2] + } + } + -> + [ + {'a': 1, 'b':[1, 2], 'c':{'d': 1}}, + {'a': 1, 'b':[1, 2], 'c':{'d': 2}} + ] + + :param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行 + 序列化之后进行传输。 + :param device: 当前该参数无意义。 + :param group: + :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 + """ + if int(os.environ.get(FASTNLP_NO_SYNC, "0")) == 2: + return [obj] + + if isinstance(obj, oneflow.Tensor): + objs = [oneflow.zeros_like(obj) for _ in range(dist_env.get_world_size())] + comm.all_gather(objs, obj) + else: + objs = [None for _ in range(dist_env.get_world_size())] + # 防止 unpickle 的时候弄到发送的 gpu 上了 + obj = apply_to_collection(obj, oneflow.Tensor, _to_device, device=oneflow.device("cpu")) + all_gather_object(objs, obj) + return objs + + +def fastnlp_oneflow_broadcast_object(obj, src, device=None): + """ + 将 src 上的 obj 对象广播到其它 rank 上。 + + :param obj: 需要发送的对象 + :param src: 从哪里发出。 + :param device: + :param group: 属于哪个通信 group + :return: + """ + if int(os.environ.get(FASTNLP_NO_SYNC, "0")) == 2: + if src == dist_env.get_rank(): + return obj + else: + return None + + cur_rank = dist_env.get_rank() + if cur_rank == src: + # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 + obj = apply_to_collection(obj, oneflow.Tensor, _to_device, device=oneflow.device("cpu")) + if device is None: + device = oneflow.cuda.current_device() + device = get_oneflow_device(device) + + if cur_rank == src: + tensor, size = _object_to_tensor(obj, device=device) + else: + size = oneflow.LongTensor([0]).to(device) + + comm.broadcast(size, src=src) + if cur_rank != src: + tensor = oneflow.empty( + size.int().item(), # type: ignore[arg-type] + dtype=oneflow.uint8, + device=device + ) + comm.broadcast(tensor, src=src) + + return _tensor_to_object(tensor, tensor_size=size.item()) + +def all_gather_object(object_list, obj): + """ + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> all_gather_object(output, gather_objects[dist.get_rank()]) + >>> output + ['foo', 12, {1: 2}] + + :param object_list: + :param obj: + :param group: + :return: + """ + if int(os.environ.get(FASTNLP_NO_SYNC, "0")) == 2: + return [obj] + + current_device = get_oneflow_device(oneflow.cuda.current_device()) + + input_tensor, local_size = _object_to_tensor(obj, device=current_device) + + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = dist_env.get_world_size() + object_sizes_tensor = oneflow.zeros( + group_size, dtype=oneflow.long, device=current_device + ) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes + comm.all_gather(object_size_list, local_size) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor = input_tensor.reshape(max_object_size) + coalesced_output_tensor = oneflow.empty( + max_object_size * group_size, dtype=oneflow.uint8, device=current_device + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + comm.all_gather(output_tensors, input_tensor) + # Deserialize outputs back to object. + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(oneflow.uint8) + if tensor.device != oneflow.device("cpu"): + tensor = tensor.cpu() + tensor_size = object_size_list[i] + object_list[i] = _tensor_to_object(tensor, tensor_size) + return object_list diff --git a/fastNLP/core/drivers/oneflow_driver/initialize_oneflow_driver.py b/fastNLP/core/drivers/oneflow_driver/initialize_oneflow_driver.py new file mode 100644 index 00000000..2dab1729 --- /dev/null +++ b/fastNLP/core/drivers/oneflow_driver/initialize_oneflow_driver.py @@ -0,0 +1,70 @@ +import os +from typing import Optional, Union, List, Sequence +from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW +if _NEED_IMPORT_ONEFLOW: + import oneflow + +from .oneflow_driver import OneflowDriver +from .single_device import OneflowSingleDriver +from .ddp import OneflowDDPDriver +from fastNLP.core.log import logger +from fastNLP.envs import FASTNLP_BACKEND_LAUNCH + +__all__ = [] + + +def initialize_oneflow_driver(driver: str, device: Optional[Union[str, "oneflow.device", int, List[int]]], + model: "oneflow.nn.Module", **kwargs) -> OneflowDriver: + r""" + 用来根据参数 ``driver` 和 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去; + + :param driver: 该参数的值应为以下之一:``["oneflow"]``; + :param device: 该参数的格式与 ``Trainer`` 对参数 ``device`` 的要求一致; + :param model: 训练或者评测的具体的模型; + + :return: 返回一个 :class:`~fastNLP.core.OneflowSingleDriver` 或 :class:`~fastNLP.core.OneflowDDPDriver` 实例; + """ + # world_size 和 rank + if FASTNLP_BACKEND_LAUNCH in os.environ: + if device is not None: + logger.rank_zero_warning("Parameter `device` would be ignored when you are using `oneflow.distributed.launch` to pull " + "up your script. ", once=True) + return OneflowDDPDriver(model, None, **kwargs) + + if driver not in {"oneflow"}: + raise ValueError("Parameter `driver` can only be one of these values: ['oneflow'].") + + _could_use_device_num = oneflow.cuda.device_count() + if isinstance(device, str): + device = oneflow.device(device) + elif isinstance(device, int): + if device < 0: + if device != -1: + raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") + device = [oneflow.device(f"cuda:{w}") for w in range(_could_use_device_num)] + elif device >= _could_use_device_num: + print(device, _could_use_device_num) + raise ValueError("The gpu device that parameter `device` specifies is not existed.") + else: + device = oneflow.device(f"cuda:{device}") + elif isinstance(device, Sequence): + device = list(set(device)) + for each in device: + if not isinstance(each, int): + raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.") + elif each < 0: + raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.") + elif each >= _could_use_device_num: + raise ValueError(f"When parameter `device` is 'Sequence' type, the value in it should not be bigger than" + f" the available gpu number:{_could_use_device_num}.") + device = [oneflow.device(f"cuda:{w}") for w in device] + elif device is not None and not isinstance(device, oneflow.device): + raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") + + if driver == "oneflow": # single, ddp, 直接启动。 + if not isinstance(device, List): + return OneflowSingleDriver(model, device, **kwargs) + else: + raise RuntimeError("If you want to run distributed training, please use " + "'python -m oneflow.distributed.launch xxx.py'.") + return OneflowDDPDriver(model, device, **kwargs) \ No newline at end of file diff --git a/fastNLP/core/drivers/oneflow_driver/oneflow_driver.py b/fastNLP/core/drivers/oneflow_driver/oneflow_driver.py new file mode 100644 index 00000000..29027738 --- /dev/null +++ b/fastNLP/core/drivers/oneflow_driver/oneflow_driver.py @@ -0,0 +1,445 @@ +import os +from typing import Union, Dict, Optional, Callable, Tuple +from functools import partial +import numpy as np +import random +from dataclasses import dataclass +from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW +from pathlib import Path +if _NEED_IMPORT_ONEFLOW: + import oneflow + from oneflow.utils.data import DataLoader, Sampler, BatchSampler, Dataset + from oneflow.optim import Optimizer + from oneflow.utils.data import RandomSampler as OneflowRandomSampler + _reduces = { + "sum": oneflow.sum, + "min": oneflow.min, + "max": oneflow.max, + "mean": oneflow.mean + } + + +__all__ = [ + "OneflowDriver" +] + +from .utils import optimizer_state_to_device, DummyGradScaler +from fastNLP.core.drivers.driver import Driver +from fastNLP.core.utils.utils import _get_fun_msg, nullcontext +from fastNLP.core.utils import apply_to_collection, oneflow_move_data_to_device, auto_param_call +from fastNLP.envs import rank_zero_call +from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME +from fastNLP.core.log import logger +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler +from fastNLP.core.dataloaders import OverfitDataLoader + + +class OneflowDriver(Driver): + r""" + 专属于 ``oneflow`` 的 ``driver``,是 ``OneflowSingleDriver`` 和 ``OneflowDDPDriver`` 的父类; + + .. warning:: + + 您不应当直接初始化该类,然后传入给 ``Trainer``,换句话说,您应当使用该类的子类 ``OneflowSingleDriver`` 和 ``OneflowDDPDriver``,而不是 + 该类本身; + + .. note:: + + 您可以在使用 ``OneflowSingleDriver`` 和 ``OneflowDDPDriver`` 时使用 ``OneflowDriver`` 提供的接口; + + """ + def __init__(self, model, fp16: Optional[bool] = False, oneflow_kwargs: Dict = None, **kwargs): + super(OneflowDriver, self).__init__(model) + + """ 进行 fp16 的设置 """ + self._oneflow_kwargs = oneflow_kwargs if oneflow_kwargs is not None else {} + + self.fp16 = fp16 + if fp16: + logger.warn("OneflowDriver of eager mode dose not support fp16 now.``") + # self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not self.fp16) + # self.grad_scaler = _grad_scaler(**self._oneflow_kwargs.get("gradscaler_kwargs", {})) + self.auto_cast = nullcontext + self.grad_scaler = DummyGradScaler() + self.set_grad_to_none = self._oneflow_kwargs.get("set_grad_to_none") + + self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) + + def zero_grad(self): + for optimizer in self.optimizers: + optimizer.zero_grad(self.set_grad_to_none) + + def backward(self, loss): + loss.backward() + # self.grad_scaler.scale(loss).backward() + + def step(self): + for optimizer in self.optimizers: + self.grad_scaler.step(optimizer) + self.grad_scaler.update() + + def check_dataloader_legality(self, dataloader): + if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader): + raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") + if len(dataloader) == 0: + logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " + "may cause some unexpected exceptions.", once=True) + + @staticmethod + def _check_optimizer_legality(optimizers): + for each_optimizer in optimizers: + if not isinstance(each_optimizer, Optimizer): + raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " + f"not {type(each_optimizer)}.") + + @staticmethod + def tensor_to_numeric(tensor, reduce: str = None): + r""" + 将 ``oneflow.Tensor`` 转换成 python 中的数值类型; + + :param tensor: ``oneflow.Tensor``; + :param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``; + :return: 返回一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等; + """ + + if tensor is None: + return None + + def _translate(_data): + if _data.numel() == 1: + return _data.item() + if reduce is None: + return _data.tolist() + return _reduces[reduce](_data).item() + + return apply_to_collection( + data=tensor, + dtype=oneflow.Tensor, + function=_translate + ) + + def set_model_mode(self, mode: str): + r""" + 设置模型的状态是 ``train`` 还是 ``eval``; + :param mode: ``'train'`` 或 ``'eval'``; + """ + assert mode in {"train", "eval"} + getattr(self.model, mode)() + + @rank_zero_call + def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs): + """ + 保存当前 driver 的模型到 folder 下。 + + :param filepath: 保存到哪个文件夹; + :param only_state_dict: 是否只保存权重;如果使用 ``DistributedDataParallel`` 启动分布式训练的话,该参数只能为 ``True``; + :return: + """ + model = self.unwrap_model() + if not only_state_dict and self.is_distributed(): + logger.warn("`Cannot save ddp model directly, we will save its state_dict for you.") + only_state_dict = True + + if only_state_dict: + states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} + oneflow.save(states, filepath) + else: + if self.model_device is not None: + if not self.is_distributed(): + self.move_model_to_device(model, oneflow.device("cpu")) + oneflow.save(model, filepath) + if not self.is_distributed(): + self.move_model_to_device(model, self.model_device) + else: + oneflow.save(model, filepath) + + def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs): + """ + 从 folder 中加载权重并赋值到当前 driver 的模型上。 + + :param filepath: 加载权重或模型的路径 + :param load_state_dict: 保存的内容是否只是权重。 + :param kwargs: + :return: + """ + model = self.unwrap_model() + res = oneflow.load(filepath) + if isinstance(res, dict) and only_state_dict is False: + logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use " + f"`only_state_dict=True`") + elif not isinstance(res, dict) and only_state_dict is True: + logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use " + f"`only_state_dict=False`") + if not isinstance(res, dict): + res = res.state_dict() + model.load_state_dict(res) + + @rank_zero_call + def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): + # 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 + # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; + + # 1. sampler 的状态; + num_consumed_batches = states.pop("num_consumed_batches") + states["sampler_states"] = self.get_sampler_state(dataloader, num_consumed_batches) + + # 2. 保存模型的状态; + if should_save_model: + if not os.path.exists(folder): + os.mkdir(folder) + model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) + self.save_model(model_path, only_state_dict=only_state_dict) + + # 3. 保存 optimizers 的状态; + states["optimizers_state_dict"] = self.get_optimizer_state() + logger.debug("Save optimizer state dict.") + + # # 4. 保存fp16的状态 + # if not isinstance(self.grad_scaler, DummyGradScaler): + # grad_scaler_state_dict = self.grad_scaler.state_dict() + # states['grad_scaler_state_dict'] = grad_scaler_state_dict + + oneflow.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) + + def get_sampler_state(self, dataloader, num_consumed_batches): + dataloader_args = self.get_dataloader_args(dataloader) + if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): + sampler = dataloader_args.batch_sampler + elif dataloader_args.sampler: + sampler = dataloader_args.sampler + else: + raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") + + if hasattr(sampler, "state_dict") and callable(sampler.state_dict): + sampler_states = sampler.state_dict() + if dataloader_args.batch_size is not None: + sampler_states["num_consumed_samples"] = sampler.num_replicas * dataloader_args.batch_size \ + * num_consumed_batches + else: + logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on sampler's " + "`num_consumed_samples`, it may cause missing some samples when reload.") + else: + raise RuntimeError("The sampler has no `state_dict()` method, fastNLP cannot save the training " + "state.") + + return sampler_states + + def load_sampler_state(self, dataloader, sampler_states): + states = {} + dataloader_args = self.get_dataloader_args(dataloader) + if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): + sampler = dataloader_args.batch_sampler + elif isinstance(dataloader_args.sampler, ReproducibleSampler): + sampler = dataloader_args.sampler + elif isinstance(dataloader_args.sampler, OneflowRandomSampler): + sampler = RandomSampler(dataloader_args.sampler.data_source) + logger.debug("Replace oneflow RandomSampler into fastNLP RandomSampler.") + elif self.is_distributed(): + raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" + "`ReproducibleSampler`.") + else: + sampler = ReproduceBatchSampler( + batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, + batch_size=dataloader_args.batch_size, + drop_last=dataloader_args.drop_last + ) + sampler.load_state_dict(sampler_states) + states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) + + # 修改 trainer_state.batch_idx_in_epoch + # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; + if not isinstance(sampler, ReproducibleBatchSampler): + if dataloader_args.drop_last: + batch_idx_in_epoch = len( + sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size + else: + batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ + (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size + # sampler 是 batch_sampler; + else: + batch_idx_in_epoch = sampler.batch_idx_in_epoch + + states["batch_idx_in_epoch"] = batch_idx_in_epoch + return states + + def get_optimizer_state(self): + optimizers_state_dict = {} + for i in range(len(self.optimizers)): + optimizer: oneflow.optim.Optimizer = self.optimizers[i] + optimizer_state = optimizer.state_dict() + optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], oneflow.device("cpu")) + optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; + return optimizers_state_dict + + def load_optimizer_state(self, states): + assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ + f"checkpoint it is:{len(states)}" + for i in range(len(self.optimizers)): + optimizer: oneflow.optim.Optimizer = self.optimizers[i] + optimizer.load_state_dict(states[f"optimizer{i}"]) + logger.debug("Load optimizer state dict.") + + def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: + states = oneflow.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) + + # 1. 加载 optimizers 的状态; + optimizers_state_dict = states.pop("optimizers_state_dict") + self.load_optimizer_state(optimizers_state_dict) + + # 2. 加载模型状态; + if should_load_model: + self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) + + # # 3. 加载 fp16 的状态 + # if "grad_scaler_state_dict" in states: + # grad_scaler_state_dict = states.pop("grad_scaler_state_dict") + # if not isinstance(self.grad_scaler, DummyGradScaler): + # self.grad_scaler.load_state_dict(grad_scaler_state_dict) + # logger.debug("Load grad_scaler state dict...") + # elif not isinstance(self.grad_scaler, DummyGradScaler): + # logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " + # f"the training process may be unstable.") + + # 4. 恢复 sampler 的状态; + sampler_states = states.pop("sampler_states") + states_ret = self.load_sampler_state(dataloader, sampler_states) + states.update(states_ret) + + return states + + def get_evaluate_context(self): + r""" + :return: 返回 ``oneflow.no_grad`` 这个 context; + """ + return oneflow.no_grad + + def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: + if isinstance(batch, Dict) and not self.wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) + else: + return fn(batch) + + def get_model_call_fn(self, fn: str) -> Tuple: + if hasattr(self.model, fn): + fn = getattr(self.model, fn) + if not callable(fn): + raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") + logger.debug(f"Use {_get_fun_msg(fn, with_fp=False)}...") + return fn, None + elif fn in {"train_step", "evaluate_step"}: + logger.debug(f"Use {_get_fun_msg(self.model.forward, with_fp=False)}...") + return self.model, self.model.forward + else: + raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") + + @staticmethod + def move_model_to_device(model: "oneflow.nn.Module", device: "oneflow.device"): + r""" + 将模型迁移到对应的设备上; + """ + if device is not None: + model.to(device) + + def move_data_to_device(self, batch): + """ + 将一个 batch 的数据迁移到对应的设备上; + + :param batch: 一个 batch 的数据,可以是 ``list、dict`` 等; + :return: + """ + return oneflow_move_data_to_device(batch, self.data_device) + + @staticmethod + def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover + global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) + process_seed = oneflow.initial_seed() + + base_seed = process_seed - worker_id + ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) + + np.random.seed(ss.generate_state(4)) + + oneflow_ss, stdlib_ss = ss.spawn(2) + oneflow.manual_seed(oneflow_ss.generate_state(1, dtype=np.uint64)[0]) + + stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() + random.seed(stdlib_seed) + + def set_deterministic_dataloader(self, dataloader: "DataLoader"): + if dataloader.worker_init_fn is None: + dataloader.worker_init_fn = partial(self.worker_init_function, + rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))) + + def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx: int): + # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; + if callable(getattr(dataloader.sampler, "set_epoch", None)): + dataloader.sampler.set_epoch(cur_epoch_idx) + + @staticmethod + def get_dataloader_args(dataloader: "DataLoader"): + """ + 获取 dataloader 的 shuffle 和 drop_last 属性; + """ + + @dataclass + class Res: + dataset: Optional[Dataset] = None + batch_sampler: Optional[BatchSampler] = None + sampler: Optional[Sampler] = None + batch_size: Optional[int] = None + shuffle: Optional[bool] = None + drop_last: Optional[bool] = None + + res = Res() + + # oneflow 的 DataLoader 一定会有 dataset 属性; + res.dataset = dataloader.dataset + + # dataloader 使用的是 sampler; + if dataloader.batch_sampler is None: + res.sampler = dataloader.sampler + res.batch_size = 1 + res.shuffle = True if isinstance(dataloader.sampler, RandomSampler) else False + res.drop_last = False + # dataloader 使用的是 batch_sampler; + else: + res.batch_sampler = dataloader.batch_sampler + if hasattr(dataloader.batch_sampler, "batch_size"): + res.batch_size = getattr(dataloader.batch_sampler, "batch_size") + # 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性; + else: + dataloader_iter = iter(dataloader) + pre_sample = next(dataloader_iter) + res.batch_size = pre_sample.shape[0] + + if hasattr(dataloader.batch_sampler, "sampler"): + res.sampler = dataloader.batch_sampler.sampler + if hasattr(dataloader.batch_sampler.sampler, "shuffle"): + res.shuffle = dataloader.batch_sampler.sampler.shuffle + elif isinstance(dataloader.batch_sampler.sampler, OneflowRandomSampler): + res.shuffle = True + else: + res.shuffle = False + # ReproduceBatchSampler 的情况 + elif hasattr(dataloader.batch_sampler, "batch_sampler"): + batch_sampler = dataloader.batch_sampler.batch_sampler + res.sampler = batch_sampler.sampler + if hasattr(batch_sampler.sampler, "shuffle"): + res.shuffle = dataloader.batch_sampler.sampler.shuffle + elif isinstance(batch_sampler.sampler, OneflowRandomSampler): + res.shuffle = True + else: + res.shuffle = False + else: + # 如果 dataloader.batch_sampler 没有 sampler 这个属性,那么说明其使用的是自己的 batch_sampler,且没有 "sampler" 属性; + # 这种情况下 DataLoader 会自己初始化一个 sampler;我们因此将这个默认初始化的 sampler 挂载到 res 上; + res.sampler = dataloader.sampler + res.shuffle = False + + if hasattr(dataloader.batch_sampler, "drop_last"): + res.drop_last = getattr(dataloader.batch_sampler, "drop_last") + # 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性; + else: + res.drop_last = False + + return res diff --git a/fastNLP/core/drivers/oneflow_driver/single_device.py b/fastNLP/core/drivers/oneflow_driver/single_device.py new file mode 100644 index 00000000..84d77d14 --- /dev/null +++ b/fastNLP/core/drivers/oneflow_driver/single_device.py @@ -0,0 +1,114 @@ +import os +from typing import Dict, Union +from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW + +if _NEED_IMPORT_ONEFLOW: + import oneflow + from oneflow.utils.data import SequentialSampler as OneflowSequentialSampler + from oneflow.utils.data import BatchSampler as OneflowBatchSampler + +__all__ = [ + "OneflowSingleDriver" +] + +from .oneflow_driver import OneflowDriver +from fastNLP.core.drivers.oneflow_driver.utils import replace_sampler, replace_batch_sampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \ + ReproduceBatchSampler +from fastNLP.core.samplers import RandomSampler +from fastNLP.core.log import logger + + +class OneflowSingleDriver(OneflowDriver): + r""" + 用于执行 ``oneflow`` 动态图 cpu 和 单卡 gpu 运算的 ``driver``; + + :param model: 传入给 ``Trainer`` 的 ``model`` 参数; + :param device: oneflow.device,当前进程所使用的设备; + :param fp16: 是否开启 fp16;目前动态图的单卡下该参数无效; + :param oneflow_kwargs: + """ + + def __init__(self, model, device: "oneflow.device", fp16: bool = False, oneflow_kwargs: Dict = None, **kwargs): + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) + if cuda_visible_devices == "": + device = oneflow.device("cpu") + logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to" + "use `cpu` instead of `gpu` device.") + + super(OneflowSingleDriver, self).__init__(model, fp16=fp16, oneflow_kwargs=oneflow_kwargs, **kwargs) + + if device is None: + logger.debug("device is not set, fastNLP will try to automatically get it.") + try: + device = next(model.parameters()).device + assert isinstance(device, oneflow.device) + except: + raise ValueError("fastNLP cannot get device automatically, please set device explicitly.") + + self.model_device = device + + self.local_rank = 0 + self.global_rank = 0 + self.world_size = 1 + + def setup(self): + r""" + 将模型迁移到相应的设备上; + """ + if self.model_device is not None: + self.model.to(self.model_device) + + def set_dist_repro_dataloader(self, dataloader, + dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None, + reproducible: bool = False): + + # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用; + if isinstance(dist, ReproducibleBatchSampler): + return replace_batch_sampler(dataloader, dist) + elif isinstance(dist, ReproducibleSampler): + return replace_sampler(dataloader, dist) + + # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; + args = self.get_dataloader_args(dataloader) + if isinstance(args.batch_sampler, ReproducibleBatchSampler): + batch_sampler = re_instantiate_sampler(args.batch_sampler) + return replace_batch_sampler(dataloader, batch_sampler) + elif isinstance(args.sampler, ReproducibleSampler): + sampler = re_instantiate_sampler(args.sampler) + return replace_sampler(dataloader, sampler) + + if reproducible: + if type(args.batch_sampler) is OneflowBatchSampler: + if type(args.sampler) is OneflowSequentialSampler: + # 需要替换为不要 shuffle 的。 + sampler = RandomSampler(args.sampler.data_source, shuffle=False) + logger.debug("Replace oneflow SequentialSampler into fastNLP RandomSampler.") + return replace_sampler(dataloader, sampler) + batch_sampler = ReproduceBatchSampler( + batch_sampler=args.batch_sampler, + batch_size=args.batch_size, + drop_last=args.drop_last + ) + return replace_batch_sampler(dataloader, batch_sampler) + else: + return dataloader + + def unwrap_model(self): + r""" + :return: 返回模型 + """ + return self.model + + @property + def data_device(self): + r""" + :return: 数据和模型所在的设备; + """ + return self.model_device + + def is_distributed(self): + r""" + :return: 返回当前使用的 driver 是否是分布式的 driver,在 ``OneflowSingleDriver`` 中返回 ``False``; + """ + return False diff --git a/fastNLP/core/drivers/oneflow_driver/utils.py b/fastNLP/core/drivers/oneflow_driver/utils.py new file mode 100644 index 00000000..33019883 --- /dev/null +++ b/fastNLP/core/drivers/oneflow_driver/utils.py @@ -0,0 +1,292 @@ +import os + +from typing import Any, Dict, Optional +from enum import IntEnum +import contextlib +import random +import numpy as np +import inspect + +from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW +from fastNLP.envs.utils import get_global_seed +from fastNLP.envs import ( + get_global_rank, + FASTNLP_BACKEND_LAUNCH, + FASTNLP_GLOBAL_SEED, +) +from fastNLP.core.samplers import ReproducibleBatchSampler +from fastNLP.core.utils import auto_param_call +from fastNLP.core.log import logger + +if _NEED_IMPORT_ONEFLOW: + import oneflow + from oneflow.nn import Module + from oneflow.utils.data import DataLoader + from oneflow.utils.data import RandomSampler as oneflowRandomSampler + from oneflow.utils.data import SequentialSampler as oneflowSequentialSampler + from oneflow.utils.data import BatchSampler as oneflowBatchSampler +else: + from fastNLP.core.utils.dummy_class import DummyClass as Module + + +__all__ = [ + 'oneflow_seed_everything', + 'optimizer_state_to_device' +] + +def oneflow_seed_everything(seed: int = None, add_global_rank_to_seed: bool = True) -> int: + r""" + 为 **oneflow**、**numpy**、**python.random** 伪随机数生成器设置种子。 + + :param seed: 全局随机状态的整数值种子。如果为 ``None`` 则会根据时间戳生成一个种子。 + :param add_global_rank_to_seed: 在分布式训练中,是否在不同 **rank** 中使用不同的随机数。 + 当设置为 ``True`` 时,**FastNLP** 会将种子加上当前的 ``global_rank``。 + """ + max_seed_value = np.iinfo(np.uint32).max + min_seed_value = np.iinfo(np.uint32).min + + if seed is None: + if os.getenv(FASTNLP_BACKEND_LAUNCH) == "1": + seed = 42 + else: + seed = get_global_seed() + logger.info(f"'FASTNLP_GLOBAL_SEED' is set to {seed} automatically.") + if not isinstance(seed, int): + seed = int(seed) + + if not (min_seed_value <= seed <= max_seed_value): + logger.rank_zero_warning("Your seed value is too big or too small for numpy, we will choose a random seed for you.") + seed %= max_seed_value + + os.environ[FASTNLP_GLOBAL_SEED] = f"{seed}" + if add_global_rank_to_seed: + seed += get_global_rank() + + random.seed(seed) + np.random.seed(seed) + oneflow.manual_seed(seed) + oneflow.cuda.manual_seed_all(seed) + return seed + + +class ForwardState(IntEnum): + TRAIN = 0 + VALIDATE = 1 + TEST = 2 + PREDICT = 3 + + +class _DDPWrappingModel(Module): + """ + 该函数用于 DDP 训练时处理用户自己定制的 train_step 等函数; + 之所以要使用这一额外的包裹模型,是因为在使用 DDP 时,必须使用 DistributedDataParallel 的 forward 函数才能实现正常的运行; + 另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'evaluate_step' 等; + 然而,当使用 DistributedDataParallel 包裹 model 后,模型看不见其除了 forward 之外的方法;并且当我们尝试在训练过程中主动提取 + `model = model.module`,这同样会导致错误,会使得每一个gpu上的模型参数不同; + + 因此出于以上考虑,我们实现了这一函数; + 对于更详细的解释,可以参考 'pytorch_lightning' 的 ddp 的设计; + """ + + def __init__(self, model: Module): + super(_DDPWrappingModel, self).__init__() + self.model = model + + def forward(self, batch, **kwargs) -> Dict: + """ + pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; + """ + fn = kwargs.pop("fastnlp_fn") + signature_fn = kwargs.pop("fastnlp_signature_fn") + wo_auto_param_call = kwargs.pop("wo_auto_param_call") + + if isinstance(batch, Dict) and not wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) + else: + return fn(batch) + + +class DummyGradScaler: + + def __init__(self, *args, **kwargs): + pass + + def get_scale(self): + return 1.0 + + def is_enabled(self): + return False + + def scale(self, outputs): + return outputs + + def step(self, optimizer, *args, **kwargs): + optimizer.step(*args, **kwargs) + + def update(self, new_scale=None): + pass + + def unscale_(self, optimizer): + pass + + def load_state_dict(self, state_dict): + pass + + def state_dict(self): + return {} + + +def _build_fp16_env(dummy=False): + return + if dummy: + autocast = contextlib.ExitStack + GradScaler = DummyGradScaler + else: + if not oneflow.cuda.is_available(): + raise RuntimeError("Oneflow is not installed in gpu version, please use device='cpu'.") + if oneflow.cuda.get_device_capability(0)[0] < 7: + logger.rank_zero_warning( + "NOTE: your device does NOT support faster training with fp16, " + "please switch to FP32 which is likely to be faster" + ) + try: + from oneflow.amp import GradScaler + from oneflow.cuda.amp import autocast, GradScaler + except ImportError: + raise RuntimeError("torch version too low (less than 1.6)") + return autocast, GradScaler + + +def replace_sampler(dataloader: "DataLoader", sampler): + r""" + 替换 sampler (初始化一个新的 dataloader 的逻辑在于): + + 用户可能继承了 dataloader,定制了自己的 dataloader 类,这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接 + `inspect.signature(DataLoader)` 的原因,因此同时注意到我们在外层重新初始化一个 dataloader 时也是使用的用户传进来的 dataloader + 的类,而不是直接的 DataLoader; + + 如果需要定制自己的 dataloader,保证以下两点: + + 1. 在 __init__ 方法中加入 **kwargs,这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中; + 2. 在 __init__ 方法中出现的参数,请务必挂为同样名字的实例属性,例如 self.one_arg_name = one_arg_name,这是因为我们只能通过属性 + 来获取实际的参数的值; + + """ + + # 拿到实例属性; + instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} + + # 'multiprocessing_context' 是 user-defined function; + if getattr(dataloader, 'multiprocessing_context', None) is not None: + instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context + + # 拿到 dataloader '__init__' 函数的默认函数签名; + init_params = dict(inspect.signature(dataloader.__init__).parameters) + + # 防止用户的 DataLoader 是继承了 oneflow 的 DataLoader,然后还是使用了 **kwargs 的方式对父类传参数 + has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) + if has_variadic_kwargs and isinstance(dataloader, DataLoader): + # 防止用户写入了 super().__init__(**kwargs) + for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): + if key not in init_params and key != 'self': + init_params[key] = value + + # 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置; + non_default_params = {name for name, p in init_params.items() if + name in instance_attrs and p.default != instance_attrs[name]} + # add `dataset` as it might have been replaced with `*args` + non_default_params.add("dataset") + + reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} + if isinstance(dataloader, DataLoader): + reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None}) + + batch_sampler = getattr(dataloader, "batch_sampler") + if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): + raise RuntimeError("It should not be running here, please report a bug to us.") + + required_args = { + p.name + for p in init_params.values() + if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) + and p.default is p.empty + and p.name not in reconstruct_args + } + + # 在 attribute 中没有找到这些参数,导致了没有办法重新初始化 + if required_args: + required_args = sorted(required_args) + dataloader_self_name = dataloader.__class__.__name__ + raise Exception( + f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. " + f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its " + f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be " + f"`{dataloader_self_name}`'s attribute." + ) + + # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; + if not has_variadic_kwargs: + # the dataloader signature does not allow keyword arguments that need to be passed + missing_kwargs = reconstruct_args.keys() - init_params.keys() + if missing_kwargs: + missing_kwargs = sorted(missing_kwargs) + dataloader_self_name = dataloader.__class__.__name__ + raise Exception( + f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found." + ) + # 如果没有kwargs,则保证一下只传入需要的参数 + if not isinstance(dataloader, DataLoader): + reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params} + + return type(dataloader)(**reconstruct_args) + + +def replace_batch_sampler(dataloader, new_batch_sampler): + r""" + 替换一个 dataloader 的 batch_sampler; + """ + params_keys = [k for k in dataloader.__dict__.keys() if not k.startswith("_")] + for k in ["batch_size", "sampler", "drop_last", "batch_sampler", "dataset_kind"]: + if k in params_keys: + params_keys.remove(k) + params = {k: getattr(dataloader, k) for k in params_keys} + params["batch_sampler"] = new_batch_sampler + + if not isinstance(dataloader, DataLoader): + init_params = dict(inspect.signature(dataloader.__init__).parameters) + has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) + if not has_variadic_kwargs: + params = {key:value for key,value in params.items() if key in init_params} + + return type(dataloader)(**params) + + +def optimizer_state_to_device(state, device): + r""" + 将一个 ``optimizer`` 的 ``state_dict`` 迁移到对应的设备; + + :param state: ``optimzier.state_dict()``; + :param device: 要迁移到的目的设备; + :return: 返回迁移后的新的 state_dict; + """ + new_state = {} + for name, param in state.items(): + if isinstance(param, dict): + new_state[name] = optimizer_state_to_device(param, device) + elif isinstance(param, oneflow.Tensor): + new_state[name] = param.to(device).clone() + else: + new_state[name] = param + return new_state + + +def _check_dataloader_args_for_distributed(args, controller='Trainer'): + if type(args.batch_sampler) is not oneflowBatchSampler or (type(args.sampler) not in {oneflowRandomSampler, + oneflowSequentialSampler}): + mode = 'training' if controller == 'Trainer' else 'evaluation' + substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' + raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " + f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " + f"``{substitution}``. The customized sampler should set for distributed running " + f"before initializing ``{controller}`` , and then set the " + f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 98c07495..137aa9db 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -73,6 +73,7 @@ from .utils import ( _FleetWrappingModel, replace_sampler, replace_batch_sampler, + _check_dataloader_args_for_distributed ) from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object @@ -129,15 +130,15 @@ class PaddleFleetDriver(PaddleDriver): :param is_pull_by_paddle_run: 标记当前进程是否为通过 ``python -m paddle.distributed.launch`` 启动的。 这个参数仅在 :class:`~fastNLP.core.Trainer` 中初始化 driver 时使用 :param fp16: 是否开启混合精度训练; + :param paddle_kwargs: + * *fleet_kwargs* -- 用于在使用 ``PaddleFleetDriver`` 时指定 ``DataParallel`` 和 ``fleet`` 初始化时的参数,包括: + + * *is_collective* -- 是否使用 paddle 集群式的分布式训练方法,目前仅支持为 ``True`` 的情况; + * *role_maker* -- 初始化 ``fleet`` 分布式训练 API 时使用的 ``RoleMaker``; + * 其它用于初始化 ``DataParallel`` 的参数; + * *gradscaler_kwargs* -- 用于 ``fp16=True`` 时,提供给 :class:`paddle.amp.GradScaler` 的参数; + :kwargs: - * *paddle_kwargs* -- 用于在指定 ``driver`` 为 'paddle' 时设定具体 driver 实例的一些参数: - - * fleet_kwargs -- 用于在使用 ``PaddleFleetDriver`` 时指定 ``DataParallel`` 和 ``fleet`` 初始化时的参数,包括: - - * is_collective -- 是否使用 paddle 集群式的分布式训练方法,目前仅支持为 ``True`` 的情况; - * role_maker -- 初始化 ``fleet`` 分布式训练 API 时使用的 ``RoleMaker`` - * 其它用于初始化 ``DataParallel`` 的参数; - * wo_auto_param_call (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为; .. note:: @@ -151,11 +152,12 @@ class PaddleFleetDriver(PaddleDriver): parallel_device: Optional[Union[List[str], str]], is_pull_by_paddle_run: bool = False, fp16: bool = False, + paddle_kwargs: Dict = None, **kwargs ): if USER_CUDA_VISIBLE_DEVICES not in os.environ: - raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using FastNLP.") - super(PaddleFleetDriver, self).__init__(model, fp16=fp16, **kwargs) + raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using fastNLP.") + super(PaddleFleetDriver, self).__init__(model, fp16=fp16, paddle_kwargs=paddle_kwargs, **kwargs) # 如果不是通过 launch 启动,要求用户必须传入 parallel_device if not is_pull_by_paddle_run: @@ -193,17 +195,14 @@ class PaddleFleetDriver(PaddleDriver): self.world_size = None self.global_rank = 0 self.gloo_rendezvous_dir = None - - # 分布式环境的其它参数设置 - paddle_kwargs = kwargs.get("paddle_kwargs", {}) - self._fleet_kwargs = paddle_kwargs.get("fleet_kwargs", {}) + self._fleet_kwargs = self._paddle_kwargs.get("fleet_kwargs", {}) check_user_specific_params(self._fleet_kwargs, DataParallel.__init__, DataParallel.__name__) # fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档 self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) self.is_collective = self._fleet_kwargs.pop("is_collective", True) if not self.is_collective: - raise NotImplementedError("FastNLP only support `collective` for distributed training now.") + raise NotImplementedError("fastNLP only support `collective` for distributed training now.") self.role_maker = self._fleet_kwargs.pop("role_maker", None) self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") @@ -422,8 +421,7 @@ class PaddleFleetDriver(PaddleDriver): # trainer, evaluator if dist is None: if reproducible: - raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " - "control.") + raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.") else: args = self.get_dataloader_args(dataloader) if isinstance(args.batch_sampler, ReproducibleBatchSampler): @@ -454,6 +452,7 @@ class PaddleFleetDriver(PaddleDriver): ) return replace_sampler(dataloader, sampler) else: + _check_dataloader_args_for_distributed(args, controller='Trainer') sampler = RandomSampler( dataset=args.dataset, shuffle=args.shuffle, diff --git a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py index 552fc622..e059e91c 100644 --- a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py @@ -38,7 +38,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) if is_in_paddle_launch_dist(): if user_visible_devices is None: - raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using FastNLP.") + raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using fastNLP.") if device is not None: logger.rank_zero_warning("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " "up your script. And we will directly get the local device via environment variables.", once=True) diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 4527f1ed..0ba0dc1b 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -19,6 +19,7 @@ from fastNLP.envs import ( rank_zero_call, ) from fastNLP.core.log import logger +from fastNLP.core.dataloaders import OverfitDataLoader from fastNLP.core.samplers import ( ReproducibleBatchSampler, ReproducibleSampler, @@ -55,27 +56,32 @@ class PaddleDriver(Driver): 1. :class:`~fastNLP.core.drivers.PaddleSingleDriver`:实现了使用单卡和 ``cpu`` 训练的具体功能; 2. :class:`~fastNLP.core.drivers.PaddleFleetDriver`:实现了使用 ``fleet`` 分布式训练 API 进行集群式分布式训练的具体功能; - :param model: 训练时使用的 **PaddlePaddle** 模型; - :param fp16: 是否开启混合精度训练; - :kwargs: - * wo_auto_param_call (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为; + .. warning:: - .. note:: + 您不应当直接初始化该类,然后传入给 ``Trainer``,换句话说,您应当使用该类的子类 ``PaddleSingleDriver`` 和 ``PaddleDDPDriver``,而不是 + 该类本身; - 关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。 + .. note:: + + 您可以在使用 ``PaddleSingleDriver`` 和 ``PaddleFleetDriver`` 时使用 ``PaddleDriver`` 提供的接口; + + :param model: 训练时使用的 **PaddlePaddle** 模型; + :param fp16: 是否开启混合精度训练; + :param paddle_kwargs: """ - def __init__(self, model: "paddle.nn.Layer", fp16: Optional[bool] = False, **kwargs): + def __init__(self, model: "paddle.nn.Layer", fp16: Optional[bool] = False, paddle_kwargs: Dict = None, **kwargs): if not isinstance(model, paddle.nn.Layer): raise ValueError(f"Parameter `model` can not be `{type(model)}` in `PaddleDriver`, it should be exactly " f"`paddle.nn.Layer` type.") super(PaddleDriver, self).__init__(model) self.fp16 = fp16 + self._paddle_kwargs = paddle_kwargs if paddle_kwargs is not None else {} # scaler的参数 self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) - self.grad_scaler = _grad_scaler() + self.grad_scaler = _grad_scaler(**self._paddle_kwargs.get("gradscaler_kwargs", {})) # 用来设置是否关闭 auto_param_call 中的参数匹配问题; self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) @@ -93,7 +99,7 @@ class PaddleDriver(Driver): self.grad_scaler.update() def check_dataloader_legality(self, dataloader): - if not isinstance(dataloader, DataLoader): + if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader): raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") if dataloader.batch_size is None and dataloader.batch_sampler is None: raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler" @@ -154,7 +160,7 @@ class PaddleDriver(Driver): :param only_state_dict: 是否只保存模型的 ``state_dict``;如果为 ``False``,则会调用 ``paddle.jit.save`` 函数保存整个模型的参数,此时需要传入 ``input_spec`` 参数; :kwargs: - * input_spec -- 描述存储模型 ``forward`` 方法的输入; + * *input_spec* -- 描述存储模型 ``forward`` 方法的输入; 当 ``only_state_dict`` 为 ``False`` 时必须传入,否则加载时会报错。您可以通过 ``InputSpec`` 或者示例 ``Tensor`` 进行描述。详细的使用方法可以参考 **PaddlePaddle** `关于 paddle.jit.save 函数的文档 `_; """ @@ -222,26 +228,12 @@ class PaddleDriver(Driver): num_consumed_batches = states.pop("num_consumed_batches") if hasattr(sampler, "state_dict") and callable(sampler.state_dict): sampler_states = sampler.state_dict() - # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples - # 会造成多余实际消耗的问题。 - num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) - if num_consumed_samples_array is not None: - if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 - if dataloader_args.batch_size is not None: - num_consumed_batches = num_consumed_batches * dataloader_args.batch_size - else: # 有可能 batch_size 为 None,就只有损失精度了 - logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " - "it may cause missing some samples when reload.") - num_consumed_batches = sampler_states['num_consumed_samples'] - sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] - assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." + if dataloader_args.batch_size is not None: + sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ + * num_consumed_batches else: - if dataloader_args.batch_size is not None: - sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ - * num_consumed_batches - else: - logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " - "it may cause missing some samples when reload.") + logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + "it may cause missing some samples when reload.") else: raise RuntimeError( "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index ba404814..86994b79 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -26,6 +26,11 @@ if _NEED_IMPORT_PADDLE: import paddle from paddle import DataParallel from paddle.fluid.reader import _DatasetKind + from paddle.io import ( + RandomSampler as PaddleRandomSampler, + SequenceSampler as PaddleSequenialSampler, + BatchSampler as PaddleBatchSampler, + ) __all__ = [ "PaddleSingleDriver", @@ -38,6 +43,8 @@ class PaddleSingleDriver(PaddleDriver): :param model: 训练时使用的 **PaddlePaddle** 模型; :param device: 训练使用的设备; :param fp16: 是否开启混合精度训练; + :param paddle_kwargs: + * *gradscaler_kwargs* -- 用于 ``fp16=True`` 时,提供给 :class:`paddle.amp.GradScaler` 的参数; :kwargs: * wo_auto_param_call (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为; @@ -46,7 +53,7 @@ class PaddleSingleDriver(PaddleDriver): 关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。 """ - def __init__(self, model: "paddle.nn.Layer", device: Union[str, int], fp16: Optional[bool] = False, **kwargs): + def __init__(self, model: "paddle.nn.Layer", device: Union[str, int], fp16: Optional[bool] = False, paddle_kwargs: Dict = None, **kwargs): if isinstance(model, DataParallel): raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`") @@ -56,7 +63,7 @@ class PaddleSingleDriver(PaddleDriver): logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to" "use `cpu` instead of `gpu` device.") - super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs) + super(PaddleSingleDriver, self).__init__(model, fp16=fp16, paddle_kwargs=paddle_kwargs, **kwargs) if device is None: raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.") @@ -122,19 +129,21 @@ class PaddleSingleDriver(PaddleDriver): return replace_sampler(dataloader, sampler) if reproducible: - if isinstance(args.sampler, paddle.io.RandomSampler): - if getattr(args.sampler, '_num_samples', None) is None \ - and getattr(args.sampler, 'replacements', False) is False \ - and getattr(args.sampler, 'generator', None) is None: - # 如果本来就是随机的,并且没有定制,直接替换掉。 - sampler = RandomSampler(args.sampler.data_source, shuffle=True) - logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") + if type(args.batch_sampler) is PaddleBatchSampler: + if type(args.sampler) is PaddleRandomSampler: + if isinstance(args.sampler, PaddleRandomSampler): + if getattr(args.sampler, '_num_samples', None) is None \ + and getattr(args.sampler, 'replacements', False) is False \ + and getattr(args.sampler, 'generator', None) is None: + # 如果本来就是随机的,并且没有定制,直接替换掉。 + sampler = RandomSampler(args.sampler.data_source, shuffle=True) + logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") + return replace_sampler(dataloader, sampler) + elif type(args.sampler) is PaddleSequenialSampler: + # 需要替换为不要 shuffle 的。 + sampler = RandomSampler(args.sampler.data_source, shuffle=False) + logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") return replace_sampler(dataloader, sampler) - elif isinstance(args.sampler, paddle.io.SequenceSampler): - # 需要替换为不要 shuffle 的。 - sampler = RandomSampler(args.sampler.data_source, shuffle=False) - logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") - return replace_sampler(dataloader, sampler) batch_sampler = ReproduceBatchSampler( batch_sampler=args.batch_sampler, batch_size=args.batch_size, diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index b1815fbd..be83e5fe 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -15,6 +15,7 @@ from fastNLP.envs import ( FASTNLP_BACKEND_LAUNCH, FASTNLP_GLOBAL_SEED, ) +from fastNLP.core.samplers import ReproducibleBatchSampler from fastNLP.core.utils import auto_param_call, paddle_to from fastNLP.core.log import logger @@ -23,7 +24,7 @@ if _NEED_IMPORT_PADDLE: import paddle from paddle import nn from paddle.nn import Layer - from paddle.io import DataLoader, BatchSampler + from paddle.io import DataLoader, BatchSampler, RandomSampler, SequenceSampler from paddle.amp import auto_cast, GradScaler else: from fastNLP.core.utils.dummy_class import DummyClass as Layer @@ -129,7 +130,7 @@ def _build_fp16_env(dummy=False): "NOTE: your device does NOT support faster training with fp16, " "please switch to FP32 which is likely to be faster" ) - return auto_cast, GradScaler + return auto_cast, GradScaler def find_free_ports(num): """ @@ -178,23 +179,22 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler # 中寻找;VAR_KEYWORD 代表 **kwargs has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) if has_variadic_kwargs: - init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) - del init_params["self"] + for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): + if key not in init_params and key != 'self': + init_params[key] = value - # 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; - # 将同时在实例名和参数名中出现且不是默认值的参数收集起来 + # 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置; non_default_params = {name for name, p in init_params.items() if name in instance_attrs and p.default != instance_attrs[name]} # add `dataset` as it might have been replaced with `*args` non_default_params.add("dataset") - # 收集不是默认值的参数和它的值 reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} - # persistent_workers 在类中的对应成员带有下划线,因此添加进来 - reconstruct_args.update({ - "batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, - "persistent_workers": dataloader._persistent_workers, - }) + if isinstance(dataloader, DataLoader): + reconstruct_args.update({ + "batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, + "persistent_workers": dataloader._persistent_workers, + }) # POSITIONAL_OR_KEYWORD 代表一般的参数 # 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数 @@ -212,9 +212,10 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler required_args = sorted(required_args) dataloader_self_name = dataloader.__class__.__name__ raise Exception( - f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " - "This would fail as some of the `__init__` arguments are not available as instance attributes. " - f"The missing attributes are {required_args}. " + f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. " + f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its " + f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be " + f"`{dataloader_self_name}`'s attribute." ) # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; @@ -226,10 +227,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler missing_kwargs = sorted(missing_kwargs) dataloader_self_name = dataloader.__class__.__name__ raise Exception( - f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " - "This would fail as it doesn't expose all its attributes in the `__init__` signature. " - f"The missing arguments are {missing_kwargs}. " + f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found." ) + # 如果没有kwargs,则保证一下只传入需要的参数 + if not isinstance(dataloader, DataLoader): + reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params} return type(dataloader)(**reconstruct_args) @@ -237,6 +239,9 @@ def replace_sampler(dataloader, new_sampler): """ 使用 ``new_sampler`` 重新构建一个 ``BatchSampler``,并替换到 ``dataloader`` 中 """ + batch_sampler = getattr(dataloader, "batch_sampler") + if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): + raise RuntimeError("It should not be running here, please report a bug to us.") new_batch_sampler = deepcopy(dataloader.batch_sampler) new_batch_sampler.sampler = new_sampler return replace_batch_sampler(dataloader, new_batch_sampler) @@ -251,3 +256,14 @@ def optimizer_state_to_device(state, device): else: new_state[name] = param return new_state + +def _check_dataloader_args_for_distributed(args, controller='Trainer'): + if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {RandomSampler, + SequenceSampler}): + mode = 'training' if controller == 'Trainer' else 'evaluation' + substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' + raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " + f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " + f"``{substitution}``. The customized sampler should set for distributed running " + f"before initializing ``{controller}`` , and then set the " + f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") diff --git a/fastNLP/core/drivers/torch_driver/__init__.py b/fastNLP/core/drivers/torch_driver/__init__.py index 8c24fa53..08026d9e 100644 --- a/fastNLP/core/drivers/torch_driver/__init__.py +++ b/fastNLP/core/drivers/torch_driver/__init__.py @@ -1,6 +1,7 @@ __all__ = [ 'TorchDDPDriver', 'TorchSingleDriver', + 'DeepSpeedDriver', 'TorchDriver', 'torch_seed_everything', 'optimizer_state_to_device' @@ -10,6 +11,7 @@ from .ddp import TorchDDPDriver # todo 实现 fairscale 后再将 fairscale 导入到这里; from .single_device import TorchSingleDriver from .torch_driver import TorchDriver +from .deepspeed import DeepSpeedDriver from .utils import torch_seed_everything, optimizer_state_to_device diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 43c6bc36..28670071 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -159,6 +159,7 @@ from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, Unrepeated from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC from fastNLP.core.log import logger from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object +from .utils import _check_dataloader_args_for_distributed class TorchDDPDriver(TorchDriver): @@ -234,7 +235,12 @@ class TorchDDPDriver(TorchDriver): :param parallel_device: 用于分布式训练的 ``gpu`` 设备; :param is_pull_by_torch_run: 标志当前的脚本的启动是否由 ``python -m torch.distributed.launch`` 启动的; :param fp16: 是否开启 fp16 训练; - :param kwargs: 其余的一些用于设定 ddp 训练的参数; + :param torch_kwargs: + * *ddp_kwargs* -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 + {'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; + * *set_grad_to_none* -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; + * *non_blocking* -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; + * *gradscaler_kwargs* -- 用于 fp16=True 时,提供给 ``torch.amp.cuda.GradScaler`` 的参数; """ def __init__( @@ -243,11 +249,12 @@ class TorchDDPDriver(TorchDriver): parallel_device: Optional[Union[List["torch.device"], "torch.device"]], is_pull_by_torch_run: bool = False, fp16: bool = False, + torch_kwargs: Dict = None, **kwargs ): # 在加入很多东西后,需要注意这里调用 super 函数的位置; - super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) + super(TorchDDPDriver, self).__init__(model, fp16=fp16, torch_kwargs=torch_kwargs, **kwargs) if isinstance(model, torch.nn.DataParallel): raise ValueError(f"Parameter `model` can not be `DataParallel` in `TorchDDPDriver`, it should be " @@ -417,6 +424,7 @@ class TorchDDPDriver(TorchDriver): os.environ['MASTER_ADDR'] = self.master_address os.environ['MASTER_PORT'] = self.master_port + os.environ["RANK"] = "0" os.environ["LOCAL_RANK"] = str(self.local_rank) os.environ["WORLD_SIZE"] = f"{self.world_size}" @@ -429,6 +437,7 @@ class TorchDDPDriver(TorchDriver): for rank in range(1, len(self.parallel_device)): env_copy = os.environ.copy() env_copy["LOCAL_RANK"] = f"{rank}" + env_copy["RANK"] = f"{rank}" # 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK; env_copy[FASTNLP_GLOBAL_RANK] = str(rank) @@ -535,8 +544,7 @@ class TorchDDPDriver(TorchDriver): # trainer, evaluator if dist is None: if reproducible: - raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " - "control.") + raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.") else: args = self.get_dataloader_args(dataloader) if isinstance(args.batch_sampler, ReproducibleBatchSampler): @@ -565,6 +573,7 @@ class TorchDDPDriver(TorchDriver): ) return replace_sampler(dataloader, sampler) else: + _check_dataloader_args_for_distributed(args, controller='Trainer') sampler = RandomSampler( dataset=args.dataset, shuffle=args.shuffle, @@ -582,6 +591,7 @@ class TorchDDPDriver(TorchDriver): if isinstance(args.sampler, ReproducibleSampler): sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) elif not isinstance(args.sampler, UnrepeatedSampler): + _check_dataloader_args_for_distributed(args, controller='Evaluator') sampler = UnrepeatedSequentialSampler( dataset=args.dataset ) diff --git a/fastNLP/core/drivers/torch_driver/deepspeed.py b/fastNLP/core/drivers/torch_driver/deepspeed.py new file mode 100644 index 00000000..2fc6e96e --- /dev/null +++ b/fastNLP/core/drivers/torch_driver/deepspeed.py @@ -0,0 +1,445 @@ +import os +import argparse +import logging +from pathlib import Path + +from typing import Union, Dict, List +from .torch_driver import TorchDriver +from .ddp import TorchDDPDriver +from .utils import _create_default_config, _DeepSpeedWrappingModel +from fastNLP.core.utils import nullcontext +from fastNLP.core.log import logger +from fastNLP.envs import( + FASTNLP_DISTRIBUTED_CHECK, + FASTNLP_CHECKPOINT_FILENAME +) +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED + +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed as dist + from torch.optim import Optimizer + +if _NEED_IMPORT_DEEPSPEED: + import deepspeed + from deepspeed import DeepSpeedEngine, DeepSpeedOptimizer + +__all__ = [ + "DeepSpeedDriver", +] + +class DeepSpeedDriver(TorchDDPDriver): + """ + 实现 ``deepspeed`` 分布式训练的 ``Driver``。 + + .. note:: + + 您在绝大多数情况下不需要自己使用到该类,通过向 ``Trainer`` 传入正确的参数,您可以方便快速地部署您的分布式训练; + + ``DeepSpeedDriver`` 目前支持的三种启动方式: + + 1. 用户自己不进行任何操作,直接使用我们的 ``Trainer``,这时是由我们自己使用 ``open_subprocesses`` 拉起多个进程, + 然后 ``DeepSpeedDriver`` 自己通过调用 ``deepspeed.initialize`` 来初始化模型和同心组;(情况 A) + + .. code-block:: + + trainer = Trainer( + ... + driver='deepspeed', + device=[0, 1] + ) + trainer.run() + + 通过运行 ``python train.py`` 启动; + + 2. 用户同样不在 ``Trainer`` 之外初始化 ``deepspeed``,但是用户自己使用 ``python -m torch.distributed.launch`` 拉起来创建多个进程,这时我们仍旧 + 会通过调用 ``model.initialize`` 来初始化 ``ddp`` 的通信组;(情况 B) + + .. code-block:: + + trainer = Trainer( + ... + driver='deepspeed', + device=None + ) + trainer.run() + + 通过运行 ``deepspeed train.py`` 启动; + + 3. 用户自己在外面初始化 ``deepspeed``,并且通过 ``deepspeed train.py`` 拉起,这时无论是多个进程的拉起和通信组的建立 + 都由用户自己操作,我们只会在 ``driver.setup`` 的时候对 ``DeepSpeedDriver`` 设置一些必要的属性值;(情况 C) + + .. code-block:: + + import deepspeed + + # 初始化 + model, _, _, _ = deepspeed.initialize(model, ...) + + trainer = Trainer( + ... + driver='deepspeed', + device=None + ) + trainer.run() + + 通过运行 ``deepspeed train.py`` 启动; + + :param model: 传入给 ``Trainer`` 的 ``model`` 参数; + :param parallel_device: 用于分布式训练的 ``gpu`` 设备; + :param is_pull_by_torch_run: 标志当前的脚本的启动是否由 ``python -m torch.distributed.launch`` 启动的; + :param fp16: 是否开启 fp16 训练; + :param deepspeed_kwargs: + * *strategy* -- 使用 ZeRO 优化的策略,默认为 ``deepspeed``;目前仅支持以下值: + + * ``deepspeed`` -- 使用 ZeRO 的第二阶段,等同于 ``deepspeed_stage_2``; + * ``deepspeed_stage_1`` -- 使用 ZeRO 的第一阶段,仅将 ``optimizer`` 的状态分散到不同设备上; + * ``deepspeed_stage_2`` -- 使用 ZeRO 的第二阶段,将 ``optimizer`` 和**梯度**分散到不同设备上; + * ``deepspeed_stage_2_offload`` -- 使用 ZeRO 的第二阶段,并且借助 cpu 的内存来进一步节约显存; + * ``deepspeed_stage_3`` -- 使用 ZeRO 的第三阶段,将 ``optimizer`` 、**梯度**和**模型**分散到不同设备上; + * ``deepspeed_stage_3_offload`` -- 使用 ZeRO 的第三阶段,并且借助 cpu 的内存来进一步节约显存; + * ``deepspeed_stage_3_offload_nvme`` -- 使用 ZeRO 的第三阶段,并且借助 NVMe 硬盘来进一步节约显存; + * *logging_level* -- ``deepspeed`` 库的日志等级,默认为 **logging.ERROR**; + * *config* -- ``deepspeed`` 的各项设置;**FastNLP** 允许用户传入自己的设置以增强灵活性,但这会使参数 + 中的 ``optimizer`` 、``strategy`` 、 ``fp16`` 等失效,即当这个参数存在时,**FastNLP** 会用该参数覆盖 + 其它的设置; + """ + # TODO fp16 load_config + def __init__( + self, + model, + parallel_device: Union[List["torch.device"], "torch.device"], + is_pull_by_torch_run = False, + fp16: bool = False, + deepspeed_kwargs: Dict = None, + **kwargs + ): + assert _NEED_IMPORT_DEEPSPEED, "Deepspeed is not imported." + kwargs.pop("torch_kwargs", None) + self._ds_kwargs = deepspeed_kwargs + TorchDriver.__init__(self, model=model, fp16=False, torch_kwargs=deepspeed_kwargs, **kwargs) + self.fp16 = fp16 + + # 如果用户自己在外面初始化 DDP,那么其一定是通过 python -m torch.distributed.launch 拉起的; + self.is_pull_by_torch_run = is_pull_by_torch_run + self.parallel_device = parallel_device + if not is_pull_by_torch_run and parallel_device is None: + raise ValueError( + "Parameter `parallel_device` can not be None when using `TorchDeepSpeedDriver`. This error is caused " + "when your value of parameter `device` is `None` in your `Trainer` instance.") + + # 注意我们在 initialize_torch_driver 中的逻辑就是如果是 is_pull_by_torch_run,那么我们就直接把 parallel_device 置为当前进程的gpu; + if is_pull_by_torch_run: + self.model_device = parallel_device + else: + # 我们的 model_device 一定是 torch.device,而不是一个 list; + self.model_device = parallel_device[self.local_rank] + + # 如果用户自己在外面初始化了 deepspeed; + self.outside_ddp = False + if dist.is_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and \ + "fastnlp_torch_launch_not_ddp" not in os.environ: + # 如果用户自己在外面初始化了 deepspeed,那么我们要求用户传入的模型一定是已经由 DeepSpeedEngine 包裹后的模型; + if not isinstance(model, DeepSpeedEngine): + raise RuntimeError( + "It is not allowed to input a normal model instead of `DeepSpeedEngine` when" + "you initialize the ddp process out of our control.") + + self.outside_ddp = True + self.config = model.config + self.model_device = None + + self._data_device = kwargs.get("data_device", None) + if isinstance(self._data_device, int): + if self._data_device < 0: + raise ValueError("Parameter `data_device` can not be smaller than 0.") + _could_use_device_num = torch.cuda.device_count() + if self._data_device >= _could_use_device_num: + raise ValueError("The gpu device that parameter `device` specifies is not existed.") + self._data_device = torch.device(f"cuda:{self._data_device}") + elif isinstance(self._data_device, str): + self._data_device = torch.device(self._data_device) + elif self._data_device is not None and not isinstance(self._data_device, torch.device): + raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") + + self._master_port = None + # world_size 表示的就是全局的显卡的数量; + self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) + self.global_rank = 0 + + self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") + assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type." + if self.output_from_new_proc not in {"all", "ignore", "only_error"}: + os.makedirs(name=self.output_from_new_proc, exist_ok=True) + self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) + + self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; + self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; + self.accumulation_steps = kwargs.get("accumulation_steps", 1) + # 获取 batch_size 以设置 train_micro_batch_size_per_gpu 参数 + train_dl = kwargs.get("train_dataloader", None) + if train_dl is not None: + self.train_micro_batch_size = self.get_dataloader_args(train_dl).batch_size + else: + logger.warning("No `train_dataloader` found, and we will set `train_micro_batch_size_per_gpu`" + "to 1 for deepspeed configuration.") + self.train_micro_batch_size = 1 + + self.strategy = self._ds_kwargs.get("strategy", "deepspeed") + deepspeed_logging_level = self._ds_kwargs.get("logging_level", logging.ERROR) + deepspeed.utils.logging.logger.setLevel(deepspeed_logging_level) + + @staticmethod + def _check_optimizer_legality(optimizers): + for each_optimizer in optimizers: + if not isinstance(each_optimizer, (Optimizer, DeepSpeedOptimizer)): + raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' or " + f"'DeepSpeedOptimizer'type, not {type(each_optimizer)}.") + + def setup(self): + r""" + 准备分布式环境,该函数主要做以下两件事情: + + 1. 开启多进程,每个 gpu 设备对应单独的一个进程; + 2. 使用 ``deepspeed.initialize`` 包裹模型; + """ + if len(self.optimizers) != 1: + raise ValueError("Multi optimizers is not supported for `DeepSpeedDriver` right now.") + if self._has_setup: + return + self._has_setup = True + self.setup_config() + # 如果用户需要使用多机模式,那么一定进入到这里; + if self.is_pull_by_torch_run: + if self.outside_ddp: + self.world_size = dist.get_world_size() + self.global_rank = dist.get_rank() + else: + # dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用; + self.world_size = int(os.environ.get("WORLD_SIZE")) + self.global_rank = int(os.environ.get("RANK")) + logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") + + if not dist.is_initialized(): + deepspeed.init_distributed("nccl", distributed_port=self.master_port) + + os.environ["fastnlp_torch_launch_not_ddp"] = "yes" + + # 进入到这里的情况时: + # dist.is_initialized 一定为 False; + # 一定是单机; + # self.parallel_device 一定是 List[torch.device]; + else: + if not dist.is_initialized(): + # 这里主要的问题在于要区分 rank0 和其它 rank 的情况; + self.world_size = len(self.parallel_device) + self.open_subprocess() + self.global_rank = self.local_rank # rank 一定是通过环境变量去获取的; + deepspeed.init_distributed("nccl", distributed_port=self.master_port) + # 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 DeepSpeedDriver; + else: + # 如果 `dist.is_initialized() == True`,那么说明 DeepSpeedDriver 在之前已经初始化并且已经 setup 过一次,那么我们需要保证现在 + # 使用的(即之后的)DeepSpeedDriver 的设置和第一个 DeepSpeedDriver 是完全一样的; + pre_num_processes = int(os.environ[FASTNLP_DISTRIBUTED_CHECK]) + if pre_num_processes != len(self.parallel_device): + raise RuntimeError( + "Notice you are using `DeepSpeedDriver` after one instantiated `DeepSpeedDriver`, it is not" + "allowed that your second `DeepSpeedDriver` has a new setting of parameters " + "`num_nodes` and `num_processes`.") + self.world_size = dist.get_world_size() + self.global_rank = dist.get_rank() + + if not self.outside_ddp: + torch.cuda.set_device(self.model_device) + # 不加 dist.broadcast_object_list 会发生设备在 4,5 但是模型会同步到 0,1 的情况 + # 原因未知 + dist.broadcast_object_list(["test"], 0, None) + self.configure_ddp() + + self.barrier() + # 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; + self._pids = [torch.tensor(0, dtype=torch.int).to(self.data_device) for _ in range(dist.get_world_size())] + dist.all_gather(self._pids, torch.tensor(os.getpid(), dtype=torch.int).to(self.data_device)) + local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None + if local_world_size is None: + local_world_size = torch.tensor(int(os.environ.get("LOCAL_RANK")), dtype=torch.int).to(self.data_device) + dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX) + local_world_size = local_world_size.tolist() + 1 + + node_rank = self.global_rank // local_world_size + self._pids = self._pids[node_rank * local_world_size: (node_rank + 1) * local_world_size] + self._pids = self.tensor_to_numeric(self._pids) + + def configure_ddp(self): + + # 设置 deepspeed + if not isinstance(self.model, DeepSpeedEngine): + model=_DeepSpeedWrappingModel(self.model, self.fp16) + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + self.model, ds_optimizer, _, _ = deepspeed.initialize( + args=argparse.Namespace(device_rank=self.model_device.index), + model=model, + optimizer=self.optimizers[0], + model_parameters=model_parameters, + config=self.config, + dist_init_required=False + ) + self._optimizers = [ds_optimizer] + + if self.config.get("activation_checkpointing"): + checkpoint_config = self.config["activation_checkpointing"] + deepspeed.checkpointing.configure( + mpu_=None, + partition_activations=checkpoint_config.get("partition_activations"), + contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"), + checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"), + profile=checkpoint_config.get("profile"), + ) + + self._has_ddpwrapped = True + + def setup_config(self): + + self.config = self._ds_kwargs.get("config") + if self.config is not None: + logger.warning("Notice that you have defined a configuration for deepspeed and parameters like" + "`optimizers`, `strategy` and `fp16` may not take effects.") + return + + if self.strategy == "deepspeed": + self.config = _create_default_config(stage=2) + elif self.strategy == "deepspeed_stage_1": + self.config = _create_default_config(stage=1) + elif self.strategy == "deepspeed_stage_2": + self.config = _create_default_config(stage=2) + elif self.strategy == "deepspeed_stage_2_offload": + self.config = _create_default_config(stage=2, offload_optimizer=True) + elif self.strategy == "deepspeed_stage_3": + self.config = _create_default_config(stage=3) + elif self.strategy == "deepspeed_stage_3_offload": + self.config = _create_default_config( + stage=3, + offload_optimizer=True, + offload_parameters=True, + ) + elif self.strategy == "deepspeed_stage_3_offload_nvme": + self.config = _create_default_config( + stage=3, + offload_optimizer=True, + offload_parameters=True, + remote_device="nvme", + offload_params_device="nvme", + offload_optimizer_device="nvme", + ) + else: + raise ValueError(f"Unknown deepspeed strategy {self.strategy}.") + + # 设置成 max_int 防止 deepspeed 的输出干扰 fastnlp 的输出 + self.config.setdefault("steps_per_print", 2147483647) + self.config["gradient_accumulation_steps"] = self.accumulation_steps + self.config.setdefault("train_micro_batch_size_per_gpu", self.train_micro_batch_size) + + if self.fp16: + if "fp16" not in self.config: + # FP16 is a DeepSpeed standalone AMP implementation + logger.debug("Enabling DeepSpeed FP16.") + # TODO 这部分是否可以像 pytorch-lightning 那样给用户定制 + self.config["fp16"] = { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": True, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1, + } + elif "amp" not in self.config: + logger.debug("Enabling DeepSpeed APEX Implementation.") + self.config["amp"] = {"enabled": True, "opt_level": "O1"} + + def zero_grad(self): + # DeepSpeedEngine.step 包含了 zero_grad 功能 + pass + + def backward(self, loss): + self.model.backward(loss) + + def step(self): + self.model.step() + + def get_model_no_sync_context(self): + r""" + :return: 返回一个 ``context`` 上下文环境,用于关闭各个进程之间的同步;在 ``deepspeed`` 中,返回一个空的上下文 + """ + # 注意此时的 model 是 "DistributedDataParallel" 对象; + return nullcontext + + def save_model(self, filepath: Union[str, Path], only_state_dict: bool = False, **kwargs): + """ + 保存当前 driver 的模型到 folder 下。 + + :param filepath: 保存到哪个文件夹; + :param only_state_dict: 是否只保存权重;在 ``DeepSpeedDriver`` 中该参数无效; + :return: + """ + # deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器 + if self.stage_3: + logger.rank_zero_warning( + "When saving the DeepSpeed Stage 3 checkpoint, " + "each worker will save a shard of the checkpoint within a directory. " + # TODO check一下 + # "If a single file is required after training, " + # "see https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#" + # "deepspeed-zero-stage-3-single-file for instructions." + ) + if not only_state_dict: + logger.rank_zero_warning("Only saving state dict is not allowed for `DeepSpeedDriver`. We will save its " + "checkpoint for you instead.") + self.model.save_checkpoint(filepath, **kwargs) + + def load_model(self, filepath: Union[Path, str], only_state_dict: bool = False, **kwargs): + """ + 从 folder 中加载权重并赋值到当前 driver 的模型上。 + + :param filepath: 加载权重或模型的路径 + :param load_state_dict: 保存的内容是否只是权重;在 ``DeepSpeedDriver`` 中该参数无效; + :param kwargs: + :return: + """ + if not only_state_dict: + logger.warning("Only loading state dict is not allowed for `DeepSpeedDriver`. We will load its " + "checkpoint for you instead.") + self.model.load_checkpoint(filepath, **kwargs) + + def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): + # deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器 + # 1. 保存 sampler 的状态 + num_consumed_batches = states.pop('num_consumed_batches') + states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches) + + # 2. 保存模型的状态; + if not should_save_model: + logger.rank_zero_warning("Saving checkpoint without model is not allowed for `DeepSpeedDriver`, " + "so we will still save the model for you.") + + self.model.save_checkpoint(Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME), + client_state=states) + + def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: + # 1. 加载模型状态; + if not should_load_model: + logger.rank_zero_warning("Loading checkpoint without model is not allowed for `DeepSpeedDriver`, " + "so we will still load the model for you.") + load_path, states = self.model.load_checkpoint(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) + if load_path is None: + raise RuntimeError(f"Failed to load checkpoint from path: {str(folder)}") + + # 2.恢复 sampler 的状态 + sampler_states = states.pop('sampler_states') + states_ret = self.load_sampler_state(dataloader, sampler_states) + states.update(states_ret) + + return states + + @property + def stage_3(self) -> bool: + return self.config.get("zero_optimization") and self.config.get("zero_optimization").get("stage") == 3 \ No newline at end of file diff --git a/fastNLP/core/drivers/torch_driver/fairscale.py b/fastNLP/core/drivers/torch_driver/fairscale.py index 02dda6a6..304f0bfa 100644 --- a/fastNLP/core/drivers/torch_driver/fairscale.py +++ b/fastNLP/core/drivers/torch_driver/fairscale.py @@ -35,11 +35,12 @@ class FairScaleDriver(TorchDDPDriver): parallel_device: Union[List["torch.device"], "torch.device"], is_pull_by_torch_run = False, fp16: bool = False, + fairscale_kwargs: Dict = None, **kwargs ): assert _NEED_IMPORT_FAIRSCALE, "fairscale is not imported." assert not dist.is_initialized(), "FairScaleDriver does not support initialize distributed by user." - self._fairscale_kwargs = kwargs.get('fairscale_kwargs', {}) + self._fairscale_kwargs = fairscale_kwargs self.fs_type = self._fairscale_kwargs.get('fs_type', 'sdp') # ddp, sdp, fsdp if self.fs_type == 'fsdp': self._fairscale_kwargs['set_grad_to_none'] = self._fairscale_kwargs.get('set_grad_to_none', True) diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index 0deac4dc..f242b813 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -8,6 +8,7 @@ from .torch_driver import TorchDriver from .single_device import TorchSingleDriver from .ddp import TorchDDPDriver from .fairscale import FairScaleDriver +from .deepspeed import DeepSpeedDriver from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_BACKEND_LAUNCH from pkg_resources import parse_version @@ -20,7 +21,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi r""" 用来根据参数 ``driver` 和 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去; - :param driver: 该参数的值应为以下之一:``["torch", "fairscale"]``; + :param driver: 该参数的值应为以下之一:``["torch", "fairscale", "deepspeed"]``; :param device: 该参数的格式与 ``Trainer`` 对参数 ``device`` 的要求一致; :param model: 训练或者评测的具体的模型; @@ -37,11 +38,14 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi if driver == 'fairscale': return FairScaleDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), is_pull_by_torch_run=True, **kwargs) + elif driver == 'deepspeed': + return DeepSpeedDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), + is_pull_by_torch_run=True, **kwargs) else: return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), is_pull_by_torch_run=True, **kwargs) - if driver not in {"torch", "fairscale"}: + if driver not in {"torch", "fairscale", "deepspeed"}: raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].") _could_use_device_num = torch.cuda.device_count() @@ -83,4 +87,12 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.") return FairScaleDriver(model, [device], **kwargs) else: - return FairScaleDriver(model, device, **kwargs) \ No newline at end of file + return FairScaleDriver(model, device, **kwargs) + elif driver == "deepspeed": + if not isinstance(device, List): + if device.type == 'cpu': + raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.") + logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.") + return DeepSpeedDriver(model, [device], **kwargs) + else: + return DeepSpeedDriver(model, device, **kwargs) \ No newline at end of file diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index c36e0f8d..483dc257 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -8,6 +8,7 @@ if _NEED_IMPORT_TORCH: from torch.nn.parallel import DistributedDataParallel from torch.utils.data import RandomSampler as TorchRandomSampler from torch.utils.data import SequentialSampler as TorchSequentialSampler + from torch.utils.data import BatchSampler as TorchBatchSampler __all__ = [ 'TorchSingleDriver' @@ -34,9 +35,13 @@ class TorchSingleDriver(TorchDriver): :param model: 传入给 ``Trainer`` 的 ``model`` 参数; :param device: torch.device,当前进程所使用的设备; :param fp16: 是否开启 fp16; + :param torch_kwargs: + * *set_grad_to_none* -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; + * *non_blocking* -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; + * *gradscaler_kwargs* -- 用于 fp16=True 时,提供给 ``torch.amp.cuda.GradScaler`` 的参数; """ - def __init__(self, model, device: "torch.device", fp16: bool = False, **kwargs): + def __init__(self, model, device: "torch.device", fp16: bool = False, torch_kwargs: Dict = None, **kwargs): if isinstance(model, DistributedDataParallel): raise ValueError("`DistributedDataParallel` is not supported in `TorchSingleDriver`") @@ -46,7 +51,7 @@ class TorchSingleDriver(TorchDriver): logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to" "use `cpu` instead of `gpu` device.") - super(TorchSingleDriver, self).__init__(model, fp16=fp16, **kwargs) + super(TorchSingleDriver, self).__init__(model, fp16=fp16, torch_kwargs=torch_kwargs, **kwargs) if device is None: logger.debug("device is not set, fastNLP will try to automatically get it.") @@ -123,19 +128,20 @@ class TorchSingleDriver(TorchDriver): return replace_sampler(dataloader, sampler) if reproducible: - if isinstance(args.sampler, TorchRandomSampler): - if getattr(args.sampler, '_num_samples', None) is None \ - and getattr(args.sampler, 'replacements', False) is False \ - and getattr(args.sampler, 'generator', None) is None: - # 如果本来就是随机的,并且没有定制,直接替换掉吧。 - sampler = RandomSampler(args.sampler.data_source, shuffle=True) - logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") + if type(args.batch_sampler) is TorchBatchSampler: + if type(args.sampler) is TorchRandomSampler: + if getattr(args.sampler, '_num_samples', None) is None \ + and getattr(args.sampler, 'replacements', False) is False \ + and getattr(args.sampler, 'generator', None) is None: + # 如果本来就是随机的,并且没有定制,直接替换掉吧。 + sampler = RandomSampler(args.sampler.data_source, shuffle=True) + logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") + return replace_sampler(dataloader, sampler) + elif type(args.sampler) is TorchSequentialSampler: + # 需要替换为不要 shuffle 的。 + sampler = RandomSampler(args.sampler.data_source, shuffle=False) + logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.") return replace_sampler(dataloader, sampler) - elif isinstance(args.sampler, TorchSequentialSampler): - # 需要替换为不要 shuffle 的。 - sampler = RandomSampler(args.sampler.data_source, shuffle=False) - logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.") - return replace_sampler(dataloader, sampler) batch_sampler = ReproduceBatchSampler( batch_sampler=args.batch_sampler, batch_size=args.batch_size, diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 96529073..a748aa32 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -31,6 +31,7 @@ from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler +from fastNLP.core.dataloaders import OverfitDataLoader class TorchDriver(Driver): @@ -46,12 +47,15 @@ class TorchDriver(Driver): 您可以在使用 ``TorchSingleDriver`` 和 ``TorchDDPDriver`` 时使用 ``TorchDriver`` 提供的接口; + :param model: 训练时使用的 **pytorch** 模型; + :param fp16: 是否开启混合精度训练; + :param torch_kwargs: """ - def __init__(self, model, fp16: Optional[bool] = False, **kwargs): + def __init__(self, model, fp16: Optional[bool] = False, torch_kwargs: Dict = None, **kwargs): super(TorchDriver, self).__init__(model) """ 进行 fp16 的设置 """ - self._torch_kwargs = kwargs.get("torch_kwargs", {}) + self._torch_kwargs = torch_kwargs if torch_kwargs is not None else {} # 因为 ddp 和 single_device 的混合精度训练的设置是一样的,因此可以统一抽象到这里; self.fp16 = fp16 @@ -92,7 +96,7 @@ class TorchDriver(Driver): self.grad_scaler.update() def check_dataloader_legality(self, dataloader): - if not isinstance(dataloader, DataLoader): + if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader): raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") if len(dataloader) == 0: logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " @@ -189,7 +193,30 @@ class TorchDriver(Driver): # 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; - # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; + # 1. sampler 的状态; + num_consumed_batches = states.pop('num_consumed_batches') + states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches) + + # 2. 保存模型的状态; + if should_save_model: + if not os.path.exists(folder): + os.mkdir(folder) + model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) + self.save_model(model_path, only_state_dict=only_state_dict) + + # 3. 保存 optimizers 的状态; + states["optimizers_state_dict"] = self.get_optimizer_state() + logger.debug("Save optimizer state dict.") + + # 4. 保存fp16的状态 + if not isinstance(self.grad_scaler, DummyGradScaler): + grad_scaler_state_dict = self.grad_scaler.state_dict() + states['grad_scaler_state_dict'] = grad_scaler_state_dict + + torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) + + def get_sampler_state(self, dataloader, num_consumed_batches): + # 因为我们支持 resume training,即精确恢复到具体的一个 batch; # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 # sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; dataloader_args = self.get_dataloader_args(dataloader) @@ -199,53 +226,58 @@ class TorchDriver(Driver): sampler = dataloader_args.sampler else: raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") - num_consumed_batches = states.pop('num_consumed_batches') + if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): sampler_states = sampler.state_dict() - # 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples - # 会造成多余实际消耗的问题。因为 - num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) - if num_consumed_samples_array is not None: - if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 - if dataloader_args.batch_size is not None: - num_consumed_batches = num_consumed_batches * dataloader_args.batch_size - else: # 有可能 batch_size 为 None,就只有损失精度了 - logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " - "it may cause missing some samples when reload.") - num_consumed_batches = sampler_states['num_consumed_samples'] - sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] - assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." + if dataloader_args.batch_size is not None: + sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ + * num_consumed_batches else: - if dataloader_args.batch_size is not None: - sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ - * num_consumed_batches - else: - logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " - "it may cause missing some samples when reload.") - - states['sampler_states'] = sampler_states + logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on sampler's " + "`num_consumed_samples`, it may cause missing some samples when reload.") else: raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training ' 'state.') - # 2. 保存模型的状态; - if should_save_model: - if not os.path.exists(folder): - os.mkdir(folder) - model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) - self.save_model(model_path, only_state_dict=only_state_dict) + return sampler_states - # 3. 保存 optimizers 的状态; - optimizers_state_dict = self.get_optimizer_state() + def load_sampler_state(self, dataloader, sampler_states): + states = {} + dataloader_args = self.get_dataloader_args(dataloader) + if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): + sampler = dataloader_args.batch_sampler + elif isinstance(dataloader_args.sampler, ReproducibleSampler): + sampler = dataloader_args.sampler + elif isinstance(dataloader_args.sampler, TorchRandomSampler): + sampler = RandomSampler(dataloader_args.sampler.data_source) + logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") + elif self.is_distributed(): + raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" + "`ReproducibleSampler`.") + else: + sampler = ReproduceBatchSampler( + batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, + batch_size=dataloader_args.batch_size, + drop_last=dataloader_args.drop_last + ) + sampler.load_state_dict(sampler_states) + states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) - # 4. 保存fp16的状态 - if not isinstance(self.grad_scaler, DummyGradScaler): - grad_scaler_state_dict = self.grad_scaler.state_dict() - states['grad_scaler_state_dict'] = grad_scaler_state_dict + # 修改 trainer_state.batch_idx_in_epoch + # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; + if not isinstance(sampler, ReproducibleBatchSampler): + if dataloader_args.drop_last: + batch_idx_in_epoch = len( + sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size + else: + batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ + (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size + # sampler 是 batch_sampler; + else: + batch_idx_in_epoch = sampler.batch_idx_in_epoch - logger.debug("Save optimizer state dict") - states["optimizers_state_dict"] = optimizers_state_dict - torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) + states["batch_idx_in_epoch"] = batch_idx_in_epoch + return states def get_optimizer_state(self): optimizers_state_dict = {} @@ -275,7 +307,7 @@ class TorchDriver(Driver): if should_load_model: self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) - # 3. 加载fp16的状态 + # 3. 加载 fp16 的状态 if "grad_scaler_state_dict" in states: grad_scaler_state_dict = states.pop("grad_scaler_state_dict") if not isinstance(self.grad_scaler, DummyGradScaler): @@ -286,40 +318,9 @@ class TorchDriver(Driver): f"the training process may be unstable.") # 4. 恢复 sampler 的状态; - dataloader_args = self.get_dataloader_args(dataloader) - if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): - sampler = dataloader_args.batch_sampler - elif isinstance(dataloader_args.sampler, ReproducibleSampler): - sampler = dataloader_args.sampler - elif isinstance(dataloader_args.sampler, TorchRandomSampler): - sampler = RandomSampler(dataloader_args.sampler.data_source) - logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") - elif self.is_distributed(): - raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" - "`ReproducibleSampler`.") - else: - sampler = ReproduceBatchSampler( - batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, - batch_size=dataloader_args.batch_size, - drop_last=dataloader_args.drop_last - ) - sampler.load_state_dict(states.pop('sampler_states')) - states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) - - # 4. 修改 trainer_state.batch_idx_in_epoch - # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; - if not isinstance(sampler, ReproducibleBatchSampler): - if dataloader_args.drop_last: - batch_idx_in_epoch = len( - sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size - else: - batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ - (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size - # sampler 是 batch_sampler; - else: - batch_idx_in_epoch = sampler.batch_idx_in_epoch - - states["batch_idx_in_epoch"] = batch_idx_in_epoch + sampler_states = states.pop('sampler_states') + states_ret = self.load_sampler_state(dataloader, sampler_states) + states.update(states_ret) return states diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index 2d13a8e8..8c44ea37 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -1,6 +1,6 @@ import os -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from enum import IntEnum import contextlib import random @@ -14,16 +14,19 @@ from fastNLP.envs import ( FASTNLP_BACKEND_LAUNCH, FASTNLP_GLOBAL_SEED, ) -from fastNLP.core.samplers import re_instantiate_sampler -from fastNLP.core.utils import auto_param_call +from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler +from fastNLP.core.utils import auto_param_call, apply_to_collection from fastNLP.core.log import logger if _NEED_IMPORT_TORCH: import torch # import torch.nn as nn from torch.nn import Module - from torch.utils.data import DataLoader, BatchSampler - from torch.utils.data.sampler import Sampler + from torch.utils.data import DataLoader + from torch.utils.data import RandomSampler as TorchRandomSampler + from torch.utils.data import SequentialSampler as TorchSequentialSampler + from torch.utils.data import BatchSampler as TorchBatchSampler + else: from fastNLP.core.utils.dummy_class import DummyClass as Module @@ -104,6 +107,29 @@ class _DDPWrappingModel(Module): else: return fn(batch) +class _DeepSpeedWrappingModel(_DDPWrappingModel): + """ + 继承 ``_DDPWrappingModel``,区别在于进行 forward 之前先将 float 数据转换为 float16 + """ + + def __init__(self, model: Module, fp16): + super(_DeepSpeedWrappingModel, self).__init__(model) + self.fp16 = fp16 + + def forward(self, batch, **kwargs): + if self.fp16: + batch = self._move_float_tensors_to_half(batch) + + return super().forward(batch, **kwargs) + + @staticmethod + def batch_to(data): + return data.half() + + def _move_float_tensors_to_half(self, batch: Any): + batch = apply_to_collection(batch, (torch.FloatTensor, torch.cuda.FloatTensor), function=self.batch_to) + return batch + class DummyGradScaler: """ @@ -178,28 +204,33 @@ def replace_sampler(dataloader: "DataLoader", sampler): instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} # 'multiprocessing_context' 是 user-defined function; - instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context + if getattr(dataloader, 'multiprocessing_context', None) is not None: + instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context # 拿到 dataloader '__init__' 函数的默认函数签名; init_params = dict(inspect.signature(dataloader.__init__).parameters) - # 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 - # 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 - # 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader - # 中寻找; + # 防止用户的 DataLoader 是继承了 pytorch 的 DataLoader,然后还是使用了 **kwargs 的方式对父类传参数 has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) - if has_variadic_kwargs: - init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) - del init_params["self"] + if has_variadic_kwargs and isinstance(dataloader, DataLoader): + # 防止用户写入了 super().__init__(**kwargs) + for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): + if key not in init_params and key != 'self': + init_params[key] = value - # 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; + # 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置; non_default_params = {name for name, p in init_params.items() if name in instance_attrs and p.default != instance_attrs[name]} # add `dataset` as it might have been replaced with `*args` non_default_params.add("dataset") reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} - reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) + if isinstance(dataloader, DataLoader): + reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None}) + + batch_sampler = getattr(dataloader, "batch_sampler") + if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): + raise RuntimeError("It should not be running here, please report a bug to us.") required_args = { p.name @@ -209,58 +240,32 @@ def replace_sampler(dataloader: "DataLoader", sampler): and p.name not in reconstruct_args } - # 这种错误针对的是 __init__ 中的参数没有用同样名字的 self 挂上; + # 在 attribute 中没有找到这些参数,导致了没有办法重新初始化 if required_args: required_args = sorted(required_args) dataloader_self_name = dataloader.__class__.__name__ raise Exception( - f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. " - "This would fail as some of the `__init__` arguments are not available as instance attributes. " - f"The missing attributes are {required_args}. " - f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or " - "manually add the `DistributedSampler` as: " - f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." + f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. " + f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its " + f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be " + f"`{dataloader_self_name}`'s attribute." ) # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; if not has_variadic_kwargs: - # the dataloader signature does not allow keyword arguments that need to be passed missing_kwargs = reconstruct_args.keys() - init_params.keys() if missing_kwargs: missing_kwargs = sorted(missing_kwargs) dataloader_self_name = dataloader.__class__.__name__ raise Exception( - f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. " - "This would fail as it doesn't expose all its attributes in the `__init__` signature. " - f"The missing arguments are {missing_kwargs}. " - f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or " - "manually add the `DistributedSampler` as: " - f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." + f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found." ) - return type(dataloader)(**reconstruct_args) - + # 如果没有kwargs,则保证一下只传入需要的参数 + if not isinstance(dataloader, DataLoader): + reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params} -def _dataloader_init_kwargs_resolve_sampler( - dataloader: "DataLoader", sampler: Optional["Sampler"] -) -> Dict[str, Any]: - r""" - 此函数用于处理与 DataLoader 关联的采样器、batch_sampler 参数重新实例化; - """ - batch_sampler = getattr(dataloader, "batch_sampler") - # checking the batch sampler type is different than PyTorch default. - if batch_sampler is not None and not isinstance(batch_sampler, BatchSampler): - batch_sampler = re_instantiate_sampler(batch_sampler) - - return { - "sampler": None, - "shuffle": False, - "batch_sampler": batch_sampler, - "batch_size": 1, - "drop_last": False, - } - - return {"sampler": sampler, "shuffle": False, "batch_sampler": None} + return type(dataloader)(**reconstruct_args) def replace_batch_sampler(dataloader, new_batch_sampler): @@ -273,6 +278,13 @@ def replace_batch_sampler(dataloader, new_batch_sampler): params_keys.remove(k) params = {k: getattr(dataloader, k) for k in params_keys} params["batch_sampler"] = new_batch_sampler + + if not isinstance(dataloader, DataLoader): + init_params = dict(inspect.signature(dataloader.__init__).parameters) + has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) + if not has_variadic_kwargs: + params = {key:value for key,value in params.items() if key in init_params} + return type(dataloader)(**params) @@ -295,5 +307,98 @@ def optimizer_state_to_device(state, device): return new_state - - +def _check_dataloader_args_for_distributed(args, controller='Trainer'): + if type(args.batch_sampler) is not TorchBatchSampler or (type(args.sampler) not in {TorchRandomSampler, + TorchSequentialSampler}): + mode = 'training' if controller == 'Trainer' else 'evaluation' + substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' + raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " + f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " + f"``{substitution}``. The customized sampler should set for distributed running " + f"before initializing ``{controller}`` , and then set the " + f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") + +def _create_default_config( + zero_optimization: bool = True, + zero_allow_untested_optimizer: bool = True, + logging_batch_size_per_gpu: Union[str, int] = "auto", + partition_activations: bool = False, + cpu_checkpointing: bool = False, + contiguous_memory_optimization: bool = False, + synchronize_checkpoint_boundary: bool = False, + offload_optimizer: bool = False, + offload_parameters: bool = False, + offload_params_device: str = "cpu", + nvme_path: str = "/local_nvme", + params_buffer_count: int = 5, + params_buffer_size: int = 100_000_000, + max_in_cpu: int = 1_000_000_000, + offload_optimizer_device: str = "cpu", + optimizer_buffer_count: int = 4, + pin_memory: bool = False, + block_size: int = 1048576, + queue_depth: int = 8, + single_submit: bool = False, + overlap_events: bool = True, + thread_count: int = 1, + stage: int = 2, + contiguous_gradients: bool = True, + overlap_comm: bool = True, + allgather_partitions: bool = True, + reduce_scatter: bool = True, + allgather_bucket_size: int = 200_000_000, + reduce_bucket_size: int = 200_000_000, + sub_group_size: int = 1_000_000_000_000, +) -> Dict: + cfg = { + "activation_checkpointing": { + "partition_activations": partition_activations, + "cpu_checkpointing": cpu_checkpointing, + "contiguous_memory_optimization": contiguous_memory_optimization, + "synchronize_checkpoint_boundary": synchronize_checkpoint_boundary, + }, + "aio": { + "block_size": block_size, + "queue_depth": queue_depth, + "single_submit": single_submit, + "overlap_events": overlap_events, + "thread_count": thread_count, + }, + } + zero_kwargs = { + "stage": stage, + "contiguous_gradients": contiguous_gradients, + "overlap_comm": overlap_comm, + "allgather_partitions": allgather_partitions, + "reduce_scatter": reduce_scatter, + "allgather_bucket_size": allgather_bucket_size, + "reduce_bucket_size": reduce_bucket_size, + "sub_group_size": sub_group_size, + } + if zero_optimization: + zero_config = zero_kwargs + + if offload_optimizer: + zero_config["offload_optimizer"] = { + "device": offload_optimizer_device, + "nvme_path": nvme_path, + "buffer_count": optimizer_buffer_count, + "pin_memory": pin_memory, + } + if offload_parameters: + zero_config["offload_param"] = { + "device": offload_params_device, + "nvme_path": nvme_path, + "buffer_count": params_buffer_count, + "buffer_size": params_buffer_size, + "max_in_cpu": max_in_cpu, + "pin_memory": pin_memory, + } + cfg = { + "zero_allow_untested_optimizer": zero_allow_untested_optimizer, + "zero_optimization": zero_config, + **cfg, + } + if logging_batch_size_per_gpu != "auto": + cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg} + return cfg \ No newline at end of file diff --git a/fastNLP/core/metrics/accuracy.py b/fastNLP/core/metrics/accuracy.py index 47d5e114..fbd826bd 100644 --- a/fastNLP/core/metrics/accuracy.py +++ b/fastNLP/core/metrics/accuracy.py @@ -69,7 +69,7 @@ class Accuracy(Metric): elif pred.ndim == target.ndim + 1: pred = pred.argmax(axis=-1) if seq_len is None and target.ndim > 1: - logger.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") + logger.warning("You are not passing `seq_len` to exclude pad when calculate accuracy.") else: raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " diff --git a/fastNLP/core/metrics/backend/auto_backend.py b/fastNLP/core/metrics/backend/auto_backend.py index e2515313..f671ad2e 100644 --- a/fastNLP/core/metrics/backend/auto_backend.py +++ b/fastNLP/core/metrics/backend/auto_backend.py @@ -8,6 +8,7 @@ from .backend import Backend from .torch_backend.backend import TorchBackend from .paddle_backend.backend import PaddleBackend from .jittor_backend.backend import JittorBackend +from .oneflow_backend.backend import OneflowBackend class AutoBackend(Backend): @@ -52,6 +53,8 @@ class AutoBackend(Backend): self.__class__ = PaddleBackend elif backend == 'jittor': self.__class__ = JittorBackend + elif backend == 'oneflow': + self.__class__ = OneflowBackend elif backend is None: # 不用做任何事情就可以初始化了 pass diff --git a/fastNLP/core/metrics/backend/oneflow_backend/__init__.py b/fastNLP/core/metrics/backend/oneflow_backend/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastNLP/core/metrics/backend/oneflow_backend/backend.py b/fastNLP/core/metrics/backend/oneflow_backend/backend.py new file mode 100644 index 00000000..6392b09d --- /dev/null +++ b/fastNLP/core/metrics/backend/oneflow_backend/backend.py @@ -0,0 +1,130 @@ +from typing import List + +import numpy as np + +from fastNLP.core.metrics.backend import Backend +from fastNLP.core.metrics.utils import AggregateMethodError +from fastNLP.core.utils import is_in_oneflow_dist +from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW +from fastNLP.core.drivers.oneflow_driver.dist_utils import fastnlp_oneflow_all_gather + + +if _NEED_IMPORT_ONEFLOW: + import oneflow + import oneflow.comm as comm + +__all__ = [] + +class OneflowBackend(Backend): + def __init__(self): + super().__init__() + self._specified = True + + def aggregate(self, tensor, method: str): + """ + 聚集结果,并根据 method 计算后,返回结果 + + :param tensor: 需要聚合的张量 + :param method: 聚合的方法, 目前支持 ``['sum', 'mean', 'max', 'mix']``: + + * method 为 ``'sum'`` 时, 会将多张卡上聚合结果在维度为 `0` 上 累加起来。 + * method 为 ``'mean'`` 时,会将多张卡上聚合结果在维度为 `0` 上取平均值。 + * method 为 ``'max'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最大值。 + * method 为 ``'mix'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最小值。 + + """ + if isinstance(tensor, oneflow.Tensor): + # TODO 暂时没有找到 oneflow 中检测是否初始化了分布式环境的方法 + if is_in_oneflow_dist(): + if method is None: + raise AggregateMethodError(should_have_aggregate_method=True) + tensor = self.all_gather_object(tensor) + if isinstance(tensor[0], oneflow.Tensor): + tensor = oneflow.stack(tensor) + # 第一步, aggregate结果 + if method == 'sum': + tensor = oneflow.sum(tensor, dim=0) + elif method == 'mean': + tensor = oneflow.mean(tensor, dim=0) + elif method == 'max': + tensor, _ = oneflow.max(tensor, dim=0) + elif method == 'min': + tensor, _ = oneflow.min(tensor, dim=0) + else: + raise AggregateMethodError(should_have_aggregate_method=False) + + return tensor + + def create_tensor(self, value: float): + """ + 创建 tensor,并且填入 value 作为值 + + :param value: 创建张量的初始值 + """ + tensor = oneflow.ones(1).fill_(value) + return tensor + + def fill_value(self, tensor, value: float): + """ + 将 tensor 的值设置为 value + + :param tensor: 传入的张量 + :param value: 需要 fill 的值。 + """ + tensor.fill_(value) + return tensor + + def get_scalar(self, tensor) -> float: + """ + 获取 tensor 的 scalar 值 + + :param tensor: 传入的张量 + """ + return tensor.item() + + def tensor2numpy(self, tensor) -> np.array: + """ + 将 tensor 转为 numpy 值, 主要是在 metric 计算中使用 + + :param tensor: 传入的张量 + """ + + if isinstance(tensor, oneflow.Tensor): + return tensor.cpu().detach().numpy() + elif isinstance(tensor, np.ndarray): + return tensor + elif isinstance(tensor, (float, int)): + return tensor + else: + raise ValueError(f"tensor: {tensor} can not convert to ndarray!") + + @staticmethod + def is_distributed() -> bool: + """ + 判断是否为 ddp 状态 + + :return: + """ + return is_in_oneflow_dist() + + def move_tensor_to_device(self, tensor, device): + """ + 将张量移到设备上 + + :param tensor: 需要移动的张量 + :param device: 设备名, 一般为 "cpu", "cuda:0"等字符串 + """ + return tensor.to(device) + + def all_gather_object(self, obj, group=None) -> List: + """ + 给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。 + + :param obj: + :param group: + """ + if self.is_distributed(): + obj_list = fastnlp_oneflow_all_gather(obj) + return obj_list + return [obj] + diff --git a/fastNLP/core/metrics/classify_f1_pre_rec_metric.py b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py index daf325c0..39565f40 100644 --- a/fastNLP/core/metrics/classify_f1_pre_rec_metric.py +++ b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py @@ -156,7 +156,7 @@ class ClassifyFPreRecMetric(Metric): elif pred.ndim == target.ndim + 1: pred = pred.argmax(axis=-1) if seq_len is None and target.ndim > 1: - logger.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") + logger.warning("You are not passing `seq_len` to exclude pad when calculate accuracy.") else: raise RuntimeError(f"when pred have " f"size:{pred.shape}, target should have size: {pred.shape} or " diff --git a/fastNLP/core/metrics/metric.py b/fastNLP/core/metrics/metric.py index 1a69e80c..0f904041 100644 --- a/fastNLP/core/metrics/metric.py +++ b/fastNLP/core/metrics/metric.py @@ -20,7 +20,7 @@ class Metric: :param backend: 目前支持四种类型的 backend, ``[torch, paddle, jittor, auto]``。其中 ``auto`` 表示根据实际调用 Metric.update() 函数时传入的参数决定具体的 ``backend`` ,大部分情况下直接使用 ``auto`` 即可。 :param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, - 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 :class:`~fastNLP.Evaluator` 中根据 sampler 是否使用分布式 + 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 :class:`~fastNLP.core.controllers.Evaluator` 中根据 sampler 是否使用分布式 进行自动设置。 """ def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None): @@ -98,7 +98,7 @@ class Metric: return _wrap_get_metric def __setattr__(self, key, value): - if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: + if getattr(self, '_cannot_change_element', False): if key in self.elements and isinstance(value, (float, int, bool)): self.elements[key].fill_value(value) return @@ -109,6 +109,14 @@ class Metric: raise RuntimeError("Please use register_element() function to add Element.") object.__setattr__(self, key, value) + # 当调用 __getattribute__ 没有找到时才会触发这个, 保留这个的目的只是为了防止 ide 的 warning + def __getattr__(self, name: str) -> Element: + if 'elements' in self.__dict__: + elements = self.__dict__['elements'] + if name in elements: + return elements[name] + raise AttributeError("`{}` object has no attribute `{}`.".format(type(self).__name__, name)) + def _wrap_update(self, update): @functools.wraps(update) def _wrap_update(*args, **kwargs): diff --git a/fastNLP/core/metrics/span_f1_pre_rec_metric.py b/fastNLP/core/metrics/span_f1_pre_rec_metric.py index 9a0b1d9d..b264f93d 100644 --- a/fastNLP/core/metrics/span_f1_pre_rec_metric.py +++ b/fastNLP/core/metrics/span_f1_pre_rec_metric.py @@ -39,7 +39,7 @@ def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encod f"encoding_type." tags = tags.replace(tag, '') # 删除该值 if tags: # 如果不为空,说明出现了未使用的tag - logger.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " + logger.warning(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " "encoding_type.") @@ -212,7 +212,7 @@ class SpanFPreRecMetric(Metric): :param backend: 目前支持四种类型的 backend, ``[torch, paddle, jittor, auto]``。其中 ``auto`` 表示根据实际调用 Metric.update() 函数时传入的参数决定具体的 ``backend`` ,大部分情况下直接使用 ``auto`` 即可。 :param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, - 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 :class:`~fastNLP.Evaluator` 中根据 sampler 是否使用分布式 + 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 :class:`~fastNLP.core.controllers.Evaluator` 中根据 sampler 是否使用分布式 进行自动设置。 """ def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index f522f997..50276ba1 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -13,7 +13,6 @@ from itertools import chain import numpy as np from fastNLP.core.dataset import DataSet -from fastNLP.envs.utils import get_global_seed from fastNLP.core.log import logger from .utils import create_array from abc import abstractmethod @@ -171,7 +170,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): :param kwargs: fastNLP 保留使用 """ def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, - drop_last: bool = False, seed: int = None, **kwargs): + drop_last: bool = False, seed: int = 0, **kwargs): super().__init__() self.dataset = dataset @@ -179,7 +178,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last - self.seed = get_global_seed() if seed is None else seed + self.seed = int(seed) self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 @@ -398,7 +397,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): :param kwargs: fastNLP 保留使用 """ def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, - shuffle: bool = True, drop_last: bool = False, seed: int = None, **kwargs): + shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): super().__init__() if isinstance(dataset, DataSet) and isinstance(length, str): length = dataset.get_field(length).content @@ -423,7 +422,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): self.num_batch_per_bucket = num_batch_per_bucket self.shuffle = shuffle self.drop_last = drop_last - self.seed = get_global_seed() if seed is None else seed + self.seed = int(seed) self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index dc396851..e1a06fa1 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -12,7 +12,6 @@ import numpy as np from fastNLP.core.log import logger from fastNLP.core.dataset import DataSet -from fastNLP.envs.utils import get_global_seed class ReproducibleSampler: @@ -66,11 +65,11 @@ class RandomSampler(ReproducibleSampler): :param seed: 随机数种子。 :param kwargs: 用户不需要使用,fastNLP 内部使用 """ - def __init__(self, dataset, shuffle: bool = True, seed: int = None, **kwargs): + def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): super(RandomSampler, self).__init__() self.dataset = dataset self.shuffle = shuffle - self.seed = get_global_seed() if seed is None else seed + self.seed = int(seed) self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 diff --git a/fastNLP/core/samplers/unrepeated_sampler.py b/fastNLP/core/samplers/unrepeated_sampler.py index 22207274..e959a4d0 100644 --- a/fastNLP/core/samplers/unrepeated_sampler.py +++ b/fastNLP/core/samplers/unrepeated_sampler.py @@ -7,7 +7,6 @@ __all__ = [ from typing import List, Union from fastNLP.core.dataset import DataSet -from fastNLP.envs.utils import get_global_seed import numpy as np @@ -28,10 +27,10 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): :param seed: 设置的随机数种子 :param kwargs: fastNLP 保留使用 """ - def __init__(self, dataset, shuffle: bool = False, seed: int = None, **kwargs): + def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): self.dataset = dataset self.shuffle = shuffle - self.seed = get_global_seed() if seed is None else seed + self.seed = int(seed) # 多卡的相关的参数 self.num_replicas = kwargs.get('num_replicas', 1) diff --git a/fastNLP/core/utils/__init__.py b/fastNLP/core/utils/__init__.py index 0857f450..d188bc37 100644 --- a/fastNLP/core/utils/__init__.py +++ b/fastNLP/core/utils/__init__.py @@ -1,5 +1,6 @@ __all__ = [ 'cache_results', + 'is_jittor_module', 'is_jittor_dataset', 'jittor_collate_wraps', 'paddle_to', @@ -9,8 +10,14 @@ __all__ = [ 'is_in_paddle_dist', 'is_in_fnlp_paddle_dist', 'is_in_paddle_launch_dist', + 'is_paddle_module', 'f_rich_progress', 'torch_move_data_to_device', + 'is_torch_module', + 'get_oneflow_device', + 'oneflow_move_data_to_device', + 'is_oneflow_module', + 'is_in_oneflow_dist', 'get_fn_arg_names', 'auto_param_call', 'check_user_specific_params', @@ -28,11 +35,12 @@ __all__ = [ ] from .cache_results import cache_results -from .jittor_utils import is_jittor_dataset, jittor_collate_wraps +from .jittor_utils import is_jittor_dataset, jittor_collate_wraps, is_jittor_module from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ - is_in_fnlp_paddle_dist, is_in_paddle_launch_dist + is_in_fnlp_paddle_dist, is_in_paddle_launch_dist, is_paddle_module from .rich_progress import f_rich_progress -from .torch_utils import torch_move_data_to_device +from .torch_utils import torch_move_data_to_device, is_torch_module +from .oneflow_utils import oneflow_move_data_to_device, is_oneflow_module, is_in_oneflow_dist, get_oneflow_device from .utils import * from .tqdm_progress import f_tqdm_progress from .seq_len_to_mask import seq_len_to_mask diff --git a/fastNLP/core/utils/jittor_utils.py b/fastNLP/core/utils/jittor_utils.py index f29b1f46..ac00cd22 100644 --- a/fastNLP/core/utils/jittor_utils.py +++ b/fastNLP/core/utils/jittor_utils.py @@ -1,6 +1,7 @@ __all__ = [ + 'is_jittor_module', 'is_jittor_dataset', - 'jittor_collate_wraps' + 'jittor_collate_wraps', ] from collections.abc import Mapping, Callable @@ -13,6 +14,17 @@ if _NEED_IMPORT_JITTOR: from fastNLP.core.dataset import Instance +def is_jittor_module(model) -> bool: + """ + 判断传入的 ``model`` 是否是 :class:`jittor.Module` 类型 + + :param model: 模型; + :return: 当前模型是否为 ``jittor`` 的模型; + """ + try: + return isinstance(model, jt.Module) + except BaseException: + return False def is_jittor_dataset(dataset) -> bool: """ diff --git a/fastNLP/core/utils/oneflow_utils.py b/fastNLP/core/utils/oneflow_utils.py new file mode 100644 index 00000000..f9225466 --- /dev/null +++ b/fastNLP/core/utils/oneflow_utils.py @@ -0,0 +1,69 @@ +import os +from typing import Any, Union, Optional +from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK +from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW + +if _NEED_IMPORT_ONEFLOW: + import oneflow + +__all__ = [ + 'get_oneflow_device' + 'oneflow_move_data_to_device', + 'is_oneflow_module', + 'is_in_oneflow_dist', +] + +from .utils import apply_to_collection + +def get_oneflow_device(device): + """ + 构造一个 :class:`oneflow.device` 实例并返回。 + + :param device: 字符串或 gpu 编号 + :return: :class:`oneflow.device` + """ + if isinstance(device, oneflow.device): + return device + if isinstance(device, int): + return oneflow.device("cuda", device) + if isinstance(device, str): + return oneflow.device(device) + raise RuntimeError(f"Cannot get `oneflow.device` from {device}.") + +def oneflow_move_data_to_device(batch: Any, device: Optional[Union[str, "oneflow.device"]] = None) -> Any: + r""" + 在 **oneflow** 中将数据集合 ``batch`` 传输到给定设备。任何定义方法 ``to(device)`` 的对象都将被移动并且集合中的所有其他对象将保持不变; + + :param batch: 需要迁移的数据; + :param device: 数据应当迁移到的设备;当该参数的值为 ``None`` 时则不执行任何操作; + :return: 迁移到新设备上的数据集合; + """ + if device is None: + return batch + + def batch_to(data: Any) -> Any: + data_output = data.to(device) + if data_output is not None: + return data_output + # user wrongly implemented the `TransferableDataType` and forgot to return `self`. + return data + + return apply_to_collection(batch, dtype=oneflow.Tensor, function=batch_to) + +def is_oneflow_module(model) -> bool: + """ + 判断传入的 ``model`` 是否是 :class:`oneflow.nn.Module` 类型 + + :param model: 模型; + :return: 当前模型是否为 ``oneflow`` 的模型; + """ + try: + return isinstance(model, oneflow.nn.Module) + except BaseException: + return False + +def is_in_oneflow_dist() -> bool: + """ + 判断是否处于 **oneflow** 分布式的进程下。 + """ + return "GLOG_log_dir" in os.environ \ No newline at end of file diff --git a/fastNLP/core/utils/paddle_utils.py b/fastNLP/core/utils/paddle_utils.py index 9e7e73a4..adcbcabd 100644 --- a/fastNLP/core/utils/paddle_utils.py +++ b/fastNLP/core/utils/paddle_utils.py @@ -6,6 +6,7 @@ __all__ = [ "is_in_paddle_dist", "is_in_fnlp_paddle_dist", "is_in_paddle_launch_dist", + "is_paddle_module", ] import os @@ -174,4 +175,16 @@ def is_in_paddle_launch_dist() -> bool: """ 判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中 """ - return FASTNLP_BACKEND_LAUNCH in os.environ \ No newline at end of file + return FASTNLP_BACKEND_LAUNCH in os.environ + +def is_paddle_module(model) -> bool: + """ + 判断传入的 ``model`` 是否是 :class:`paddle.nn.Layer` 类型 + + :param model: 模型; + :return: 当前模型是否为 ``paddle`` 的模型; + """ + try: + return isinstance(model, paddle.nn.Layer) + except BaseException: + return False \ No newline at end of file diff --git a/fastNLP/core/utils/torch_utils.py b/fastNLP/core/utils/torch_utils.py index 0cef2205..c58715b8 100644 --- a/fastNLP/core/utils/torch_utils.py +++ b/fastNLP/core/utils/torch_utils.py @@ -8,7 +8,8 @@ if _NEED_IMPORT_TORCH: DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD __all__ = [ - 'torch_move_data_to_device' + 'torch_move_data_to_device', + 'is_torch_module', ] from .utils import apply_to_collection @@ -64,3 +65,15 @@ def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.dev dtype = TorchTransferableDataType return apply_to_collection(batch, dtype=dtype, function=batch_to) + +def is_torch_module(model) -> bool: + """ + 判断传入的 ``model`` 是否是 :class:`torch.nn.Module` 类型 + + :param model: 模型; + :return: 当前模型是否为 ``torch`` 的模型; + """ + try: + return isinstance(model, torch.nn.Module) + except BaseException: + return False \ No newline at end of file diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 11256d45..ec0c87b0 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -554,7 +554,7 @@ def deprecated(help_message: Optional[str] = None): def wrapper(*args, **kwargs): func_hash = hash(deprecated_function) if func_hash not in _emitted_deprecation_warnings: - logger.warn(warning_msg, category=FutureWarning, stacklevel=2) + logger.warning(warning_msg, category=FutureWarning, stacklevel=2) _emitted_deprecation_warnings.add(func_hash) return deprecated_function(*args, **kwargs) @@ -630,7 +630,7 @@ def is_notebook(): def flat_nest_dict(d:Dict, separator:str='#', compress_none_key:bool=True, top_down:bool=False) -> Dict: """ - 讲一个 nested 的 dict 转成 flat 的 dict,例如 + 将一个 nested 的 dict 转成 flat 的 dict,例如 ex:: d = {'test': {'f1': {'f': 0.2, 'rec': 0.1}}} -> {'f#f1#test':0.2, 'rec#f1#test':0.1} diff --git a/fastNLP/embeddings/torch/static_embedding.py b/fastNLP/embeddings/torch/static_embedding.py index de2b231a..cc15c214 100644 --- a/fastNLP/embeddings/torch/static_embedding.py +++ b/fastNLP/embeddings/torch/static_embedding.py @@ -286,7 +286,7 @@ class StaticEmbedding(TokenEmbedding): if word in vocab: index = vocab.to_index(word) if index in matrix: - logger.warn(f"Word has more than one vector in embedding file. Set logger level to " + logger.warning(f"Word has more than one vector in embedding file. Set logger level to " f"DEBUG for detail.") logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)") matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) @@ -295,7 +295,7 @@ class StaticEmbedding(TokenEmbedding): found_count += 1 except Exception as e: if error == 'ignore': - logger.warn("Error occurred at the {} line.".format(idx)) + logger.warning("Error occurred at the {} line.".format(idx)) else: logger.error("Error occurred at the {} line.".format(idx)) raise e diff --git a/fastNLP/envs/imports.py b/fastNLP/envs/imports.py index 77b642c3..08afc6a5 100644 --- a/fastNLP/envs/imports.py +++ b/fastNLP/envs/imports.py @@ -22,5 +22,7 @@ _NEED_IMPORT_FAIRSCALE = not _IS_WINDOWS and _module_available("fairscale") and _NEED_IMPORT_TORCH = _module_available("torch") and 'torch' in need_import _NEED_IMPORT_JITTOR = _module_available("jittor") and 'jittor' in need_import _NEED_IMPORT_PADDLE = _module_available("paddle") and 'paddle' in need_import +_NEED_IMPORT_DEEPSPEED = _module_available("deepspeed") and 'torch' in need_import +_NEED_IMPORT_ONEFLOW = _module_available("oneflow") and 'oneflow' in need_import _TORCH_GREATER_EQUAL_1_8 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.8.0") diff --git a/fastNLP/envs/set_backend.py b/fastNLP/envs/set_backend.py index 1ef27ff6..45674794 100644 --- a/fastNLP/envs/set_backend.py +++ b/fastNLP/envs/set_backend.py @@ -8,7 +8,7 @@ from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VIS from fastNLP.envs.utils import _module_available, get_gpu_count -SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] +SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor', 'oneflow'] def _set_backend(): @@ -145,6 +145,9 @@ def set_env(global_seed=None): if backend == 'torch': assert _module_available(backend), f"You must have {backend} available to use {backend} backend." + if backend == 'oneflow': + assert _module_available(backend), f"You must have {backend} available to use {backend} backend." + def dump_fastnlp_backend(default:bool = False, backend=None): """ diff --git a/fastNLP/envs/set_env_on_import.py b/fastNLP/envs/set_env_on_import.py index f35f8e54..27686ae3 100644 --- a/fastNLP/envs/set_env_on_import.py +++ b/fastNLP/envs/set_env_on_import.py @@ -50,6 +50,15 @@ def set_env_on_import_jittor(): if 'log_silent' not in os.environ: os.environ['log_silent'] = '1' +def set_env_on_import_oneflow(): + if 'GLOG_log_dir' in os.environ: + os.environ[FASTNLP_GLOBAL_RANK] = os.environ['RANK'] + if int(os.environ.get(FASTNLP_REMOVE_LOCAL_RANK, 1)): + remove_local_rank_in_argv() + + if 'GLOG_log_dir' in os.environ and FASTNLP_DISTRIBUTED_CHECK not in os.environ: + os.environ[FASTNLP_BACKEND_LAUNCH] = '1' + def set_env_on_import(): """ @@ -61,6 +70,7 @@ def set_env_on_import(): set_env_on_import_torch() set_env_on_import_paddle() set_env_on_import_jittor() + set_env_on_import_oneflow() # fastNLP 内部使用的一些变量 if FASTNLP_LAUNCH_TIME not in os.environ: diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index 58538d61..4029e092 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -245,8 +245,9 @@ class DataBundle: """ _progress_desc = progress_desc for name, dataset in self.datasets.items(): - if _progress_desc: - progress_desc = _progress_desc + f' for `{name}`' + if len(_progress_desc) == 0: + _progress_desc = 'Processing' + progress_desc = _progress_desc + f' for `{name}`' if dataset.has_field(field_name=field_name): dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc, progress_desc=progress_desc, progress_bar=progress_bar) @@ -284,8 +285,9 @@ class DataBundle: res = {} _progress_desc = progress_desc for name, dataset in self.datasets.items(): - if _progress_desc: - progress_desc = _progress_desc + f' for `{name}`' + if len(_progress_desc) == 0: + _progress_desc = 'Processing' + progress_desc = _progress_desc + f' for `{name}`' if dataset.has_field(field_name=field_name): res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc, modify_fields=modify_fields, @@ -317,8 +319,9 @@ class DataBundle: """ _progress_desc = progress_desc for name, dataset in self.datasets.items(): - if _progress_desc: - progress_desc = _progress_desc + f' for `{name}`' + if len(_progress_desc) == 0: + _progress_desc = 'Processing' + progress_desc = _progress_desc + f' for `{name}`' dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, progress_bar=progress_bar, progress_desc=progress_desc) return self @@ -349,8 +352,9 @@ class DataBundle: res = {} _progress_desc = progress_desc for name, dataset in self.datasets.items(): - if _progress_desc: - progress_desc = _progress_desc + f' for `{name}`' + if len(_progress_desc) == 0: + _progress_desc = 'Processing' + progress_desc = _progress_desc + f' for `{name}`' res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc, progress_bar=progress_bar, progress_desc=progress_desc) return res diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index 9080ff28..df82643b 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -91,7 +91,7 @@ class EmbedLoader: hit_flags[index] = True except Exception as e: if error == 'ignore': - logger.warn("Error occurred at the {} line.".format(idx)) + logger.warning("Error occurred at the {} line.".format(idx)) else: logging.error("Error occurred at the {} line.".format(idx)) raise e @@ -156,7 +156,7 @@ class EmbedLoader: found_pad = True except Exception as e: if error == 'ignore': - logger.warn("Error occurred at the {} line.".format(idx)) + logger.warning("Error occurred at the {} line.".format(idx)) pass else: logging.error("Error occurred at the {} line.".format(idx)) diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py index 4416376f..2ae0b163 100644 --- a/fastNLP/io/loader/classification.py +++ b/fastNLP/io/loader/classification.py @@ -345,7 +345,7 @@ class SST2Loader(Loader): with open(path, 'r', encoding='utf-8') as f: f.readline() # 跳过header if 'test' in os.path.split(path)[1]: - logger.warn("SST2's test file has no target.") + logger.warning("SST2's test file has no target.") for line in f: line = line.strip() if line: diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py index 5595b798..08387df9 100644 --- a/fastNLP/io/loader/matching.py +++ b/fastNLP/io/loader/matching.py @@ -55,7 +55,7 @@ class MNLILoader(Loader): with open(path, 'r', encoding='utf-8') as f: f.readline() # 跳过header if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): - logger.warn("MNLI's test file has no target.") + logger.warning("MNLI's test file has no target.") for line in f: line = line.strip() if line: @@ -227,7 +227,7 @@ class QNLILoader(JsonLoader): with open(path, 'r', encoding='utf-8') as f: f.readline() # 跳过header if path.endswith("test.tsv"): - logger.warn("QNLI's test file has no target.") + logger.warning("QNLI's test file has no target.") for line in f: line = line.strip() if line: @@ -289,7 +289,7 @@ class RTELoader(Loader): with open(path, 'r', encoding='utf-8') as f: f.readline() # 跳过header if path.endswith("test.tsv"): - logger.warn("RTE's test file has no target.") + logger.warning("RTE's test file has no target.") for line in f: line = line.strip() if line: diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py index a89f2f2b..baebdbaa 100644 --- a/fastNLP/io/pipe/matching.py +++ b/fastNLP/io/pipe/matching.py @@ -146,7 +146,7 @@ class MatchingBertPipe(Pipe): warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ f"data set but not in train data set!." - logger.warn(warn_msg) + logger.warning(warn_msg) print(warn_msg) has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if @@ -291,7 +291,7 @@ class MatchingPipe(Pipe): warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ f"data set but not in train data set!." - logger.warn(warn_msg) + logger.warning(warn_msg) print(warn_msg) has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if diff --git a/fastNLP/io/pipe/utils.py b/fastNLP/io/pipe/utils.py index aa28af08..05dd3cf4 100644 --- a/fastNLP/io/pipe/utils.py +++ b/fastNLP/io/pipe/utils.py @@ -138,7 +138,7 @@ def _indexize(data_bundle, input_field_names='words', target_field_names='target f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ f"data set but not in train data set!.\n" \ f"These label(s) are {tgt_vocab._no_create_word}" - logger.warn(warn_msg) + logger.warning(warn_msg) # log.warning(warn_msg) tgt_vocab.index_dataset(*[ds for ds in data_bundle.datasets.values() if ds.has_field(target_field_name)], field_name=target_field_name) data_bundle.set_vocab(tgt_vocab, target_field_name) diff --git a/fastNLP/modules/mix_modules/utils.py b/fastNLP/modules/mix_modules/utils.py index 21d0f05c..04dab056 100644 --- a/fastNLP/modules/mix_modules/utils.py +++ b/fastNLP/modules/mix_modules/utils.py @@ -112,7 +112,7 @@ def _jittor2torch(jittor_var: 'jittor.Var', device: Optional[Union[str, int]] = # 如果outputs有_grad键,可以实现求导 no_gradient = not jittor_var.requires_grad if no_gradient is None else no_gradient if no_gradient == False: - logger.warn("The result tensor will not keep gradients due to differences between jittor and pytorch.") + logger.warning("The result tensor will not keep gradients due to differences between jittor and pytorch.") jittor_numpy = jittor_var.numpy() if not np.issubdtype(jittor_numpy.dtype, np.inexact): no_gradient = True diff --git a/fastNLP/transformers/torch/configuration_utils.py b/fastNLP/transformers/torch/configuration_utils.py index 948d9873..26a80377 100644 --- a/fastNLP/transformers/torch/configuration_utils.py +++ b/fastNLP/transformers/torch/configuration_utils.py @@ -327,7 +327,7 @@ class PretrainedConfig: # Deal with gradient checkpointing if kwargs.get("gradient_checkpointing", False): - logger.warn( + logger.warning( "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " "`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`." diff --git a/fastNLP/transformers/torch/generation_beam_search.py b/fastNLP/transformers/torch/generation_beam_search.py index 117d9a38..1c3fc592 100644 --- a/fastNLP/transformers/torch/generation_beam_search.py +++ b/fastNLP/transformers/torch/generation_beam_search.py @@ -195,7 +195,7 @@ class BeamSearchScorer(BeamScorer): ) if "max_length" in kwargs: - logger.warn( + logger.warning( "Passing `max_length` to BeamSearchScorer is deprecated and has no effect." "`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" ",or `group_beam_search(...)`." diff --git a/fastNLP/transformers/torch/generation_utils.py b/fastNLP/transformers/torch/generation_utils.py index 0e6fe5c7..29828c15 100644 --- a/fastNLP/transformers/torch/generation_utils.py +++ b/fastNLP/transformers/torch/generation_utils.py @@ -872,7 +872,7 @@ class GenerationMixin: max_length = self.config.max_length elif max_length is not None and max_new_tokens is not None: # Both are set, this is odd, raise a warning - logger.warn( + logger.warning( "Both `max_length` and `max_new_tokens` have been set but they serve the same purpose.", UserWarning ) @@ -1239,7 +1239,7 @@ class GenerationMixin: logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: - logger.warn( + logger.warning( "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) @@ -1475,7 +1475,7 @@ class GenerationMixin: logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: - logger.warn( + logger.warning( "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) @@ -1726,13 +1726,13 @@ class GenerationMixin: logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: - logger.warn( + logger.warning( "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: - logger.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) + logger.warning("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores @@ -2030,7 +2030,7 @@ class GenerationMixin: logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: - logger.warn( + logger.warning( "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) @@ -2325,7 +2325,7 @@ class GenerationMixin: logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: - logger.warn( + logger.warning( "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) diff --git a/fastNLP/transformers/torch/models/auto/auto_factory.py b/fastNLP/transformers/torch/models/auto/auto_factory.py index 9eb8ec69..d0969a5b 100644 --- a/fastNLP/transformers/torch/models/auto/auto_factory.py +++ b/fastNLP/transformers/torch/models/auto/auto_factory.py @@ -401,7 +401,7 @@ class _BaseAutoModelClass: "the option `trust_remote_code=True` to remove this error." ) if kwargs.get("revision", None) is None: - logger.warn( + logger.warning( "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure " "no malicious code has been contributed in a newer revision." ) diff --git a/fastNLP/transformers/torch/models/auto/configuration_auto.py b/fastNLP/transformers/torch/models/auto/configuration_auto.py index 45d3c071..1289071d 100644 --- a/fastNLP/transformers/torch/models/auto/configuration_auto.py +++ b/fastNLP/transformers/torch/models/auto/configuration_auto.py @@ -130,7 +130,7 @@ class _LazyLoadAllMappings(OrderedDict): def _initialize(self): if self._initialized: return - # logger.warn( + # logger.warning( # "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. " # "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.", # FutureWarning, diff --git a/fastNLP/transformers/torch/models/auto/modeling_auto.py b/fastNLP/transformers/torch/models/auto/modeling_auto.py index aace27a2..dbf4b610 100644 --- a/fastNLP/transformers/torch/models/auto/modeling_auto.py +++ b/fastNLP/transformers/torch/models/auto/modeling_auto.py @@ -306,7 +306,7 @@ AutoModelForSpeechSeq2Seq = auto_class_update( class AutoModelWithLMHead(_AutoModelWithLMHead): @classmethod def from_config(cls, config): - logger.warn( + logger.warning( "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " "`AutoModelForSeq2SeqLM` for encoder-decoder models.", @@ -316,7 +316,7 @@ class AutoModelWithLMHead(_AutoModelWithLMHead): @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - logger.warn( + logger.warning( "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " "`AutoModelForSeq2SeqLM` for encoder-decoder models.", diff --git a/fastNLP/transformers/torch/models/bart/modeling_bart.py b/fastNLP/transformers/torch/models/bart/modeling_bart.py index 7219f49a..377afa41 100644 --- a/fastNLP/transformers/torch/models/bart/modeling_bart.py +++ b/fastNLP/transformers/torch/models/bart/modeling_bart.py @@ -513,7 +513,7 @@ class BartPretrainedModel(PreTrainedModel): class PretrainedBartModel(BartPretrainedModel): def __init_subclass__(self): - logger.warn( + logger.warning( "The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.", FutureWarning, ) diff --git a/fastNLP/transformers/torch/models/bert/modeling_bert.py b/fastNLP/transformers/torch/models/bert/modeling_bert.py index b95da0df..79f1c459 100644 --- a/fastNLP/transformers/torch/models/bert/modeling_bert.py +++ b/fastNLP/transformers/torch/models/bert/modeling_bert.py @@ -1374,7 +1374,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): """ if "next_sentence_label" in kwargs: - logger.warn( + logger.warning( "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.", FutureWarning, ) diff --git a/fastNLP/transformers/torch/models/cpt/modeling_cpt.py b/fastNLP/transformers/torch/models/cpt/modeling_cpt.py index 2910cc26..df7d477b 100644 --- a/fastNLP/transformers/torch/models/cpt/modeling_cpt.py +++ b/fastNLP/transformers/torch/models/cpt/modeling_cpt.py @@ -724,7 +724,7 @@ class CPTDecoder(CPTPretrainedModel): if getattr(self.config, "gradient_checkpointing", False) and self.training: if use_cache: - logger.warn( + logger.warning( "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=False`..." ) diff --git a/fastNLP/transformers/torch/tokenization_utils_base.py b/fastNLP/transformers/torch/tokenization_utils_base.py index 3a033c96..a04dbaf1 100644 --- a/fastNLP/transformers/torch/tokenization_utils_base.py +++ b/fastNLP/transformers/torch/tokenization_utils_base.py @@ -312,7 +312,7 @@ class BatchEncoding(UserDict): """ if not self._encodings: raise ValueError("words() is not available when using Python-based tokenizers") - logger.warn( + logger.warning( "`BatchEncoding.words()` property is deprecated and should be replaced with the identical, " "but more self-explanatory `BatchEncoding.word_ids()` property.", FutureWarning, @@ -1601,7 +1601,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not " "supported for this tokenizer. Use a model identifier or the path to a directory instead." ) - logger.warn( + logger.warning( f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated and " "won't be possible anymore in v5. Use a model identifier or the path to a directory instead.", FutureWarning, @@ -2163,7 +2163,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): # Get padding strategy if padding is False and old_pad_to_max_length: if verbose: - logger.warn( + logger.warning( "The `pad_to_max_length` argument is deprecated and will be removed in a future version, " "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or " "use `padding='max_length'` to pad to a max length. In this case, you can give a specific " @@ -2184,7 +2184,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): "To pad to max length, use `padding='max_length'`." ) if old_pad_to_max_length is not False: - logger.warn("Though `pad_to_max_length` = `True`, it is ignored because `padding`=`True`.") + logger.warning("Though `pad_to_max_length` = `True`, it is ignored because `padding`=`True`.") padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch elif not isinstance(padding, PaddingStrategy): padding_strategy = PaddingStrategy(padding) @@ -2196,7 +2196,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): # Get truncation strategy if truncation is False and old_truncation_strategy != "do_not_truncate": if verbose: - logger.warn( + logger.warning( "The `truncation_strategy` argument is deprecated and will be removed in a future version, " "use `truncation=True` to truncate examples to a max length. You can give a specific " "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the " @@ -3352,7 +3352,7 @@ model_inputs["labels"] = labels["input_ids"] See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice. For a more complete example, see the implementation of `prepare_seq2seq_batch`. """ - logger.warn(formatted_warning, FutureWarning) + logger.warning(formatted_warning, FutureWarning) # mBART-specific kwargs that should be ignored by other models. kwargs.pop("src_lang", None) kwargs.pop("tgt_lang", None) diff --git a/tests/core/collators/padders/test_get_padder.py b/tests/core/collators/padders/test_get_padder.py index 5996f023..a0e2dfdc 100644 --- a/tests/core/collators/padders/test_get_padder.py +++ b/tests/core/collators/padders/test_get_padder.py @@ -3,7 +3,7 @@ import numpy as np from fastNLP.core.collators.padders.get_padder import get_padder, InconsistencyError, DtypeError, \ _get_element_shape_dtype -from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR, _NEED_IMPORT_ONEFLOW def test_get_element_shape_dtype(): @@ -14,10 +14,11 @@ def test_get_element_shape_dtype(): catalog = _get_element_shape_dtype([np.zeros(3), np.zeros((2, 1))]) -# @pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle']) -@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'paddle']) +@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'paddle', 'jittor', 'oneflow']) @pytest.mark.torch @pytest.mark.paddle +@pytest.mark.jittor +@pytest.mark.oneflow def test_get_padder_run(backend): if not _NEED_IMPORT_TORCH and backend == 'torch': pytest.skip("No torch") @@ -25,6 +26,8 @@ def test_get_padder_run(backend): pytest.skip("No paddle") if not _NEED_IMPORT_JITTOR and backend == 'jittor': pytest.skip("No jittor") + if not _NEED_IMPORT_ONEFLOW and backend == 'oneflow': + pytest.skip("No oneflow") batch_field = [1, 2, 3] padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') @@ -163,3 +166,57 @@ def test_torch_padder(): assert isinstance(pad_batch, np.ndarray) assert np.shape(pad_batch) == (3, 3, 3) assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12 + +@pytest.mark.oneflow +def test_oneflow_padder(): + if not _NEED_IMPORT_ONEFLOW: + pytest.skip("No oneflow.") + import oneflow + backend = 'oneflow' + target_type = oneflow.Tensor + batch_field = [1, 2, 3] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert (pad_batch == oneflow.LongTensor(batch_field)).sum()==len(batch_field) + + batch_field = [[1], [2, 2], [3, 3, 3]] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert pad_batch.shape == (3, 3) + assert (pad_batch == oneflow.zeros(pad_batch.shape)).sum()==3 + + batch_field = [oneflow.ones((3,3)), oneflow.ones((2,3)), oneflow.ones((1,3))] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert pad_batch.shape == (3, 3, 3) + assert (pad_batch == oneflow.zeros(pad_batch.shape)).sum()==9 + + batch_field = [oneflow.ones((3,3)), oneflow.ones((2,3)), oneflow.ones((1,0))] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert pad_batch.shape == (3, 3, 3) + assert (pad_batch == oneflow.zeros(pad_batch.shape)).sum()==12 + + batch_field = [oneflow.ones((3,3)), oneflow.ones((2,3)), oneflow.ones((1,))] + with pytest.raises(InconsistencyError): + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + + # 可以是 numpy.ndarray + batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert pad_batch.shape == (3, 3, 3) + assert (pad_batch == oneflow.zeros(pad_batch.shape)).sum()==12 + + # 测试 to numpy + batch_field = [oneflow.ones((3,3)), oneflow.ones((2,3)), oneflow.ones((1,0))] + padder = get_padder(batch_field, pad_val=0, backend='numpy', dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, np.ndarray) + assert np.shape(pad_batch) == (3, 3, 3) + assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12 diff --git a/tests/core/collators/padders/test_oneflow_padder.py b/tests/core/collators/padders/test_oneflow_padder.py new file mode 100644 index 00000000..9ad31816 --- /dev/null +++ b/tests/core/collators/padders/test_oneflow_padder.py @@ -0,0 +1,105 @@ +import numpy as np +import pytest + +from fastNLP.core.collators.padders.oneflow_padder import OneflowTensorPadder, OneflowSequencePadder, OneflowNumberPadder +from fastNLP.core.collators.padders.exceptions import DtypeError +from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW + +if _NEED_IMPORT_ONEFLOW: + import oneflow + + +@pytest.mark.oneflow +class TestOneflowNumberPadder: + def test_run(self): + padder = OneflowNumberPadder(pad_val=-1, ele_dtype=int, dtype=int) + a = [1, 2, 3] + t_a = padder(a) + assert isinstance(t_a, oneflow.Tensor) + assert (t_a == oneflow.LongTensor(a)).sum() == 3 + + +@pytest.mark.oneflow +class TestOneflowSequencePadder: + def test_run(self): + padder = OneflowSequencePadder(pad_val=-1, ele_dtype=int, dtype=int) + a = [[1, 2, 3], [3]] + a = padder(a) + shape = a.shape + assert isinstance(a, oneflow.Tensor) + assert tuple(shape) == (2, 3) + b = oneflow.LongTensor([[1, 2, 3], [3, -1, -1]]) + assert (a == b).sum().item() == shape[0]*shape[1] + + def test_dtype_check(self): + padder = OneflowSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) + with pytest.raises(DtypeError): + padder = OneflowSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) + padder = OneflowSequencePadder(pad_val=-1, ele_dtype=oneflow.long, dtype=int) + padder = OneflowSequencePadder(pad_val=-1, ele_dtype=np.int8, dtype=None) + a = padder([[1], [2, 322]]) + assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 + padder = OneflowSequencePadder(pad_val=-1, ele_dtype=np.zeros(2).dtype, dtype=None) + + +@pytest.mark.oneflow +class TestOneflowTensorPadder: + def test_run(self): + padder = OneflowTensorPadder(pad_val=-1, ele_dtype=oneflow.zeros(3).dtype, dtype=int) + a = [oneflow.zeros(3), oneflow.zeros(2), oneflow.zeros(0)] + a = padder(a) + shape = a.shape + assert isinstance(a, oneflow.Tensor) + assert tuple(shape) == (3, 3) + b = oneflow.LongTensor([[0, 0, 0], [0, 0, -1], [-1, -1, -1]]) + assert (a == b).sum().item() == shape[0]*shape[1] + + a = [oneflow.zeros((3, 2)), oneflow.zeros((2, 2)), oneflow.zeros((1, 2))] + a = padder(a) + shape = a.shape + assert isinstance(a, oneflow.Tensor) + assert tuple(shape) == (3, 3, 2) + b = oneflow.LongTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[0, 0], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + a = [oneflow.zeros((3, 2)), oneflow.zeros((2, 2)), oneflow.zeros((1, 1))] + a = padder(a) + shape = a.shape + assert isinstance(a, oneflow.Tensor) + assert tuple(shape) == (3, 3, 2) + b = oneflow.LongTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[0, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + padder = OneflowTensorPadder(pad_val=-1, ele_dtype=oneflow.zeros(3).dtype, dtype=int) + a = [oneflow.zeros((3, 2)), oneflow.zeros((2, 2)), oneflow.zeros((1, 0))] + a = padder(a) + shape = a.shape + assert isinstance(a, oneflow.Tensor) + assert tuple(shape) == (3, 3, 2) + b = oneflow.LongTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + padder = OneflowTensorPadder(pad_val=-1, ele_dtype=oneflow.zeros(3).dtype, dtype=None) + a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] + a = padder(a) + shape = a.shape + assert isinstance(a, oneflow.Tensor) + assert tuple(shape) == (3, 3, 2) + b = oneflow.FloatTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + def test_dtype_check(self): + padder = OneflowTensorPadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) + with pytest.raises(DtypeError): + padder = OneflowTensorPadder(pad_val=-1, ele_dtype=str, dtype=int) + padder = OneflowTensorPadder(pad_val=-1, ele_dtype=oneflow.long, dtype=int) + padder = OneflowTensorPadder(pad_val=-1, ele_dtype=int, dtype=oneflow.long) + diff --git a/tests/core/collators/test_collator.py b/tests/core/collators/test_collator.py index 8443ef92..d00cbe05 100644 --- a/tests/core/collators/test_collator.py +++ b/tests/core/collators/test_collator.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR, _NEED_IMPORT_ONEFLOW from fastNLP.core.collators.collator import Collator from ...helpers.utils import Capturing @@ -14,6 +14,10 @@ def _assert_equal(d1, d2): if 'float64' in str(d2.dtype): print(d2.dtype) assert (d1 == d2).all().item() + if 'oneflow' in str(type(d1)): + if 'float64' in str(d2.dtype): + print(d2.dtype) + assert (d1 == d2).all().item() else: assert all(d1 == d2) except TypeError: @@ -43,9 +47,9 @@ def findListDiff(d1, d2): class TestCollator: - @pytest.mark.torch - def test_run(self): - dict_batch = [{ + @staticmethod + def setup_class(cls): + cls.dict_batch = [{ 'str': '1', 'lst_str': ['1'], 'int': 1, @@ -75,17 +79,21 @@ class TestCollator: } ] - list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], - ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + cls.list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + + def test_run_traw(self): raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} collator = Collator(backend='raw') - assert raw_pad_batch == collator(dict_batch) + assert raw_pad_batch == collator(self.dict_batch) collator = Collator(backend='raw') raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], [1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}], [{'1'}, {'2'}]] - findListDiff(raw_pad_lst, collator(list_batch)) + findListDiff(raw_pad_lst, collator(self.list_batch)) + + def test_run_numpy(self): collator = Collator(backend='numpy') numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), @@ -94,36 +102,60 @@ class TestCollator: 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), 'b': np.array([[1, 2], [1, 2]])}} - findDictDiff(numpy_pad_batch, collator(dict_batch)) + findDictDiff(numpy_pad_batch, collator(self.dict_batch)) collator = Collator(backend='numpy') numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], [{'1'}, {'2'}]] - findListDiff(numpy_pad_lst, collator(list_batch)) - - if _NEED_IMPORT_TORCH: - import torch - collator = Collator(backend='torch') - numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), - 'lst_int': torch.LongTensor([[1, 0], [1, 2]]), - 'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), - 'float': torch.FloatTensor([1.1, 2.1]), - 'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), - 'numpy': torch.FloatTensor([[1], [0]]), - 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), - 'b': torch.LongTensor( - [[1, 2], [1, 2]])}} - - findDictDiff(numpy_pad_batch, collator(dict_batch)) - collator = Collator(backend='torch') - torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), - torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), - torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), - torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(torch_pad_lst, collator(list_batch)) + findListDiff(numpy_pad_lst, collator(self.list_batch)) + + @pytest.mark.torch + def test_run_torch(self): + import torch + collator = Collator(backend='torch') + numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), + 'lst_int': torch.LongTensor([[1, 0], [1, 2]]), + 'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + 'float': torch.FloatTensor([1.1, 2.1]), + 'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), + 'numpy': torch.FloatTensor([[1], [0]]), + 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), + 'b': torch.LongTensor( + [[1, 2], [1, 2]])}} + + findDictDiff(numpy_pad_batch, collator(self.dict_batch)) + collator = Collator(backend='torch') + torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), + torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), + torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(torch_pad_lst, collator(self.list_batch)) + + @pytest.mark.oneflow + def test_run_oneflow(self): + import oneflow + collator = Collator(backend='oneflow') + numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': oneflow.LongTensor([1, 2]), + 'lst_int': oneflow.LongTensor([[1, 0], [1, 2]]), + 'nest_lst_int': oneflow.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + 'float': oneflow.FloatTensor([1.1, 2.1]), + 'lst_float': oneflow.FloatTensor([[1.1], [2.1]]), 'bool': oneflow.BoolTensor([True, False]), + 'numpy': oneflow.FloatTensor([[1], [0]]), + 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': oneflow.LongTensor([1, 2]), + 'b': oneflow.LongTensor( + [[1, 2], [1, 2]])}} + + findDictDiff(numpy_pad_batch, collator(self.dict_batch)) + collator = Collator(backend='oneflow') + oneflow_pad_lst = [['1', '2'], [['1'], ['2', '2']], oneflow.LongTensor([1, 2]), oneflow.LongTensor([[1, 0], [2, 2]]), + oneflow.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + oneflow.FloatTensor([1.1, 2.1]), oneflow.FloatTensor([[1.1], [2.1]]), oneflow.BoolTensor([True, False]), + oneflow.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(oneflow_pad_lst, collator(self.list_batch)) def test_pad(self): dict_batch = [{ @@ -366,6 +398,46 @@ def test_torch_dl(): with pytest.raises(KeyError): dl.set_pad('i', pad_val=None) +@pytest.mark.oneflow +def test_oneflow_dl(): + from fastNLP import OneflowDataLoader + from fastNLP import DataSet + import numpy as np + import oneflow + + ds = DataSet({ + 'x': [1, 2], 'y': [[1,2], [3]], 'z':[np.ones((1, 2)), np.ones((2, 3))], + 'i': [{'j': [1, 2]}, {'j': [3]}], 'j': ['a', 'b'] + }) + + dl = OneflowDataLoader(ds, batch_size=2) + batch = next(iter(dl)) + assert 'x' in batch and 'y' in batch and 'z' in batch and 'i' in batch and 'j' in batch + assert batch['z'].dtype == oneflow.float32 + assert isinstance(batch['j'], list) + assert batch['i']['j'].dtype, oneflow.long + + dl.set_ignore('x') + batch = next(iter(dl)) + assert 'x' not in batch and 'y' in batch and 'z' in batch + + dl.set_pad('y', pad_val=None) + batch = next(iter(dl)) + assert 'x' not in batch and 'y' in batch and 'z' in batch + assert isinstance(batch['y'], list) + assert len(batch['y'][0])!=len(batch['y'][1]) # 没有 pad + + dl.set_pad(('i', 'j'), pad_val=None) + batch = next(iter(dl)) + assert 'x' not in batch and 'y' in batch and 'z' in batch + assert isinstance(batch['y'], list) + assert len(batch['y'][0])!=len(batch['y'][1]) # 没有 pad + assert isinstance(batch['i']['j'], list) + assert len(batch['i']['j'][0])!=len(batch['i']['j'][1]) # 没有 pad + + with pytest.raises(KeyError): + dl.set_pad('i', pad_val=None) + def test_compare_tuple(): from fastNLP.core.collators.collator import _compare_tuple diff --git a/tests/core/controllers/_test_trainer_deepspeed.py b/tests/core/controllers/_test_trainer_deepspeed.py new file mode 100644 index 00000000..0c51e47c --- /dev/null +++ b/tests/core/controllers/_test_trainer_deepspeed.py @@ -0,0 +1,95 @@ +""" +这个文件测试多卡情况下使用 deepspeed 的情况:: + + >>> # 测试直接使用多卡 + >>> python _test_trainer_deepspeed.py + >>> # 测试通过 deepspeed 拉起 + >>> deepspeed _test_trainer_deepspeed.py + +""" +import sys +sys.path.append("../../../") +from dataclasses import dataclass + +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.metrics.accuracy import Accuracy +from fastNLP.core.callbacks.progress_callback import RichCallback + +from torch.optim import Adam +from torch.utils.data import DataLoader + +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +from tests.helpers.datasets.torch_data import TorchArgMaxDataset + +@dataclass +class TrainDeepSpeedConfig: + num_labels: int = 3 + feature_dimension: int = 3 + + batch_size: int = 2 + shuffle: bool = True + evaluate_every = 2 + +def test_trainer_deepspeed( + device, + callbacks, + strategy, + config, + n_epochs=2, +): + model = TorchNormalModel_Classification_1( + num_labels=TrainDeepSpeedConfig.num_labels, + feature_dimension=TrainDeepSpeedConfig.feature_dimension + ) + optimizers = Adam(params=model.parameters(), lr=0.0001) + train_dataloader = DataLoader( + dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 20), + batch_size=TrainDeepSpeedConfig.batch_size, + shuffle=True + ) + val_dataloader = DataLoader( + dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 12), + batch_size=TrainDeepSpeedConfig.batch_size, + shuffle=True + ) + train_dataloader = train_dataloader + evaluate_dataloaders = val_dataloader + evaluate_every = TrainDeepSpeedConfig.evaluate_every + metrics = {"acc": Accuracy()} + if config is not None: + config["train_micro_batch_size_per_gpu"] = TrainDeepSpeedConfig.batch_size + trainer = Trainer( + model=model, + driver="deepspeed", + device=device, + optimizers=optimizers, + train_dataloader=train_dataloader, + evaluate_dataloaders=evaluate_dataloaders, + evaluate_every=evaluate_every, + metrics=metrics, + output_mapping={"preds": "pred"}, + + n_epochs=n_epochs, + callbacks=callbacks, + deepspeed_kwargs={ + "strategy": strategy, + "config": config + } + ) + trainer.run() + +if __name__ == "__main__": + device = [4, 5] + # device = [0,1,3] + callbacks = [ + # RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), + RichCallback(5), + ] + config = None + test_trainer_deepspeed( + device=device, + callbacks=callbacks, + strategy="deepspeed", + config=config, + n_epochs=5, + ) \ No newline at end of file diff --git a/tests/core/controllers/_test_trainer_deepspeed_outside.py b/tests/core/controllers/_test_trainer_deepspeed_outside.py new file mode 100644 index 00000000..6821787e --- /dev/null +++ b/tests/core/controllers/_test_trainer_deepspeed_outside.py @@ -0,0 +1,105 @@ +""" +这个文件测试多卡情况下使用 deepspeed ,且用户自己调用了 deepspeed.initialize 的情况:: + + >>> deepspeed _test_trainer_deepspeed_outside.py + +""" +import os +import sys +sys.path.append("../../../") +from dataclasses import dataclass + +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.metrics.accuracy import Accuracy +from fastNLP.core.callbacks.progress_callback import RichCallback +from fastNLP.core.drivers.torch_driver.utils import _create_default_config + +import deepspeed +import torch +from torch.optim import Adam +from torch.utils.data import DataLoader + + +from tests.helpers.models.torch_model import TorchNormalModel_Classification_2 +from tests.helpers.datasets.torch_data import TorchArgMaxDataset + +local_rank = int(os.environ["LOCAL_RANK"]) + +@dataclass +class TrainDeepSpeedConfig: + num_labels: int = 3 + feature_dimension: int = 3 + + batch_size: int = 2 + shuffle: bool = True + evaluate_every = 2 + +def test_trainer_deepspeed( + device, + callbacks, + strategy, + config, + n_epochs=2, +): + model = TorchNormalModel_Classification_2( + num_labels=TrainDeepSpeedConfig.num_labels, + feature_dimension=TrainDeepSpeedConfig.feature_dimension + ) + optimizers = Adam(params=model.parameters(), lr=0.0001) + train_dataloader = DataLoader( + dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 20), + batch_size=TrainDeepSpeedConfig.batch_size, + shuffle=True + ) + val_dataloader = DataLoader( + dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 12), + batch_size=TrainDeepSpeedConfig.batch_size, + shuffle=True + ) + train_dataloader = train_dataloader + evaluate_dataloaders = val_dataloader + evaluate_every = TrainDeepSpeedConfig.evaluate_every + metrics = {"acc": Accuracy()} + if config is not None: + config["train_micro_batch_size_per_gpu"] = TrainDeepSpeedConfig.batch_size + model, optimizers, _, _ = deepspeed.initialize( + model=model, + optimizer=optimizers, + config=config, + ) + trainer = Trainer( + model=model, + driver="deepspeed", + device=device, + data_device=torch.device(f"cuda:{local_rank}"), + optimizers=optimizers, + train_dataloader=train_dataloader, + evaluate_dataloaders=evaluate_dataloaders, + evaluate_every=evaluate_every, + metrics=metrics, + output_mapping={"preds": "pred"}, + + n_epochs=n_epochs, + callbacks=callbacks, + deepspeed_kwargs={ + "strategy": strategy, + "config": config + } + ) + trainer.run() + +if __name__ == "__main__": + device = [0,1] + # device = [0,1,3] + callbacks = [ + # RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), + RichCallback(5), + ] + config = _create_default_config(stage=2) + test_trainer_deepspeed( + device=device, + callbacks=callbacks, + strategy="deepspeed", + config=config, + n_epochs=5, + ) \ No newline at end of file diff --git a/tests/core/controllers/_test_trainer_oneflow.py b/tests/core/controllers/_test_trainer_oneflow.py new file mode 100644 index 00000000..385aded0 --- /dev/null +++ b/tests/core/controllers/_test_trainer_oneflow.py @@ -0,0 +1,96 @@ +""" +测试 oneflow 动态图的多卡训练:: + + >>> # 不使用 DistributedDataParallel 包裹的情况 + >>> python -m oneflow.distributed.launch --nproc_per_node 2 _test_trainer_oneflow.py + >>> # 使用 DistributedDataParallel 包裹的情况 + >>> python -m oneflow.distributed.launch --nproc_per_node 2 _test_trainer_oneflow.py -w +""" +import sys +sys.path.append("../../../") +import os +from dataclasses import dataclass + +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.metrics.accuracy import Accuracy +from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW + +if _NEED_IMPORT_ONEFLOW: + import oneflow + from oneflow.nn.parallel import DistributedDataParallel + from oneflow.optim import Adam + from oneflow.utils.data import DataLoader + +from tests.helpers.models.oneflow_model import OneflowNormalModel_Classification_1 +from tests.helpers.datasets.oneflow_data import OneflowArgMaxDataset + +@dataclass +class TrainOneflowConfig: + num_labels: int = 3 + feature_dimension: int = 3 + + batch_size: int = 2 + shuffle: bool = True + evaluate_every = 2 + +def test_trainer_oneflow( + callbacks, + wrapped=False, + n_epochs=2, +): + model = OneflowNormalModel_Classification_1( + num_labels=TrainOneflowConfig.num_labels, + feature_dimension=TrainOneflowConfig.feature_dimension + ) + optimizers = Adam(params=model.parameters(), lr=0.0001) + train_dataloader = DataLoader( + dataset=OneflowArgMaxDataset(20, TrainOneflowConfig.feature_dimension), + batch_size=TrainOneflowConfig.batch_size, + shuffle=True + ) + val_dataloader = DataLoader( + dataset=OneflowArgMaxDataset(12, TrainOneflowConfig.feature_dimension), + batch_size=TrainOneflowConfig.batch_size, + shuffle=True + ) + train_dataloader = train_dataloader + evaluate_dataloaders = val_dataloader + evaluate_every = TrainOneflowConfig.evaluate_every + metrics = {"acc": Accuracy()} + + if wrapped: + model.to(int(os.environ["LOCAL_RANK"])) + model = DistributedDataParallel(model) + + + trainer = Trainer( + model=model, + driver="oneflow", + device=0, + optimizers=optimizers, + train_dataloader=train_dataloader, + evaluate_dataloaders=evaluate_dataloaders, + evaluate_every=evaluate_every, + input_mapping=None, + output_mapping=None, + metrics=metrics, + + n_epochs=n_epochs, + callbacks=callbacks, + ) + trainer.run() + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + "-w", + "--wrapped", + default=False, + action="store_true", + help="Use DistributedDataParallal to wrap model first.", + ) + args = parser.parse_args() + + callbacks = [] + test_trainer_oneflow(callbacks, args.wrapped) diff --git a/tests/core/controllers/test_trainer_deepspeed.py b/tests/core/controllers/test_trainer_deepspeed.py new file mode 100644 index 00000000..c718e01d --- /dev/null +++ b/tests/core/controllers/test_trainer_deepspeed.py @@ -0,0 +1,99 @@ +import pytest +from dataclasses import dataclass + +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.metrics.accuracy import Accuracy +from fastNLP.core.callbacks.progress_callback import RichCallback +from fastNLP.core.drivers.torch_driver import DeepSpeedDriver +from fastNLP.core.drivers.torch_driver.utils import _create_default_config +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + from torch.optim import Adam + from torch.utils.data import DataLoader + + +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +from tests.helpers.datasets.torch_data import TorchArgMaxDataset +from tests.helpers.utils import magic_argv_env_context + +@dataclass +class TrainDeepSpeedConfig: + num_labels: int = 3 + feature_dimension: int = 3 + + batch_size: int = 2 + shuffle: bool = True + evaluate_every = 2 + +@pytest.mark.deepspeed +class TestTrainer: + @classmethod + def setup_class(cls): + # 不初始化的话从第二个测试例开始会因为环境变量报错。 + torch_model = TorchNormalModel_Classification_1(1, 1) + torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) + device = [torch.device(i) for i in [0,1]] + driver = DeepSpeedDriver( + model=torch_model, + parallel_device=device, + ) + driver.set_optimizers(torch_opt) + driver.setup() + + return driver + + @pytest.mark.parametrize("device", [[0, 1]]) + @pytest.mark.parametrize("callbacks", [[RichCallback(5)]]) + @pytest.mark.parametrize("strategy", ["deepspeed", "deepspeed_stage_1"]) + @pytest.mark.parametrize("config", [None, _create_default_config(stage=1)]) + @magic_argv_env_context + def test_trainer_deepspeed( + self, + device, + callbacks, + strategy, + config, + n_epochs=2, + ): + model = TorchNormalModel_Classification_1( + num_labels=TrainDeepSpeedConfig.num_labels, + feature_dimension=TrainDeepSpeedConfig.feature_dimension + ) + optimizers = Adam(params=model.parameters(), lr=0.0001) + train_dataloader = DataLoader( + dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 20), + batch_size=TrainDeepSpeedConfig.batch_size, + shuffle=True + ) + val_dataloader = DataLoader( + dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 12), + batch_size=TrainDeepSpeedConfig.batch_size, + shuffle=True + ) + train_dataloader = train_dataloader + evaluate_dataloaders = val_dataloader + evaluate_every = TrainDeepSpeedConfig.evaluate_every + metrics = {"acc": Accuracy()} + if config is not None: + config["train_micro_batch_size_per_gpu"] = TrainDeepSpeedConfig.batch_size + trainer = Trainer( + model=model, + driver="torch", + device=device, + optimizers=optimizers, + train_dataloader=train_dataloader, + evaluate_dataloaders=evaluate_dataloaders, + evaluate_every=evaluate_every, + metrics=metrics, + output_mapping={"preds": "pred"}, + + n_epochs=n_epochs, + callbacks=callbacks, + deepspeed_kwargs={ + "strategy": strategy, + "config": config + } + ) + trainer.run() diff --git a/tests/core/controllers/test_trainer_oneflow.py b/tests/core/controllers/test_trainer_oneflow.py new file mode 100644 index 00000000..e5e2433a --- /dev/null +++ b/tests/core/controllers/test_trainer_oneflow.py @@ -0,0 +1,70 @@ +import os +import pytest +from dataclasses import dataclass + +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.metrics.accuracy import Accuracy +from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW + +if _NEED_IMPORT_ONEFLOW: + from oneflow.optim import Adam + from oneflow.utils.data import DataLoader + + +from tests.helpers.models.oneflow_model import OneflowNormalModel_Classification_1 +from tests.helpers.datasets.oneflow_data import OneflowArgMaxDataset +from tests.helpers.utils import magic_argv_env_context + +@dataclass +class TrainOneflowConfig: + num_labels: int = 3 + feature_dimension: int = 3 + + batch_size: int = 2 + shuffle: bool = True + evaluate_every = 2 + +@pytest.mark.parametrize("device", ["cpu", 1]) +@pytest.mark.parametrize("callbacks", [[]]) +@pytest.mark.oneflow +@magic_argv_env_context +def test_trainer_oneflow( + device, + callbacks, + n_epochs=2, +): + model = OneflowNormalModel_Classification_1( + num_labels=TrainOneflowConfig.num_labels, + feature_dimension=TrainOneflowConfig.feature_dimension + ) + optimizers = Adam(params=model.parameters(), lr=0.0001) + train_dataloader = DataLoader( + dataset=OneflowArgMaxDataset(20, TrainOneflowConfig.feature_dimension), + batch_size=TrainOneflowConfig.batch_size, + shuffle=True + ) + val_dataloader = DataLoader( + dataset=OneflowArgMaxDataset(12, TrainOneflowConfig.feature_dimension), + batch_size=TrainOneflowConfig.batch_size, + shuffle=True + ) + train_dataloader = train_dataloader + evaluate_dataloaders = val_dataloader + evaluate_every = TrainOneflowConfig.evaluate_every + metrics = {"acc": Accuracy()} + trainer = Trainer( + model=model, + driver="oneflow", + device=device, + optimizers=optimizers, + train_dataloader=train_dataloader, + evaluate_dataloaders=evaluate_dataloaders, + evaluate_every=evaluate_every, + input_mapping=None, + output_mapping=None, + metrics=metrics, + + n_epochs=n_epochs, + callbacks=callbacks, + ) + trainer.run() diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 752e06d8..78eff36c 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -286,6 +286,9 @@ def test_trainer_specific_params_1( assert trainer.driver.non_blocking is False assert trainer.driver.wo_auto_param_call is True + if dist.is_initialized(): + dist.destroy_process_group() + @pytest.mark.torch @pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) @@ -332,5 +335,47 @@ def test_trainer_specific_params_2( assert _ddp_kwargs.get("broadcast_buffers") is True assert _ddp_kwargs.get("find_unused_parameters") is True + if dist.is_initialized(): + dist.destroy_process_group() +@pytest.mark.torch +@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) +@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) +@magic_argv_env_context +def test_trainer_w_evaluator_overfit_torch( + model_and_optimizers: TrainerParameters, + driver, + device, + overfit_batches, + num_train_batch_per_epoch +): + """ + 测试一些特殊的参数是否能够正确地传递; + """ + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + overfit_batches=overfit_batches, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders={"dl": model_and_optimizers.evaluate_dataloaders}, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + n_epochs=2, + output_from_new_proc="all", + evaluate_every=-1, + + torch_kwargs={ + "non_blocking": False, + "set_grad_to_none": True + } + + ) + + trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch) + + if dist.is_initialized(): + dist.destroy_process_group() \ No newline at end of file diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index be04bcd3..ce67814e 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -361,5 +361,35 @@ def test_torch_wo_auto_param_call( dist.destroy_process_group() +# 测试 accumulation_steps; +@pytest.mark.torch +@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) +@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) +@magic_argv_env_context +def test_trainer_overfit_torch( + model_and_optimizers: TrainerParameters, + driver, + device, + overfit_batches, + num_train_batch_per_epoch +): + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + overfit_batches=overfit_batches, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + output_from_new_proc="all", + n_epochs=2, + ) + + trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch) + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/tests/core/dataloaders/oneflow_dataloader/__init__.py b/tests/core/dataloaders/oneflow_dataloader/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/dataloaders/oneflow_dataloader/test_fdl.py b/tests/core/dataloaders/oneflow_dataloader/test_fdl.py new file mode 100644 index 00000000..f6a80d7c --- /dev/null +++ b/tests/core/dataloaders/oneflow_dataloader/test_fdl.py @@ -0,0 +1,169 @@ +import pytest + +from fastNLP.core.dataloaders.oneflow_dataloader import OneflowDataLoader, prepare_oneflow_dataloader +from fastNLP.core.dataset import DataSet +from fastNLP.io.data_bundle import DataBundle +from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW +from tests.helpers.utils import Capturing, recover_logger +from fastNLP import logger +import numpy as np + +if _NEED_IMPORT_ONEFLOW: + import oneflow + + +@pytest.mark.oneflow +class TestFdl: + + def test_init_v1(self): + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + fdl = OneflowDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) + # for batch in fdl: + # print(batch) + fdl1 = OneflowDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) + # for batch in fdl1: + # print(batch) + + def test_set_padding(self): + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + fdl = OneflowDataLoader(ds, batch_size=3) + fdl.set_pad("x", -1) + for batch in fdl: + assert batch['x'].shape == oneflow.Size([3, 4]) + + def test_get_batch_indices(self): + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + fdl = OneflowDataLoader(ds, batch_size=3, shuffle=True) + for batch in fdl: + assert len(fdl.get_batch_indices()) == 3 + + def test_other_dataset(self): + import numpy as np + class _DataSet: + + def __init__(self): + pass + + def __getitem__(self, item): + return np.random.randn(5), [[1, 2], [2, 3, 4]] + + def __len__(self): + return 10 + + def __getattribute__(self, item): + return object.__getattribute__(self, item) + + dataset = _DataSet() + dl = OneflowDataLoader(dataset, batch_size=2, shuffle=True) + # dl.set_inputs('data', 'labels') + # dl.set_pad_val('labels', val=None) + for batch in dl: + assert batch[0].shape == oneflow.Size([2, 5]) + assert batch[1].shape == oneflow.Size([2, 2, 3]) + + def test_default_collate_fn(self): + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + with pytest.raises(ValueError): + fdl = OneflowDataLoader(ds, batch_size=3, collate_fn=None) + import numpy as np + class _DataSet: + + def __init__(self): + pass + + def __getitem__(self, item): + return np.random.randn(5), [[1, 2], [2, 3, 4]] + + def __len__(self): + return 10 + + fdl = OneflowDataLoader(_DataSet(), batch_size=3, collate_fn=None, drop_last=True) + for batch in fdl: + assert batch[0].shape == oneflow.Size([3, 5]) + + def test_my_collate_fn(self): + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + def collate_fn(batch): + res = {'x': [], 'y': []} + for ins in batch: + res['x'].append(ins['x']) + res['y'].append(ins['y']) + return res + fdl = OneflowDataLoader(ds, collate_fn=collate_fn, batch_size=3, drop_last=True) + for batch in fdl: + assert batch['x'] == [[1, 2], [2, 3, 4], [4, 5, 6, 7]] + assert batch['y'] == [1, 0, 1] + + def test_prepare_oneflow_dataloader(self): + # 测试 fastNLP 的 dataset + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + dl = prepare_oneflow_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) + assert isinstance(dl, OneflowDataLoader) + + ds1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + dbl = DataBundle(datasets={'train': ds, 'val': ds1}) + dl_bundle = prepare_oneflow_dataloader(dbl) + assert isinstance(dl_bundle['train'], OneflowDataLoader) + assert isinstance(dl_bundle['val'], OneflowDataLoader) + + ds_dict = {'train_1': ds, 'val': ds1} + dl_dict = prepare_oneflow_dataloader(ds_dict) + assert isinstance(dl_dict['train_1'], OneflowDataLoader) + assert isinstance(dl_dict['val'], OneflowDataLoader) + + # 测试其他 dataset + class _DataSet: + + def __init__(self): + pass + + def __getitem__(self, item): + return np.random.randn(5), [[1, 2], [2, 3, 4]] + + def __len__(self): + return 10 + + def __getattribute__(self, item): + return object.__getattribute__(self, item) + + ds2 = _DataSet() + dl1 = prepare_oneflow_dataloader(ds2, batch_size=8, shuffle=True, num_workers=2) + assert isinstance(dl1, OneflowDataLoader) + + ds3 = _DataSet() + dbl1 = DataBundle(datasets={'train': ds2, 'val': ds3}) + dl_bundle1 = prepare_oneflow_dataloader(dbl1) + assert isinstance(dl_bundle1['train'], OneflowDataLoader) + assert isinstance(dl_bundle1['val'], OneflowDataLoader) + + ds_dict1 = {'train_1': ds2, 'val': ds3} + dl_dict1 = prepare_oneflow_dataloader(ds_dict1) + assert isinstance(dl_dict1['train_1'], OneflowDataLoader) + assert isinstance(dl_dict1['val'], OneflowDataLoader) + + ds = [[1, [1]], [2, [2, 2]]] + dl = prepare_oneflow_dataloader(ds, batch_size=2) + for batch in dl: + assert (batch[0] == oneflow.LongTensor([1, 2])).sum()==2 + assert (batch[1] == oneflow.LongTensor([[1, 0], [2, 2]])).sum()==4 + + # sequence = [ds, ds1] + # seq_ds = prepare_oneflow_dataloader(sequence) + # assert isinstance(seq_ds[0], OneflowDataLoader) + # assert isinstance(seq_ds[1], OneflowDataLoader) + + def test_get_backend(self): + from fastNLP.core.collators import Collator + from oneflow.utils.data import DataLoader, Dataset + + class MyDatset(DataSet): + def __len__(self): + return 1000 + + def __getitem__(self, item): + return [[1, 0], [1], [1, 2, 4]], [1, 0] + + collate_batch = Collator(backend='auto') + dl = DataLoader(MyDatset(), collate_fn=collate_batch) + for batch in dl: + print(batch) diff --git a/tests/core/drivers/oneflow_driver/__init__.py b/tests/core/drivers/oneflow_driver/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/drivers/oneflow_driver/test_ddp.py b/tests/core/drivers/oneflow_driver/test_ddp.py new file mode 100644 index 00000000..8fa92924 --- /dev/null +++ b/tests/core/drivers/oneflow_driver/test_ddp.py @@ -0,0 +1,948 @@ +import os +import sys +sys.path.append("../../../../") +import pytest +from pathlib import Path + +from fastNLP.core.drivers.oneflow_driver.ddp import OneflowDDPDriver +from fastNLP import prepare_oneflow_dataloader +from fastNLP.core.samplers import ( + RandomSampler, + UnrepeatedSampler, + BucketedBatchSampler, + UnrepeatedRandomSampler, + UnrepeatedSequentialSampler, +) +from tests.helpers.models.oneflow_model import OneflowNormalModel_Classification_1 +from tests.helpers.datasets.oneflow_data import OneflowNormalDataset, OneflowNormalXYDataset +from tests.helpers.utils import recover_logger +from fastNLP.envs.distributed import rank_zero_rm +from fastNLP import logger +from fastNLP.core.drivers.oneflow_driver.dist_utils import fastnlp_oneflow_all_gather +from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW +if _NEED_IMPORT_ONEFLOW: + import oneflow + import oneflow.comm as comm + import oneflow.env as dist_env + from oneflow.utils.data import DataLoader, BatchSampler + +def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all"): + oneflow_model = OneflowNormalModel_Classification_1(labels, features) + oneflow_opt = oneflow.optim.Adam(params=oneflow_model.parameters(), lr=0.01) + device = [oneflow.device("cuda", i) for i in device] + driver = OneflowDDPDriver( + model=oneflow_model, + parallel_device=device, + fp16=fp16, + output_from_new_proc=output_from_new_proc + ) + driver.set_optimizers(oneflow_opt) + driver.setup() + + return driver + +def dataloader_with_bucketedbatchsampler(dataset, length, batch_size, shuffle, drop_last): + """ + 建立一个 batch_sampler 为 BucketedBatchSampler 的 dataloader + """ + dataloader = DataLoader( + dataset=dataset, + batch_sampler=BucketedBatchSampler( + dataset, + length, + batch_size, + shuffle=shuffle, + drop_last=drop_last, + ), + ) + + return dataloader + +def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=0, unrepeated=False): + """ + 建立一个 sampler 为 RandomSampler 的 dataloader + """ + if unrepeated: + sampler = UnrepeatedRandomSampler(dataset, shuffle, seed) + else: + sampler = RandomSampler(dataset, shuffle, seed=seed) + dataloader = DataLoader( + dataset, + sampler=sampler, + drop_last=drop_last, + batch_size=batch_size + ) + return dataloader + +############################################################################ +# +# 测试 OneflowDDPDriver 的一些函数 +# +############################################################################ + +@pytest.mark.oneflow +class TestDDPDriverFunction: + """ + 测试 OneflowDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 + """ + + def test_simple_functions(self): + """ + 简单测试多个函数 + """ + driver = generate_driver(10, 10) + + """ + 测试 move_data_to_device 函数。 + """ + + driver.move_data_to_device(oneflow.rand((32, 64))) + comm.barrier() + + """ + 测试 is_distributed 函数 + """ + assert driver.is_distributed() == True + comm.barrier() + + """ + 测试 get_no_sync_context 函数 + """ + res = driver.get_model_no_sync_context() + comm.barrier() + + """ + 测试 is_global_zero 函数 + """ + driver.is_global_zero() + comm.barrier() + + """ + 测试 unwrap_model 函数 + """ + driver.unwrap_model() + comm.barrier() + + """ + 测试 get_local_rank 函数 + """ + driver.get_local_rank() + comm.barrier() + + """ + 测试 all_gather 函数 + 详细的测试在 test_dist_utils.py 中完成 + """ + obj = { + "rank": driver.global_rank + } + obj_list = driver.all_gather(obj) + for i, res in enumerate(obj_list): + assert res["rank"] == i + + """ + 测试 broadcast_object 函数 + 详细的函数在 test_dist_utils.py 中完成 + """ + if driver.global_rank == 0: + obj = { + "rank": driver.global_rank + } + else: + obj = None + res = driver.broadcast_object(obj, src=0) + assert res["rank"] == 0 + +############################################################################ +# +# 测试 set_dist_repro_dataloader 函数 +# +############################################################################ + +@pytest.mark.oneflow +class TestSetDistReproDataloader: + + @classmethod + def setup_class(cls): + cls.device = [0, 1] + + def setup_method(self): + self.dataset = OneflowNormalDataset(100) + + """ + 传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 + 此时对应 driver.load_checkpoint 中的情况 + """ + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_batch_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 + 此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler + """ + driver = generate_driver(10, 10, device=self.device) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) + batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) + assert replaced_loader.batch_sampler is batch_sampler + self.check_distributed_sampler(replaced_loader.batch_sampler) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) + + comm.barrier() + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 + 此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler + """ + driver = generate_driver(10, 10, device=self.device) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) + sampler = RandomSampler(self.dataset, shuffle=shuffle) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, sampler, False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert replaced_loader.batch_sampler.sampler is sampler + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) + + comm.barrier() + + """ + 传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` + 参数为 False。此时函数会根据 `reproducible` 的设置进行不同的处理。 + 当 `reproducible` 为 False 时,需要根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定 + 是否重新实例化 dataloader + """ + + def test_with_dist_none_reproducible_true(self): + """ + 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 + 当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 + """ + driver = generate_driver(10, 10, device=self.device) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) + with pytest.raises(RuntimeError): + # 应当抛出 RuntimeError + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, True) + + comm.barrier() + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler + 时的表现 + 此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler + 和原 dataloader 相同 + """ + driver = generate_driver(10, 10, device=self.device) + dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) + dataloader.batch_sampler.set_distributed( + num_replicas=driver.world_size, + rank=driver.global_rank, + pad=True + ) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) + assert replaced_loader.batch_sampler.batch_size == 4 + self.check_distributed_sampler(dataloader.batch_sampler) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) + + comm.barrier() + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_none_reproducible_false_dataloader_reproducible_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 + 此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 + batch_sampler.sampler 和原 dataloader 相同 + """ + driver = generate_driver(10, 10, device=self.device) + dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) + dataloader.batch_sampler.sampler.set_distributed( + num_replicas=driver.world_size, + rank=driver.global_rank + ) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) + assert replaced_loader.batch_sampler.batch_size == 4 + assert replaced_loader.batch_sampler.drop_last == False + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) + + comm.barrier() + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 + 此时直接返回原来的 dataloader,不做任何处理。 + """ + driver = generate_driver(10, 10, device=self.device) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) + + assert replaced_loader is dataloader + comm.barrier() + + """ + 传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 + 为 True。此时函数会根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定如何重新实例化 dataloader + """ + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler + 的表现 + 此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 + """ + driver = generate_driver(10, 10, device=self.device) + dataloader = DataLoader( + dataset=self.dataset, + batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) + ) + dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert replaced_loader.batch_sampler.batch_size == 4 + assert replaced_loader.drop_last == dataloader.drop_last + self.check_distributed_sampler(replaced_loader.batch_sampler) + comm.barrier() + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_dist_dataloader_reproducible_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler + 的表现 + 此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 + 的属性 + """ + driver = generate_driver(10, 10, device=self.device) + dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) + + assert not (replaced_loader is dataloader) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) + assert replaced_loader.batch_sampler.batch_size == 4 + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + comm.barrier() + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_dist_dataloader_normal(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 + 此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 + 的属性 + """ + driver = generate_driver(10, 10, device=self.device) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + comm.barrier() + + """ + 传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 + 为 True。此时函数会根据 dataloader 的 sampler 是否为 Unrepeated 和 Reproducible 来决定如何重新实例化 dataloader + """ + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler + 的表现 + 此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 + 的属性 + """ + driver = generate_driver(10, 10, device=self.device) + dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) + assert replaced_loader.batch_sampler.batch_size == 4 + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + comm.barrier() + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler + 的表现 + 此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler + """ + driver = generate_driver(10, 10, device=self.device) + dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) + assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) + assert replaced_loader.batch_sampler.batch_size == 4 + assert replaced_loader.drop_last == dataloader.drop_last + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + comm.barrier() + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_unrepeat_dataloader_normal(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 + 此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 + 的属性 + """ + driver = generate_driver(10, 10, device=self.device) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedSequentialSampler) + assert replaced_loader.batch_sampler.batch_size == 4 + assert replaced_loader.drop_last == dataloader.drop_last + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + comm.barrier() + + def check_distributed_sampler(self, sampler): + """ + 测试替换得到的 sampler 或 batch_sampler 的分布式设置是否正确 + """ + assert sampler.num_replicas == dist_env.get_world_size() + assert sampler.rank == dist_env.get_rank() + if not isinstance(sampler, UnrepeatedSampler): + assert sampler.pad == True + + def check_set_dist_repro_dataloader(self, driver, dataloader, replaced_loader, shuffle): + """ + 测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 + """ + # 迭代两个 batch + num_replicas = len(self.device) + num_consumed_batches = 2 + already_seen_idx = set() + if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): + sampler_states = replaced_loader.batch_sampler.set_epoch(4) + else: + sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(4) + for idx, batch in enumerate(replaced_loader): + if idx >= num_consumed_batches: + break + already_seen_idx.update(batch.tolist()) + comm.barrier() + if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): + sampler_states = replaced_loader.batch_sampler.state_dict() + else: + sampler_states = replaced_loader.batch_sampler.sampler.state_dict() + + # 重新加载,应该可以输出剩下的内容,且对于 OneflowNormalDataset 来说,排序后应该是一个 range + left_idxes = set() + if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): + batch_size = replaced_loader.batch_sampler.batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas + # 重新改造 dataloader + new_loader = dataloader_with_bucketedbatchsampler( + replaced_loader.dataset, + length=replaced_loader.dataset._data, + batch_size=batch_size, + shuffle=shuffle, + drop_last=False, + ) + new_loader.batch_sampler.set_distributed( + num_replicas=driver.world_size, + rank=driver.global_rank, + pad=True + ) + new_loader.batch_sampler.load_state_dict(sampler_states) + new_loader.batch_sampler.set_epoch(4) + else: + batch_size = replaced_loader.batch_sampler.batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas + # 重新构造 dataloader + new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False) + new_loader.batch_sampler.sampler.set_distributed( + num_replicas=driver.world_size, + rank=driver.global_rank + ) + new_loader.batch_sampler.sampler.load_state_dict(sampler_states) + new_loader.batch_sampler.sampler.set_epoch(4) + for idx, batch in enumerate(new_loader): + left_idxes.update(batch.tolist()) + + assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas + assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas + + +############################################################################ +# +# 测试 save 和 load 相关的功能 +# +############################################################################ +@pytest.mark.oneflow +class TestSaveLoad: + """ + 测试多卡情况下 save 和 load 相关函数的表现 + """ + + def setup_method(self): + self.dataset = OneflowNormalXYDataset(100) + + @pytest.mark.parametrize("only_state_dict", ([True, False])) + def test_save_and_load_model(self, only_state_dict): + """ + 测试 save_model 和 load_model 函数 + """ + try: + path = "model" + + dataloader = DataLoader(self.dataset, batch_size=2) + driver1, driver2 = generate_driver(20, 1), generate_driver(20, 1) + + driver1.save_model(path, only_state_dict) + + # 同步 + comm.barrier() + driver2.load_model(path, only_state_dict) + + for idx, batch in enumerate(dataloader): + batch = driver1.move_data_to_device(batch) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + + assert oneflow.all(res1["preds"] == res2["preds"]) + finally: + rank_zero_rm(path) + + @pytest.mark.parametrize("only_state_dict", ([True, False])) + @pytest.mark.parametrize("fp16", ([True, False])) + @pytest.mark.parametrize("device", ([[0,1]])) + def test_save_and_load_with_bucketedbatchsampler(self, device, only_state_dict, fp16): + """ + 测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 + """ + + try: + path = "model.ckp" + num_replicas = len(device) + + driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16), \ + generate_driver(20, 1, device=device, fp16=False) + dataloader = dataloader_with_bucketedbatchsampler( + self.dataset, + length=[10 for i in range(len(self.dataset))], + batch_size=4, + shuffle=True, + drop_last=False + ) + dataloader.batch_sampler.set_distributed( + num_replicas=driver1.world_size, + rank=driver1.global_rank, + pad=True + ) + num_consumed_batches = 4 + + already_seen_x_set = set() + already_seen_y_set = set() + driver1.set_sampler_epoch(dataloader, 4) + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) + already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) + + # 同步 + comm.barrier() + + # 保存状态 + sampler_states = dataloader.batch_sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + comm.barrier() + # 加载 + # 更改 batch_size + dataloader = dataloader_with_bucketedbatchsampler( + self.dataset, + length=[10 for i in range(len(self.dataset))], + batch_size=2, + shuffle=True, + drop_last=False + ) + dataloader.batch_sampler.set_distributed( + num_replicas=driver2.world_size, + rank=driver2.global_rank, + pad=True + ) + comm.barrier() + load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) + comm.barrier() + replaced_loader = load_states.pop("dataloader") + + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 batch_sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert replaced_loader.batch_sampler is dataloader.batch_sampler + assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) + if os.environ['FASTNLP_GLOBAL_RANK'] == '0': + assert replaced_loader.batch_sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas + + # # 3. 检查 fp16 是否被加载 + # if fp16: + # assert not isinstance(driver2.grad_scaler, oneflow.cuda.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + driver2.set_sampler_epoch(replaced_loader, 4) + for idx, batch in enumerate(replaced_loader): + + left_x_batches.update(batch["x"].reshape(-1, ).tolist()) + left_y_batches.update(batch["y"].reshape(-1, ).tolist()) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + assert oneflow.all(res1["preds"] == res2["preds"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas + assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas + comm.barrier() + finally: + rank_zero_rm(path) + + @pytest.mark.parametrize("only_state_dict", ([True, False])) + @pytest.mark.parametrize("fp16", ([True, False])) + @pytest.mark.parametrize("device", ([[0,1]])) + def test_save_and_load_with_randomsampler(self, device, only_state_dict, fp16): + """ + 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 + """ + + try: + path = "checkpoints/" + + num_replicas = len(device) + + driver1 = generate_driver(20, 1, device=device, fp16=fp16) + driver2 = generate_driver(20, 1, device=device, fp16=False) + + dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) + dataloader.batch_sampler.sampler.set_distributed( + num_replicas=driver1.world_size, + rank=driver1.global_rank, + pad=True + ) + num_consumed_batches = 4 + + already_seen_x_set = set() + already_seen_y_set = set() + driver1.set_sampler_epoch(dataloader, 4) + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) + already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) + + # 同步 + comm.barrier() + + # 保存状态 + sampler_states = dataloader.batch_sampler.sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + comm.barrier() # 等待save成功 + # 加载 + # 更改 batch_size + dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) + dataloader.batch_sampler.sampler.set_distributed( + num_replicas=driver2.world_size, + rank=driver2.global_rank, + pad=True + ) + load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") + + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + if os.environ['FASTNLP_GLOBAL_RANK'] == '0': + assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] + assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] + assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] + assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas + + # # 3. 检查 fp16 是否被加载 + # if fp16: + # assert not isinstance(driver2.grad_scaler, oneflow.cuda.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + driver2.set_sampler_epoch(replaced_loader, 4) + for idx, batch in enumerate(replaced_loader): + + left_x_batches.update(batch["x"].reshape(-1, ).tolist()) + left_y_batches.update(batch["y"].reshape(-1, ).tolist()) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + assert oneflow.all(res1["preds"] == res2["preds"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas + assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas + + finally: + rank_zero_rm(path) + + +@pytest.mark.oneflow +@pytest.mark.parametrize("shuffle", ([True, False])) +@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) +@pytest.mark.parametrize("drop_last", ([True, False])) +def test_shuffle_dataloader(shuffle, batch_size, drop_last, reproducible=True): + try: + # 需要检验一下 set_dist_repro_dataloader 没有修改参数 + num_samples = 200 + dataset = OneflowNormalXYDataset(num_samples) + dl = prepare_oneflow_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) + model = OneflowNormalModel_Classification_1(10, 32) + device = [oneflow.device("cuda", i) for i in [0, 1]] + + driver = OneflowDDPDriver(model, parallel_device=device) + driver.setup() + dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) + + data = [] + flags = [] + for batch in dl: + flags.append(batch['x'].size(0) == batch_size) + data.extend(batch['x'].reshape(-1).tolist()) + + _num_samples = num_samples//2 + + if drop_last and _num_samples%batch_size != 0: + assert len(data)!=_num_samples + assert all(flags) == True + elif _num_samples%batch_size!=0: + assert flags[-1] is False + else: + assert len(data) == _num_samples + + if not shuffle: + for i in range(1, len(data)-1): + assert data[i]>data[i-1] + else: + flags = [] + for i in range(1, len(data)-1): + flags.append(data[i]>data[i-1]) + assert all(flags) is False + datas = fastnlp_oneflow_all_gather(data) + if drop_last: + assert len(set(datas[0] + datas[1])) == num_samples-_num_samples%batch_size*2 + else: + assert len(set(datas[0] + datas[1])) == num_samples + finally: + pass + + +@pytest.mark.oneflow +@pytest.mark.parametrize("shuffle", ([True, False])) +@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) +@pytest.mark.parametrize("drop_last", ([True, False])) +def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible=True): + try: + # 需要检验一下 set_dist_repro_dataloader 没有修改参数 + num_samples = 200 + num_device = 2 + dataset = OneflowNormalXYDataset(num_samples) + sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, + shuffle=shuffle, num_batch_per_bucket=2) + dl = prepare_oneflow_dataloader(dataset, batch_sampler=sampler) + model = OneflowNormalModel_Classification_1(10, 32) + device = [oneflow.device("cuda", i) for i in [0, 1]] + driver = OneflowDDPDriver(model, parallel_device=device) + driver.setup() + dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) + + data = [] + flags = [] + for batch in dl: + d = batch['x'].reshape(-1).tolist() + diff = max(d) - min(d) + assert diff= num_consumed_batches: + break + already_seen_idx.update(batch.tolist()) + if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): + sampler_states = replaced_loader.batch_sampler.state_dict() + else: + sampler_states = replaced_loader.batch_sampler.sampler.state_dict() + + # 重新加载,应该可以输出剩下的内容,且对于 OneflowNormalDataset 来说,排序后应该是一个 range + left_idxes = set() + if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): + batch_size = replaced_loader.batch_sampler.batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size + # 重新改造 dataloader + new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False) + new_loader.batch_sampler.load_state_dict(sampler_states) + new_loader.batch_sampler.set_epoch(3) + else: + batch_size = replaced_loader.batch_sampler.batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size + # 重新构造 dataloader + new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False) + new_loader.batch_sampler.sampler.load_state_dict(sampler_states) + new_loader.batch_sampler.sampler.set_epoch(3) + for idx, batch in enumerate(new_loader): + left_idxes.update(batch.tolist()) + + assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) + assert len(left_idxes | already_seen_idx) == len(self.dataset) + +############################################################################ +# +# 测试 save 和 load 相关的功能 +# +############################################################################ + +def generate_random_driver(labels, features, fp16=False, device="cpu"): + """ + 生成driver + """ + model = OneflowNormalModel_Classification_1(labels, features) + opt = oneflow.optim.Adam(params=model.parameters(), lr=0.01) + driver = OneflowSingleDriver(model, device=device, fp16=fp16) + driver.set_optimizers(opt) + driver.setup() + + return driver + +@pytest.mark.oneflow +@pytest.mark.parametrize("only_state_dict", ([True, False])) +def test_save_and_load_model(only_state_dict): + """ + 测试 save_model 和 load_model 函数 + """ + try: + path = "model" + dataset = OneflowNormalXYDataset(20) + dataloader = DataLoader(dataset, batch_size=4) + driver1, driver2 = generate_random_driver(20, 1), generate_random_driver(20, 1) + + driver1.save_model(path, only_state_dict) + driver2.load_model(path, only_state_dict) + + for batch in dataloader: + batch = driver1.move_data_to_device(batch) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + + assert oneflow.all(res1["preds"] == res2["preds"]) + finally: + rank_zero_rm(path) + +@pytest.mark.oneflow +@pytest.mark.parametrize("only_state_dict", ([True, False])) +@pytest.mark.parametrize("fp16", ([True, False])) +def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): + """ + 测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 + """ + + try: + path = "model.ckp" + dataset = OneflowNormalXYDataset(20) + dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False) + driver1, driver2 = generate_random_driver(20, 1, fp16, "cuda"), generate_random_driver(20, 1, False, "cuda") + + num_consumed_batches = 2 + + already_seen_x_set = set() + already_seen_y_set = set() + driver1.set_sampler_epoch(dataloader, 3) + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) + already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) + + sampler_states = dataloader.batch_sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + # 加载 + # 更改 batch_size + + dataloader = dataloader_with_randombatchsampler(dataset, 2, True, False) + load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 batch_sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert replaced_loader.batch_sampler is dataloader.batch_sampler + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) + assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] + assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 + + # # 3. 检查 fp16 是否被加载 + # if fp16: + # assert not isinstance(driver2.grad_scaler, oneflow.cuda.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + driver1.set_sampler_epoch(replaced_loader, 3) + for idx, batch in enumerate(replaced_loader): + + batch = driver2.move_data_to_device(batch) + left_x_batches.update(batch["x"].reshape(-1, ).tolist()) + left_y_batches.update(batch["y"].reshape(-1, ).tolist()) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + assert oneflow.all(res1["preds"] == res2["preds"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) + assert len(left_x_batches | already_seen_x_set) == len(dataset) + assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) + assert len(left_y_batches | already_seen_y_set) == len(dataset) + finally: + rank_zero_rm(path) + +@pytest.mark.oneflow +@pytest.mark.parametrize("only_state_dict", ([True, False])) +@pytest.mark.parametrize("fp16", ([True, False])) +def test_save_and_load_with_randomsampler(only_state_dict, fp16): + """ + 测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 + """ + + try: + path = "model.ckp" + + driver1, driver2 = generate_random_driver(40, 1, fp16, "cuda"), generate_random_driver(40, 1, False, "cuda") + dataset = OneflowNormalXYDataset(40) + dataloader = dataloader_with_randomsampler(dataset, 4, True, False) + num_consumed_batches = 2 + + already_seen_x_set = set() + already_seen_y_set = set() + driver1.set_sampler_epoch(dataloader, 3) + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) + already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) + + sampler_states = dataloader.batch_sampler.sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + + # 加载 + # 更改 batch_size + dataloader = dataloader_with_randomsampler(dataset, 2, True, False) + load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") + + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] + assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches + assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] + assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] + + # # 3. 检查 fp16 是否被加载 + # if fp16: + # assert not isinstance(driver2.grad_scaler, oneflow.cuda.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + # set epoch + driver2.set_sampler_epoch(replaced_loader, 3) + for idx, batch in enumerate(replaced_loader): + + batch = driver2.move_data_to_device(batch) + left_x_batches.update(batch["x"].reshape(-1, ).tolist()) + left_y_batches.update(batch["y"].reshape(-1, ).tolist()) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + assert oneflow.all(res1["preds"] == res2["preds"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) + assert len(left_x_batches | already_seen_x_set) == len(dataset) + assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) + assert len(left_y_batches | already_seen_y_set) == len(dataset) + finally: + rank_zero_rm(path) + + +@pytest.mark.oneflow +@pytest.mark.parametrize("shuffle", ([True, False])) +@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) +@pytest.mark.parametrize("drop_last", ([True, False])) +@pytest.mark.parametrize("reproducible", ([True, False])) +def test_shuffle_dataloader(shuffle, batch_size, drop_last, reproducible): + # 需要检验一下 set_dist_repro_dataloader 没有修改参数 + num_samples = 100 + dataset = OneflowNormalXYDataset(num_samples) + dl = prepare_oneflow_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) + model = OneflowNormalModel_Classification_1(10, 32) + driver = OneflowSingleDriver(model, device="cpu") + dl = driver.set_dist_repro_dataloader(dataloader=dl, reproducible=reproducible) + + data = [] + flags = [] + for batch in dl: + flags.append(batch['x'].size(0) == batch_size) + data.extend(batch['x'].reshape(-1).tolist()) + + if drop_last and num_samples%batch_size != 0: + assert len(data)!=num_samples + assert all(flags) == True + elif num_samples%batch_size!=0: + assert flags[-1] is False + else: + assert len(data) == num_samples + + if not shuffle: + for i in range(1, len(data)): + assert data[i]>data[i-1] + else: + flags = [] + for i in range(1, len(data)): + flags.append(data[i]>data[i-1]) + assert all(flags) is False + + +@pytest.mark.oneflow +@pytest.mark.parametrize("shuffle", ([True, False])) +@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) +@pytest.mark.parametrize("drop_last", ([True, False])) +@pytest.mark.parametrize("reproducible", ([True, False])) +def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible): + # 需要检验一下 set_dist_repro_dataloader 没有修改参数 + num_samples = 100 + dataset = OneflowNormalXYDataset(num_samples) + sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, + shuffle=shuffle, num_batch_per_bucket=2) + dl = prepare_oneflow_dataloader(dataset, batch_sampler=sampler) + model = OneflowNormalModel_Classification_1(10, 32) + driver = OneflowSingleDriver(model, device="cpu") + dl = driver.set_dist_repro_dataloader(dataloader=dl, reproducible=reproducible) + + data = [] + flags = [] + for batch in dl: + d = batch['x'].reshape(-1).tolist() + diff = max(d) - min(d) + assert diffdata[i-1] + else: + flags = [] + for i in range(1, len(data)-1): + flags.append(data[i]>data[i-1]) + assert all(flags) is False + datas = fastnlp_paddle_all_gather(data) + if drop_last: + assert len(set(datas[0] + datas[1])) == num_samples-_num_samples%batch_size*2 + else: + assert len(set(datas[0] + datas[1])) == num_samples + finally: + dist.barrier() + + @magic_argv_env_context + @pytest.mark.parametrize("shuffle", ([True, False])) + @pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) + @pytest.mark.parametrize("drop_last", ([True, False])) + def test_batch_sampler_dataloader(self, shuffle, batch_size, drop_last, reproducible=True): + try: + # 需要检验一下 set_dist_repro_dataloader 没有修改参数 + num_samples = 200 + num_device = 2 + dataset = PaddleNormalXYDataset(num_samples) + sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, + shuffle=shuffle, num_batch_per_bucket=2) + dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler) + self.driver.setup() + dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) + + data = [] + flags = [] + for batch in dl: + d = batch['x'].reshape((-1, )).tolist() + diff = max(d) - min(d) + assert diffdata[i-1] + else: + flags = [] + for i in range(1, len(data)): + flags.append(data[i]>data[i-1]) + assert all(flags) is False + + + @pytest.mark.paddle + @pytest.mark.parametrize("shuffle", ([True, False])) + @pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) + @pytest.mark.parametrize("drop_last", ([True, False])) + @pytest.mark.parametrize("reproducible", ([True, False])) + def test_batch_sampler_dataloader(self, shuffle, batch_size, drop_last, reproducible): + # 需要检验一下 set_dist_repro_dataloader 没有修改参数 + num_samples = 200 + dataset = PaddleNormalXYDataset(num_samples) + sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, + shuffle=shuffle, num_batch_per_bucket=2) + dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler) + model = PaddleNormalModel_Classification_1(1, 2) + dl = self.driver.set_dist_repro_dataloader(dataloader=dl, reproducible=reproducible) + + data = [] + flags = [] + for batch in dl: + d = batch['x'].reshape((-1, )).tolist() + diff = max(d) - min(d) + assert diffdata[i-1] + else: + flags = [] + for i in range(1, len(data)-1): + flags.append(data[i]>data[i-1]) + assert all(flags) is False + datas = fastnlp_torch_all_gather(data) + if drop_last: + assert len(set(datas[0] + datas[1])) == num_samples-_num_samples%batch_size*2 + else: + assert len(set(datas[0] + datas[1])) == num_samples + finally: + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + +@pytest.mark.torch +@magic_argv_env_context +@pytest.mark.parametrize("shuffle", ([True, False])) +@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) +@pytest.mark.parametrize("drop_last", ([True, False])) +def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible=True): + try: + # 需要检验一下 set_dist_repro_dataloader 没有修改参数 + num_samples = 200 + num_device = 2 + dataset = TorchNormalXYDataset(num_samples) + sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, + shuffle=shuffle, num_batch_per_bucket=2) + dl = prepare_torch_dataloader(dataset, batch_sampler=sampler) + model = TorchNormalModel_Classification_1(10, 32) + device = [torch.device(i) for i in [0, 1]] + driver = TorchDDPDriver(model, parallel_device=device) + driver.setup() + dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) + + data = [] + flags = [] + for batch in dl: + d = batch['x'].reshape(-1).tolist() + diff = max(d) - min(d) + assert diff DataParallel.module -> _FleetWrappingModel.model + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + res2 = driver2.model( + batch, + fastnlp_fn=driver2.model.module.model.evaluate_step, + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + + assert torch.equal(res1["preds"], res2["preds"]) + finally: + rank_zero_rm(path) + + # if dist.is_initialized(): + # dist.destroy_process_group() + + @magic_argv_env_context + @pytest.mark.parametrize("only_state_dict", ([True, False])) + @pytest.mark.parametrize("fp16", ([True, False])) + @pytest.mark.parametrize("device", ([[0,1]])) + def test_save_and_load_with_bucketedbatchsampler(self, device, only_state_dict, fp16): + """ + 测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 + """ + + try: + path = "model.ckp" + num_replicas = len(device) + + dataloader = dataloader_with_bucketedbatchsampler( + self.dataset, + length=[10 for i in range(len(self.dataset))], + batch_size=4, + shuffle=True, + drop_last=False + ) + dataloader.batch_sampler.set_distributed( + num_replicas=int(os.getenv("WORLD_SIZE", "1")), + rank=int(os.getenv("RANK", "0")), + pad=True, + ) + num_consumed_batches = 4 + driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16, train_dataloader=dataloader), \ + generate_driver(20, 1, device=device, fp16=False, train_dataloader=dataloader) + + already_seen_x_set = set() + already_seen_y_set = set() + driver1.set_sampler_epoch(dataloader, 4) + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) + already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) + + # 同步 + dist.barrier() + + # 保存状态 + sampler_states = dataloader.batch_sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + dist.barrier() + # 加载 + # 更改 batch_size + dataloader = dataloader_with_bucketedbatchsampler( + self.dataset, + length=[10 for i in range(len(self.dataset))], + batch_size=2, + shuffle=True, + drop_last=False + ) + dataloader.batch_sampler.set_distributed( + num_replicas=driver2.world_size, + rank=driver2.global_rank, + pad=True + ) + dist.barrier() + load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) + dist.barrier() + replaced_loader = load_states.pop("dataloader") + + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 batch_sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert replaced_loader.batch_sampler is dataloader.batch_sampler + assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) + if os.environ['FASTNLP_GLOBAL_RANK'] == '0': + assert replaced_loader.batch_sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + driver2.set_sampler_epoch(replaced_loader, 4) + for idx, batch in enumerate(replaced_loader): + + left_x_batches.update(batch["x"].reshape(-1, ).tolist()) + left_y_batches.update(batch["y"].reshape(-1, ).tolist()) + batch = driver1.move_data_to_device(batch) + res1 = driver1.model( + batch, + fastnlp_fn=driver1.model.module.model.evaluate_step, + # Driver.model -> DataParallel.module -> _FleetWrappingModel.model + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + res2 = driver2.model( + batch, + fastnlp_fn=driver2.model.module.model.evaluate_step, + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + assert torch.equal(res1["preds"], res2["preds"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas + assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas + dist.barrier() + finally: + rank_zero_rm(path) + + # if dist.is_initialized(): + # dist.destroy_process_group() + + @magic_argv_env_context + @pytest.mark.parametrize("only_state_dict", ([True, False])) + @pytest.mark.parametrize("fp16", ([True, False])) + @pytest.mark.parametrize("device", ([[0,1]])) + def test_save_and_load_with_randomsampler(self, device, only_state_dict, fp16): + """ + 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 + """ + + try: + path = "checkpoints/" + + num_replicas = len(device) + + dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) + dataloader.batch_sampler.sampler.set_distributed( + num_replicas=int(os.getenv("WORLD_SIZE", "1")), + rank=int(os.getenv("RANK", "0")), + pad=True + ) + num_consumed_batches = 4 + + driver1 = generate_driver(20, 1, device=device, fp16=fp16, train_dataloader=dataloader) + driver2 = generate_driver(20, 1, device=device, fp16=False, train_dataloader=dataloader) + + already_seen_x_set = set() + already_seen_y_set = set() + driver1.set_sampler_epoch(dataloader, 4) + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) + already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) + + # 同步 + dist.barrier() + + # 保存状态 + sampler_states = dataloader.batch_sampler.sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + if only_state_dict: + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + else: + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) + dist.barrier() # 等待save成功 + # 加载 + # 更改 batch_size + dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) + dataloader.batch_sampler.sampler.set_distributed( + num_replicas=driver2.world_size, + rank=driver2.global_rank, + pad=True + ) + load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") + + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + if os.environ['FASTNLP_GLOBAL_RANK'] == '0': + assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] + assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] + assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] + assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas + + # 3. 检查 fp16 是否被加载 + if fp16: + assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + driver2.set_sampler_epoch(replaced_loader, 4) + for idx, batch in enumerate(replaced_loader): + + left_x_batches.update(batch["x"].reshape(-1, ).tolist()) + left_y_batches.update(batch["y"].reshape(-1, ).tolist()) + batch = driver1.move_data_to_device(batch) + res1 = driver1.model( + batch, + fastnlp_fn=driver1.model.module.model.evaluate_step, + # Driver.model -> DataParallel.module -> _FleetWrappingModel.model + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + res2 = driver2.model( + batch, + fastnlp_fn=driver2.model.module.model.evaluate_step, + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + assert torch.equal(res1["preds"], res2["preds"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas + assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas + + finally: + rank_zero_rm(path) + + # if dist.is_initialized(): + # dist.destroy_process_group() \ No newline at end of file diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py index 73ffbb8d..44141d06 100644 --- a/tests/core/drivers/torch_driver/test_single_device.py +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -11,6 +11,7 @@ from tests.helpers.datasets.paddle_data import PaddleNormalDataset from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH +from fastNLP import prepare_torch_dataloader, BucketedBatchSampler if _NEED_IMPORT_TORCH: import torch @@ -710,3 +711,85 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): assert len(left_y_batches | already_seen_y_set) == len(dataset) finally: rank_zero_rm(path) + + +@pytest.mark.torch +@pytest.mark.parametrize("shuffle", ([True, False])) +@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) +@pytest.mark.parametrize("drop_last", ([True, False])) +@pytest.mark.parametrize("reproducible", ([True, False])) +def test_shuffle_dataloader(shuffle, batch_size, drop_last, reproducible): + # 需要检验一下 set_dist_repro_dataloader 没有修改参数 + num_samples = 100 + dataset = TorchNormalXYDataset(num_samples) + dl = prepare_torch_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) + model = TorchNormalModel_Classification_1(10, 32) + driver = TorchSingleDriver(model, device="cpu") + dl = driver.set_dist_repro_dataloader(dataloader=dl, reproducible=reproducible) + + data = [] + flags = [] + for batch in dl: + flags.append(batch['x'].size(0) == batch_size) + data.extend(batch['x'].reshape(-1).tolist()) + + if drop_last and num_samples%batch_size != 0: + assert len(data)!=num_samples + assert all(flags) == True + elif num_samples%batch_size!=0: + assert flags[-1] is False + else: + assert len(data) == num_samples + + if not shuffle: + for i in range(1, len(data)): + assert data[i]>data[i-1] + else: + flags = [] + for i in range(1, len(data)): + flags.append(data[i]>data[i-1]) + assert all(flags) is False + + +@pytest.mark.torch +@pytest.mark.parametrize("shuffle", ([True, False])) +@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) +@pytest.mark.parametrize("drop_last", ([True, False])) +@pytest.mark.parametrize("reproducible", ([True, False])) +def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible): + # 需要检验一下 set_dist_repro_dataloader 没有修改参数 + num_samples = 100 + dataset = TorchNormalXYDataset(num_samples) + sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, + shuffle=shuffle, num_batch_per_bucket=2) + dl = prepare_torch_dataloader(dataset, batch_sampler=sampler) + model = TorchNormalModel_Classification_1(10, 32) + driver = TorchSingleDriver(model, device="cpu") + dl = driver.set_dist_repro_dataloader(dataloader=dl, reproducible=reproducible) + + data = [] + flags = [] + for batch in dl: + d = batch['x'].reshape(-1).tolist() + diff = max(d) - min(d) + assert diff\n", + "\n", + "\n", + "\n", + "``ERNIE 2.0`` 则提出了连续学习(``Continual Learning``)的概念,即首先用一个简单的任务来初始化模型,在更新时用前一个任务训练好的参数作为下一个任务模型初始化的参数。这样在训练新的任务时,模型便可以记住之前学习到的知识,使得模型在新任务上获得更好的表现。``ERNIE 2.0`` 分别构建了词法、语法、语义不同级别的预训练任务,并使用不同的 task id 来标示不同的任务,在共计16个中英文任务上都取得了SOTA效果。\n", + "\n", + "\n", + "\n", + "``ERNIE 3.0`` 将自回归和自编码网络融合在一起进行预训练,其中自编码网络采用 ``ERNIE 2.0`` 的多任务学习增量式构建预训练任务,持续进行语义理解学习。其中自编码网络增加了知识增强的预训练任务。自回归网络则基于 ``Tranformer-XL`` 结构,支持长文本语言模型建模,并在多个自然语言处理任务中取得了SOTA的效果。\n", + "\n", + "\n", + "\n", + "接下来,我们将展示如何在 ``FastNLP`` 中使用基于 ``paddle`` 的 ``ERNIE 1.0`` 框架进行中文情感分析。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 使用 tokenizer 处理数据并构造 dataloader\n", + "\n", + "#### 2.1 加载中文数据集 ChnSentiCorp\n", + "\n", + "``ChnSentiCorp`` 数据集是由中国科学院发布的中文句子级情感分析数据集,包含了从网络上获取的酒店、电影、书籍等多个领域的评论,每条评论都被划分为两个标签:消极(``0``)和积极(``1``),可以用于二分类的中文情感分析任务。通过 ``paddlenlp.datasets.load_dataset`` 函数,我们可以加载并查看 ``ChnSentiCorp`` 数据集的内容。" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "训练集大小: 9600\n", + "{'text': '选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般', 'label': 1, 'qid': ''}\n", + "{'text': '15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错', 'label': 1, 'qid': ''}\n", + "{'text': '房间太小。其他的都一般。。。。。。。。。', 'label': 0, 'qid': ''}\n" + ] + } + ], + "source": [ + "from paddlenlp.datasets import load_dataset\n", + "\n", + "train_dataset, val_dataset, test_dataset = load_dataset(\"chnsenticorp\", splits=[\"train\", \"dev\", \"test\"])\n", + "print(\"训练集大小:\", len(train_dataset))\n", + "for i in range(3):\n", + " print(train_dataset[i])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.2 处理数据\n", + "\n", + "可以看到,原本的数据集仅包含中文的文本和标签,这样的数据是无法被模型识别的。同英文文本分类任务一样,我们需要使用 ``tokenizer`` 对文本进行分词并转换为数字形式的结果。我们可以加载已经预训练好的中文分词模型 ``ernie-1.0-base-zh``,将分词的过程写在函数 ``_process`` 中,然后调用数据集的 ``map`` 函数对每一条数据进行分词。其中:\n", + "- 参数 ``max_length`` 代表句子的最大长度;\n", + "- ``padding=\"max_length\"`` 表示将长度不足的结果 padding 至和最大长度相同;\n", + "- ``truncation=True`` 表示将长度过长的句子进行截断。\n", + "\n", + "至此,我们得到了每条数据长度均相同的数据集。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m[2022-06-22 21:31:04,168] [ INFO]\u001b[0m - We are using to load 'ernie-1.0-base-zh'.\u001b[0m\n", + "\u001b[32m[2022-06-22 21:31:04,171] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/vocab.txt\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'text': '选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般', 'label': 1, 'qid': '', 'input_ids': [1, 352, 790, 1252, 409, 283, 509, 5, 250, 196, 113, 10, 58, 518, 4, 9, 128, 70, 1495, 1855, 339, 293, 45, 302, 233, 554, 4, 544, 637, 1134, 774, 6, 494, 2068, 6, 278, 191, 6, 634, 99, 6, 2678, 144, 7, 149, 1573, 62, 12043, 661, 737, 371, 435, 7, 689, 4, 255, 201, 559, 407, 1308, 12043, 2275, 1110, 11, 19, 842, 5, 1207, 878, 4, 196, 198, 321, 96, 4, 16, 93, 291, 464, 1099, 10, 692, 811, 12043, 392, 5, 748, 1134, 10, 213, 220, 5, 4, 201, 559, 723, 595, 12043, 231, 112, 1114, 4, 7, 689, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}\n" + ] + } + ], + "source": [ + "max_len = 128\n", + "model_checkpoint = \"ernie-1.0-base-zh\"\n", + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n", + "def _process(data):\n", + " data.update(tokenizer(\n", + " data[\"text\"],\n", + " max_length=max_len,\n", + " padding=\"max_length\",\n", + " truncation=True,\n", + " return_attention_mask=True,\n", + " ))\n", + " return data\n", + "\n", + "train_dataset.map(_process, num_workers=5)\n", + "val_dataset.map(_process, num_workers=5)\n", + "test_dataset.map(_process, num_workers=5)\n", + "\n", + "print(train_dataset[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "得到数据集之后,我们便可以将数据集包裹在 ``PaddleDataLoader`` 中,用于之后的训练。``FastNLP`` 提供的 ``PaddleDataLoader`` 拓展了 ``paddle.io.DataLoader`` 的功能,详情可以查看相关的文档。" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP.core import PaddleDataLoader\n", + "import paddle.nn as nn\n", + "\n", + "train_dataloader = PaddleDataLoader(train_dataset, batch_size=32, shuffle=True)\n", + "val_dataloader = PaddleDataLoader(val_dataset, batch_size=32, shuffle=False)\n", + "test_dataloader = PaddleDataLoader(test_dataset, batch_size=1, shuffle=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 模型训练:加载 ERNIE 预训练模型,使用 FastNLP 进行训练\n", + "\n", + "#### 3.1 使用 ERNIE 预训练模型\n", + "\n", + "为了实现文本分类,我们首先需要定义文本分类的模型。``paddlenlp.transformers`` 提供了模型 ``AutoModelForSequenceClassification``,我们可以利用它来加载不同权重的文本分类模型。在 ``FastNLP`` 中,我们可以定义 ``train_step`` 和 ``evaluate_step`` 函数来实现训练和验证过程中的不同行为。\n", + "\n", + "- ``train_step`` 函数在获得返回值 ``logits`` (大小为 ``(batch_size, num_labels)``)后计算交叉熵损失 ``CrossEntropyLoss``,然后将 ``loss`` 放在字典中返回。``FastNLP`` 也支持返回 ``dataclass`` 类型的训练结果,但二者都需要包含名为 **``loss``** 的键或成员。\n", + "- ``evaluate_step`` 函数在获得返回值 ``logits`` 后,将 ``logits`` 和标签 ``label`` 放在字典中返回。\n", + "\n", + "这两个函数的参数均为数据集中字典**键**的子集,``FastNLP`` 会自动进行参数匹配然后输入到模型中。" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m[2022-06-22 21:31:15,577] [ INFO]\u001b[0m - We are using to load 'ernie-1.0-base-zh'.\u001b[0m\n", + "\u001b[32m[2022-06-22 21:31:15,580] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/ernie_v1_chn_base.pdparams\u001b[0m\n" + ] + } + ], + "source": [ + "import paddle.nn as nn\n", + "\n", + "class SeqClsModel(nn.Layer):\n", + " def __init__(self, model_checkpoint, num_labels):\n", + " super(SeqClsModel, self).__init__()\n", + " self.model = AutoModelForSequenceClassification.from_pretrained(\n", + " model_checkpoint,\n", + " num_classes=num_labels,\n", + " )\n", + "\n", + " def forward(self, input_ids, attention_mask, token_type_ids):\n", + " logits = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)\n", + " return logits\n", + "\n", + " def train_step(self, input_ids, attention_mask, token_type_ids, label):\n", + " logits = self(input_ids, attention_mask, token_type_ids)\n", + " loss = nn.CrossEntropyLoss()(logits, label)\n", + " return {\"loss\": loss}\n", + "\n", + " def evaluate_step(self, input_ids, attention_mask, token_type_ids, label):\n", + " logits = self(input_ids, attention_mask, token_type_ids)\n", + " return {'pred': logits, 'target': label}\n", + "\n", + "model = SeqClsModel(model_checkpoint, num_labels=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.2 设置参数并使用 Trainer 开始训练\n", + "\n", + "现在我们可以着手使用 ``FastNLP.Trainer`` 进行训练了。\n", + "\n", + "首先,为了高效地训练 ``ERNIE`` 模型,我们最好为学习率指定一定的策略。``paddlenlp`` 提供的 ``LinearDecayWithWarmup`` 可以令学习率在一段时间内从 0 开始线性地增长(预热),然后再线性地衰减至 0 。在本篇教程中,我们将学习率设置为 ``5e-5``,预热时间为 ``0.1``,然后将得到的的 ``lr_scheduler`` 赋值给 ``AdamW`` 优化器。\n", + "\n", + "其次,我们还可以为 ``Trainer`` 指定多个 ``Callback`` 来在基础的训练过程之外进行额外的定制操作。在本篇教程中,我们使用的 ``Callback`` 有以下三种:\n", + "\n", + "- ``LRSchedCallback`` - 由于我们使用了 ``Scheduler``,因此需要将 ``lr_scheduler`` 传给该 ``Callback`` 以在训练中进行更新。\n", + "- ``LoadBestModelCallback`` - 该 ``Callback`` 会评估结果中的 ``'acc#accuracy'`` 值,保存训练中出现的正确率最高的模型,并在训练结束时加载到模型上,方便对模型进行测试和评估。\n", + "\n", + "在 ``Trainer`` 中,我们还可以设置 ``metrics`` 来衡量模型的表现。``Accuracy`` 能够根据传入的预测值和真实值计算出模型预测的正确率。还记得模型中 ``evaluate_step`` 函数的返回值吗?键 ``pred`` 和 ``target`` 分别为 ``Accuracy.update`` 的参数名,在验证过程中 ``FastNLP`` 会自动将键和参数名匹配从而计算出正确率,这也是我们规定模型需要返回字典类型数据的原因。\n", + "\n", + "``Accuracy`` 的返回值包含三个部分:``acc``、``total`` 和 ``correct``,分别代表 ``正确率``、 ``数据总数`` 和 ``预测正确的数目``,这让您能够直观地知晓训练中模型的变化,``LoadBestModelCallback`` 的参数 ``'acc#accuracy'`` 也正是代表了 ``accuracy`` 指标的 ``acc`` 结果。\n", + "\n", + "在设定好参数之后,调用 ``run`` 函数便可以进行训练和验证了。" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[21:31:16] INFO     Running evaluator sanity check for 2 batches.              trainer.py:631\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[21:31:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=4641;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=822054;file://../fastNLP/core/controllers/trainer.py#631\u001b\\\u001b[2m631\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:60 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m60\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.895833,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1075.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.895833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1075.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:120 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m120\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.8975,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1077.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.8975\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1077.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:180 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m180\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.911667,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1094.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.911667\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1094.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:240 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m240\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.9225,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1107.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9225\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1107.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:300 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.9275,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1113.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9275\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1113.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:60 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m60\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.930833,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1117.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.930833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1117.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:120 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m120\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.935833,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1123.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.935833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1123.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:180 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m180\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.935833,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1123.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.935833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1123.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:240 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m240\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.9375,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1125.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9375\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1125.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:300 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.941667,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1130.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.941667\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1130.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[21:34:28] INFO     Loading best model from fnlp-ernie/2022-0 load_best_model_callback.py:111\n",
+       "                    6-22-21_29_12_898095/best_so_far with                                    \n",
+       "                    acc#accuracy: 0.941667...                                                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[21:34:28]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from fnlp-ernie/\u001b[1;36m2022\u001b[0m-\u001b[1;36m0\u001b[0m \u001b]8;id=340364;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=763898;file://../fastNLP/core/callbacks/load_best_model_callback.py#111\u001b\\\u001b[2m111\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m6\u001b[0m-\u001b[1;36m22\u001b[0m-21_29_12_898095/best_so_far with \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m acc#accuracy: \u001b[1;36m0.941667\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[21:34:34] INFO     Deleting fnlp-ernie/2022-06-22-21_29_12_8 load_best_model_callback.py:131\n",
+       "                    98095/best_so_far...                                                     \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[21:34:34]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Deleting fnlp-ernie/\u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m22\u001b[0m-21_29_12_8 \u001b]8;id=430330;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=508566;file://../fastNLP/core/callbacks/load_best_model_callback.py#131\u001b\\\u001b[2m131\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m 98095/best_so_far\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import LRSchedCallback, LoadBestModelCallback\n", + "from fastNLP import Trainer, Accuracy\n", + "from paddlenlp.transformers import LinearDecayWithWarmup\n", + "\n", + "n_epochs = 2\n", + "num_training_steps = len(train_dataloader) * n_epochs\n", + "lr_scheduler = LinearDecayWithWarmup(5e-5, num_training_steps, 0.1)\n", + "optimizer = paddle.optimizer.AdamW(\n", + " learning_rate=lr_scheduler,\n", + " parameters=model.parameters(),\n", + ")\n", + "callbacks = [\n", + " LRSchedCallback(lr_scheduler, step_on=\"batch\"),\n", + " LoadBestModelCallback(\"acc#accuracy\", larger_better=True, save_folder=\"fnlp-ernie\"),\n", + "]\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver=\"paddle\",\n", + " optimizers=optimizer,\n", + " device=0,\n", + " n_epochs=n_epochs,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=val_dataloader,\n", + " evaluate_every=60,\n", + " metrics={\"accuracy\": Accuracy()},\n", + " callbacks=callbacks,\n", + ")\n", + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.3 测试和评估\n", + "\n", + "现在我们已经得到了一个表现良好的 ``ERNIE`` 模型,接下来可以在测试集上测试模型的效果了。``FastNLP.Evaluator`` 提供了定制函数的功能。我们以 ``test_dataloader`` 初始化一个 ``Evaluator``,然后将写好的测试函数 ``test_batch_step_fn`` 传给参数 ``evaluate_batch_step_fn``,``Evaluate`` 在对每个 batch 进行评估时就会调用我们自定义的 ``test_batch_step_fn`` 函数而不是 ``evaluate_step`` 函数。在这里,我们仅测试 5 条数据并输出文本和对应的标签。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
text: ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般']\n",
+       "
\n" + ], + "text/plain": [ + "text: ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 0\n",
+       "
\n" + ], + "text/plain": [ + "labels: 0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始\n",
+       "还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫\n",
+       "和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一\n",
+       "集算什么??简直是画蛇添足!!']\n",
+       "
\n" + ], + "text/plain": [ + "text: ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始\n", + "还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫\n", + "和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一\n", + "集算什么??简直是画蛇添足!!']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 0\n",
+       "
\n" + ], + "text/plain": [ + "labels: 0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气\n",
+       "泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。'\n",
+       "]\n",
+       "
\n" + ], + "text/plain": [ + "text: ['还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气\n", + "泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。'\n", + "]\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 0\n",
+       "
\n" + ], + "text/plain": [ + "labels: 0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['交通方便;环境很好;服务态度很好 房间较小']\n",
+       "
\n" + ], + "text/plain": [ + "text: ['交通方便;环境很好;服务态度很好 房间较小']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 1\n",
+       "
\n" + ], + "text/plain": [ + "labels: 1\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['不错,作者的观点很颠覆目前中国父母的教育方式,其实古人们对于教育已经有了很系统的体系\n",
+       "了,可是现在的父母以及祖父母们更多的娇惯纵容孩子,放眼看去自私的孩子是大多数,父母觉得自己的\n",
+       "孩子在外面只要不吃亏就是好事,完全把古人几千年总结的教育古训抛在的九霄云外。所以推荐准妈妈们\n",
+       "可以在等待宝宝降临的时候,好好学习一下,怎么把孩子教育成一个有爱心、有责任心、宽容、大度的人\n",
+       "。']\n",
+       "
\n" + ], + "text/plain": [ + "text: ['不错,作者的观点很颠覆目前中国父母的教育方式,其实古人们对于教育已经有了很系统的体系\n", + "了,可是现在的父母以及祖父母们更多的娇惯纵容孩子,放眼看去自私的孩子是大多数,父母觉得自己的\n", + "孩子在外面只要不吃亏就是好事,完全把古人几千年总结的教育古训抛在的九霄云外。所以推荐准妈妈们\n", + "可以在等待宝宝降临的时候,好好学习一下,怎么把孩子教育成一个有爱心、有责任心、宽容、大度的人\n", + "。']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 1\n",
+       "
\n" + ], + "text/plain": [ + "labels: 1\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{}"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "from fastNLP import Evaluator\n",
+    "def test_batch_step_fn(evaluator, batch):\n",
+    "    input_ids = batch[\"input_ids\"]\n",
+    "    attention_mask = batch[\"attention_mask\"]\n",
+    "    token_type_ids = batch[\"token_type_ids\"]\n",
+    "    logits = model(input_ids, attention_mask, token_type_ids)\n",
+    "    predict = logits.argmax().item()\n",
+    "    print(\"text:\", batch['text'])\n",
+    "    print(\"labels:\", predict)\n",
+    "\n",
+    "evaluator = Evaluator(\n",
+    "    model=model,\n",
+    "    dataloaders=test_dataloader,\n",
+    "    driver=\"paddle\",\n",
+    "    device=0,\n",
+    "    evaluate_batch_step_fn=test_batch_step_fn,\n",
+    ")\n",
+    "evaluator.run(5)    "
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3.7.13 ('fnlp-paddle')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.13"
+  },
+  "orig_nbformat": 4,
+  "vscode": {
+   "interpreter": {
+    "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/tutorials/fastnlp_tutorial_paddle_e2.ipynb b/tutorials/fastnlp_tutorial_paddle_e2.ipynb
new file mode 100644
index 00000000..c17be405
--- /dev/null
+++ b/tutorials/fastnlp_tutorial_paddle_e2.ipynb
@@ -0,0 +1,1510 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 使用 paddlenlp 和 FastNLP 训练中文阅读理解任务\n",
+    "\n",
+    "本篇教程属于 **`FastNLP v0.8 tutorial` 的 `paddle examples` 系列**。在本篇教程中,我们将为您展示如何在 `FastNLP` 中通过自定义 `Metric` 和 损失函数来完成进阶的问答任务。\n",
+    "\n",
+    "1. 基础介绍:自然语言处理中的阅读理解任务\n",
+    "\n",
+    "2. 准备工作:加载 `DuReader-robust` 数据集,并使用 `tokenizer` 处理数据\n",
+    "\n",
+    "3. 模型训练:自己定义评测用的 `Metric` 实现更加自由的任务评测"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 1. 基础介绍:自然语言处理中的阅读理解任务\n",
+    "\n",
+    "阅读理解任务,顾名思义,就是给出一段文字,然后让模型理解这段文字所含的语义。大部分机器阅读理解任务都采用问答式测评,即设计与文章内容相关的自然语言式问题,让模型理解问题并根据文章作答。与文本分类任务不同的是,在阅读理解任务中我们有时需要需要输入“一对”句子,分别代表问题和上下文;答案的格式也分为多种:\n",
+    "\n",
+    "- 多项选择:让模型从多个答案选项中选出正确答案\n",
+    "- 区间答案:答案为上下文的一段子句,需要模型给出答案的起始位置\n",
+    "- 自由回答:不做限制,让模型自行生成答案\n",
+    "- 完形填空:在原文中挖空部分关键词,让模型补全;这类答案往往不需要问题\n",
+    "\n",
+    "如果您对 `transformers` 有所了解的话,其中的 `ModelForQuestionAnswering` 系列模型就可以用于这项任务。阅读理解模型的泛用性是衡量该技术能否在实际应用中大规模落地的重要指标之一,随着当前技术的进步,许多模型虽然能够在一些测试集上取得较好的性能,但在实际应用中,这些模型仍然难以让人满意。在本篇教程中,我们将会为您展示如何训练一个问答模型。\n",
+    "\n",
+    "在这一领域,`SQuAD` 数据集是一个影响深远的数据集。它的全称是斯坦福问答数据集(Stanford Question Answering Dataset),每条数据包含 `(问题,上下文,答案)` 三部分,规模大(约十万条,2.0又新增了五万条),在提出之后很快成为训练问答任务的经典数据集之一。`SQuAD` 数据集有两个指标来衡量模型的表现:`EM`(Exact Match,精确匹配)和 `F1`(模糊匹配)。前者反应了模型给出的答案中有多少和正确答案完全一致,后者则反应了模型给出的答案中与正确答案重叠的部分,均为越高越好。"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 2. 准备工作:加载 DuReader-robust 数据集,并使用 tokenizer 处理数据"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/remote-home/shxing/anaconda3/envs/fnlp-paddle/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+      "  from .autonotebook import tqdm as notebook_tqdm\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "2.3.3\n"
+     ]
+    }
+   ],
+   "source": [
+    "import sys\n",
+    "sys.path.append(\"../\")\n",
+    "import paddle\n",
+    "import paddlenlp\n",
+    "\n",
+    "print(paddlenlp.__version__)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "在数据集方面,我们选用 `DuReader-robust` 中文数据集作为训练数据。它是一种抽取式问答数据集,采用 `SQuAD` 数据格式,能够评估真实应用场景下模型的泛用性。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Reusing dataset dureader_robust (/remote-home/shxing/.cache/huggingface/datasets/dureader_robust/plain_text/1.0.0/d462ecadc8c010cee20f57632f1413f272867cd802a91a602df48c7d34eb0c27)\n",
+      "Reusing dataset dureader_robust (/remote-home/shxing/.cache/huggingface/datasets/dureader_robust/plain_text/1.0.0/d462ecadc8c010cee20f57632f1413f272867cd802a91a602df48c7d34eb0c27)\n",
+      "\u001b[32m[2022-06-27 19:22:46,998] [    INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/vocab.txt\u001b[0m\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'id': '0a25cb4bc1ab6f474c699884e04601e4', 'title': '', 'context': '第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。', 'question': '仙剑奇侠传3第几集上天界', 'answers': {'text': ['第35集'], 'answer_start': [0]}}\n",
+      "{'id': '7de192d6adf7d60ba73ba25cf590cc1e', 'title': '', 'context': '选择燃气热水器时,一定要关注这几个问题:1、出水稳定性要好,不能出现忽热忽冷的现象2、快速到达设定的需求水温3、操作要智能、方便4、安全性要好,要装有安全报警装置 市场上燃气热水器品牌众多,购买时还需多加对比和仔细鉴别。方太今年主打的磁化恒温热水器在使用体验方面做了全面升级:9秒速热,可快速进入洗浴模式;水温持久稳定,不会出现忽热忽冷的现象,并通过水量伺服技术将出水温度精确控制在±0.5℃,可满足家里宝贝敏感肌肤洗护需求;配备CO和CH4双气体报警装置更安全(市场上一般多为CO单气体报警)。另外,这款热水器还有智能WIFI互联功能,只需下载个手机APP即可用手机远程操作热水器,实现精准调节水温,满足家人多样化的洗浴需求。当然方太的磁化恒温系列主要的是增加磁化功能,可以有效吸附水中的铁锈、铁屑等微小杂质,防止细菌滋生,使沐浴水质更洁净,长期使用磁化水沐浴更利于身体健康。', 'question': '燃气热水器哪个牌子好', 'answers': {'text': ['方太'], 'answer_start': [110]}}\n",
+      "{'id': 'b9e74d4b9228399b03701d1fe6d52940', 'title': '', 'context': '迈克尔.乔丹在NBA打了15个赛季。他在84年进入nba,期间在1993年10月6日第一次退役改打棒球,95年3月18日重新回归,在99年1月13日第二次退役,后于2001年10月31日复出,在03年最终退役。迈克尔·乔丹(Michael Jordan),1963年2月17日生于纽约布鲁克林,美国著名篮球运动员,司职得分后卫,历史上最伟大的篮球运动员。1984年的NBA选秀大会,乔丹在首轮第3顺位被芝加哥公牛队选中。 1986-87赛季,乔丹场均得到37.1分,首次获得分王称号。1990-91赛季,乔丹连夺常规赛MVP和总决赛MVP称号,率领芝加哥公牛首次夺得NBA总冠军。 1997-98赛季,乔丹获得个人职业生涯第10个得分王,并率领公牛队第六次夺得总冠军。2009年9月11日,乔丹正式入选NBA名人堂。', 'question': '乔丹打了多少个赛季', 'answers': {'text': ['15个'], 'answer_start': [12]}}\n",
+      "训练集大小: 14520\n",
+      "验证集大小: 1417\n"
+     ]
+    }
+   ],
+   "source": [
+    "from paddlenlp.datasets import load_dataset\n",
+    "train_dataset = load_dataset(\"PaddlePaddle/dureader_robust\", splits=\"train\")\n",
+    "val_dataset = load_dataset(\"PaddlePaddle/dureader_robust\", splits=\"validation\")\n",
+    "for i in range(3):\n",
+    "    print(train_dataset[i])\n",
+    "print(\"训练集大小:\", len(train_dataset))\n",
+    "print(\"验证集大小:\", len(val_dataset))\n",
+    "\n",
+    "MODEL_NAME = \"ernie-1.0-base-zh\"\n",
+    "from paddlenlp.transformers import ErnieTokenizer\n",
+    "tokenizer =ErnieTokenizer.from_pretrained(MODEL_NAME)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 2.1 处理训练集\n",
+    "\n",
+    "对于阅读理解任务,数据处理的方式较为麻烦。接下来我们会为您详细讲解处理函数 `_process_train` 的功能,同时也将通过实践展示关于 `tokenizer` 的更多功能,让您更加深入地了解自然语言处理任务。首先让我们向 `tokenizer` 输入一条数据(以列表的形式):"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "2\n",
+      "dict_keys(['offset_mapping', 'input_ids', 'token_type_ids', 'overflow_to_sample'])\n"
+     ]
+    }
+   ],
+   "source": [
+    "result = tokenizer(\n",
+    "    [train_dataset[0][\"question\"]],\n",
+    "    [train_dataset[0][\"context\"]],\n",
+    "    stride=128,\n",
+    "    max_length=256,\n",
+    "    padding=\"max_length\",\n",
+    "    return_dict=False\n",
+    ")\n",
+    "\n",
+    "print(len(result))\n",
+    "print(result[0].keys())"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "首先不难理解的是,模型必须要同时接受问题(`question`)和上下文(`context`)才能够进行阅读理解,因此我们需要将二者同时进行分词(`tokenize`)。所幸,`Tokenizer` 提供了这一功能,当我们调用 `tokenizer` 的时候,其第一个参数名为 `text`,第二个参数名为 `text_pair`,这使得我们可以同时对一对文本进行分词。同时,`tokenizer` 还需要标记出一条数据中哪些属于问题,哪些属于上下文,这一功能则由 `token_type_ids` 完成。`token_type_ids` 会将输入的第一个文本(问题)标记为 `0`,第二个文本(上下文)标记为 `1`,这样模型在训练时便可以将问题和上下文区分开来:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 2]\n",
+      "['[CLS]', '仙', '剑', '奇', '侠', '传', '3', '第', '几', '集', '上', '天', '界', '[SEP]', '第', '35', '集', '雪', '见', '缓', '缓', '张', '开', '眼', '睛', ',', '景', '天', '又', '惊', '又', '喜', '之', '际', ',', '长', '卿', '和', '紫', '萱', '的', '仙', '船', '驶', '至', ',', '见', '众', '人', '无', '恙', ',', '也', '十', '分', '高', '兴', '。', '众', '人', '登', '船', ',', '用', '尽', '合', '力', '把', '自', '身', '的', '真', '气', '和', '水', '分', '输', '给', '她', '。', '雪', '见', '终', '于', '醒', '过', '来', '了', ',', '但', '却', '一', '脸', '木', '然', ',', '全', '无', '反', '应', '。', '众', '人', '向', '常', '胤', '求', '助', ',', '却', '发', '现', '人', '世', '界', '竟', '没', '有', '雪', '见', '的', '身', '世', '纪', '录', '。', '长', '卿', '询', '问', '清', '微', '的', '身', '世', ',', '清', '微', '语', '带', '双', '关', '说', '一', '切', '上', '了', '天', '界', '便', '有', '答', '案', '。', '长', '卿', '驾', '驶', '仙', '船', ',', '众', '人', '决', '定', '立', '马', '动', '身', ',', '往', '天', '界', '而', '去', '。', '众', '人', '来', '到', '一', '荒', '山', ',', '长', '卿', '指', '出', ',', '魔', '界', '和', '天', '界', '相', '连', '。', '由', '魔', '界', '进', '入', '通', '过', '神', '魔', '之', '井', ',', '便', '可', '登', '天', '。', '众', '人', '至', '魔', '界', '入', '口', ',', '仿', '若', '一', '黑', '色', '的', '蝙', '蝠', '洞', ',', '但', '始', '终', '无', '法', '进', '入', '。', '后', '来', '花', '楹', '发', '现', '只', '要', '有', '翅', '膀', '便', '能', '飞', '入', '[SEP]']\n",
+      "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(result[0][\"input_ids\"])\n",
+    "print(tokenizer.convert_ids_to_tokens(result[0][\"input_ids\"]))\n",
+    "print(result[0][\"token_type_ids\"])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "根据上面的输出我们可以看出,`tokenizer` 会将数据开头用 `[CLS]` 标记,用 `[SEP]` 来分割句子。同时,根据 `token_type_ids` 得到的 0、1 串,我们也很容易将问题和上下文区分开。顺带一提,如果一条数据进行了 `padding`,那么这部分会被标记为 `0` 。\n",
+    "\n",
+    "在输出的 `keys` 中还有一项名为 `offset_mapping` 的键。该项数据能够表示分词后的每个 `token` 在原文中对应文字或词语的位置。比如我们可以像下面这样将数据打印出来:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7)]\n",
+      "[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427]\n",
+      "['[CLS]', '仙', '剑', '奇', '侠', '传', '3', '第', '几', '集', '上', '天', '界', '[SEP]', '第', '35', '集', '雪', '见', '缓']\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(result[0][\"offset_mapping\"][:20])\n",
+    "print(result[0][\"input_ids\"][:20])\n",
+    "print(tokenizer.convert_ids_to_tokens(result[0][\"input_ids\"])[:20])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "`[CLS]` 由于是 `tokenizer` 自己添加进去用于标记数据的 `token`,因此它在原文中找不到任何对应的词语,所以给出的位置范围就是 `(0, 0)`;第二个 `token` 对应第一个 `“仙”` 字,因此映射的位置就是 `(0, 1)`;同理,后面的 `[SEP]` 也不对应任何文字,映射的位置为 `(0, 0)`;而接下来的 `token` 对应 **上下文** 中的第一个字 `“第”`,映射出的位置为 `(0, 1)`;再后面的 `token` 对应原文中的两个字符 `35`,因此其位置映射为 `(1, 3)` 。通过这种手段,我们可以更方便地获取 `token` 与原文的对应关系。\n",
+    "\n",
+    "最后,您也许会注意到我们获取的 `result` 长度为 2 。这是文本在分词后长度超过了 `max_length` 256 ,`tokenizer` 将数据分成了两部分所致。在阅读理解任务中,我们不可能像文本分类那样轻易地将一条数据截断,因为答案很可能就出现在后面被丢弃的那部分数据中,因此,我们需要保留所有的数据(当然,您也可以直接丢弃这些超长的数据)。`overflow_to_sample` 则可以标识当前数据在原数据的索引:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[CLS]仙剑奇侠传3第几集上天界[SEP]第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入[SEP]\n",
+      "overflow_to_sample:  0\n",
+      "[CLS]仙剑奇侠传3第几集上天界[SEP]说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]\n",
+      "overflow_to_sample:  0\n"
+     ]
+    }
+   ],
+   "source": [
+    "for res in result:\n",
+    "    tokens = tokenizer.convert_ids_to_tokens(res[\"input_ids\"])\n",
+    "    print(\"\".join(tokens))\n",
+    "    print(\"overflow_to_sample: \", res[\"overflow_to_sample\"])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "将两条数据均输出之后可以看到,它们都出自我们传入的数据,并且存在一部分重合。`tokenizer` 的 `stride` 参数可以设置重合部分的长度,这也可以帮助模型识别被分割开的两条数据;`overflow_to_sample` 的 `0` 则代表它们来自于第 `0` 条数据。\n",
+    "\n",
+    "基于以上信息,我们处理训练集的思路如下:\n",
+    "\n",
+    "1. 通过 `overflow_to_sample` 来获取原来的数据\n",
+    "2. 通过原数据的 `answers` 找到答案的起始位置\n",
+    "3. 通过 `offset_mapping` 给出的映射关系在分词处理后的数据中找到答案的起始位置,分别记录在 `start_pos` 和 `end_pos` 中;如果没有找到答案(比如答案被截断了),那么答案的起始位置就被标记为 `[CLS]` 的位置。\n",
+    "\n",
+    "这样 `_process_train` 函数就呼之欲出了,我们调用 `train_dataset.map` 函数,并将 `batched` 参数设置为 `True` ,将所有数据批量地进行更新。有一点需要注意的是,**在处理过后数据量会增加**。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'offset_mapping': [(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (81, 82), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 88), (88, 89), (89, 90), (90, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (99, 100), (100, 101), (101, 102), (102, 103), (103, 104), (104, 105), (105, 106), (106, 107), (107, 108), (108, 109), (109, 110), (110, 111), (111, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 117), (117, 118), (118, 119), (119, 120), (120, 121), (121, 122), (122, 123), (123, 124), (124, 125), (125, 126), (126, 127), (127, 128), (128, 129), (129, 130), (130, 131), (131, 132), (132, 133), (133, 134), (134, 135), (135, 136), (136, 137), (137, 138), (138, 139), (139, 140), (140, 141), (141, 142), (142, 143), (143, 144), (144, 145), (145, 146), (146, 147), (147, 148), (148, 149), (149, 150), (150, 151), (151, 152), (152, 153), (153, 154), (154, 155), (155, 156), (156, 157), (157, 158), (158, 159), (159, 160), (160, 161), (161, 162), (162, 163), (163, 164), (164, 165), (165, 166), (166, 167), (167, 168), (168, 169), (169, 170), (170, 171), (171, 172), (172, 173), (173, 174), (174, 175), (175, 176), (176, 177), (177, 178), (178, 179), (179, 180), (180, 181), (181, 182), (182, 183), (183, 184), (184, 185), (185, 186), (186, 187), (187, 188), (188, 189), (189, 190), (190, 191), (191, 192), (192, 193), (193, 194), (194, 195), (195, 196), (196, 197), (197, 198), (198, 199), (199, 200), (200, 201), (201, 202), (202, 203), (203, 204), (204, 205), (205, 206), (206, 207), (207, 208), (208, 209), (209, 210), (210, 211), (211, 212), (212, 213), (213, 214), (214, 215), (215, 216), (216, 217), (217, 218), (218, 219), (219, 220), (220, 221), (221, 222), (222, 223), (223, 224), (224, 225), (225, 226), (226, 227), (227, 228), (228, 229), (229, 230), (230, 231), (231, 232), (232, 233), (233, 234), (234, 235), (235, 236), (236, 237), (237, 238), (238, 239), (239, 240), (240, 241), (241, 242), (0, 0)], 'input_ids': [1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'overflow_to_sample': 0, 'start_pos': 14, 'end_pos': 16}\n",
+      "处理后的训练集大小: 26198\n"
+     ]
+    }
+   ],
+   "source": [
+    "max_length = 256\n",
+    "doc_stride = 128\n",
+    "def _process_train(data):\n",
+    "\n",
+    "    contexts = [data[i][\"context\"] for i in range(len(data))]\n",
+    "    questions = [data[i][\"question\"] for i in range(len(data))]\n",
+    "\n",
+    "    tokenized_data_list = tokenizer(\n",
+    "        questions,\n",
+    "        contexts,\n",
+    "        stride=doc_stride,\n",
+    "        max_length=max_length,\n",
+    "        padding=\"max_length\",\n",
+    "        return_dict=False\n",
+    "    )\n",
+    "\n",
+    "    for i, tokenized_data in enumerate(tokenized_data_list):\n",
+    "        # 获取 [CLS] 对应的位置\n",
+    "        input_ids = tokenized_data[\"input_ids\"]\n",
+    "        cls_index = input_ids.index(tokenizer.cls_token_id)\n",
+    "\n",
+    "        # 在 tokenize 的过程中,汉字和 token 在位置上并非一一对应的\n",
+    "        # 而 offset mapping 记录了每个 token 在原文中对应的起始位置\n",
+    "        offsets = tokenized_data[\"offset_mapping\"]\n",
+    "        # token_type_ids 记录了一条数据中哪些是问题,哪些是上下文\n",
+    "        token_type_ids = tokenized_data[\"token_type_ids\"]\n",
+    "\n",
+    "        # 一条数据可能因为长度过长而在 tokenized_data 中存在多个结果\n",
+    "        # overflow_to_sample 表示了当前 tokenize_example 属于 data 中的哪一条数据\n",
+    "        sample_index = tokenized_data[\"overflow_to_sample\"]\n",
+    "        answers = data[sample_index][\"answers\"]\n",
+    "\n",
+    "        # answers 和 answer_starts 均为长度为 1 的 list\n",
+    "        # 我们可以计算出答案的结束位置\n",
+    "        start_char = answers[\"answer_start\"][0]\n",
+    "        end_char = start_char + len(answers[\"text\"][0])\n",
+    "\n",
+    "        token_start_index = 0\n",
+    "        while token_type_ids[token_start_index] != 1:\n",
+    "            token_start_index += 1\n",
+    "\n",
+    "        token_end_index = len(input_ids) - 1\n",
+    "        while token_type_ids[token_end_index] != 1:\n",
+    "            token_end_index -= 1\n",
+    "        # 分词后一条数据的结尾一定是 [SEP],因此还需要减一\n",
+    "        token_end_index -= 1\n",
+    "\n",
+    "        if not (offsets[token_start_index][0] <= start_char and\n",
+    "                offsets[token_end_index][1] >= end_char):\n",
+    "            # 如果答案不在这条数据中,则将答案位置标记为 [CLS] 的位置\n",
+    "            tokenized_data_list[i][\"start_pos\"] = cls_index\n",
+    "            tokenized_data_list[i][\"end_pos\"] = cls_index\n",
+    "        else:\n",
+    "            # 否则,我们可以找到答案对应的 token 的起始位置,记录在 start_pos 和 end_pos 中\n",
+    "            while token_start_index < len(offsets) and offsets[\n",
+    "                    token_start_index][0] <= start_char:\n",
+    "                token_start_index += 1\n",
+    "            tokenized_data_list[i][\"start_pos\"] = token_start_index - 1\n",
+    "            while offsets[token_end_index][1] >= end_char:\n",
+    "                token_end_index -= 1\n",
+    "            tokenized_data_list[i][\"end_pos\"] = token_end_index + 1\n",
+    "\n",
+    "    return tokenized_data_list\n",
+    "\n",
+    "train_dataset.map(_process_train, batched=True, num_workers=5)\n",
+    "print(train_dataset[0])\n",
+    "print(\"处理后的训练集大小:\", len(train_dataset))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 2.2 处理验证集\n",
+    "\n",
+    "对于验证集的处理则简单得多,我们只需要保存原数据的 `id` 并将 `offset_mapping` 中不属于上下文的部分设置为 `None` 即可。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       ""
+      ]
+     },
+     "execution_count": 8,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "def _process_val(data):\n",
+    "\n",
+    "    contexts = [data[i][\"context\"] for i in range(len(data))]\n",
+    "    questions = [data[i][\"question\"] for i in range(len(data))]\n",
+    "\n",
+    "    tokenized_data_list = tokenizer(\n",
+    "        questions,\n",
+    "        contexts,\n",
+    "        stride=doc_stride,\n",
+    "        max_length=max_length,\n",
+    "        return_dict=False\n",
+    "    )\n",
+    "\n",
+    "    for i, tokenized_data in enumerate(tokenized_data_list):\n",
+    "        token_type_ids = tokenized_data[\"token_type_ids\"]\n",
+    "        # 保存数据对应的 id\n",
+    "        sample_index = tokenized_data[\"overflow_to_sample\"]\n",
+    "        tokenized_data_list[i][\"example_id\"] = data[sample_index][\"id\"]\n",
+    "\n",
+    "        # 将不属于 context 的 offset 设置为 None\n",
+    "        tokenized_data_list[i][\"offset_mapping\"] = [\n",
+    "            (o if token_type_ids[k] == 1 else None)\n",
+    "            for k, o in enumerate(tokenized_data[\"offset_mapping\"])\n",
+    "        ]\n",
+    "\n",
+    "    return tokenized_data_list\n",
+    "\n",
+    "val_dataset.map(_process_val, batched=True, num_workers=5)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 2.3 DataLoader\n",
+    "\n",
+    "最后使用 `PaddleDataLoader` 将数据集包裹起来即可。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP.core import PaddleDataLoader\n", + "\n", + "train_dataloader = PaddleDataLoader(train_dataset, batch_size=32, shuffle=True)\n", + "val_dataloader = PaddleDataLoader(val_dataset, batch_size=16)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 模型训练:自己定义评测用的 Metric 实现更加自由的任务评测\n", + "\n", + "#### 3.1 损失函数\n", + "\n", + "对于阅读理解任务,我们使用的是 `ErnieForQuestionAnswering` 模型。该模型在接受输入后会返回两个值:`start_logits` 和 `end_logits` ,大小均为 `(batch_size, sequence_length)`,反映了每条数据每个词语为答案起始位置的可能性,因此我们需要自定义一个损失函数来计算 `loss`。 `CrossEntropyLossForSquad` 会分别对答案起始位置的预测值和真实值计算交叉熵,最后返回其平均值作为最终的损失。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class CrossEntropyLossForSquad(paddle.nn.Layer):\n", + " def __init__(self):\n", + " super(CrossEntropyLossForSquad, self).__init__()\n", + "\n", + " def forward(self, start_logits, end_logits, start_pos, end_pos):\n", + " start_pos = paddle.unsqueeze(start_pos, axis=-1)\n", + " end_pos = paddle.unsqueeze(end_pos, axis=-1)\n", + " start_loss = paddle.nn.functional.softmax_with_cross_entropy(\n", + " logits=start_logits, label=start_pos)\n", + " start_loss = paddle.mean(start_loss)\n", + " end_loss = paddle.nn.functional.softmax_with_cross_entropy(\n", + " logits=end_logits, label=end_pos)\n", + " end_loss = paddle.mean(end_loss)\n", + "\n", + " loss = (start_loss + end_loss) / 2\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.2 定义模型\n", + "\n", + "模型的核心则是 `ErnieForQuestionAnswering` 的 `ernie-1.0-base-zh` 预训练模型,同时按照 `FastNLP` 的规定定义 `train_step` 和 `evaluate_step` 函数。这里 `evaluate_step` 函数并没有像文本分类那样直接返回该批次数据的评测结果,这一点我们将在下面为您讲解。" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m[2022-06-27 19:00:15,825] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/ernie_v1_chn_base.pdparams\u001b[0m\n", + "W0627 19:00:15.831080 21543 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.5, Driver API Version: 11.2, Runtime API Version: 11.2\n", + "W0627 19:00:15.843276 21543 gpu_context.cc:306] device: 0, cuDNN Version: 8.1.\n" + ] + } + ], + "source": [ + "from paddlenlp.transformers import ErnieForQuestionAnswering\n", + "\n", + "class QAModel(paddle.nn.Layer):\n", + " def __init__(self, model_checkpoint):\n", + " super(QAModel, self).__init__()\n", + " self.model = ErnieForQuestionAnswering.from_pretrained(model_checkpoint)\n", + " self.loss_func = CrossEntropyLossForSquad()\n", + "\n", + " def forward(self, input_ids, token_type_ids):\n", + " start_logits, end_logits = self.model(input_ids, token_type_ids)\n", + " return start_logits, end_logits\n", + "\n", + " def train_step(self, input_ids, token_type_ids, start_pos, end_pos):\n", + " start_logits, end_logits = self(input_ids, token_type_ids)\n", + " loss = self.loss_func(start_logits, end_logits, start_pos, end_pos)\n", + " return {\"loss\": loss}\n", + "\n", + " def evaluate_step(self, input_ids, token_type_ids):\n", + " start_logits, end_logits = self(input_ids, token_type_ids)\n", + " return {\"start_logits\": start_logits, \"end_logits\": end_logits}\n", + "\n", + "model = QAModel(MODEL_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.3 自定义 Metric 进行数据的评估\n", + "\n", + "`paddlenlp` 为我们提供了评测 `SQuAD` 格式数据集的函数 `compute_prediction` 和 `squad_evaluate`:\n", + "- `compute_prediction` 函数要求传入原数据 `examples` 、处理后的数据 `features` 和 `features` 对应的结果 `predictions`(一个包含所有数据 `start_logits` 和 `end_logits` 的元组)\n", + "- `squad_evaluate` 要求传入原数据 `examples` 和预测结果 `all_predictions`(通常来自于 `compute_prediction`)\n", + "\n", + "在使用这两个函数的时候,我们需要向其中传入数据集,但显然根据 `fastNLP` 的设计,我们无法在 `evaluate_step` 里实现这一过程,并且 `FastNLP` 也并没有提供计算 `F1` 和 `EM` 的 `Metric`,故我们需要自己定义用于评测的 `Metric`。\n", + "\n", + "在初始化之外,一个 `Metric` 还需要实现三个函数:\n", + "\n", + "1. `reset` - 该函数会在验证数据集的迭代之前被调用,用于清空数据;在我们自定义的 `Metric` 中,我们需要将 `all_start_logits` 和 `all_end_logits` 清空,重新收集每个 `batch` 的结果。\n", + "2. `update` - 该函数会在在每个 `batch` 得到结果后被调用,用于更新 `Metric` 的状态;它的参数即为 `evaluate_step` 返回的内容。我们在这里将得到的 `start_logits` 和 `end_logits` 收集起来。\n", + "3. `get_metric` - 该函数会在数据集被迭代完毕后调用,用于计算评测的结果。现在我们有了整个验证集的 `all_start_logits` 和 `all_end_logits` ,将他们传入 `compute_predictions` 函数得到预测的结果,并继续使用 `squad_evaluate` 函数得到评测的结果。\n", + " - 注:`suqad_evaluate` 函数会自己输出评测结果,为了不让其干扰 `FastNLP` 输出,这里我们使用 `contextlib.redirect_stdout(None)` 将函数的标准输出屏蔽掉。\n", + "\n", + "综上,`SquadEvaluateMetric` 实现的评估过程是:将验证集中所有数据的 `logits` 收集起来,然后统一传入 `compute_prediction` 和 `squad_evaluate` 中进行评估。值得一提的是,`paddlenlp.datasets.load_dataset` 返回的结果是一个 `MapDataset` 类型,其 `data` 成员为加载时的数据,`new_data` 为经过 `map` 函数处理后更新的数据,因此可以分别作为 `examples` 和 `features` 传入。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP.core import Metric\n", + "from paddlenlp.metrics.squad import squad_evaluate, compute_prediction\n", + "import contextlib\n", + "\n", + "class SquadEvaluateMetric(Metric):\n", + " def __init__(self, examples, features, testing=False):\n", + " super(SquadEvaluateMetric, self).__init__(\"paddle\", False)\n", + " self.examples = examples\n", + " self.features = features\n", + " self.all_start_logits = []\n", + " self.all_end_logits = []\n", + " self.testing = testing\n", + "\n", + " def reset(self):\n", + " self.all_start_logits = []\n", + " self.all_end_logits = []\n", + "\n", + " def update(self, start_logits, end_logits):\n", + " for start, end in zip(start_logits, end_logits):\n", + " self.all_start_logits.append(start.numpy())\n", + " self.all_end_logits.append(end.numpy())\n", + "\n", + " def get_metric(self):\n", + " all_predictions, _, _ = compute_prediction(\n", + " self.examples, self.features[:len(self.all_start_logits)],\n", + " (self.all_start_logits, self.all_end_logits),\n", + " False, 20, 30\n", + " )\n", + " with contextlib.redirect_stdout(None):\n", + " result = squad_evaluate(\n", + " examples=self.examples,\n", + " preds=all_predictions,\n", + " is_whitespace_splited=False\n", + " )\n", + "\n", + " if self.testing:\n", + " self.print_predictions(all_predictions)\n", + " return result\n", + "\n", + " def print_predictions(self, preds):\n", + " for i, data in enumerate(self.examples):\n", + " if i >= 5:\n", + " break\n", + " print()\n", + " print(\"原文:\", data[\"context\"])\n", + " print(\"问题:\", data[\"question\"], \\\n", + " \"答案:\", preds[data[\"id\"]], \\\n", + " \"正确答案:\", data[\"answers\"][\"text\"])\n", + "\n", + "metric = SquadEvaluateMetric(\n", + " val_dataloader.dataset.data,\n", + " val_dataloader.dataset.new_data,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.4 训练\n", + "\n", + "至此所有的准备工作已经完成,可以使用 `Trainer` 进行训练了。学习率我们依旧采用线性预热策略 `LinearDecayWithWarmup`,优化器为 `AdamW`;回调模块我们选择 `LRSchedCallback` 更新学习率和 `LoadBestModelCallback` 监视评测结果的 `f1` 分数。初始化好 `Trainer` 之后,就将训练的过程交给 `FastNLP` 吧。" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[19:04:54] INFO     Running evaluator sanity check for 2 batches.              trainer.py:631\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[19:04:54]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=367046;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=96810;file://../fastNLP/core/controllers/trainer.py#631\u001b\\\u001b[2m631\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:100 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m100\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 49.25899788285109,\n",
+       "  \"f1#squad\": 66.55559127349602,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 49.25899788285109,\n",
+       "  \"HasAns_f1#squad\": 66.55559127349602,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m49.25899788285109\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m66.55559127349602\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m49.25899788285109\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m66.55559127349602\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:200 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m200\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 57.37473535638673,\n",
+       "  \"f1#squad\": 70.93036525200617,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 57.37473535638673,\n",
+       "  \"HasAns_f1#squad\": 70.93036525200617,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m57.37473535638673\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m70.93036525200617\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m57.37473535638673\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m70.93036525200617\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:300 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 63.86732533521524,\n",
+       "  \"f1#squad\": 78.62546663568186,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 63.86732533521524,\n",
+       "  \"HasAns_f1#squad\": 78.62546663568186,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m63.86732533521524\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m78.62546663568186\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m63.86732533521524\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m78.62546663568186\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:400 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m400\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 64.92589978828511,\n",
+       "  \"f1#squad\": 79.36746074079691,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 64.92589978828511,\n",
+       "  \"HasAns_f1#squad\": 79.36746074079691,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m64.92589978828511\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.36746074079691\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m64.92589978828511\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.36746074079691\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:500 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m500\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 65.70218772053634,\n",
+       "  \"f1#squad\": 80.33295482054824,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 65.70218772053634,\n",
+       "  \"HasAns_f1#squad\": 80.33295482054824,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:600 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m600\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 65.41990119971771,\n",
+       "  \"f1#squad\": 79.7483487059053,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 65.41990119971771,\n",
+       "  \"HasAns_f1#squad\": 79.7483487059053,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.41990119971771\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.7483487059053\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.41990119971771\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.7483487059053\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:700 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m700\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 66.61961891319689,\n",
+       "  \"f1#squad\": 80.32432238994133,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 66.61961891319689,\n",
+       "  \"HasAns_f1#squad\": 80.32432238994133,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m66.61961891319689\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m80.32432238994133\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m66.61961891319689\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m80.32432238994133\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:800 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m800\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 65.84333098094567,\n",
+       "  \"f1#squad\": 79.23169801265415,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 65.84333098094567,\n",
+       "  \"HasAns_f1#squad\": 79.23169801265415,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.84333098094567\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.23169801265415\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.84333098094567\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.23169801265415\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[19:20:28] INFO     Loading best model from fnlp-ernie-squad/ load_best_model_callback.py:111\n",
+       "                    2022-06-27-19_00_15_388554/best_so_far                                   \n",
+       "                    with f1#squad: 80.33295482054824...                                      \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[19:20:28]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from fnlp-ernie-squad/ \u001b]8;id=163935;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=31503;file://../fastNLP/core/callbacks/load_best_model_callback.py#111\u001b\\\u001b[2m111\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m27\u001b[0m-19_00_15_388554/best_so_far \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m with f1#squad: \u001b[1;36m80.33295482054824\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
           INFO     Deleting fnlp-ernie-squad/2022-06-27-19_0 load_best_model_callback.py:131\n",
+       "                    0_15_388554/best_so_far...                                               \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Deleting fnlp-ernie-squad/\u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m27\u001b[0m-19_0 \u001b]8;id=560859;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=573263;file://../fastNLP/core/callbacks/load_best_model_callback.py#131\u001b\\\u001b[2m131\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m 0_15_388554/best_so_far\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import Trainer, LRSchedCallback, LoadBestModelCallback\n", + "from paddlenlp.transformers import LinearDecayWithWarmup\n", + "\n", + "n_epochs = 1\n", + "num_training_steps = len(train_dataloader) * n_epochs\n", + "lr_scheduler = LinearDecayWithWarmup(3e-5, num_training_steps, 0.1)\n", + "optimizer = paddle.optimizer.AdamW(\n", + " learning_rate=lr_scheduler,\n", + " parameters=model.parameters(),\n", + ")\n", + "callbacks=[\n", + " LRSchedCallback(lr_scheduler, step_on=\"batch\"),\n", + " LoadBestModelCallback(\"f1#squad\", larger_better=True, save_folder=\"fnlp-ernie-squad\")\n", + "]\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=val_dataloader,\n", + " device=1,\n", + " optimizers=optimizer,\n", + " n_epochs=n_epochs,\n", + " callbacks=callbacks,\n", + " evaluate_every=100,\n", + " metrics={\"squad\": metric},\n", + ")\n", + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.5 测试\n", + "\n", + "最后,我们可以使用 `Evaluator` 查看我们训练的结果。我们在之前为 `SquadEvaluateMetric` 设置了 `testing` 参数来在测试阶段进行输出,可以看到,训练的结果还是比较不错的。" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 爬行垫根据中间材料的不同可以分为:XPE爬行垫、EPE爬行垫、EVA爬行垫、PVC爬行垫;其中XPE爬\n",
+       "行垫、EPE爬行垫都属于PE材料加保鲜膜复合而成,都是无异味的环保材料,但是XPE爬行垫是品质较好的爬\n",
+       "行垫,韩国进口爬行垫都是这种爬行垫,而EPE爬行垫是国内厂家为了减低成本,使用EPE(珍珠棉)作为原料生\n",
+       "产的一款爬行垫,该材料弹性差,易碎,开孔发泡防水性弱。EVA爬行垫、PVC爬行垫是用EVA或PVC作为原材料\n",
+       "与保鲜膜复合的而成的爬行垫,或者把图案转印在原材料上,这两款爬行垫通常有异味,如果是图案转印的爬\n",
+       "行垫,油墨外露容易脱落。 \n",
+       "当时我儿子爬的时候,我们也买了垫子,但是始终有味。最后就没用了,铺的就的薄毯子让他爬。\n",
+       "
\n" + ], + "text/plain": [ + "原文: 爬行垫根据中间材料的不同可以分为:XPE爬行垫、EPE爬行垫、EVA爬行垫、PVC爬行垫;其中XPE爬\n", + "行垫、EPE爬行垫都属于PE材料加保鲜膜复合而成,都是无异味的环保材料,但是XPE爬行垫是品质较好的爬\n", + "行垫,韩国进口爬行垫都是这种爬行垫,而EPE爬行垫是国内厂家为了减低成本,使用EPE(珍珠棉)作为原料生\n", + "产的一款爬行垫,该材料弹性差,易碎,开孔发泡防水性弱。EVA爬行垫、PVC爬行垫是用EVA或PVC作为原材料\n", + "与保鲜膜复合的而成的爬行垫,或者把图案转印在原材料上,这两款爬行垫通常有异味,如果是图案转印的爬\n", + "行垫,油墨外露容易脱落。 \n", + "当时我儿子爬的时候,我们也买了垫子,但是始终有味。最后就没用了,铺的就的薄毯子让他爬。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 爬行垫什么材质的好 答案: EPE(珍珠棉 正确答案: ['XPE']\n",
+       "
\n" + ], + "text/plain": [ + "问题: 爬行垫什么材质的好 答案: EPE(珍珠棉 正确答案: ['XPE']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 真实情况是160-162。她平时谎报的168是因为不离脚穿高水台恨天高(15厘米) 图1她穿着高水台恨\n",
+       "天高和刘亦菲一样高,(刘亦菲对外报身高172)范冰冰礼服下厚厚的高水台暴露了她的心机,对比一下两者的\n",
+       "鞋子吧 图2 穿着高水台恨天高才和刘德华谢霆锋持平,如果她真的有168,那么加上鞋高,刘和谢都要有180?\n",
+       "明显是不可能的。所以刘德华对外报的身高174减去10-15厘米才是范冰冰的真实身高 图3,范冰冰有一次脱\n",
+       "鞋上场,这个最说明问题了,看看她的身体比例吧。还有目测一下她手上鞋子的鞋跟有多高多厚吧,至少超过\n",
+       "10厘米。\n",
+       "
\n" + ], + "text/plain": [ + "原文: 真实情况是160-162。她平时谎报的168是因为不离脚穿高水台恨天高(15厘米) 图1她穿着高水台恨\n", + "天高和刘亦菲一样高,(刘亦菲对外报身高172)范冰冰礼服下厚厚的高水台暴露了她的心机,对比一下两者的\n", + "鞋子吧 图2 穿着高水台恨天高才和刘德华谢霆锋持平,如果她真的有168,那么加上鞋高,刘和谢都要有180?\n", + "明显是不可能的。所以刘德华对外报的身高174减去10-15厘米才是范冰冰的真实身高 图3,范冰冰有一次脱\n", + "鞋上场,这个最说明问题了,看看她的身体比例吧。还有目测一下她手上鞋子的鞋跟有多高多厚吧,至少超过\n", + "10厘米。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 范冰冰多高真实身高 答案: 160-162 正确答案: ['160-162']\n",
+       "
\n" + ], + "text/plain": [ + "问题: 范冰冰多高真实身高 答案: 160-162 正确答案: ['160-162']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 防水作为目前高端手机的标配,特别是苹果也支持防水之后,国产大多数高端旗舰手机都已经支持防\n",
+       "水。虽然我们真的不会故意把手机放入水中,但是有了防水之后,用户心里会多一重安全感。那么近日最为\n",
+       "火热的小米6防水吗?小米6的防水级别又是多少呢? 小编查询了很多资料发现,小米6确实是防水的,但是为\n",
+       "了保持低调,同时为了不被别人说防水等级不够,很多资料都没有标注小米是否防水。根据评测资料显示,小\n",
+       "米6是支持IP68级的防水,是绝对能够满足日常生活中的防水需求的。\n",
+       "
\n" + ], + "text/plain": [ + "原文: 防水作为目前高端手机的标配,特别是苹果也支持防水之后,国产大多数高端旗舰手机都已经支持防\n", + "水。虽然我们真的不会故意把手机放入水中,但是有了防水之后,用户心里会多一重安全感。那么近日最为\n", + "火热的小米6防水吗?小米6的防水级别又是多少呢? 小编查询了很多资料发现,小米6确实是防水的,但是为\n", + "了保持低调,同时为了不被别人说防水等级不够,很多资料都没有标注小米是否防水。根据评测资料显示,小\n", + "米6是支持IP68级的防水,是绝对能够满足日常生活中的防水需求的。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 小米6防水等级 答案: IP68级 正确答案: ['IP68级']\n",
+       "
\n" + ], + "text/plain": [ + "问题: 小米6防水等级 答案: IP68级 正确答案: ['IP68级']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 这位朋友你好,女性出现妊娠反应一般是从6-12周左右,也就是女性怀孕1个多月就会开始出现反应,\n",
+       "第3个月的时候,妊辰反应基本结束。 而大部分女性怀孕初期都会出现恶心、呕吐的感觉,这些症状都是因\n",
+       "人而异的,除非恶心、呕吐的非常厉害,才需要就医,否则这些都是刚怀孕的的正常症状。1-3个月的时候可\n",
+       "以观察一下自己的皮肤,一般女性怀孕初期可能会产生皮肤色素沉淀或是腹壁产生妊娠纹,特别是在怀孕的\n",
+       "后期更加明显。 还有很多女性怀孕初期会出现疲倦、嗜睡的情况。怀孕三个月的时候,膀胱会受到日益胀\n",
+       "大的子宫的压迫,容量会变小,所以怀孕期间也会有尿频的现象出现。月经停止也是刚怀孕最容易出现的症\n",
+       "状,只要是平时月经正常的女性,在性行为后超过正常经期两周,就有可能是怀孕了。 如果你想判断自己是\n",
+       "否怀孕,可以看看自己有没有这些反应。当然这也只是多数人的怀孕表现,也有部分女性怀孕表现并不完全\n",
+       "是这样,如果你无法确定自己是否怀孕,最好去医院检查一下。\n",
+       "
\n" + ], + "text/plain": [ + "原文: 这位朋友你好,女性出现妊娠反应一般是从6-12周左右,也就是女性怀孕1个多月就会开始出现反应,\n", + "第3个月的时候,妊辰反应基本结束。 而大部分女性怀孕初期都会出现恶心、呕吐的感觉,这些症状都是因\n", + "人而异的,除非恶心、呕吐的非常厉害,才需要就医,否则这些都是刚怀孕的的正常症状。1-3个月的时候可\n", + "以观察一下自己的皮肤,一般女性怀孕初期可能会产生皮肤色素沉淀或是腹壁产生妊娠纹,特别是在怀孕的\n", + "后期更加明显。 还有很多女性怀孕初期会出现疲倦、嗜睡的情况。怀孕三个月的时候,膀胱会受到日益胀\n", + "大的子宫的压迫,容量会变小,所以怀孕期间也会有尿频的现象出现。月经停止也是刚怀孕最容易出现的症\n", + "状,只要是平时月经正常的女性,在性行为后超过正常经期两周,就有可能是怀孕了。 如果你想判断自己是\n", + "否怀孕,可以看看自己有没有这些反应。当然这也只是多数人的怀孕表现,也有部分女性怀孕表现并不完全\n", + "是这样,如果你无法确定自己是否怀孕,最好去医院检查一下。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 怀孕多久会有反应 答案: 6-12周左右 正确答案: ['6-12周左右', '6-12周', '1个多月']\n",
+       "
\n" + ], + "text/plain": [ + "问题: 怀孕多久会有反应 答案: 6-12周左右 正确答案: ['6-12周左右', '6-12周', '1个多月']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 【东奥会计在线——中级会计职称频道推荐】根据《关于提高科技型中小企业研究开发费用税前加计\n",
+       "扣除比例的通知》的规定,研发费加计扣除比例提高到75%。|财政部、国家税务总局、科技部发布《关于提\n",
+       "高科技型中小企业研究开发费用税前加计扣除比例的通知》。|通知称,为进一步激励中小企业加大研发投\n",
+       "入,支持科技创新,就提高科技型中小企业研究开发费用(以下简称研发费用)税前加计扣除比例有关问题发\n",
+       "布通知。|通知明确,科技型中小企业开展研发活动中实际发生的研发费用,未形成无形资产计入当期损益的\n",
+       ",在按规定据实扣除的基础上,在2017年1月1日至2019年12月31日期间,再按照实际发生额的75%在税前加计\n",
+       "扣除;形成无形资产的,在上述期间按照无形资产成本的175%在税前摊销。|科技型中小企业享受研发费用税\n",
+       "前加计扣除政策的其他政策口径按照《财政部国家税务总局科技部关于完善研究开发费用税前加计扣除政\n",
+       "策的通知》(财税〔2015〕119号)规定执行。|科技型中小企业条件和管理办法由科技部、财政部和国家税\n",
+       "务总局另行发布。科技、财政和税务部门应建立信息共享机制,及时共享科技型中小企业的相关信息,加强\n",
+       "协调配合,保障优惠政策落实到位。|上一篇文章:关于2016年度企业研究开发费用税前加计扣除政策企业所\n",
+       "得税纳税申报问题的公告 下一篇文章:关于提高科技型中小企业研究开发费用税前加计扣除比例的通知\n",
+       "
\n" + ], + "text/plain": [ + "原文: 【东奥会计在线——中级会计职称频道推荐】根据《关于提高科技型中小企业研究开发费用税前加计\n", + "扣除比例的通知》的规定,研发费加计扣除比例提高到75%。|财政部、国家税务总局、科技部发布《关于提\n", + "高科技型中小企业研究开发费用税前加计扣除比例的通知》。|通知称,为进一步激励中小企业加大研发投\n", + "入,支持科技创新,就提高科技型中小企业研究开发费用(以下简称研发费用)税前加计扣除比例有关问题发\n", + "布通知。|通知明确,科技型中小企业开展研发活动中实际发生的研发费用,未形成无形资产计入当期损益的\n", + ",在按规定据实扣除的基础上,在2017年1月1日至2019年12月31日期间,再按照实际发生额的75%在税前加计\n", + "扣除;形成无形资产的,在上述期间按照无形资产成本的175%在税前摊销。|科技型中小企业享受研发费用税\n", + "前加计扣除政策的其他政策口径按照《财政部国家税务总局科技部关于完善研究开发费用税前加计扣除政\n", + "策的通知》(财税〔2015〕119号)规定执行。|科技型中小企业条件和管理办法由科技部、财政部和国家税\n", + "务总局另行发布。科技、财政和税务部门应建立信息共享机制,及时共享科技型中小企业的相关信息,加强\n", + "协调配合,保障优惠政策落实到位。|上一篇文章:关于2016年度企业研究开发费用税前加计扣除政策企业所\n", + "得税纳税申报问题的公告 下一篇文章:关于提高科技型中小企业研究开发费用税前加计扣除比例的通知\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 研发费用加计扣除比例 答案: 75% 正确答案: ['75%']\n",
+       "
\n" + ], + "text/plain": [ + "问题: 研发费用加计扣除比例 答案: 75% 正确答案: ['75%']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
{\n",
+       "    'exact#squad': 65.70218772053634,\n",
+       "    'f1#squad': 80.33295482054824,\n",
+       "    'total#squad': 1417,\n",
+       "    'HasAns_exact#squad': 65.70218772053634,\n",
+       "    'HasAns_f1#squad': 80.33295482054824,\n",
+       "    'HasAns_total#squad': 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[32m'exact#squad'\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[32m'f1#squad'\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[32m'total#squad'\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[32m'HasAns_exact#squad'\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[32m'HasAns_f1#squad'\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[32m'HasAns_total#squad'\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import Evaluator\n", + "evaluator = Evaluator(\n", + " model=model,\n", + " dataloaders=val_dataloader,\n", + " device=1,\n", + " metrics={\n", + " \"squad\": SquadEvaluateMetric(\n", + " val_dataloader.dataset.data,\n", + " val_dataloader.dataset.new_data,\n", + " testing=True,\n", + " ),\n", + " },\n", + ")\n", + "result = evaluator.run()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.7.13 ('fnlp-paddle')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/figures/paddle-ernie-1.0-masking-levels.png b/tutorials/figures/paddle-ernie-1.0-masking-levels.png new file mode 100644 index 00000000..ff2519c4 Binary files /dev/null and b/tutorials/figures/paddle-ernie-1.0-masking-levels.png differ diff --git a/tutorials/figures/paddle-ernie-1.0-masking.png b/tutorials/figures/paddle-ernie-1.0-masking.png new file mode 100644 index 00000000..ed003a2f Binary files /dev/null and b/tutorials/figures/paddle-ernie-1.0-masking.png differ diff --git a/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png b/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png new file mode 100644 index 00000000..d45f65d8 Binary files /dev/null and b/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png differ diff --git a/tutorials/figures/paddle-ernie-3.0-framework.png b/tutorials/figures/paddle-ernie-3.0-framework.png new file mode 100644 index 00000000..f50ddb1c Binary files /dev/null and b/tutorials/figures/paddle-ernie-3.0-framework.png differ