You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_module.py 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. import numpy as np
  2. import pytest
  3. import megengine as mge
  4. import megengine.functional as F
  5. import megengine.module as Float
  6. import megengine.module.qat as QAT
  7. import megengine.module.quantized as Q
  8. from megengine.core.tensor import dtype
  9. from megengine.quantization import min_max_fakequant_qconfig
  10. from megengine.quantization.quantize import (
  11. disable_fake_quant,
  12. disable_observer,
  13. propagate_qconfig,
  14. )
  15. """
  16. Calculate testing scales based on ``min_max_fakequant_qconfig``
  17. """
  18. inp_scale = np.float32(np.random.rand() + 1)
  19. min_val = np.random.randint(-127, 0, size=(2,)).astype("float32")
  20. max_val = np.random.randint(1, 127, size=(2,)).astype("float32")
  21. weight_scale = np.float32(np.max([-min_val[0], max_val[0]]) / 254 * 2)
  22. act_scale = np.float32(np.max([-min_val[1], max_val[1]]) / 255 * 2)
  23. def quant(x, scale):
  24. inp_dtype = dtype.qint8(scale)
  25. return x.astype(inp_dtype)
  26. def fake_quant(x, scale):
  27. x = x / scale
  28. x = F.round(x)
  29. x = F.clip(x, -128, 127)
  30. x = x * scale
  31. return x
  32. def init_qat_net(net):
  33. if net.with_weight:
  34. net.weight_observer.min_val.set_value(min_val[0])
  35. net.weight_observer.max_val.set_value(max_val[0])
  36. if net.with_act:
  37. net.act_observer.min_val.set_value(min_val[1])
  38. net.act_observer.max_val.set_value(max_val[1])
  39. def test_quant_stub():
  40. normal_net = Float.QuantStub()
  41. normal_net.eval()
  42. qat_from_float = QAT.QuantStub.from_float_module(normal_net)
  43. qat_from_float.eval()
  44. disable_observer(qat_from_float)
  45. disable_fake_quant(qat_from_float)
  46. qat_net = QAT.QuantStub()
  47. qat_net.eval()
  48. disable_observer(qat_net)
  49. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  50. init_qat_net(qat_net)
  51. q_net = Q.QuantStub.from_qat_module(qat_net)
  52. q_net.eval()
  53. x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  54. normal = normal_net(x)
  55. qat_without_fakequant = qat_from_float(x)
  56. fake_quant_normal = fake_quant(normal_net(x), act_scale)
  57. qat = qat_net(x)
  58. q = q_net(x).numpy() * act_scale
  59. np.testing.assert_allclose(qat_without_fakequant, normal)
  60. np.testing.assert_allclose(qat, fake_quant_normal)
  61. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  62. def test_dequant_stub():
  63. normal_net = Float.DequantStub()
  64. normal_net.eval()
  65. qat_from_float = QAT.DequantStub.from_float_module(normal_net)
  66. qat_from_float.eval()
  67. disable_fake_quant(qat_from_float)
  68. disable_observer(qat_from_float)
  69. qat_net = QAT.DequantStub()
  70. qat_net.eval()
  71. disable_observer(qat_net)
  72. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  73. init_qat_net(qat_net)
  74. q_net = Q.DequantStub.from_qat_module(qat_net)
  75. q_net.eval()
  76. x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  77. x = fake_quant(x, inp_scale)
  78. x.q_dict["scale"] = inp_scale
  79. normal = normal_net(x)
  80. qat_without_fakequant = qat_from_float(x)
  81. fake_quant_normal = normal_net(x)
  82. qat = qat_net(x)
  83. q = q_net(quant(x, inp_scale)).numpy()
  84. np.testing.assert_allclose(qat_without_fakequant, normal)
  85. np.testing.assert_allclose(qat, fake_quant_normal)
  86. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  87. @pytest.mark.parametrize("kind", ["COS", "RELU", "ADD", "MUL", "FUSE_ADD_RELU"])
  88. def test_elemwise(kind):
  89. normal_net = Float.Elemwise(kind)
  90. normal_net.eval()
  91. qat_from_float = QAT.Elemwise.from_float_module(normal_net)
  92. qat_from_float.eval()
  93. disable_observer(qat_from_float)
  94. disable_fake_quant(qat_from_float)
  95. qat_net = QAT.Elemwise(kind)
  96. qat_net.eval()
  97. disable_observer(qat_net)
  98. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  99. init_qat_net(qat_net)
  100. q_net = Q.Elemwise.from_qat_module(qat_net)
  101. q_net.eval()
  102. x1_scale = np.float32(np.random.rand() + 1)
  103. x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  104. x1 = fake_quant(x1, x1_scale)
  105. x1.q_dict["scale"] = x1_scale
  106. x2_scale = np.float32(np.random.rand() + 1)
  107. x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  108. x2 = fake_quant(x2, x2_scale)
  109. x2.q_dict["scale"] = x2_scale
  110. x1_int8 = quant(x1, x1_scale)
  111. x2_int8 = quant(x2, x2_scale)
  112. # test correctness of `Float`, `QAT` and `Quantized`
  113. if kind in ("ADD", "MUL", "FUSE_ADD_RELU"):
  114. normal = normal_net(x1, x2)
  115. qat_without_fakequant = qat_from_float(x1, x2)
  116. fake_quant_normal = fake_quant(normal_net(x1, x2), act_scale)
  117. qat = qat_net(x1, x2)
  118. q = q_net(x1_int8, x2_int8).numpy() * act_scale
  119. else:
  120. normal = normal_net(x1)
  121. qat_without_fakequant = qat_from_float(x1)
  122. fake_quant_normal = fake_quant(normal_net(x1), act_scale)
  123. qat = qat_net(x1)
  124. q = q_net(x1_int8).numpy() * act_scale
  125. np.testing.assert_allclose(qat_without_fakequant, normal)
  126. np.testing.assert_allclose(qat, fake_quant_normal)
  127. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  128. def test_linear():
  129. normal_net = Float.Linear(3, 3, bias=True)
  130. normal_net.eval()
  131. qat_net = QAT.Linear(3, 3, bias=True)
  132. qat_net.eval()
  133. disable_observer(qat_net)
  134. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  135. init_qat_net(qat_net)
  136. x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  137. x = fake_quant(x, inp_scale)
  138. x.q_dict["scale"] = inp_scale
  139. x_int8 = quant(x, inp_scale)
  140. weight = np.random.normal(size=(3, 3)).astype("float32")
  141. bias = np.random.normal(size=(3,)).astype("float32")
  142. normal_net.weight.set_value(fake_quant(weight, weight_scale))
  143. normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale))
  144. qat_net.weight.set_value(weight)
  145. qat_net.bias.set_value(bias)
  146. qat_from_float = QAT.Linear.from_float_module(normal_net)
  147. qat_from_float.eval()
  148. disable_fake_quant(qat_from_float)
  149. disable_observer(qat_from_float)
  150. q_net = Q.Linear.from_qat_module(qat_net)
  151. q_net.eval()
  152. normal = normal_net(x)
  153. qat_without_fakequant = qat_from_float(x)
  154. fake_quant_normal = fake_quant(normal_net(x), act_scale)
  155. qat = qat_net(x)
  156. q = q_net(x_int8).numpy() * act_scale
  157. np.testing.assert_allclose(qat_without_fakequant, normal)
  158. np.testing.assert_allclose(qat, fake_quant_normal)
  159. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  160. @pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"])
  161. def test_conv(module):
  162. normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True)
  163. normal_net.eval()
  164. qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True)
  165. qat_net.eval()
  166. disable_observer(qat_net)
  167. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  168. init_qat_net(qat_net)
  169. x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32"))
  170. x = fake_quant(x, inp_scale)
  171. x.q_dict["scale"] = inp_scale
  172. x_int8 = quant(x, inp_scale)
  173. weight = np.random.normal(size=(3, 3, 3, 3)).astype("float32")
  174. bias = np.random.normal(size=(1, 3, 1, 1)).astype("float32")
  175. if module in ("ConvBn2d", "ConvBnRelu2d"):
  176. normal_net.conv.weight.set_value(fake_quant(weight, weight_scale))
  177. normal_net.conv.bias.set_value(fake_quant(bias, inp_scale * weight_scale))
  178. qat_net.conv.weight.set_value(weight)
  179. qat_net.conv.bias.set_value(bias)
  180. else:
  181. normal_net.weight.set_value(fake_quant(weight, weight_scale))
  182. normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale))
  183. qat_net.weight.set_value(weight)
  184. qat_net.bias.set_value(bias)
  185. qat_from_float = getattr(QAT, module).from_float_module(normal_net)
  186. qat_from_float.eval()
  187. disable_observer(qat_from_float)
  188. disable_fake_quant(qat_from_float)
  189. q_net = getattr(Q, module).from_qat_module(qat_net)
  190. q_net.eval()
  191. normal = normal_net(x)
  192. qat_without_fakequant = qat_from_float(x)
  193. fake_quant_normal = fake_quant(normal_net(x), act_scale)
  194. qat = qat_net(x)
  195. q = q_net(x_int8).numpy() * act_scale
  196. np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-6)
  197. np.testing.assert_allclose(qat, fake_quant_normal)
  198. np.testing.assert_allclose(q, fake_quant_normal.numpy())

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台