You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fixed.py 6.1 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import json
  4. import logging
  5. from .mutables import InputChoice, LayerChoice, MutableScope
  6. from .mutator import Mutator
  7. from .utils import to_list
  8. _logger = logging.getLogger(__name__)
  9. _logger.setLevel(logging.INFO)
  10. class FixedArchitecture(Mutator):
  11. """
  12. Fixed architecture mutator that always selects a certain graph.
  13. Parameters
  14. ----------
  15. model : nn.Module
  16. A mutable network.
  17. fixed_arc : dict
  18. Preloaded architecture object.
  19. strict : bool
  20. Force everything that appears in ``fixed_arc`` to be used at least once.
  21. """
  22. def __init__(self, model, fixed_arc, strict=True):
  23. super().__init__(model)
  24. self._fixed_arc = fixed_arc
  25. mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)])
  26. fixed_arc_keys = set(self._fixed_arc.keys())
  27. if fixed_arc_keys - mutable_keys:
  28. raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys))
  29. if mutable_keys - fixed_arc_keys:
  30. raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys))
  31. self._fixed_arc = self._from_human_readable_architecture(self._fixed_arc)
  32. def _from_human_readable_architecture(self, human_arc):
  33. # convert from an exported architecture
  34. result_arc = {k: to_list(v) for k, v in human_arc.items()} # there could be tensors, numpy arrays, etc.
  35. # First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"},
  36. # which means {"op1": [0, ]} ir {"op1": ["conv", ]}
  37. result_arc = {k: v if isinstance(v, list) else [v] for k, v in result_arc.items()}
  38. # Second, infer which ones are multi-hot arrays and which ones are in human-readable format.
  39. # This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true].
  40. # Here, we assume an multihot array has to be a boolean array or a float array and matches the length.
  41. for mutable in self.mutables:
  42. if mutable.key not in result_arc:
  43. continue # skip silently
  44. choice_arr = result_arc[mutable.key]
  45. if all(isinstance(v, bool) for v in choice_arr) or all(isinstance(v, float) for v in choice_arr):
  46. if (isinstance(mutable, LayerChoice) and len(mutable) == len(choice_arr)) or \
  47. (isinstance(mutable, InputChoice) and mutable.n_candidates == len(choice_arr)):
  48. # multihot, do nothing
  49. continue
  50. if isinstance(mutable, LayerChoice):
  51. choice_arr = [mutable.names.index(val) if isinstance(val, str) else val for val in choice_arr]
  52. choice_arr = [i in choice_arr for i in range(len(mutable))]
  53. elif isinstance(mutable, InputChoice):
  54. choice_arr = [mutable.choose_from.index(val) if isinstance(val, str) else val for val in choice_arr]
  55. choice_arr = [i in choice_arr for i in range(mutable.n_candidates)]
  56. result_arc[mutable.key] = choice_arr
  57. return result_arc
  58. def sample_search(self):
  59. """
  60. Always returns the fixed architecture.
  61. """
  62. return self._fixed_arc
  63. def sample_final(self):
  64. """
  65. Always returns the fixed architecture.
  66. """
  67. return self._fixed_arc
  68. def replace_layer_choice(self, module=None, prefix=""):
  69. """
  70. Replace layer choices with selected candidates. It's done with best effort.
  71. In case of weighted choices or multiple choices. if some of the choices on weighted with zero, delete them.
  72. If single choice, replace the module with a normal module.
  73. Parameters
  74. ----------
  75. module : nn.Module
  76. Module to be processed.
  77. prefix : str
  78. Module name under global namespace.
  79. """
  80. if module is None:
  81. module = self.model
  82. for name, mutable in module.named_children():
  83. global_name = (prefix + "." if prefix else "") + name
  84. if isinstance(mutable, LayerChoice):
  85. chosen = self._fixed_arc[mutable.key]
  86. if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask:
  87. # sum is one, max is one, there has to be an only one
  88. # this is compatible with both integer arrays, boolean arrays and float arrays
  89. _logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1))
  90. setattr(module, name, mutable[chosen.index(1)])
  91. else:
  92. if mutable.return_mask:
  93. _logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \
  94. "LayerChoice will not be replaced.")
  95. # remove unused parameters
  96. for ch, n in zip(chosen, mutable.names):
  97. if ch == 0 and not isinstance(ch, float):
  98. setattr(mutable, n, None)
  99. else:
  100. self.replace_layer_choice(mutable, global_name)
  101. def apply_fixed_architecture(model, fixed_arc):
  102. """
  103. Load architecture from `fixed_arc` and apply to model.
  104. Parameters
  105. ----------
  106. model : torch.nn.Module
  107. Model with mutables.
  108. fixed_arc : str or dict
  109. Path to the JSON that stores the architecture, or dict that stores the exported architecture.
  110. Returns
  111. -------
  112. FixedArchitecture
  113. Mutator that is responsible for fixes the graph.
  114. """
  115. if isinstance(fixed_arc, str):
  116. with open(fixed_arc) as f:
  117. fixed_arc = json.load(f)
  118. architecture = FixedArchitecture(model, fixed_arc)
  119. architecture.reset()
  120. # for the convenience of parameters counting
  121. architecture.replace_layer_choice()
  122. return architecture

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能