|
|
@@ -51,10 +51,6 @@ _float2qat_dict, _qat2quantized_dict = _get_convert_dict() |
|
|
|
qat_modules = tuple(_qat2quantized_dict.keys()) |
|
|
|
|
|
|
|
|
|
|
|
def is_qat(mod: Module): |
|
|
|
return isinstance(mod, qat_modules) |
|
|
|
|
|
|
|
|
|
|
|
def quantize(module: Module, inplace: bool = True, mapping: dict = None): |
|
|
|
r""" |
|
|
|
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule` |
|
|
@@ -157,6 +153,9 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): |
|
|
|
inst.set_qparams(q_dict) |
|
|
|
return inst |
|
|
|
|
|
|
|
def is_qat(mod: Module): |
|
|
|
return isinstance(mod, QATModule) |
|
|
|
|
|
|
|
for m in list(module._flatten(predicate=is_qat)): |
|
|
|
if m.with_weight: |
|
|
|
weight_q_dict = m.get_weight_qparams() |
|
|
@@ -193,6 +192,9 @@ def hook_qat_module(module: Module, func: Callable): |
|
|
|
Add hooks for all :class:`~.QATModule` submodule |
|
|
|
""" |
|
|
|
|
|
|
|
def is_qat(mod: Module): |
|
|
|
return isinstance(mod, QATModule) |
|
|
|
|
|
|
|
hooks = [] |
|
|
|
for submodule in list(module._flatten(predicate=is_qat)): |
|
|
|
hooks.append(submodule.register_forward_hook(func)) |
|
|
|