Browse Source

fix(mge/quantization): fix QATModule filter in `reset_qconfig` and `hook_qat_module`

GitOrigin-RevId: 92e9f36ca4
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
4130dcd355
1 changed files with 6 additions and 4 deletions
  1. +6
    -4
      imperative/python/megengine/quantization/quantize.py

+ 6
- 4
imperative/python/megengine/quantization/quantize.py View File

@@ -51,10 +51,6 @@ _float2qat_dict, _qat2quantized_dict = _get_convert_dict()
qat_modules = tuple(_qat2quantized_dict.keys()) qat_modules = tuple(_qat2quantized_dict.keys())




def is_qat(mod: Module):
return isinstance(mod, qat_modules)


def quantize(module: Module, inplace: bool = True, mapping: dict = None): def quantize(module: Module, inplace: bool = True, mapping: dict = None):
r""" r"""
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule` Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule`
@@ -157,6 +153,9 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
inst.set_qparams(q_dict) inst.set_qparams(q_dict)
return inst return inst


def is_qat(mod: Module):
return isinstance(mod, QATModule)

for m in list(module._flatten(predicate=is_qat)): for m in list(module._flatten(predicate=is_qat)):
if m.with_weight: if m.with_weight:
weight_q_dict = m.get_weight_qparams() weight_q_dict = m.get_weight_qparams()
@@ -193,6 +192,9 @@ def hook_qat_module(module: Module, func: Callable):
Add hooks for all :class:`~.QATModule` submodule Add hooks for all :class:`~.QATModule` submodule
""" """


def is_qat(mod: Module):
return isinstance(mod, QATModule)

hooks = [] hooks = []
for submodule in list(module._flatten(predicate=is_qat)): for submodule in list(module._flatten(predicate=is_qat)):
hooks.append(submodule.register_forward_hook(func)) hooks.append(submodule.register_forward_hook(func))


Loading…
Cancel
Save