You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mutator.py 12 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import logging
  4. from collections import defaultdict
  5. import numpy as np
  6. import torch
  7. from .base_mutator import BaseMutator
  8. from .mutables import LayerChoice, InputChoice
  9. from .utils import to_list
  10. logger = logging.getLogger(__name__)
  11. logger.setLevel(logging.INFO)
  12. class Mutator(BaseMutator):
  13. def __init__(self, model):
  14. super().__init__(model)
  15. self._cache = dict()
  16. self._connect_all = False
  17. def sample_search(self):
  18. """
  19. Override to implement this method to iterate over mutables and make decisions.
  20. Returns
  21. -------
  22. dict
  23. A mapping from key of mutables to decisions.
  24. """
  25. raise NotImplementedError
  26. def sample_final(self):
  27. """
  28. Override to implement this method to iterate over mutables and make decisions that is final
  29. for export and retraining.
  30. Returns
  31. -------
  32. dict
  33. A mapping from key of mutables to decisions.
  34. """
  35. raise NotImplementedError
  36. def reset(self):
  37. """
  38. Reset the mutator by call the `sample_search` to resample (for search). Stores the result in a local
  39. variable so that `on_forward_layer_choice` and `on_forward_input_choice` can use the decision directly.
  40. """
  41. self._cache = self.sample_search()
  42. def export(self):
  43. """
  44. Resample (for final) and return results.
  45. Returns
  46. -------
  47. dict
  48. A mapping from key of mutables to decisions.
  49. """
  50. sampled = self.sample_final()
  51. result = dict()
  52. for mutable in self.mutables:
  53. if not isinstance(mutable, (LayerChoice, InputChoice)):
  54. # not supported as built-in
  55. continue
  56. result[mutable.key] = self._convert_mutable_decision_to_human_readable(mutable, sampled.pop(mutable.key))
  57. if sampled:
  58. raise ValueError("Unexpected keys returned from 'sample_final()': %s", list(sampled.keys()))
  59. return result
  60. def status(self):
  61. """
  62. Return current selection status of mutator.
  63. Returns
  64. -------
  65. dict
  66. A mapping from key of mutables to decisions. All weights (boolean type and float type)
  67. are converted into real number values. Numpy arrays and tensors are converted into list.
  68. """
  69. data = dict()
  70. for k, v in self._cache.items():
  71. if torch.is_tensor(v):
  72. v = v.detach().cpu().numpy()
  73. if isinstance(v, np.ndarray):
  74. v = v.astype(np.float32).tolist()
  75. data[k] = v
  76. return data
  77. def graph(self, inputs):
  78. """
  79. Return model supernet graph.
  80. Parameters
  81. ----------
  82. inputs: tuple of tensor
  83. Inputs that will be feeded into the network.
  84. Returns
  85. -------
  86. dict
  87. Containing ``node``, in Tensorboard GraphDef format.
  88. Additional key ``mutable`` is a map from key to list of modules.
  89. """
  90. if not torch.__version__.startswith("1.4"):
  91. logger.warning("Graph is only tested with PyTorch 1.4. Other versions might not work.")
  92. from nni._graph_utils import build_graph
  93. from google.protobuf import json_format
  94. # protobuf should be installed as long as tensorboard is installed
  95. try:
  96. self._connect_all = True
  97. graph_def, _ = build_graph(self.model, inputs, verbose=False)
  98. result = json_format.MessageToDict(graph_def)
  99. finally:
  100. self._connect_all = False
  101. # `mutable` is to map the keys to a list of corresponding modules.
  102. # A key can be linked to multiple modules, use `dedup=False` to find them all.
  103. result["mutable"] = defaultdict(list)
  104. for mutable in self.mutables.traverse(deduplicate=False):
  105. # A module will be represent in the format of
  106. # [{"type": "Net", "name": ""}, {"type": "Cell", "name": "cell1"}, {"type": "Conv2d": "name": "conv"}]
  107. # which will be concatenated into Net/Cell[cell1]/Conv2d[conv] in frontend.
  108. # This format is aligned with the scope name jit gives.
  109. modules = mutable.name.split(".")
  110. path = [
  111. {"type": self.model.__class__.__name__, "name": ""}
  112. ]
  113. m = self.model
  114. for module in modules:
  115. m = getattr(m, module)
  116. path.append({
  117. "type": m.__class__.__name__,
  118. "name": module
  119. })
  120. result["mutable"][mutable.key].append(path)
  121. return result
  122. def on_forward_layer_choice(self, mutable, *args, **kwargs):
  123. """
  124. On default, this method retrieves the decision obtained previously, and select certain operations.
  125. Only operations with non-zero weight will be executed. The results will be added to a list.
  126. Then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`.
  127. Parameters
  128. ----------
  129. mutable : LayerChoice
  130. Layer choice module.
  131. args : list of torch.Tensor
  132. Inputs
  133. kwargs : dict
  134. Inputs
  135. Returns
  136. -------
  137. tuple of torch.Tensor and torch.Tensor
  138. Output and mask.
  139. """
  140. if self._connect_all:
  141. return self._all_connect_tensor_reduction(mutable.reduction,
  142. [op(*args, **kwargs) for op in mutable]), \
  143. torch.ones(len(mutable)).bool()
  144. def _map_fn(op, args, kwargs):
  145. return op(*args, **kwargs)
  146. mask = self._get_decision(mutable)
  147. assert len(mask) == len(mutable), \
  148. "Invalid mask, expected {} to be of length {}.".format(mask, len(mutable))
  149. out, mask = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask)
  150. return self._tensor_reduction(mutable.reduction, out), mask
  151. def on_forward_input_choice(self, mutable, tensor_list):
  152. """
  153. On default, this method retrieves the decision obtained previously, and select certain tensors.
  154. Then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`.
  155. Parameters
  156. ----------
  157. mutable : InputChoice
  158. Input choice module.
  159. tensor_list : list of torch.Tensor
  160. Tensor list to apply the decision on.
  161. Returns
  162. -------
  163. tuple of torch.Tensor and torch.Tensor
  164. Output and mask.
  165. """
  166. if self._connect_all:
  167. return self._all_connect_tensor_reduction(mutable.reduction, tensor_list), \
  168. torch.ones(mutable.n_candidates).bool()
  169. mask = self._get_decision(mutable)
  170. assert len(mask) == mutable.n_candidates, \
  171. "Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates)
  172. out, mask = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
  173. return self._tensor_reduction(mutable.reduction, out), mask
  174. def _select_with_mask(self, map_fn, candidates, mask):
  175. """
  176. Select masked tensors and return a list of tensors.
  177. Parameters
  178. ----------
  179. map_fn : function
  180. Convert candidates to target candidates. Can be simply identity.
  181. candidates : list of torch.Tensor
  182. Tensor list to apply the decision on.
  183. mask : list-like object
  184. Can be a list, an numpy array or a tensor (recommended). Needs to
  185. have the same length as ``candidates``.
  186. Returns
  187. -------
  188. tuple of list of torch.Tensor and torch.Tensor
  189. Output and mask.
  190. """
  191. if (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], bool)) or \
  192. (isinstance(mask, np.ndarray) and mask.dtype == np.bool) or \
  193. "BoolTensor" in mask.type():
  194. out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
  195. elif (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], (float, int))) or \
  196. (isinstance(mask, np.ndarray) and mask.dtype in (np.float32, np.float64, np.int32, np.int64)) or \
  197. "FloatTensor" in mask.type():
  198. out = [map_fn(*cand) * m for cand, m in zip(candidates, mask) if m]
  199. else:
  200. raise ValueError("Unrecognized mask '%s'" % mask)
  201. if not torch.is_tensor(mask):
  202. mask = torch.tensor(mask) # pylint: disable=not-callable
  203. return out, mask
  204. def _tensor_reduction(self, reduction_type, tensor_list):
  205. if reduction_type == "none":
  206. return tensor_list
  207. if not tensor_list:
  208. return None # empty. return None for now
  209. if len(tensor_list) == 1:
  210. return tensor_list[0]
  211. if reduction_type == "sum":
  212. return sum(tensor_list)
  213. if reduction_type == "mean":
  214. return sum(tensor_list) / len(tensor_list)
  215. if reduction_type == "concat":
  216. return torch.cat(tensor_list, dim=1)
  217. raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
  218. def _all_connect_tensor_reduction(self, reduction_type, tensor_list):
  219. if reduction_type == "none":
  220. return tensor_list
  221. if reduction_type == "concat":
  222. return torch.cat(tensor_list, dim=1)
  223. return torch.stack(tensor_list).sum(0)
  224. def _get_decision(self, mutable):
  225. """
  226. By default, this method checks whether `mutable.key` is already in the decision cache,
  227. and returns the result without double-check.
  228. Parameters
  229. ----------
  230. mutable : Mutable
  231. Returns
  232. -------
  233. object
  234. """
  235. if mutable.key not in self._cache:
  236. raise ValueError("\"{}\" not found in decision cache.".format(mutable.key))
  237. result = self._cache[mutable.key]
  238. logger.debug("Decision %s: %s", mutable.key, result)
  239. return result
  240. def _convert_mutable_decision_to_human_readable(self, mutable, sampled):
  241. # Assert the existence of mutable.key in returned architecture.
  242. # Also check if there is anything extra.
  243. multihot_list = to_list(sampled)
  244. converted = None
  245. # If it's a boolean array, we can do optimization.
  246. if all([t == 0 or t == 1 for t in multihot_list]):
  247. if isinstance(mutable, LayerChoice):
  248. assert len(multihot_list) == len(mutable), \
  249. "Results returned from 'sample_final()' (%s: %s) either too short or too long." \
  250. % (mutable.key, multihot_list)
  251. # check if all modules have different names and they indeed have names
  252. if len(set(mutable.names)) == len(mutable) and not all(d.isdigit() for d in mutable.names):
  253. converted = [name for i, name in enumerate(mutable.names) if multihot_list[i]]
  254. else:
  255. converted = [i for i in range(len(multihot_list)) if multihot_list[i]]
  256. if isinstance(mutable, InputChoice):
  257. assert len(multihot_list) == mutable.n_candidates, \
  258. "Results returned from 'sample_final()' (%s: %s) either too short or too long." \
  259. % (mutable.key, multihot_list)
  260. # check if all input candidates have different names
  261. if len(set(mutable.choose_from)) == mutable.n_candidates:
  262. converted = [name for i, name in enumerate(mutable.choose_from) if multihot_list[i]]
  263. else:
  264. converted = [i for i in range(len(multihot_list)) if multihot_list[i]]
  265. if converted is not None:
  266. # if only one element, then remove the bracket
  267. if len(converted) == 1:
  268. converted = converted[0]
  269. else:
  270. # do nothing
  271. converted = multihot_list
  272. return converted

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能