Browse Source

feat(mge/quantization): use extra act_fakequant to decide whether to do bias fakequant

GitOrigin-RevId: bf54012155
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
ce88e6c4f7
6 changed files with 32 additions and 21 deletions
  1. +1
    -3
      imperative/python/megengine/module/qat/batch_matmul_activation.py
  2. +1
    -5
      imperative/python/megengine/module/qat/conv.py
  3. +1
    -5
      imperative/python/megengine/module/qat/conv_bn.py
  4. +3
    -7
      imperative/python/megengine/module/qat/linear.py
  5. +19
    -0
      imperative/python/megengine/module/qat/module.py
  6. +7
    -1
      imperative/python/megengine/quantization/__init__.py

+ 1
- 3
imperative/python/megengine/module/qat/batch_matmul_activation.py View File

@@ -5,8 +5,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ...quantization.utils import fake_quant_bias
from .. import batch_matmul_activation as Float
from .module import QATModule

@@ -18,7 +16,7 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule):

def forward(self, inp):
w_qat = self.apply_quant_weight(self.weight)
b_qat = fake_quant_bias(self.bias, inp, w_qat)
b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat))

@classmethod


+ 1
- 5
imperative/python/megengine/module/qat/conv.py View File

@@ -6,7 +6,6 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import functional as F
from ...quantization.utils import fake_quant_bias
from .. import conv as Float
from .module import QATModule

@@ -19,10 +18,7 @@ class Conv2d(Float.Conv2d, QATModule):

def calc_conv_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(self.bias, inp, w_qat)
else:
b_qat = self.bias
b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
conv = self.calc_conv(inp, w_qat, b_qat)
return conv



+ 1
- 5
imperative/python/megengine/module/qat/conv_bn.py View File

@@ -6,7 +6,6 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...functional import ones, relu, sqrt, sum, zeros
from ...quantization.utils import fake_quant_bias
from .. import conv_bn as Float
from .module import QATModule

@@ -122,10 +121,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd

w_qat = self.apply_quant_weight(w_fold)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(b_fold, inp, w_qat)
else:
b_qat = b_fold
b_qat = self.apply_quant_bias(b_fold, inp, w_qat)
conv = self.conv.calc_conv(inp, w_qat, b_qat)
if not (self.training and approx):
return conv


+ 3
- 7
imperative/python/megengine/module/qat/linear.py View File

@@ -5,7 +5,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...quantization.utils import fake_quant_bias
from .. import linear as Float
from .module import QATModule

@@ -22,13 +21,10 @@ class Linear(Float.Linear, QATModule):

"""

def forward(self, x):
def forward(self, inp):
w_qat = self.apply_quant_weight(self.weight)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(self.bias, x, w_qat)
else:
b_qat = self.bias
return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat))
b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat))

@classmethod
def from_float_module(cls, float_module: Float.Linear):


+ 19
- 0
imperative/python/megengine/module/qat/module.py View File

@@ -11,6 +11,7 @@ from abc import abstractmethod
from ...quantization.fake_quant import FakeQuantize
from ...quantization.observer import Observer
from ...quantization.qconfig import QConfig
from ...quantization.utils import fake_quant_bias
from ...tensor import Tensor
from ..module import Module

@@ -107,6 +108,24 @@ class QATModule(Module):
target, self.act_fake_quant, self.act_observer
)

def apply_quant_bias(self, target: Tensor, inp: Tensor, w_qat: Tensor):
r"""
Use :func:`~.fake_quant_bias` to process ``target``. Only valid when
``act_fake_quant`` and ``weight_fake_quant`` are both enabled.
"""
# bias should have the same dtype as activation, so act_fake_quant can also
# decide whether to do bias fakequant
if (
self.act_fake_quant
and self.act_fake_quant.enabled
and self.weight_fake_quant
and self.weight_fake_quant.enabled
):
b_qat = fake_quant_bias(target, inp, w_qat)
else:
b_qat = target
return b_qat

def _get_method_result(
self, method: str, fake_quant: FakeQuantize, observer: Observer
):


+ 7
- 1
imperative/python/megengine/quantization/__init__.py View File

@@ -30,4 +30,10 @@ from .quantize import (
quantize_qat,
reset_qconfig,
)
from .utils import QParams, QuantMode, create_qparams
from .utils import (
QParams,
QuantMode,
create_qparams,
fake_quant_bias,
fake_quant_tensor,
)

Loading…
Cancel
Save