You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mutables.py 14 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import logging
  4. import warnings
  5. from collections import OrderedDict
  6. import torch.nn as nn
  7. from .utils import global_mutable_counting
  8. logger = logging.getLogger(__name__)
  9. logger.setLevel(logging.INFO)
  10. class Mutable(nn.Module):
  11. """
  12. Mutable is designed to function as a normal layer, with all necessary operators' weights.
  13. States and weights of architectures should be included in mutator, instead of the layer itself.
  14. Mutable has a key, which marks the identity of the mutable. This key can be used by users to share
  15. decisions among different mutables. In mutator's implementation, mutators should use the key to
  16. distinguish different mutables. Mutables that share the same key should be "similar" to each other.
  17. Currently the default scope for keys is global. By default, the keys uses a global counter from 1 to
  18. produce unique ids.
  19. Parameters
  20. ----------
  21. key : str
  22. The key of mutable.
  23. Notes
  24. -----
  25. The counter is program level, but mutables are model level. In case multiple models are defined, and
  26. you want to have `counter` starting from 1 in the second model, it's recommended to assign keys manually
  27. instead of using automatic keys.
  28. """
  29. def __init__(self, key=None):
  30. super().__init__()
  31. if key is not None:
  32. if not isinstance(key, str):
  33. key = str(key)
  34. logger.warning("Warning: key \"%s\" is not string, converted to string.", key)
  35. self._key = key
  36. else:
  37. self._key = self.__class__.__name__ + str(global_mutable_counting())
  38. self.init_hook = self.forward_hook = None
  39. def __deepcopy__(self, memodict=None):
  40. raise NotImplementedError("Deep copy doesn't work for mutables.")
  41. def __call__(self, *args, **kwargs):
  42. self._check_built()
  43. return super().__call__(*args, **kwargs)
  44. def set_mutator(self, mutator):
  45. if "mutator" in self.__dict__:
  46. raise RuntimeError("`set_mutator` is called more than once. Did you parse the search space multiple times? "
  47. "Or did you apply multiple fixed architectures?")
  48. self.__dict__["mutator"] = mutator
  49. @property
  50. def key(self):
  51. """
  52. Read-only property of key.
  53. """
  54. return self._key
  55. @property
  56. def name(self):
  57. """
  58. After the search space is parsed, it will be the module name of the mutable.
  59. """
  60. return self._name if hasattr(self, "_name") else "_key"
  61. @name.setter
  62. def name(self, name):
  63. self._name = name
  64. def _check_built(self):
  65. if not hasattr(self, "mutator"):
  66. raise ValueError(
  67. "Mutator not set for {}. You might have forgotten to initialize and apply your mutator. "
  68. "Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
  69. "so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
  70. class MutableScope(Mutable):
  71. """
  72. Mutable scope marks a subgraph/submodule to help mutators make better decisions.
  73. If not annotated with mutable scope, search space will be flattened as a list. However, some mutators might
  74. need to leverage the concept of a "cell". So if a module is defined as a mutable scope, everything in it will
  75. look like "sub-search-space" in the scope. Scopes can be nested.
  76. There are two ways mutators can use mutable scope. One is to traverse the search space as a tree during initialization
  77. and reset. The other is to implement `enter_mutable_scope` and `exit_mutable_scope`. They are called before and after
  78. the forward method of the class inheriting mutable scope.
  79. Mutable scopes are also mutables that are listed in the mutator.mutables (search space), but they are not supposed
  80. to appear in the dict of choices.
  81. Parameters
  82. ----------
  83. key : str
  84. Key of mutable scope.
  85. """
  86. def __init__(self, key):
  87. super().__init__(key=key)
  88. def __call__(self, *args, **kwargs):
  89. try:
  90. self._check_built()
  91. self.mutator.enter_mutable_scope(self)
  92. return super().__call__(*args, **kwargs)
  93. finally:
  94. self.mutator.exit_mutable_scope(self)
  95. class LayerChoice(Mutable):
  96. """
  97. Layer choice selects one of the ``op_candidates``, then apply it on inputs and return results.
  98. In rare cases, it can also select zero or many.
  99. Layer choice does not allow itself to be nested.
  100. Parameters
  101. ----------
  102. op_candidates : list of nn.Module or OrderedDict
  103. A module list to be selected from.
  104. reduction : str
  105. ``mean``, ``concat``, ``sum`` or ``none``. Policy if multiples are selected.
  106. If ``none``, a list is returned. ``mean`` returns the average. ``sum`` returns the sum.
  107. ``concat`` concatenate the list at dimension 1.
  108. return_mask : bool
  109. If ``return_mask``, return output tensor and a mask. Otherwise return tensor only.
  110. key : str
  111. Key of the input choice.
  112. Attributes
  113. ----------
  114. length : int
  115. Deprecated. Number of ops to choose from. ``len(layer_choice)`` is recommended.
  116. names : list of str
  117. Names of candidates.
  118. choices : list of Module
  119. Deprecated. A list of all candidate modules in the layer choice module.
  120. ``list(layer_choice)`` is recommended, which will serve the same purpose.
  121. Notes
  122. -----
  123. ``op_candidates`` can be a list of modules or a ordered dict of named modules, for example,
  124. .. code-block:: python
  125. self.op_choice = LayerChoice(OrderedDict([
  126. ("conv3x3", nn.Conv2d(3, 16, 128)),
  127. ("conv5x5", nn.Conv2d(5, 16, 128)),
  128. ("conv7x7", nn.Conv2d(7, 16, 128))
  129. ]))
  130. Elements in layer choice can be modified or deleted. Use ``del self.op_choice["conv5x5"]`` or
  131. ``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
  132. """
  133. def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None):
  134. super().__init__(key=key)
  135. self.names = []
  136. if isinstance(op_candidates, OrderedDict):
  137. for name, module in op_candidates.items():
  138. assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
  139. "Please don't use a reserved name '{}' for your module.".format(name)
  140. self.add_module(name, module)
  141. self.names.append(name)
  142. elif isinstance(op_candidates, list):
  143. for i, module in enumerate(op_candidates):
  144. self.add_module(str(i), module)
  145. self.names.append(str(i))
  146. else:
  147. raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates)))
  148. self.reduction = reduction
  149. self.return_mask = return_mask
  150. def __getitem__(self, idx):
  151. if isinstance(idx, str):
  152. return self._modules[idx]
  153. return list(self)[idx]
  154. def __setitem__(self, idx, module):
  155. key = idx if isinstance(idx, str) else self.names[idx]
  156. return setattr(self, key, module)
  157. def __delitem__(self, idx):
  158. if isinstance(idx, slice):
  159. for key in self.names[idx]:
  160. delattr(self, key)
  161. else:
  162. if isinstance(idx, str):
  163. key, idx = idx, self.names.index(idx)
  164. else:
  165. key = self.names[idx]
  166. delattr(self, key)
  167. del self.names[idx]
  168. @property
  169. def length(self):
  170. warnings.warn("layer_choice.length is deprecated. Use `len(layer_choice)` instead.", DeprecationWarning)
  171. return len(self)
  172. def __len__(self):
  173. return len(self.names)
  174. def __iter__(self):
  175. return map(lambda name: self._modules[name], self.names)
  176. @property
  177. def choices(self):
  178. warnings.warn("layer_choice.choices is deprecated. Use `list(layer_choice)` instead.", DeprecationWarning)
  179. return list(self)
  180. def forward(self, *args, **kwargs):
  181. """
  182. Returns
  183. -------
  184. tuple of tensors
  185. Output and selection mask. If ``return_mask`` is ``False``, only output is returned.
  186. """
  187. out, mask = self.mutator.on_forward_layer_choice(self, *args, **kwargs)
  188. if self.return_mask:
  189. return out, mask
  190. return out
  191. class InputChoice(Mutable):
  192. """
  193. Input choice selects ``n_chosen`` inputs from ``choose_from`` (contains ``n_candidates`` keys). For beginners,
  194. use ``n_candidates`` instead of ``choose_from`` is a safe option. To get the most power out of it, you might want to
  195. know about ``choose_from``.
  196. The keys in ``choose_from`` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones.
  197. The keys are designed to be the keys of the sources. To help mutators make better decisions,
  198. mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the
  199. output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g.,
  200. ``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a
  201. module/submodule, it needs to be annotated with a key: that's where a :class:`MutableScope` is needed.
  202. In the example below, ``input_choice`` is a 4-choose-any. The first 3 is semantically output of cell1, output of cell2,
  203. output of cell3 with respectively. Notice that an extra max pooling is followed by cell1, indicating x1 is not
  204. "actually" the direct output of cell1.
  205. .. code-block:: python
  206. class Cell(MutableScope):
  207. pass
  208. class Net(nn.Module):
  209. def __init__(self):
  210. self.cell1 = Cell("cell1")
  211. self.cell2 = Cell("cell2")
  212. self.op = LayerChoice([conv3x3(), conv5x5()], key="op")
  213. self.input_choice = InputChoice(choose_from=["cell1", "cell2", "op", InputChoice.NO_KEY])
  214. def forward(self, x):
  215. x1 = max_pooling(self.cell1(x))
  216. x2 = self.cell2(x)
  217. x3 = self.op(x)
  218. x4 = torch.zeros_like(x)
  219. return self.input_choice([x1, x2, x3, x4])
  220. Parameters
  221. ----------
  222. n_candidates : int
  223. Number of inputs to choose from.
  224. choose_from : list of str
  225. List of source keys to choose from. At least of one of ``choose_from`` and ``n_candidates`` must be fulfilled.
  226. If ``n_candidates`` has a value but ``choose_from`` is None, it will be automatically treated as ``n_candidates``
  227. number of empty string.
  228. n_chosen : int
  229. Recommended inputs to choose. If None, mutator is instructed to select any.
  230. reduction : str
  231. ``mean``, ``concat``, ``sum`` or ``none``. See :class:`LayerChoice`.
  232. return_mask : bool
  233. If ``return_mask``, return output tensor and a mask. Otherwise return tensor only.
  234. key : str
  235. Key of the input choice.
  236. """
  237. NO_KEY = ""
  238. def __init__(self, n_candidates=None, choose_from=None, n_chosen=None,
  239. reduction="sum", return_mask=False, key=None):
  240. super().__init__(key=key)
  241. # precondition check
  242. assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \
  243. "must be not None."
  244. if choose_from is not None and n_candidates is None:
  245. n_candidates = len(choose_from)
  246. elif choose_from is None and n_candidates is not None:
  247. choose_from = [self.NO_KEY] * n_candidates
  248. assert n_candidates == len(choose_from), "Number of candidates must be equal to the length of `choose_from`."
  249. assert n_candidates > 0, "Number of candidates must be greater than 0."
  250. assert n_chosen is None or 0 <= n_chosen <= n_candidates, "Expected selected number must be None or no more " \
  251. "than number of candidates."
  252. self.n_candidates = n_candidates
  253. self.choose_from = choose_from.copy()
  254. self.n_chosen = n_chosen
  255. self.reduction = reduction
  256. self.return_mask = return_mask
  257. def forward(self, optional_inputs):
  258. """
  259. Forward method of LayerChoice.
  260. Parameters
  261. ----------
  262. optional_inputs : list or dict
  263. Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of
  264. ``choose_from`` in initialization. As a list, inputs must follow the semantic order that is the same as
  265. ``choose_from``.
  266. Returns
  267. -------
  268. tuple of tensors
  269. Output and selection mask. If ``return_mask`` is ``False``, only output is returned.
  270. """
  271. optional_input_list = optional_inputs
  272. if isinstance(optional_inputs, dict):
  273. optional_input_list = [optional_inputs[tag] for tag in self.choose_from]
  274. assert isinstance(optional_input_list, list), \
  275. "Optional input list must be a list, not a {}.".format(type(optional_input_list))
  276. assert len(optional_inputs) == self.n_candidates, \
  277. "Length of the input list must be equal to number of candidates."
  278. out, mask = self.mutator.on_forward_input_choice(self, optional_input_list)
  279. if self.return_mask:
  280. return out, mask
  281. return out

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能