You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_mutator.py 5.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import logging
  4. import torch.nn as nn
  5. from .mutables import Mutable, MutableScope, InputChoice
  6. from .utils import StructuredMutableTreeNode
  7. logger = logging.getLogger(__name__)
  8. logger.setLevel(logging.INFO)
  9. class BaseMutator(nn.Module):
  10. """
  11. A mutator is responsible for mutating a graph by obtaining the search space from the network and implementing
  12. callbacks that are called in ``forward`` in mutables.
  13. Parameters
  14. ----------
  15. model : nn.Module
  16. PyTorch model to apply mutator on.
  17. """
  18. def __init__(self, model):
  19. super().__init__()
  20. self.__dict__["model"] = model
  21. self._structured_mutables = self._parse_search_space(self.model)
  22. def _parse_search_space(self, module, root=None, prefix="", memo=None, nested_detection=None):
  23. if memo is None:
  24. memo = set()
  25. if root is None:
  26. root = StructuredMutableTreeNode(None)
  27. if module not in memo:
  28. memo.add(module)
  29. if isinstance(module, Mutable):
  30. if nested_detection is not None:
  31. raise RuntimeError("Cannot have nested search space. Error at {} in {}"
  32. .format(module, nested_detection))
  33. module.name = prefix
  34. module.set_mutator(self)
  35. root = root.add_child(module)
  36. if not isinstance(module, MutableScope):
  37. nested_detection = module
  38. if isinstance(module, InputChoice):
  39. for k in module.choose_from:
  40. if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]:
  41. raise RuntimeError("'{}' required by '{}' not found in keys that appeared before, and is not NO_KEY."
  42. .format(k, module.key))
  43. for name, submodule in module._modules.items():
  44. if submodule is None:
  45. continue
  46. submodule_prefix = prefix + ("." if prefix else "") + name
  47. self._parse_search_space(submodule, root, submodule_prefix, memo=memo,
  48. nested_detection=nested_detection)
  49. return root
  50. @property
  51. def mutables(self):
  52. """
  53. A generator of all modules inheriting :class:`~nni.nas.pytorch.mutables.Mutable`.
  54. Modules are yielded in the order that they are defined in ``__init__``.
  55. For mutables with their keys appearing multiple times, only the first one will appear.
  56. """
  57. return self._structured_mutables
  58. @property
  59. def undedup_mutables(self):
  60. return self._structured_mutables.traverse(deduplicate=False)
  61. def forward(self, *inputs):
  62. """
  63. Warnings
  64. --------
  65. Don't call forward of a mutator.
  66. """
  67. raise RuntimeError("Forward is undefined for mutators.")
  68. def __setattr__(self, name, value):
  69. if name == "model":
  70. raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to "
  71. "include you network, as it will include all parameters in model into the mutator.")
  72. return super().__setattr__(name, value)
  73. def enter_mutable_scope(self, mutable_scope):
  74. """
  75. Callback when forward of a MutableScope is entered.
  76. Parameters
  77. ----------
  78. mutable_scope : MutableScope
  79. The mutable scope that is entered.
  80. """
  81. pass
  82. def exit_mutable_scope(self, mutable_scope):
  83. """
  84. Callback when forward of a MutableScope is exited.
  85. Parameters
  86. ----------
  87. mutable_scope : MutableScope
  88. The mutable scope that is exited.
  89. """
  90. pass
  91. def on_forward_layer_choice(self, mutable, *args, **kwargs):
  92. """
  93. Callbacks of forward in LayerChoice.
  94. Parameters
  95. ----------
  96. mutable : LayerChoice
  97. Module whose forward is called.
  98. args : list of torch.Tensor
  99. The arguments of its forward function.
  100. kwargs : dict
  101. The keyword arguments of its forward function.
  102. Returns
  103. -------
  104. tuple of torch.Tensor and torch.Tensor
  105. Output tensor and mask.
  106. """
  107. raise NotImplementedError
  108. def on_forward_input_choice(self, mutable, tensor_list):
  109. """
  110. Callbacks of forward in InputChoice.
  111. Parameters
  112. ----------
  113. mutable : InputChoice
  114. Mutable that is called.
  115. tensor_list : list of torch.Tensor
  116. The arguments mutable is called with.
  117. Returns
  118. -------
  119. tuple of torch.Tensor and torch.Tensor
  120. Output tensor and mask.
  121. """
  122. raise NotImplementedError
  123. def export(self):
  124. """
  125. Export the data of all decisions. This should output the decisions of all the mutables, so that the whole
  126. network can be fully determined with these decisions for further training from scratch.
  127. Returns
  128. -------
  129. dict
  130. Mappings from mutable keys to decisions.
  131. """
  132. raise NotImplementedError

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能