GitOrigin-RevId: f16fbba2b7
tags/v0.5.0
@@ -496,8 +496,11 @@ class QATModule(Module): | |||||
self, target: Tensor, fq: "FakeQuantize", obs: "Observer" | self, target: Tensor, fq: "FakeQuantize", obs: "Observer" | ||||
): | ): | ||||
oup = self.apply_observer(target, obs) | oup = self.apply_observer(target, obs) | ||||
scale, zero_point = obs.get_qparams() | |||||
return fq(oup, scale, zero_point) | |||||
if self.quantizing == self.QATMode.CALIBRATION: | |||||
return oup | |||||
else: | |||||
scale, zero_point = obs.get_qparams() | |||||
return fq(oup, scale, zero_point) | |||||
def set_qat_mode(self, mode: QATMode): | def set_qat_mode(self, mode: QATMode): | ||||
r""" | r""" | ||||
@@ -524,11 +527,7 @@ class QATModule(Module): | |||||
""" | """ | ||||
def __call__(self, *args, **kwargs): | def __call__(self, *args, **kwargs): | ||||
if self.quantizing == self.QATMode.QAT: | |||||
return self.forward_qat(*args, **kwargs) | |||||
elif self.quantizing == self.QATMode.CALIBRATION: | |||||
# TODO implement the CALIBRATION | |||||
assert False | |||||
return None | |||||
else: | |||||
if self.quantizing == self.QATMode.DISABLED: | |||||
return self.forward(*args, **kwargs) | return self.forward(*args, **kwargs) | ||||
else: | |||||
return self.forward_qat(*args, **kwargs) |
@@ -20,11 +20,9 @@ class Concat(Module): | |||||
A :class:`~.Module` to do quantized concat, inference only. | A :class:`~.Module` to do quantized concat, inference only. | ||||
""" | """ | ||||
def __init__(self): | |||||
def __init__(self, dtype=None): | |||||
super().__init__() | super().__init__() | ||||
self.scale = 1.0 | |||||
self.zero_point = 0.0 | |||||
self.output_dtype = mgb.dtype.qint8(self.scale) | |||||
self.output_dtype = dtype | |||||
def forward(self, inps: Iterable[Tensor], axis: int = 0): | def forward(self, inps: Iterable[Tensor], axis: int = 0): | ||||
if self.training: | if self.training: | ||||
@@ -39,7 +37,4 @@ def to_quantized(float_module): | |||||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | ||||
implemented here to avoid circular import. | implemented here to avoid circular import. | ||||
""" | """ | ||||
qmod = Concat() | |||||
qmod.output_dtype = float_module.act_observer.get_dtype() | |||||
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() | |||||
return qmod | |||||
return Concat(float_module.act_observer.get_dtype()) |
@@ -34,6 +34,7 @@ class _ConvBnActivation2d(Conv2d): | |||||
groups: int = 1, | groups: int = 1, | ||||
conv_mode: str = "CROSS_CORRELATION", | conv_mode: str = "CROSS_CORRELATION", | ||||
compute_mode: str = "DEFAULT", | compute_mode: str = "DEFAULT", | ||||
dtype=None, | |||||
): | ): | ||||
super().__init__( | super().__init__( | ||||
in_channels, | in_channels, | ||||
@@ -47,11 +48,7 @@ class _ConvBnActivation2d(Conv2d): | |||||
conv_mode, | conv_mode, | ||||
compute_mode, | compute_mode, | ||||
) | ) | ||||
self.scale = 1.0 | |||||
self.zero_point = 0.0 | |||||
self.output_dtype = mgb.dtype.qint8(self.scale) | |||||
self.weight = self.weight.astype(self.output_dtype) | |||||
self.bias = self.bias.astype(mgb.dtype.qint32(self.scale)) | |||||
self.output_dtype = dtype | |||||
def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"): | def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"): | ||||
inp_scale = mgb.dtype.get_scale(inp.dtype) | inp_scale = mgb.dtype.get_scale(inp.dtype) | ||||
@@ -87,6 +84,7 @@ class ConvBnRelu2d(_ConvBnActivation2d): | |||||
def to_quantized(quantized_class, float_module): | def to_quantized(quantized_class, float_module): | ||||
output_dtype = float_module.act_observer.get_dtype() | |||||
qconv = quantized_class( | qconv = quantized_class( | ||||
float_module.conv.in_channels, | float_module.conv.in_channels, | ||||
float_module.conv.out_channels, | float_module.conv.out_channels, | ||||
@@ -95,15 +93,14 @@ def to_quantized(quantized_class, float_module): | |||||
float_module.conv.padding, | float_module.conv.padding, | ||||
float_module.conv.dilation, | float_module.conv.dilation, | ||||
float_module.conv.groups, | float_module.conv.groups, | ||||
dtype=output_dtype, | |||||
) | ) | ||||
w_fold, b_fold = float_module.fold_weight_bias( | w_fold, b_fold = float_module.fold_weight_bias( | ||||
float_module.bn.running_mean, float_module.bn.running_var | float_module.bn.running_mean, float_module.bn.running_var | ||||
) | ) | ||||
weight = w_fold.astype(float_module.weight_observer.get_dtype()) | weight = w_fold.astype(float_module.weight_observer.get_dtype()) | ||||
qconv.output_dtype = float_module.act_observer.get_dtype() | |||||
qconv.weight = Parameter(weight.numpy()) | qconv.weight = Parameter(weight.numpy()) | ||||
qconv.bias = Parameter(b_fold.numpy()) | qconv.bias = Parameter(b_fold.numpy()) | ||||
qconv.scale, qconv.zero_point = float_module.act_observer.get_qparams() | |||||
return qconv | return qconv | ||||
@@ -34,12 +34,10 @@ class Elemwise(Module): | |||||
_elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode | _elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode | ||||
def __init__(self, method): | |||||
def __init__(self, method, dtype=None): | |||||
super().__init__() | super().__init__() | ||||
self.method = self._elemwise_multi_type_mode.convert("Q" + method) | self.method = self._elemwise_multi_type_mode.convert("Q" + method) | ||||
self.scale = 1.0 | |||||
self.zero_point = 0.0 | |||||
self.output_dtype = mgb.dtype.qint8(self.scale) | |||||
self.output_dtype = dtype | |||||
def forward(self, *inps): | def forward(self, *inps): | ||||
if self.training: | if self.training: | ||||
@@ -53,7 +51,4 @@ def to_quantized(float_module): | |||||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | ||||
implemented here to avoid circular import. | implemented here to avoid circular import. | ||||
""" | """ | ||||
qmod = Elemwise(float_module.method.name) | |||||
qmod.output_dtype = float_module.act_observer.get_dtype() | |||||
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() | |||||
return qmod | |||||
return Elemwise(float_module.method.name, float_module.act_observer.get_dtype()) |
@@ -16,11 +16,9 @@ class QuantStub(Module): | |||||
A helper quantize operation on input and inference only. | A helper quantize operation on input and inference only. | ||||
""" | """ | ||||
def __init__(self): | |||||
def __init__(self, dtype=None): | |||||
super().__init__() | super().__init__() | ||||
self.scale = 1.0 | |||||
self.zero_point = 0.0 | |||||
self.output_dtype = mgb.dtype.qint8(self.scale) | |||||
self.output_dtype = dtype | |||||
def forward(self, inp): | def forward(self, inp): | ||||
if self.training: | if self.training: | ||||
@@ -45,10 +43,7 @@ def to_quantized(float_module): | |||||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | ||||
implemented here to avoid circular import. | implemented here to avoid circular import. | ||||
""" | """ | ||||
qmod = QuantStub() | |||||
qmod.output_dtype = float_module.act_observer.get_dtype() | |||||
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() | |||||
return qmod | |||||
return QuantStub(float_module.act_observer.get_dtype()) | |||||
@register_method_to_class(Float.DequantStub) | @register_method_to_class(Float.DequantStub) | ||||
@@ -57,5 +52,4 @@ def to_quantized(float_module): | |||||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | ||||
implemented here to avoid circular import. | implemented here to avoid circular import. | ||||
""" | """ | ||||
qmod = DequantStub() | |||||
return qmod | |||||
return DequantStub() |
@@ -14,5 +14,6 @@ from .quantize import ( | |||||
enable_fake_quant, | enable_fake_quant, | ||||
enable_observer, | enable_observer, | ||||
quantize, | quantize, | ||||
quantize_calibration, | |||||
quantize_qat, | quantize_qat, | ||||
) | ) |
@@ -11,7 +11,7 @@ import numpy as np | |||||
from .. import functional as F | from .. import functional as F | ||||
from .._internal.dtype import _metadata_dict, get_quantized_dtype | from .._internal.dtype import _metadata_dict, get_quantized_dtype | ||||
from ..core import Buffer, Function, ones, tensor, zeros | |||||
from ..core import Buffer, Function, tensor | |||||
from ..module import Module | from ..module import Module | ||||
@@ -34,6 +34,8 @@ def quantize(module: Module, inplace=True): | |||||
else: | else: | ||||
setattr(parent, key.split(".")[-1], submodule.to_quantized()) | setattr(parent, key.split(".")[-1], submodule.to_quantized()) | ||||
return module | |||||
def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | ||||
r""" | r""" | ||||
@@ -53,6 +55,25 @@ def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | |||||
module.apply(fn) | module.apply(fn) | ||||
def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | |||||
r""" | |||||
Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply` | |||||
and set qconfig relatively. | |||||
:param module: root module to do convert recursively. | |||||
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||||
default is :any:`~.qconfig.ema_fakequant_qconfig`. | |||||
""" | |||||
def fn(mod: Module): | |||||
if isinstance(mod, QATModule): | |||||
mod.set_qat_mode(QATModule.QATMode.CALIBRATION) | |||||
mod.set_qconfig(qconfig) | |||||
module.apply(fn) | |||||
enable_observer(module) | |||||
def disable_fake_quant(module: Module): | def disable_fake_quant(module: Module): | ||||
r""" | r""" | ||||
Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply` | Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply` | ||||