You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quantize.py 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. from copy import copy, deepcopy
  2. from functools import partial
  3. from typing import Callable
  4. import numpy as np
  5. from .. import module as Float
  6. from ..functional import concat, norm
  7. from ..logger import get_logger
  8. from ..module import Module
  9. from ..module import qat as QAT
  10. from ..module import quantized as Quantized
  11. from ..module.qat import QATModule
  12. from ..module.quantized import QuantizedModule
  13. from ..tensor import Tensor
  14. from ..utils.module_utils import set_expand_structure
  15. from .qconfig import QConfig, ema_fakequant_qconfig
  16. logger = get_logger(__name__)
  17. def _get_quantable_module_names():
  18. def is_quantable(key: str):
  19. value = getattr(Quantized, key)
  20. return (
  21. isinstance(value, type)
  22. and issubclass(value, QuantizedModule)
  23. and value != QuantizedModule
  24. )
  25. # source should have all quantable modules' names
  26. quantable_module_names = [key for key in dir(Quantized) if is_quantable(key)]
  27. return quantable_module_names
  28. def _get_convert_dict():
  29. quantable_module_names = _get_quantable_module_names()
  30. quantable_modules = [getattr(Float, key) for key in quantable_module_names]
  31. qat_modules = [getattr(QAT, key) for key in quantable_module_names]
  32. quantized_modules = [getattr(Quantized, key) for key in quantable_module_names]
  33. float2qat_dict = dict(zip(quantable_modules, qat_modules))
  34. qat2quantized_dict = dict(zip(qat_modules, quantized_modules))
  35. return float2qat_dict, qat2quantized_dict
  36. _float2qat_dict, _qat2quantized_dict = _get_convert_dict()
  37. qat_modules = tuple(_qat2quantized_dict.keys())
  38. def quantize(module: Module, inplace: bool = True, mapping: dict = None):
  39. r"""Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule`
  40. through :meth:`~.Module.apply`.
  41. Args:
  42. module: root module to do convert recursively.
  43. inplace: whether to convert submodules in-place.
  44. mapping: a dict indicating how to convert custom modules from QATModule to
  45. QuantizedModule. Will be combined with internal default convert mapping dict.
  46. """
  47. if not inplace:
  48. module = deepcopy(module)
  49. convert_dict = copy(_qat2quantized_dict)
  50. if mapping is not None:
  51. convert_dict.update(mapping)
  52. qat_modules = tuple(convert_dict.keys())
  53. def is_qat(mod: Module):
  54. return isinstance(mod, qat_modules)
  55. # must use list to avoid replacement influencing successor modules
  56. for key, submodule, parent in list(
  57. module._flatten(with_key=True, with_parent=True, predicate=is_qat)
  58. ):
  59. new_mod = convert_dict[type(submodule)].from_qat_module(submodule)
  60. set_expand_structure(module, key, new_mod)
  61. return module
  62. def quantize_qat(
  63. module: Module,
  64. inplace: bool = True,
  65. qconfig: QConfig = ema_fakequant_qconfig,
  66. mapping: dict = None,
  67. ):
  68. r"""Recursively convert float :class:`~.Module` to :class:`~.QATModule`
  69. through :meth:`~.Module.apply` and set qconfig relatively.
  70. Args:
  71. module: root module to do convert recursively.
  72. inplace: whether to convert submodules in-place.
  73. qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig.
  74. default is ``ema_fakequant_qconfig``.
  75. mapping: a dict indicating how to convert custom modules from Module to QATModule.
  76. Will be combined with internal default convert mapping dict.
  77. """
  78. if not inplace:
  79. module = deepcopy(module)
  80. convert_dict = copy(_float2qat_dict)
  81. if mapping is not None:
  82. convert_dict.update(mapping)
  83. quantable_modules = tuple(convert_dict.keys())
  84. def is_quantable(mod: Module):
  85. return isinstance(mod, quantable_modules)
  86. # must use list to avoid replacement influencing successor modules
  87. for key, submodule, parent in list(
  88. module._flatten(with_key=True, with_parent=True, predicate=is_quantable)
  89. ):
  90. # only convert top quantable module.
  91. if is_quantable(parent) or submodule.quantize_disabled:
  92. continue
  93. new_mod = convert_dict[type(submodule)].from_float_module(submodule)
  94. set_expand_structure(module, key, new_mod)
  95. propagate_qconfig(module, qconfig)
  96. return module
  97. def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
  98. r"""Reset :class:`~._FakeQuantize` and :class:`~.Observer` according to ``qconfig``
  99. Args:
  100. module: root module to reset recursively.
  101. qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig.
  102. inplace: whether to reset submodules in-place.
  103. """
  104. if not inplace:
  105. module = deepcopy(module)
  106. def safe_call(func, qparams):
  107. inst = func() if func is not None else None
  108. if inst is not None and getattr(inst, "set_qparams", None) is not None:
  109. inst.set_qparams(qparams)
  110. return inst
  111. def is_qat(mod: Module):
  112. return isinstance(mod, QATModule)
  113. for m in list(module._flatten(predicate=is_qat)):
  114. if m.with_weight:
  115. weight_params = m.get_weight_qparams()
  116. m.weight_observer = safe_call(qconfig.weight_observer, weight_params)
  117. m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_params)
  118. if m.with_act:
  119. act_params = m.get_activation_qparams()
  120. m.act_observer = safe_call(qconfig.act_observer, act_params)
  121. m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_params)
  122. return module
  123. def _propagate(module: Module, func_str: str, *args, **kargs):
  124. def fn(mod: Module):
  125. if isinstance(mod, QATModule):
  126. getattr(mod, func_str)(*args, **kargs)
  127. module.apply(fn)
  128. def propagate_qconfig(module: QATModule, qconfig: QConfig):
  129. r"""Recursively set ``module``'s qconfig through :meth:`~.Module.apply`.
  130. Args:
  131. module: root module to traverse recursively.
  132. qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
  133. """
  134. _propagate(module, "set_qconfig", qconfig)
  135. def hook_qat_module(module: Module, func: Callable):
  136. r"""Add hooks for all :class:`~.QATModule` submodule"""
  137. def is_qat(mod: Module):
  138. return isinstance(mod, QATModule)
  139. hooks = []
  140. for submodule in list(module._flatten(predicate=is_qat)):
  141. hooks.append(submodule.register_forward_hook(func))
  142. return hooks
  143. def apply_easy_quant(
  144. module: Module, data: Tensor, start: float = 0.8, stop: float = 1.2, num: int = 40
  145. ):
  146. r"""Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669.
  147. Search for optimal scales.
  148. Args:
  149. module: root module.
  150. data: input tensor used to search optimal scale.
  151. start: lower bound of the search interval.
  152. stop: upper bound of the search interval.
  153. num: number of samples to search.
  154. module: Module:
  155. """
  156. batch_size = data.shape[0]
  157. def get_cosine(x, y):
  158. ndim = len(x.shape)
  159. axis = tuple(range(1, ndim))
  160. up = (x * y).sum(axis=axis)
  161. down = norm(x, axis=axis) * norm(y, axis=axis)
  162. sim = up / down
  163. return sim.mean(axis=0)
  164. def search(mod, inputs, outputs, where):
  165. mod._forward_hooks.clear()
  166. normal_in = [_[:batch_size] for _ in inputs]
  167. fakequant_in = [_[batch_size:] for _ in inputs]
  168. disable_fake_quant(mod)
  169. normal_out = mod(*normal_in)
  170. enable_fake_quant(mod)
  171. ob = getattr(mod, where)
  172. if ob is None:
  173. return
  174. orig_scale = ob.orig_scale
  175. cosine = optimal = 0
  176. for scale in np.linspace(start * orig_scale, stop * orig_scale, num):
  177. ob.scale = scale
  178. fakequant_out = mod(*fakequant_in)
  179. dis = get_cosine(normal_out, fakequant_out)
  180. if dis > cosine:
  181. cosine = dis
  182. optimal = scale
  183. if optimal == 0:
  184. logger.warning("EasyQuant finds no better scale")
  185. else:
  186. ob.scale = optimal
  187. fakequant_out = outputs[batch_size:]
  188. return concat([normal_out, fakequant_out])
  189. data = concat([data, data])
  190. hook_qat_module(module, partial(search, where="weight_observer"))
  191. module(data)
  192. hook_qat_module(module, partial(search, where="act_observer"))
  193. module(data)
  194. return module
  195. def disable_fake_quant(module: Module):
  196. r"""Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply`
  197. Args:
  198. module: root module to do disable fake quantization recursively.
  199. """
  200. _propagate(module, "set_fake_quant", False)
  201. def disable_observer(module: Module):
  202. r"""Recursively disable ``module`` observer in QATModule through :meth:`~.Module.apply`
  203. Args:
  204. module: root module to do disable observer recursively.
  205. """
  206. _propagate(module, "set_observer", False)
  207. def enable_fake_quant(module: Module):
  208. r"""Recursively enable ``module`` fake quantization in QATModule through :meth:`~.Module.apply`
  209. Args:
  210. module: root module to do enable fake quantization recursively.
  211. """
  212. _propagate(module, "set_fake_quant", True)
  213. def enable_observer(module: Module):
  214. r"""Recursively enable ``module`` observer in QATModule through :meth:`~.Module.apply`
  215. Args:
  216. module: root module to do enable observer recursively.
  217. """
  218. _propagate(module, "set_observer", True)