Browse Source

feat(quant): support nnie quant

GitOrigin-RevId: 8ca3f828bd
tags/v1.0.0-rc1
Megvii Engine Team 5 years ago
parent
commit
7a8a2830a8
5 changed files with 38 additions and 8 deletions
  1. +3
    -1
      python_module/megengine/module/qat/module.py
  2. +2
    -0
      python_module/megengine/quantization/__init__.py
  3. +11
    -7
      python_module/megengine/quantization/fake_quant.py
  4. +19
    -0
      python_module/megengine/quantization/internal_fake_quant.py
  5. +3
    -0
      python_module/test/unit/quantization/test_fake_quant.py

+ 3
- 1
python_module/megengine/module/qat/module.py View File

@@ -49,6 +49,8 @@ class QATModule(Module):
def _apply_fakequant_with_observer( def _apply_fakequant_with_observer(
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer self, target: Tensor, fake_quant: FakeQuantize, observer: Observer
): ):
if observer is None:
return target
oup = observer(target) oup = observer(target)
if fake_quant is None: if fake_quant is None:
return oup return oup
@@ -76,7 +78,7 @@ class QATModule(Module):
r""" r"""
Get weight's quantization dtype as the method from ``qconfig``. Get weight's quantization dtype as the method from ``qconfig``.
""" """
if hasattr(self.act_fake_quant, "get_dtype"):
if hasattr(self.weight_fake_quant, "get_dtype"):
return self.weight_fake_quant.get_dtype() return self.weight_fake_quant.get_dtype()
else: else:
return self.weight_observer.get_dtype() return self.weight_observer.get_dtype()


+ 2
- 0
python_module/megengine/quantization/__init__.py View File

@@ -5,7 +5,9 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from .fake_quant import FakeQuantize from .fake_quant import FakeQuantize
from .internal_fake_quant import *
from .observer import HistogramObserver, Observer, ObserverMode from .observer import HistogramObserver, Observer, ObserverMode
from .qconfig import ( from .qconfig import (
QConfig, QConfig,


+ 11
- 7
python_module/megengine/quantization/fake_quant.py View File

@@ -19,6 +19,15 @@ from .observer import ObserverMode, Round




class _FakeQuantize(Module): class _FakeQuantize(Module):
r"""
A Basic Fake Quant module.

:param dtype: A string indicating the target quantization type of input.
:param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``,
instead of 1 greater. Usually True for weight and False for activation.
:param enable: Whether do ``normal_forward`` or ``fake_quant_forward``.
"""

def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True):
super().__init__() super().__init__()
if not dtype in _metadata_dict.keys(): if not dtype in _metadata_dict.keys():
@@ -92,9 +101,9 @@ class TQT_Function(Function):




class TQT(_FakeQuantize): class TQT(_FakeQuantize):
"""
r"""
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.
""" """


def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True):
@@ -119,11 +128,6 @@ class TQT(_FakeQuantize):
class FakeQuantize(_FakeQuantize): class FakeQuantize(_FakeQuantize):
r""" r"""
A module to do quant and dequant according to observer's scale and zero_point. A module to do quant and dequant according to observer's scale and zero_point.

:param dtype: A string indicating the target quantization type of input.
:param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``,
instead of 1 greater. Usually True for weight and False for activation.
:param enable: Whether do ``normal_forward`` or ``fake_quant_forward``.
""" """


def fake_quant_forward(self, inp, q_dict): def fake_quant_forward(self, inp, q_dict):


+ 19
- 0
python_module/megengine/quantization/internal_fake_quant.py View File

@@ -0,0 +1,19 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
import math
from functools import partial

import numpy as np

from .. import functional as F
from ..core import Function
from .fake_quant import _FakeQuantize
from .observer import MinMaxObserver
from .qconfig import QConfig


python_module/test/unit/quantization/test_TQT.py → python_module/test/unit/quantization/test_fake_quant.py View File

@@ -13,6 +13,7 @@ import megengine as mge
import megengine._internal as mgb import megengine._internal as mgb
from megengine.core import tensor from megengine.core import tensor
from megengine.quantization.fake_quant import TQT_Function from megengine.quantization.fake_quant import TQT_Function
from megengine.quantization.internal_fake_quant import *
from megengine.test import assertTensorClose from megengine.test import assertTensorClose




@@ -75,3 +76,5 @@ def test_TQT():
a.set_value(a_np) a.set_value(a_np)
b.set_value(b_np) b.set_value(b_np)
check_inp(a, b, b, a_np, b_np, b_np) check_inp(a, b, b, a_np, b_np, b_np)



Loading…
Cancel
Save