Browse Source

feat(mge/quant): add TQT quant method

GitOrigin-RevId: 00b1616e73
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
fda9599a84
8 changed files with 203 additions and 30 deletions
  1. +1
    -0
      python_module/megengine/core/function.py
  2. +8
    -2
      python_module/megengine/module/qat/module.py
  3. +1
    -0
      python_module/megengine/quantization/__init__.py
  4. +106
    -25
      python_module/megengine/quantization/fake_quant.py
  5. +2
    -0
      python_module/megengine/quantization/observer.py
  6. +8
    -2
      python_module/megengine/quantization/qconfig.py
  7. +0
    -1
      python_module/test/unit/core/test_function.py
  8. +77
    -0
      python_module/test/unit/quantization/test_TQT.py

+ 1
- 0
python_module/megengine/core/function.py View File

@@ -154,6 +154,7 @@ class Function(metaclass=ABCMeta):
memo[id(self)] = result memo[id(self)] = result
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo)) setattr(result, k, copy.deepcopy(v, memo))
setattr(result, "saved_tensors", tmp)
self.saved_tensors = tmp self.saved_tensors = tmp
return result return result




+ 8
- 2
python_module/megengine/module/qat/module.py View File

@@ -77,13 +77,19 @@ class QATModule(Module):
r""" r"""
Get weight's quantization dtype as the method from ``qconfig``. Get weight's quantization dtype as the method from ``qconfig``.
""" """
return self.weight_observer.get_dtype()
if hasattr(self.act_fake_quant, "get_dtype"):
return self.weight_fake_quant.get_dtype()
else:
return self.weight_observer.get_dtype()


def get_activation_dtype(self): def get_activation_dtype(self):
r""" r"""
Get activation's quantization dtype as the method from ``qconfig``. Get activation's quantization dtype as the method from ``qconfig``.
""" """
return self.act_observer.get_dtype()
if hasattr(self.act_fake_quant, "get_dtype"):
return self.act_fake_quant.get_dtype()
else:
return self.act_observer.get_dtype()


@classmethod @classmethod
@abstractmethod @abstractmethod


+ 1
- 0
python_module/megengine/quantization/__init__.py View File

@@ -12,4 +12,5 @@ from .qconfig import (
calibration_qconfig, calibration_qconfig,
ema_fakequant_qconfig, ema_fakequant_qconfig,
min_max_fakequant_qconfig, min_max_fakequant_qconfig,
tqt_quant_qconfig,
) )

+ 106
- 25
python_module/megengine/quantization/fake_quant.py View File

@@ -5,17 +5,20 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
import math

import numpy as np

from .. import functional as F from .. import functional as F
from .._internal.dtype import _metadata_dict
from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, Parameter
from ..jit import sideeffect
from ..module import Module from ..module import Module
from .observer import ObserverMode, Round from .observer import ObserverMode, Round




class FakeQuantize(Module):
r"""
A module to do quant and dequant according to observer's scale and zero_point.
"""

class _FakeQuantize(Module):
def __init__(self, dtype: str, enable: bool = True): def __init__(self, dtype: str, enable: bool = True):
super().__init__() super().__init__()
if not dtype in _metadata_dict.keys(): if not dtype in _metadata_dict.keys():
@@ -35,25 +38,103 @@ class FakeQuantize(Module):
def disable(self): def disable(self):
self.enabled = False self.enabled = False


def fake_quant_forward(self, inp, q_dict):
return inp

def normal_foward(self, inp, q_dict):
return inp

def forward(self, inp, q_dict): def forward(self, inp, q_dict):
if self.enabled: if self.enabled:
if q_dict["mode"] == ObserverMode.SYMMERTIC:
scale = q_dict["scale"]
# Quant
oup = Round()(inp / scale)
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup) * scale
return oup
else:
scale = q_dict["scale"]
zero_point = q_dict["zero_point"]
# Quant
oup = Round()(inp / scale) + zero_point
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup - zero_point) * scale
return oup
return self.fake_quant_forward(inp, q_dict)
else:
return self.normal_foward(inp, q_dict)


class TQT_Function(Function):
def __init__(self, lowerbound, upperbound):
super().__init__()
self.lowerbound = lowerbound
self.upperbound = upperbound

def forward(self, inp, scale):
t = 2 ** scale
# t = F.maximum(t, 1e-4)
inp_scaled = inp / t
inp_clipped = F.maximum(F.minimum(inp_scaled, self.upperbound), self.lowerbound)
inp_rounded = F.round(inp_clipped)
inp_flq = inp_rounded * t
self.save_for_backward(inp_scaled, inp_rounded, t)
return inp_flq

def backward(self, grad_inp_flq):
(inp_scaled, inp_rounded, t) = self.saved_tensors
mask_clip = (inp_scaled < -0.5 + self.lowerbound) + (
inp_scaled > self.upperbound + 0.5
) # mask for accumulating the gradients of |data_scaled|>L
mask_quant = F.abs(
mask_clip - 1
) # mask for accumulating the gradients with |data_scaled|<=L
grad_quant = (
grad_inp_flq * mask_quant * (inp_rounded - inp_scaled)
) # gradient within |data_scaled|<=L
grad_clip = (
grad_inp_flq * mask_clip * inp_rounded
) # gradient with | data_scaled|>L
grad_s = grad_clip.sum() + grad_quant.sum()
# dL/ds = dL/dt * t * ln(2)
grad_s = grad_s * t * math.log(2)
grad_inp = grad_inp_flq * mask_quant
return grad_inp, grad_s


class TQT(_FakeQuantize):
"""
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks
"""

