You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_module.py 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import numpy as np
  2. import pytest
  3. import megengine as mge
  4. import megengine.functional as F
  5. import megengine.module as Float
  6. import megengine.module.qat as QAT
  7. import megengine.module.quantized as Q
  8. from megengine.core.tensor import dtype
  9. from megengine.quantization import min_max_fakequant_qconfig
  10. from megengine.quantization.quantize import disable_observer, propagate_qconfig
  11. """
  12. Calculate testing scales based on ``min_max_fakequant_qconfig``
  13. """
  14. inp_scale = np.float32(np.random.rand() + 1)
  15. min_val = np.random.randint(-127, 0, size=(2,)).astype("float32")
  16. max_val = np.random.randint(1, 127, size=(2,)).astype("float32")
  17. weight_scale = np.float32(np.max([-min_val[0], max_val[0]]) / 254 * 2)
  18. act_scale = np.float32(np.max([-min_val[1], max_val[1]]) / 255 * 2)
  19. def quant(x, scale):
  20. inp_dtype = dtype.qint8(scale)
  21. return x.astype(inp_dtype)
  22. def fake_quant(x, scale):
  23. x = x / scale
  24. x = F.round(x)
  25. x = F.clip(x, -128, 127)
  26. x = x * scale
  27. return x
  28. def init_qat_net(net):
  29. if net.with_weight:
  30. net.weight_observer.min_val.set_value(min_val[0])
  31. net.weight_observer.max_val.set_value(max_val[0])
  32. if net.with_act:
  33. net.act_observer.min_val.set_value(min_val[1])
  34. net.act_observer.max_val.set_value(max_val[1])
  35. def test_quant_stub():
  36. normal_net = Float.QuantStub()
  37. normal_net.eval()
  38. qat_net = QAT.QuantStub()
  39. qat_net.eval()
  40. disable_observer(qat_net)
  41. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  42. init_qat_net(qat_net)
  43. q_net = Q.QuantStub.from_qat_module(qat_net)
  44. q_net.eval()
  45. x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  46. normal_out = fake_quant(normal_net(x), act_scale)
  47. qat_out = qat_net(x)
  48. q_out = q_net(x).numpy() * act_scale
  49. np.testing.assert_allclose(qat_out, normal_out)
  50. np.testing.assert_allclose(q_out, normal_out.numpy())
  51. def test_dequant_stub():
  52. normal_net = Float.DequantStub()
  53. normal_net.eval()
  54. qat_net = QAT.DequantStub()
  55. qat_net.eval()
  56. disable_observer(qat_net)
  57. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  58. init_qat_net(qat_net)
  59. q_net = Q.DequantStub.from_qat_module(qat_net)
  60. q_net.eval()
  61. x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  62. x = fake_quant(x, inp_scale)
  63. x.q_dict["scale"] = inp_scale
  64. normal_out = normal_net(x)
  65. qat_out = qat_net(x)
  66. q_out = q_net(quant(x, inp_scale)).numpy()
  67. np.testing.assert_allclose(qat_out, normal_out)
  68. np.testing.assert_allclose(q_out, normal_out.numpy())
  69. @pytest.mark.parametrize("kind", ["COS", "RELU", "ADD", "MUL", "FUSE_ADD_RELU"])
  70. def test_elemwise(kind):
  71. normal_net = Float.Elemwise(kind)
  72. normal_net.eval()
  73. qat_net = QAT.Elemwise(kind)
  74. qat_net.eval()
  75. disable_observer(qat_net)
  76. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  77. init_qat_net(qat_net)
  78. q_net = Q.Elemwise.from_qat_module(qat_net)
  79. q_net.eval()
  80. x1_scale = np.float32(np.random.rand() + 1)
  81. x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  82. x1 = fake_quant(x1, x1_scale)
  83. x1.q_dict["scale"] = x1_scale
  84. x2_scale = np.float32(np.random.rand() + 1)
  85. x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  86. x2 = fake_quant(x2, x2_scale)
  87. x2.q_dict["scale"] = x2_scale
  88. x1_int8 = quant(x1, x1_scale)
  89. x2_int8 = quant(x2, x2_scale)
  90. if kind in ("ADD", "MUL", "FUSE_ADD_RELU"):
  91. normal_out = fake_quant(normal_net(x1, x2), act_scale)
  92. qat_out = qat_net(x1, x2)
  93. q_out = q_net(x1_int8, x2_int8).numpy() * act_scale
  94. else:
  95. normal_out = fake_quant(normal_net(x1), act_scale)
  96. qat_out = qat_net(x1)
  97. q_out = q_net(x1_int8).numpy() * act_scale
  98. np.testing.assert_allclose(qat_out, normal_out)
  99. np.testing.assert_allclose(q_out, normal_out.numpy())
  100. def test_linear():
  101. normal_net = Float.Linear(3, 3, bias=True)
  102. normal_net.eval()
  103. qat_net = QAT.Linear(3, 3, bias=True)
  104. qat_net.eval()
  105. disable_observer(qat_net)
  106. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  107. init_qat_net(qat_net)
  108. x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  109. x = fake_quant(x, inp_scale)
  110. x.q_dict["scale"] = inp_scale
  111. x_int8 = quant(x, inp_scale)
  112. weight = np.random.normal(size=(3, 3)).astype("float32")
  113. bias = np.random.normal(size=(3,)).astype("float32")
  114. normal_net.weight.set_value(fake_quant(weight, weight_scale))
  115. normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale))
  116. qat_net.weight.set_value(weight)
  117. qat_net.bias.set_value(bias)
  118. q_net = Q.Linear.from_qat_module(qat_net)
  119. q_net.eval()
  120. normal_out = fake_quant(normal_net(x), act_scale)
  121. qat_out = qat_net(x)
  122. q_out = q_net(x_int8).numpy() * act_scale
  123. np.testing.assert_allclose(qat_out, normal_out)
  124. np.testing.assert_allclose(q_out, normal_out.numpy())
  125. @pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"])
  126. def test_conv(module):
  127. normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True)
  128. normal_net.eval()
  129. qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True)
  130. qat_net.eval()
  131. disable_observer(qat_net)
  132. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  133. init_qat_net(qat_net)
  134. x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32"))
  135. x = fake_quant(x, inp_scale)
  136. x.q_dict["scale"] = inp_scale
  137. x_int8 = quant(x, inp_scale)
  138. weight = np.random.normal(size=(3, 3, 3, 3)).astype("float32")
  139. bias = np.random.normal(size=(1, 3, 1, 1)).astype("float32")
  140. if module in ("ConvBn2d", "ConvBnRelu2d"):
  141. normal_net.conv.weight.set_value(fake_quant(weight, weight_scale))
  142. normal_net.conv.bias.set_value(fake_quant(bias, inp_scale * weight_scale))
  143. qat_net.conv.weight.set_value(weight)
  144. qat_net.conv.bias.set_value(bias)
  145. else:
  146. normal_net.weight.set_value(fake_quant(weight, weight_scale))
  147. normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale))
  148. qat_net.weight.set_value(weight)
  149. qat_net.bias.set_value(bias)
  150. q_net = getattr(Q, module).from_qat_module(qat_net)
  151. q_net.eval()
  152. normal_out = fake_quant(normal_net(x), act_scale)
  153. qat_out = qat_net(x)
  154. q_out = q_net(x_int8).numpy() * act_scale
  155. np.testing.assert_allclose(qat_out, normal_out)
  156. np.testing.assert_allclose(q_out, normal_out.numpy())

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台