Browse Source

feat(mge/quantization): add `mapping` parameter for custom modules

GitOrigin-RevId: a4de4261d0
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
22853fa20c
3 changed files with 50 additions and 11 deletions
  1. +3
    -3
      python_module/megengine/module/module.py
  2. +21
    -8
      python_module/megengine/quantization/quantize.py
  3. +26
    -0
      python_module/test/unit/quantization/quantize.py

+ 3
- 3
python_module/megengine/module/module.py View File

@@ -60,7 +60,7 @@ class Module(metaclass=ABCMeta):
def __init__(self): def __init__(self):
# runtime attributes # runtime attributes
self.training = True self.training = True
self.quantize_diabled = False
self.quantize_disabled = False


# hooks # hooks
self._forward_pre_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict()
@@ -328,12 +328,12 @@ class Module(metaclass=ABCMeta):


def disable_quantize(self, value=True): def disable_quantize(self, value=True):
r""" r"""
Set ``module``'s ``quantize_diabled`` attribute and return ``module``.
Set ``module``'s ``quantize_disabled`` attribute and return ``module``.
Could be used as a decorator. Could be used as a decorator.
""" """


def fn(module: Module) -> None: def fn(module: Module) -> None:
module.quantize_diabled = value
module.quantize_disabled = value


self.apply(fn) self.apply(fn)




+ 21
- 8
python_module/megengine/quantization/quantize.py View File

@@ -5,7 +5,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy
from copy import copy, deepcopy
from typing import Callable, Dict, Tuple from typing import Callable, Dict, Tuple


from .. import module as Float from .. import module as Float
@@ -49,19 +49,24 @@ def _get_convert_dict() -> Tuple[
_float2qat_dict, _qat2quantized_dict = _get_convert_dict() _float2qat_dict, _qat2quantized_dict = _get_convert_dict()




def quantize(module: Module, inplace: bool = True):
def quantize(module: Module, inplace: bool = True, mapping: dict = None):
r""" r"""
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule` Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule`
through :meth:`~.Module.apply`. through :meth:`~.Module.apply`.


:param module: root module to do convert recursively. :param module: root module to do convert recursively.
:param inplace: whether to convert submodules in-place. :param inplace: whether to convert submodules in-place.
:param mapping: a dict indicating how to convert custom modules from QATModule to
QuantizedModule. Will be combined with internal default convert mapping dict.
""" """


if not inplace: if not inplace:
module = deepcopy(module) module = deepcopy(module)


qat_modules = tuple(_qat2quantized_dict.keys())
convert_dict = copy(_qat2quantized_dict)
if mapping is not None:
convert_dict.update(mapping)
qat_modules = tuple(convert_dict.keys())


def is_qat(mod: Module): def is_qat(mod: Module):
return isinstance(mod, qat_modules) return isinstance(mod, qat_modules)
@@ -70,7 +75,7 @@ def quantize(module: Module, inplace: bool = True):
for key, submodule, parent in list( for key, submodule, parent in list(
module._flatten(with_key=True, with_parent=True, predicate=is_qat) module._flatten(with_key=True, with_parent=True, predicate=is_qat)
): ):
new_mod = _qat2quantized_dict[type(submodule)].from_qat_module(submodule)
new_mod = convert_dict[type(submodule)].from_qat_module(submodule)
if isinstance(parent, Float.Sequential): if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__`` # cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = new_mod parent[int(key.split(".")[-1])] = new_mod
@@ -81,7 +86,10 @@ def quantize(module: Module, inplace: bool = True):




def quantize_qat( def quantize_qat(
module: Module, inplace: bool = True, qconfig: QConfig = ema_fakequant_qconfig,
module: Module,
inplace: bool = True,
qconfig: QConfig = ema_fakequant_qconfig,
mapping: dict = None,
): ):
r""" r"""
Recursively convert float :class:`~.Module` to :class:`~.QATModule` Recursively convert float :class:`~.Module` to :class:`~.QATModule`
@@ -91,12 +99,17 @@ def quantize_qat(
:param inplace: whether to convert submodules in-place. :param inplace: whether to convert submodules in-place.
:param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig. :param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is ``ema_fakequant_qconfig``. default is ``ema_fakequant_qconfig``.
:param mapping: a dict indicating how to convert custom modules from Module to QATModule.
Will be combined with internal default convert mapping dict.
""" """


if not inplace: if not inplace:
module = deepcopy(module) module = deepcopy(module)


quantable_modules = tuple(_float2qat_dict.keys())
convert_dict = copy(_float2qat_dict)
if mapping is not None:
convert_dict.update(mapping)
quantable_modules = tuple(convert_dict.keys())


def is_quantable(mod: Module): def is_quantable(mod: Module):
return isinstance(mod, quantable_modules) return isinstance(mod, quantable_modules)
@@ -106,10 +119,10 @@ def quantize_qat(
module._flatten(with_key=True, with_parent=True, predicate=is_quantable) module._flatten(with_key=True, with_parent=True, predicate=is_quantable)
): ):
# only convert top quantable module. # only convert top quantable module.
if is_quantable(parent) or submodule.quantize_diabled:
if is_quantable(parent) or submodule.quantize_disabled:
continue continue


new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule)
new_mod = convert_dict[type(submodule)].from_float_module(submodule)
if isinstance(parent, Float.Sequential): if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__`` # cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = new_mod parent[int(key.split(".")[-1])] = new_mod


+ 26
- 0
python_module/test/unit/quantization/quantize.py View File

@@ -52,3 +52,29 @@ def test_disable_quantize():
qat_net = quantize_qat(net, inplace=False) qat_net = quantize_qat(net, inplace=False)
assert isinstance(qat_net.conv, Float.ConvBnRelu2d) assert isinstance(qat_net.conv, Float.ConvBnRelu2d)
assert isinstance(qat_net.conv.conv, Float.Conv2d) assert isinstance(qat_net.conv.conv, Float.Conv2d)


def test_convert_with_custom_mapping():
class FloatExample(Float.Module):
def forward(self, x):
return x

class QATExample(QAT.QATModule):
def forward(self, x):
return x

@classmethod
def from_float_module(cls, float_module):
return cls()

class Net(Float.Module):
def __init__(self):
super().__init__()
self.example = FloatExample()

def forward(self, x):
return self.example(x)

net = Net()
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample})
assert isinstance(qat_net.example, QATExample)

Loading…
Cancel
Save