@@ -49,6 +49,8 @@ class QATModule(Module): | |||||
def _apply_fakequant_with_observer( | def _apply_fakequant_with_observer( | ||||
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | ||||
): | ): | ||||
if observer is None: | |||||
return target | |||||
oup = observer(target) | oup = observer(target) | ||||
if fake_quant is None: | if fake_quant is None: | ||||
return oup | return oup | ||||
@@ -76,7 +78,7 @@ class QATModule(Module): | |||||
r""" | r""" | ||||
Get weight's quantization dtype as the method from ``qconfig``. | Get weight's quantization dtype as the method from ``qconfig``. | ||||
""" | """ | ||||
if hasattr(self.act_fake_quant, "get_dtype"): | |||||
if hasattr(self.weight_fake_quant, "get_dtype"): | |||||
return self.weight_fake_quant.get_dtype() | return self.weight_fake_quant.get_dtype() | ||||
else: | else: | ||||
return self.weight_observer.get_dtype() | return self.weight_observer.get_dtype() | ||||
@@ -5,7 +5,9 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .fake_quant import FakeQuantize | from .fake_quant import FakeQuantize | ||||
from .internal_fake_quant import * | |||||
from .observer import HistogramObserver, Observer, ObserverMode | from .observer import HistogramObserver, Observer, ObserverMode | ||||
from .qconfig import ( | from .qconfig import ( | ||||
QConfig, | QConfig, | ||||
@@ -19,6 +19,15 @@ from .observer import ObserverMode, Round | |||||
class _FakeQuantize(Module): | class _FakeQuantize(Module): | ||||
r""" | |||||
A Basic Fake Quant module. | |||||
:param dtype: A string indicating the target quantization type of input. | |||||
:param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, | |||||
instead of 1 greater. Usually True for weight and False for activation. | |||||
:param enable: Whether do ``normal_forward`` or ``fake_quant_forward``. | |||||
""" | |||||
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | ||||
super().__init__() | super().__init__() | ||||
if not dtype in _metadata_dict.keys(): | if not dtype in _metadata_dict.keys(): | ||||
@@ -92,9 +101,9 @@ class TQT_Function(Function): | |||||
class TQT(_FakeQuantize): | class TQT(_FakeQuantize): | ||||
""" | |||||
r""" | |||||
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | ||||
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks | |||||
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks. | |||||
""" | """ | ||||
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | ||||
@@ -119,11 +128,6 @@ class TQT(_FakeQuantize): | |||||
class FakeQuantize(_FakeQuantize): | class FakeQuantize(_FakeQuantize): | ||||
r""" | r""" | ||||
A module to do quant and dequant according to observer's scale and zero_point. | A module to do quant and dequant according to observer's scale and zero_point. | ||||
:param dtype: A string indicating the target quantization type of input. | |||||
:param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, | |||||
instead of 1 greater. Usually True for weight and False for activation. | |||||
:param enable: Whether do ``normal_forward`` or ``fake_quant_forward``. | |||||
""" | """ | ||||
def fake_quant_forward(self, inp, q_dict): | def fake_quant_forward(self, inp, q_dict): | ||||
@@ -0,0 +1,19 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import copy | |||||
import math | |||||
from functools import partial | |||||
import numpy as np | |||||
from .. import functional as F | |||||
from ..core import Function | |||||
from .fake_quant import _FakeQuantize | |||||
from .observer import MinMaxObserver | |||||
from .qconfig import QConfig | |||||
@@ -13,6 +13,7 @@ import megengine as mge | |||||
import megengine._internal as mgb | import megengine._internal as mgb | ||||
from megengine.core import tensor | from megengine.core import tensor | ||||
from megengine.quantization.fake_quant import TQT_Function | from megengine.quantization.fake_quant import TQT_Function | ||||
from megengine.quantization.internal_fake_quant import * | |||||
from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
@@ -75,3 +76,5 @@ def test_TQT(): | |||||
a.set_value(a_np) | a.set_value(a_np) | ||||
b.set_value(b_np) | b.set_value(b_np) | ||||
check_inp(a, b, b, a_np, b_np, b_np) | check_inp(a, b, b, a_np, b_np, b_np) | ||||