You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_fake_quant.py 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import pytest
  11. import megengine as mge
  12. import megengine.functional as F
  13. from megengine import tensor
  14. from megengine.core.autodiff.grad import Function, Grad
  15. from megengine.core.tensor.dtype import QuantDtypeMeta
  16. from megengine.core.tensor.utils import make_shape_tuple
  17. from megengine.quantization.internal_fake_quant import *
  18. from megengine.quantization.utils import (
  19. QuantMode,
  20. create_qparams,
  21. fake_quant_tensor,
  22. lsq_forward,
  23. tqt_forward,
  24. )
  25. class TQT_numpy:
  26. def __init__(self, lowerbound, upperbound):
  27. super().__init__()
  28. self.lowerbound = lowerbound
  29. self.upperbound = upperbound
  30. def forward(self, inp, scale):
  31. t = 2 ** scale
  32. # t = F.maximum(t, 1e-4)
  33. inp_scaled = inp / t
  34. inp_clipped = np.maximum(
  35. np.minimum(inp_scaled, self.upperbound), self.lowerbound
  36. )
  37. inp_rounded = np.round(inp_clipped)
  38. inp_flq = inp_rounded * t
  39. self.saved_tensors = (inp_scaled, inp_rounded, t)
  40. return inp_flq
  41. def backward(self, grad_inp_flq):
  42. (inp_scaled, inp_rounded, t) = self.saved_tensors
  43. mask_clip = (inp_scaled < -0.5 + self.lowerbound) + (
  44. inp_scaled > self.upperbound + 0.5
  45. ) # mask for accumulating the gradients of |data_scaled|>L
  46. mask_quant = np.abs(
  47. mask_clip - 1
  48. ) # mask for accumulating the gradients with |data_scaled|<=L
  49. grad_quant = (
  50. grad_inp_flq * mask_quant * (inp_rounded - inp_scaled)
  51. ) # gradient within |data_scaled|<=L
  52. grad_clip = (
  53. grad_inp_flq * mask_clip * inp_rounded
  54. ) # gradient with | data_scaled|>L
  55. grad_s = grad_clip.sum() + grad_quant.sum()
  56. # dL/ds = dL/dt * t * ln(2)
  57. grad_s = grad_s * t * np.log(2)
  58. grad_inp = grad_inp_flq * mask_quant
  59. return grad_inp, grad_s
  60. def test_tqt():
  61. g = []
  62. def cb(grad):
  63. g.append(grad)
  64. x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32")
  65. s = np.random.rand(1) - 1
  66. g_y = np.ones(shape=(1, 2, 3, 4), dtype="float32")
  67. n = TQT_numpy(-127, 127)
  68. y_np = n.forward(x, s)
  69. g_x_np, g_s_np = n.backward(g_y)
  70. x = mge.tensor(x, dtype="float32")
  71. s = mge.tensor(s, dtype="float32")
  72. g_y = mge.tensor(g_y, dtype="float32")
  73. grad = Grad().wrt(x, s, callback=cb)
  74. y = tqt_forward(-127, 127, x, s)
  75. grad(y, g_y)
  76. g_x, g_s = g
  77. np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5)
  78. np.testing.assert_allclose(g_x.numpy(), g_x_np, rtol=1e-5, atol=1e-5)
  79. np.testing.assert_allclose(g_s.numpy(), g_s_np, rtol=5e-5, atol=5e-5)
  80. def _save_to(self, name="grad"):
  81. def callback(grad):
  82. setattr(self, name, grad)
  83. return callback
  84. class Round(Function):
  85. def forward(self, x):
  86. return F.round(x)
  87. def backward(self, output_grads):
  88. return output_grads
  89. def fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax):
  90. oup = Round()(inp / scale) + zero_point
  91. oup = F.minimum(F.maximum(oup, qmin), qmax)
  92. oup = (oup - zero_point) * scale
  93. return oup
  94. def test_fakequant():
  95. qmin = -126
  96. qmax = 129
  97. test_dtype = QuantDtypeMeta("test_qint8", None, "int8", qmin, qmax)
  98. def run(zero_point, scale):
  99. qparams = create_qparams(QuantMode.ASYMMERTIC, test_dtype, scale, zero_point)
  100. inp_data = np.random.uniform(low=-512.0, high=512.0, size=(1, 32, 32, 32))
  101. inp = tensor(inp_data, dtype=np.float32)
  102. # test forward
  103. oup = fake_quant_tensor(inp, qparams).numpy()
  104. oup_gt = fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax).numpy()
  105. assert np.allclose(oup, oup_gt)
  106. assert oup.shape == oup_gt.shape
  107. # test backward
  108. x = tensor(inp_data, dtype=np.float32)
  109. grad = Grad().wrt(x, callback=_save_to(x))
  110. y = fake_quant_tensor(x, qparams)
  111. grad(y, tensor(F.ones_like(x)))
  112. x1 = tensor(inp_data, dtype=np.float32)
  113. grad = Grad().wrt(x1, callback=_save_to(x1))
  114. y1 = fake_quant_tensor_gt(x1, scale, zero_point, qmin, qmax)
  115. grad(y1, tensor(F.ones_like(x1)))
  116. assert np.allclose(x.grad.numpy(), x1.grad.numpy())
  117. assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape)
  118. zero_point = tensor([1.0], dtype=np.float32)
  119. scale = tensor([4.0], dtype=np.float32)
  120. run(zero_point, scale)
  121. zero_point = tensor(1.0 * np.ones((1, 32, 1, 1)), dtype=np.float32)
  122. scale = tensor(4.0 * np.ones((1, 32, 1, 1)), dtype=np.float32)
  123. run(zero_point, scale)
  124. class LSQ_numpy:
  125. def __init__(self, lowerbound, upperbound):
  126. super().__init__()
  127. self.lowerbound = lowerbound
  128. self.upperbound = upperbound
  129. def forward(self, inp, scale, zero_point, grad_scale):
  130. inp_scaled = inp / scale + zero_point
  131. inp_clipped = np.maximum(
  132. np.minimum(inp_scaled, self.upperbound), self.lowerbound
  133. )
  134. inp_rounded = np.floor(inp_clipped + 0.5)
  135. inp_flq = (inp_rounded - zero_point) * scale
  136. self.saved_tensors = (inp_scaled, inp_rounded, scale, grad_scale)
  137. return inp_flq
  138. def backward(self, grad_inp_flq):
  139. (inp_scaled, inp_rounded, scale, grad_scale) = self.saved_tensors
  140. ind_small = inp_scaled < self.lowerbound
  141. ind_big = inp_scaled > self.upperbound
  142. ind_middle = np.logical_xor(ind_small, ind_big)
  143. ind_middle = np.abs(ind_middle - 1)
  144. grad_s = (
  145. ind_small * self.lowerbound
  146. + ind_big * self.upperbound
  147. + ind_middle * (-inp_scaled + inp_rounded)
  148. )
  149. grad_s = grad_s * grad_scale * grad_inp_flq
  150. grad_s = grad_s.sum()
  151. grad_inp = grad_inp_flq * ind_middle
  152. return grad_inp, grad_s
  153. def test_lsq():
  154. g = []
  155. def cb(grad):
  156. g.append(grad)
  157. # FIXME: use random number when LSQ is fixed
  158. # x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32")
  159. # s = np.random.rand(1)
  160. x = np.array(
  161. [
  162. [
  163. [
  164. [4.0, 38.0, -121.0, 38.0],
  165. [15.0, -115.0, -112.0, 24.0],
  166. [23.0, -65.0, 109.0, -115.0],
  167. ],
  168. [
  169. [-66.0, -90.0, -45.0, -101.0],
  170. [68.0, -98.0, 108.0, -79.0],
  171. [54.0, 63.0, -10.0, -50.0],
  172. ],
  173. ]
  174. ],
  175. dtype="float32",
  176. )
  177. s = np.array([0.02918224], dtype="float32")
  178. eps = np.array([1e-5], dtype="float32")
  179. s = np.abs(s) if np.abs(s) > eps else eps
  180. zero_point = np.array([1.0], dtype="float32")
  181. grad_s = np.array([2.0], dtype="float32")
  182. g_y = np.ones(shape=(1, 2, 3, 4), dtype="float32")
  183. n = LSQ_numpy(-127, 127)
  184. y_np = n.forward(x, s, zero_point, grad_s)
  185. g_x_np, g_s_np = n.backward(g_y)
  186. x = mge.tensor(x, dtype="float32")
  187. s = mge.tensor(s, dtype="float32")
  188. zero_point = mge.tensor(zero_point, dtype="float32")
  189. grad_s = mge.tensor(grad_s, dtype="float32")
  190. g_y = mge.tensor(g_y, dtype="float32")
  191. grad = Grad().wrt(x, s, callback=cb)
  192. y = lsq_forward(-127, 127, x, s, zero_point, grad_s)
  193. grad(y, g_y)
  194. g_x, g_s = g
  195. np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-7, atol=1e-7)
  196. np.testing.assert_allclose(g_x.numpy(), g_x_np, rtol=1e-7, atol=1e-7)
  197. np.testing.assert_allclose(g_s.numpy(), g_s_np, rtol=5e-7, atol=5e-7)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台