You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper_callbacks_torch.py 2.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from copy import deepcopy
  2. from fastNLP.core.callbacks.callback import Callback
  3. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  4. if _NEED_IMPORT_TORCH:
  5. import torch
  6. class RecordAccumulationStepsCallback_Torch(Callback):
  7. """
  8. 通过该 callback 来测试 Trainer 的 accumulation_steps 是否实现正确;
  9. 1. 在每一个 batch 检验模型是否正确地得到了更新(只有每隔 accumulation_steps 模型的参数才应该改变);
  10. 2. 检验 optimizer 的参数是否只在正确的时刻进行了清零;
  11. """
  12. def __init__(self, accumulation_steps: int):
  13. self.accumulation_steps = accumulation_steps
  14. self.last_batch_params = None
  15. self.equal = 0
  16. def on_train_batch_end(self, trainer):
  17. # 注意这里的 trainer.global_forward_steps 的值比 trainer 上一次调用 batch_step_fn 的值大一;
  18. if trainer.global_forward_batches % trainer.accumulation_steps == 0:
  19. # 模型的参数应该与上一个 batch 不同;
  20. cur_batch_params = deepcopy(next(trainer.driver.unwrap_model().parameters()).cpu().detach())
  21. if self.last_batch_params is not None:
  22. assert not cur_batch_params.equal(self.last_batch_params)
  23. if cur_batch_params.equal(self.last_batch_params):
  24. self.equal += 1
  25. # optimizer 的梯度应该得到了清零;
  26. optimizers = trainer.driver.optimizers
  27. for optimizer in optimizers:
  28. param_groups = optimizer.param_groups
  29. for group in param_groups:
  30. for p in group['params']:
  31. assert p.grad is None or p.grad.equal(torch.zeros_like(p.grad))
  32. else:
  33. # 模型的参数应该与上一个 batch 相同;
  34. cur_batch_params = deepcopy(next(trainer.driver.unwrap_model().parameters()).cpu().detach())
  35. if self.last_batch_params is not None:
  36. assert cur_batch_params.equal(self.last_batch_params)
  37. # optimizer 的梯度不应该得到了清零;
  38. optimizers = trainer.driver.optimizers
  39. for optimizer in optimizers:
  40. param_groups = optimizer.param_groups
  41. for group in param_groups:
  42. for p in group['params']:
  43. assert p.grad is not None and not p.grad.equal(torch.zeros_like(p.grad))
  44. self.last_batch_params = cur_batch_params
  45. def on_train_end(self, trainer):
  46. print(f"\n equal num: {self.equal}.\n")
  47. print(f"\ntotal_batch_num: {trainer.n_batches}.\n")