diff --git a/python_module/megengine/module/linear.py b/python_module/megengine/module/linear.py index ff4da3b5..30f1ec3d 100644 --- a/python_module/megengine/module/linear.py +++ b/python_module/megengine/module/linear.py @@ -8,13 +8,13 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np +from .. import functional as F from ..core import Parameter -from ..functional import linear from . import init -from .module import Module +from .module import QATModule -class Linear(Module): +class Linear(QATModule): r"""Applies a linear transformation to the input. For instance, if input is x, then output y is: @@ -55,5 +55,18 @@ class Linear(Module): if self.bias is not None: init.zeros_(self.bias) + def _calc_linear(self, x, weight, bias): + return F.linear(x, weight, bias) + def forward(self, x): - return linear(x, self.weight, self.bias) + return self._calc_linear(x, self.weight, self.bias) + + def forward_qat(self, x): + w_qat = self.apply_fakequant_with_observer( + self.weight, self.weight_fake_quant, self.weight_observer + ) + return self.apply_fakequant_with_observer( + self._calc_linear(x, w_qat, self.bias), + self.act_fake_quant, + self.act_observer, + ) diff --git a/python_module/megengine/module/quantized/__init__.py b/python_module/megengine/module/quantized/__init__.py index c7c9635e..040a3b14 100644 --- a/python_module/megengine/module/quantized/__init__.py +++ b/python_module/megengine/module/quantized/__init__.py @@ -8,4 +8,5 @@ from .concat import Concat from .conv_bn_relu import ConvBn2d, ConvBnRelu2d from .elemwise import Elemwise +from .linear import Linear from .quant_dequant import DequantStub, QuantStub diff --git a/python_module/megengine/module/quantized/linear.py b/python_module/megengine/module/quantized/linear.py new file mode 100644 index 00000000..243db7d7 --- /dev/null +++ b/python_module/megengine/module/quantized/linear.py @@ -0,0 +1,60 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np + +import megengine._internal as mgb + +from ... import functional as F +from ... import module as Float +from ...core import Parameter +from ...quantization.utils import register_method_to_class +from ..module import Module + + +class Linear(Module): + r"""Applies a quantized linear transformation to the input. The module + usually convert from QAT module by to_quantized method. + + :param dtype: output data type. + + """ + + def __init__( + self, dtype: np.dtype = None, + ): + super().__init__() + self.weight = None + self.bias = None + self.output_dtype = dtype + + def forward(self, inp): + if self.training: + raise ValueError("quantized module only support inference.") + inp_scale = mgb.dtype.get_scale(inp.dtype) + w_scale = mgb.dtype.get_scale(self.weight.dtype) + bias_dtype = mgb.dtype.qint32(inp_scale * w_scale) + return F.linear( + inp, + self.weight, + None if self.bias is None else self.bias.astype(bias_dtype), + ).astype(self.output_dtype) + + +@register_method_to_class(Float.Linear) +def to_quantized(float_module): + r""" + Replace :class:`~.module.QATModule`'s ``to_quantized`` method. + implemented here to avoid circular import. + """ + output_dtype = float_module.act_observer.get_dtype() + qmod = Linear(dtype=output_dtype,) + weight = float_module.weight.astype(float_module.weight_observer.get_dtype()) + qmod.weight = Parameter(weight.numpy()) + if float_module.bias is not None: + qmod.bias = Parameter(float_module.bias.numpy()) + return qmod