def __init__(self, dtype: str, enable: bool = True):
super().__init__(dtype, enable)
self.scale = Parameter(0.0, dtype=np.float32)

def fake_quant_forward(self, inp, q_dict):
# when enable, TQT will do fakequant forward, finetune the scale
return TQT_Function(self.qmin, self.qmax)(inp, self.scale)

def normal_foward(self, inp, q_dict):
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
tmp_scale = F.log(tmp_scale / 127) / F.log(2)
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0)
return inp return inp

def get_dtype(self):
return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None)


class FakeQuantize(_FakeQuantize):
r"""
A module to do quant and dequant according to observer's scale and zero_point.
"""

def fake_quant_forward(self, inp, q_dict):
if q_dict["mode"] == ObserverMode.SYMMERTIC:
scale = q_dict["scale"]
# Quant
oup = Round()(inp / scale)
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup) * scale
return oup
else:
scale = q_dict["scale"]
zero_point = q_dict["zero_point"]
# Quant
oup = Round()(inp / scale) + zero_point
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup - zero_point) * scale
return oup

+ 2
- 0
python_module/megengine/quantization/observer.py View File

@@ -107,6 +107,8 @@ class MinMaxObserver(Observer):
min_val = F.minimum(0.0, inp_min_val) min_val = F.minimum(0.0, inp_min_val)
max_val = F.maximum(0.0, inp_max_val) max_val = F.maximum(0.0, inp_max_val)
q_dict = create_observer_dict(self.mode) q_dict = create_observer_dict(self.mode)
q_dict["min_val"] = inp_min_val
q_dict["max_val"] = inp_max_val
if self.mode == ObserverMode.SYMMERTIC: if self.mode == ObserverMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val) symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin # use maximun to avoid scale too small at the begin


+ 8
- 2
python_module/megengine/quantization/qconfig.py View File

@@ -1,12 +1,12 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
#'
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..module import Module from ..module import Module
from .fake_quant import FakeQuantize
from .fake_quant import TQT, FakeQuantize
from .observer import ( from .observer import (
ExponentialMovingAverageObserver, ExponentialMovingAverageObserver,
HistogramObserver, HistogramObserver,
@@ -52,6 +52,12 @@ class QConfig:
self.fake_quant = fake_quant self.fake_quant = fake_quant




tqt_quant_qconfig = QConfig(
weight_observer=ExponentialMovingAverageObserver,
act_observer=ExponentialMovingAverageObserver,
fake_quant=TQT,
)

# Default QAT QConfigs # Default QAT QConfigs
min_max_fakequant_qconfig = QConfig( min_max_fakequant_qconfig = QConfig(
weight_observer=MinMaxObserver, weight_observer=MinMaxObserver,


+ 0
- 1
python_module/test/unit/core/test_function.py View File

@@ -96,7 +96,6 @@ def test_deepcopy():
origin = Sigmoid(0) origin = Sigmoid(0)
new = copy.deepcopy(Sigmoid(0)) new = copy.deepcopy(Sigmoid(0))
assert new.param == origin.param assert new.param == origin.param
assert new.saved_tensors == None




def test_save_context(): def test_save_context():


+ 77
- 0
python_module/test/unit/quantization/test_TQT.py View File

@@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest

import megengine as mge
import megengine._internal as mgb
from megengine.core import tensor
from megengine.quantization.fake_quant import TQT_Function
from megengine.test import assertTensorClose


class numpy_TQT_Function:
def __init__(self, lowerbound, upperbound):
super().__init__()
self.lowerbound = lowerbound
self.upperbound = upperbound

def forward(self, inp, scale):
t = 2 ** scale
# t = F.maximum(t, 1e-4)
inp_scaled = inp / t
inp_clipped = np.maximum(
np.minimum(inp_scaled, self.upperbound), self.lowerbound
)
inp_rounded = np.round(inp_clipped)
inp_flq = inp_rounded * t
self.saved_tensors = (inp_scaled, inp_rounded, t)
return inp_flq

def backward(self, grad_inp_flq):
(inp_scaled, inp_rounded, t) = self.saved_tensors
mask_clip = (inp_scaled < -0.5 + self.lowerbound) + (
inp_scaled > self.upperbound + 0.5
) # mask for accumulating the gradients of |data_scaled|>L
mask_quant = np.abs(
mask_clip - 1
) # mask for accumulating the gradients with |data_scaled|<=L
grad_quant = (
grad_inp_flq * mask_quant * (inp_rounded - inp_scaled)
) # gradient within |data_scaled|<=L
grad_clip = (
grad_inp_flq * mask_clip * inp_rounded
) # gradient with | data_scaled|>L
grad_s = grad_clip.sum() + grad_quant.sum()
# dL/ds = dL/dt * t * ln(2)
grad_s = grad_s * t * np.log(2)
grad_inp = grad_inp_flq * mask_quant
return grad_inp, grad_s


def test_TQT():
f = TQT_Function(-127, 127)
nf = numpy_TQT_Function(-127, 127)

def check_inp(a, b, c, a_np, b_np, c_np):
assertTensorClose(
f.forward(a, b).numpy(), nf.forward(a_np, b_np).astype("float32")
)
c1, c2 = f.backward(c)
c1_np, c2_np = nf.backward(c_np)
assertTensorClose(c1.numpy(), c1_np.astype("float32"))
assertTensorClose(c2.numpy(), c2_np.astype("float32"))

a = tensor()
b = tensor()
a_np = np.random.random((4, 3)).astype("float32")
b_np = np.random.random((1)).astype("float32")
a.set_value(a_np)
b.set_value(b_np)
check_inp(a, b, b, a_np, b_np, b_np)

Loading…
Cancel
Save