You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 6.6 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import json
  4. import logging
  5. import os
  6. import time
  7. from abc import abstractmethod
  8. import torch
  9. from .base_trainer import BaseTrainer
  10. _logger = logging.getLogger(__name__)
  11. class TorchTensorEncoder(json.JSONEncoder):
  12. def default(self, o): # pylint: disable=method-hidden
  13. if isinstance(o, torch.Tensor):
  14. olist = o.tolist()
  15. if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)):
  16. _logger.warning("Every element in %s is either 0 or 1. "
  17. "You might consider convert it into bool.", olist)
  18. return olist
  19. return super().default(o)
  20. class Trainer(BaseTrainer):
  21. """
  22. A trainer with some helper functions implemented. To implement a new trainer,
  23. users need to implement :meth:`train_one_epoch`, :meth:`validate_one_epoch` and :meth:`checkpoint`.
  24. Parameters
  25. ----------
  26. model : nn.Module
  27. Model with mutables.
  28. mutator : BaseMutator
  29. A mutator object that has been initialized with the model.
  30. loss : callable
  31. Called with logits and targets. Returns a loss tensor.
  32. See `PyTorch loss functions`_ for examples.
  33. metrics : callable
  34. Called with logits and targets. Returns a dict that maps metrics keys to metrics data. For example,
  35. .. code-block:: python
  36. def metrics_fn(output, target):
  37. return {"acc1": accuracy(output, target, topk=1), "acc5": accuracy(output, target, topk=5)}
  38. optimizer : Optimizer
  39. Optimizer that optimizes the model.
  40. num_epochs : int
  41. Number of epochs of training.
  42. dataset_train : torch.utils.data.Dataset
  43. Dataset of training. If not otherwise specified, ``dataset_train`` and ``dataset_valid`` should be standard
  44. PyTorch Dataset. See `torch.utils.data`_ for examples.
  45. dataset_valid : torch.utils.data.Dataset
  46. Dataset of validation/testing.
  47. batch_size : int
  48. Batch size.
  49. workers : int
  50. Number of workers used in data preprocessing.
  51. device : torch.device
  52. Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
  53. automatic detects GPU and selects GPU first.
  54. log_frequency : int
  55. Number of mini-batches to log metrics.
  56. callbacks : list of Callback
  57. Callbacks to plug into the trainer. See Callbacks.
  58. .. _`PyTorch loss functions`: https://pytorch.org/docs/stable/nn.html#loss-functions
  59. .. _`torch.utils.data`: https://pytorch.org/docs/stable/data.html
  60. """
  61. def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs,
  62. dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks):
  63. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
  64. self.model = model
  65. self.mutator = mutator
  66. self.loss = loss
  67. self.metrics = metrics
  68. self.optimizer = optimizer
  69. self.model.to(self.device)
  70. self.mutator.to(self.device)
  71. self.loss.to(self.device)
  72. self.num_epochs = num_epochs
  73. self.dataset_train = dataset_train
  74. self.dataset_valid = dataset_valid
  75. self.batch_size = batch_size
  76. self.workers = workers
  77. self.log_frequency = log_frequency
  78. # self.log_dir = os.path.join("logs", str(time.time()))
  79. # os.makedirs(self.log_dir, exist_ok=True)
  80. # self.status_writer = open(os.path.join(self.log_dir, "log"), "w")
  81. self.callbacks = callbacks if callbacks is not None else []
  82. for callback in self.callbacks:
  83. callback.build(self.model, self.optimizer, self.mutator, self)
  84. @abstractmethod
  85. def train_one_epoch(self, epoch):
  86. """
  87. Train one epoch.
  88. Parameters
  89. ----------
  90. epoch : int
  91. Epoch number starting from 0.
  92. """
  93. pass
  94. @abstractmethod
  95. def validate_one_epoch(self, epoch):
  96. """
  97. Validate one epoch.
  98. Parameters
  99. ----------
  100. epoch : int
  101. Epoch number starting from 0.
  102. """
  103. pass
  104. def train(self, validate=True):
  105. """
  106. Train ``num_epochs``.
  107. Trigger callbacks at the start and the end of each epoch.
  108. Parameters
  109. ----------
  110. validate : bool
  111. If ``true``, will do validation every epoch.
  112. """
  113. for epoch in range(self.num_epochs):
  114. for callback in self.callbacks:
  115. callback.on_epoch_begin(epoch)
  116. # training
  117. _logger.info("Epoch %d Training", epoch + 1)
  118. self.train_one_epoch(epoch)
  119. if validate:
  120. # validation
  121. _logger.info("Epoch %d Validating", epoch + 1)
  122. self.validate_one_epoch(epoch)
  123. for callback in self.callbacks:
  124. callback.on_epoch_end(epoch)
  125. def validate(self):
  126. """
  127. Do one validation.
  128. """
  129. self.validate_one_epoch(-1)
  130. def export(self, file):
  131. """
  132. Call ``mutator.export()`` and dump the architecture to ``file``.
  133. Parameters
  134. ----------
  135. file : str
  136. A file path. Expected to be a JSON.
  137. """
  138. mutator_export = self.mutator.export()
  139. with open(file, "w") as f:
  140. json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
  141. def checkpoint(self):
  142. """
  143. Return trainer checkpoint.
  144. """
  145. raise NotImplementedError("Not implemented yet")
  146. # 暂时还没确定graph输出格式
  147. # def enable_visualization(self):
  148. # """
  149. # Enable visualization. Write graph and training log to folder ``logs/<timestamp>``.
  150. # """
  151. # sample = None
  152. # for x, _ in self.train_loader:
  153. # sample = x.to(self.device)[:2]
  154. # break
  155. # if sample is None:
  156. # _logger.warning("Sample is %s.", sample)
  157. # _logger.info("Creating graph json, writing to %s. Visualization enabled.", self.log_dir)
  158. # with open(os.path.join(self.log_dir, "graph.json"), "w") as f:
  159. # json.dump(self.mutator.graph(sample), f)
  160. # self.visualization_enabled = True
  161. # def _write_graph_status(self):
  162. # if hasattr(self, "visualization_enabled") and self.visualization_enabled:
  163. # print(json.dumps(self.mutator.status()), file=self.status_writer, flush=True)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能