|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- from abc import abstractmethod
-
- import numpy as np
-
- from .. import functional as F
- from .._internal.dtype import _metadata_dict, get_quantized_dtype
- from ..core import Buffer, Function, ones, tensor, zeros
- from ..module import Module
-
-
- class Round(Function):
- def forward(self, x):
- return x.round()
-
- def backward(self, output_grads):
- return output_grads
-
-
- class Observer(Module):
- r"""
- A base class for Observer Module.
-
- :param dtype: a string indicating to collect scale and zero_point of which dtype
- """
-
- def __init__(self, dtype="qint8"):
- super().__init__()
- if dtype not in _metadata_dict.keys():
- raise ValueError(
- "unknown dtype: {}, only support {}".format(
- dtype, _metadata_dict.keys()
- )
- )
- self.dtype = dtype
- self.qmin = _metadata_dict[dtype].qmin
- self.qmax = _metadata_dict[dtype].qmax
- self.zero_point, self.scale = None, None
- self.enabled = True
-
- def get_dtype(self):
- scale, zero_point = self.get_qparams()
- numpy_scale = None if scale is None else scale.numpy()[0]
- numpy_zero_point = None if zero_point is None else zero_point.numpy()[0]
- return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)
-
- def enable(self):
- self.enabled = True
-
- def disable(self):
- self.enabled = False
-
- @abstractmethod
- def forward(self, x):
- pass
-
- @abstractmethod
- def get_qparams(self, **kwargs):
- pass
-
-
- class IdentityObserver(Observer):
- r"""
- An test Observer that always return scale:1 and zero_point:0.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.zero_point = ones((1), dtype="float32")
- self.scale = zeros((1), dtype="float32")
-
- def forward(self, x):
- return x
-
- def get_qparams(self):
- return self.scale, self.zero_point
-
-
- class MinMaxObserver(Observer):
- def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.symmetric = symmetric
- if self.symmetric:
- # assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1'
- self.zero_point = tensor((self.qmin + self.qmax + 1) // 2)
-
- self.min_val = Buffer(0.0, dtype=np.float32)
- self.max_val = Buffer(0.0, dtype=np.float32)
- self.scale_limit = eps
- # flag is used by cond_take, first time will be first flag, and after will be set as not_flag
- self.first_flag = Buffer(np.array([1, 0], dtype=np.int32))
- self.not_flag = Buffer(np.array([0, 1], dtype=np.int32))
-
- def set_min_max(self, tmp_min, tmp_max):
- # FIXME: cond_take will destory shape, use reshape to reset shape
- tmp_min = tmp_min.reshape(1)
- tmp_max = tmp_max.reshape(1)
- if self.training:
- F.zero_grad(
- F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0)
- )
- F.zero_grad(
- F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0)
- )
- F.zero_grad(
- F.add_update(
- self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0
- )
- )
-
- # FIXME: add_update is applied after the whole trace procedure in `symbolic=True`
- # mode. So use tmp_min/tmp_max to calc and save scale/zero_point for further
- # calculation in FakeQuant.
- self.set_scale_zero_point(tmp_min, tmp_max)
-
- def set_scale_zero_point(self, tmp_min, tmp_max):
- if self.symmetric:
- symmetric_max_vals = F.maximum(-tmp_min, tmp_max)
- # use maximun to avoid scale too small at the begin
- self.scale = F.maximum(
- symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
- )
- # zero_point = self.zero_point
- else:
- # use maximun to avoid scale too small at the begin
- self.scale = F.maximum(
- (tmp_max - tmp_min) / (self.qmax - self.qmin), self.scale_limit
- )
- # caculate zero_point
- self.zero_point = self.qmin - Round()((tmp_min / self.scale))
-
- def get_qparams(self):
- # scale and zero_point is runtime tensor rather than Buffer,
- # so need to re-calc if min_val and max_val are loaded.
- if self.scale is None:
- self.set_scale_zero_point(self.min_val, self.max_val)
-
- return self.scale, self.zero_point
-
- def forward(self, x_orig):
- if self.enabled:
- # stop gradient
- x = F.zero_grad(x_orig)
- # find max and min
- tmp_min, _ = F.cond_take(
- self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())])
- )
- tmp_max, _ = F.cond_take(
- self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())])
- )
- self.set_min_max(tmp_min, tmp_max)
- return x_orig
-
-
- class ExponentialMovingAverageObserver(MinMaxObserver):
- def __init__(self, momentum=0.9, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.momentum = Buffer(momentum)
-
- def set_momentum(self, momentum):
- self.momentum.set_value(momentum)
-
- def forward(self, x_orig):
- if self.enabled:
- # stop gradient
- x = F.zero_grad(x_orig)
- # Exponential Moving Average
- tmp_min, _ = F.cond_take(
- self.first_flag,
- F.concat(
- [
- x.min(),
- self.momentum * self.min_val + (1 - self.momentum) * x.min(),
- ]
- ),
- )
- tmp_max, _ = F.cond_take(
- self.first_flag,
- F.concat(
- [
- x.max(),
- self.momentum * self.max_val + (1 - self.momentum) * x.max(),
- ]
- ),
- )
- self.set_min_max(tmp_min, tmp_max)
- return x_orig
|