You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

task_branching.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. """
  2. Copyright 2020 Tianshu AI Platform. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. =============================================================
  13. """
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from kamal.core.engine.engine import Engine
  18. from kamal.core.engine.hooks import FeatureHook
  19. from kamal.core import tasks
  20. from kamal.utils import move_to_device, set_mode
  21. from kamal.core.hub import meta
  22. from kamal import vision
  23. import kamal
  24. from kamal.utils import set_mode
  25. import typing
  26. import time
  27. from copy import deepcopy
  28. import random
  29. import numpy as np
  30. from collections import defaultdict
  31. import numbers
  32. class BranchySegNet(nn.Module):
  33. def __init__(self, out_channels, segnet_fn=vision.models.segmentation.segnet_vgg16_bn):
  34. super(BranchySegNet, self).__init__()
  35. channels=[512, 512, 256, 128, 64]
  36. self.register_buffer( 'branch_indices', torch.zeros((len(out_channels),)) )
  37. self.student_b_decoders_list = nn.ModuleList()
  38. self.student_adaptors_list = nn.ModuleList()
  39. ses = []
  40. for i in range(5):
  41. se = int(channels[i]/4)
  42. ses.append(16 if se < 16 else se)
  43. for oc in out_channels:
  44. segnet = self.get_segnet( oc, segnet_fn )
  45. decoders = nn.ModuleList(deepcopy(list(segnet.children())[5:]))
  46. adaptors = nn.ModuleList()
  47. for i in range(5):
  48. adaptor = nn.Sequential(
  49. nn.Conv2d(channels[i], ses[i], kernel_size=1, bias=False),
  50. nn.ReLU(),
  51. nn.Conv2d(ses[i], channels[i], kernel_size=1, bias=False),
  52. nn.Sigmoid()
  53. )
  54. adaptors.append(adaptor)
  55. self.student_b_decoders_list.append(decoders)
  56. self.student_adaptors_list.append(adaptors)
  57. self.student_encoders = nn.ModuleList(deepcopy(list(segnet.children())[0:5]))
  58. self.student_decoders = nn.ModuleList(deepcopy(list(segnet.children())[5:]))
  59. def set_branch(self, branch_indices):
  60. assert len(branch_indices)==len(self.student_b_decoders_list)
  61. self.branch_indices = torch.from_numpy( np.array( branch_indices ) ).to(self.branch_indices.device)
  62. def get_segnet(self, oc, segnet_fn):
  63. return segnet_fn( num_classes=oc, pretrained_backbone=True )
  64. def forward(self, inputs):
  65. output_list = []
  66. down1, indices_1, unpool_shape1 = self.student_encoders[0](inputs)
  67. down2, indices_2, unpool_shape2 = self.student_encoders[1](down1)
  68. down3, indices_3, unpool_shape3 = self.student_encoders[2](down2)
  69. down4, indices_4, unpool_shape4 = self.student_encoders[3](down3)
  70. down5, indices_5, unpool_shape5 = self.student_encoders[4](down4)
  71. up5 = self.student_decoders[0](down5, indices_5, unpool_shape5)
  72. up4 = self.student_decoders[1](up5, indices_4, unpool_shape4)
  73. up3 = self.student_decoders[2](up4, indices_3, unpool_shape3)
  74. up2 = self.student_decoders[3](up3, indices_2, unpool_shape2)
  75. up1 = self.student_decoders[4](up2, indices_1, unpool_shape1)
  76. decoder_features = [down5, up5, up4, up3, up2]
  77. decoder_indices = [indices_5, indices_4, indices_3, indices_2, indices_1]
  78. decoder_shapes = [unpool_shape5, unpool_shape4, unpool_shape3, unpool_shape2, unpool_shape1]
  79. # Mimic teachers.
  80. for i in range(len(self.branch_indices)):
  81. out_idx = self.branch_indices[i]
  82. output = decoder_features[out_idx]
  83. output = output * self.student_adaptors_list[i][out_idx](F.avg_pool2d(output, output.shape[2:3]))
  84. for j in range(out_idx, 5):
  85. output = self.student_b_decoders_list[i][j](
  86. output,
  87. decoder_indices[j],
  88. decoder_shapes[j]
  89. )
  90. output_list.append( output )
  91. return output_list
  92. class JointSegNet(nn.Module):
  93. """The online student model to learn from any number of single teacher with 'SegNet' structure.
  94. **Parameters:**
  95. - **teachers** (list of 'Module' object): Teachers with 'SegNet' structure to learn from.
  96. - **indices** (list of int): Where to branch out for each task.
  97. - **phase** (string): Should be 'block' or 'finetune'. Useful only in training mode.
  98. - **channels** (list of int, optional): Parameter to build adaptor modules, corresponding to that of 'SegNet'.
  99. """
  100. def __init__(self, teachers, student=None, channels=[512, 512, 256, 128, 64]):
  101. super(JointSegNet, self).__init__()
  102. self.register_buffer( 'branch_indices', torch.zeros((2,)) )
  103. if student is None:
  104. student = teachers[0]
  105. self.student_encoders = nn.ModuleList(deepcopy(list(teachers[0].children())[0:5]))
  106. self.student_decoders = nn.ModuleList(deepcopy(list(teachers[0].children())[5:]))
  107. self.student_b_decoders_list = nn.ModuleList()
  108. self.student_adaptors_list = nn.ModuleList()
  109. ses = []
  110. for i in range(5):
  111. se = int(channels[i]/4)
  112. ses.append(16 if se < 16 else se)
  113. for teacher in teachers:
  114. decoders = nn.ModuleList(deepcopy(list(teacher.children())[5:]))
  115. adaptors = nn.ModuleList()
  116. for i in range(5):
  117. adaptor = nn.Sequential(
  118. nn.Conv2d(channels[i], ses[i], kernel_size=1, bias=False),
  119. nn.ReLU(),
  120. nn.Conv2d(ses[i], channels[i], kernel_size=1, bias=False),
  121. nn.Sigmoid()
  122. )
  123. adaptors.append(adaptor)
  124. self.student_b_decoders_list.append(decoders)
  125. self.student_adaptors_list.append(adaptors)
  126. def set_branch(self, branch_indices):
  127. assert len(branch_indices)==len(self.student_b_decoders_list)
  128. self.branch_indices = torch.from_numpy( np.array( branch_indices ) ).to(self.branch_indices.device)
  129. def forward(self, inputs):
  130. output_list = []
  131. down1, indices_1, unpool_shape1 = self.student_encoders[0](inputs)
  132. down2, indices_2, unpool_shape2 = self.student_encoders[1](down1)
  133. down3, indices_3, unpool_shape3 = self.student_encoders[2](down2)
  134. down4, indices_4, unpool_shape4 = self.student_encoders[3](down3)
  135. down5, indices_5, unpool_shape5 = self.student_encoders[4](down4)
  136. up5 = self.student_decoders[0](down5, indices_5, unpool_shape5)
  137. up4 = self.student_decoders[1](up5, indices_4, unpool_shape4)
  138. up3 = self.student_decoders[2](up4, indices_3, unpool_shape3)
  139. up2 = self.student_decoders[3](up3, indices_2, unpool_shape2)
  140. up1 = self.student_decoders[4](up2, indices_1, unpool_shape1)
  141. decoder_features = [down5, up5, up4, up3, up2]
  142. decoder_indices = [indices_5, indices_4, indices_3, indices_2, indices_1]
  143. decoder_shapes = [unpool_shape5, unpool_shape4, unpool_shape3, unpool_shape2, unpool_shape1]
  144. # Mimic teachers.
  145. for i in range(len(self.branch_indices)):
  146. out_idx = self.branch_indices[i]
  147. output = decoder_features[out_idx]
  148. output = output * self.student_adaptors_list[i][out_idx](F.avg_pool2d(output, output.shape[2:3]))
  149. for j in range(out_idx, 5):
  150. output = self.student_b_decoders_list[i][j](
  151. output,
  152. decoder_indices[j],
  153. decoder_shapes[j]
  154. )
  155. output_list.append( output )
  156. return output_list
  157. class TaskBranchingAmalgamator(Engine):
  158. def setup(
  159. self,
  160. joint_student: JointSegNet,
  161. teachers,
  162. tasks,
  163. dataloader: torch.utils.data.DataLoader,
  164. optimizer: torch.optim.Optimizer,
  165. device=None,
  166. ):
  167. if device is None:
  168. device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
  169. self._device = device
  170. self._dataloader = dataloader
  171. self.student = self.model = joint_student.to(self._device)
  172. self.teachers = nn.ModuleList(teachers).to(self._device)
  173. self.tasks = tasks
  174. self.optimizer = optimizer
  175. self.is_finetuning=False
  176. @property
  177. def device(self):
  178. return self._device
  179. def run(self, max_iter, start_iter=0, epoch_length=None, stage_callback=None ):
  180. # Branching
  181. with set_mode(self.student, training=True), \
  182. set_mode(self.teachers, training=False):
  183. super( TaskBranchingAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter//2, epoch_length=epoch_length)
  184. branch = self.find_the_best_branch( self._dataloader )
  185. self.logger.info("[Task Branching] the best branch indices: %s"%(branch))
  186. if stage_callback is not None:
  187. stage_callback()
  188. # Finetuning
  189. self.is_finetuning = True
  190. with set_mode(self.student, training=True), \
  191. set_mode(self.teachers, training=False):
  192. super( TaskBranchingAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=max_iter-max_iter//2, max_iter=max_iter, epoch_length=epoch_length)
  193. def find_the_best_branch(self, dataloader):
  194. with set_mode(self.student, training=False), \
  195. set_mode(self.teachers, training=False), \
  196. torch.no_grad():
  197. n_blocks = len(self.student.student_decoders)
  198. branch_loss = { task: [0 for _ in range(n_blocks)] for task in self.tasks }
  199. for batch in dataloader:
  200. batch = move_to_device(batch, self.device)
  201. data = batch if isinstance(batch, torch.Tensor) else batch[0]
  202. for candidate_branch in range( n_blocks ):
  203. self.student.set_branch( [candidate_branch for _ in range(len(self.teachers))] )
  204. s_out_list = self.student( data )
  205. t_out_list = [ teacher( data ) for teacher in self.teachers ]
  206. for task, s_out, t_out in zip( self.tasks, s_out_list, t_out_list ):
  207. task_loss = task.get_loss( s_out, t_out )
  208. branch_loss[task][candidate_branch] += sum(task_loss.values())
  209. best_brach = []
  210. for task in self.tasks:
  211. best_brach.append( int(np.argmin( branch_loss[task] )) )
  212. self.student.set_branch(best_brach)
  213. return best_brach
  214. @property
  215. def device(self):
  216. return self._device
  217. def step_fn(self, engine, batch):
  218. start_time = time.perf_counter()
  219. batch = move_to_device(batch, self._device)
  220. data = batch[0]
  221. #data = batch if isinstance(batch, torch.Tensor) else batch[0]
  222. data, None
  223. n_blocks = len(self.student.student_decoders)
  224. if not self.is_finetuning:
  225. rand_branch_indices = [ random.randint(0, n_blocks-1) for _ in range(len(self.teachers)) ]
  226. self.student.set_branch(rand_branch_indices)
  227. s_out_list = self.student( data )
  228. with torch.no_grad():
  229. t_out_list = [ teacher( data ) for teacher in self.teachers ]
  230. loss_dict = {}
  231. for task, s_out, t_out in zip( self.tasks, s_out_list, t_out_list ):
  232. task_loss = task.get_loss( s_out, t_out )
  233. loss_dict.update( task_loss )
  234. loss = sum(loss_dict.values())
  235. self.optimizer.zero_grad()
  236. loss.backward()
  237. self.optimizer.step()
  238. step_time = time.perf_counter() - start_time
  239. metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() }
  240. metrics.update({
  241. 'total_loss': loss.item(),
  242. 'step_time': step_time,
  243. 'lr': float( self.optimizer.param_groups[0]['lr'] ),
  244. 'branch': self.student.branch_indices.cpu().numpy().tolist()
  245. })
  246. return metrics

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)