You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_module.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from functools import partial
  9. import numpy as np
  10. import pytest
  11. import megengine as mge
  12. import megengine.functional as F
  13. import megengine.module as Float
  14. import megengine.module.qat as QAT
  15. import megengine.module.quantized as Q
  16. from megengine import Parameter, Tensor
  17. from megengine.core.tensor import dtype
  18. from megengine.quantization import (
  19. FakeQuantize,
  20. MinMaxObserver,
  21. QConfig,
  22. QuantMode,
  23. create_qparams,
  24. )
  25. from megengine.quantization.quantize import (
  26. disable_fake_quant,
  27. disable_observer,
  28. propagate_qconfig,
  29. )
  30. min_max_fakequant_qconfig = QConfig(
  31. weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
  32. act_observer=partial(MinMaxObserver, dtype="qint8"),
  33. weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
  34. act_fake_quant=partial(FakeQuantize, dtype="qint8"),
  35. )
  36. def gen_inp_scale():
  37. return np.float32(np.random.rand() + 1)
  38. min_val = np.random.randint(-127, 0, size=(2,)).astype("float32")
  39. max_val = np.random.randint(1, 127, size=(2,)).astype("float32")
  40. weight_scale = np.float32(np.max([-min_val[0], max_val[0]]) / 254 * 2)
  41. act_scale = np.float32(np.max([-min_val[1], max_val[1]]) / 255 * 2)
  42. def quant(x, scale):
  43. inp_dtype = dtype.qint8(scale)
  44. return x.astype(inp_dtype)
  45. def fake_quant(x, scale, qmin, qmax):
  46. x = x / scale
  47. x = F.round(x)
  48. x = F.clip(x, qmin, qmax)
  49. x = x * scale
  50. return x
  51. fake_quant_act = partial(fake_quant, qmin=-128, qmax=127)
  52. fake_quant_weight = partial(fake_quant, qmin=-127, qmax=127)
  53. fake_quant_bias = partial(fake_quant, qmin=-(2 ** 31), qmax=2 ** 31 - 1)
  54. def init_qat_net(net):
  55. if net.with_weight:
  56. net.weight_observer.min_val[...] = Tensor(min_val[0])
  57. net.weight_observer.max_val[...] = Tensor(max_val[0])
  58. if net.with_act:
  59. net.act_observer.min_val[...] = Tensor(min_val[1])
  60. net.act_observer.max_val[...] = Tensor(max_val[1])
  61. def test_quant_stub():
  62. normal_net = Float.QuantStub()
  63. normal_net.eval()
  64. qat_from_float = QAT.QuantStub.from_float_module(normal_net)
  65. qat_from_float.eval()
  66. disable_observer(qat_from_float)
  67. disable_fake_quant(qat_from_float)
  68. qat_net = QAT.QuantStub()
  69. qat_net.eval()
  70. disable_observer(qat_net)
  71. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  72. init_qat_net(qat_net)
  73. q_net = Q.QuantStub.from_qat_module(qat_net)
  74. q_net.eval()
  75. x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  76. normal = normal_net(x)
  77. qat_without_fakequant = qat_from_float(x)
  78. fake_quant_normal = fake_quant_act(normal_net(x), act_scale)
  79. qat = qat_net(x)
  80. q = q_net(x).numpy() * act_scale
  81. np.testing.assert_allclose(qat_without_fakequant, normal)
  82. np.testing.assert_allclose(qat, fake_quant_normal)
  83. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  84. def test_dequant_stub():
  85. normal_net = Float.DequantStub()
  86. normal_net.eval()
  87. qat_from_float = QAT.DequantStub.from_float_module(normal_net)
  88. qat_from_float.eval()
  89. disable_fake_quant(qat_from_float)
  90. disable_observer(qat_from_float)
  91. qat_net = QAT.DequantStub()
  92. qat_net.eval()
  93. disable_observer(qat_net)
  94. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  95. init_qat_net(qat_net)
  96. q_net = Q.DequantStub.from_qat_module(qat_net)
  97. q_net.eval()
  98. x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  99. inp_scale = gen_inp_scale()
  100. x = fake_quant_act(x, inp_scale)
  101. x.qparams.scale = inp_scale
  102. normal = normal_net(x)
  103. qat_without_fakequant = qat_from_float(x)
  104. fake_quant_normal = normal_net(x)
  105. qat = qat_net(x)
  106. q = q_net(quant(x, inp_scale)).numpy()
  107. np.testing.assert_allclose(qat_without_fakequant, normal)
  108. np.testing.assert_allclose(qat, fake_quant_normal)
  109. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  110. @pytest.mark.parametrize("kind", ["cos", "relu", "add", "mul", "fuse_add_relu"])
  111. def test_elemwise(kind):
  112. normal_net = Float.Elemwise(kind)
  113. normal_net.eval()
  114. qat_from_float = QAT.Elemwise.from_float_module(normal_net)
  115. qat_from_float.eval()
  116. disable_observer(qat_from_float)
  117. disable_fake_quant(qat_from_float)
  118. qat_net = QAT.Elemwise(kind)
  119. qat_net.eval()
  120. disable_observer(qat_net)
  121. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  122. init_qat_net(qat_net)
  123. q_net = Q.Elemwise.from_qat_module(qat_net)
  124. q_net.eval()
  125. x1_scale = np.float32(np.random.rand() + 1)
  126. x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  127. x1 = fake_quant_act(x1, x1_scale)
  128. x1.qparams.scale = x1_scale
  129. x2_scale = np.float32(np.random.rand() + 1)
  130. x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  131. x2 = fake_quant_act(x2, x2_scale)
  132. x2.qparams.scale = x2_scale
  133. x1_int8 = quant(x1, x1_scale)
  134. x2_int8 = quant(x2, x2_scale)
  135. # test correctness of `Float`, `QAT` and `Quantized`
  136. if kind in ("add", "mul", "fuse_add_relu"):
  137. normal = normal_net(x1, x2)
  138. qat_without_fakequant = qat_from_float(x1, x2)
  139. fake_quant_normal = fake_quant_act(normal_net(x1, x2), act_scale)
  140. qat = qat_net(x1, x2)
  141. q = q_net(x1_int8, x2_int8).numpy() * act_scale
  142. else:
  143. normal = normal_net(x1)
  144. qat_without_fakequant = qat_from_float(x1)
  145. fake_quant_normal = fake_quant_act(normal_net(x1), act_scale)
  146. qat = qat_net(x1)
  147. q = q_net(x1_int8).numpy() * act_scale
  148. np.testing.assert_allclose(qat_without_fakequant, normal)
  149. np.testing.assert_allclose(qat, fake_quant_normal)
  150. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  151. def test_linear():
  152. normal_net = Float.Linear(3, 3, bias=True)
  153. normal_net.eval()
  154. qat_net = QAT.Linear(3, 3, bias=True)
  155. qat_net.eval()
  156. disable_observer(qat_net)
  157. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  158. init_qat_net(qat_net)
  159. x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  160. inp_scale = gen_inp_scale()
  161. x = fake_quant_act(x, inp_scale)
  162. x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
  163. x_int8 = quant(x, inp_scale)
  164. weight = np.random.normal(size=(3, 3)).astype("float32")
  165. bias = np.random.normal(size=(3,)).astype("float32")
  166. normal_net.weight[...] = fake_quant_weight(weight, weight_scale)
  167. normal_net.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale)
  168. qat_net.weight[...] = Parameter(weight)
  169. qat_net.bias[...] = Parameter(bias)
  170. qat_from_float = QAT.Linear.from_float_module(normal_net)
  171. qat_from_float.eval()
  172. disable_fake_quant(qat_from_float)
  173. disable_observer(qat_from_float)
  174. q_net = Q.Linear.from_qat_module(qat_net)
  175. q_net.eval()
  176. normal = normal_net(x)
  177. qat_without_fakequant = qat_from_float(x)
  178. fake_quant_normal = fake_quant_act(normal_net(x), act_scale)
  179. qat = qat_net(x)
  180. q = q_net(x_int8).numpy() * act_scale
  181. np.testing.assert_allclose(qat_without_fakequant, normal)
  182. np.testing.assert_allclose(qat, fake_quant_normal.numpy())
  183. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  184. @pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"])
  185. @pytest.mark.parametrize("padding_mode", ["zeros", "reflect", "replicate"])
  186. def test_conv(module, padding_mode):
  187. normal_net = getattr(Float, module)(
  188. 3, 3, 3, 1, 1, 1, bias=True, padding_mode=padding_mode
  189. )
  190. normal_net.eval()
  191. qat_net = getattr(QAT, module)(
  192. 3, 3, 3, 1, 1, 1, bias=True, padding_mode=padding_mode
  193. )
  194. qat_net.eval()
  195. disable_observer(qat_net)
  196. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  197. init_qat_net(qat_net)
  198. x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32"))
  199. inp_scale = gen_inp_scale()
  200. x = fake_quant_act(x, inp_scale)
  201. x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
  202. x_int8 = quant(x, inp_scale)
  203. weight = np.random.normal(size=(3, 3, 3, 3)).astype("float32")
  204. bias = np.random.normal(size=(1, 3, 1, 1)).astype("float32")
  205. if module in ("ConvBn2d", "ConvBnRelu2d"):
  206. normal_net.conv.weight[...] = fake_quant_weight(weight, weight_scale)
  207. normal_net.conv.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale)
  208. qat_net.conv.weight[...] = Parameter(weight)
  209. qat_net.conv.bias[...] = Parameter(bias)
  210. else:
  211. normal_net.weight[...] = fake_quant_weight(weight, weight_scale)
  212. normal_net.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale)
  213. qat_net.weight[...] = Parameter(weight)
  214. qat_net.bias[...] = Parameter(bias)
  215. qat_from_float = getattr(QAT, module).from_float_module(normal_net)
  216. qat_from_float.eval()
  217. disable_observer(qat_from_float)
  218. disable_fake_quant(qat_from_float)
  219. q_net = getattr(Q, module).from_qat_module(qat_net)
  220. q_net.eval()
  221. normal = normal_net(x)
  222. qat_without_fakequant = qat_from_float(x)
  223. fake_quant_normal = fake_quant_act(normal_net(x), act_scale)
  224. qat = qat_net(x)
  225. q = q_net(x_int8).numpy() * act_scale
  226. np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5)
  227. np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale)
  228. np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale)
  229. def test_concat():
  230. normal_net = Float.Concat()
  231. normal_net.eval()
  232. qat_net = QAT.Concat()
  233. qat_net.eval()
  234. disable_observer(qat_net)
  235. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  236. init_qat_net(qat_net)
  237. inps = []
  238. inps_int8 = []
  239. for i in range(3):
  240. inp_scale = gen_inp_scale()
  241. inps.append(mge.tensor(np.random.normal(size=(3, 3)).astype("float32")))
  242. inps[i] = fake_quant_act(inps[i], inp_scale)
  243. inps[i].qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
  244. inps_int8.append(quant(inps[i], inp_scale))
  245. qat_from_float = QAT.Concat.from_float_module(normal_net)
  246. qat_from_float.eval()
  247. disable_fake_quant(qat_from_float)
  248. disable_observer(qat_from_float)
  249. q_net = Q.Concat.from_qat_module(qat_net)
  250. q_net.eval()
  251. normal = normal_net(inps)
  252. qat_without_fakequant = qat_from_float(inps)
  253. fake_quant_normal = fake_quant_act(normal_net(inps), act_scale)
  254. qat = qat_net(inps)
  255. q = q_net(inps_int8).numpy() * act_scale
  256. np.testing.assert_allclose(qat_without_fakequant, normal)
  257. np.testing.assert_allclose(qat, fake_quant_normal.numpy())
  258. np.testing.assert_allclose(q, fake_quant_normal.numpy())