You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tutorial_9_callback.rst 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. ===================================================
  2. 使用 Callback 自定义你的训练过程
  3. ===================================================
  4. - `什么是Callback`_
  5. - `使用 Callback`_
  6. - `fastNLP 中的 Callback`_
  7. - `自定义 Callback`_
  8. 什么是Callback
  9. ---------------------
  10. :class:`~fastNLP.core.callback.Callback` 是与 :class:`~fastNLP.core.trainer.Trainer` 紧密结合的模块,利用 Callback 可以在 :class:`~fastNLP.core.trainer.Trainer` 训练时,加入自定义的操作,比如梯度裁剪,学习率调节,测试模型的性能等。定义的 Callback 会在训练的特定阶段被调用。
  11. fastNLP 中提供了很多常用的 :class:`~fastNLP.core.callback.Callback` ,开箱即用。
  12. 使用 Callback
  13. ---------------------
  14. 使用 Callback 很简单,将需要的 callback 按 list 存储,以对应参数 ``callbacks`` 传入对应的 Trainer。Trainer 在训练时就会自动执行这些 Callback 指定的操作了。
  15. .. code-block:: python
  16. from fastNLP import (Callback, EarlyStopCallback,
  17. Trainer, CrossEntropyLoss, AccuracyMetric)
  18. from fastNLP.models import CNNText
  19. import torch.cuda
  20. # prepare data
  21. def get_data():
  22. from fastNLP.io import ChnSentiCorpPipe as pipe
  23. data = pipe().process_from_file()
  24. print(data)
  25. data.rename_field('chars', 'words')
  26. train_data = data.get_dataset('train')
  27. dev_data = data.get_dataset('dev')
  28. test_data = data.get_dataset('test')
  29. vocab = data.get_vocab('words')
  30. tgt_vocab = data.get_vocab('target')
  31. return train_data, dev_data, test_data, vocab, tgt_vocab
  32. # prepare model
  33. train_data, dev_data, _, vocab, tgt_vocab = get_data()
  34. device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
  35. model = CNNText((len(vocab),50), num_classes=len(tgt_vocab))
  36. # define callback
  37. callbacks=[EarlyStopCallback(5)]
  38. # pass callbacks to Trainer
  39. def train_with_callback(cb_list):
  40. trainer = Trainer(
  41. device=device,
  42. n_epochs=3,
  43. model=model,
  44. train_data=train_data,
  45. dev_data=dev_data,
  46. loss=CrossEntropyLoss(),
  47. metrics=AccuracyMetric(),
  48. callbacks=cb_list,
  49. check_code_level=-1
  50. )
  51. trainer.train()
  52. train_with_callback(callbacks)
  53. fastNLP 中的 Callback
  54. ---------------------
  55. fastNLP 中提供了很多常用的 Callback,如梯度裁剪,训练时早停和测试验证集,fitlog 等等。具体 Callback 请参考 :mod:`fastNLP.core.callback`
  56. .. code-block:: python
  57. from fastNLP import EarlyStopCallback, GradientClipCallback, EvaluateCallback
  58. callbacks = [
  59. EarlyStopCallback(5),
  60. GradientClipCallback(clip_value=5, clip_type='value'),
  61. EvaluateCallback(dev_data)
  62. ]
  63. train_with_callback(callbacks)
  64. 自定义 Callback
  65. ---------------------
  66. 这里我们以一个简单的 Callback作为例子,它的作用是打印每一个 Epoch 平均训练 loss。
  67. 1. 创建 Callback
  68. 要自定义 Callback,我们要实现一个类,继承 :class:`~fastNLP.core.callback.Callback` 。这里我们定义 ``MyCallBack`` ,继承 fastNLP.Callback 。
  69. 2. 指定 Callback 调用的阶段
  70. Callback 中所有以 `on_` 开头的类方法会在 Trainer 的训练中在特定阶段调用。 如 on_train_begin() 会在训练开始时被调用,on_epoch_end()
  71. 会在每个 epoch 结束时调用。 具体有哪些类方法,参见 :class:`~fastNLP.core.callback.Callback` 文档。这里, MyCallBack 在求得loss时调用 on_backward_begin() 记录
  72. 当前 loss,在每一个 epoch 结束时调用 on_epoch_end() ,求当前 epoch 平均loss并输出。
  73. 3. 使用 Callback 的属性访问 Trainer 的内部信息
  74. 为了方便使用,可以使用 :class:`~fastNLP.core.callback.Callback` 的属性,访问 :class:`~fastNLP.core.trainer.Trainer` 中的对应信息,如 optimizer, epoch, n_epochs,分别对应训练时的优化器,
  75. 当前 epoch 数,和总 epoch 数。 具体可访问的属性,参见 :class:`~fastNLP.core.callback.Callback` 。这里, MyCallBack 为了求平均 loss ,需要知道当前 epoch 的总步
  76. 数,可以通过 self.step 属性得到当前训练了多少步。
  77. .. code-block:: python
  78. from fastNLP import Callback
  79. from fastNLP import logger
  80. class MyCallBack(Callback):
  81. """Print average loss in each epoch"""
  82. def __init__(self):
  83. super().__init__()
  84. self.total_loss = 0
  85. self.start_step = 0
  86. def on_backward_begin(self, loss):
  87. self.total_loss += loss.item()
  88. def on_epoch_end(self):
  89. n_steps = self.step - self.start_step
  90. avg_loss = self.total_loss / n_steps
  91. logger.info('Avg loss at epoch %d, %.6f', self.epoch, avg_loss)
  92. self.start_step = self.step
  93. callbacks = [MyCallBack()]
  94. train_with_callback(callbacks)
  95. ----------------------------------
  96. 代码下载
  97. ----------------------------------
  98. `点击下载 IPython Notebook 文件 <https://sourcegraph.com/github.com/fastnlp/fastNLP@master/-/raw/tutorials/tutorial_9_callback.ipynb>`_)