Browse Source

feat(mge/quantization): add `mapping` parameter for custom modules

GitOrigin-RevId: a4de4261d0
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
22853fa20c
3 changed files with 50 additions and 11 deletions
  1. +3
    -3
      python_module/megengine/module/module.py
  2. +21
    -8
      python_module/megengine/quantization/quantize.py
  3. +26
    -0
      python_module/test/unit/quantization/quantize.py

+ 3
- 3
python_module/megengine/module/module.py View File

@@ -60,7 +60,7 @@ class Module(metaclass=ABCMeta):
def __init__(self):
# runtime attributes
self.training = True
self.quantize_diabled = False
self.quantize_disabled = False

# hooks
self._forward_pre_hooks = OrderedDict()
@@ -328,12 +328,12 @@ class Module(metaclass=ABCMeta):

def disable_quantize(self, value=True):
r"""
Set ``module``'s ``quantize_diabled`` attribute and return ``module``.
Set ``module``'s ``quantize_disabled`` attribute and return ``module``.
Could be used as a decorator.
"""

def fn(module: Module) -> None:
module.quantize_diabled = value
module.quantize_disabled = value

self.apply(fn)



+ 21
- 8
python_module/megengine/quantization/quantize.py View File

@@ -5,7 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy
from copy import copy, deepcopy
from typing import Callable, Dict, Tuple

from .. import module as Float
@@ -49,19 +49,24 @@ def _get_convert_dict() -> Tuple[
_float2qat_dict, _qat2quantized_dict = _get_convert_dict()


def quantize(module: Module, inplace: bool = True):
def quantize(module: Module, inplace: bool = True, mapping: dict = None):
r"""
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule`
through :meth:`~.Module.apply`.

:param module: root module to do convert recursively.
:param inplace: whether to convert submodules in-place.
:param mapping: a dict indicating how to convert custom modules from QATModule to
QuantizedModule. Will be combined with internal default convert mapping dict.
"""

if not inplace:
module = deepcopy(module)

qat_modules = tuple(_qat2quantized_dict.keys())
convert_dict = copy(_qat2quantized_dict)
if mapping is not None:
convert_dict.update(mapping)
qat_modules = tuple(convert_dict.keys())

def is_qat(mod: Module):
return isinstance(mod, qat_modules)
@@ -70,7 +75,7 @@ def quantize(module: Module, inplace: bool = True):
for key, submodule, parent in list(
module._flatten(with_key=True, with_parent=True, predicate=is_qat)
):
new_mod = _qat2quantized_dict[type(submodule)].from_qat_module(submodule)
new_mod = convert_dict[type(submodule)].from_qat_module(submodule)
if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = new_mod
@@ -81,7 +86,10 @@ def quantize(module: Module, inplace: bool = True):


def quantize_qat(
module: Module, inplace: bool = True, qconfig: QConfig = ema_fakequant_qconfig,
module: Module,
inplace: bool = True,
qconfig: QConfig = ema_fakequant_qconfig,
mapping: dict = None,
):
r"""
Recursively convert float :class:`~.Module` to :class:`~.QATModule`
@@ -91,12 +99,17 @@ def quantize_qat(
:param inplace: whether to convert submodules in-place.
:param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is ``ema_fakequant_qconfig``.
:param mapping: a dict indicating how to convert custom modules from Module to QATModule.
Will be combined with internal default convert mapping dict.
"""

if not inplace:
module = deepcopy(module)

quantable_modules = tuple(_float2qat_dict.keys())
convert_dict = copy(_float2qat_dict)
if mapping is not None:
convert_dict.update(mapping)
quantable_modules = tuple(convert_dict.keys())

def is_quantable(mod: Module):
return isinstance(mod, quantable_modules)
@@ -106,10 +119,10 @@ def quantize_qat(
module._flatten(with_key=True, with_parent=True, predicate=is_quantable)
):
# only convert top quantable module.
if is_quantable(parent) or submodule.quantize_diabled:
if is_quantable(parent) or submodule.quantize_disabled:
continue

new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule)
new_mod = convert_dict[type(submodule)].from_float_module(submodule)
if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = new_mod


+ 26
- 0
python_module/test/unit/quantization/quantize.py View File

@@ -52,3 +52,29 @@ def test_disable_quantize():
qat_net = quantize_qat(net, inplace=False)
assert isinstance(qat_net.conv, Float.ConvBnRelu2d)
assert isinstance(qat_net.conv.conv, Float.Conv2d)


def test_convert_with_custom_mapping():
class FloatExample(Float.Module):
def forward(self, x):
return x

class QATExample(QAT.QATModule):
def forward(self, x):
return x

@classmethod
def from_float_module(cls, float_module):
return cls()

class Net(Float.Module):
def __init__(self):
super().__init__()
self.example = FloatExample()

def forward(self, x):
return self.example(x)

net = Net()
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample})
assert isinstance(qat_net.example, QATExample)

Loading…
Cancel
Save