You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pcdartsmutator.py 7.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import logging
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from collections import OrderedDict
  8. from pytorch.mutator import Mutator
  9. from pytorch.mutables import LayerChoice, InputChoice
  10. _logger = logging.getLogger(__name__)
  11. class PCdartsMutator(Mutator):
  12. """
  13. Connects the model in a PC-DARTS (differentiable) way.
  14. Two connections are automatically inserted for each LayerChoice and InputChoice, when these connections are selected by softmax function.
  15. Ops on the LayerChoice are selected by max top-k probabilities. But channels in the all candicate predecessors on the InputChoice are weighted sum
  16. There is no op on this LayerChoice (namely a ``ZeroOp``), in which case, every element in the exported choice list is ``false``
  17. (not chosen).
  18. All input choice will be fully connected in the search phase. On exporting, the input choice will choose inputs based
  19. on keys in ``choose_from``. If the keys were to be keys of LayerChoices, the top logit of the corresponding LayerChoice
  20. will join the competition of input choice to compete against other logits. Otherwise, the logit will be assumed 0.
  21. It's possible to cut branches by setting parameter ``choices`` in a particular position to ``-inf``. After softmax, the
  22. value would be 0. Framework will ignore 0 values and not connect. Note that the gradient on the ``-inf`` location will
  23. be 0. Since manipulations with ``-inf`` will be ``nan``, you need to handle the gradient update phase carefully.
  24. Attributes
  25. ----------
  26. choices: ParameterDict
  27. dict that maps keys of LayerChoices to weighted-connection float tensors.
  28. """
  29. def __init__(self, model):
  30. super().__init__(model)
  31. self.choices = nn.ParameterDict()
  32. for mutable in self.mutables:
  33. if isinstance(mutable, LayerChoice):
  34. self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1))
  35. if isinstance(mutable, InputChoice):
  36. self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.n_candidates))
  37. def device(self):
  38. for v in self.choices.values():
  39. return v.device
  40. def sample_search(self):
  41. result = dict()
  42. for mutable in self.mutables:
  43. if isinstance(mutable, LayerChoice):
  44. result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1]
  45. elif isinstance(mutable, InputChoice):
  46. result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)
  47. return result
  48. def sample_final(self):
  49. result = dict()
  50. edges_max = dict()
  51. choices = dict()
  52. for mutable in self.mutables:
  53. if isinstance(mutable, LayerChoice):
  54. # multiply the normalized coefficients together to select top-1 op in each LayerChoice
  55. predecessor_idx = int(mutable.key[-1])
  56. inputchoice_key = mutable.key[:-2] + "switch"
  57. choices[mutable.key] = self.choices[mutable.key] * self.choices[inputchoice_key][predecessor_idx]
  58. for mutable in self.mutables:
  59. if isinstance(mutable, LayerChoice):
  60. # select non-none top-1 op
  61. max_val, index = torch.max(F.softmax(choices[mutable.key], dim=-1)[:-1], 0)
  62. edges_max[mutable.key] = max_val
  63. result[mutable.key] = F.one_hot(index, num_classes=len(mutable)).view(-1).bool()
  64. for mutable in self.mutables:
  65. if isinstance(mutable, InputChoice):
  66. if mutable.n_chosen is not None:
  67. weights = []
  68. for src_key in mutable.choose_from:
  69. if src_key not in edges_max:
  70. _logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key)
  71. weights.append(edges_max.get(src_key, 0.))
  72. weights = torch.tensor(weights) # pylint: disable=not-callable
  73. # select top-2 strongest predecessor
  74. _, topk_edge_indices = torch.topk(weights, mutable.n_chosen)
  75. selected_multihot = []
  76. for i, src_key in enumerate(mutable.choose_from):
  77. if i not in topk_edge_indices and src_key in result:
  78. # If an edge is never selected, there is no need to calculate any op on this edge.
  79. # This is to eliminate redundant calculation.
  80. result[src_key] = torch.zeros_like(result[src_key])
  81. selected_multihot.append(i in topk_edge_indices)
  82. result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
  83. else:
  84. result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
  85. return result
  86. def _generate_search_space(self):
  87. """
  88. Generate search space from mutables.
  89. Here is the search space format:
  90. ::
  91. { key_name: {"_type": "layer_choice",
  92. "_value": ["conv1", "conv2"]} }
  93. { key_name: {"_type": "input_choice",
  94. "_value": {"candidates": ["in1", "in2"],
  95. "n_chosen": 1}} }
  96. Returns
  97. -------
  98. dict
  99. the generated search space
  100. """
  101. res = OrderedDict()
  102. res["op_list"] = OrderedDict()
  103. res["search_space"] = {"reduction_cell": OrderedDict(), "normal_cell": OrderedDict()}
  104. keys = []
  105. for mutable in self.mutables:
  106. # for now we only generate flattened search space
  107. if (len(res["search_space"]["reduction_cell"]) + len(res["search_space"]["normal_cell"])) >= 36:
  108. break
  109. if isinstance(mutable, LayerChoice):
  110. key = mutable.key
  111. if key not in keys:
  112. val = mutable.names
  113. if not res["op_list"]:
  114. res["op_list"] = {"_type": "layer_choice", "_value": val + ["none"]}
  115. node_type = "normal_cell" if "normal" in key else "reduction_cell"
  116. res["search_space"][node_type][key] = "op_list"
  117. keys.append(key)
  118. elif isinstance(mutable, InputChoice):
  119. key = mutable.key
  120. if key not in keys:
  121. node_type = "normal_cell" if "normal" in key else "reduction_cell"
  122. res["search_space"][node_type][key] = {"_type": "input_choice",
  123. "_value": {"candidates": mutable.choose_from,
  124. "n_chosen": mutable.n_chosen}}
  125. keys.append(key)
  126. else:
  127. raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
  128. return res

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能