You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. """
  2. Copyright 2020 Tianshu AI Platform. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. =============================================================
  13. """
  14. import torch
  15. import torch.nn as nn
  16. from kamal.core.engine.engine import Engine, Event, DefaultEvents, State
  17. from kamal.core import tasks
  18. from kamal.utils import set_mode, move_to_device, get_logger, split_batch
  19. from typing import Callable, Mapping, Any, Sequence
  20. import time
  21. import weakref
  22. class BasicTrainer(Engine):
  23. def __init__( self,
  24. logger=None,
  25. tb_writer=None):
  26. super(BasicTrainer, self).__init__(logger=logger, tb_writer=tb_writer)
  27. def setup(self,
  28. model: torch.nn.Module,
  29. task: tasks.Task,
  30. dataloader: torch.utils.data.DataLoader,
  31. optimizer: torch.optim.Optimizer,
  32. device: torch.device=None):
  33. if device is None:
  34. device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
  35. self.device = device
  36. if isinstance(task, Sequence):
  37. task = tasks.TaskCompose(task)
  38. self.task = task
  39. self.model = model
  40. self.dataloader = dataloader
  41. self.optimizer = optimizer
  42. return self
  43. def run( self, max_iter, start_iter=0, epoch_length=None):
  44. self.model.to(self.device)
  45. with set_mode(self.model, training=True):
  46. super( BasicTrainer, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length)
  47. def step_fn(self, engine, batch):
  48. model = self.model
  49. start_time = time.perf_counter()
  50. batch = move_to_device(batch, self.device)
  51. inputs, targets = split_batch(batch)
  52. outputs = model(inputs)
  53. loss_dict = self.task.get_loss(outputs, targets) # get loss
  54. loss = sum( loss_dict.values() )
  55. self.optimizer.zero_grad()
  56. loss.backward()
  57. self.optimizer.step()
  58. step_time = time.perf_counter() - start_time
  59. metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() }
  60. metrics.update({
  61. 'total_loss': loss.item(),
  62. 'step_time': step_time,
  63. 'lr': float( self.optimizer.param_groups[0]['lr'] )
  64. })
  65. return metrics
  66. class KDTrainer(BasicTrainer):
  67. def setup(self,
  68. student: torch.nn.Module,
  69. teacher: torch.nn.Module,
  70. task: tasks.Task,
  71. dataloader: torch.utils.data.DataLoader,
  72. optimizer: torch.optim.Optimizer,
  73. device: torch.device=None):
  74. super(KDTrainer, self).setup(
  75. model=student, task=task, dataloader=dataloader, optimizer=optimizer, device=device)
  76. if isinstance(teacher, (list, tuple)):
  77. if len(teacher)==1:
  78. teacher=teacher[0]
  79. else:
  80. teacher = nn.ModuleList(teacher)
  81. self.student = self.model
  82. self.teacher = teacher
  83. return self
  84. def run( self, max_iter, start_iter=0, epoch_length=None):
  85. self.student.to(self.device)
  86. self.teacher.to(self.device)
  87. with set_mode(self.student, training=True), \
  88. set_mode(self.teacher, training=False):
  89. super( BasicTrainer, self ).run(
  90. self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length)
  91. def step_fn(self, engine, batch):
  92. model = self.model
  93. start_time = time.perf_counter()
  94. batch = move_to_device(batch, self.device)
  95. inputs, targets = split_batch(batch)
  96. outputs = model(inputs)
  97. if isinstance(self.teacher, nn.ModuleList):
  98. soft_targets = [ t(inputs) for t in self.teacher ]
  99. else:
  100. soft_targets = self.teacher(inputs)
  101. loss_dict = self.task.get_loss(outputs, soft_targets) # get loss
  102. loss = sum( loss_dict.values() )
  103. self.optimizer.zero_grad()
  104. loss.backward()
  105. self.optimizer.step()
  106. step_time = time.perf_counter() - start_time
  107. metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() }
  108. metrics.update({
  109. 'total_loss': loss.item(),
  110. 'step_time': step_time,
  111. 'lr': float( self.optimizer.param_groups[0]['lr'] )
  112. })
  113. return metrics

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)