Browse Source

feat(mge/quantization): use extra act_fakequant to decide whether to do bias fakequant

GitOrigin-RevId: bf54012155
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
ce88e6c4f7
6 changed files with 32 additions and 21 deletions
  1. +1
    -3
      imperative/python/megengine/module/qat/batch_matmul_activation.py
  2. +1
    -5
      imperative/python/megengine/module/qat/conv.py
  3. +1
    -5
      imperative/python/megengine/module/qat/conv_bn.py
  4. +3
    -7
      imperative/python/megengine/module/qat/linear.py
  5. +19
    -0
      imperative/python/megengine/module/qat/module.py
  6. +7
    -1
      imperative/python/megengine/quantization/__init__.py

+ 1
- 3
imperative/python/megengine/module/qat/batch_matmul_activation.py View File

@@ -5,8 +5,6 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ...quantization.utils import fake_quant_bias
from .. import batch_matmul_activation as Float from .. import batch_matmul_activation as Float
from .module import QATModule from .module import QATModule


@@ -18,7 +16,7 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule):


def forward(self, inp): def forward(self, inp):
w_qat = self.apply_quant_weight(self.weight) w_qat = self.apply_quant_weight(self.weight)
b_qat = fake_quant_bias(self.bias, inp, w_qat)
b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat)) return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat))


@classmethod @classmethod


+ 1
- 5
imperative/python/megengine/module/qat/conv.py View File

@@ -6,7 +6,6 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import functional as F from ... import functional as F
from ...quantization.utils import fake_quant_bias
from .. import conv as Float from .. import conv as Float
from .module import QATModule from .module import QATModule


@@ -19,10 +18,7 @@ class Conv2d(Float.Conv2d, QATModule):


def calc_conv_qat(self, inp): def calc_conv_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight) w_qat = self.apply_quant_weight(self.weight)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(self.bias, inp, w_qat)
else:
b_qat = self.bias
b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
conv = self.calc_conv(inp, w_qat, b_qat) conv = self.calc_conv(inp, w_qat, b_qat)
return conv return conv




+ 1
- 5
imperative/python/megengine/module/qat/conv_bn.py View File

@@ -6,7 +6,6 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...functional import ones, relu, sqrt, sum, zeros from ...functional import ones, relu, sqrt, sum, zeros
from ...quantization.utils import fake_quant_bias
from .. import conv_bn as Float from .. import conv_bn as Float
from .module import QATModule from .module import QATModule


@@ -122,10 +121,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd


w_qat = self.apply_quant_weight(w_fold) w_qat = self.apply_quant_weight(w_fold)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(b_fold, inp, w_qat)
else:
b_qat = b_fold
b_qat = self.apply_quant_bias(b_fold, inp, w_qat)
conv = self.conv.calc_conv(inp, w_qat, b_qat) conv = self.conv.calc_conv(inp, w_qat, b_qat)
if not (self.training and approx): if not (self.training and approx):
return conv return conv


+ 3
- 7
imperative/python/megengine/module/qat/linear.py View File

@@ -5,7 +5,6 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...quantization.utils import fake_quant_bias
from .. import linear as Float from .. import linear as Float
from .module import QATModule from .module import QATModule


@@ -22,13 +21,10 @@ class Linear(Float.Linear, QATModule):


""" """


def forward(self, x):
def forward(self, inp):
w_qat = self.apply_quant_weight(self.weight) w_qat = self.apply_quant_weight(self.weight)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(self.bias, x, w_qat)
else:
b_qat = self.bias
return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat))
b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat))


@classmethod @classmethod
def from_float_module(cls, float_module: Float.Linear): def from_float_module(cls, float_module: Float.Linear):


+ 19
- 0
imperative/python/megengine/module/qat/module.py View File

@@ -11,6 +11,7 @@ from abc import abstractmethod
from ...quantization.fake_quant import FakeQuantize from ...quantization.fake_quant import FakeQuantize
from ...quantization.observer import Observer from ...quantization.observer import Observer
from ...quantization.qconfig import QConfig from ...quantization.qconfig import QConfig
from ...quantization.utils import fake_quant_bias
from ...tensor import Tensor from ...tensor import Tensor
from ..module import Module from ..module import Module


@@ -107,6 +108,24 @@ class QATModule(Module):
target, self.act_fake_quant, self.act_observer target, self.act_fake_quant, self.act_observer
) )


def apply_quant_bias(self, target: Tensor, inp: Tensor, w_qat: Tensor):
r"""
Use :func:`~.fake_quant_bias` to process ``target``. Only valid when
``act_fake_quant`` and ``weight_fake_quant`` are both enabled.
"""
# bias should have the same dtype as activation, so act_fake_quant can also
# decide whether to do bias fakequant
if (
self.act_fake_quant
and self.act_fake_quant.enabled
and self.weight_fake_quant
and self.weight_fake_quant.enabled
):
b_qat = fake_quant_bias(target, inp, w_qat)
else:
b_qat = target
return b_qat

def _get_method_result( def _get_method_result(
self, method: str, fake_quant: FakeQuantize, observer: Observer self, method: str, fake_quant: FakeQuantize, observer: Observer
): ):


+ 7
- 1
imperative/python/megengine/quantization/__init__.py View File

@@ -30,4 +30,10 @@ from .quantize import (
quantize_qat, quantize_qat,
reset_qconfig, reset_qconfig,
) )
from .utils import QParams, QuantMode, create_qparams
from .utils import (
QParams,
QuantMode,
create_qparams,
fake_quant_bias,
fake_quant_tensor,
)

Loading…
Cancel
Save