You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_module.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. from functools import partial
  2. import numpy as np
  3. import pytest
  4. import megengine as mge
  5. import megengine.functional as F
  6. import megengine.module as Float
  7. import megengine.module.qat as QAT
  8. import megengine.module.quantized as Q
  9. from megengine import Parameter, Tensor
  10. from megengine.core.tensor import dtype
  11. from megengine.quantization import (
  12. FakeQuantize,
  13. MinMaxObserver,
  14. QConfig,
  15. QuantMode,
  16. create_qparams,
  17. )
  18. from megengine.quantization.quantize import (
  19. disable_fake_quant,
  20. disable_observer,
  21. propagate_qconfig,
  22. )
  23. min_max_fakequant_qconfig = QConfig(
  24. weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
  25. act_observer=partial(MinMaxObserver, dtype="qint8"),
  26. weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
  27. act_fake_quant=partial(FakeQuantize, dtype="qint8"),
  28. )
  29. def gen_inp_scale():
  30. return np.float32(np.random.rand() + 1)
  31. min_val = np.random.randint(-127, 0, size=(2,)).astype("float32")
  32. max_val = np.random.randint(1, 127, size=(2,)).astype("float32")
  33. weight_scale = np.float32(np.max([-min_val[0], max_val[0]]) / 254 * 2)
  34. act_scale = np.float32(np.max([-min_val[1], max_val[1]]) / 255 * 2)
  35. def quant(x, scale):
  36. inp_dtype = dtype.qint8(scale)
  37. return x.astype(inp_dtype)
  38. def fake_quant(x, scale, qmin, qmax):
  39. x = x / scale
  40. x = F.round(x)
  41. x = F.clip(x, qmin, qmax)
  42. x = x * scale
  43. return x
  44. fake_quant_act = partial(fake_quant, qmin=-128, qmax=127)
  45. fake_quant_weight = partial(fake_quant, qmin=-127, qmax=127)
  46. fake_quant_bias = partial(fake_quant, qmin=-(2 ** 31), qmax=2 ** 31 - 1)
  47. def init_qat_net(net):
  48. if net.with_weight:
  49. net.weight_observer.min_val[...] = Tensor(min_val[0])
  50. net.weight_observer.max_val[...] = Tensor(max_val[0])
  51. if net.with_act:
  52. net.act_observer.min_val[...] = Tensor(min_val[1])
  53. net.act_observer.max_val[...] = Tensor(max_val[1])
  54. def test_quant_stub():
  55. normal_net = Float.QuantStub()
  56. normal_net.eval()
  57. qat_from_float = QAT.QuantStub.from_float_module(normal_net)
  58. qat_from_float.eval()
  59. disable_observer(qat_from_float)
  60. disable_fake_quant(qat_from_float)
  61. qat_net = QAT.QuantStub()
  62. qat_net.eval()
  63. disable_observer(qat_net)
  64. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  65. init_qat_net(qat_net)
  66. q_net = Q.QuantStub.from_qat_module(qat_net)
  67. q_net.eval()
  68. x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  69. normal = normal_net(x)
  70. qat_without_fakequant = qat_from_float(x)
  71. fake_quant_normal = fake_quant_act(normal_net(x), act_scale)
  72. qat = qat_net(x)
  73. q = q_net(x).numpy() * act_scale
  74. np.testing.assert_allclose(qat_without_fakequant, normal)
  75. np.testing.assert_allclose(qat, fake_quant_normal)
  76. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  77. def test_dequant_stub():
  78. normal_net = Float.DequantStub()
  79. normal_net.eval()
  80. qat_from_float = QAT.DequantStub.from_float_module(normal_net)
  81. qat_from_float.eval()
  82. disable_fake_quant(qat_from_float)
  83. disable_observer(qat_from_float)
  84. qat_net = QAT.DequantStub()
  85. qat_net.eval()
  86. disable_observer(qat_net)
  87. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  88. init_qat_net(qat_net)
  89. q_net = Q.DequantStub.from_qat_module(qat_net)
  90. q_net.eval()
  91. x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  92. inp_scale = gen_inp_scale()
  93. x = fake_quant_act(x, inp_scale)
  94. x.qparams.scale = inp_scale
  95. normal = normal_net(x)
  96. qat_without_fakequant = qat_from_float(x)
  97. fake_quant_normal = normal_net(x)
  98. qat = qat_net(x)
  99. q = q_net(quant(x, inp_scale)).numpy()
  100. np.testing.assert_allclose(qat_without_fakequant, normal)
  101. np.testing.assert_allclose(qat, fake_quant_normal)
  102. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  103. @pytest.mark.parametrize("kind", ["COS", "RELU", "ADD", "MUL", "FUSE_ADD_RELU"])
  104. def test_elemwise(kind):
  105. normal_net = Float.Elemwise(kind)
  106. normal_net.eval()
  107. qat_from_float = QAT.Elemwise.from_float_module(normal_net)
  108. qat_from_float.eval()
  109. disable_observer(qat_from_float)
  110. disable_fake_quant(qat_from_float)
  111. qat_net = QAT.Elemwise(kind)
  112. qat_net.eval()
  113. disable_observer(qat_net)
  114. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  115. init_qat_net(qat_net)
  116. q_net = Q.Elemwise.from_qat_module(qat_net)
  117. q_net.eval()
  118. x1_scale = np.float32(np.random.rand() + 1)
  119. x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  120. x1 = fake_quant_act(x1, x1_scale)
  121. x1.qparams.scale = x1_scale
  122. x2_scale = np.float32(np.random.rand() + 1)
  123. x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  124. x2 = fake_quant_act(x2, x2_scale)
  125. x2.qparams.scale = x2_scale
  126. x1_int8 = quant(x1, x1_scale)
  127. x2_int8 = quant(x2, x2_scale)
  128. # test correctness of `Float`, `QAT` and `Quantized`
  129. if kind in ("ADD", "MUL", "FUSE_ADD_RELU"):
  130. normal = normal_net(x1, x2)
  131. qat_without_fakequant = qat_from_float(x1, x2)
  132. fake_quant_normal = fake_quant_act(normal_net(x1, x2), act_scale)
  133. qat = qat_net(x1, x2)
  134. q = q_net(x1_int8, x2_int8).numpy() * act_scale
  135. else:
  136. normal = normal_net(x1)
  137. qat_without_fakequant = qat_from_float(x1)
  138. fake_quant_normal = fake_quant_act(normal_net(x1), act_scale)
  139. qat = qat_net(x1)
  140. q = q_net(x1_int8).numpy() * act_scale
  141. np.testing.assert_allclose(qat_without_fakequant, normal)
  142. np.testing.assert_allclose(qat, fake_quant_normal)
  143. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  144. def test_linear():
  145. normal_net = Float.Linear(3, 3, bias=True)
  146. normal_net.eval()
  147. qat_net = QAT.Linear(3, 3, bias=True)
  148. qat_net.eval()
  149. disable_observer(qat_net)
  150. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  151. init_qat_net(qat_net)
  152. x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  153. inp_scale = gen_inp_scale()
  154. x = fake_quant_act(x, inp_scale)
  155. x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
  156. x_int8 = quant(x, inp_scale)
  157. weight = np.random.normal(size=(3, 3)).astype("float32")
  158. bias = np.random.normal(size=(3,)).astype("float32")
  159. normal_net.weight[...] = fake_quant_weight(weight, weight_scale)
  160. normal_net.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale)
  161. qat_net.weight[...] = Parameter(weight)
  162. qat_net.bias[...] = Parameter(bias)
  163. qat_from_float = QAT.Linear.from_float_module(normal_net)
  164. qat_from_float.eval()
  165. disable_fake_quant(qat_from_float)
  166. disable_observer(qat_from_float)
  167. q_net = Q.Linear.from_qat_module(qat_net)
  168. q_net.eval()
  169. normal = normal_net(x)
  170. qat_without_fakequant = qat_from_float(x)
  171. fake_quant_normal = fake_quant_act(normal_net(x), act_scale)
  172. qat = qat_net(x)
  173. q = q_net(x_int8).numpy() * act_scale
  174. np.testing.assert_allclose(qat_without_fakequant, normal)
  175. np.testing.assert_allclose(qat, fake_quant_normal.numpy())
  176. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  177. @pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"])
  178. def test_conv(module):
  179. normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True)
  180. normal_net.eval()
  181. qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True)
  182. qat_net.eval()
  183. disable_observer(qat_net)
  184. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  185. init_qat_net(qat_net)
  186. x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32"))
  187. inp_scale = gen_inp_scale()
  188. x = fake_quant_act(x, inp_scale)
  189. x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
  190. x_int8 = quant(x, inp_scale)
  191. weight = np.random.normal(size=(3, 3, 3, 3)).astype("float32")
  192. bias = np.random.normal(size=(1, 3, 1, 1)).astype("float32")
  193. if module in ("ConvBn2d", "ConvBnRelu2d"):
  194. normal_net.conv.weight[...] = fake_quant_weight(weight, weight_scale)
  195. normal_net.conv.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale)
  196. qat_net.conv.weight[...] = Parameter(weight)
  197. qat_net.conv.bias[...] = Parameter(bias)
  198. else:
  199. normal_net.weight[...] = fake_quant_weight(weight, weight_scale)
  200. normal_net.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale)
  201. qat_net.weight[...] = Parameter(weight)
  202. qat_net.bias[...] = Parameter(bias)
  203. qat_from_float = getattr(QAT, module).from_float_module(normal_net)
  204. qat_from_float.eval()
  205. disable_observer(qat_from_float)
  206. disable_fake_quant(qat_from_float)
  207. q_net = getattr(Q, module).from_qat_module(qat_net)
  208. q_net.eval()
  209. normal = normal_net(x)
  210. qat_without_fakequant = qat_from_float(x)
  211. fake_quant_normal = fake_quant_act(normal_net(x), act_scale)
  212. qat = qat_net(x)
  213. q = q_net(x_int8).numpy() * act_scale
  214. np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5)
  215. np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale)
  216. np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale)
  217. def test_concat():
  218. normal_net = Float.Concat()
  219. normal_net.eval()
  220. qat_net = QAT.Concat()
  221. qat_net.eval()
  222. disable_observer(qat_net)
  223. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  224. init_qat_net(qat_net)
  225. inps = []
  226. inps_int8 = []
  227. for i in range(3):
  228. inp_scale = gen_inp_scale()
  229. inps.append(mge.tensor(np.random.normal(size=(3, 3)).astype("float32")))
  230. inps[i] = fake_quant_act(inps[i], inp_scale)
  231. inps[i].qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
  232. inps_int8.append(quant(inps[i], inp_scale))
  233. qat_from_float = QAT.Concat.from_float_module(normal_net)
  234. qat_from_float.eval()
  235. disable_fake_quant(qat_from_float)
  236. disable_observer(qat_from_float)
  237. q_net = Q.Concat.from_qat_module(qat_net)
  238. q_net.eval()
  239. normal = normal_net(inps)
  240. qat_without_fakequant = qat_from_float(inps)
  241. fake_quant_normal = fake_quant_act(normal_net(inps), act_scale)
  242. qat = qat_net(inps)
  243. q = q_net(inps_int8).numpy() * act_scale
  244. np.testing.assert_allclose(qat_without_fakequant, normal)
  245. np.testing.assert_allclose(qat, fake_quant_normal.numpy())
  246. np.testing.assert_allclose(q, fake_quant_normal.numpy())

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台