You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_module.py 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. from functools import partial
  2. import numpy as np
  3. import pytest
  4. import megengine as mge
  5. import megengine.functional as F
  6. import megengine.module as Float
  7. import megengine.module.qat as QAT
  8. import megengine.module.quantized as Q
  9. from megengine import Parameter, Tensor
  10. from megengine.core.tensor import dtype
  11. from megengine.quantization import FakeQuantize, MinMaxObserver, QConfig
  12. from megengine.quantization.quantize import (
  13. disable_fake_quant,
  14. disable_observer,
  15. propagate_qconfig,
  16. )
  17. min_max_fakequant_qconfig = QConfig(
  18. weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
  19. act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False),
  20. weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
  21. act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
  22. )
  23. inp_scale = np.float32(np.random.rand() + 1)
  24. min_val = np.random.randint(-127, 0, size=(2,)).astype("float32")
  25. max_val = np.random.randint(1, 127, size=(2,)).astype("float32")
  26. weight_scale = np.float32(np.max([-min_val[0], max_val[0]]) / 254 * 2)
  27. act_scale = np.float32(np.max([-min_val[1], max_val[1]]) / 255 * 2)
  28. def quant(x, scale):
  29. inp_dtype = dtype.qint8(scale)
  30. return x.astype(inp_dtype)
  31. def fake_quant(x, scale, qmin, qmax):
  32. x = x / scale
  33. x = F.round(x)
  34. x = F.clip(x, qmin, qmax)
  35. x = x * scale
  36. return x
  37. fake_quant_act = partial(fake_quant, qmin=-128, qmax=127)
  38. fake_quant_weight = partial(fake_quant, qmin=-127, qmax=127)
  39. fake_quant_bias = partial(fake_quant, qmin=-(2 ** 31), qmax=2 ** 31 - 1)
  40. def init_qat_net(net):
  41. if net.with_weight:
  42. net.weight_observer.min_val[...] = Tensor(min_val[0])
  43. net.weight_observer.max_val[...] = Tensor(max_val[0])
  44. if net.with_act:
  45. net.act_observer.min_val[...] = Tensor(min_val[1])
  46. net.act_observer.max_val[...] = Tensor(max_val[1])
  47. def test_quant_stub():
  48. normal_net = Float.QuantStub()
  49. normal_net.eval()
  50. qat_from_float = QAT.QuantStub.from_float_module(normal_net)
  51. qat_from_float.eval()
  52. disable_observer(qat_from_float)
  53. disable_fake_quant(qat_from_float)
  54. qat_net = QAT.QuantStub()
  55. qat_net.eval()
  56. disable_observer(qat_net)
  57. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  58. init_qat_net(qat_net)
  59. q_net = Q.QuantStub.from_qat_module(qat_net)
  60. q_net.eval()
  61. x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  62. normal = normal_net(x)
  63. qat_without_fakequant = qat_from_float(x)
  64. fake_quant_normal = fake_quant_act(normal_net(x), act_scale)
  65. qat = qat_net(x)
  66. q = q_net(x).numpy() * act_scale
  67. np.testing.assert_allclose(qat_without_fakequant, normal)
  68. np.testing.assert_allclose(qat, fake_quant_normal)
  69. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  70. def test_dequant_stub():
  71. normal_net = Float.DequantStub()
  72. normal_net.eval()
  73. qat_from_float = QAT.DequantStub.from_float_module(normal_net)
  74. qat_from_float.eval()
  75. disable_fake_quant(qat_from_float)
  76. disable_observer(qat_from_float)
  77. qat_net = QAT.DequantStub()
  78. qat_net.eval()
  79. disable_observer(qat_net)
  80. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  81. init_qat_net(qat_net)
  82. q_net = Q.DequantStub.from_qat_module(qat_net)
  83. q_net.eval()
  84. x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  85. x = fake_quant_act(x, inp_scale)
  86. x.q_dict["scale"] = inp_scale
  87. normal = normal_net(x)
  88. qat_without_fakequant = qat_from_float(x)
  89. fake_quant_normal = normal_net(x)
  90. qat = qat_net(x)
  91. q = q_net(quant(x, inp_scale)).numpy()
  92. np.testing.assert_allclose(qat_without_fakequant, normal)
  93. np.testing.assert_allclose(qat, fake_quant_normal)
  94. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  95. @pytest.mark.parametrize("kind", ["COS", "RELU", "ADD", "MUL", "FUSE_ADD_RELU"])
  96. def test_elemwise(kind):
  97. normal_net = Float.Elemwise(kind)
  98. normal_net.eval()
  99. qat_from_float = QAT.Elemwise.from_float_module(normal_net)
  100. qat_from_float.eval()
  101. disable_observer(qat_from_float)
  102. disable_fake_quant(qat_from_float)
  103. qat_net = QAT.Elemwise(kind)
  104. qat_net.eval()
  105. disable_observer(qat_net)
  106. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  107. init_qat_net(qat_net)
  108. q_net = Q.Elemwise.from_qat_module(qat_net)
  109. q_net.eval()
  110. x1_scale = np.float32(np.random.rand() + 1)
  111. x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  112. x1 = fake_quant_act(x1, x1_scale)
  113. x1.q_dict["scale"] = x1_scale
  114. x2_scale = np.float32(np.random.rand() + 1)
  115. x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  116. x2 = fake_quant_act(x2, x2_scale)
  117. x2.q_dict["scale"] = x2_scale
  118. x1_int8 = quant(x1, x1_scale)
  119. x2_int8 = quant(x2, x2_scale)
  120. # test correctness of `Float`, `QAT` and `Quantized`
  121. if kind in ("ADD", "MUL", "FUSE_ADD_RELU"):
  122. normal = normal_net(x1, x2)
  123. qat_without_fakequant = qat_from_float(x1, x2)
  124. fake_quant_normal = fake_quant_act(normal_net(x1, x2), act_scale)
  125. qat = qat_net(x1, x2)
  126. q = q_net(x1_int8, x2_int8).numpy() * act_scale
  127. else:
  128. normal = normal_net(x1)
  129. qat_without_fakequant = qat_from_float(x1)
  130. fake_quant_normal = fake_quant_act(normal_net(x1), act_scale)
  131. qat = qat_net(x1)
  132. q = q_net(x1_int8).numpy() * act_scale
  133. np.testing.assert_allclose(qat_without_fakequant, normal)
  134. np.testing.assert_allclose(qat, fake_quant_normal)
  135. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  136. def test_linear():
  137. normal_net = Float.Linear(3, 3, bias=True)
  138. normal_net.eval()
  139. qat_net = QAT.Linear(3, 3, bias=True)
  140. qat_net.eval()
  141. disable_observer(qat_net)
  142. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  143. init_qat_net(qat_net)
  144. x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  145. x = fake_quant_act(x, inp_scale)
  146. x.q_dict["scale"] = inp_scale
  147. x_int8 = quant(x, inp_scale)
  148. weight = np.random.normal(size=(3, 3)).astype("float32")
  149. bias = np.random.normal(size=(3,)).astype("float32")
  150. normal_net.weight[...] = fake_quant_weight(weight, weight_scale)
  151. normal_net.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale)
  152. qat_net.weight[...] = Parameter(weight)
  153. qat_net.bias[...] = Parameter(bias)
  154. qat_from_float = QAT.Linear.from_float_module(normal_net)
  155. qat_from_float.eval()
  156. disable_fake_quant(qat_from_float)
  157. disable_observer(qat_from_float)
  158. q_net = Q.Linear.from_qat_module(qat_net)
  159. q_net.eval()
  160. normal = normal_net(x)
  161. qat_without_fakequant = qat_from_float(x)
  162. fake_quant_normal = fake_quant_act(normal_net(x), act_scale)
  163. qat = qat_net(x)
  164. q = q_net(x_int8).numpy() * act_scale
  165. np.testing.assert_allclose(qat_without_fakequant, normal)
  166. np.testing.assert_allclose(qat, fake_quant_normal.numpy())
  167. np.testing.assert_allclose(q, fake_quant_normal.numpy())
  168. @pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"])
  169. def test_conv(module):
  170. normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True)
  171. normal_net.eval()
  172. qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True)
  173. qat_net.eval()
  174. disable_observer(qat_net)
  175. propagate_qconfig(qat_net, min_max_fakequant_qconfig)
  176. init_qat_net(qat_net)
  177. x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32"))
  178. x = fake_quant_act(x, inp_scale)
  179. x.q_dict["scale"] = inp_scale
  180. x_int8 = quant(x, inp_scale)
  181. weight = np.random.normal(size=(3, 3, 3, 3)).astype("float32")
  182. bias = np.random.normal(size=(1, 3, 1, 1)).astype("float32")
  183. if module in ("ConvBn2d", "ConvBnRelu2d"):
  184. normal_net.conv.weight[...] = fake_quant_weight(weight, weight_scale)
  185. normal_net.conv.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale)
  186. qat_net.conv.weight[...] = Parameter(weight)
  187. qat_net.conv.bias[...] = Parameter(bias)
  188. else:
  189. normal_net.weight[...] = fake_quant_weight(weight, weight_scale)
  190. normal_net.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale)
  191. qat_net.weight[...] = Parameter(weight)
  192. qat_net.bias[...] = Parameter(bias)
  193. qat_from_float = getattr(QAT, module).from_float_module(normal_net)
  194. qat_from_float.eval()
  195. disable_observer(qat_from_float)
  196. disable_fake_quant(qat_from_float)
  197. q_net = getattr(Q, module).from_qat_module(qat_net)
  198. q_net.eval()
  199. normal = normal_net(x)
  200. qat_without_fakequant = qat_from_float(x)
  201. fake_quant_normal = fake_quant_act(normal_net(x), act_scale)
  202. qat = qat_net(x)
  203. q = q_net(x_int8).numpy() * act_scale
  204. np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5)
  205. np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale)
  206. np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台