You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pdartsmutator.py 8.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import copy
  2. import numpy as np
  3. import torch
  4. import logging
  5. from collections import OrderedDict
  6. from torch import nn
  7. from pytorch.darts.dartsmutator import DartsMutator
  8. from pytorch.mutables import LayerChoice, InputChoice
  9. logger = logging.getLogger(__name__)
  10. class PdartsMutator(DartsMutator):
  11. """
  12. It works with PdartsTrainer to calculate ops weights,
  13. and drop weights in different PDARTS epochs.
  14. """
  15. def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}):
  16. self.pdarts_epoch_index = pdarts_epoch_index
  17. self.pdarts_num_to_drop = pdarts_num_to_drop
  18. # save the last two switches and choices for restrict skip
  19. self.last_two_switches = None
  20. self.last_two_choices = None
  21. if switches is None:
  22. self.switches = {}
  23. else:
  24. self.switches = switches
  25. super(PdartsMutator, self).__init__(model)
  26. # this loop go through mutables with different keys,
  27. # it's mainly to update length of choices.
  28. for mutable in self.mutables:
  29. if isinstance(mutable, LayerChoice):
  30. switches = self.switches.get(mutable.key, [True for j in range(len(mutable))])
  31. # choices = self.choices[mutable.key]
  32. operations_count = np.sum(switches)
  33. # +1 and -1 are caused by zero operation in darts network
  34. # the zero operation is not in choices list(switches) in network, but its weight are in,
  35. # so it needs one more weights and switch for zero.
  36. self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(operations_count + 1))
  37. self.switches[mutable.key] = switches
  38. # update LayerChoice instances in model,
  39. # it's physically remove dropped choices operations.
  40. for module in self.model.modules():
  41. if isinstance(module, LayerChoice):
  42. switches = self.switches.get(module.key)
  43. choices = self.choices[module.key]
  44. if len(module) > len(choices):
  45. # from last to first, so that it won't effect previous indexes after removed one.
  46. for index in range(len(switches)-1, -1, -1):
  47. if switches[index] == False:
  48. del module[index]
  49. assert len(module) <= len(choices), "Failed to remove dropped choices."
  50. def export(self, last, switches):
  51. # In last pdarts_epoches, need to restrict skipconnection
  52. # Cannot rely on super().export() because P-DARTS has deleted some of the choices and has misaligned length.
  53. if last:
  54. # restrict Up to 2 skipconnect (normal cell only)
  55. name = "normal"
  56. max_num = 2
  57. skip_num = self.check_skip_num(name, switches)
  58. logger.info("Initially, the number of skipconnect is {}.".format(skip_num))
  59. while skip_num > max_num:
  60. logger.info("Restricting {} skipconnect to {}.".format(skip_num, max_num))
  61. logger.info("Original normal_switch is {}.".format(switches))
  62. # update self.choices setting skip prob to 0 and self.switches setting skip prob to False
  63. switches = self.delete_min_sk(name, switches)
  64. logger.info("Restricted normal_switch is {}.".format(switches))
  65. skip_num = self.check_skip_num(name, switches)
  66. # from bool result convert to human readable by Mutator export()
  67. results = super().sample_final()
  68. for mutable in self.mutables:
  69. if isinstance(mutable, LayerChoice):
  70. # As some operations are dropped physically,
  71. # so it needs to fill back false to track dropped operations.
  72. trained_result = results[mutable.key]
  73. trained_index = 0
  74. switches = self.switches[mutable.key]
  75. result = torch.Tensor(switches).bool()
  76. for index in range(len(result)):
  77. if result[index]:
  78. result[index] = trained_result[trained_index]
  79. trained_index += 1
  80. results[mutable.key] = result
  81. return results
  82. def drop_paths(self):
  83. """
  84. This method is called when a PDARTS epoch is finished.
  85. It prepares switches for next epoch.
  86. candidate operations with False switch will be doppped in next epoch.
  87. """
  88. all_switches = copy.deepcopy(self.switches)
  89. for key in all_switches:
  90. switches = all_switches[key]
  91. idxs = []
  92. for j in range(len(switches)):
  93. if switches[j]:
  94. idxs.append(j)
  95. sorted_weights = self.choices[key].data.cpu().numpy()[:-1]
  96. drop = np.argsort(sorted_weights)[:self.pdarts_num_to_drop[self.pdarts_epoch_index]]
  97. for idx in drop:
  98. switches[idxs[idx]] = False
  99. return all_switches
  100. def check_skip_num(self, name, switches):
  101. counter = 0
  102. for key in switches:
  103. if name in key:
  104. # zero operation not in switches, so "skipconnect" in 2
  105. if switches[key][2]:
  106. counter += 1
  107. return counter
  108. def delete_min_sk(self, name, switches):
  109. def _get_sk_idx(key, switches):
  110. if not switches[key][2]:
  111. idx = -1
  112. else:
  113. idx = 0
  114. for i in range(2):
  115. # switches has 1 True, self.switches has 2 True
  116. if self.switches[key][i]:
  117. idx += 1
  118. return idx
  119. sk_choices = [1.0 for i in range(14)]
  120. sk_keys = [None for i in range(14)] # key has skip connection
  121. sk_choices_idx = -1
  122. for key in switches:
  123. if name in key:
  124. # default key in order
  125. sk_choices_idx += 1
  126. idx = _get_sk_idx(key, switches)
  127. if not idx == -1:
  128. sk_keys[sk_choices_idx] = key
  129. sk_choices[sk_choices_idx] = self.choices[key][idx]
  130. min_sk_idx = np.argmin(sk_choices)
  131. idx = _get_sk_idx(sk_keys[min_sk_idx], switches)
  132. # modify self.choices or copy.deepcopy ?
  133. self.choices[sk_keys[min_sk_idx]][idx] = 0.0
  134. # modify self.switches or copy.deepcopy ?
  135. # self.switches indicate last two switches, and switches indicate present(last) switches
  136. self.switches[sk_keys[min_sk_idx]][2] = False
  137. switches[sk_keys[min_sk_idx]][2] = False
  138. return switches
  139. def _generate_search_space(self):
  140. """
  141. Generate search space from mutables.
  142. Here is the search space format:
  143. ::
  144. { key_name: {"_type": "layer_choice",
  145. "_value": ["conv1", "conv2"]} }
  146. { key_name: {"_type": "input_choice",
  147. "_value": {"candidates": ["in1", "in2"],
  148. "n_chosen": 1}} }
  149. Returns
  150. -------
  151. dict
  152. the generated search space
  153. """
  154. res = OrderedDict()
  155. res["op_list"] = OrderedDict()
  156. res["search_space"] = {"reduction_cell": OrderedDict(), "normal_cell": OrderedDict()}
  157. keys = []
  158. for mutable in self.mutables:
  159. # for now we only generate flattened search space
  160. if (len(res["search_space"]["reduction_cell"]) + len(res["search_space"]["normal_cell"])) >= 36:
  161. break
  162. if isinstance(mutable, LayerChoice):
  163. key = mutable.key
  164. if key not in keys:
  165. val = mutable.names
  166. if not res["op_list"]:
  167. res["op_list"] = {"_type": "layer_choice", "_value": val + ["none"]}
  168. node_type = "normal_cell" if "normal" in key else "reduction_cell"
  169. res["search_space"][node_type][key] = "op_list"
  170. keys.append(key)
  171. elif isinstance(mutable, InputChoice):
  172. key = mutable.key
  173. if key not in keys:
  174. node_type = "normal_cell" if "normal" in key else "reduction_cell"
  175. res["search_space"][node_type][key] = {"_type": "input_choice",
  176. "_value": {"candidates": mutable.choose_from,
  177. "n_chosen": mutable.n_chosen}}
  178. keys.append(key)
  179. else:
  180. raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
  181. return res

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能