You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

layerwise_amalgamation.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. """
  2. Copyright 2020 Tianshu AI Platform. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. =============================================================
  13. """
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from kamal.core.engine.engine import Engine
  18. from kamal.core.engine.hooks import FeatureHook
  19. from kamal.core import tasks
  20. from kamal.utils import set_mode
  21. import typing
  22. import time
  23. from kamal.utils import move_to_device, set_mode
  24. class AmalBlock(nn.Module):
  25. def __init__(self, cs, cts):
  26. super( AmalBlock, self ).__init__()
  27. self.cs, self.cts = cs, cts
  28. self.enc = nn.Conv2d( in_channels=sum(self.cts), out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True )
  29. self.fam = nn.Conv2d( in_channels=self.cs, out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True )
  30. self.dec = nn.Conv2d( in_channels=self.cs, out_channels=sum(self.cts), kernel_size=1, stride=1, padding=0, bias=True )
  31. def forward(self, fs, fts):
  32. rep = self.enc( torch.cat( fts, dim=1 ) )
  33. _fts = self.dec( rep )
  34. _fts = torch.split( _fts, self.cts, dim=1 )
  35. _fs = self.fam( fs )
  36. return rep, _fs, _fts
  37. class LayerWiseAmalgamator(Engine):
  38. def setup(
  39. self,
  40. student,
  41. teachers,
  42. layer_groups: typing.Sequence[typing.Sequence],
  43. layer_channels: typing.Sequence[typing.Sequence],
  44. dataloader: torch.utils.data.DataLoader,
  45. optimizer: torch.optim.Optimizer,
  46. weights = [1., 1., 1.],
  47. device=None,
  48. ):
  49. if device is None:
  50. device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
  51. self._device = device
  52. self._dataloader = dataloader
  53. self.model = self.student = student.to(self.device)
  54. self.teachers = nn.ModuleList(teachers).to(self.device)
  55. self.optimizer = optimizer
  56. self._weights = weights
  57. amal_blocks = []
  58. for group, C in zip(layer_groups, layer_channels):
  59. hooks = [ FeatureHook(layer) for layer in group ]
  60. amal_block = AmalBlock(cs=C[0], cts=C[1:]).to(self.device).train()
  61. amal_blocks.append( (amal_block, hooks, C) )
  62. self._amal_blocks = amal_blocks
  63. @property
  64. def device(self):
  65. return self._device
  66. def run(self, max_iter, start_iter=0, epoch_length=None ):
  67. block_params = []
  68. for block, _, _ in self._amal_blocks:
  69. block_params.extend( list(block.parameters()) )
  70. if isinstance( self.optimizer, torch.optim.SGD ):
  71. self._amal_optimimizer = torch.optim.SGD( block_params, lr=self.optimizer.param_groups[0]['lr'], momentum=0.9, weight_decay=1e-4 )
  72. else:
  73. self._amal_optimimizer = torch.optim.Adam( block_params, lr=self.optimizer.param_groups[0]['lr'], weight_decay=1e-4 )
  74. self._amal_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._amal_optimimizer, T_max=max_iter )
  75. with set_mode(self.student, training=True), \
  76. set_mode(self.teachers, training=False):
  77. super( LayerWiseAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length)
  78. @property
  79. def device(self):
  80. return self._device
  81. def step_fn(self, engine, batch):
  82. start_time = time.perf_counter()
  83. batch = move_to_device(batch, self._device)
  84. data = batch[0]
  85. s_out = self.student( data )
  86. with torch.no_grad():
  87. t_out = [ teacher( data ) for teacher in self.teachers ]
  88. loss_amal = 0
  89. loss_recons = 0
  90. for amal_block, hooks, C in self._amal_blocks:
  91. features = [ h.feat_out for h in hooks ]
  92. fs, fts = features[0], features[1:]
  93. rep, _fs, _fts = amal_block( fs, fts )
  94. loss_amal += F.mse_loss( _fs, rep.detach() )
  95. loss_recons += sum( [ F.mse_loss( _ft, ft ) for (_ft, ft) in zip( _fts, fts ) ] )
  96. loss_kd = tasks.loss.kldiv( s_out, torch.cat( t_out, dim=1 ) )
  97. #loss_kd = F.mse_loss( s_out, torch.cat( t_out, dim=1 ) )
  98. loss_dict = { "loss_kd": self._weights[0] * loss_kd,
  99. "loss_amal": self._weights[1] * loss_amal,
  100. "loss_recons": self._weights[2] * loss_recons }
  101. loss = sum(loss_dict.values())
  102. self.optimizer.zero_grad()
  103. self._amal_optimimizer.zero_grad()
  104. loss.backward()
  105. self.optimizer.step()
  106. self._amal_optimimizer.step()
  107. self._amal_scheduler.step()
  108. step_time = time.perf_counter() - start_time
  109. metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() }
  110. metrics.update({
  111. 'total_loss': loss.item(),
  112. 'step_time': step_time,
  113. 'lr': float( self.optimizer.param_groups[0]['lr'] )
  114. })
  115. return metrics

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)