Browse Source

feat(mge/module): add linear quantization module

GitOrigin-RevId: d0c96a9411
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
855c49ca7d
3 changed files with 78 additions and 4 deletions
  1. +17
    -4
      python_module/megengine/module/linear.py
  2. +1
    -0
      python_module/megengine/module/quantized/__init__.py
  3. +60
    -0
      python_module/megengine/module/quantized/linear.py

+ 17
- 4
python_module/megengine/module/linear.py View File

@@ -8,13 +8,13 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

from .. import functional as F
from ..core import Parameter
from ..functional import linear
from . import init
from .module import Module
from .module import QATModule


class Linear(Module):
class Linear(QATModule):
r"""Applies a linear transformation to the input. For instance, if input
is x, then output y is:

@@ -55,5 +55,18 @@ class Linear(Module):
if self.bias is not None:
init.zeros_(self.bias)

def _calc_linear(self, x, weight, bias):
return F.linear(x, weight, bias)

def forward(self, x):
return linear(x, self.weight, self.bias)
return self._calc_linear(x, self.weight, self.bias)

def forward_qat(self, x):
w_qat = self.apply_fakequant_with_observer(
self.weight, self.weight_fake_quant, self.weight_observer
)
return self.apply_fakequant_with_observer(
self._calc_linear(x, w_qat, self.bias),
self.act_fake_quant,
self.act_observer,
)

+ 1
- 0
python_module/megengine/module/quantized/__init__.py View File

@@ -8,4 +8,5 @@
from .concat import Concat
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .elemwise import Elemwise
from .linear import Linear
from .quant_dequant import DequantStub, QuantStub

+ 60
- 0
python_module/megengine/module/quantized/linear.py View File

@@ -0,0 +1,60 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

import megengine._internal as mgb

from ... import functional as F
from ... import module as Float
from ...core import Parameter
from ...quantization.utils import register_method_to_class
from ..module import Module


class Linear(Module):
r"""Applies a quantized linear transformation to the input. The module
usually convert from QAT module by to_quantized method.

:param dtype: output data type.

"""

def __init__(
self, dtype: np.dtype = None,
):
super().__init__()
self.weight = None
self.bias = None
self.output_dtype = dtype

def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
inp_scale = mgb.dtype.get_scale(inp.dtype)
w_scale = mgb.dtype.get_scale(self.weight.dtype)
bias_dtype = mgb.dtype.qint32(inp_scale * w_scale)
return F.linear(
inp,
self.weight,
None if self.bias is None else self.bias.astype(bias_dtype),
).astype(self.output_dtype)


@register_method_to_class(Float.Linear)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
output_dtype = float_module.act_observer.get_dtype()
qmod = Linear(dtype=output_dtype,)
weight = float_module.weight.astype(float_module.weight_observer.get_dtype())
qmod.weight = Parameter(weight.numpy())
if float_module.bias is not None:
qmod.bias = Parameter(float_module.bias.numpy())
return qmod

Loading…
Cancel
Save