Browse Source

feat(mge/quantization): add `quantize_disabled` attribute in Module

GitOrigin-RevId: f108f03c5a
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
ab91302515
3 changed files with 25 additions and 14 deletions
  1. +11
    -0
      python_module/megengine/module/module.py
  2. +0
    -2
      python_module/megengine/module/qat/module.py
  3. +14
    -12
      python_module/megengine/quantization/quantize.py

+ 11
- 0
python_module/megengine/module/module.py View File

@@ -57,6 +57,7 @@ class Module(metaclass=ABCMeta):

def __init__(self):
self.training = True
self.quantize_diabled = False

@abstractmethod
def forward(self, inputs):
@@ -312,6 +313,16 @@ class Module(metaclass=ABCMeta):
"""
self.train(False)

def disable_quantize(self, value=True):
r"""
Set ``module``'s ``quantize_diabled`` attribute and return ``module``.
Could be used as a decorator.
"""
def fn(module: Module) -> None:
module.quantize_diabled = value

self.apply(fn)

def state_dict(self, rst=None, prefix="", keep_var=False):
r"""Returns a dictionary containing whole states of the module.
"""


+ 0
- 2
python_module/megengine/module/qat/module.py View File

@@ -26,8 +26,6 @@ class QATModule(Module):
def __init__(self):
super().__init__()

self.scale = None

self.weight_observer = None # type: Observer
self.act_observer = None # type: Observer



+ 14
- 12
python_module/megengine/quantization/quantize.py View File

@@ -6,7 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy
from typing import Dict, Tuple
from typing import Callable, Dict, Tuple

from .. import module as Float
from ..module import Module
@@ -48,7 +48,7 @@ def _get_convert_dict() -> Tuple[
_float2qat_dict, _qat2quantized_dict = _get_convert_dict()


def quantize(module: Module, inplace=True):
def quantize(module: Module, inplace: bool = True):
r"""
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule`
through :meth:`~.Module.apply`.
@@ -80,7 +80,9 @@ def quantize(module: Module, inplace=True):


def quantize_qat(
module: Module, inplace=True, qconfig: QConfig = ema_fakequant_qconfig
module: Module,
inplace: bool = True,
qconfig: QConfig = ema_fakequant_qconfig,
):
r"""
Recursively convert float :class:`~.Module` to :class:`~.QATModule`
@@ -105,7 +107,7 @@ def quantize_qat(
module._flatten(with_key=True, with_parent=True, predicate=is_quantable)
):
# only convert top quantable module.
if is_quantable(parent):
if is_quantable(parent) or submodule.quantize_diabled:
continue

new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule)
@@ -136,12 +138,12 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig):

def disable_fake_quant(module: Module):
r"""
Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply`
Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply`

:param module: root module to do disable fake quantization recursively.
"""

def fn(mod):
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.act_fake_quant.disable()
mod.weight_fake_quant.disable()
@@ -151,12 +153,12 @@ def disable_fake_quant(module: Module):

def disable_observer(module: Module):
r"""
Recursively disable `module` observer in QATModule through :meth:`~.Module.apply`
Recursively disable ``module`` observer in QATModule through :meth:`~.Module.apply`

:param module: root module to do disable observer recursively.
"""

def fn(mod):
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.act_observer.disable()
mod.weight_observer.disable()
@@ -166,12 +168,12 @@ def disable_observer(module: Module):

def enable_fake_quant(module: Module):
r"""
Recursively enable `module` fake quantization in QATModule through :meth:`~.Module.apply`
Recursively enable ``module`` fake quantization in QATModule through :meth:`~.Module.apply`

:param module: root module to do enable fake quantization recursively.
"""

def fn(mod):
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.act_fake_quant.enable()
mod.weight_fake_quant.enable()
@@ -181,12 +183,12 @@ def enable_fake_quant(module: Module):

def enable_observer(module: Module):
r"""
Recursively enable `module` observer in QATModule through :meth:`~.Module.apply`
Recursively enable ``module`` observer in QATModule through :meth:`~.Module.apply`

:param module: root module to do enable observer recursively.
"""

def fn(mod):
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.act_observer.enable()
mod.weight_observer.enable()


Loading…
Cancel
Save