|
|
@@ -6,7 +6,7 @@ |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
from copy import deepcopy |
|
|
|
from typing import Dict, Tuple |
|
|
|
from typing import Callable, Dict, Tuple |
|
|
|
|
|
|
|
from .. import module as Float |
|
|
|
from ..module import Module |
|
|
@@ -48,7 +48,7 @@ def _get_convert_dict() -> Tuple[ |
|
|
|
_float2qat_dict, _qat2quantized_dict = _get_convert_dict() |
|
|
|
|
|
|
|
|
|
|
|
def quantize(module: Module, inplace=True): |
|
|
|
def quantize(module: Module, inplace: bool = True): |
|
|
|
r""" |
|
|
|
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule` |
|
|
|
through :meth:`~.Module.apply`. |
|
|
@@ -80,7 +80,9 @@ def quantize(module: Module, inplace=True): |
|
|
|
|
|
|
|
|
|
|
|
def quantize_qat( |
|
|
|
module: Module, inplace=True, qconfig: QConfig = ema_fakequant_qconfig |
|
|
|
module: Module, |
|
|
|
inplace: bool = True, |
|
|
|
qconfig: QConfig = ema_fakequant_qconfig, |
|
|
|
): |
|
|
|
r""" |
|
|
|
Recursively convert float :class:`~.Module` to :class:`~.QATModule` |
|
|
@@ -105,7 +107,7 @@ def quantize_qat( |
|
|
|
module._flatten(with_key=True, with_parent=True, predicate=is_quantable) |
|
|
|
): |
|
|
|
# only convert top quantable module. |
|
|
|
if is_quantable(parent): |
|
|
|
if is_quantable(parent) or submodule.quantize_diabled: |
|
|
|
continue |
|
|
|
|
|
|
|
new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule) |
|
|
@@ -136,12 +138,12 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig): |
|
|
|
|
|
|
|
def disable_fake_quant(module: Module): |
|
|
|
r""" |
|
|
|
Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply` |
|
|
|
Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply` |
|
|
|
|
|
|
|
:param module: root module to do disable fake quantization recursively. |
|
|
|
""" |
|
|
|
|
|
|
|
def fn(mod): |
|
|
|
def fn(mod: Module): |
|
|
|
if isinstance(mod, QATModule): |
|
|
|
mod.act_fake_quant.disable() |
|
|
|
mod.weight_fake_quant.disable() |
|
|
@@ -151,12 +153,12 @@ def disable_fake_quant(module: Module): |
|
|
|
|
|
|
|
def disable_observer(module: Module): |
|
|
|
r""" |
|
|
|
Recursively disable `module` observer in QATModule through :meth:`~.Module.apply` |
|
|
|
Recursively disable ``module`` observer in QATModule through :meth:`~.Module.apply` |
|
|
|
|
|
|
|
:param module: root module to do disable observer recursively. |
|
|
|
""" |
|
|
|
|
|
|
|
def fn(mod): |
|
|
|
def fn(mod: Module): |
|
|
|
if isinstance(mod, QATModule): |
|
|
|
mod.act_observer.disable() |
|
|
|
mod.weight_observer.disable() |
|
|
@@ -166,12 +168,12 @@ def disable_observer(module: Module): |
|
|
|
|
|
|
|
def enable_fake_quant(module: Module): |
|
|
|
r""" |
|
|
|
Recursively enable `module` fake quantization in QATModule through :meth:`~.Module.apply` |
|
|
|
Recursively enable ``module`` fake quantization in QATModule through :meth:`~.Module.apply` |
|
|
|
|
|
|
|
:param module: root module to do enable fake quantization recursively. |
|
|
|
""" |
|
|
|
|
|
|
|
def fn(mod): |
|
|
|
def fn(mod: Module): |
|
|
|
if isinstance(mod, QATModule): |
|
|
|
mod.act_fake_quant.enable() |
|
|
|
mod.weight_fake_quant.enable() |
|
|
@@ -181,12 +183,12 @@ def enable_fake_quant(module: Module): |
|
|
|
|
|
|
|
def enable_observer(module: Module): |
|
|
|
r""" |
|
|
|
Recursively enable `module` observer in QATModule through :meth:`~.Module.apply` |
|
|
|
Recursively enable ``module`` observer in QATModule through :meth:`~.Module.apply` |
|
|
|
|
|
|
|
:param module: root module to do enable observer recursively. |
|
|
|
""" |
|
|
|
|
|
|
|
def fn(mod): |
|
|
|
def fn(mod: Module): |
|
|
|
if isinstance(mod, QATModule): |
|
|
|
mod.act_observer.enable() |
|
|
|
mod.weight_observer.enable() |
|
|
|