Browse Source

fix(mge/quantization): fix QATModule filter in `reset_qconfig` and `hook_qat_module`

GitOrigin-RevId: 92e9f36ca4
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
4130dcd355
1 changed files with 6 additions and 4 deletions
  1. +6
    -4
      imperative/python/megengine/quantization/quantize.py

+ 6
- 4
imperative/python/megengine/quantization/quantize.py View File

@@ -51,10 +51,6 @@ _float2qat_dict, _qat2quantized_dict = _get_convert_dict()
qat_modules = tuple(_qat2quantized_dict.keys())


def is_qat(mod: Module):
return isinstance(mod, qat_modules)


def quantize(module: Module, inplace: bool = True, mapping: dict = None):
r"""
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule`
@@ -157,6 +153,9 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
inst.set_qparams(q_dict)
return inst

def is_qat(mod: Module):
return isinstance(mod, QATModule)

for m in list(module._flatten(predicate=is_qat)):
if m.with_weight:
weight_q_dict = m.get_weight_qparams()
@@ -193,6 +192,9 @@ def hook_qat_module(module: Module, func: Callable):
Add hooks for all :class:`~.QATModule` submodule
"""

def is_qat(mod: Module):
return isinstance(mod, QATModule)

hooks = []
for submodule in list(module._flatten(predicate=is_qat)):
hooks.append(submodule.register_forward_hook(func))


Loading…
Cancel
Save