|
|
@@ -5,7 +5,7 @@ |
|
|
|
# Unless required by applicable law or agreed to in writing, |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
from copy import deepcopy |
|
|
|
from copy import copy, deepcopy |
|
|
|
from typing import Callable, Dict, Tuple |
|
|
|
|
|
|
|
from .. import module as Float |
|
|
@@ -49,19 +49,24 @@ def _get_convert_dict() -> Tuple[ |
|
|
|
_float2qat_dict, _qat2quantized_dict = _get_convert_dict() |
|
|
|
|
|
|
|
|
|
|
|
def quantize(module: Module, inplace: bool = True): |
|
|
|
def quantize(module: Module, inplace: bool = True, mapping: dict = None): |
|
|
|
r""" |
|
|
|
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule` |
|
|
|
through :meth:`~.Module.apply`. |
|
|
|
|
|
|
|
:param module: root module to do convert recursively. |
|
|
|
:param inplace: whether to convert submodules in-place. |
|
|
|
:param mapping: a dict indicating how to convert custom modules from QATModule to |
|
|
|
QuantizedModule. Will be combined with internal default convert mapping dict. |
|
|
|
""" |
|
|
|
|
|
|
|
if not inplace: |
|
|
|
module = deepcopy(module) |
|
|
|
|
|
|
|
qat_modules = tuple(_qat2quantized_dict.keys()) |
|
|
|
convert_dict = copy(_qat2quantized_dict) |
|
|
|
if mapping is not None: |
|
|
|
convert_dict.update(mapping) |
|
|
|
qat_modules = tuple(convert_dict.keys()) |
|
|
|
|
|
|
|
def is_qat(mod: Module): |
|
|
|
return isinstance(mod, qat_modules) |
|
|
@@ -70,7 +75,7 @@ def quantize(module: Module, inplace: bool = True): |
|
|
|
for key, submodule, parent in list( |
|
|
|
module._flatten(with_key=True, with_parent=True, predicate=is_qat) |
|
|
|
): |
|
|
|
new_mod = _qat2quantized_dict[type(submodule)].from_qat_module(submodule) |
|
|
|
new_mod = convert_dict[type(submodule)].from_qat_module(submodule) |
|
|
|
if isinstance(parent, Float.Sequential): |
|
|
|
# cannnot use setattr to be compatible with Sequential's ``__setitem__`` |
|
|
|
parent[int(key.split(".")[-1])] = new_mod |
|
|
@@ -81,7 +86,10 @@ def quantize(module: Module, inplace: bool = True): |
|
|
|
|
|
|
|
|
|
|
|
def quantize_qat( |
|
|
|
module: Module, inplace: bool = True, qconfig: QConfig = ema_fakequant_qconfig, |
|
|
|
module: Module, |
|
|
|
inplace: bool = True, |
|
|
|
qconfig: QConfig = ema_fakequant_qconfig, |
|
|
|
mapping: dict = None, |
|
|
|
): |
|
|
|
r""" |
|
|
|
Recursively convert float :class:`~.Module` to :class:`~.QATModule` |
|
|
@@ -91,12 +99,17 @@ def quantize_qat( |
|
|
|
:param inplace: whether to convert submodules in-place. |
|
|
|
:param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig. |
|
|
|
default is ``ema_fakequant_qconfig``. |
|
|
|
:param mapping: a dict indicating how to convert custom modules from Module to QATModule. |
|
|
|
Will be combined with internal default convert mapping dict. |
|
|
|
""" |
|
|
|
|
|
|
|
if not inplace: |
|
|
|
module = deepcopy(module) |
|
|
|
|
|
|
|
quantable_modules = tuple(_float2qat_dict.keys()) |
|
|
|
convert_dict = copy(_float2qat_dict) |
|
|
|
if mapping is not None: |
|
|
|
convert_dict.update(mapping) |
|
|
|
quantable_modules = tuple(convert_dict.keys()) |
|
|
|
|
|
|
|
def is_quantable(mod: Module): |
|
|
|
return isinstance(mod, quantable_modules) |
|
|
@@ -106,10 +119,10 @@ def quantize_qat( |
|
|
|
module._flatten(with_key=True, with_parent=True, predicate=is_quantable) |
|
|
|
): |
|
|
|
# only convert top quantable module. |
|
|
|
if is_quantable(parent) or submodule.quantize_diabled: |
|
|
|
if is_quantable(parent) or submodule.quantize_disabled: |
|
|
|
continue |
|
|
|
|
|
|
|
new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule) |
|
|
|
new_mod = convert_dict[type(submodule)].from_float_module(submodule) |
|
|
|
if isinstance(parent, Float.Sequential): |
|
|
|
# cannnot use setattr to be compatible with Sequential's ``__setitem__`` |
|
|
|
parent[int(key.split(".")[-1])] = new_mod |
|
|
|