You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

observer.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from abc import abstractmethod
  9. import numpy as np
  10. from .. import functional as F
  11. from .._internal.dtype import _metadata_dict, get_quantized_dtype
  12. from ..core import Buffer, Function, ones, tensor, zeros
  13. from ..module import Module
  14. class Round(Function):
  15. def forward(self, x):
  16. return x.round()
  17. def backward(self, output_grads):
  18. return output_grads
  19. class Observer(Module):
  20. r"""
  21. A base class for Observer Module.
  22. :param dtype: a string indicating to collect scale and zero_point of which dtype
  23. """
  24. def __init__(self, dtype="qint8"):
  25. super().__init__()
  26. if dtype not in _metadata_dict.keys():
  27. raise ValueError(
  28. "unknown dtype: {}, only support {}".format(
  29. dtype, _metadata_dict.keys()
  30. )
  31. )
  32. self.dtype = dtype
  33. self.qmin = _metadata_dict[dtype].qmin
  34. self.qmax = _metadata_dict[dtype].qmax
  35. self.enabled = True
  36. def get_dtype(self):
  37. scale, zero_point = self.get_qparams()
  38. numpy_scale = None if scale is None else scale.numpy()[0]
  39. numpy_zero_point = None if zero_point is None else zero_point.numpy()[0]
  40. return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)
  41. def enable(self):
  42. self.enabled = True
  43. def disable(self):
  44. self.enabled = False
  45. def train(self, mode: bool = True) -> None:
  46. super().train(mode)
  47. if mode:
  48. self.enable()
  49. else:
  50. self.disable()
  51. @abstractmethod
  52. def forward(self, x):
  53. pass
  54. @abstractmethod
  55. def get_qparams(self, **kwargs):
  56. pass
  57. class MinMaxObserver(Observer):
  58. def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs):
  59. super().__init__(*args, **kwargs)
  60. self.symmetric = symmetric
  61. if self.symmetric:
  62. # assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1'
  63. self.zero_point = tensor((self.qmin + self.qmax + 1) // 2)
  64. self.min_val = Buffer(0.0, dtype=np.float32)
  65. self.max_val = Buffer(0.0, dtype=np.float32)
  66. self.scale_limit = eps
  67. # flag is used by cond_take, first time will be first flag, and after will be set as not_flag
  68. self.first_flag = Buffer(np.array([1, 0], dtype=np.int32))
  69. self.not_flag = Buffer(np.array([0, 1], dtype=np.int32))
  70. def set_min_max(self, tmp_min, tmp_max):
  71. # FIXME: cond_take will destory shape, use reshape to reset shape
  72. tmp_min = tmp_min.reshape(1)
  73. tmp_max = tmp_max.reshape(1)
  74. F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0)
  75. F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0)
  76. F.add_update(self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0)
  77. def get_qparams(self):
  78. if self.symmetric:
  79. symmetric_max_vals = F.maximum(-self.min_val, self.max_val)
  80. # use maximun to avoid scale too small at the begin
  81. scale = F.maximum(
  82. symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
  83. )
  84. zero_point = self.zero_point
  85. else:
  86. # use maximun to avoid scale too small at the begin
  87. scale = F.maximum(
  88. (self.max_val - self.min_val) / (self.qmax - self.qmin),
  89. self.scale_limit,
  90. )
  91. # caculate zero_point
  92. zero_point = self.qmin - Round()((self.min_val / scale))
  93. return scale, zero_point
  94. def forward(self, x_orig):
  95. if self.enabled:
  96. # stop gradient
  97. x = F.zero_grad(x_orig)
  98. # find max and min
  99. tmp_min, _ = F.cond_take(
  100. self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())])
  101. )
  102. tmp_max, _ = F.cond_take(
  103. self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())])
  104. )
  105. self.set_min_max(tmp_min, tmp_max)
  106. return x_orig
  107. class ExponentialMovingAverageObserver(MinMaxObserver):
  108. def __init__(self, momentum=0.9, *args, **kwargs):
  109. super().__init__(*args, **kwargs)
  110. self.momentum = Buffer(momentum)
  111. def set_momentum(self, momentum):
  112. self.momentum.set_value(momentum)
  113. def forward(self, x_orig):
  114. if self.enabled:
  115. # stop gradient
  116. x = F.zero_grad(x_orig)
  117. # Exponential Moving Average
  118. tmp_min, _ = F.cond_take(
  119. self.first_flag,
  120. F.concat(
  121. [
  122. x.min(),
  123. self.momentum * self.min_val + (1 - self.momentum) * x.min(),
  124. ]
  125. ),
  126. )
  127. tmp_max, _ = F.cond_take(
  128. self.first_flag,
  129. F.concat(
  130. [
  131. x.max(),
  132. self.momentum * self.max_val + (1 - self.momentum) * x.max(),
  133. ]
  134. ),
  135. )
  136. self.set_min_max(tmp_min, tmp_max)
  137. return x_orig

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台