You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dartsmutator.py 6.3 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import logging
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from collections import OrderedDict
  8. from pytorch.mutator import Mutator
  9. from pytorch.mutables import LayerChoice, InputChoice
  10. _logger = logging.getLogger(__name__)
  11. class DartsMutator(Mutator):
  12. """
  13. Connects the model in a DARTS (differentiable) way.
  14. An extra connection is automatically inserted for each LayerChoice, when this connection is selected, there is no
  15. op on this LayerChoice (namely a ``ZeroOp``), in which case, every element in the exported choice list is ``false``
  16. (not chosen).
  17. All input choice will be fully connected in the search phase. On exporting, the input choice will choose inputs based
  18. on keys in ``choose_from``. If the keys were to be keys of LayerChoices, the top logit of the corresponding LayerChoice
  19. will join the competition of input choice to compete against other logits. Otherwise, the logit will be assumed 0.
  20. It's possible to cut branches by setting parameter ``choices`` in a particular position to ``-inf``. After softmax, the
  21. value would be 0. Framework will ignore 0 values and not connect. Note that the gradient on the ``-inf`` location will
  22. be 0. Since manipulations with ``-inf`` will be ``nan``, you need to handle the gradient update phase carefully.
  23. Attributes
  24. ----------
  25. choices: ParameterDict
  26. dict that maps keys of LayerChoices to weighted-connection float tensors.
  27. """
  28. def __init__(self, model):
  29. super().__init__(model)
  30. self.choices = nn.ParameterDict()
  31. for mutable in self.mutables:
  32. if isinstance(mutable, LayerChoice):
  33. self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1))
  34. def device(self):
  35. for v in self.choices.values():
  36. return v.device
  37. def sample_search(self):
  38. result = dict()
  39. for mutable in self.mutables:
  40. if isinstance(mutable, LayerChoice):
  41. result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1]
  42. elif isinstance(mutable, InputChoice):
  43. result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device())
  44. return result
  45. def sample_final(self):
  46. result = dict()
  47. edges_max = dict()
  48. for mutable in self.mutables:
  49. if isinstance(mutable, LayerChoice):
  50. max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0)
  51. edges_max[mutable.key] = max_val
  52. result[mutable.key] = F.one_hot(index, num_classes=len(mutable)).view(-1).bool()
  53. for mutable in self.mutables:
  54. if isinstance(mutable, InputChoice):
  55. if mutable.n_chosen is not None:
  56. weights = []
  57. for src_key in mutable.choose_from:
  58. if src_key not in edges_max:
  59. _logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key)
  60. weights.append(edges_max.get(src_key, 0.))
  61. weights = torch.tensor(weights) # pylint: disable=not-callable
  62. _, topk_edge_indices = torch.topk(weights, mutable.n_chosen)
  63. selected_multihot = []
  64. for i, src_key in enumerate(mutable.choose_from):
  65. if i not in topk_edge_indices and src_key in result:
  66. # If an edge is never selected, there is no need to calculate any op on this edge.
  67. # This is to eliminate redundant calculation.
  68. result[src_key] = torch.zeros_like(result[src_key])
  69. selected_multihot.append(i in topk_edge_indices)
  70. result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
  71. else:
  72. result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
  73. return result
  74. def _generate_search_space(self):
  75. """
  76. Generate search space from mutables.
  77. Here is the search space format:
  78. ::
  79. { key_name: {"_type": "layer_choice",
  80. "_value": ["conv1", "conv2"]} }
  81. { key_name: {"_type": "input_choice",
  82. "_value": {"candidates": ["in1", "in2"],
  83. "n_chosen": 1}} }
  84. Returns
  85. -------
  86. dict
  87. the generated search space
  88. """
  89. res = OrderedDict()
  90. res["op_list"] = OrderedDict()
  91. res["search_space"] = OrderedDict()
  92. # res["normal_cell"] = OrderedDict(),
  93. # res["reduction_cell"] = OrderedDict()
  94. keys = []
  95. for mutable in self.mutables:
  96. # for now we only generate flattened search space
  97. if (len(res["search_space"])) >= 36:
  98. break
  99. if isinstance(mutable, LayerChoice):
  100. key = mutable.key
  101. if key not in keys:
  102. val = mutable.names
  103. if not res["op_list"]:
  104. res["op_list"] = {"_type": "layer_choice", "_value": val + ["none"]}
  105. # node_type = "normal_cell" if "normal" in key else "reduction_cell"
  106. res["search_space"][key] = "op_list"
  107. keys.append(key)
  108. elif isinstance(mutable, InputChoice):
  109. key = mutable.key
  110. if key not in keys:
  111. # node_type = "normal_cell" if "normal" in key else "reduction_cell"
  112. res["search_space"][key] = {"_type": "input_choice",
  113. "_value": {"candidates": mutable.choose_from,
  114. "n_chosen": mutable.n_chosen}}
  115. keys.append(key)
  116. else:
  117. raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
  118. return res

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能