You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fake_quant.py 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from copy import deepcopy
  9. from typing import Union
  10. from ..core.tensor.dtype import QuantDtypeMeta
  11. from ..quantization.fake_quant import QParamsModuleMixin, _FakeQuantize
  12. from ..quantization.utils import QParams, QuantMode, fake_quant_tensor
  13. class FakeQuantize(_FakeQuantize, QParamsModuleMixin):
  14. r"""A module to do quant and dequant according to :attr:`~.FakeQuantize.qparams`."""
  15. def __init__(
  16. self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
  17. ):
  18. super().__init__(dtype, enable, **kwargs)
  19. self.qparams = None
  20. def fake_quant_forward(self, inp, qparams: QParams = None):
  21. if qparams is None:
  22. qparams = self.get_qparams()
  23. assert (
  24. qparams.dtype_meta is self.dtype
  25. ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
  26. qparams.dtype_meta, self.dtype
  27. )
  28. return fake_quant_tensor(inp, qparams)
  29. def get_qparams(self):
  30. return self.qparams
  31. def set_qparams(self, qparams: QParams):
  32. r"""Initialize :attr:`~.FakeQuantize.qparams`.
  33. Args:
  34. qparams: used to set initial ``scale`` and ``zero_point``.
  35. """
  36. if qparams.scale is None:
  37. raise AssertionError("Can not get an initialized scale")
  38. scale = qparams.scale
  39. if qparams.dtype_meta is None:
  40. qparams.dtype_meta = self.dtype
  41. else:
  42. assert (
  43. qparams.dtype_meta is self.dtype
  44. ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
  45. qparams.dtype_meta, self.dtype
  46. )
  47. dtype_meta = qparams.dtype_meta
  48. zero_point = qparams.zero_point
  49. mode = qparams.mode
  50. self.qparams = QParams(mode, dtype_meta, scale, zero_point)