You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mutator.py 9.5 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from pytorch.mutator import Mutator
  7. from pytorch.mutables import LayerChoice, InputChoice, MutableScope
  8. class StackedLSTMCell(nn.Module):
  9. def __init__(self, layers, size, bias):
  10. super().__init__()
  11. self.lstm_num_layers = layers
  12. self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
  13. for _ in range(self.lstm_num_layers)])
  14. def forward(self, inputs, hidden):
  15. prev_c, prev_h = hidden
  16. next_c, next_h = [], []
  17. for i, m in enumerate(self.lstm_modules):
  18. curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i]))
  19. next_c.append(curr_c)
  20. next_h.append(curr_h)
  21. # current implementation only supports batch size equals 1,
  22. # but the algorithm does not necessarily have this limitation
  23. inputs = curr_h[-1].view(1, -1)
  24. return next_c, next_h
  25. class EnasMutator(Mutator):
  26. """
  27. A mutator that mutates the graph with RL.
  28. Parameters
  29. ----------
  30. model : nn.Module
  31. PyTorch model.
  32. lstm_size : int
  33. Controller LSTM hidden units.
  34. lstm_num_layers : int
  35. Number of layers for stacked LSTM.
  36. tanh_constant : float
  37. Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
  38. cell_exit_extra_step : bool
  39. If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state
  40. and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
  41. skip_target : float
  42. Target probability that skipconnect will appear.
  43. temperature : float
  44. Temperature constant that divides the logits.
  45. branch_bias : float
  46. Manual bias applied to make some operations more likely to be chosen.
  47. Currently this is implemented with a hardcoded match rule that aligns with original repo.
  48. If a mutable has a ``reduce`` in its key, all its op choices
  49. that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
  50. receive a bias of ``-self.branch_bias``.
  51. entropy_reduction : str
  52. Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
  53. """
  54. def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
  55. skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction="sum"):
  56. super().__init__(model)
  57. self.lstm_size = lstm_size
  58. self.lstm_num_layers = lstm_num_layers
  59. self.tanh_constant = tanh_constant
  60. self.temperature = temperature
  61. self.cell_exit_extra_step = cell_exit_extra_step
  62. self.skip_target = skip_target
  63. self.branch_bias = branch_bias
  64. self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
  65. self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
  66. self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
  67. self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
  68. self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
  69. self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable
  70. assert entropy_reduction in ["sum", "mean"], "Entropy reduction must be one of sum and mean."
  71. self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean
  72. self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
  73. self.bias_dict = nn.ParameterDict()
  74. self.max_layer_choice = 0
  75. for mutable in self.mutables:
  76. if isinstance(mutable, LayerChoice):
  77. if self.max_layer_choice == 0:
  78. self.max_layer_choice = len(mutable)
  79. assert self.max_layer_choice == len(mutable), \
  80. "ENAS mutator requires all layer choice have the same number of candidates."
  81. # We are judging by keys and module types to add biases to layer choices. Needs refactor.
  82. if "reduce" in mutable.key:
  83. def is_conv(choice):
  84. return "conv" in str(type(choice)).lower()
  85. bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable
  86. for choice in mutable])
  87. self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False)
  88. self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
  89. self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False)
  90. def sample_search(self):
  91. self._initialize()
  92. self._sample(self.mutables)
  93. return self._choices
  94. def sample_final(self):
  95. return self.sample_search()
  96. def _sample(self, tree):
  97. mutable = tree.mutable
  98. if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
  99. self._choices[mutable.key] = self._sample_layer_choice(mutable)
  100. elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
  101. self._choices[mutable.key] = self._sample_input_choice(mutable)
  102. for child in tree.children:
  103. self._sample(child)
  104. if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
  105. if self.cell_exit_extra_step:
  106. self._lstm_next_step()
  107. self._mark_anchor(mutable.key)
  108. def _initialize(self):
  109. self._choices = dict()
  110. self._anchors_hid = dict()
  111. self._inputs = self.g_emb.data
  112. self._c = [torch.zeros((1, self.lstm_size),
  113. dtype=self._inputs.dtype,
  114. device=self._inputs.device) for _ in range(self.lstm_num_layers)]
  115. self._h = [torch.zeros((1, self.lstm_size),
  116. dtype=self._inputs.dtype,
  117. device=self._inputs.device) for _ in range(self.lstm_num_layers)]
  118. self.sample_log_prob = 0
  119. self.sample_entropy = 0
  120. self.sample_skip_penalty = 0
  121. def _lstm_next_step(self):
  122. self._c, self._h = self.lstm(self._inputs, (self._c, self._h))
  123. def _mark_anchor(self, key):
  124. self._anchors_hid[key] = self._h[-1]
  125. def _sample_layer_choice(self, mutable):
  126. self._lstm_next_step()
  127. logit = self.soft(self._h[-1])
  128. if self.temperature is not None:
  129. logit /= self.temperature
  130. if self.tanh_constant is not None:
  131. logit = self.tanh_constant * torch.tanh(logit)
  132. if mutable.key in self.bias_dict:
  133. logit += self.bias_dict[mutable.key]
  134. branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
  135. log_prob = self.cross_entropy_loss(logit, branch_id)
  136. self.sample_log_prob += self.entropy_reduction(log_prob)
  137. entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
  138. self.sample_entropy += self.entropy_reduction(entropy)
  139. self._inputs = self.embedding(branch_id)
  140. return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
  141. def _sample_input_choice(self, mutable):
  142. query, anchors = [], []
  143. for label in mutable.choose_from:
  144. if label not in self._anchors_hid:
  145. self._lstm_next_step()
  146. self._mark_anchor(label) # empty loop, fill not found
  147. query.append(self.attn_anchor(self._anchors_hid[label]))
  148. anchors.append(self._anchors_hid[label])
  149. query = torch.cat(query, 0)
  150. query = torch.tanh(query + self.attn_query(self._h[-1]))
  151. query = self.v_attn(query)
  152. if self.temperature is not None:
  153. query /= self.temperature
  154. if self.tanh_constant is not None:
  155. query = self.tanh_constant * torch.tanh(query)
  156. if mutable.n_chosen is None:
  157. logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type
  158. skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
  159. skip_prob = torch.sigmoid(logit)
  160. kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
  161. self.sample_skip_penalty += kl
  162. log_prob = self.cross_entropy_loss(logit, skip)
  163. self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0)
  164. else:
  165. assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
  166. logit = query.view(1, -1)
  167. index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
  168. skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1)
  169. log_prob = self.cross_entropy_loss(logit, index)
  170. self._inputs = anchors[index.item()]
  171. self.sample_log_prob += self.entropy_reduction(log_prob)
  172. entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
  173. self.sample_entropy += self.entropy_reduction(entropy)
  174. return skip.bool()

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能