You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callbacks.py 4.6 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import logging
  4. import os
  5. import torch
  6. import torch.nn as nn
  7. _logger = logging.getLogger(__name__)
  8. _logger.setLevel(logging.INFO)
  9. class Callback:
  10. """
  11. Callback provides an easy way to react to events like begin/end of epochs.
  12. """
  13. def __init__(self):
  14. self.model = None
  15. self.optimizer = None
  16. self.mutator = None
  17. self.trainer = None
  18. def build(self, model, optimizer, mutator, trainer):
  19. """
  20. Callback needs to be built with model, mutator, trainer, to get updates from them.
  21. Parameters
  22. ----------
  23. model : nn.Module
  24. Model to be trained.
  25. mutator : nn.Module
  26. Mutator that mutates the model.
  27. trainer : BaseTrainer
  28. Trainer that is to call the callback.
  29. """
  30. self.model = model
  31. self.optimizer = optimizer
  32. self.mutator = mutator
  33. self.trainer = trainer
  34. def on_epoch_begin(self, epoch):
  35. """
  36. Implement this to do something at the begin of epoch.
  37. Parameters
  38. ----------
  39. epoch : int
  40. Epoch number, starting from 0.
  41. """
  42. pass
  43. def on_epoch_end(self, epoch):
  44. """
  45. Implement this to do something at the end of epoch.
  46. Parameters
  47. ----------
  48. epoch : int
  49. Epoch number, starting from 0.
  50. """
  51. pass
  52. def on_batch_begin(self, epoch):
  53. pass
  54. def on_batch_end(self, epoch):
  55. pass
  56. class LRSchedulerCallback(Callback):
  57. """
  58. Calls scheduler on every epoch ends.
  59. Parameters
  60. ----------
  61. scheduler : LRScheduler
  62. Scheduler to be called.
  63. """
  64. def __init__(self, scheduler, mode="epoch"):
  65. super().__init__()
  66. assert mode == "epoch"
  67. self.scheduler = scheduler
  68. self.mode = mode
  69. def on_epoch_end(self, epoch):
  70. """
  71. Call ``self.scheduler.step()`` on epoch end.
  72. """
  73. self.scheduler.step()
  74. class ArchitectureCheckpoint(Callback):
  75. """
  76. Calls ``trainer.export()`` on every epoch ends.
  77. Parameters
  78. ----------
  79. checkpoint_dir : str
  80. Location to save checkpoints.
  81. """
  82. def __init__(self, checkpoint_dir):
  83. super().__init__()
  84. self.checkpoint_dir = checkpoint_dir
  85. os.makedirs(self.checkpoint_dir, exist_ok=True)
  86. def on_epoch_end(self, epoch):
  87. """
  88. Dump to ``/checkpoint_dir/epoch_{number}.json`` on epoch end.
  89. """
  90. dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch))
  91. _logger.info("Saving architecture to %s", dest_path)
  92. self.trainer.export(dest_path)
  93. class BestArchitectureCheckpoint(Callback):
  94. """
  95. Calls ``trainer.export()`` on final epoch ends.
  96. Parameters
  97. ----------
  98. checkpoint_path : str
  99. Location to save checkpoints.
  100. """
  101. def __init__(self, checkpoint_path, epoches):
  102. super().__init__()
  103. self.epoches = epoches
  104. self.checkpoint_path = checkpoint_path
  105. def on_epoch_end(self, epoch):
  106. """
  107. Dump to ``./best_selected_space.json`` on epoch end.
  108. """
  109. if epoch == self.epoches -1:
  110. _logger.info("Saving architecture to %s", self.checkpoint_path)
  111. self.trainer.export(self.checkpoint_path)
  112. class ModelCheckpoint(Callback):
  113. """
  114. Calls ``trainer.export()`` on every epoch ends.
  115. Parameters
  116. ----------
  117. checkpoint_dir : str
  118. Location to save checkpoints.
  119. """
  120. def __init__(self, checkpoint_dir):
  121. super().__init__()
  122. self.checkpoint_dir = checkpoint_dir
  123. os.makedirs(self.checkpoint_dir, exist_ok=True)
  124. def on_epoch_end(self, epoch):
  125. """
  126. Dump to ``/checkpoint_dir/epoch_{number}.pth.tar`` on every epoch end.
  127. ``DataParallel`` object will have their inside modules exported.
  128. """
  129. if isinstance(self.model, nn.DataParallel):
  130. child_model_state_dict = self.model.module.state_dict()
  131. else:
  132. child_model_state_dict = self.model.state_dict()
  133. save_state = {'child_model_state_dict': child_model_state_dict,
  134. 'optimizer_state_dict': self.optimizer.state_dict(),
  135. 'epoch': epoch}
  136. dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.pth.tar".format(epoch))
  137. _logger.info("Saving model to %s", dest_path)
  138. torch.save(save_state, dest_path)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能