from copy import deepcopy from fastNLP.core.callbacks.callback import Callback from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch class RecordAccumulationStepsCallback_Torch(Callback): """ 通过该 callback 来测试 Trainer 的 accumulation_steps 是否实现正确; 1. 在每一个 batch 检验模型是否正确地得到了更新(只有每隔 accumulation_steps 模型的参数才应该改变); 2. 检验 optimizer 的参数是否只在正确的时刻进行了清零; """ def __init__(self, accumulation_steps: int): self.accumulation_steps = accumulation_steps self.last_batch_params = None self.equal = 0 def on_train_batch_end(self, trainer): # 注意这里的 trainer.global_forward_steps 的值比 trainer 上一次调用 batch_step_fn 的值大一; if trainer.global_forward_batches % trainer.accumulation_steps == 0: # 模型的参数应该与上一个 batch 不同; cur_batch_params = deepcopy(next(trainer.driver.unwrap_model().parameters()).cpu().detach()) if self.last_batch_params is not None: assert not cur_batch_params.equal(self.last_batch_params) if cur_batch_params.equal(self.last_batch_params): self.equal += 1 # optimizer 的梯度应该得到了清零; optimizers = trainer.driver.optimizers for optimizer in optimizers: param_groups = optimizer.param_groups for group in param_groups: for p in group['params']: assert p.grad is None or p.grad.equal(torch.zeros_like(p.grad)) else: # 模型的参数应该与上一个 batch 相同; cur_batch_params = deepcopy(next(trainer.driver.unwrap_model().parameters()).cpu().detach()) if self.last_batch_params is not None: assert cur_batch_params.equal(self.last_batch_params) # optimizer 的梯度不应该得到了清零; optimizers = trainer.driver.optimizers for optimizer in optimizers: param_groups = optimizer.param_groups for group in param_groups: for p in group['params']: assert p.grad is not None and not p.grad.equal(torch.zeros_like(p.grad)) self.last_batch_params = cur_batch_params def on_train_end(self, trainer): print(f"\n equal num: {self.equal}.\n") print(f"\ntotal_batch_num: {trainer.n_batches}.\n")