You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper_callbacks_torch.py 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import torch
  2. from copy import deepcopy
  3. from fastNLP.core.callbacks.callback import Callback
  4. class RecordAccumulationStepsCallback_Torch(Callback):
  5. """
  6. 通过该 callback 来测试 Trainer 的 accumulation_steps 是否实现正确;
  7. 1. 在每一个 batch 检验模型是否正确地得到了更新(只有每隔 accumulation_steps 模型的参数才应该改变);
  8. 2. 检验 optimizer 的参数是否只在正确的时刻进行了清零;
  9. """
  10. def __init__(self, accumulation_steps: int):
  11. self.accumulation_steps = accumulation_steps
  12. self.last_batch_params = None
  13. self.equal = 0
  14. def on_train_batch_end(self, trainer):
  15. # 注意这里的 trainer.global_forward_steps 的值比 trainer 上一次调用 batch_step_fn 的值大一;
  16. if trainer.global_forward_batches % trainer.accumulation_steps == 0:
  17. # 模型的参数应该与上一个 batch 不同;
  18. cur_batch_params = deepcopy(next(trainer.driver.unwrap_model().parameters()).cpu().detach())
  19. if self.last_batch_params is not None:
  20. assert not cur_batch_params.equal(self.last_batch_params)
  21. if cur_batch_params.equal(self.last_batch_params):
  22. self.equal += 1
  23. # optimizer 的梯度应该得到了清零;
  24. optimizers = trainer.driver.optimizers
  25. for optimizer in optimizers:
  26. param_groups = optimizer.param_groups
  27. for group in param_groups:
  28. for p in group['params']:
  29. assert p.grad is None or p.grad.equal(torch.zeros_like(p.grad))
  30. else:
  31. # 模型的参数应该与上一个 batch 相同;
  32. cur_batch_params = deepcopy(next(trainer.driver.unwrap_model().parameters()).cpu().detach())
  33. if self.last_batch_params is not None:
  34. assert cur_batch_params.equal(self.last_batch_params)
  35. # optimizer 的梯度不应该得到了清零;
  36. optimizers = trainer.driver.optimizers
  37. for optimizer in optimizers:
  38. param_groups = optimizer.param_groups
  39. for group in param_groups:
  40. for p in group['params']:
  41. assert p.grad is not None and not p.grad.equal(torch.zeros_like(p.grad))
  42. self.last_batch_params = cur_batch_params
  43. def on_train_end(self, trainer):
  44. print(f"\n equal num: {self.equal}.\n")
  45. print(f"\ntotal_batch_num: {trainer.total_batches}.\n")