You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quantize.py 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from copy import copy, deepcopy
  9. from functools import partial
  10. from typing import Callable
  11. import numpy as np
  12. from .. import module as Float
  13. from ..functional import concat, norm
  14. from ..logger import get_logger
  15. from ..module import Module
  16. from ..module import qat as QAT
  17. from ..module import quantized as Quantized
  18. from ..module.qat import QATModule
  19. from ..module.quantized import QuantizedModule
  20. from ..tensor import Tensor
  21. from ..utils.module_utils import set_expand_structure
  22. from .qconfig import QConfig, ema_fakequant_qconfig
  23. logger = get_logger(__name__)
  24. def _get_quantable_module_names():
  25. def is_quantable(key: str):
  26. value = getattr(Quantized, key)
  27. return (
  28. isinstance(value, type)
  29. and issubclass(value, QuantizedModule)
  30. and value != QuantizedModule
  31. )
  32. # source should have all quantable modules' names
  33. quantable_module_names = [key for key in dir(Quantized) if is_quantable(key)]
  34. return quantable_module_names
  35. def _get_convert_dict():
  36. quantable_module_names = _get_quantable_module_names()
  37. quantable_modules = [getattr(Float, key) for key in quantable_module_names]
  38. qat_modules = [getattr(QAT, key) for key in quantable_module_names]
  39. quantized_modules = [getattr(Quantized, key) for key in quantable_module_names]
  40. float2qat_dict = dict(zip(quantable_modules, qat_modules))
  41. qat2quantized_dict = dict(zip(qat_modules, quantized_modules))
  42. return float2qat_dict, qat2quantized_dict
  43. _float2qat_dict, _qat2quantized_dict = _get_convert_dict()
  44. qat_modules = tuple(_qat2quantized_dict.keys())
  45. def quantize(module: Module, inplace: bool = True, mapping: dict = None):
  46. r"""
  47. Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule`
  48. through :meth:`~.Module.apply`.
  49. :param module: root module to do convert recursively.
  50. :param inplace: whether to convert submodules in-place.
  51. :param mapping: a dict indicating how to convert custom modules from QATModule to
  52. QuantizedModule. Will be combined with internal default convert mapping dict.
  53. """
  54. if not inplace:
  55. module = deepcopy(module)
  56. convert_dict = copy(_qat2quantized_dict)
  57. if mapping is not None:
  58. convert_dict.update(mapping)
  59. qat_modules = tuple(convert_dict.keys())
  60. def is_qat(mod: Module):
  61. return isinstance(mod, qat_modules)
  62. # must use list to avoid replacement influencing successor modules
  63. for key, submodule, parent in list(
  64. module._flatten(with_key=True, with_parent=True, predicate=is_qat)
  65. ):
  66. new_mod = convert_dict[type(submodule)].from_qat_module(submodule)
  67. set_expand_structure(module, key, new_mod)
  68. return module
  69. def quantize_qat(
  70. module: Module,
  71. inplace: bool = True,
  72. qconfig: QConfig = ema_fakequant_qconfig,
  73. mapping: dict = None,
  74. ):
  75. r"""
  76. Recursively convert float :class:`~.Module` to :class:`~.QATModule`
  77. through :meth:`~.Module.apply` and set qconfig relatively.
  78. :param module: root module to do convert recursively.
  79. :param inplace: whether to convert submodules in-place.
  80. :param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig.
  81. default is ``ema_fakequant_qconfig``.
  82. :param mapping: a dict indicating how to convert custom modules from Module to QATModule.
  83. Will be combined with internal default convert mapping dict.
  84. """
  85. if not inplace:
  86. module = deepcopy(module)
  87. convert_dict = copy(_float2qat_dict)
  88. if mapping is not None:
  89. convert_dict.update(mapping)
  90. quantable_modules = tuple(convert_dict.keys())
  91. def is_quantable(mod: Module):
  92. return isinstance(mod, quantable_modules)
  93. # must use list to avoid replacement influencing successor modules
  94. for key, submodule, parent in list(
  95. module._flatten(with_key=True, with_parent=True, predicate=is_quantable)
  96. ):
  97. # only convert top quantable module.
  98. if is_quantable(parent) or submodule.quantize_disabled:
  99. continue
  100. new_mod = convert_dict[type(submodule)].from_float_module(submodule)
  101. set_expand_structure(module, key, new_mod)
  102. propagate_qconfig(module, qconfig)
  103. return module
  104. def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
  105. r"""
  106. Reset :class:`~._FakeQuantize` and :class:`~.Observer` according to ``qconfig``
  107. :param module: root module to reset recursively.
  108. :param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig.
  109. :param inplace: whether to reset submodules in-place.
  110. """
  111. if not inplace:
  112. module = deepcopy(module)
  113. def safe_call(func, qparams):
  114. inst = func() if func is not None else None
  115. if inst is not None and getattr(inst, "set_qparams", None) is not None:
  116. inst.set_qparams(qparams)
  117. return inst
  118. def is_qat(mod: Module):
  119. return isinstance(mod, QATModule)
  120. for m in list(module._flatten(predicate=is_qat)):
  121. if m.with_weight:
  122. weight_params = m.get_weight_qparams()
  123. m.weight_observer = safe_call(qconfig.weight_observer, weight_params)
  124. m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_params)
  125. if m.with_act:
  126. act_params = m.get_activation_qparams()
  127. m.act_observer = safe_call(qconfig.act_observer, act_params)
  128. m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_params)
  129. return module
  130. def _propagate(module: Module, func_str: str, *args, **kargs):
  131. def fn(mod: Module):
  132. if isinstance(mod, QATModule):
  133. getattr(mod, func_str)(*args, **kargs)
  134. module.apply(fn)
  135. def propagate_qconfig(module: QATModule, qconfig: QConfig):
  136. r"""
  137. Recursively set ``module``'s qconfig through :meth:`~.Module.apply`.
  138. :param module: root module to traverse recursively.
  139. :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
  140. """
  141. _propagate(module, "set_qconfig", qconfig)
  142. def hook_qat_module(module: Module, func: Callable):
  143. r"""
  144. Add hooks for all :class:`~.QATModule` submodule
  145. """
  146. def is_qat(mod: Module):
  147. return isinstance(mod, QATModule)
  148. hooks = []
  149. for submodule in list(module._flatten(predicate=is_qat)):
  150. hooks.append(submodule.register_forward_hook(func))
  151. return hooks
  152. def apply_easy_quant(
  153. module: Module, data: Tensor, start: float = 0.8, stop: float = 1.2, num: int = 40
  154. ):
  155. r"""
  156. Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669.
  157. Search for optimal scales.
  158. :param module: root module.
  159. :param data: input tensor used to search optimal scale.
  160. :param start: lower bound of the search interval.
  161. :param stop: upper bound of the search interval.
  162. :param num: number of samples to search.
  163. """
  164. batch_size = data.shape[0]
  165. def get_cosine(x, y):
  166. ndim = len(x.shape)
  167. axis = tuple(range(1, ndim))
  168. up = (x * y).sum(axis=axis)
  169. down = norm(x, axis=axis) * norm(y, axis=axis)
  170. sim = up / down
  171. return sim.mean(axis=0)
  172. def search(mod, inputs, outputs, where):
  173. mod._forward_hooks.clear()
  174. normal_in = [_[:batch_size] for _ in inputs]
  175. fakequant_in = [_[batch_size:] for _ in inputs]
  176. disable_fake_quant(mod)
  177. normal_out = mod(*normal_in)
  178. enable_fake_quant(mod)
  179. ob = getattr(mod, where)
  180. if ob is None:
  181. return
  182. orig_scale = ob.orig_scale
  183. cosine = optimal = 0
  184. for scale in np.linspace(start * orig_scale, stop * orig_scale, num):
  185. ob.scale = scale
  186. fakequant_out = mod(*fakequant_in)
  187. dis = get_cosine(normal_out, fakequant_out)
  188. if dis > cosine:
  189. cosine = dis
  190. optimal = scale
  191. if optimal == 0:
  192. logger.warning("EasyQuant finds no better scale")
  193. else:
  194. ob.scale = optimal
  195. fakequant_out = outputs[batch_size:]
  196. return concat([normal_out, fakequant_out])
  197. data = concat([data, data])
  198. hook_qat_module(module, partial(search, where="weight_observer"))
  199. module(data)
  200. hook_qat_module(module, partial(search, where="act_observer"))
  201. module(data)
  202. return module
  203. def disable_fake_quant(module: Module):
  204. r"""
  205. Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply`
  206. :param module: root module to do disable fake quantization recursively.
  207. """
  208. _propagate(module, "set_fake_quant", False)
  209. def disable_observer(module: Module):
  210. r"""
  211. Recursively disable ``module`` observer in QATModule through :meth:`~.Module.apply`
  212. :param module: root module to do disable observer recursively.
  213. """
  214. _propagate(module, "set_observer", False)
  215. def enable_fake_quant(module: Module):
  216. r"""
  217. Recursively enable ``module`` fake quantization in QATModule through :meth:`~.Module.apply`
  218. :param module: root module to do enable fake quantization recursively.
  219. """
  220. _propagate(module, "set_fake_quant", True)
  221. def enable_observer(module: Module):
  222. r"""
  223. Recursively enable ``module`` observer in QATModule through :meth:`~.Module.apply`
  224. :param module: root module to do enable observer recursively.
  225. """
  226. _propagate(module, "set_observer", True)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台