You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fixed.py 6.3 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import json
  4. import logging
  5. import sys
  6. sys.path.append('..'+ '/' + '..')
  7. from pytorch.mutables import InputChoice, LayerChoice, MutableScope
  8. from pytorch.mutator import Mutator
  9. from pytorch.utils import to_list
  10. _logger = logging.getLogger(__name__)
  11. #_logger.setLevel(logging.INFO)
  12. class FixedArchitecture(Mutator):
  13. """
  14. Fixed architecture mutator that always selects a certain graph.
  15. Parameters
  16. ----------
  17. model : nn.Module
  18. A mutable network.
  19. fixed_arc : dict
  20. Preloaded architecture object.
  21. strict : bool
  22. Force everything that appears in ``fixed_arc`` to be used at least once.
  23. """
  24. def __init__(self, model, fixed_arc, strict=True):
  25. super().__init__(model)
  26. self._fixed_arc = fixed_arc
  27. mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)])
  28. fixed_arc_keys = set(self._fixed_arc.keys())
  29. if fixed_arc_keys - mutable_keys:
  30. raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys))
  31. if mutable_keys - fixed_arc_keys:
  32. raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys))
  33. self._fixed_arc = self._from_human_readable_architecture(self._fixed_arc)
  34. def _from_human_readable_architecture(self, human_arc):
  35. # convert from an exported architecture
  36. #print('human_arc',human_arc)
  37. result_arc = {k: to_list(v) for k, v in human_arc.items()} # there could be tensors, numpy arrays, etc.
  38. #print('result_arc',result_arc)
  39. # First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"},
  40. # which means {"op1": [0, ]} ir {"op1": ["conv", ]}
  41. result_arc = {k: v['_value'] if isinstance(v['_value'], list) else [v['_value']] for k, v in result_arc.items()}
  42. # Second, infer which ones are multi-hot arrays and which ones are in human-readable format.
  43. # This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true].
  44. # Here, we assume an multihot array has to be a boolean array or a float array and matches the length.
  45. for mutable in self.mutables:
  46. if mutable.key not in result_arc:
  47. continue # skip silently
  48. choice_arr = result_arc[mutable.key]
  49. if all(isinstance(v, bool) for v in choice_arr) or all(isinstance(v, float) for v in choice_arr):
  50. if (isinstance(mutable, LayerChoice) and len(mutable) == len(choice_arr)) or \
  51. (isinstance(mutable, InputChoice) and mutable.n_candidates == len(choice_arr)):
  52. # multihot, do nothing
  53. continue
  54. if isinstance(mutable, LayerChoice):
  55. choice_arr = [mutable.names.index(val) if isinstance(val, str) else val for val in choice_arr]
  56. choice_arr = [i in choice_arr for i in range(len(mutable))]
  57. elif isinstance(mutable, InputChoice):
  58. choice_arr = [mutable.choose_from.index(val) if isinstance(val, str) else val for val in choice_arr]
  59. choice_arr = [i in choice_arr for i in range(mutable.n_candidates)]
  60. result_arc[mutable.key] = choice_arr
  61. return result_arc
  62. def sample_search(self):
  63. """
  64. Always returns the fixed architecture.
  65. """
  66. return self._fixed_arc
  67. def sample_final(self):
  68. """
  69. Always returns the fixed architecture.
  70. """
  71. return self._fixed_arc
  72. def replace_layer_choice(self, module=None, prefix=""):
  73. """
  74. Replace layer choices with selected candidates. It's done with best effort.
  75. In case of weighted choices or multiple choices. if some of the choices on weighted with zero, delete them.
  76. If single choice, replace the module with a normal module.
  77. Parameters
  78. ----------
  79. module : nn.Module
  80. Module to be processed.
  81. prefix : str
  82. Module name under global namespace.
  83. """
  84. if module is None:
  85. module = self.model
  86. for name, mutable in module.named_children():
  87. global_name = (prefix + "." if prefix else "") + name
  88. if isinstance(mutable, LayerChoice):
  89. chosen = self._fixed_arc[mutable.key]
  90. if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask:
  91. # sum is one, max is one, there has to be an only one
  92. # this is compatible with both integer arrays, boolean arrays and float arrays
  93. _logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1))
  94. setattr(module, name, mutable[chosen.index(1)])
  95. else:
  96. if mutable.return_mask:
  97. _logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \
  98. "LayerChoice will not be replaced.")
  99. # remove unused parameters
  100. for ch, n in zip(chosen, mutable.names):
  101. if ch == 0 and not isinstance(ch, float):
  102. setattr(mutable, n, None)
  103. else:
  104. self.replace_layer_choice(mutable, global_name)
  105. def apply_fixed_architecture(model, fixed_arc):
  106. """
  107. Load architecture from `fixed_arc` and apply to model.
  108. Parameters
  109. ----------
  110. model : torch.nn.Module
  111. Model with mutables.
  112. fixed_arc : str or dict
  113. Path to the JSON that stores the architecture, or dict that stores the exported architecture.
  114. Returns
  115. -------
  116. FixedArchitecture
  117. Mutator that is responsible for fixes the graph.
  118. """
  119. if isinstance(fixed_arc, str):
  120. with open(fixed_arc) as f:
  121. fixed_arc = json.load(f)
  122. architecture = FixedArchitecture(model, fixed_arc)
  123. architecture.reset()
  124. # for the convenience of parameters counting
  125. architecture.replace_layer_choice()
  126. return architecture

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能